Merge pull request #61868 from Tai78641:pr_fix_2d_batch_matmul

PiperOrigin-RevId: 579065619
diff --git a/.bazelrc b/.bazelrc
index 7035378..e9fc2d4 100644
--- a/.bazelrc
+++ b/.bazelrc
@@ -55,6 +55,7 @@
 #
 #     rbe_linux_cpu:                  RBE options to build with only CPU support.
 #     rbe_linux_cuda:                 RBE options to build with GPU support using clang.
+#     rbe_linux_cuda_nvcc:            RBE options to build with GPU support using nvcc.
 #
 #     rbe_win_py39: Windows Python 3.9 RBE config
 #
@@ -237,9 +238,12 @@
 # Select supported compute capabilities (supported graphics cards).
 # This is the same as the official TensorFlow builds.
 # See https://developer.nvidia.com/cuda-gpus#compute
-# TODO(angerson, perfinion): What does sm_ vs compute_ mean? How can users
-# select a good value for this? See go/tf-pip-cuda
-build:cuda_clang --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_75,compute_80"
+# `compute_XY` enables PTX embedding in addition to SASS. PTX
+# is forward compatible beyond the current compute capability major
+# release while SASS is only forward compatible inside the current
+# major release. Example: sm_80 kernels can run on sm_89 GPUs but
+# not on sm_90 GPUs. compute_80 kernels though can also run on sm_90 GPUs.
+build:cuda_clang --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90"
 
 # Set up compilation CUDA version and paths and use the CUDA Clang toolchain.
 build:cuda_clang_official --config=cuda_clang
@@ -249,7 +253,7 @@
 build:cuda_clang_official --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc"
 build:cuda_clang_official --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-17/bin/clang"
 build:cuda_clang_official --action_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64"
-build:cuda_clang_official --crosstool_top="@sigbuild-r2.14-clang_config_cuda//crosstool:toolchain"
+build:cuda_clang_official --crosstool_top="@sigbuild-r2.16-clang_config_cuda//crosstool:toolchain"
 
 # Debug config
 build:dbg -c dbg
@@ -482,12 +486,12 @@
 
 build:rbe_linux_cpu --config=rbe_linux
 # Linux cpu and cuda builds share the same toolchain now.
-build:rbe_linux_cpu --host_crosstool_top="@sigbuild-r2.14-clang_config_cuda//crosstool:toolchain"
-build:rbe_linux_cpu --crosstool_top="@sigbuild-r2.14-clang_config_cuda//crosstool:toolchain"
-build:rbe_linux_cpu --extra_toolchains="@sigbuild-r2.14-clang_config_cuda//crosstool:toolchain-linux-x86_64"
-build:rbe_linux_cpu --extra_execution_platforms="@sigbuild-r2.14-clang_config_platform//:platform"
-build:rbe_linux_cpu --host_platform="@sigbuild-r2.14-clang_config_platform//:platform"
-build:rbe_linux_cpu --platforms="@sigbuild-r2.14-clang_config_platform//:platform"
+build:rbe_linux_cpu --host_crosstool_top="@sigbuild-r2.16-clang_config_cuda//crosstool:toolchain"
+build:rbe_linux_cpu --crosstool_top="@sigbuild-r2.16-clang_config_cuda//crosstool:toolchain"
+build:rbe_linux_cpu --extra_toolchains="@sigbuild-r2.16-clang_config_cuda//crosstool:toolchain-linux-x86_64"
+build:rbe_linux_cpu --extra_execution_platforms="@sigbuild-r2.16-clang_config_platform//:platform"
+build:rbe_linux_cpu --host_platform="@sigbuild-r2.16-clang_config_platform//:platform"
+build:rbe_linux_cpu --platforms="@sigbuild-r2.16-clang_config_platform//:platform"
 # This is needed for all Clang17 builds but must not be present in GCC builds.
 build:rbe_linux_cpu --copt=-Wno-error=unused-command-line-argument
 # This was added in clang-16 by https://reviews.llvm.org/D133574.
@@ -496,7 +500,7 @@
 # See https://github.com/protocolbuffers/upb/blob/9effcbcb27f0a665f9f345030188c0b291e32482/upb/upb.c#L183.
 build:rbe_linux_cpu --copt=-Wno-gnu-offsetof-extensions
 # Python config is the same across all containers because the binary is the same
-build:rbe_linux_cpu --repo_env=TF_PYTHON_CONFIG_REPO="@sigbuild-r2.14-clang_config_python"
+build:rbe_linux_cpu --repo_env=TF_PYTHON_CONFIG_REPO="@sigbuild-r2.16-clang_config_python"
 build:rbe_linux_cpu --python_path="/usr/bin/python3"
 # These you may need to change for your own GCP project.
 common:rbe_linux_cpu --remote_instance_name=projects/tensorflow-testing/instances/default_instance
@@ -517,11 +521,40 @@
 build:rbe_linux_cuda --config=rbe_linux_cpu
 # For Remote build execution -- GPU configuration
 build:rbe_linux_cuda --repo_env=REMOTE_GPU_TESTING=1
-build:rbe_linux_cuda --repo_env=TF_CUDA_CONFIG_REPO="@sigbuild-r2.14-clang_config_cuda"
-build:rbe_linux_cuda --repo_env=TF_TENSORRT_CONFIG_REPO="@sigbuild-r2.14-clang_config_tensorrt"
-build:rbe_linux_cuda --repo_env=TF_NCCL_CONFIG_REPO="@sigbuild-r2.14-clang_config_nccl"
+build:rbe_linux_cuda --repo_env=TF_CUDA_CONFIG_REPO="@sigbuild-r2.16-clang_config_cuda"
+build:rbe_linux_cuda --repo_env=TF_TENSORRT_CONFIG_REPO="@sigbuild-r2.16-clang_config_tensorrt"
+build:rbe_linux_cuda --repo_env=TF_NCCL_CONFIG_REPO="@sigbuild-r2.16-clang_config_nccl"
 test:rbe_linux_cuda --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64"
 
+build:rbe_linux_cuda_nvcc --config=cuda
+build:rbe_linux_cuda_nvcc --repo_env TF_NCCL_USE_STUB=1
+build:rbe_linux_cuda_nvcc --@local_xla//xla/python:enable_gpu=true
+build:rbe_linux_cuda_nvcc --@local_xla//xla/python:jax_cuda_pip_rpaths=true
+build:rbe_linux_cuda_nvcc --define=xla_python_enable_gpu=true
+build:rbe_linux_cuda_nvcc --config=tensorrt
+build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_75,compute_80"
+build:rbe_linux_cuda_nvcc --action_env=TF_CUDA_VERSION="12"
+build:rbe_linux_cuda_nvcc --action_env=TF_CUDNN_VERSION="8"
+build:rbe_linux_cuda_nvcc --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-12.2"
+build:rbe_linux_cuda_nvcc --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc"
+build:rbe_linux_cuda_nvcc --action_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64"
+build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_cuda//crosstool:toolchain"
+build:rbe_linux_cuda_nvcc --config=rbe_linux
+build:rbe_linux_cuda_nvcc --host_crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_cuda//crosstool:toolchain"
+build:rbe_linux_cuda_nvcc --extra_toolchains="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_cuda//crosstool:toolchain-linux-x86_64"
+build:rbe_linux_cuda_nvcc --extra_execution_platforms="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_platform//:platform"
+build:rbe_linux_cuda_nvcc --host_platform="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_platform//:platform"
+build:rbe_linux_cuda_nvcc --platforms="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_platform//:platform"
+build:rbe_linux_cuda_nvcc --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_python3.9"
+build:rbe_linux_cuda_nvcc --python_path="/usr/bin/python3"
+# These you may need to change for your own GCP project.
+common:rbe_linux_cuda_nvcc --remote_instance_name=projects/tensorflow-testing/instances/default_instance
+build:rbe_linux_cuda_nvcc --repo_env=REMOTE_GPU_TESTING=1
+build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda12.2-cudnn8.9_config_cuda"
+build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda12.2-cudnn8.9_config_tensorrt"
+build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda12.2-cudnn8.9_config_nccl"
+test:rbe_linux_cuda_nvcc --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64"
+
 # TODO(kanglan): Remove rbe_win and rbe_win_py3* after b/289091160 is fixed
 build:rbe_win --config=rbe_base
 build:rbe_win --crosstool_top="//tensorflow/tools/toolchains/win/tf_win_05022023:toolchain"
@@ -576,8 +609,6 @@
 # Here are bazelrc configs for release builds
 # Build TensorFlow v2.
 test:release_base --test_size_filters=small,medium
-# TODO(b/294367488) disable after 2.15 brancut
-test:release_base --flaky_test_attempts=3
 
 # Target the AVX instruction set
 build:release_linux_base --config=avx_linux
@@ -615,7 +646,7 @@
 
 # Use the Clang toolchain to compile
 build:release_cpu_linux --config=release_linux_base
-build:release_cpu_linux --crosstool_top="@sigbuild-r2.14-clang_config_cuda//crosstool:toolchain"
+build:release_cpu_linux --crosstool_top="@sigbuild-r2.16-clang_config_cuda//crosstool:toolchain"
 
 build:release_gpu_linux --config=release_cpu_linux
 # Set up compilation CUDA version and paths and use the CUDA Clang toolchain.
@@ -684,7 +715,7 @@
 build:macos   --config=no_tfrt
 build:windows --config=no_tfrt
 build:rocm --config=no_tfrt
-build:no_tfrt --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compiler/mlir/tfrt/benchmarks,tensorflow/compiler/mlir/tfrt/ir,tensorflow/compiler/mlir/tfrt/ir/mlrt,tensorflow/compiler/mlir/tfrt/jit/python_binding,tensorflow/compiler/mlir/tfrt/jit/transforms,tensorflow/compiler/mlir/tfrt/python_tests,tensorflow/compiler/mlir/tfrt/tests,tensorflow/compiler/mlir/tfrt/tests/mlrt,tensorflow/compiler/mlir/tfrt/tests/ir,tensorflow/compiler/mlir/tfrt/tests/analysis,tensorflow/compiler/mlir/tfrt/tests/jit,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt,tensorflow/compiler/mlir/tfrt/tests/tf_to_corert,tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data,tensorflow/compiler/mlir/tfrt/tests/saved_model,tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu,tensorflow/compiler/mlir/tfrt/transforms/mlrt,tensorflow/core/runtime_fallback,tensorflow/core/runtime_fallback/conversion,tensorflow/core/runtime_fallback/kernel,tensorflow/core/runtime_fallback/opdefs,tensorflow/core/runtime_fallback/runtime,tensorflow/core/runtime_fallback/util,tensorflow/core/runtime_fallback/test,tensorflow/core/runtime_fallback/test/gpu,tensorflow/core/runtime_fallback/test/saved_model,tensorflow/core/runtime_fallback/test/testdata,tensorflow/core/tfrt/stubs,tensorflow/core/tfrt/tfrt_session,tensorflow/core/tfrt/mlrt,tensorflow/core/tfrt/mlrt/attribute,tensorflow/core/tfrt/mlrt/kernel,tensorflow/core/tfrt/mlrt/bytecode,tensorflow/core/tfrt/mlrt/interpreter,tensorflow/compiler/mlir/tfrt/translate/mlrt,tensorflow/compiler/mlir/tfrt/translate/mlrt/testdata,tensorflow/core/tfrt/gpu,tensorflow/core/tfrt/run_handler_thread_pool,tensorflow/core/tfrt/runtime,tensorflow/core/tfrt/saved_model,tensorflow/core/tfrt/graph_executor,tensorflow/core/tfrt/saved_model/tests,tensorflow/core/tfrt/tpu,tensorflow/core/tfrt/utils,tensorflow/core/tfrt/utils/debug,tensorflow/core/tfrt/saved_model/python,tensorflow/core/tfrt/graph_executor/python,tensorflow/core/tfrt/saved_model/utils
+build:no_tfrt --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compiler/mlir/tfrt/benchmarks,tensorflow/compiler/mlir/tfrt/ir,tensorflow/compiler/mlir/tfrt/ir/mlrt,tensorflow/compiler/mlir/tfrt/jit/python_binding,tensorflow/compiler/mlir/tfrt/jit/transforms,tensorflow/compiler/mlir/tfrt/python_tests,tensorflow/compiler/mlir/tfrt/tests,tensorflow/compiler/mlir/tfrt/tests/ifrt,tensorflow/compiler/mlir/tfrt/tests/mlrt,tensorflow/compiler/mlir/tfrt/tests/ir,tensorflow/compiler/mlir/tfrt/tests/analysis,tensorflow/compiler/mlir/tfrt/tests/jit,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt,tensorflow/compiler/mlir/tfrt/tests/tf_to_corert,tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data,tensorflow/compiler/mlir/tfrt/tests/saved_model,tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu,tensorflow/compiler/mlir/tfrt/transforms/mlrt,tensorflow/core/runtime_fallback,tensorflow/core/runtime_fallback/conversion,tensorflow/core/runtime_fallback/kernel,tensorflow/core/runtime_fallback/opdefs,tensorflow/core/runtime_fallback/runtime,tensorflow/core/runtime_fallback/util,tensorflow/core/runtime_fallback/test,tensorflow/core/runtime_fallback/test/gpu,tensorflow/core/runtime_fallback/test/saved_model,tensorflow/core/runtime_fallback/test/testdata,tensorflow/core/tfrt/stubs,tensorflow/core/tfrt/tfrt_session,tensorflow/core/tfrt/mlrt,tensorflow/core/tfrt/mlrt/attribute,tensorflow/core/tfrt/mlrt/kernel,tensorflow/core/tfrt/mlrt/bytecode,tensorflow/core/tfrt/mlrt/interpreter,tensorflow/compiler/mlir/tfrt/translate/mlrt,tensorflow/compiler/mlir/tfrt/translate/mlrt/testdata,tensorflow/core/tfrt/gpu,tensorflow/core/tfrt/run_handler_thread_pool,tensorflow/core/tfrt/runtime,tensorflow/core/tfrt/saved_model,tensorflow/core/tfrt/graph_executor,tensorflow/core/tfrt/saved_model/tests,tensorflow/core/tfrt/tpu,tensorflow/core/tfrt/utils,tensorflow/core/tfrt/utils/debug,tensorflow/core/tfrt/saved_model/python,tensorflow/core/tfrt/graph_executor/python,tensorflow/core/tfrt/saved_model/utils
 
 # BEGIN TF CACHE HELPER OPTIONS
 # Options when using remote execution
diff --git a/.github/bot_config.yml b/.github/bot_config.yml
index 9ddb1c2..f3508fe 100644
--- a/.github/bot_config.yml
+++ b/.github/bot_config.yml
@@ -18,7 +18,7 @@
    - sushreebarsa
    - SuryanarayanaY
    - tilakrayal
-   - Varsha-anjanappa
+   - Venkat6871
 # A list of assignees for compiler folder
 compiler_assignees:
    - joker-eph
diff --git a/.github/workflows/osv-scanner-scheduled.yml b/.github/workflows/osv-scanner-scheduled.yml
new file mode 100644
index 0000000..bb39d60
--- /dev/null
+++ b/.github/workflows/osv-scanner-scheduled.yml
@@ -0,0 +1,39 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+name: OSV-Scanner Scheduled Scan
+
+on:
+  schedule:
+    - cron: 0 4 * * 1
+
+permissions:
+  # Require writing security events to upload SARIF file to security tab
+  security-events: write
+  # Only need to read contents
+  contents: read
+
+jobs:
+  scan-scheduled:
+    uses: "google/osv-scanner/.github/workflows/osv-scanner-reusable.yml@main"
+    with:
+      scan-args: |-
+        --lockfile=requirements.txt:./requirements_lock_3_9.txt
+        --lockfile=requirements.txt:./requirements_lock_3_10.txt
+        --lockfile=requirements.txt:./requirements_lock_3_11.txt
+        --lockfile=requirements.txt:./requirements_lock_3_12.txt
+        --lockfile=requirements.txt:./ci/official/containers/linux_arm64/devel.requirements.txt
+        --lockfile=requirements.txt:./ci/official/containers/linux_arm64/jax.requirements.txt
+        --lockfile=requirements.txt:./ci/official/containers/linux_arm64/devel.usertools/test.requirements.txt
\ No newline at end of file
diff --git a/.github/workflows/sigbuild-docker-branch.yml b/.github/workflows/sigbuild-docker-branch.yml
index 108fe47..9f842f9 100644
--- a/.github/workflows/sigbuild-docker-branch.yml
+++ b/.github/workflows/sigbuild-docker-branch.yml
@@ -34,7 +34,7 @@
     runs-on: ubuntu-latest
     strategy:
       matrix:
-        python-version: [python3.9, python3.10, python3.11]
+        python-version: [python3.9, python3.10, python3.11, python3.12]
     steps:
       - name: Delete unnecessary tools folder
         run: rm -rf /opt/hostedtoolcache
diff --git a/.github/workflows/sigbuild-docker-presubmit.yml b/.github/workflows/sigbuild-docker-presubmit.yml
index c61e65e..03ae6f1 100644
--- a/.github/workflows/sigbuild-docker-presubmit.yml
+++ b/.github/workflows/sigbuild-docker-presubmit.yml
@@ -32,7 +32,7 @@
     runs-on: ubuntu-latest
     strategy:
       matrix:
-        python-version: [python3.9, python3.10, python3.11]
+        python-version: [python3.9, python3.10, python3.11, python3.12]
     permissions:
       contents: read
       pull-requests: write
@@ -87,6 +87,7 @@
           message: |
             I pushed these containers:
             
+            - `gcr.io/tensorflow-sigs/build:${{ github.event.number }}-python3.12`
             - `gcr.io/tensorflow-sigs/build:${{ github.event.number }}-python3.11`
             - `gcr.io/tensorflow-sigs/build:${{ github.event.number }}-python3.10`
             - `gcr.io/tensorflow-sigs/build:${{ github.event.number }}-python3.9`
diff --git a/.github/workflows/sigbuild-docker.yml b/.github/workflows/sigbuild-docker.yml
index ce9b99c..5549f29 100644
--- a/.github/workflows/sigbuild-docker.yml
+++ b/.github/workflows/sigbuild-docker.yml
@@ -37,7 +37,7 @@
     runs-on: ubuntu-latest
     strategy:
       matrix:
-        python-version: [python3.9, python3.10, python3.11]
+        python-version: [python3.9, python3.10, python3.11, python3.12]
     steps:
       - name: Delete unnecessary tools folder
         run: rm -rf /opt/hostedtoolcache
diff --git a/.github/workflows/update-rbe.yml b/.github/workflows/update-rbe.yml
index ca22041..1b421ef 100644
--- a/.github/workflows/update-rbe.yml
+++ b/.github/workflows/update-rbe.yml
@@ -105,6 +105,18 @@
         map sigbuild-r2.14-clang-python3.9 2.14-python3.9
         map sigbuild-r2.14-clang-python3.10 2.14-python3.10
         map sigbuild-r2.14-clang-python3.11 2.14-python3.11
+        # TF 2.16
+        map sigbuild-r2.16 2.16-python3.9
+        map sigbuild-r2.16-python3.9 2.16-python3.9
+        map sigbuild-r2.16-python3.10 2.16-python3.10
+        map sigbuild-r2.16-python3.11 2.16-python3.11
+        map sigbuild-r2.16-python3.12 2.16-python3.12
+        # TF 2.16 + Clang (containers are the same, but env vars in configs.bzl are different)
+        map sigbuild-r2.16-clang 2.16-python3.9
+        map sigbuild-r2.16-clang-python3.9 2.16-python3.9
+        map sigbuild-r2.16-clang-python3.10 2.16-python3.10
+        map sigbuild-r2.16-clang-python3.11 2.16-python3.11
+        map sigbuild-r2.16-clang-python3.12 2.16-python3.12
     - name: Create Pull Request with changes
       uses: peter-evans/create-pull-request@2b011faafdcbc9ceb11414d64d0573f37c774b04 # v4.2.3
       with:
diff --git a/RELEASE.md b/RELEASE.md
index 2786e62..75350ae 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -26,9 +26,28 @@
 * <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
 * <NOTES SHOULD BE GROUPED PER AREA>
 
+* `tf.lite`
+    * Added support for `stablehlo.gather`.
+    * Added support for `stablehlo.add`.
+    * Added support for `stablehlo.multiply`.
+    * Added support for `stablehlo.maximum`.
+    * Added support for `stablehlo.minimum`.
+
 ## Keras
 
-<INSERT SMALL BLURB ABOUT RELEASE FOCUS AREA AND POTENTIAL TOOLCHAIN CHANGES>
+*  `keras.layers.experimental.DynamicEmbedding`
+    * Added `DynamicEmbedding` Keras layer
+    * Added 'UpdateEmbeddingCallback`
+    * `DynamicEmbedding` layer allows for the continuous updating of the
+      vocabulary and embeddings during the training process. This layer
+      maintains a hash table to track the most up-to-date vocabulary based on
+      the inputs received by the layer and the eviction policy. When this layer
+      is used with an `UpdateEmbeddingCallback`, which is a time-based callback,
+      the vocabulary lookup tensor is updated at the time interval set in the
+      `UpdateEmbeddingCallback` based on the most up-to-date vocabulary hash
+      table maintained by the layer. If this layer is not used in conjunction
+      with `UpdateEmbeddingCallback` the behavior of the layer would be same as
+      `keras.layers.Embedding`.
 
 ### Breaking Changes
 
diff --git a/WORKSPACE b/WORKSPACE
index 6a85ffe..a697405 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -31,12 +31,12 @@
 
 python_repository(name = "python_version_repo")
 
-load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION")
+load("@python_version_repo//:py_version.bzl", "TF_PYTHON_VERSION")
 
 python_register_toolchains(
     name = "python",
     ignore_root_user_error = True,
-    python_version = HERMETIC_PYTHON_VERSION,
+    python_version = TF_PYTHON_VERSION,
 )
 
 load("@python//:defs.bzl", "interpreter")
@@ -62,7 +62,7 @@
     name = "pypi",
     annotations = NUMPY_ANNOTATIONS,
     python_interpreter_target = interpreter,
-    requirements = "//:requirements_lock_" + HERMETIC_PYTHON_VERSION.replace(".", "_") + ".txt",
+    requirements = "//:requirements_lock_" + TF_PYTHON_VERSION.replace(".", "_") + ".txt",
 )
 
 load("@pypi//:requirements.bzl", "install_deps")
diff --git a/ci/official/containers/linux_arm64/devel.packages.txt b/ci/official/containers/linux_arm64/devel.packages.txt
index a8a9cb4..efbae80 100644
--- a/ci/official/containers/linux_arm64/devel.packages.txt
+++ b/ci/official/containers/linux_arm64/devel.packages.txt
@@ -3,6 +3,8 @@
 automake
 build-essential
 ca-certificates
+# TODO(b/308399490) Remove CMake once dm-tree (Keras dependency) has 3.12 wheels
+cmake
 llvm-17
 clang-17
 clang-format-12
diff --git a/tensorflow/core/tfrt/saved_model/python/_pywrap_saved_model_aot_compile.pyi b/ci/official/debug_tfci.sh
old mode 100644
new mode 100755
similarity index 66%
copy from tensorflow/core/tfrt/saved_model/python/_pywrap_saved_model_aot_compile.pyi
copy to ci/official/debug_tfci.sh
index 05aae4b..2498203
--- a/tensorflow/core/tfrt/saved_model/python/_pywrap_saved_model_aot_compile.pyi
+++ b/ci/official/debug_tfci.sh
@@ -1,3 +1,4 @@
+#!/bin/bash
 # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,8 +13,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
+# This script dumps some information about the environment. It's most useful
+# for verifying changes to the TFCI scripts system, and most users won't need
+# to interact with it at all.
+source "${BASH_SOURCE%/*}/utilities/setup.sh"
 
-class AotOptions:
-    def __init__(self) -> None: ...
-
-def AotCompileSavedModel(input_model_dir: str = ..., aot_options: AotOptions = ..., output_model_dir: str = ...) -> None: ...
+echo "==TFCI== env outside of tfrun:"
+env
+echo "==TFCI== env inside of tfrun:"
+tfrun env
diff --git a/ci/official/envs/ci_default b/ci/official/envs/ci_default
index b7f433b..eb7938c 100644
--- a/ci/official/envs/ci_default
+++ b/ci/official/envs/ci_default
@@ -1,6 +1,7 @@
 TFCI_BAZEL_BAZELRC_ARGS=()
 TFCI_BAZEL_CONFIG_PREFIX=
 TFCI_BAZEL_COMMON_ARGS=()
+TFCI_PYTHON_VERSION=
 TFCI_BUILD_PIP_PACKAGE_ARGS=()
 TFCI_DOCKER_ARGS=()
 TFCI_DOCKER_ENABLE=1
@@ -15,6 +16,7 @@
 TFCI_NVIDIA_SMI_ENABLE=
 TFCI_OUTPUT_DIR=build_output
 TFCI_LIBTPU_DOWNLOAD_ENABLE=0
+TFCI_LIBTPU_DOWNLOAD_NIGHTLY_ENABLE=0
 TFCI_LIBTPU_DOWNLOAD_URL=
 TFCI_UPLOAD_LIB_ENABLE=
 TFCI_UPLOAD_LIB_LATEST_ENABLE=
diff --git a/ci/official/envs/ci_nightly_uploads b/ci/official/envs/ci_nightly_uploads
index e35a712d..ca6671f 100644
--- a/ci/official/envs/ci_nightly_uploads
+++ b/ci/official/envs/ci_nightly_uploads
@@ -1,6 +1,8 @@
-TFCI_UPLOAD_LIB_ENABLE=0
-TFCI_UPLOAD_LIB_LATEST_ENABLE=
-TFCI_UPLOAD_WHL_GCS_ENABLE=
+TFCI_UPLOAD_LIB_ENABLE=1
+TFCI_UPLOAD_LIB_GCS_URI="gs://libtensorflow-nightly/$(date -I)"
+TFCI_UPLOAD_LIB_LATEST_ENABLE=1
+TFCI_UPLOAD_LIB_LATEST_GCS_URI="gs://libtensorflow-nightly/latest"
+TFCI_UPLOAD_WHL_GCS_ENABLE=0
 TFCI_UPLOAD_WHL_GCS_URI=
-TFCI_UPLOAD_WHL_PYPI_ARGS=
-TFCI_UPLOAD_WHL_PYPI_ENABLE=
+TFCI_UPLOAD_WHL_PYPI_ARGS=(--config-file="$KOKORO_KEYSTORE_DIR/73361_tensorflow_pypirc_using_global_api_token" --repository pypi-warehouse)
+TFCI_UPLOAD_WHL_PYPI_ENABLE=1
diff --git a/ci/official/envs/continuous_linux_arm64_cpu_py310 b/ci/official/envs/continuous_linux_arm64_cpu_py310
index 0dde659..b8d7e5c 100644
--- a/ci/official/envs/continuous_linux_arm64_cpu_py310
+++ b/ci/official/envs/continuous_linux_arm64_cpu_py310
@@ -1,6 +1,7 @@
 # This envrionment is experimental and should not yet be used for production jobs
 source ci/official/envs/ci_default
+TFCI_PYTHON_VERSION=3.10
 TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_arm64
-TFCI_BAZEL_COMMON_ARGS=(--config release_arm64_linux --config tf_public_cache_push --repo_env=TF_PYTHON_VERSION=3.10)
+TFCI_BAZEL_COMMON_ARGS=(--config release_arm64_linux --config tf_public_cache_push --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION)
 TFCI_DOCKER_IMAGE=gcr.io/tensorflow-sigs/build-arm64:tf-latest-multi-python
 TFCI_DOCKER_REBUILD_ARGS=(--target=tf ci/official/containers/linux_arm64)
diff --git a/ci/official/envs/continuous_linux_arm64_cpu_py311 b/ci/official/envs/continuous_linux_arm64_cpu_py311
index cc05e25..7a0ae9e 100644
--- a/ci/official/envs/continuous_linux_arm64_cpu_py311
+++ b/ci/official/envs/continuous_linux_arm64_cpu_py311
@@ -1,6 +1,7 @@
 # This envrionment is experimental and should not yet be used for production jobs
 source ci/official/envs/ci_default
+TFCI_PYTHON_VERSION=3.11
 TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_arm64
-TFCI_BAZEL_COMMON_ARGS=(--config release_arm64_linux --config tf_public_cache_push --repo_env=TF_PYTHON_VERSION=3.11)
+TFCI_BAZEL_COMMON_ARGS=(--config release_arm64_linux --config tf_public_cache_push --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION)
 TFCI_DOCKER_IMAGE=gcr.io/tensorflow-sigs/build-arm64:tf-latest-multi-python
 TFCI_DOCKER_REBUILD_ARGS=(--target=tf ci/official/containers/linux_arm64)
diff --git a/ci/official/envs/continuous_linux_arm64_cpu_py39 b/ci/official/envs/continuous_linux_arm64_cpu_py39
index 790a2e0..53aee87 100644
--- a/ci/official/envs/continuous_linux_arm64_cpu_py39
+++ b/ci/official/envs/continuous_linux_arm64_cpu_py39
@@ -1,6 +1,7 @@
 # This envrionment is experimental and should not yet be used for production jobs
 source ci/official/envs/ci_default
+TFCI_PYTHON_VERSION=3.9
 TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_arm64
-TFCI_BAZEL_COMMON_ARGS=(--config release_arm64_linux --config tf_public_cache_push --repo_env=TF_PYTHON_VERSION=3.9)
+TFCI_BAZEL_COMMON_ARGS=(--config release_arm64_linux --config tf_public_cache_push --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION)
 TFCI_DOCKER_IMAGE=gcr.io/tensorflow-sigs/build-arm64:tf-latest-multi-python
 TFCI_DOCKER_REBUILD_ARGS=(--target=tf ci/official/containers/linux_arm64)
diff --git a/ci/official/envs/continuous_linux_x86_cpu_py310 b/ci/official/envs/continuous_linux_x86_cpu_py310
index b597e69..13b2730 100644
--- a/ci/official/envs/continuous_linux_x86_cpu_py310
+++ b/ci/official/envs/continuous_linux_x86_cpu_py310
@@ -1,5 +1,6 @@
 source ci/official/envs/ci_default
+TFCI_PYTHON_VERSION=3.10
 TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cpu
-TFCI_BAZEL_COMMON_ARGS=(--config release_cpu_linux --config rbe_linux_cpu --repo_env=TF_PYTHON_VERSION=3.10)
-TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.10
-TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=3.10 --target=devel tools/tf_sig_build_dockerfiles)
+TFCI_BAZEL_COMMON_ARGS=(--config release_cpu_linux --config rbe_linux_cpu --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION)
+TFCI_DOCKER_IMAGE=tensorflow/build:latest-pythonlatest-python${TFCI_PYTHON_VERSION}
+TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles)
diff --git a/ci/official/envs/continuous_linux_x86_cpu_py311 b/ci/official/envs/continuous_linux_x86_cpu_py311
index 1ce6c8b..3f92c5c 100644
--- a/ci/official/envs/continuous_linux_x86_cpu_py311
+++ b/ci/official/envs/continuous_linux_x86_cpu_py311
@@ -1,5 +1,6 @@
 source ci/official/envs/ci_default
+TFCI_PYTHON_VERSION=3.11
 TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cpu
-TFCI_BAZEL_COMMON_ARGS=(--config release_cpu_linux --config rbe_linux_cpu --repo_env=TF_PYTHON_VERSION=3.11)
-TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.11
-TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=3.11 --target=devel tools/tf_sig_build_dockerfiles)
+TFCI_BAZEL_COMMON_ARGS=(--config release_cpu_linux --config rbe_linux_cpu --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION)
+TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION}
+TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles)
diff --git a/ci/official/envs/continuous_linux_x86_cpu_py39 b/ci/official/envs/continuous_linux_x86_cpu_py39
index de2c986..4ca275c 100644
--- a/ci/official/envs/continuous_linux_x86_cpu_py39
+++ b/ci/official/envs/continuous_linux_x86_cpu_py39
@@ -1,5 +1,6 @@
 source ci/official/envs/ci_default
+TFCI_PYTHON_VERSION=3.9
 TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cpu
-TFCI_BAZEL_COMMON_ARGS=(--config release_cpu_linux --config rbe_linux_cpu --repo_env=TF_PYTHON_VERSION=3.9)
-TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.9
-TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=3.11 --target=devel tools/tf_sig_build_dockerfiles)
+TFCI_BAZEL_COMMON_ARGS=(--config release_cpu_linux --config rbe_linux_cpu --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION)
+TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION}
+TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles)
diff --git a/ci/official/envs/continuous_linux_x86_cuda_py310 b/ci/official/envs/continuous_linux_x86_cuda_py310
index 69ca107..f09a5d55 100644
--- a/ci/official/envs/continuous_linux_x86_cuda_py310
+++ b/ci/official/envs/continuous_linux_x86_cuda_py310
@@ -1,7 +1,8 @@
 source ci/official/envs/ci_default
+TFCI_PYTHON_VERSION=3.10
 TFCI_NVIDIA_SMI_ENABLE=1
 TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cuda
-TFCI_BAZEL_COMMON_ARGS=(--config release_gpu_linux --config rbe_linux_cuda --repo_env=TF_PYTHON_VERSION=3.10)
+TFCI_BAZEL_COMMON_ARGS=(--config release_gpu_linux --config rbe_linux_cuda --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION)
 TFCI_DOCKER_ARGS=(--gpus all)
-TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.10
-TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=3.10 --target=devel tools/tf_sig_build_dockerfiles)
+TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION}
+TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles)
diff --git a/ci/official/envs/continuous_linux_x86_cuda_py311 b/ci/official/envs/continuous_linux_x86_cuda_py311
index da54acb..cd834c2 100644
--- a/ci/official/envs/continuous_linux_x86_cuda_py311
+++ b/ci/official/envs/continuous_linux_x86_cuda_py311
@@ -1,7 +1,8 @@
 source ci/official/envs/ci_default
+TFCI_PYTHON_VERSION=3.11
 TFCI_NVIDIA_SMI_ENABLE=1
 TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cuda
-TFCI_BAZEL_COMMON_ARGS=(--config release_gpu_linux --config rbe_linux_cuda --repo_env=TF_PYTHON_VERSION=3.11)
+TFCI_BAZEL_COMMON_ARGS=(--config release_gpu_linux --config rbe_linux_cuda --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION)
 TFCI_DOCKER_ARGS=(--gpus all)
-TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.11
-TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=3.11 --target=devel tools/tf_sig_build_dockerfiles)
+TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION}
+TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles)
diff --git a/ci/official/envs/continuous_linux_x86_cuda_py39 b/ci/official/envs/continuous_linux_x86_cuda_py39
index a53df57..798dfdf 100644
--- a/ci/official/envs/continuous_linux_x86_cuda_py39
+++ b/ci/official/envs/continuous_linux_x86_cuda_py39
@@ -1,7 +1,8 @@
 source ci/official/envs/ci_default
+TFCI_PYTHON_VERSION=3.9
 TFCI_NVIDIA_SMI_ENABLE=1
 TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cuda
-TFCI_BAZEL_COMMON_ARGS=(--config release_gpu_linux --config rbe_linux_cuda --repo_env=TF_PYTHON_VERSION=3.9)
+TFCI_BAZEL_COMMON_ARGS=(--config release_gpu_linux --config rbe_linux_cuda --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION)
 TFCI_DOCKER_ARGS=(--gpus all)
-TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.9
-TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=3.9 --target=devel tools/tf_sig_build_dockerfiles)
+TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION}
+TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles)
diff --git a/ci/official/envs/continuous_macos_arm64_py310 b/ci/official/envs/continuous_macos_arm64_py310
index d6bde2b..a08a335 100644
--- a/ci/official/envs/continuous_macos_arm64_py310
+++ b/ci/official/envs/continuous_macos_arm64_py310
@@ -1,4 +1,5 @@
 source ci/official/envs/ci_default
+TFCI_PYTHON_VERSION=3.10
 TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=macos_arm64
-TFCI_BAZEL_COMMON_ARGS=(--config release_macos_arm64 --repo_env=TF_PYTHON_VERSION=3.10)
+TFCI_BAZEL_COMMON_ARGS=(--config release_macos_arm64 --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION)
 TFCI_DOCKER_ENABLE=0
diff --git a/ci/official/envs/continuous_macos_arm64_py311 b/ci/official/envs/continuous_macos_arm64_py311
index c227bd8..230d18d 100644
--- a/ci/official/envs/continuous_macos_arm64_py311
+++ b/ci/official/envs/continuous_macos_arm64_py311
@@ -1,4 +1,5 @@
 source ci/official/envs/ci_default
+TFCI_PYTHON_VERSION=3.11
 TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=macos_arm64
-TFCI_BAZEL_COMMON_ARGS=(--config release_macos_arm64 --repo_env=TF_PYTHON_VERSION=3.11)
+TFCI_BAZEL_COMMON_ARGS=(--config release_macos_arm64 --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION)
 TFCI_DOCKER_ENABLE=0
diff --git a/ci/official/envs/continuous_macos_arm64_py39 b/ci/official/envs/continuous_macos_arm64_py39
index c227bd8..59585ff 100644
--- a/ci/official/envs/continuous_macos_arm64_py39
+++ b/ci/official/envs/continuous_macos_arm64_py39
@@ -1,4 +1,5 @@
 source ci/official/envs/ci_default
+TFCI_PYTHON_VERSION=3.9
 TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=macos_arm64
-TFCI_BAZEL_COMMON_ARGS=(--config release_macos_arm64 --repo_env=TF_PYTHON_VERSION=3.11)
+TFCI_BAZEL_COMMON_ARGS=(--config release_macos_arm64 --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION)
 TFCI_DOCKER_ENABLE=0
diff --git a/ci/official/envs/disable_all_uploads b/ci/official/envs/disable_all_uploads
index b09169d..6559f80 100644
--- a/ci/official/envs/disable_all_uploads
+++ b/ci/official/envs/disable_all_uploads
@@ -1,7 +1,9 @@
-TFCI_UPLOAD_LIB_ENABLE=0
+TFCI_DOCKER_REBUILD_UPLOAD_ENABLE=0
+TFCI_UPLOAD_LIB_ENABLE=
+TFCI_UPLOAD_LIB_GCS_URI=
 TFCI_UPLOAD_LIB_LATEST_ENABLE=
-TFCI_DOCKER_REBUILD_UPLOAD_ENABLE=
+TFCI_UPLOAD_LIB_LATEST_GCS_URI=
 TFCI_UPLOAD_WHL_GCS_ENABLE=
 TFCI_UPLOAD_WHL_GCS_URI=
-TFCI_UPLOAD_WHL_PYPI_ENABLE=
 TFCI_UPLOAD_WHL_PYPI_ARGS=
+TFCI_UPLOAD_WHL_PYPI_ENABLE=
diff --git a/ci/official/envs/nightly_libtensorflow_linux_x86_cpu b/ci/official/envs/nightly_libtensorflow_linux_x86_cpu
index aa04aef..9fbd23a 100644
--- a/ci/official/envs/nightly_libtensorflow_linux_x86_cpu
+++ b/ci/official/envs/nightly_libtensorflow_linux_x86_cpu
@@ -1,7 +1,8 @@
 source ci/official/envs/ci_default
 source ci/official/envs/ci_nightly_uploads
-TFCI_BAZEL_COMMON_ARGS=(--config release_cpu_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=3.10)
-TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.10
-TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=3.10 --target=devel tools/tf_sig_build_dockerfiles)
+TFCI_PYTHON_VERSION=3.10
+TFCI_BAZEL_COMMON_ARGS=(--config release_cpu_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION)
+TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION}
+TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles)
 TFCI_LIB_SUFFIX="-cpu-linux-x86_64"
 TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1
diff --git a/ci/official/envs/nightly_libtensorflow_linux_x86_cuda b/ci/official/envs/nightly_libtensorflow_linux_x86_cuda
index d168317..0b35c0e 100644
--- a/ci/official/envs/nightly_libtensorflow_linux_x86_cuda
+++ b/ci/official/envs/nightly_libtensorflow_linux_x86_cuda
@@ -1,9 +1,10 @@
 source ci/official/envs/ci_default
 source ci/official/envs/ci_nightly_uploads
-TFCI_BAZEL_COMMON_ARGS=(--config release_gpu_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=3.10)
+TFCI_PYTHON_VERSION=3.10
+TFCI_BAZEL_COMMON_ARGS=(--config release_gpu_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION)
 TFCI_DOCKER_ARGS=(--gpus all)
-TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.10
-TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=3.10 --target=devel tools/tf_sig_build_dockerfiles)
+TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION}
+TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles)
 TFCI_LIB_SUFFIX="-gpu-linux-x86_64"
 TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1
 TFCI_NVIDIA_SMI_ENABLE=1
diff --git a/ci/official/envs/nightly_libtensorflow_macos_arm64 b/ci/official/envs/nightly_libtensorflow_macos_arm64
index a0c1ad1..d29447d 100644
--- a/ci/official/envs/nightly_libtensorflow_macos_arm64
+++ b/ci/official/envs/nightly_libtensorflow_macos_arm64
@@ -1,6 +1,7 @@
 source ci/official/envs/ci_default
 source ci/official/envs/ci_nightly_uploads
-TFCI_BAZEL_COMMON_ARGS=(--config release_macos_arm64 --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=3.10)
+TFCI_PYTHON_VERSION=3.10
+TFCI_BAZEL_COMMON_ARGS=(--config release_macos_arm64 --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION)
 TFCI_DOCKER_ENABLE=0
 TFCI_LIB_SUFFIX="-cpu-macos-arm64"
 TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1
diff --git a/ci/official/envs/nightly_linux_arm64_cpu_py310 b/ci/official/envs/nightly_linux_arm64_cpu_py310
index 255b317..5b7900c 100644
--- a/ci/official/envs/nightly_linux_arm64_cpu_py310
+++ b/ci/official/envs/nightly_linux_arm64_cpu_py310
@@ -1,6 +1,8 @@
 source ci/official/envs/ci_default
-source ci/official/envs/ci_nightly_uploads
-TFCI_BAZEL_COMMON_ARGS=(--config release_arm64_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=3.10)
+# Disable arm64 uploads while being worked on
+source ci/official/envs/disable_all_uploads
+TFCI_PYTHON_VERSION=3.10
+TFCI_BAZEL_COMMON_ARGS=(--config release_arm64_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION)
 TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_arm64
 TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag)
 TFCI_DOCKER_IMAGE=gcr.io/tensorflow-sigs/build-arm64:tf-latest-multi-python
diff --git a/ci/official/envs/nightly_linux_arm64_cpu_py311 b/ci/official/envs/nightly_linux_arm64_cpu_py311
index c51b19e..6edb93b 100644
--- a/ci/official/envs/nightly_linux_arm64_cpu_py311
+++ b/ci/official/envs/nightly_linux_arm64_cpu_py311
@@ -1,6 +1,8 @@
 source ci/official/envs/ci_default
-source ci/official/envs/ci_nightly_uploads
-TFCI_BAZEL_COMMON_ARGS=(--config release_arm64_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=3.11)
+# Disable arm64 uploads while being worked on
+source ci/official/envs/disable_all_uploads
+TFCI_PYTHON_VERSION=3.11
+TFCI_BAZEL_COMMON_ARGS=(--config release_arm64_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION)
 TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_arm64
 TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag)
 TFCI_DOCKER_IMAGE=gcr.io/tensorflow-sigs/build-arm64:tf-latest-multi-python
diff --git a/ci/official/envs/nightly_linux_arm64_cpu_py312 b/ci/official/envs/nightly_linux_arm64_cpu_py312
new file mode 100644
index 0000000..dfe96fa
--- /dev/null
+++ b/ci/official/envs/nightly_linux_arm64_cpu_py312
@@ -0,0 +1,10 @@
+source ci/official/envs/ci_default
+# Disable arm64 uploads while being worked on
+source ci/official/envs/disable_all_uploads
+TFCI_PYTHON_VERSION=3.12
+TFCI_BAZEL_COMMON_ARGS=(--config release_arm64_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION)
+TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_arm64
+TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag)
+TFCI_DOCKER_IMAGE=gcr.io/tensorflow-sigs/build-arm64:tf-latest-multi-python
+TFCI_DOCKER_REBUILD_ARGS=(--target=tf ci/official/containers/linux_arm64)
+TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1
diff --git a/ci/official/envs/nightly_linux_arm64_cpu_py39 b/ci/official/envs/nightly_linux_arm64_cpu_py39
index 0ecc79c..e3b5161 100644
--- a/ci/official/envs/nightly_linux_arm64_cpu_py39
+++ b/ci/official/envs/nightly_linux_arm64_cpu_py39
@@ -1,6 +1,8 @@
 source ci/official/envs/ci_default
-source ci/official/envs/ci_nightly_uploads
-TFCI_BAZEL_COMMON_ARGS=(--config release_arm64_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=3.9)
+# Disable arm64 uploads while being worked on
+source ci/official/envs/disable_all_uploads
+TFCI_PYTHON_VERSION=3.9
+TFCI_BAZEL_COMMON_ARGS=(--config release_arm64_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION)
 TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_arm64
 TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag)
 TFCI_DOCKER_IMAGE=gcr.io/tensorflow-sigs/build-arm64:tf-latest-multi-python
diff --git a/ci/official/envs/nightly_linux_x86_cpu_py310 b/ci/official/envs/nightly_linux_x86_cpu_py310
index d4192c7..574ac7b 100644
--- a/ci/official/envs/nightly_linux_x86_cpu_py310
+++ b/ci/official/envs/nightly_linux_x86_cpu_py310
@@ -1,8 +1,9 @@
 source ci/official/envs/ci_default
 source ci/official/envs/ci_nightly_uploads
-TFCI_BAZEL_COMMON_ARGS=(--config release_cpu_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=3.10)
+TFCI_PYTHON_VERSION=3.10
+TFCI_BAZEL_COMMON_ARGS=(--config release_cpu_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION)
 TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cpu
 TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag)
-TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.10
-TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=3.10 --target=devel tools/tf_sig_build_dockerfiles)
+TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION}
+TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles)
 TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1
diff --git a/ci/official/envs/nightly_linux_x86_cpu_py311 b/ci/official/envs/nightly_linux_x86_cpu_py311
index 15d680a..d1b8bfe 100644
--- a/ci/official/envs/nightly_linux_x86_cpu_py311
+++ b/ci/official/envs/nightly_linux_x86_cpu_py311
@@ -1,8 +1,9 @@
 source ci/official/envs/ci_default
 source ci/official/envs/ci_nightly_uploads
-TFCI_BAZEL_COMMON_ARGS=(--config release_cpu_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=3.11)
+TFCI_PYTHON_VERSION=3.11
+TFCI_BAZEL_COMMON_ARGS=(--config release_cpu_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION)
 TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cpu
 TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag)
-TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.11
-TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=3.11 --target=devel tools/tf_sig_build_dockerfiles)
+TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION}
+TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles)
 TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1
diff --git a/ci/official/envs/nightly_linux_x86_cpu_py312 b/ci/official/envs/nightly_linux_x86_cpu_py312
new file mode 100644
index 0000000..586fd92
--- /dev/null
+++ b/ci/official/envs/nightly_linux_x86_cpu_py312
@@ -0,0 +1,10 @@
+source ci/official/envs/ci_default
+# Disable 3.12 uploads while being worked on
+source ci/official/envs/disable_all_uploads
+TFCI_PYTHON_VERSION=3.12
+TFCI_BAZEL_COMMON_ARGS=(--config release_cpu_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION)
+TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cpu
+TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag)
+TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION}
+TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles)
+TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1
diff --git a/ci/official/envs/nightly_linux_x86_cpu_py39 b/ci/official/envs/nightly_linux_x86_cpu_py39
index a180c1e..2c3e118 100644
--- a/ci/official/envs/nightly_linux_x86_cpu_py39
+++ b/ci/official/envs/nightly_linux_x86_cpu_py39
@@ -1,8 +1,9 @@
 source ci/official/envs/ci_default
 source ci/official/envs/ci_nightly_uploads
-TFCI_BAZEL_COMMON_ARGS=(--config release_cpu_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=3.9)
+TFCI_PYTHON_VERSION=3.9
+TFCI_BAZEL_COMMON_ARGS=(--config release_cpu_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION)
 TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cpu
 TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag)
-TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.9
-TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=3.9 --target=devel tools/tf_sig_build_dockerfiles)
+TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION}
+TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles)
 TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1
diff --git a/ci/official/envs/nightly_linux_x86_cuda_py310 b/ci/official/envs/nightly_linux_x86_cuda_py310
index 0003790..16038d6 100644
--- a/ci/official/envs/nightly_linux_x86_cuda_py310
+++ b/ci/official/envs/nightly_linux_x86_cuda_py310
@@ -1,9 +1,10 @@
 source ci/official/envs/ci_default
 source ci/official/envs/ci_nightly_uploads
-TFCI_BAZEL_COMMON_ARGS=(--config release_gpu_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=3.10)
+TFCI_PYTHON_VERSION=3.10
+TFCI_BAZEL_COMMON_ARGS=(--config release_gpu_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION)
 TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cuda
 TFCI_BUILD_PIP_PACKAGE_ARGS=(--nightly_flag)
 TFCI_DOCKER_ARGS=(--gpus all)
-TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.10
-TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=3.10 --target=devel tools/tf_sig_build_dockerfiles)
+TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION}
+TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles)
 TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1
diff --git a/ci/official/envs/nightly_linux_x86_cuda_py311 b/ci/official/envs/nightly_linux_x86_cuda_py311
index e4bce79..1d0d931 100644
--- a/ci/official/envs/nightly_linux_x86_cuda_py311
+++ b/ci/official/envs/nightly_linux_x86_cuda_py311
@@ -1,9 +1,10 @@
 source ci/official/envs/ci_default
 source ci/official/envs/ci_nightly_uploads
-TFCI_BAZEL_COMMON_ARGS=(--config release_gpu_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=3.11)
+TFCI_PYTHON_VERSION=3.11
+TFCI_BAZEL_COMMON_ARGS=(--config release_gpu_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION)
 TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cuda
 TFCI_BUILD_PIP_PACKAGE_ARGS=(--nightly_flag)
 TFCI_DOCKER_ARGS=(--gpus all)
-TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.11
-TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=3.11 --target=devel tools/tf_sig_build_dockerfiles)
+TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION}
+TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles)
 TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1
diff --git a/ci/official/envs/nightly_linux_x86_cuda_py312 b/ci/official/envs/nightly_linux_x86_cuda_py312
new file mode 100644
index 0000000..4767f6d
--- /dev/null
+++ b/ci/official/envs/nightly_linux_x86_cuda_py312
@@ -0,0 +1,11 @@
+source ci/official/envs/ci_default
+# Disable 3.12 uploads while being worked on
+source ci/official/envs/disable_all_uploads
+TFCI_PYTHON_VERSION=3.12
+TFCI_BAZEL_COMMON_ARGS=(--config release_gpu_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION)
+TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cuda
+TFCI_BUILD_PIP_PACKAGE_ARGS=(--nightly_flag)
+TFCI_DOCKER_ARGS=(--gpus all)
+TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION}
+TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles)
+TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1
\ No newline at end of file
diff --git a/ci/official/envs/nightly_linux_x86_cuda_py39 b/ci/official/envs/nightly_linux_x86_cuda_py39
index 065a914..e3a5d3f 100644
--- a/ci/official/envs/nightly_linux_x86_cuda_py39
+++ b/ci/official/envs/nightly_linux_x86_cuda_py39
@@ -1,9 +1,10 @@
 source ci/official/envs/ci_default
 source ci/official/envs/ci_nightly_uploads
-TFCI_BAZEL_COMMON_ARGS=(--config release_gpu_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=3.9)
+TFCI_PYTHON_VERSION=3.9
+TFCI_BAZEL_COMMON_ARGS=(--config release_gpu_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION)
 TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cuda
 TFCI_BUILD_PIP_PACKAGE_ARGS=(--nightly_flag)
 TFCI_DOCKER_ARGS=(--gpus all)
-TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.9
-TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=3.9 --target=devel tools/tf_sig_build_dockerfiles)
+TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION}
+TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles)
 TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1
diff --git a/ci/official/envs/nightly_linux_x86_tpu_py310 b/ci/official/envs/nightly_linux_x86_tpu_py310
index c2704eb..4e80141 100644
--- a/ci/official/envs/nightly_linux_x86_tpu_py310
+++ b/ci/official/envs/nightly_linux_x86_tpu_py310
@@ -1,10 +1,11 @@
 source ci/official/envs/ci_default
-source ci/official/envs/ci_nightly_uploads
-TFCI_BAZEL_COMMON_ARGS=(--config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=3.10 --config=tpu)
+# Disable tpu uploads while being worked on
+source ci/official/envs/disable_all_uploads
+TFCI_PYTHON_VERSION=3.10
+TFCI_BAZEL_COMMON_ARGS=(--config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION --config=tpu)
 TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_tpu
 TFCI_BUILD_PIP_PACKAGE_ARGS=(--tpu --nightly_flag)
-TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.10
-TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=3.10 --target=devel tools/tf_sig_build_dockerfiles)
-TFCI_LIBTPU_DOWNLOAD_ENABLE=1
-TFCI_LIBTPU_DOWNLOAD_URL=https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/1.8.0/libtpu.so
+TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION}
+TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles)
+TFCI_LIBTPU_DOWNLOAD_NIGHTLY_ENABLE=1
 TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1
diff --git a/ci/official/envs/nightly_linux_x86_tpu_py311 b/ci/official/envs/nightly_linux_x86_tpu_py311
index 7d740c8..e4ae8cc 100644
--- a/ci/official/envs/nightly_linux_x86_tpu_py311
+++ b/ci/official/envs/nightly_linux_x86_tpu_py311
@@ -1,10 +1,11 @@
 source ci/official/envs/ci_default
-source ci/official/envs/ci_nightly_uploads
-TFCI_BAZEL_COMMON_ARGS=(--config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=3.11 --config=tpu)
+# Disable tpu uploads while being worked on
+source ci/official/envs/disable_all_uploads
+TFCI_PYTHON_VERSION=3.11
+TFCI_BAZEL_COMMON_ARGS=(--config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION --config=tpu)
 TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_tpu
 TFCI_BUILD_PIP_PACKAGE_ARGS=(--tpu --nightly_flag)
-TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.11
-TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=3.11 --target=devel tools/tf_sig_build_dockerfiles)
-TFCI_LIBTPU_DOWNLOAD_ENABLE=1
-TFCI_LIBTPU_DOWNLOAD_URL=https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/1.8.0/libtpu.so
+TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION}
+TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles)
+TFCI_LIBTPU_DOWNLOAD_NIGHTLY_ENABLE=1
 TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1
diff --git a/ci/official/envs/nightly_linux_x86_tpu_py312 b/ci/official/envs/nightly_linux_x86_tpu_py312
new file mode 100644
index 0000000..54d96b1
--- /dev/null
+++ b/ci/official/envs/nightly_linux_x86_tpu_py312
@@ -0,0 +1,11 @@
+source ci/official/envs/ci_default
+# Disable tpu uploads while being worked on
+source ci/official/envs/disable_all_uploads
+TFCI_PYTHON_VERSION=3.12
+TFCI_BAZEL_COMMON_ARGS=(--config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION --config=tpu)
+TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_tpu
+TFCI_BUILD_PIP_PACKAGE_ARGS=(--tpu --nightly_flag)
+TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION}
+TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles)
+TFCI_LIBTPU_DOWNLOAD_NIGHTLY_ENABLE=1
+TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1
diff --git a/ci/official/envs/nightly_linux_x86_tpu_py39 b/ci/official/envs/nightly_linux_x86_tpu_py39
index 2bb4bcc..4adaa8b 100644
--- a/ci/official/envs/nightly_linux_x86_tpu_py39
+++ b/ci/official/envs/nightly_linux_x86_tpu_py39
@@ -1,10 +1,11 @@
 source ci/official/envs/ci_default
-source ci/official/envs/ci_nightly_uploads
-TFCI_BAZEL_COMMON_ARGS=(--config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=3.9 --config=tpu)
+# Disable tpu uploads while being worked on
+source ci/official/envs/disable_all_uploads
+TFCI_PYTHON_VERSION=3.9
+TFCI_BAZEL_COMMON_ARGS=(--config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION --config=tpu)
 TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_tpu
 TFCI_BUILD_PIP_PACKAGE_ARGS=(--tpu --nightly_flag)
-TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.9
-TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=3.9 --target=devel tools/tf_sig_build_dockerfiles)
-TFCI_LIBTPU_DOWNLOAD_ENABLE=1
-TFCI_LIBTPU_DOWNLOAD_URL=https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/1.8.0/libtpu.so
+TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION}
+TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles)
+TFCI_LIBTPU_DOWNLOAD_NIGHTLY_ENABLE=1
 TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1
diff --git a/ci/official/envs/nightly_macos_arm64_py310 b/ci/official/envs/nightly_macos_arm64_py310
index 6eed041..81fa2c9 100644
--- a/ci/official/envs/nightly_macos_arm64_py310
+++ b/ci/official/envs/nightly_macos_arm64_py310
@@ -1,7 +1,8 @@
 source ci/official/envs/ci_default
 source ci/official/envs/disable_all_uploads
+TFCI_PYTHON_VERSION=3.10
 TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=macos_arm64
-TFCI_BAZEL_COMMON_ARGS=(--config release_macos_arm64 --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=3.10)
+TFCI_BAZEL_COMMON_ARGS=(--config release_macos_arm64 --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION)
 TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag)
 TFCI_DOCKER_ENABLE=0
 TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1
diff --git a/ci/official/envs/nightly_macos_arm64_py311 b/ci/official/envs/nightly_macos_arm64_py311
index 7565068..e8046a3 100644
--- a/ci/official/envs/nightly_macos_arm64_py311
+++ b/ci/official/envs/nightly_macos_arm64_py311
@@ -1,7 +1,8 @@
 source ci/official/envs/ci_default
 source ci/official/envs/disable_all_uploads
+TFCI_PYTHON_VERSION=3.11
 TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=macos_arm64
-TFCI_BAZEL_COMMON_ARGS=(--config release_macos_arm64 --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=3.11)
+TFCI_BAZEL_COMMON_ARGS=(--config release_macos_arm64 --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION)
 TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag)
 TFCI_DOCKER_ENABLE=0
 TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1
diff --git a/ci/official/envs/nightly_macos_arm64_py312 b/ci/official/envs/nightly_macos_arm64_py312
new file mode 100644
index 0000000..21432f0
--- /dev/null
+++ b/ci/official/envs/nightly_macos_arm64_py312
@@ -0,0 +1,9 @@
+source ci/official/envs/ci_default
+source ci/official/envs/disable_all_uploads
+TFCI_PYTHON_VERSION=3.12
+TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=macos_arm64
+TFCI_BAZEL_COMMON_ARGS=(--config release_macos_arm64 --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION)
+TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag)
+TFCI_DOCKER_ENABLE=0
+TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1
+TFCI_UPLOAD_WHL_GCS_ENABLE=1
diff --git a/ci/official/envs/nightly_macos_arm64_py39 b/ci/official/envs/nightly_macos_arm64_py39
index aac737a..ee58e84 100644
--- a/ci/official/envs/nightly_macos_arm64_py39
+++ b/ci/official/envs/nightly_macos_arm64_py39
@@ -1,7 +1,8 @@
 source ci/official/envs/ci_default
 source ci/official/envs/disable_all_uploads
+TFCI_PYTHON_VERSION=3.9
 TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=macos_arm64
-TFCI_BAZEL_COMMON_ARGS=(--config release_macos_arm64 --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=3.9)
+TFCI_BAZEL_COMMON_ARGS=(--config release_macos_arm64 --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION)
 TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag)
 TFCI_DOCKER_ENABLE=0
 TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1
diff --git a/ci/official/envs/sample b/ci/official/envs/sample
index a64de50..1e01d6a 100644
--- a/ci/official/envs/sample
+++ b/ci/official/envs/sample
@@ -13,7 +13,7 @@
 # Reset bazel common options. This combines a local disk cache and
 # TensorFlow's remote cache to speed up your builds. The "nightly" branch has
 # the most content cached. TFCI_BAZEL_COMMON_ARGS is also where we target
-# different Python versions. You can add e.g. "--repo_env=TF_PYTHON_VERSION=3.9"
+# different Python versions. You can add e.g. "--repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION"
 # to change the Python version to anything available (including the default) in
 # tensorflow/tools/toolchains/python/python_repo.bzl.
 TFCI_BAZEL_COMMON_ARGS=(--config tf_public_cache --disk_cache=build_output/cache)
diff --git a/ci/official/requirements_updater/README.md b/ci/official/requirements_updater/README.md
index 292cb07..ad2e350 100644
--- a/ci/official/requirements_updater/README.md
+++ b/ci/official/requirements_updater/README.md
@@ -1,75 +1,128 @@
-### Hermetic Python
+# Hermetic Python
 
-Hermetic Python allows us not to rely on system-installed python and
-system-installed python packages, instead we register our own python toolchain.
+Hermetic Python allows not to rely on system-installed Python, and
+system-installed Python packages. \
+Instead, an independent Python toolchain is registered, ensuring the right
+dependencies are always used. \
 See https://github.com/bazelbuild/rules_python/ for more details.
 
-#### Hermetic Python toolchain details
+### Specifying the Python version
 
-By default, Python 3.9 is used.
+Note: Only a number of minor Python versions are supported at any given time.
 
-To set your own version for hermetic Python toolchain, use `TF_PYTHON_VERSION`
-environment variable, e.g.
+By default, the lowest supported version is used.
+
+To set a different version, use the `TF_PYTHON_VERSION` environment variable,
+e.g.
 
 ```
-export TF_PYTHON_VERSION=3.10
+export TF_PYTHON_VERSION=3.11
 ```
 
-To set a version from argument line, add to your command
+To specify the version via a Bazel command argument, use the following:
 
 ```
---repo_env=TF_PYTHON_VERSION=3.10
+--repo_env=TF_PYTHON_VERSION=3.11
 ```
 
-### Requirements updater
+## Requirements updater
 
-Requirements updater is a standalone tool intended to simplify process of
-updating requirements for multiple versions of Python.
+Requirements updater is a standalone tool, intended to simplify process of
+updating requirements for multiple minor versions of Python.
 
-#### How to update/add requirements
+It takes in a file with a set of dependencies, and produces a more detailed
+requirements file for each version, with hashes specified for each
+dependency required, as well as their sub-dependencies.
 
-By default, the name of the input requirements file is `requirements.in`,
-but it can be set using the `REQUIREMENTS_FILE_NAME` variable, for example:
+### How to update/add requirements
+
+By default, the name of the base requirements file is `requirements.in`, but it
+can be set using the `REQUIREMENTS_FILE_NAME` variable. \
+For example:
+
 ```
-export REQUIREMENTS_FILE_NAME=`my_requirements.in`
+export REQUIREMENTS_FILE_NAME=my_requirements.in
 ```
 
-To set a version from the argument line, add to your command
+To specify the file via a Bazel command argument, use the following:
+
 ```
---repo_env=REQUIREMENTS_FILE_NAME=`my_requirements.in`
+--repo_env=REQUIREMENTS_FILE_NAME=my_requirements.in
 ```
 
-#### How to run the updater
+### How to run the updater
 
 ```
 bash updater.sh
 ```
 
-### How to add a new Python version
+## How to add a new Python version
 
-1) In the `WORKSPACE` file add a new version to `python_versions` argument of
-the `python_register_multi_toolchains` function.
+Note: Updating the
+[rules-python](https://github.com/bazelbuild/rules_python/releases) version may
+be required before going through the steps below. This is due to the new Python
+versions becoming available through `rules-python`. \
+See
+[here](https://github.com/tensorflow/tensorflow/commit/f91457f258fdd78f693044a57efa63a38335d1de),
+and
+[here](https://github.com/tensorflow/tensorflow/commit/052445e04ce20fd747657e0198a1bcec2b6dff5b),
+for an example.
 
-2) In `BUILD.bazel` file add a load statement for the new version, e.g.
+See
+[this commit](https://github.com/tensorflow/tensorflow/commit/5f7f05a80aac9b01325a78ec3fcff0dbedb1cc23)
+as a rough example of the steps below.
 
-```
-load("@python//3.11:defs.bzl",
-     compile_pip_requirements_3_11 = "compile_pip_requirements")
-```
+All the files referenced below are located in the same directory as this README,
+unless indicated otherwise.
 
-Add a new entry for the loaded `compile_pip_requirements`, e.g.
+1) Add the new version to the `VERSIONS` variable inside
+   `tensorflow/tools/toolchains/python/python_repo.bzl`. \
+   While this isn't necessary for running the updater, it is required for
+   actually using the new version with Tensorflow.
 
-```
-compile_pip_requirements_3_11(
-    name = "requirements_3_11",
-    extra_args = ["--allow-unsafe"],
-    requirements_in = "requirements.in",
-    requirements_txt = "requirements_lock_3_11.txt",
-)
-```
+2) In the `WORKSPACE` file, add the new version to the `python_versions`
+   parameter of the `python_register_multi_toolchains` function.
 
-3) Add the version to `SUPPORTED_VERSIONS` in `updater.sh`, after that run the
- requirements updater tool.
+3) In the `BUILD.bazel` file, add a load statement for the new version, e.g.
 
-4) As a result, a new `requirements_lock_3_11.txt` file should appear under the
-root of tensorflow directory.
+   ```
+      load("@python//3.11:defs.bzl",
+           compile_pip_requirements_3_11 = "compile_pip_requirements")
+   ```
+
+   Add a new entry for the loaded `compile_pip_requirements`, e.g.
+
+   ```
+      compile_pip_requirements_3_11(
+          name = "requirements_3_11",
+          extra_args = ["--allow-unsafe"],
+          requirements_in = "requirements.in",
+          requirements_txt = "requirements_lock_3_11.txt",
+      )
+   ```
+
+   ```
+      compile_pip_requirements_3_11(
+          name = "requirements_3_11_release",
+          extra_args = [
+              "--allow-unsafe",
+              "-P keras-nightly",
+              "-P tb-nightly",
+              "-P tf-estimator-nightly",
+          ],
+          requirements_in = "requirements.in",
+          requirements_txt = "requirements_lock_3_11.txt",
+      )
+   ```
+
+4) Add the version to `SUPPORTED_VERSIONS` in `updater.sh`, and
+   `release_updater.sh`
+
+5) Run the `updater.sh` shell script. \
+   If the base requirements file hasn't yet been updated to account for the new
+   Python version, which will require different versions for at least some
+   dependencies, it will need to be updated now, for the script to run
+   successfully.
+
+6) A new `requirements_lock_3_11.txt` file should appear under the root of the
+   `tensorflow` directory.
diff --git a/ci/official/utilities/docker.sh b/ci/official/utilities/docker.sh
index 4d726a9..ea1ecc2 100755
--- a/ci/official/utilities/docker.sh
+++ b/ci/official/utilities/docker.sh
@@ -30,6 +30,7 @@
 if ! docker container inspect tf >/dev/null 2>&1 ; then
   docker run "${TFCI_DOCKER_ARGS[@]}" --name tf -w "$TFCI_GIT_DIR" -itd --rm \
       -v "$TFCI_GIT_DIR:$TFCI_GIT_DIR" \
+      --env TFCI_PYTHON_VERSION \
       "$TFCI_DOCKER_IMAGE" \
     bash
 fi
diff --git a/ci/official/utilities/extract_resultstore_links.py b/ci/official/utilities/extract_resultstore_links.py
index 49d8e98..a801397 100644
--- a/ci/official/utilities/extract_resultstore_links.py
+++ b/ci/official/utilities/extract_resultstore_links.py
@@ -33,7 +33,8 @@
 FAILED_BUILD_LINE = 'FAILED: Build did NOT complete successfully'
 BUILD_STATUS_LINE = 'INFO: Build'
 TESTS_FAILED_RE = re.compile(r'^INFO: Build completed, \d+ tests? FAILED')
-BAZEL_COMMAND_RE = re.compile(r'(^| )(bazel .* (test|build) .+)')
+BAZEL_COMMAND_RE = re.compile(
+    r'(^| )(?P<command>bazel (.*? )?(?P<type>test|build) .+)')
 
 
 class InvokeStatus:
@@ -136,8 +137,8 @@
       if 'bazel ' in backtrack_line and not backtrack_line.endswith('\\'):
         bazel_line = BAZEL_COMMAND_RE.search(backtrack_line)
         if bazel_line:
-          lines['command'] = bazel_line.group(2)
-          lines['command_type'] = bazel_line.group(3)
+          lines['command'] = bazel_line.group('command')
+          lines['command_type'] = bazel_line.group('type')
           break
       k -= 1
       continue
diff --git a/ci/official/utilities/rename_and_verify_wheels.sh b/ci/official/utilities/rename_and_verify_wheels.sh
index cd02b82..4388329 100755
--- a/ci/official/utilities/rename_and_verify_wheels.sh
+++ b/ci/official/utilities/rename_and_verify_wheels.sh
@@ -20,14 +20,14 @@
 set -euxo pipefail
 
 DIR=$1
-find $DIR -iname "*.whl" | while read wheel; do
+find "$DIR" -iname "*.whl" | while read wheel; do
   echo "Checking and renaming $wheel..."
   wheel=$(realpath "$wheel")
   # Repair wheel based upon name/architecture, fallback to x86
   if [[ $wheel == *"aarch64.whl" ]]; then
-    time python3 -m auditwheel repair --plat manylinux2014_aarch64 "$wheel" --wheel-dir build 2>&1 | tee check.txt
+    time python3 -m auditwheel repair --plat manylinux2014_aarch64 "$wheel" --wheel-dir "$DIR" 2>&1 | tee check.txt
   else
-    time python3 -m auditwheel repair --plat manylinux2014_x86_64 "$wheel" --wheel-dir build 2>&1 | tee check.txt
+    time python3 -m auditwheel repair --plat manylinux2014_x86_64 "$wheel" --wheel-dir "$DIR" 2>&1 | tee check.txt
   fi
 
   # We don't need the original wheel if it was renamed
@@ -38,5 +38,5 @@
   fi
   rm check.txt
 
-  TF_WHEEL="$wheel" bats ./ci/official/utilities/wheel_verification.bats --timing
+  TF_WHEEL="$wheel" BUILD_DIR="$DIR" bats ./ci/official/utilities/wheel_verification.bats --timing
 done
diff --git a/ci/official/utilities/setup.sh b/ci/official/utilities/setup.sh
index 2d1e330..aa4e783 100755
--- a/ci/official/utilities/setup.sh
+++ b/ci/official/utilities/setup.sh
@@ -43,20 +43,52 @@
 # relevant variables in their environment. Because of 'set -o allexport' above
 # (which is equivalent to "set -a"), every variable in the file is exported
 # for other files to use.
+#
+# Separately, if TFCI is set *and* there are also additional TFCI_ variables
+# set in the shell environment, those variables will be restored after the
+# TFCI env has been loaded. This is useful for e.g. on-demand "generic" jobs
+# where the user may wish to change just one option. Conveniently, this method
+# even works for arrays; e.g. TFCI_SOME_ARRAY="(--array --contents)" ends up
+# as TFCI_SOME_ARRAY=(--array --contents) in the storage file and is thus
+# loaded as an array when sourced.
 if [[ -n "${TFCI:-}" ]]; then
-  # Sourcing this twice, the first time with "-u" unset, means that variable
+  FROM_ENV=$(mktemp)
+  # Piping into cat means grep won't abort the process if no errors are found.
+  env | grep TFCI_ | cat > "$FROM_ENV"
+
+  # Sourcing TFCI twice, the first time with "-u" unset, means that variable
   # order does not matter. i.e. "TFCI_BAR=$TFCI_FOO; TFCI_FOO=true" will work.
   # TFCI_FOO is only valid the second time through.
   set +u
   source "$TFCI"
   set -u
   source "$TFCI"
+
+  # Load those stored pre-existing TFCI_ vars, if any
+  if [[ -s "$FROM_ENV" ]]; then
+    echo '==TFCI==: NOTE: Loading the following env parameters, which were'
+    echo 'already set in the shell environment. If you want to disable this'
+    echo 'behavior, create a new shell.'
+    cat "$FROM_ENV"
+    source "$FROM_ENV"
+    rm "$FROM_ENV"
+  fi
 else
   echo '==TFCI==: The $TFCI variable is not set. This is fine as long as you'
   echo 'already sourced a TFCI env file with "set -a; source <path>; set +a".'
   echo 'If you have not, you will see a lot of undefined variable errors.'
 fi
 
+# Force-disable uploads if the job initiator is not Kokoro
+# This is temporary: it's currently standard practice for employees to
+# run nightly jobs for testing purposes. We're aiming to move away from
+# this with more convenient methods, but as long as it's possible to do,
+# we want to make sure those extra jobs don't upload anything.
+# TODO(angerson) Remove this once it's no longer relevant
+if [[ "${KOKORO_BUILD_INITIATOR:-}" != "kokoro" ]]; then
+  source ./ci/official/envs/disable_all_uploads
+fi
+
 # Create and expand to the full path of TFCI_OUTPUT_DIR
 export TFCI_OUTPUT_DIR=$(realpath "$TFCI_OUTPUT_DIR")
 mkdir -p "$TFCI_OUTPUT_DIR"
diff --git a/ci/official/utilities/wheel_verification.bats b/ci/official/utilities/wheel_verification.bats
index 5af41a7..99d0f32 100644
--- a/ci/official/utilities/wheel_verification.bats
+++ b/ci/official/utilities/wheel_verification.bats
@@ -12,14 +12,20 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-# Suite of verification tests for the SINGLE TensorFlow wheel in the "build"
-# directory, or whatever path is set as $TF_WHEEL.
+# Suite of verification tests for the SINGLE TensorFlow wheel in the
+# $BUILD_DIR directory, or whatever path is set as $TF_WHEEL.
 
 setup_file() {
-    python3 -m venv "$BATS_FILE_TMPDIR/venv"
-    cd build
+    cd "$BUILD_DIR"
     if [[ -z "$TF_WHEEL" ]]; then
-        export TF_WHEEL=$(find build -iname "*.whl")
+        export TF_WHEEL=$(find "$BUILD_DIR" -iname "*.whl")
+    fi
+
+    # Setup the env for the python import testing
+    if [[ $TF_WHEEL == *"aarch64.whl" ]]; then
+        python${TFCI_PYTHON_VERSION} -m venv "$BATS_FILE_TMPDIR/venv"
+    else
+        python3 -m venv "$BATS_FILE_TMPDIR/venv"
     fi
 }
 
diff --git a/ci/official/wheel.sh b/ci/official/wheel.sh
index e3e569f..20c6f26 100755
--- a/ci/official/wheel.sh
+++ b/ci/official/wheel.sh
@@ -29,6 +29,14 @@
 if [[ "$TFCI_LIBTPU_DOWNLOAD_ENABLE" == 1 ]]; then
   wget -P ./tensorflow/lib/ "$TFCI_LIBTPU_DOWNLOAD_URL"
 fi
+if [[ "$TFCI_LIBTPU_DOWNLOAD_NIGHTLY_ENABLE" == 1 ]]; then
+  # For nightly jobs, libtpu.so comes from the latest nightly libtpu build.
+  # Note: expects a working wheel for today
+  DATE=$(TZ='America/Los_Angeles' date '+%Y%m%d')
+  tfrun wget "https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev${DATE}-py3-none-any.whl" -O libtpu.whl
+  # -j to discard intermediate directories; -o to overwrite if exists; -d to set output dir
+  tfrun unzip libtpu.whl libtpu/libtpu.so -j -o -d ./tensorflow/lib
+fi
 
 tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" build "${TFCI_BAZEL_COMMON_ARGS[@]}" //tensorflow/tools/pip_package:build_pip_package
 tfrun ./bazel-bin/tensorflow/tools/pip_package/build_pip_package "$TFCI_OUTPUT_DIR" "${TFCI_BUILD_PIP_PACKAGE_ARGS[@]}"
diff --git a/ci/official/wheel_test/WORKSPACE b/ci/official/wheel_test/WORKSPACE
index 922d227..cef9033 100644
--- a/ci/official/wheel_test/WORKSPACE
+++ b/ci/official/wheel_test/WORKSPACE
@@ -35,7 +35,7 @@
 
 python_repository(name = "python_version_repo")
 
-load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION")
+load("@python_version_repo//:py_version.bzl", "TF_PYTHON_VERSION")
 
 # Register multi toolchains
 load("@rules_python//python:repositories.bzl", "python_register_toolchains")
@@ -43,7 +43,7 @@
 python_register_toolchains(
     name = "python",
     ignore_root_user_error = True,
-    python_version = HERMETIC_PYTHON_VERSION,
+    python_version = TF_PYTHON_VERSION,
 )
 
 load("@python//:defs.bzl", "interpreter")
@@ -52,7 +52,7 @@
 pip_parse(
     name = "pypi",
     python_interpreter_target = interpreter,
-    requirements = "//:requirements_lock_" + HERMETIC_PYTHON_VERSION.replace(".", "_") + ".txt",
+    requirements = "//:requirements_lock_" + TF_PYTHON_VERSION.replace(".", "_") + ".txt",
 )
 
 load("@pypi//:requirements.bzl", "install_deps")
diff --git a/configure.py b/configure.py
index cbb74be..c1cb201 100644
--- a/configure.py
+++ b/configure.py
@@ -878,11 +878,12 @@
 
 # Disable clang extension that rejects type definitions within offsetof.
 # This was added in clang-16 by https://reviews.llvm.org/D133574.
+# Still required for clang-17.
 # Can be removed once upb is updated, since a type definition is used within
 # offset of in the current version of ubp. See
 # https://github.com/protocolbuffers/upb/blob/9effcbcb27f0a665f9f345030188c0b291e32482/upb/upb.c#L183.
-def disable_clang16_offsetof_extension(clang_version):
-  if int(clang_version.split('.')[0]) == 16:
+def disable_clang_offsetof_extension(clang_version):
+  if int(clang_version.split('.')[0]) in (16, 17):
     write_to_bazelrc('build --copt=-Wno-gnu-offsetof-extensions')
 
 
@@ -1399,7 +1400,7 @@
       # Set up which clang we should use as the cuda / host compiler.
       clang_cuda_compiler_path = set_clang_cuda_compiler_path(environ_cp)
       clang_version = retrieve_clang_version(clang_cuda_compiler_path)
-      disable_clang16_offsetof_extension(clang_version)
+      disable_clang_offsetof_extension(clang_version)
     else:
       # Set up which gcc nvcc should use as the host compiler
       # No need to set this on Windows
@@ -1413,7 +1414,7 @@
       if environ_cp.get('TF_NEED_CLANG') == '1':
         clang_compiler_path = set_clang_compiler_path(environ_cp)
         clang_version = retrieve_clang_version(clang_compiler_path)
-        disable_clang16_offsetof_extension(clang_version)
+        disable_clang_offsetof_extension(clang_version)
 
   # ROCm / CUDA are mutually exclusive.
   # At most 1 GPU platform can be configured.
diff --git a/requirements_lock_3_10.txt b/requirements_lock_3_10.txt
index b9aa6d0..534c341 100644
--- a/requirements_lock_3_10.txt
+++ b/requirements_lock_3_10.txt
@@ -12,105 +12,105 @@
     --hash=sha256:5ad93a8456f0d084c3456d059fd9a92cce667963232cbf763eac3bc5b7940872 \
     --hash=sha256:c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8
     # via -r requirements.in
-cachetools==5.3.1 \
-    --hash=sha256:95ef631eeaea14ba2e36f06437f36463aac3a096799e876ee55e5cdccb102590 \
-    --hash=sha256:dce83f2d9b4e1f732a8cd44af8e8fab2dbe46201467fc98b3ef8f269092bf62b
+cachetools==5.3.2 \
+    --hash=sha256:086ee420196f7b2ab9ca2db2520aca326318b68fe5ba8bc4d49cca91add450f2 \
+    --hash=sha256:861f35a13a451f94e301ce2bec7cac63e881232ccce7ed67fab9b5df4d3beaa1
     # via google-auth
 certifi==2023.7.22 \
     --hash=sha256:539cc1d13202e33ca466e88b2807e29f4c13049d6d87031a3c110744495cb082 \
     --hash=sha256:92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9
     # via requests
-charset-normalizer==3.3.0 \
-    --hash=sha256:02673e456dc5ab13659f85196c534dc596d4ef260e4d86e856c3b2773ce09843 \
-    --hash=sha256:02af06682e3590ab952599fbadac535ede5d60d78848e555aa58d0c0abbde786 \
-    --hash=sha256:03680bb39035fbcffe828eae9c3f8afc0428c91d38e7d61aa992ef7a59fb120e \
-    --hash=sha256:0570d21da019941634a531444364f2482e8db0b3425fcd5ac0c36565a64142c8 \
-    --hash=sha256:09c77f964f351a7369cc343911e0df63e762e42bac24cd7d18525961c81754f4 \
-    --hash=sha256:0d3d5b7db9ed8a2b11a774db2bbea7ba1884430a205dbd54a32d61d7c2a190fa \
-    --hash=sha256:1063da2c85b95f2d1a430f1c33b55c9c17ffaf5e612e10aeaad641c55a9e2b9d \
-    --hash=sha256:12ebea541c44fdc88ccb794a13fe861cc5e35d64ed689513a5c03d05b53b7c82 \
-    --hash=sha256:153e7b6e724761741e0974fc4dcd406d35ba70b92bfe3fedcb497226c93b9da7 \
-    --hash=sha256:15b26ddf78d57f1d143bdf32e820fd8935d36abe8a25eb9ec0b5a71c82eb3895 \
-    --hash=sha256:1872d01ac8c618a8da634e232f24793883d6e456a66593135aeafe3784b0848d \
-    --hash=sha256:187d18082694a29005ba2944c882344b6748d5be69e3a89bf3cc9d878e548d5a \
-    --hash=sha256:1b2919306936ac6efb3aed1fbf81039f7087ddadb3160882a57ee2ff74fd2382 \
-    --hash=sha256:232ac332403e37e4a03d209a3f92ed9071f7d3dbda70e2a5e9cff1c4ba9f0678 \
-    --hash=sha256:23e8565ab7ff33218530bc817922fae827420f143479b753104ab801145b1d5b \
-    --hash=sha256:24817cb02cbef7cd499f7c9a2735286b4782bd47a5b3516a0e84c50eab44b98e \
-    --hash=sha256:249c6470a2b60935bafd1d1d13cd613f8cd8388d53461c67397ee6a0f5dce741 \
-    --hash=sha256:24a91a981f185721542a0b7c92e9054b7ab4fea0508a795846bc5b0abf8118d4 \
-    --hash=sha256:2502dd2a736c879c0f0d3e2161e74d9907231e25d35794584b1ca5284e43f596 \
-    --hash=sha256:250c9eb0f4600361dd80d46112213dff2286231d92d3e52af1e5a6083d10cad9 \
-    --hash=sha256:278c296c6f96fa686d74eb449ea1697f3c03dc28b75f873b65b5201806346a69 \
-    --hash=sha256:2935ffc78db9645cb2086c2f8f4cfd23d9b73cc0dc80334bc30aac6f03f68f8c \
-    --hash=sha256:2f4a0033ce9a76e391542c182f0d48d084855b5fcba5010f707c8e8c34663d77 \
-    --hash=sha256:30a85aed0b864ac88309b7d94be09f6046c834ef60762a8833b660139cfbad13 \
-    --hash=sha256:380c4bde80bce25c6e4f77b19386f5ec9db230df9f2f2ac1e5ad7af2caa70459 \
-    --hash=sha256:3ae38d325b512f63f8da31f826e6cb6c367336f95e418137286ba362925c877e \
-    --hash=sha256:3b447982ad46348c02cb90d230b75ac34e9886273df3a93eec0539308a6296d7 \
-    --hash=sha256:3debd1150027933210c2fc321527c2299118aa929c2f5a0a80ab6953e3bd1908 \
-    --hash=sha256:4162918ef3098851fcd8a628bf9b6a98d10c380725df9e04caf5ca6dd48c847a \
-    --hash=sha256:468d2a840567b13a590e67dd276c570f8de00ed767ecc611994c301d0f8c014f \
-    --hash=sha256:4cc152c5dd831641e995764f9f0b6589519f6f5123258ccaca8c6d34572fefa8 \
-    --hash=sha256:542da1178c1c6af8873e143910e2269add130a299c9106eef2594e15dae5e482 \
-    --hash=sha256:557b21a44ceac6c6b9773bc65aa1b4cc3e248a5ad2f5b914b91579a32e22204d \
-    --hash=sha256:5707a746c6083a3a74b46b3a631d78d129edab06195a92a8ece755aac25a3f3d \
-    --hash=sha256:588245972aca710b5b68802c8cad9edaa98589b1b42ad2b53accd6910dad3545 \
-    --hash=sha256:5adf257bd58c1b8632046bbe43ee38c04e1038e9d37de9c57a94d6bd6ce5da34 \
-    --hash=sha256:619d1c96099be5823db34fe89e2582b336b5b074a7f47f819d6b3a57ff7bdb86 \
-    --hash=sha256:63563193aec44bce707e0c5ca64ff69fa72ed7cf34ce6e11d5127555756fd2f6 \
-    --hash=sha256:67b8cc9574bb518ec76dc8e705d4c39ae78bb96237cb533edac149352c1f39fe \
-    --hash=sha256:6a685067d05e46641d5d1623d7c7fdf15a357546cbb2f71b0ebde91b175ffc3e \
-    --hash=sha256:70f1d09c0d7748b73290b29219e854b3207aea922f839437870d8cc2168e31cc \
-    --hash=sha256:750b446b2ffce1739e8578576092179160f6d26bd5e23eb1789c4d64d5af7dc7 \
-    --hash=sha256:7966951325782121e67c81299a031f4c115615e68046f79b85856b86ebffc4cd \
-    --hash=sha256:7b8b8bf1189b3ba9b8de5c8db4d541b406611a71a955bbbd7385bbc45fcb786c \
-    --hash=sha256:7f5d10bae5d78e4551b7be7a9b29643a95aded9d0f602aa2ba584f0388e7a557 \
-    --hash=sha256:805dfea4ca10411a5296bcc75638017215a93ffb584c9e344731eef0dcfb026a \
-    --hash=sha256:81bf654678e575403736b85ba3a7867e31c2c30a69bc57fe88e3ace52fb17b89 \
-    --hash=sha256:82eb849f085624f6a607538ee7b83a6d8126df6d2f7d3b319cb837b289123078 \
-    --hash=sha256:85a32721ddde63c9df9ebb0d2045b9691d9750cb139c161c80e500d210f5e26e \
-    --hash=sha256:86d1f65ac145e2c9ed71d8ffb1905e9bba3a91ae29ba55b4c46ae6fc31d7c0d4 \
-    --hash=sha256:86f63face3a527284f7bb8a9d4f78988e3c06823f7bea2bd6f0e0e9298ca0403 \
-    --hash=sha256:8eaf82f0eccd1505cf39a45a6bd0a8cf1c70dcfc30dba338207a969d91b965c0 \
-    --hash=sha256:93aa7eef6ee71c629b51ef873991d6911b906d7312c6e8e99790c0f33c576f89 \
-    --hash=sha256:96c2b49eb6a72c0e4991d62406e365d87067ca14c1a729a870d22354e6f68115 \
-    --hash=sha256:9cf3126b85822c4e53aa28c7ec9869b924d6fcfb76e77a45c44b83d91afd74f9 \
-    --hash=sha256:9fe359b2e3a7729010060fbca442ca225280c16e923b37db0e955ac2a2b72a05 \
-    --hash=sha256:a0ac5e7015a5920cfce654c06618ec40c33e12801711da6b4258af59a8eff00a \
-    --hash=sha256:a3f93dab657839dfa61025056606600a11d0b696d79386f974e459a3fbc568ec \
-    --hash=sha256:a4b71f4d1765639372a3b32d2638197f5cd5221b19531f9245fcc9ee62d38f56 \
-    --hash=sha256:aae32c93e0f64469f74ccc730a7cb21c7610af3a775157e50bbd38f816536b38 \
-    --hash=sha256:aaf7b34c5bc56b38c931a54f7952f1ff0ae77a2e82496583b247f7c969eb1479 \
-    --hash=sha256:abecce40dfebbfa6abf8e324e1860092eeca6f7375c8c4e655a8afb61af58f2c \
-    --hash=sha256:abf0d9f45ea5fb95051c8bfe43cb40cda383772f7e5023a83cc481ca2604d74e \
-    --hash=sha256:ac71b2977fb90c35d41c9453116e283fac47bb9096ad917b8819ca8b943abecd \
-    --hash=sha256:ada214c6fa40f8d800e575de6b91a40d0548139e5dc457d2ebb61470abf50186 \
-    --hash=sha256:b09719a17a2301178fac4470d54b1680b18a5048b481cb8890e1ef820cb80455 \
-    --hash=sha256:b1121de0e9d6e6ca08289583d7491e7fcb18a439305b34a30b20d8215922d43c \
-    --hash=sha256:b3b2316b25644b23b54a6f6401074cebcecd1244c0b8e80111c9a3f1c8e83d65 \
-    --hash=sha256:b3d9b48ee6e3967b7901c052b670c7dda6deb812c309439adaffdec55c6d7b78 \
-    --hash=sha256:b5bcf60a228acae568e9911f410f9d9e0d43197d030ae5799e20dca8df588287 \
-    --hash=sha256:b8f3307af845803fb0b060ab76cf6dd3a13adc15b6b451f54281d25911eb92df \
-    --hash=sha256:c2af80fb58f0f24b3f3adcb9148e6203fa67dd3f61c4af146ecad033024dde43 \
-    --hash=sha256:c350354efb159b8767a6244c166f66e67506e06c8924ed74669b2c70bc8735b1 \
-    --hash=sha256:c5a74c359b2d47d26cdbbc7845e9662d6b08a1e915eb015d044729e92e7050b7 \
-    --hash=sha256:c71f16da1ed8949774ef79f4a0260d28b83b3a50c6576f8f4f0288d109777989 \
-    --hash=sha256:d47ecf253780c90ee181d4d871cd655a789da937454045b17b5798da9393901a \
-    --hash=sha256:d7eff0f27edc5afa9e405f7165f85a6d782d308f3b6b9d96016c010597958e63 \
-    --hash=sha256:d97d85fa63f315a8bdaba2af9a6a686e0eceab77b3089af45133252618e70884 \
-    --hash=sha256:db756e48f9c5c607b5e33dd36b1d5872d0422e960145b08ab0ec7fd420e9d649 \
-    --hash=sha256:dc45229747b67ffc441b3de2f3ae5e62877a282ea828a5bdb67883c4ee4a8810 \
-    --hash=sha256:e0fc42822278451bc13a2e8626cf2218ba570f27856b536e00cfa53099724828 \
-    --hash=sha256:e39c7eb31e3f5b1f88caff88bcff1b7f8334975b46f6ac6e9fc725d829bc35d4 \
-    --hash=sha256:e46cd37076971c1040fc8c41273a8b3e2c624ce4f2be3f5dfcb7a430c1d3acc2 \
-    --hash=sha256:e5c1502d4ace69a179305abb3f0bb6141cbe4714bc9b31d427329a95acfc8bdd \
-    --hash=sha256:edfe077ab09442d4ef3c52cb1f9dab89bff02f4524afc0acf2d46be17dc479f5 \
-    --hash=sha256:effe5406c9bd748a871dbcaf3ac69167c38d72db8c9baf3ff954c344f31c4cbe \
-    --hash=sha256:f0d1e3732768fecb052d90d62b220af62ead5748ac51ef61e7b32c266cac9293 \
-    --hash=sha256:f5969baeaea61c97efa706b9b107dcba02784b1601c74ac84f2a532ea079403e \
-    --hash=sha256:f8888e31e3a85943743f8fc15e71536bda1c81d5aa36d014a3c0c44481d7db6e \
-    --hash=sha256:fc52b79d83a3fe3a360902d3f5d79073a993597d48114c29485e9431092905d8
+charset-normalizer==3.3.1 \
+    --hash=sha256:06cf46bdff72f58645434d467bf5228080801298fbba19fe268a01b4534467f5 \
+    --hash=sha256:0c8c61fb505c7dad1d251c284e712d4e0372cef3b067f7ddf82a7fa82e1e9a93 \
+    --hash=sha256:10b8dd31e10f32410751b3430996f9807fc4d1587ca69772e2aa940a82ab571a \
+    --hash=sha256:1171ef1fc5ab4693c5d151ae0fdad7f7349920eabbaca6271f95969fa0756c2d \
+    --hash=sha256:17a866d61259c7de1bdadef418a37755050ddb4b922df8b356503234fff7932c \
+    --hash=sha256:1d6bfc32a68bc0933819cfdfe45f9abc3cae3877e1d90aac7259d57e6e0f85b1 \
+    --hash=sha256:1ec937546cad86d0dce5396748bf392bb7b62a9eeb8c66efac60e947697f0e58 \
+    --hash=sha256:223b4d54561c01048f657fa6ce41461d5ad8ff128b9678cfe8b2ecd951e3f8a2 \
+    --hash=sha256:2465aa50c9299d615d757c1c888bc6fef384b7c4aec81c05a0172b4400f98557 \
+    --hash=sha256:28f512b9a33235545fbbdac6a330a510b63be278a50071a336afc1b78781b147 \
+    --hash=sha256:2c092be3885a1b7899cd85ce24acedc1034199d6fca1483fa2c3a35c86e43041 \
+    --hash=sha256:2c4c99f98fc3a1835af8179dcc9013f93594d0670e2fa80c83aa36346ee763d2 \
+    --hash=sha256:31445f38053476a0c4e6d12b047b08ced81e2c7c712e5a1ad97bc913256f91b2 \
+    --hash=sha256:31bbaba7218904d2eabecf4feec0d07469284e952a27400f23b6628439439fa7 \
+    --hash=sha256:34d95638ff3613849f473afc33f65c401a89f3b9528d0d213c7037c398a51296 \
+    --hash=sha256:352a88c3df0d1fa886562384b86f9a9e27563d4704ee0e9d56ec6fcd270ea690 \
+    --hash=sha256:39b70a6f88eebe239fa775190796d55a33cfb6d36b9ffdd37843f7c4c1b5dc67 \
+    --hash=sha256:3c66df3f41abee950d6638adc7eac4730a306b022570f71dd0bd6ba53503ab57 \
+    --hash=sha256:3f70fd716855cd3b855316b226a1ac8bdb3caf4f7ea96edcccc6f484217c9597 \
+    --hash=sha256:3f9bc2ce123637a60ebe819f9fccc614da1bcc05798bbbaf2dd4ec91f3e08846 \
+    --hash=sha256:3fb765362688821404ad6cf86772fc54993ec11577cd5a92ac44b4c2ba52155b \
+    --hash=sha256:45f053a0ece92c734d874861ffe6e3cc92150e32136dd59ab1fb070575189c97 \
+    --hash=sha256:46fb9970aa5eeca547d7aa0de5d4b124a288b42eaefac677bde805013c95725c \
+    --hash=sha256:4cb50a0335382aac15c31b61d8531bc9bb657cfd848b1d7158009472189f3d62 \
+    --hash=sha256:4e12f8ee80aa35e746230a2af83e81bd6b52daa92a8afaef4fea4a2ce9b9f4fa \
+    --hash=sha256:4f3100d86dcd03c03f7e9c3fdb23d92e32abbca07e7c13ebd7ddfbcb06f5991f \
+    --hash=sha256:4f6e2a839f83a6a76854d12dbebde50e4b1afa63e27761549d006fa53e9aa80e \
+    --hash=sha256:4f861d94c2a450b974b86093c6c027888627b8082f1299dfd5a4bae8e2292821 \
+    --hash=sha256:501adc5eb6cd5f40a6f77fbd90e5ab915c8fd6e8c614af2db5561e16c600d6f3 \
+    --hash=sha256:520b7a142d2524f999447b3a0cf95115df81c4f33003c51a6ab637cbda9d0bf4 \
+    --hash=sha256:548eefad783ed787b38cb6f9a574bd8664468cc76d1538215d510a3cd41406cb \
+    --hash=sha256:555fe186da0068d3354cdf4bbcbc609b0ecae4d04c921cc13e209eece7720727 \
+    --hash=sha256:55602981b2dbf8184c098bc10287e8c245e351cd4fdcad050bd7199d5a8bf514 \
+    --hash=sha256:58e875eb7016fd014c0eea46c6fa92b87b62c0cb31b9feae25cbbe62c919f54d \
+    --hash=sha256:5a3580a4fdc4ac05f9e53c57f965e3594b2f99796231380adb2baaab96e22761 \
+    --hash=sha256:5b70bab78accbc672f50e878a5b73ca692f45f5b5e25c8066d748c09405e6a55 \
+    --hash=sha256:5ceca5876032362ae73b83347be8b5dbd2d1faf3358deb38c9c88776779b2e2f \
+    --hash=sha256:61f1e3fb621f5420523abb71f5771a204b33c21d31e7d9d86881b2cffe92c47c \
+    --hash=sha256:633968254f8d421e70f91c6ebe71ed0ab140220469cf87a9857e21c16687c034 \
+    --hash=sha256:63a6f59e2d01310f754c270e4a257426fe5a591dc487f1983b3bbe793cf6bac6 \
+    --hash=sha256:63accd11149c0f9a99e3bc095bbdb5a464862d77a7e309ad5938fbc8721235ae \
+    --hash=sha256:6db3cfb9b4fcecb4390db154e75b49578c87a3b9979b40cdf90d7e4b945656e1 \
+    --hash=sha256:71ef3b9be10070360f289aea4838c784f8b851be3ba58cf796262b57775c2f14 \
+    --hash=sha256:7ae8e5142dcc7a49168f4055255dbcced01dc1714a90a21f87448dc8d90617d1 \
+    --hash=sha256:7b6cefa579e1237ce198619b76eaa148b71894fb0d6bcf9024460f9bf30fd228 \
+    --hash=sha256:800561453acdecedaac137bf09cd719c7a440b6800ec182f077bb8e7025fb708 \
+    --hash=sha256:82ca51ff0fc5b641a2d4e1cc8c5ff108699b7a56d7f3ad6f6da9dbb6f0145b48 \
+    --hash=sha256:851cf693fb3aaef71031237cd68699dded198657ec1e76a76eb8be58c03a5d1f \
+    --hash=sha256:854cc74367180beb327ab9d00f964f6d91da06450b0855cbbb09187bcdb02de5 \
+    --hash=sha256:87071618d3d8ec8b186d53cb6e66955ef2a0e4fa63ccd3709c0c90ac5a43520f \
+    --hash=sha256:871d045d6ccc181fd863a3cd66ee8e395523ebfbc57f85f91f035f50cee8e3d4 \
+    --hash=sha256:8aee051c89e13565c6bd366813c386939f8e928af93c29fda4af86d25b73d8f8 \
+    --hash=sha256:8af5a8917b8af42295e86b64903156b4f110a30dca5f3b5aedea123fbd638bff \
+    --hash=sha256:8ec8ef42c6cd5856a7613dcd1eaf21e5573b2185263d87d27c8edcae33b62a61 \
+    --hash=sha256:91e43805ccafa0a91831f9cd5443aa34528c0c3f2cc48c4cb3d9a7721053874b \
+    --hash=sha256:9505dc359edb6a330efcd2be825fdb73ee3e628d9010597aa1aee5aa63442e97 \
+    --hash=sha256:985c7965f62f6f32bf432e2681173db41336a9c2611693247069288bcb0c7f8b \
+    --hash=sha256:9a74041ba0bfa9bc9b9bb2cd3238a6ab3b7618e759b41bd15b5f6ad958d17605 \
+    --hash=sha256:9edbe6a5bf8b56a4a84533ba2b2f489d0046e755c29616ef8830f9e7d9cf5728 \
+    --hash=sha256:a15c1fe6d26e83fd2e5972425a772cca158eae58b05d4a25a4e474c221053e2d \
+    --hash=sha256:a66bcdf19c1a523e41b8e9d53d0cedbfbac2e93c649a2e9502cb26c014d0980c \
+    --hash=sha256:ae4070f741f8d809075ef697877fd350ecf0b7c5837ed68738607ee0a2c572cf \
+    --hash=sha256:ae55d592b02c4349525b6ed8f74c692509e5adffa842e582c0f861751701a673 \
+    --hash=sha256:b578cbe580e3b41ad17b1c428f382c814b32a6ce90f2d8e39e2e635d49e498d1 \
+    --hash=sha256:b891a2f68e09c5ef989007fac11476ed33c5c9994449a4e2c3386529d703dc8b \
+    --hash=sha256:baec8148d6b8bd5cee1ae138ba658c71f5b03e0d69d5907703e3e1df96db5e41 \
+    --hash=sha256:bb06098d019766ca16fc915ecaa455c1f1cd594204e7f840cd6258237b5079a8 \
+    --hash=sha256:bc791ec3fd0c4309a753f95bb6c749ef0d8ea3aea91f07ee1cf06b7b02118f2f \
+    --hash=sha256:bd28b31730f0e982ace8663d108e01199098432a30a4c410d06fe08fdb9e93f4 \
+    --hash=sha256:be4d9c2770044a59715eb57c1144dedea7c5d5ae80c68fb9959515037cde2008 \
+    --hash=sha256:c0c72d34e7de5604df0fde3644cc079feee5e55464967d10b24b1de268deceb9 \
+    --hash=sha256:c0e842112fe3f1a4ffcf64b06dc4c61a88441c2f02f373367f7b4c1aa9be2ad5 \
+    --hash=sha256:c15070ebf11b8b7fd1bfff7217e9324963c82dbdf6182ff7050519e350e7ad9f \
+    --hash=sha256:c2000c54c395d9e5e44c99dc7c20a64dc371f777faf8bae4919ad3e99ce5253e \
+    --hash=sha256:c30187840d36d0ba2893bc3271a36a517a717f9fd383a98e2697ee890a37c273 \
+    --hash=sha256:cb7cd68814308aade9d0c93c5bd2ade9f9441666f8ba5aa9c2d4b389cb5e2a45 \
+    --hash=sha256:cd805513198304026bd379d1d516afbf6c3c13f4382134a2c526b8b854da1c2e \
+    --hash=sha256:d0bf89afcbcf4d1bb2652f6580e5e55a840fdf87384f6063c4a4f0c95e378656 \
+    --hash=sha256:d9137a876020661972ca6eec0766d81aef8a5627df628b664b234b73396e727e \
+    --hash=sha256:dbd95e300367aa0827496fe75a1766d198d34385a58f97683fe6e07f89ca3e3c \
+    --hash=sha256:dced27917823df984fe0c80a5c4ad75cf58df0fbfae890bc08004cd3888922a2 \
+    --hash=sha256:de0b4caa1c8a21394e8ce971997614a17648f94e1cd0640fbd6b4d14cab13a72 \
+    --hash=sha256:debb633f3f7856f95ad957d9b9c781f8e2c6303ef21724ec94bea2ce2fcbd056 \
+    --hash=sha256:e372d7dfd154009142631de2d316adad3cc1c36c32a38b16a4751ba78da2a397 \
+    --hash=sha256:ecd26be9f112c4f96718290c10f4caea6cc798459a3a76636b817a0ed7874e42 \
+    --hash=sha256:edc0202099ea1d82844316604e17d2b175044f9bcb6b398aab781eba957224bd \
+    --hash=sha256:f194cce575e59ffe442c10a360182a986535fd90b57f7debfaa5c845c409ecc3 \
+    --hash=sha256:f5fb672c396d826ca16a022ac04c9dce74e00a1c344f6ad1a0fdc1ba1f332213 \
+    --hash=sha256:f6a02a3c7950cafaadcd46a226ad9e12fc9744652cc69f9e5534f98b47f3bbcf \
+    --hash=sha256:fe81b35c33772e56f4b6cf62cf4aedc1762ef7162a31e6ac7fe5e40d0149eb67
     # via requests
 dill==0.3.7 \
     --hash=sha256:76b122c08ef4ce2eedcd4d1abd8e641114bfc6c2867f49f3c41facf65bf19f5e \
@@ -171,61 +171,61 @@
     --hash=sha256:089c6e587d36f4803ac7e0720c045c6a8b1fd1790088b8424975b90d0ee61c12 \
     --hash=sha256:83ea8c3b0881e453790baff4448e8a6112ac8778d1de9da0b68010b843937afb
     # via tb-nightly
-grpcio==1.59.0 \
-    --hash=sha256:0ae444221b2c16d8211b55326f8ba173ba8f8c76349bfc1768198ba592b58f74 \
-    --hash=sha256:0b84445fa94d59e6806c10266b977f92fa997db3585f125d6b751af02ff8b9fe \
-    --hash=sha256:14890da86a0c0e9dc1ea8e90101d7a3e0e7b1e71f4487fab36e2bfd2ecadd13c \
-    --hash=sha256:15f03bd714f987d48ae57fe092cf81960ae36da4e520e729392a59a75cda4f29 \
-    --hash=sha256:1a839ba86764cc48226f50b924216000c79779c563a301586a107bda9cbe9dcf \
-    --hash=sha256:225e5fa61c35eeaebb4e7491cd2d768cd8eb6ed00f2664fa83a58f29418b39fd \
-    --hash=sha256:228b91ce454876d7eed74041aff24a8f04c0306b7250a2da99d35dd25e2a1211 \
-    --hash=sha256:2ea95cd6abbe20138b8df965b4a8674ec312aaef3147c0f46a0bac661f09e8d0 \
-    --hash=sha256:2f120d27051e4c59db2f267b71b833796770d3ea36ca712befa8c5fff5da6ebd \
-    --hash=sha256:34341d9e81a4b669a5f5dca3b2a760b6798e95cdda2b173e65d29d0b16692857 \
-    --hash=sha256:3859917de234a0a2a52132489c4425a73669de9c458b01c9a83687f1f31b5b10 \
-    --hash=sha256:38823bd088c69f59966f594d087d3a929d1ef310506bee9e3648317660d65b81 \
-    --hash=sha256:38da5310ef84e16d638ad89550b5b9424df508fd5c7b968b90eb9629ca9be4b9 \
-    --hash=sha256:3b8ff795d35a93d1df6531f31c1502673d1cebeeba93d0f9bd74617381507e3f \
-    --hash=sha256:50eff97397e29eeee5df106ea1afce3ee134d567aa2c8e04fabab05c79d791a7 \
-    --hash=sha256:5711c51e204dc52065f4a3327dca46e69636a0b76d3e98c2c28c4ccef9b04c52 \
-    --hash=sha256:598f3530231cf10ae03f4ab92d48c3be1fee0c52213a1d5958df1a90957e6a88 \
-    --hash=sha256:611d9aa0017fa386809bddcb76653a5ab18c264faf4d9ff35cb904d44745f575 \
-    --hash=sha256:61bc72a00ecc2b79d9695220b4d02e8ba53b702b42411397e831c9b0589f08a3 \
-    --hash=sha256:63982150a7d598281fa1d7ffead6096e543ff8be189d3235dd2b5604f2c553e5 \
-    --hash=sha256:6c4b1cc3a9dc1924d2eb26eec8792fedd4b3fcd10111e26c1d551f2e4eda79ce \
-    --hash=sha256:81d86a096ccd24a57fa5772a544c9e566218bc4de49e8c909882dae9d73392df \
-    --hash=sha256:849c47ef42424c86af069a9c5e691a765e304079755d5c29eff511263fad9c2a \
-    --hash=sha256:871371ce0c0055d3db2a86fdebd1e1d647cf21a8912acc30052660297a5a6901 \
-    --hash=sha256:8cd2d38c2d52f607d75a74143113174c36d8a416d9472415eab834f837580cf7 \
-    --hash=sha256:936b2e04663660c600d5173bc2cc84e15adbad9c8f71946eb833b0afc205b996 \
-    --hash=sha256:93e9cb546e610829e462147ce724a9cb108e61647a3454500438a6deef610be1 \
-    --hash=sha256:956f0b7cb465a65de1bd90d5a7475b4dc55089b25042fe0f6c870707e9aabb1d \
-    --hash=sha256:986de4aa75646e963466b386a8c5055c8b23a26a36a6c99052385d6fe8aaf180 \
-    --hash=sha256:aca8a24fef80bef73f83eb8153f5f5a0134d9539b4c436a716256b311dda90a6 \
-    --hash=sha256:acf70a63cf09dd494000007b798aff88a436e1c03b394995ce450be437b8e54f \
-    --hash=sha256:b34c7a4c31841a2ea27246a05eed8a80c319bfc0d3e644412ec9ce437105ff6c \
-    --hash=sha256:b95ec8ecc4f703f5caaa8d96e93e40c7f589bad299a2617bdb8becbcce525539 \
-    --hash=sha256:ba0ca727a173ee093f49ead932c051af463258b4b493b956a2c099696f38aa66 \
-    --hash=sha256:c041a91712bf23b2a910f61e16565a05869e505dc5a5c025d429ca6de5de842c \
-    --hash=sha256:c0488c2b0528e6072010182075615620071371701733c63ab5be49140ed8f7f0 \
-    --hash=sha256:c173a87d622ea074ce79be33b952f0b424fa92182063c3bda8625c11d3585d09 \
-    --hash=sha256:c251d22de8f9f5cca9ee47e4bade7c5c853e6e40743f47f5cc02288ee7a87252 \
-    --hash=sha256:c4dfdb49f4997dc664f30116af2d34751b91aa031f8c8ee251ce4dcfc11277b0 \
-    --hash=sha256:ca87ee6183421b7cea3544190061f6c1c3dfc959e0b57a5286b108511fd34ff4 \
-    --hash=sha256:ceb1e68135788c3fce2211de86a7597591f0b9a0d2bb80e8401fd1d915991bac \
-    --hash=sha256:d09bd2a4e9f5a44d36bb8684f284835c14d30c22d8ec92ce796655af12163588 \
-    --hash=sha256:d0fcf53df684fcc0154b1e61f6b4a8c4cf5f49d98a63511e3f30966feff39cd0 \
-    --hash=sha256:d74f7d2d7c242a6af9d4d069552ec3669965b74fed6b92946e0e13b4168374f9 \
-    --hash=sha256:de2599985b7c1b4ce7526e15c969d66b93687571aa008ca749d6235d056b7205 \
-    --hash=sha256:e5378785dce2b91eb2e5b857ec7602305a3b5cf78311767146464bfa365fc897 \
-    --hash=sha256:ec78aebb9b6771d6a1de7b6ca2f779a2f6113b9108d486e904bde323d51f5589 \
-    --hash=sha256:f1feb034321ae2f718172d86b8276c03599846dc7bb1792ae370af02718f91c5 \
-    --hash=sha256:f21917aa50b40842b51aff2de6ebf9e2f6af3fe0971c31960ad6a3a2b24988f4 \
-    --hash=sha256:f367e4b524cb319e50acbdea57bb63c3b717c5d561974ace0b065a648bb3bad3 \
-    --hash=sha256:f6cfe44a5d7c7d5f1017a7da1c8160304091ca5dc64a0f85bca0d63008c3137a \
-    --hash=sha256:fa66cac32861500f280bb60fe7d5b3e22d68c51e18e65367e38f8669b78cea3b \
-    --hash=sha256:fc8bf2e7bc725e76c0c11e474634a08c8f24bcf7426c0c6d60c8f9c6e70e4d4a \
-    --hash=sha256:fe976910de34d21057bcb53b2c5e667843588b48bf11339da2a75f5c4c5b4055
+grpcio==1.59.2 \
+    --hash=sha256:023088764012411affe7db183d1ada3ad9daf2e23ddc719ff46d7061de661340 \
+    --hash=sha256:08d77e682f2bf730a4961eea330e56d2f423c6a9b91ca222e5b1eb24a357b19f \
+    --hash=sha256:0a4a3833c0e067f3558538727235cd8a49709bff1003200bbdefa2f09334e4b1 \
+    --hash=sha256:0a754aff9e3af63bdc4c75c234b86b9d14e14a28a30c4e324aed1a9b873d755f \
+    --hash=sha256:11168ef43e4a43ff1b1a65859f3e0ef1a173e277349e7fb16923ff108160a8cd \
+    --hash=sha256:128e20f57c5f27cb0157e73756d1586b83c1b513ebecc83ea0ac37e4b0e4e758 \
+    --hash=sha256:1f9524d1d701e399462d2c90ba7c193e49d1711cf429c0d3d97c966856e03d00 \
+    --hash=sha256:1ff16d68bf453275466a9a46739061a63584d92f18a0f5b33d19fc97eb69867c \
+    --hash=sha256:2067274c88bc6de89c278a672a652b4247d088811ece781a4858b09bdf8448e3 \
+    --hash=sha256:2171c39f355ba5b551c5d5928d65aa6c69807fae195b86ef4a7d125bcdb860a9 \
+    --hash=sha256:242adc47725b9a499ee77c6a2e36688fa6c96484611f33b1be4c57ab075a92dd \
+    --hash=sha256:27f879ae604a7fcf371e59fba6f3ff4635a4c2a64768bd83ff0cac503142fef4 \
+    --hash=sha256:2b230028a008ae1d0f430acb227d323ff8a619017415cf334c38b457f814119f \
+    --hash=sha256:3059668df17627f0e0fa680e9ef8c995c946c792612e9518f5cc1503be14e90b \
+    --hash=sha256:31176aa88f36020055ace9adff2405a33c8bdbfa72a9c4980e25d91b2f196873 \
+    --hash=sha256:36f53c2b3449c015880e7d55a89c992c357f176327b0d2873cdaaf9628a37c69 \
+    --hash=sha256:3b4368b33908f683a363f376dfb747d40af3463a6e5044afee07cf9436addf96 \
+    --hash=sha256:3c61d641d4f409c5ae46bfdd89ea42ce5ea233dcf69e74ce9ba32b503c727e29 \
+    --hash=sha256:4abb717e320e74959517dc8e84a9f48fbe90e9abe19c248541e9418b1ce60acd \
+    --hash=sha256:4c93f4abbb54321ee6471e04a00139c80c754eda51064187963ddf98f5cf36a4 \
+    --hash=sha256:535561990e075fa6bd4b16c4c3c1096b9581b7bb35d96fac4650f1181e428268 \
+    --hash=sha256:53c9aa5ddd6857c0a1cd0287225a2a25873a8e09727c2e95c4aebb1be83a766a \
+    --hash=sha256:5d573e70a6fe77555fb6143c12d3a7d3fa306632a3034b4e7c59ca09721546f8 \
+    --hash=sha256:6009386a2df66159f64ac9f20425ae25229b29b9dd0e1d3dd60043f037e2ad7e \
+    --hash=sha256:686e975a5d16602dc0982c7c703948d17184bd1397e16c8ee03511ecb8c4cdda \
+    --hash=sha256:6959fb07e8351e20501ffb8cc4074c39a0b7ef123e1c850a7f8f3afdc3a3da01 \
+    --hash=sha256:6b25ed37c27e652db01be341af93fbcea03d296c024d8a0e680017a268eb85dd \
+    --hash=sha256:6da6dea3a1bacf99b3c2187e296db9a83029ed9c38fd4c52b7c9b7326d13c828 \
+    --hash=sha256:72ca2399097c0b758198f2ff30f7178d680de8a5cfcf3d9b73a63cf87455532e \
+    --hash=sha256:73abb8584b0cf74d37f5ef61c10722adc7275502ab71789a8fe3cb7ef04cf6e2 \
+    --hash=sha256:74100fecaec8a535e380cf5f2fb556ff84957d481c13e54051c52e5baac70541 \
+    --hash=sha256:75c6ecb70e809cf1504465174343113f51f24bc61e22a80ae1c859f3f7034c6d \
+    --hash=sha256:7cf05053242f61ba94014dd3a986e11a083400a32664058f80bf4cf817c0b3a1 \
+    --hash=sha256:9411e24328a2302e279e70cae6e479f1fddde79629fcb14e03e6d94b3956eabf \
+    --hash=sha256:a213acfbf186b9f35803b52e4ca9addb153fc0b67f82a48f961be7000ecf6721 \
+    --hash=sha256:bb7e0fe6ad73b7f06d7e2b689c19a71cf5cc48f0c2bf8608469e51ffe0bd2867 \
+    --hash=sha256:c2504eed520958a5b77cc99458297cb7906308cb92327f35fb7fbbad4e9b2188 \
+    --hash=sha256:c35aa9657f5d5116d23b934568e0956bd50c615127810fffe3ac356a914c176a \
+    --hash=sha256:c5f09cffa619adfb44799fa4a81c2a1ad77c887187613fb0a8f201ab38d89ba1 \
+    --hash=sha256:c978f864b35f2261e0819f5cd88b9830b04dc51bcf055aac3c601e525a10d2ba \
+    --hash=sha256:cbe946b3e6e60a7b4618f091e62a029cb082b109a9d6b53962dd305087c6e4fd \
+    --hash=sha256:cc3e4cd087f07758b16bef8f31d88dbb1b5da5671d2f03685ab52dece3d7a16e \
+    --hash=sha256:cf0dead5a2c5a3347af2cfec7131d4f2a2e03c934af28989c9078f8241a491fa \
+    --hash=sha256:d2794f0e68b3085d99b4f6ff9c089f6fdd02b32b9d3efdfbb55beac1bf22d516 \
+    --hash=sha256:d2fa68a96a30dd240be80bbad838a0ac81a61770611ff7952b889485970c4c71 \
+    --hash=sha256:d6f70406695e3220f09cd7a2f879333279d91aa4a8a1d34303b56d61a8180137 \
+    --hash=sha256:d8f9cd4ad1be90b0cf350a2f04a38a36e44a026cac1e036ac593dc48efe91d52 \
+    --hash=sha256:da2d94c15f88cd40d7e67f7919d4f60110d2b9d5b1e08cf354c2be773ab13479 \
+    --hash=sha256:e1727c1c0e394096bb9af185c6923e8ea55a5095b8af44f06903bcc0e06800a2 \
+    --hash=sha256:e420ced29b5904cdf9ee5545e23f9406189d8acb6750916c2db4793dada065c6 \
+    --hash=sha256:e82c5cf1495244adf5252f925ac5932e5fd288b3e5ab6b70bec5593074b7236c \
+    --hash=sha256:f1ef0d39bc1feb420caf549b3c657c871cad4ebbcf0580c4d03816b0590de0cf \
+    --hash=sha256:f8753a6c88d1d0ba64302309eecf20f70d2770f65ca02d83c2452279085bfcd3 \
+    --hash=sha256:f93dbf58f03146164048be5426ffde298b237a5e059144847e4940f5b80172c3
     # via
     #   -r requirements.in
     #   tb-nightly
@@ -265,12 +265,12 @@
 jax==0.4.7 \
     --hash=sha256:5e7002d74db25f97c99b979d4ba1233b1ef26e1597e5fc468ad11d1c8a9dc4f8
     # via -r requirements.in
-keras-nightly==3.0.0.dev2023101703 \
-    --hash=sha256:72674b80300b9672b76e20c3d4c4bd827f6cacd486d9f98e3db864d872e3eaa4 \
-    --hash=sha256:ab5e59bf6a84d048b3241ad9ae368ed8497740804a8e85774fbded34c1322587
+keras-nightly==3.0.0.dev2023103103 \
+    --hash=sha256:25a6a9030d7e067d535ec41ae6700636c90cabada80fbddb3bb8775006807860 \
+    --hash=sha256:c39fccebb3d4cc9d838371981c4d3eef08fc06eadf6cffb39b356cb625bab50f
     # via -r requirements.in
-lit==17.0.3 \
-    --hash=sha256:e6049032462be1e2928686cbd4a6cc5b3c545d83ecd078737fe79412c1f3fcc1
+lit==17.0.4 \
+    --hash=sha256:ee2e180128e770abc6aed3a02de2daf09d81b7d30225e315205d3599c311d304
     # via -r requirements.in
 markdown==3.5 \
     --hash=sha256:4afb124395ce5fc34e6d9886dab977fd9ae987fc6e85689f08278cf0c69d4bf3 \
@@ -526,17 +526,17 @@
     # via
     #   astunparse
     #   tb-nightly
-tb-nightly==2.15.0a20231017 \
-    --hash=sha256:982b8cf32bcab4902eebd2e67b885c127f16106114b34eb4d46ea554af2a713a
+tb-nightly==2.15.0a20231023 \
+    --hash=sha256:8990a52985e3296aa18a6825efc017bcded9e2fb2cbcdd2d01f5c62d4fdc9825
     # via -r requirements.in
 tblib==2.0.0 \
     --hash=sha256:9100bfa016b047d5b980d66e7efed952fbd20bd85b56110aaf473cb97d18709a \
     --hash=sha256:a6df30f272c08bf8be66e0775fad862005d950a6b8449b94f7c788731d70ecd7
     # via -r requirements.in
-tensorboard-data-server==0.7.1 \
-    --hash=sha256:255c02b7f5b03dd5c0a88c928e563441ff39e1d4b4a234cdbe09f016e53d9594 \
-    --hash=sha256:9938bd39f5041797b33921066fba0eab03a0dd10d1887a05e62ae58841ad4c3f \
-    --hash=sha256:be8d016a1aa394e6198280d4a3dc37898f56467310c5f5e617cac10a783e055a
+tensorboard-data-server==0.7.2 \
+    --hash=sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb \
+    --hash=sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60 \
+    --hash=sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530
     # via tb-nightly
 termcolor==2.3.0 \
     --hash=sha256:3afb05607b89aed0ffe25202399ee0867ad4d3cb4180d98aaf8eefa6a5f7d475 \
@@ -553,13 +553,13 @@
     --hash=sha256:c97dfde1f7bd43a71c8d2a58e369e9b2bf692d1334ea9f9cae55add7d0dd0f84 \
     --hash=sha256:fdb6d215c776278489906c2f8916e6e7d4f5a9b602ccbcfdf7f016fc8da0596e
     # via requests
-werkzeug==3.0.0 \
-    --hash=sha256:3ffff4dcc32db52ef3cc94dff3000a3c2846890f3a5a51800a27b909c5e770f0 \
-    --hash=sha256:cbb2600f7eabe51dbc0502f58be0b3e1b96b893b05695ea2b35b43d4de2d9962
+werkzeug==3.0.1 \
+    --hash=sha256:507e811ecea72b18a404947aded4b3390e1db8f826b494d76550ef45bb3b1dcc \
+    --hash=sha256:90a285dc0e42ad56b34e696398b8122ee4c681833fb35b8334a095d82c56da10
     # via tb-nightly
-wheel==0.41.2 \
-    --hash=sha256:0c5ac5ff2afb79ac23ab82bab027a0be7b5dbcf2e54dc50efe4bf507de1f7985 \
-    --hash=sha256:75909db2664838d015e3d9139004ee16711748a52c8f336b52882266540215d8
+wheel==0.41.3 \
+    --hash=sha256:488609bc63a29322326e05560731bf7bfea8e48ad646e1f5e40d366607de0942 \
+    --hash=sha256:4d4987ce51a49370ea65c0bfd2234e8ce80a12780820d9dc462597a6e60d0841
     # via
     #   -r requirements.in
     #   astunparse
diff --git a/requirements_lock_3_11.txt b/requirements_lock_3_11.txt
index b9aa6d0..534c341 100644
--- a/requirements_lock_3_11.txt
+++ b/requirements_lock_3_11.txt
@@ -12,105 +12,105 @@
     --hash=sha256:5ad93a8456f0d084c3456d059fd9a92cce667963232cbf763eac3bc5b7940872 \
     --hash=sha256:c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8
     # via -r requirements.in
-cachetools==5.3.1 \
-    --hash=sha256:95ef631eeaea14ba2e36f06437f36463aac3a096799e876ee55e5cdccb102590 \
-    --hash=sha256:dce83f2d9b4e1f732a8cd44af8e8fab2dbe46201467fc98b3ef8f269092bf62b
+cachetools==5.3.2 \
+    --hash=sha256:086ee420196f7b2ab9ca2db2520aca326318b68fe5ba8bc4d49cca91add450f2 \
+    --hash=sha256:861f35a13a451f94e301ce2bec7cac63e881232ccce7ed67fab9b5df4d3beaa1
     # via google-auth
 certifi==2023.7.22 \
     --hash=sha256:539cc1d13202e33ca466e88b2807e29f4c13049d6d87031a3c110744495cb082 \
     --hash=sha256:92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9
     # via requests
-charset-normalizer==3.3.0 \
-    --hash=sha256:02673e456dc5ab13659f85196c534dc596d4ef260e4d86e856c3b2773ce09843 \
-    --hash=sha256:02af06682e3590ab952599fbadac535ede5d60d78848e555aa58d0c0abbde786 \
-    --hash=sha256:03680bb39035fbcffe828eae9c3f8afc0428c91d38e7d61aa992ef7a59fb120e \
-    --hash=sha256:0570d21da019941634a531444364f2482e8db0b3425fcd5ac0c36565a64142c8 \
-    --hash=sha256:09c77f964f351a7369cc343911e0df63e762e42bac24cd7d18525961c81754f4 \
-    --hash=sha256:0d3d5b7db9ed8a2b11a774db2bbea7ba1884430a205dbd54a32d61d7c2a190fa \
-    --hash=sha256:1063da2c85b95f2d1a430f1c33b55c9c17ffaf5e612e10aeaad641c55a9e2b9d \
-    --hash=sha256:12ebea541c44fdc88ccb794a13fe861cc5e35d64ed689513a5c03d05b53b7c82 \
-    --hash=sha256:153e7b6e724761741e0974fc4dcd406d35ba70b92bfe3fedcb497226c93b9da7 \
-    --hash=sha256:15b26ddf78d57f1d143bdf32e820fd8935d36abe8a25eb9ec0b5a71c82eb3895 \
-    --hash=sha256:1872d01ac8c618a8da634e232f24793883d6e456a66593135aeafe3784b0848d \
-    --hash=sha256:187d18082694a29005ba2944c882344b6748d5be69e3a89bf3cc9d878e548d5a \
-    --hash=sha256:1b2919306936ac6efb3aed1fbf81039f7087ddadb3160882a57ee2ff74fd2382 \
-    --hash=sha256:232ac332403e37e4a03d209a3f92ed9071f7d3dbda70e2a5e9cff1c4ba9f0678 \
-    --hash=sha256:23e8565ab7ff33218530bc817922fae827420f143479b753104ab801145b1d5b \
-    --hash=sha256:24817cb02cbef7cd499f7c9a2735286b4782bd47a5b3516a0e84c50eab44b98e \
-    --hash=sha256:249c6470a2b60935bafd1d1d13cd613f8cd8388d53461c67397ee6a0f5dce741 \
-    --hash=sha256:24a91a981f185721542a0b7c92e9054b7ab4fea0508a795846bc5b0abf8118d4 \
-    --hash=sha256:2502dd2a736c879c0f0d3e2161e74d9907231e25d35794584b1ca5284e43f596 \
-    --hash=sha256:250c9eb0f4600361dd80d46112213dff2286231d92d3e52af1e5a6083d10cad9 \
-    --hash=sha256:278c296c6f96fa686d74eb449ea1697f3c03dc28b75f873b65b5201806346a69 \
-    --hash=sha256:2935ffc78db9645cb2086c2f8f4cfd23d9b73cc0dc80334bc30aac6f03f68f8c \
-    --hash=sha256:2f4a0033ce9a76e391542c182f0d48d084855b5fcba5010f707c8e8c34663d77 \
-    --hash=sha256:30a85aed0b864ac88309b7d94be09f6046c834ef60762a8833b660139cfbad13 \
-    --hash=sha256:380c4bde80bce25c6e4f77b19386f5ec9db230df9f2f2ac1e5ad7af2caa70459 \
-    --hash=sha256:3ae38d325b512f63f8da31f826e6cb6c367336f95e418137286ba362925c877e \
-    --hash=sha256:3b447982ad46348c02cb90d230b75ac34e9886273df3a93eec0539308a6296d7 \
-    --hash=sha256:3debd1150027933210c2fc321527c2299118aa929c2f5a0a80ab6953e3bd1908 \
-    --hash=sha256:4162918ef3098851fcd8a628bf9b6a98d10c380725df9e04caf5ca6dd48c847a \
-    --hash=sha256:468d2a840567b13a590e67dd276c570f8de00ed767ecc611994c301d0f8c014f \
-    --hash=sha256:4cc152c5dd831641e995764f9f0b6589519f6f5123258ccaca8c6d34572fefa8 \
-    --hash=sha256:542da1178c1c6af8873e143910e2269add130a299c9106eef2594e15dae5e482 \
-    --hash=sha256:557b21a44ceac6c6b9773bc65aa1b4cc3e248a5ad2f5b914b91579a32e22204d \
-    --hash=sha256:5707a746c6083a3a74b46b3a631d78d129edab06195a92a8ece755aac25a3f3d \
-    --hash=sha256:588245972aca710b5b68802c8cad9edaa98589b1b42ad2b53accd6910dad3545 \
-    --hash=sha256:5adf257bd58c1b8632046bbe43ee38c04e1038e9d37de9c57a94d6bd6ce5da34 \
-    --hash=sha256:619d1c96099be5823db34fe89e2582b336b5b074a7f47f819d6b3a57ff7bdb86 \
-    --hash=sha256:63563193aec44bce707e0c5ca64ff69fa72ed7cf34ce6e11d5127555756fd2f6 \
-    --hash=sha256:67b8cc9574bb518ec76dc8e705d4c39ae78bb96237cb533edac149352c1f39fe \
-    --hash=sha256:6a685067d05e46641d5d1623d7c7fdf15a357546cbb2f71b0ebde91b175ffc3e \
-    --hash=sha256:70f1d09c0d7748b73290b29219e854b3207aea922f839437870d8cc2168e31cc \
-    --hash=sha256:750b446b2ffce1739e8578576092179160f6d26bd5e23eb1789c4d64d5af7dc7 \
-    --hash=sha256:7966951325782121e67c81299a031f4c115615e68046f79b85856b86ebffc4cd \
-    --hash=sha256:7b8b8bf1189b3ba9b8de5c8db4d541b406611a71a955bbbd7385bbc45fcb786c \
-    --hash=sha256:7f5d10bae5d78e4551b7be7a9b29643a95aded9d0f602aa2ba584f0388e7a557 \
-    --hash=sha256:805dfea4ca10411a5296bcc75638017215a93ffb584c9e344731eef0dcfb026a \
-    --hash=sha256:81bf654678e575403736b85ba3a7867e31c2c30a69bc57fe88e3ace52fb17b89 \
-    --hash=sha256:82eb849f085624f6a607538ee7b83a6d8126df6d2f7d3b319cb837b289123078 \
-    --hash=sha256:85a32721ddde63c9df9ebb0d2045b9691d9750cb139c161c80e500d210f5e26e \
-    --hash=sha256:86d1f65ac145e2c9ed71d8ffb1905e9bba3a91ae29ba55b4c46ae6fc31d7c0d4 \
-    --hash=sha256:86f63face3a527284f7bb8a9d4f78988e3c06823f7bea2bd6f0e0e9298ca0403 \
-    --hash=sha256:8eaf82f0eccd1505cf39a45a6bd0a8cf1c70dcfc30dba338207a969d91b965c0 \
-    --hash=sha256:93aa7eef6ee71c629b51ef873991d6911b906d7312c6e8e99790c0f33c576f89 \
-    --hash=sha256:96c2b49eb6a72c0e4991d62406e365d87067ca14c1a729a870d22354e6f68115 \
-    --hash=sha256:9cf3126b85822c4e53aa28c7ec9869b924d6fcfb76e77a45c44b83d91afd74f9 \
-    --hash=sha256:9fe359b2e3a7729010060fbca442ca225280c16e923b37db0e955ac2a2b72a05 \
-    --hash=sha256:a0ac5e7015a5920cfce654c06618ec40c33e12801711da6b4258af59a8eff00a \
-    --hash=sha256:a3f93dab657839dfa61025056606600a11d0b696d79386f974e459a3fbc568ec \
-    --hash=sha256:a4b71f4d1765639372a3b32d2638197f5cd5221b19531f9245fcc9ee62d38f56 \
-    --hash=sha256:aae32c93e0f64469f74ccc730a7cb21c7610af3a775157e50bbd38f816536b38 \
-    --hash=sha256:aaf7b34c5bc56b38c931a54f7952f1ff0ae77a2e82496583b247f7c969eb1479 \
-    --hash=sha256:abecce40dfebbfa6abf8e324e1860092eeca6f7375c8c4e655a8afb61af58f2c \
-    --hash=sha256:abf0d9f45ea5fb95051c8bfe43cb40cda383772f7e5023a83cc481ca2604d74e \
-    --hash=sha256:ac71b2977fb90c35d41c9453116e283fac47bb9096ad917b8819ca8b943abecd \
-    --hash=sha256:ada214c6fa40f8d800e575de6b91a40d0548139e5dc457d2ebb61470abf50186 \
-    --hash=sha256:b09719a17a2301178fac4470d54b1680b18a5048b481cb8890e1ef820cb80455 \
-    --hash=sha256:b1121de0e9d6e6ca08289583d7491e7fcb18a439305b34a30b20d8215922d43c \
-    --hash=sha256:b3b2316b25644b23b54a6f6401074cebcecd1244c0b8e80111c9a3f1c8e83d65 \
-    --hash=sha256:b3d9b48ee6e3967b7901c052b670c7dda6deb812c309439adaffdec55c6d7b78 \
-    --hash=sha256:b5bcf60a228acae568e9911f410f9d9e0d43197d030ae5799e20dca8df588287 \
-    --hash=sha256:b8f3307af845803fb0b060ab76cf6dd3a13adc15b6b451f54281d25911eb92df \
-    --hash=sha256:c2af80fb58f0f24b3f3adcb9148e6203fa67dd3f61c4af146ecad033024dde43 \
-    --hash=sha256:c350354efb159b8767a6244c166f66e67506e06c8924ed74669b2c70bc8735b1 \
-    --hash=sha256:c5a74c359b2d47d26cdbbc7845e9662d6b08a1e915eb015d044729e92e7050b7 \
-    --hash=sha256:c71f16da1ed8949774ef79f4a0260d28b83b3a50c6576f8f4f0288d109777989 \
-    --hash=sha256:d47ecf253780c90ee181d4d871cd655a789da937454045b17b5798da9393901a \
-    --hash=sha256:d7eff0f27edc5afa9e405f7165f85a6d782d308f3b6b9d96016c010597958e63 \
-    --hash=sha256:d97d85fa63f315a8bdaba2af9a6a686e0eceab77b3089af45133252618e70884 \
-    --hash=sha256:db756e48f9c5c607b5e33dd36b1d5872d0422e960145b08ab0ec7fd420e9d649 \
-    --hash=sha256:dc45229747b67ffc441b3de2f3ae5e62877a282ea828a5bdb67883c4ee4a8810 \
-    --hash=sha256:e0fc42822278451bc13a2e8626cf2218ba570f27856b536e00cfa53099724828 \
-    --hash=sha256:e39c7eb31e3f5b1f88caff88bcff1b7f8334975b46f6ac6e9fc725d829bc35d4 \
-    --hash=sha256:e46cd37076971c1040fc8c41273a8b3e2c624ce4f2be3f5dfcb7a430c1d3acc2 \
-    --hash=sha256:e5c1502d4ace69a179305abb3f0bb6141cbe4714bc9b31d427329a95acfc8bdd \
-    --hash=sha256:edfe077ab09442d4ef3c52cb1f9dab89bff02f4524afc0acf2d46be17dc479f5 \
-    --hash=sha256:effe5406c9bd748a871dbcaf3ac69167c38d72db8c9baf3ff954c344f31c4cbe \
-    --hash=sha256:f0d1e3732768fecb052d90d62b220af62ead5748ac51ef61e7b32c266cac9293 \
-    --hash=sha256:f5969baeaea61c97efa706b9b107dcba02784b1601c74ac84f2a532ea079403e \
-    --hash=sha256:f8888e31e3a85943743f8fc15e71536bda1c81d5aa36d014a3c0c44481d7db6e \
-    --hash=sha256:fc52b79d83a3fe3a360902d3f5d79073a993597d48114c29485e9431092905d8
+charset-normalizer==3.3.1 \
+    --hash=sha256:06cf46bdff72f58645434d467bf5228080801298fbba19fe268a01b4534467f5 \
+    --hash=sha256:0c8c61fb505c7dad1d251c284e712d4e0372cef3b067f7ddf82a7fa82e1e9a93 \
+    --hash=sha256:10b8dd31e10f32410751b3430996f9807fc4d1587ca69772e2aa940a82ab571a \
+    --hash=sha256:1171ef1fc5ab4693c5d151ae0fdad7f7349920eabbaca6271f95969fa0756c2d \
+    --hash=sha256:17a866d61259c7de1bdadef418a37755050ddb4b922df8b356503234fff7932c \
+    --hash=sha256:1d6bfc32a68bc0933819cfdfe45f9abc3cae3877e1d90aac7259d57e6e0f85b1 \
+    --hash=sha256:1ec937546cad86d0dce5396748bf392bb7b62a9eeb8c66efac60e947697f0e58 \
+    --hash=sha256:223b4d54561c01048f657fa6ce41461d5ad8ff128b9678cfe8b2ecd951e3f8a2 \
+    --hash=sha256:2465aa50c9299d615d757c1c888bc6fef384b7c4aec81c05a0172b4400f98557 \
+    --hash=sha256:28f512b9a33235545fbbdac6a330a510b63be278a50071a336afc1b78781b147 \
+    --hash=sha256:2c092be3885a1b7899cd85ce24acedc1034199d6fca1483fa2c3a35c86e43041 \
+    --hash=sha256:2c4c99f98fc3a1835af8179dcc9013f93594d0670e2fa80c83aa36346ee763d2 \
+    --hash=sha256:31445f38053476a0c4e6d12b047b08ced81e2c7c712e5a1ad97bc913256f91b2 \
+    --hash=sha256:31bbaba7218904d2eabecf4feec0d07469284e952a27400f23b6628439439fa7 \
+    --hash=sha256:34d95638ff3613849f473afc33f65c401a89f3b9528d0d213c7037c398a51296 \
+    --hash=sha256:352a88c3df0d1fa886562384b86f9a9e27563d4704ee0e9d56ec6fcd270ea690 \
+    --hash=sha256:39b70a6f88eebe239fa775190796d55a33cfb6d36b9ffdd37843f7c4c1b5dc67 \
+    --hash=sha256:3c66df3f41abee950d6638adc7eac4730a306b022570f71dd0bd6ba53503ab57 \
+    --hash=sha256:3f70fd716855cd3b855316b226a1ac8bdb3caf4f7ea96edcccc6f484217c9597 \
+    --hash=sha256:3f9bc2ce123637a60ebe819f9fccc614da1bcc05798bbbaf2dd4ec91f3e08846 \
+    --hash=sha256:3fb765362688821404ad6cf86772fc54993ec11577cd5a92ac44b4c2ba52155b \
+    --hash=sha256:45f053a0ece92c734d874861ffe6e3cc92150e32136dd59ab1fb070575189c97 \
+    --hash=sha256:46fb9970aa5eeca547d7aa0de5d4b124a288b42eaefac677bde805013c95725c \
+    --hash=sha256:4cb50a0335382aac15c31b61d8531bc9bb657cfd848b1d7158009472189f3d62 \
+    --hash=sha256:4e12f8ee80aa35e746230a2af83e81bd6b52daa92a8afaef4fea4a2ce9b9f4fa \
+    --hash=sha256:4f3100d86dcd03c03f7e9c3fdb23d92e32abbca07e7c13ebd7ddfbcb06f5991f \
+    --hash=sha256:4f6e2a839f83a6a76854d12dbebde50e4b1afa63e27761549d006fa53e9aa80e \
+    --hash=sha256:4f861d94c2a450b974b86093c6c027888627b8082f1299dfd5a4bae8e2292821 \
+    --hash=sha256:501adc5eb6cd5f40a6f77fbd90e5ab915c8fd6e8c614af2db5561e16c600d6f3 \
+    --hash=sha256:520b7a142d2524f999447b3a0cf95115df81c4f33003c51a6ab637cbda9d0bf4 \
+    --hash=sha256:548eefad783ed787b38cb6f9a574bd8664468cc76d1538215d510a3cd41406cb \
+    --hash=sha256:555fe186da0068d3354cdf4bbcbc609b0ecae4d04c921cc13e209eece7720727 \
+    --hash=sha256:55602981b2dbf8184c098bc10287e8c245e351cd4fdcad050bd7199d5a8bf514 \
+    --hash=sha256:58e875eb7016fd014c0eea46c6fa92b87b62c0cb31b9feae25cbbe62c919f54d \
+    --hash=sha256:5a3580a4fdc4ac05f9e53c57f965e3594b2f99796231380adb2baaab96e22761 \
+    --hash=sha256:5b70bab78accbc672f50e878a5b73ca692f45f5b5e25c8066d748c09405e6a55 \
+    --hash=sha256:5ceca5876032362ae73b83347be8b5dbd2d1faf3358deb38c9c88776779b2e2f \
+    --hash=sha256:61f1e3fb621f5420523abb71f5771a204b33c21d31e7d9d86881b2cffe92c47c \
+    --hash=sha256:633968254f8d421e70f91c6ebe71ed0ab140220469cf87a9857e21c16687c034 \
+    --hash=sha256:63a6f59e2d01310f754c270e4a257426fe5a591dc487f1983b3bbe793cf6bac6 \
+    --hash=sha256:63accd11149c0f9a99e3bc095bbdb5a464862d77a7e309ad5938fbc8721235ae \
+    --hash=sha256:6db3cfb9b4fcecb4390db154e75b49578c87a3b9979b40cdf90d7e4b945656e1 \
+    --hash=sha256:71ef3b9be10070360f289aea4838c784f8b851be3ba58cf796262b57775c2f14 \
+    --hash=sha256:7ae8e5142dcc7a49168f4055255dbcced01dc1714a90a21f87448dc8d90617d1 \
+    --hash=sha256:7b6cefa579e1237ce198619b76eaa148b71894fb0d6bcf9024460f9bf30fd228 \
+    --hash=sha256:800561453acdecedaac137bf09cd719c7a440b6800ec182f077bb8e7025fb708 \
+    --hash=sha256:82ca51ff0fc5b641a2d4e1cc8c5ff108699b7a56d7f3ad6f6da9dbb6f0145b48 \
+    --hash=sha256:851cf693fb3aaef71031237cd68699dded198657ec1e76a76eb8be58c03a5d1f \
+    --hash=sha256:854cc74367180beb327ab9d00f964f6d91da06450b0855cbbb09187bcdb02de5 \
+    --hash=sha256:87071618d3d8ec8b186d53cb6e66955ef2a0e4fa63ccd3709c0c90ac5a43520f \
+    --hash=sha256:871d045d6ccc181fd863a3cd66ee8e395523ebfbc57f85f91f035f50cee8e3d4 \
+    --hash=sha256:8aee051c89e13565c6bd366813c386939f8e928af93c29fda4af86d25b73d8f8 \
+    --hash=sha256:8af5a8917b8af42295e86b64903156b4f110a30dca5f3b5aedea123fbd638bff \
+    --hash=sha256:8ec8ef42c6cd5856a7613dcd1eaf21e5573b2185263d87d27c8edcae33b62a61 \
+    --hash=sha256:91e43805ccafa0a91831f9cd5443aa34528c0c3f2cc48c4cb3d9a7721053874b \
+    --hash=sha256:9505dc359edb6a330efcd2be825fdb73ee3e628d9010597aa1aee5aa63442e97 \
+    --hash=sha256:985c7965f62f6f32bf432e2681173db41336a9c2611693247069288bcb0c7f8b \
+    --hash=sha256:9a74041ba0bfa9bc9b9bb2cd3238a6ab3b7618e759b41bd15b5f6ad958d17605 \
+    --hash=sha256:9edbe6a5bf8b56a4a84533ba2b2f489d0046e755c29616ef8830f9e7d9cf5728 \
+    --hash=sha256:a15c1fe6d26e83fd2e5972425a772cca158eae58b05d4a25a4e474c221053e2d \
+    --hash=sha256:a66bcdf19c1a523e41b8e9d53d0cedbfbac2e93c649a2e9502cb26c014d0980c \
+    --hash=sha256:ae4070f741f8d809075ef697877fd350ecf0b7c5837ed68738607ee0a2c572cf \
+    --hash=sha256:ae55d592b02c4349525b6ed8f74c692509e5adffa842e582c0f861751701a673 \
+    --hash=sha256:b578cbe580e3b41ad17b1c428f382c814b32a6ce90f2d8e39e2e635d49e498d1 \
+    --hash=sha256:b891a2f68e09c5ef989007fac11476ed33c5c9994449a4e2c3386529d703dc8b \
+    --hash=sha256:baec8148d6b8bd5cee1ae138ba658c71f5b03e0d69d5907703e3e1df96db5e41 \
+    --hash=sha256:bb06098d019766ca16fc915ecaa455c1f1cd594204e7f840cd6258237b5079a8 \
+    --hash=sha256:bc791ec3fd0c4309a753f95bb6c749ef0d8ea3aea91f07ee1cf06b7b02118f2f \
+    --hash=sha256:bd28b31730f0e982ace8663d108e01199098432a30a4c410d06fe08fdb9e93f4 \
+    --hash=sha256:be4d9c2770044a59715eb57c1144dedea7c5d5ae80c68fb9959515037cde2008 \
+    --hash=sha256:c0c72d34e7de5604df0fde3644cc079feee5e55464967d10b24b1de268deceb9 \
+    --hash=sha256:c0e842112fe3f1a4ffcf64b06dc4c61a88441c2f02f373367f7b4c1aa9be2ad5 \
+    --hash=sha256:c15070ebf11b8b7fd1bfff7217e9324963c82dbdf6182ff7050519e350e7ad9f \
+    --hash=sha256:c2000c54c395d9e5e44c99dc7c20a64dc371f777faf8bae4919ad3e99ce5253e \
+    --hash=sha256:c30187840d36d0ba2893bc3271a36a517a717f9fd383a98e2697ee890a37c273 \
+    --hash=sha256:cb7cd68814308aade9d0c93c5bd2ade9f9441666f8ba5aa9c2d4b389cb5e2a45 \
+    --hash=sha256:cd805513198304026bd379d1d516afbf6c3c13f4382134a2c526b8b854da1c2e \
+    --hash=sha256:d0bf89afcbcf4d1bb2652f6580e5e55a840fdf87384f6063c4a4f0c95e378656 \
+    --hash=sha256:d9137a876020661972ca6eec0766d81aef8a5627df628b664b234b73396e727e \
+    --hash=sha256:dbd95e300367aa0827496fe75a1766d198d34385a58f97683fe6e07f89ca3e3c \
+    --hash=sha256:dced27917823df984fe0c80a5c4ad75cf58df0fbfae890bc08004cd3888922a2 \
+    --hash=sha256:de0b4caa1c8a21394e8ce971997614a17648f94e1cd0640fbd6b4d14cab13a72 \
+    --hash=sha256:debb633f3f7856f95ad957d9b9c781f8e2c6303ef21724ec94bea2ce2fcbd056 \
+    --hash=sha256:e372d7dfd154009142631de2d316adad3cc1c36c32a38b16a4751ba78da2a397 \
+    --hash=sha256:ecd26be9f112c4f96718290c10f4caea6cc798459a3a76636b817a0ed7874e42 \
+    --hash=sha256:edc0202099ea1d82844316604e17d2b175044f9bcb6b398aab781eba957224bd \
+    --hash=sha256:f194cce575e59ffe442c10a360182a986535fd90b57f7debfaa5c845c409ecc3 \
+    --hash=sha256:f5fb672c396d826ca16a022ac04c9dce74e00a1c344f6ad1a0fdc1ba1f332213 \
+    --hash=sha256:f6a02a3c7950cafaadcd46a226ad9e12fc9744652cc69f9e5534f98b47f3bbcf \
+    --hash=sha256:fe81b35c33772e56f4b6cf62cf4aedc1762ef7162a31e6ac7fe5e40d0149eb67
     # via requests
 dill==0.3.7 \
     --hash=sha256:76b122c08ef4ce2eedcd4d1abd8e641114bfc6c2867f49f3c41facf65bf19f5e \
@@ -171,61 +171,61 @@
     --hash=sha256:089c6e587d36f4803ac7e0720c045c6a8b1fd1790088b8424975b90d0ee61c12 \
     --hash=sha256:83ea8c3b0881e453790baff4448e8a6112ac8778d1de9da0b68010b843937afb
     # via tb-nightly
-grpcio==1.59.0 \
-    --hash=sha256:0ae444221b2c16d8211b55326f8ba173ba8f8c76349bfc1768198ba592b58f74 \
-    --hash=sha256:0b84445fa94d59e6806c10266b977f92fa997db3585f125d6b751af02ff8b9fe \
-    --hash=sha256:14890da86a0c0e9dc1ea8e90101d7a3e0e7b1e71f4487fab36e2bfd2ecadd13c \
-    --hash=sha256:15f03bd714f987d48ae57fe092cf81960ae36da4e520e729392a59a75cda4f29 \
-    --hash=sha256:1a839ba86764cc48226f50b924216000c79779c563a301586a107bda9cbe9dcf \
-    --hash=sha256:225e5fa61c35eeaebb4e7491cd2d768cd8eb6ed00f2664fa83a58f29418b39fd \
-    --hash=sha256:228b91ce454876d7eed74041aff24a8f04c0306b7250a2da99d35dd25e2a1211 \
-    --hash=sha256:2ea95cd6abbe20138b8df965b4a8674ec312aaef3147c0f46a0bac661f09e8d0 \
-    --hash=sha256:2f120d27051e4c59db2f267b71b833796770d3ea36ca712befa8c5fff5da6ebd \
-    --hash=sha256:34341d9e81a4b669a5f5dca3b2a760b6798e95cdda2b173e65d29d0b16692857 \
-    --hash=sha256:3859917de234a0a2a52132489c4425a73669de9c458b01c9a83687f1f31b5b10 \
-    --hash=sha256:38823bd088c69f59966f594d087d3a929d1ef310506bee9e3648317660d65b81 \
-    --hash=sha256:38da5310ef84e16d638ad89550b5b9424df508fd5c7b968b90eb9629ca9be4b9 \
-    --hash=sha256:3b8ff795d35a93d1df6531f31c1502673d1cebeeba93d0f9bd74617381507e3f \
-    --hash=sha256:50eff97397e29eeee5df106ea1afce3ee134d567aa2c8e04fabab05c79d791a7 \
-    --hash=sha256:5711c51e204dc52065f4a3327dca46e69636a0b76d3e98c2c28c4ccef9b04c52 \
-    --hash=sha256:598f3530231cf10ae03f4ab92d48c3be1fee0c52213a1d5958df1a90957e6a88 \
-    --hash=sha256:611d9aa0017fa386809bddcb76653a5ab18c264faf4d9ff35cb904d44745f575 \
-    --hash=sha256:61bc72a00ecc2b79d9695220b4d02e8ba53b702b42411397e831c9b0589f08a3 \
-    --hash=sha256:63982150a7d598281fa1d7ffead6096e543ff8be189d3235dd2b5604f2c553e5 \
-    --hash=sha256:6c4b1cc3a9dc1924d2eb26eec8792fedd4b3fcd10111e26c1d551f2e4eda79ce \
-    --hash=sha256:81d86a096ccd24a57fa5772a544c9e566218bc4de49e8c909882dae9d73392df \
-    --hash=sha256:849c47ef42424c86af069a9c5e691a765e304079755d5c29eff511263fad9c2a \
-    --hash=sha256:871371ce0c0055d3db2a86fdebd1e1d647cf21a8912acc30052660297a5a6901 \
-    --hash=sha256:8cd2d38c2d52f607d75a74143113174c36d8a416d9472415eab834f837580cf7 \
-    --hash=sha256:936b2e04663660c600d5173bc2cc84e15adbad9c8f71946eb833b0afc205b996 \
-    --hash=sha256:93e9cb546e610829e462147ce724a9cb108e61647a3454500438a6deef610be1 \
-    --hash=sha256:956f0b7cb465a65de1bd90d5a7475b4dc55089b25042fe0f6c870707e9aabb1d \
-    --hash=sha256:986de4aa75646e963466b386a8c5055c8b23a26a36a6c99052385d6fe8aaf180 \
-    --hash=sha256:aca8a24fef80bef73f83eb8153f5f5a0134d9539b4c436a716256b311dda90a6 \
-    --hash=sha256:acf70a63cf09dd494000007b798aff88a436e1c03b394995ce450be437b8e54f \
-    --hash=sha256:b34c7a4c31841a2ea27246a05eed8a80c319bfc0d3e644412ec9ce437105ff6c \
-    --hash=sha256:b95ec8ecc4f703f5caaa8d96e93e40c7f589bad299a2617bdb8becbcce525539 \
-    --hash=sha256:ba0ca727a173ee093f49ead932c051af463258b4b493b956a2c099696f38aa66 \
-    --hash=sha256:c041a91712bf23b2a910f61e16565a05869e505dc5a5c025d429ca6de5de842c \
-    --hash=sha256:c0488c2b0528e6072010182075615620071371701733c63ab5be49140ed8f7f0 \
-    --hash=sha256:c173a87d622ea074ce79be33b952f0b424fa92182063c3bda8625c11d3585d09 \
-    --hash=sha256:c251d22de8f9f5cca9ee47e4bade7c5c853e6e40743f47f5cc02288ee7a87252 \
-    --hash=sha256:c4dfdb49f4997dc664f30116af2d34751b91aa031f8c8ee251ce4dcfc11277b0 \
-    --hash=sha256:ca87ee6183421b7cea3544190061f6c1c3dfc959e0b57a5286b108511fd34ff4 \
-    --hash=sha256:ceb1e68135788c3fce2211de86a7597591f0b9a0d2bb80e8401fd1d915991bac \
-    --hash=sha256:d09bd2a4e9f5a44d36bb8684f284835c14d30c22d8ec92ce796655af12163588 \
-    --hash=sha256:d0fcf53df684fcc0154b1e61f6b4a8c4cf5f49d98a63511e3f30966feff39cd0 \
-    --hash=sha256:d74f7d2d7c242a6af9d4d069552ec3669965b74fed6b92946e0e13b4168374f9 \
-    --hash=sha256:de2599985b7c1b4ce7526e15c969d66b93687571aa008ca749d6235d056b7205 \
-    --hash=sha256:e5378785dce2b91eb2e5b857ec7602305a3b5cf78311767146464bfa365fc897 \
-    --hash=sha256:ec78aebb9b6771d6a1de7b6ca2f779a2f6113b9108d486e904bde323d51f5589 \
-    --hash=sha256:f1feb034321ae2f718172d86b8276c03599846dc7bb1792ae370af02718f91c5 \
-    --hash=sha256:f21917aa50b40842b51aff2de6ebf9e2f6af3fe0971c31960ad6a3a2b24988f4 \
-    --hash=sha256:f367e4b524cb319e50acbdea57bb63c3b717c5d561974ace0b065a648bb3bad3 \
-    --hash=sha256:f6cfe44a5d7c7d5f1017a7da1c8160304091ca5dc64a0f85bca0d63008c3137a \
-    --hash=sha256:fa66cac32861500f280bb60fe7d5b3e22d68c51e18e65367e38f8669b78cea3b \
-    --hash=sha256:fc8bf2e7bc725e76c0c11e474634a08c8f24bcf7426c0c6d60c8f9c6e70e4d4a \
-    --hash=sha256:fe976910de34d21057bcb53b2c5e667843588b48bf11339da2a75f5c4c5b4055
+grpcio==1.59.2 \
+    --hash=sha256:023088764012411affe7db183d1ada3ad9daf2e23ddc719ff46d7061de661340 \
+    --hash=sha256:08d77e682f2bf730a4961eea330e56d2f423c6a9b91ca222e5b1eb24a357b19f \
+    --hash=sha256:0a4a3833c0e067f3558538727235cd8a49709bff1003200bbdefa2f09334e4b1 \
+    --hash=sha256:0a754aff9e3af63bdc4c75c234b86b9d14e14a28a30c4e324aed1a9b873d755f \
+    --hash=sha256:11168ef43e4a43ff1b1a65859f3e0ef1a173e277349e7fb16923ff108160a8cd \
+    --hash=sha256:128e20f57c5f27cb0157e73756d1586b83c1b513ebecc83ea0ac37e4b0e4e758 \
+    --hash=sha256:1f9524d1d701e399462d2c90ba7c193e49d1711cf429c0d3d97c966856e03d00 \
+    --hash=sha256:1ff16d68bf453275466a9a46739061a63584d92f18a0f5b33d19fc97eb69867c \
+    --hash=sha256:2067274c88bc6de89c278a672a652b4247d088811ece781a4858b09bdf8448e3 \
+    --hash=sha256:2171c39f355ba5b551c5d5928d65aa6c69807fae195b86ef4a7d125bcdb860a9 \
+    --hash=sha256:242adc47725b9a499ee77c6a2e36688fa6c96484611f33b1be4c57ab075a92dd \
+    --hash=sha256:27f879ae604a7fcf371e59fba6f3ff4635a4c2a64768bd83ff0cac503142fef4 \
+    --hash=sha256:2b230028a008ae1d0f430acb227d323ff8a619017415cf334c38b457f814119f \
+    --hash=sha256:3059668df17627f0e0fa680e9ef8c995c946c792612e9518f5cc1503be14e90b \
+    --hash=sha256:31176aa88f36020055ace9adff2405a33c8bdbfa72a9c4980e25d91b2f196873 \
+    --hash=sha256:36f53c2b3449c015880e7d55a89c992c357f176327b0d2873cdaaf9628a37c69 \
+    --hash=sha256:3b4368b33908f683a363f376dfb747d40af3463a6e5044afee07cf9436addf96 \
+    --hash=sha256:3c61d641d4f409c5ae46bfdd89ea42ce5ea233dcf69e74ce9ba32b503c727e29 \
+    --hash=sha256:4abb717e320e74959517dc8e84a9f48fbe90e9abe19c248541e9418b1ce60acd \
+    --hash=sha256:4c93f4abbb54321ee6471e04a00139c80c754eda51064187963ddf98f5cf36a4 \
+    --hash=sha256:535561990e075fa6bd4b16c4c3c1096b9581b7bb35d96fac4650f1181e428268 \
+    --hash=sha256:53c9aa5ddd6857c0a1cd0287225a2a25873a8e09727c2e95c4aebb1be83a766a \
+    --hash=sha256:5d573e70a6fe77555fb6143c12d3a7d3fa306632a3034b4e7c59ca09721546f8 \
+    --hash=sha256:6009386a2df66159f64ac9f20425ae25229b29b9dd0e1d3dd60043f037e2ad7e \
+    --hash=sha256:686e975a5d16602dc0982c7c703948d17184bd1397e16c8ee03511ecb8c4cdda \
+    --hash=sha256:6959fb07e8351e20501ffb8cc4074c39a0b7ef123e1c850a7f8f3afdc3a3da01 \
+    --hash=sha256:6b25ed37c27e652db01be341af93fbcea03d296c024d8a0e680017a268eb85dd \
+    --hash=sha256:6da6dea3a1bacf99b3c2187e296db9a83029ed9c38fd4c52b7c9b7326d13c828 \
+    --hash=sha256:72ca2399097c0b758198f2ff30f7178d680de8a5cfcf3d9b73a63cf87455532e \
+    --hash=sha256:73abb8584b0cf74d37f5ef61c10722adc7275502ab71789a8fe3cb7ef04cf6e2 \
+    --hash=sha256:74100fecaec8a535e380cf5f2fb556ff84957d481c13e54051c52e5baac70541 \
+    --hash=sha256:75c6ecb70e809cf1504465174343113f51f24bc61e22a80ae1c859f3f7034c6d \
+    --hash=sha256:7cf05053242f61ba94014dd3a986e11a083400a32664058f80bf4cf817c0b3a1 \
+    --hash=sha256:9411e24328a2302e279e70cae6e479f1fddde79629fcb14e03e6d94b3956eabf \
+    --hash=sha256:a213acfbf186b9f35803b52e4ca9addb153fc0b67f82a48f961be7000ecf6721 \
+    --hash=sha256:bb7e0fe6ad73b7f06d7e2b689c19a71cf5cc48f0c2bf8608469e51ffe0bd2867 \
+    --hash=sha256:c2504eed520958a5b77cc99458297cb7906308cb92327f35fb7fbbad4e9b2188 \
+    --hash=sha256:c35aa9657f5d5116d23b934568e0956bd50c615127810fffe3ac356a914c176a \
+    --hash=sha256:c5f09cffa619adfb44799fa4a81c2a1ad77c887187613fb0a8f201ab38d89ba1 \
+    --hash=sha256:c978f864b35f2261e0819f5cd88b9830b04dc51bcf055aac3c601e525a10d2ba \
+    --hash=sha256:cbe946b3e6e60a7b4618f091e62a029cb082b109a9d6b53962dd305087c6e4fd \
+    --hash=sha256:cc3e4cd087f07758b16bef8f31d88dbb1b5da5671d2f03685ab52dece3d7a16e \
+    --hash=sha256:cf0dead5a2c5a3347af2cfec7131d4f2a2e03c934af28989c9078f8241a491fa \
+    --hash=sha256:d2794f0e68b3085d99b4f6ff9c089f6fdd02b32b9d3efdfbb55beac1bf22d516 \
+    --hash=sha256:d2fa68a96a30dd240be80bbad838a0ac81a61770611ff7952b889485970c4c71 \
+    --hash=sha256:d6f70406695e3220f09cd7a2f879333279d91aa4a8a1d34303b56d61a8180137 \
+    --hash=sha256:d8f9cd4ad1be90b0cf350a2f04a38a36e44a026cac1e036ac593dc48efe91d52 \
+    --hash=sha256:da2d94c15f88cd40d7e67f7919d4f60110d2b9d5b1e08cf354c2be773ab13479 \
+    --hash=sha256:e1727c1c0e394096bb9af185c6923e8ea55a5095b8af44f06903bcc0e06800a2 \
+    --hash=sha256:e420ced29b5904cdf9ee5545e23f9406189d8acb6750916c2db4793dada065c6 \
+    --hash=sha256:e82c5cf1495244adf5252f925ac5932e5fd288b3e5ab6b70bec5593074b7236c \
+    --hash=sha256:f1ef0d39bc1feb420caf549b3c657c871cad4ebbcf0580c4d03816b0590de0cf \
+    --hash=sha256:f8753a6c88d1d0ba64302309eecf20f70d2770f65ca02d83c2452279085bfcd3 \
+    --hash=sha256:f93dbf58f03146164048be5426ffde298b237a5e059144847e4940f5b80172c3
     # via
     #   -r requirements.in
     #   tb-nightly
@@ -265,12 +265,12 @@
 jax==0.4.7 \
     --hash=sha256:5e7002d74db25f97c99b979d4ba1233b1ef26e1597e5fc468ad11d1c8a9dc4f8
     # via -r requirements.in
-keras-nightly==3.0.0.dev2023101703 \
-    --hash=sha256:72674b80300b9672b76e20c3d4c4bd827f6cacd486d9f98e3db864d872e3eaa4 \
-    --hash=sha256:ab5e59bf6a84d048b3241ad9ae368ed8497740804a8e85774fbded34c1322587
+keras-nightly==3.0.0.dev2023103103 \
+    --hash=sha256:25a6a9030d7e067d535ec41ae6700636c90cabada80fbddb3bb8775006807860 \
+    --hash=sha256:c39fccebb3d4cc9d838371981c4d3eef08fc06eadf6cffb39b356cb625bab50f
     # via -r requirements.in
-lit==17.0.3 \
-    --hash=sha256:e6049032462be1e2928686cbd4a6cc5b3c545d83ecd078737fe79412c1f3fcc1
+lit==17.0.4 \
+    --hash=sha256:ee2e180128e770abc6aed3a02de2daf09d81b7d30225e315205d3599c311d304
     # via -r requirements.in
 markdown==3.5 \
     --hash=sha256:4afb124395ce5fc34e6d9886dab977fd9ae987fc6e85689f08278cf0c69d4bf3 \
@@ -526,17 +526,17 @@
     # via
     #   astunparse
     #   tb-nightly
-tb-nightly==2.15.0a20231017 \
-    --hash=sha256:982b8cf32bcab4902eebd2e67b885c127f16106114b34eb4d46ea554af2a713a
+tb-nightly==2.15.0a20231023 \
+    --hash=sha256:8990a52985e3296aa18a6825efc017bcded9e2fb2cbcdd2d01f5c62d4fdc9825
     # via -r requirements.in
 tblib==2.0.0 \
     --hash=sha256:9100bfa016b047d5b980d66e7efed952fbd20bd85b56110aaf473cb97d18709a \
     --hash=sha256:a6df30f272c08bf8be66e0775fad862005d950a6b8449b94f7c788731d70ecd7
     # via -r requirements.in
-tensorboard-data-server==0.7.1 \
-    --hash=sha256:255c02b7f5b03dd5c0a88c928e563441ff39e1d4b4a234cdbe09f016e53d9594 \
-    --hash=sha256:9938bd39f5041797b33921066fba0eab03a0dd10d1887a05e62ae58841ad4c3f \
-    --hash=sha256:be8d016a1aa394e6198280d4a3dc37898f56467310c5f5e617cac10a783e055a
+tensorboard-data-server==0.7.2 \
+    --hash=sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb \
+    --hash=sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60 \
+    --hash=sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530
     # via tb-nightly
 termcolor==2.3.0 \
     --hash=sha256:3afb05607b89aed0ffe25202399ee0867ad4d3cb4180d98aaf8eefa6a5f7d475 \
@@ -553,13 +553,13 @@
     --hash=sha256:c97dfde1f7bd43a71c8d2a58e369e9b2bf692d1334ea9f9cae55add7d0dd0f84 \
     --hash=sha256:fdb6d215c776278489906c2f8916e6e7d4f5a9b602ccbcfdf7f016fc8da0596e
     # via requests
-werkzeug==3.0.0 \
-    --hash=sha256:3ffff4dcc32db52ef3cc94dff3000a3c2846890f3a5a51800a27b909c5e770f0 \
-    --hash=sha256:cbb2600f7eabe51dbc0502f58be0b3e1b96b893b05695ea2b35b43d4de2d9962
+werkzeug==3.0.1 \
+    --hash=sha256:507e811ecea72b18a404947aded4b3390e1db8f826b494d76550ef45bb3b1dcc \
+    --hash=sha256:90a285dc0e42ad56b34e696398b8122ee4c681833fb35b8334a095d82c56da10
     # via tb-nightly
-wheel==0.41.2 \
-    --hash=sha256:0c5ac5ff2afb79ac23ab82bab027a0be7b5dbcf2e54dc50efe4bf507de1f7985 \
-    --hash=sha256:75909db2664838d015e3d9139004ee16711748a52c8f336b52882266540215d8
+wheel==0.41.3 \
+    --hash=sha256:488609bc63a29322326e05560731bf7bfea8e48ad646e1f5e40d366607de0942 \
+    --hash=sha256:4d4987ce51a49370ea65c0bfd2234e8ce80a12780820d9dc462597a6e60d0841
     # via
     #   -r requirements.in
     #   astunparse
diff --git a/requirements_lock_3_12.txt b/requirements_lock_3_12.txt
index 853381c..e22b3e7 100644
--- a/requirements_lock_3_12.txt
+++ b/requirements_lock_3_12.txt
@@ -1,7 +1,9 @@
 absl-py==2.0.0 \
     --hash=sha256:9a28abb62774ae4e8edbe2dd4c49ffcd45a6a848952a5eccc6a49f3f0fc1e2f3 \
     --hash=sha256:d9690211c5fcfefcdd1a45470ac2b5c5acd45241c3af71eed96bc5441746c0d5
-    # via tb-nightly
+    # via
+    #   keras-nightly
+    #   tb-nightly
 astor==0.7.1 \
     --hash=sha256:95c30d87a6c2cf89aa628b87398466840f0ad8652f88eb173125a6df8533fb8d \
     --hash=sha256:fb503b9e2fdd05609fbf557b916b4a7824171203701660f0c55bbf5a7a68713e
@@ -10,110 +12,151 @@
     --hash=sha256:5ad93a8456f0d084c3456d059fd9a92cce667963232cbf763eac3bc5b7940872 \
     --hash=sha256:c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8
     # via -r requirements.in
-cachetools==5.3.1 \
-    --hash=sha256:95ef631eeaea14ba2e36f06437f36463aac3a096799e876ee55e5cdccb102590 \
-    --hash=sha256:dce83f2d9b4e1f732a8cd44af8e8fab2dbe46201467fc98b3ef8f269092bf62b
+cachetools==5.3.2 \
+    --hash=sha256:086ee420196f7b2ab9ca2db2520aca326318b68fe5ba8bc4d49cca91add450f2 \
+    --hash=sha256:861f35a13a451f94e301ce2bec7cac63e881232ccce7ed67fab9b5df4d3beaa1
     # via google-auth
 certifi==2023.7.22 \
     --hash=sha256:539cc1d13202e33ca466e88b2807e29f4c13049d6d87031a3c110744495cb082 \
     --hash=sha256:92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9
     # via requests
-charset-normalizer==3.3.0 \
-    --hash=sha256:02673e456dc5ab13659f85196c534dc596d4ef260e4d86e856c3b2773ce09843 \
-    --hash=sha256:02af06682e3590ab952599fbadac535ede5d60d78848e555aa58d0c0abbde786 \
-    --hash=sha256:03680bb39035fbcffe828eae9c3f8afc0428c91d38e7d61aa992ef7a59fb120e \
-    --hash=sha256:0570d21da019941634a531444364f2482e8db0b3425fcd5ac0c36565a64142c8 \
-    --hash=sha256:09c77f964f351a7369cc343911e0df63e762e42bac24cd7d18525961c81754f4 \
-    --hash=sha256:0d3d5b7db9ed8a2b11a774db2bbea7ba1884430a205dbd54a32d61d7c2a190fa \
-    --hash=sha256:1063da2c85b95f2d1a430f1c33b55c9c17ffaf5e612e10aeaad641c55a9e2b9d \
-    --hash=sha256:12ebea541c44fdc88ccb794a13fe861cc5e35d64ed689513a5c03d05b53b7c82 \
-    --hash=sha256:153e7b6e724761741e0974fc4dcd406d35ba70b92bfe3fedcb497226c93b9da7 \
-    --hash=sha256:15b26ddf78d57f1d143bdf32e820fd8935d36abe8a25eb9ec0b5a71c82eb3895 \
-    --hash=sha256:1872d01ac8c618a8da634e232f24793883d6e456a66593135aeafe3784b0848d \
-    --hash=sha256:187d18082694a29005ba2944c882344b6748d5be69e3a89bf3cc9d878e548d5a \
-    --hash=sha256:1b2919306936ac6efb3aed1fbf81039f7087ddadb3160882a57ee2ff74fd2382 \
-    --hash=sha256:232ac332403e37e4a03d209a3f92ed9071f7d3dbda70e2a5e9cff1c4ba9f0678 \
-    --hash=sha256:23e8565ab7ff33218530bc817922fae827420f143479b753104ab801145b1d5b \
-    --hash=sha256:24817cb02cbef7cd499f7c9a2735286b4782bd47a5b3516a0e84c50eab44b98e \
-    --hash=sha256:249c6470a2b60935bafd1d1d13cd613f8cd8388d53461c67397ee6a0f5dce741 \
-    --hash=sha256:24a91a981f185721542a0b7c92e9054b7ab4fea0508a795846bc5b0abf8118d4 \
-    --hash=sha256:2502dd2a736c879c0f0d3e2161e74d9907231e25d35794584b1ca5284e43f596 \
-    --hash=sha256:250c9eb0f4600361dd80d46112213dff2286231d92d3e52af1e5a6083d10cad9 \
-    --hash=sha256:278c296c6f96fa686d74eb449ea1697f3c03dc28b75f873b65b5201806346a69 \
-    --hash=sha256:2935ffc78db9645cb2086c2f8f4cfd23d9b73cc0dc80334bc30aac6f03f68f8c \
-    --hash=sha256:2f4a0033ce9a76e391542c182f0d48d084855b5fcba5010f707c8e8c34663d77 \
-    --hash=sha256:30a85aed0b864ac88309b7d94be09f6046c834ef60762a8833b660139cfbad13 \
-    --hash=sha256:380c4bde80bce25c6e4f77b19386f5ec9db230df9f2f2ac1e5ad7af2caa70459 \
-    --hash=sha256:3ae38d325b512f63f8da31f826e6cb6c367336f95e418137286ba362925c877e \
-    --hash=sha256:3b447982ad46348c02cb90d230b75ac34e9886273df3a93eec0539308a6296d7 \
-    --hash=sha256:3debd1150027933210c2fc321527c2299118aa929c2f5a0a80ab6953e3bd1908 \
-    --hash=sha256:4162918ef3098851fcd8a628bf9b6a98d10c380725df9e04caf5ca6dd48c847a \
-    --hash=sha256:468d2a840567b13a590e67dd276c570f8de00ed767ecc611994c301d0f8c014f \
-    --hash=sha256:4cc152c5dd831641e995764f9f0b6589519f6f5123258ccaca8c6d34572fefa8 \
-    --hash=sha256:542da1178c1c6af8873e143910e2269add130a299c9106eef2594e15dae5e482 \
-    --hash=sha256:557b21a44ceac6c6b9773bc65aa1b4cc3e248a5ad2f5b914b91579a32e22204d \
-    --hash=sha256:5707a746c6083a3a74b46b3a631d78d129edab06195a92a8ece755aac25a3f3d \
-    --hash=sha256:588245972aca710b5b68802c8cad9edaa98589b1b42ad2b53accd6910dad3545 \
-    --hash=sha256:5adf257bd58c1b8632046bbe43ee38c04e1038e9d37de9c57a94d6bd6ce5da34 \
-    --hash=sha256:619d1c96099be5823db34fe89e2582b336b5b074a7f47f819d6b3a57ff7bdb86 \
-    --hash=sha256:63563193aec44bce707e0c5ca64ff69fa72ed7cf34ce6e11d5127555756fd2f6 \
-    --hash=sha256:67b8cc9574bb518ec76dc8e705d4c39ae78bb96237cb533edac149352c1f39fe \
-    --hash=sha256:6a685067d05e46641d5d1623d7c7fdf15a357546cbb2f71b0ebde91b175ffc3e \
-    --hash=sha256:70f1d09c0d7748b73290b29219e854b3207aea922f839437870d8cc2168e31cc \
-    --hash=sha256:750b446b2ffce1739e8578576092179160f6d26bd5e23eb1789c4d64d5af7dc7 \
-    --hash=sha256:7966951325782121e67c81299a031f4c115615e68046f79b85856b86ebffc4cd \
-    --hash=sha256:7b8b8bf1189b3ba9b8de5c8db4d541b406611a71a955bbbd7385bbc45fcb786c \
-    --hash=sha256:7f5d10bae5d78e4551b7be7a9b29643a95aded9d0f602aa2ba584f0388e7a557 \
-    --hash=sha256:805dfea4ca10411a5296bcc75638017215a93ffb584c9e344731eef0dcfb026a \
-    --hash=sha256:81bf654678e575403736b85ba3a7867e31c2c30a69bc57fe88e3ace52fb17b89 \
-    --hash=sha256:82eb849f085624f6a607538ee7b83a6d8126df6d2f7d3b319cb837b289123078 \
-    --hash=sha256:85a32721ddde63c9df9ebb0d2045b9691d9750cb139c161c80e500d210f5e26e \
-    --hash=sha256:86d1f65ac145e2c9ed71d8ffb1905e9bba3a91ae29ba55b4c46ae6fc31d7c0d4 \
-    --hash=sha256:86f63face3a527284f7bb8a9d4f78988e3c06823f7bea2bd6f0e0e9298ca0403 \
-    --hash=sha256:8eaf82f0eccd1505cf39a45a6bd0a8cf1c70dcfc30dba338207a969d91b965c0 \
-    --hash=sha256:93aa7eef6ee71c629b51ef873991d6911b906d7312c6e8e99790c0f33c576f89 \
-    --hash=sha256:96c2b49eb6a72c0e4991d62406e365d87067ca14c1a729a870d22354e6f68115 \
-    --hash=sha256:9cf3126b85822c4e53aa28c7ec9869b924d6fcfb76e77a45c44b83d91afd74f9 \
-    --hash=sha256:9fe359b2e3a7729010060fbca442ca225280c16e923b37db0e955ac2a2b72a05 \
-    --hash=sha256:a0ac5e7015a5920cfce654c06618ec40c33e12801711da6b4258af59a8eff00a \
-    --hash=sha256:a3f93dab657839dfa61025056606600a11d0b696d79386f974e459a3fbc568ec \
-    --hash=sha256:a4b71f4d1765639372a3b32d2638197f5cd5221b19531f9245fcc9ee62d38f56 \
-    --hash=sha256:aae32c93e0f64469f74ccc730a7cb21c7610af3a775157e50bbd38f816536b38 \
-    --hash=sha256:aaf7b34c5bc56b38c931a54f7952f1ff0ae77a2e82496583b247f7c969eb1479 \
-    --hash=sha256:abecce40dfebbfa6abf8e324e1860092eeca6f7375c8c4e655a8afb61af58f2c \
-    --hash=sha256:abf0d9f45ea5fb95051c8bfe43cb40cda383772f7e5023a83cc481ca2604d74e \
-    --hash=sha256:ac71b2977fb90c35d41c9453116e283fac47bb9096ad917b8819ca8b943abecd \
-    --hash=sha256:ada214c6fa40f8d800e575de6b91a40d0548139e5dc457d2ebb61470abf50186 \
-    --hash=sha256:b09719a17a2301178fac4470d54b1680b18a5048b481cb8890e1ef820cb80455 \
-    --hash=sha256:b1121de0e9d6e6ca08289583d7491e7fcb18a439305b34a30b20d8215922d43c \
-    --hash=sha256:b3b2316b25644b23b54a6f6401074cebcecd1244c0b8e80111c9a3f1c8e83d65 \
-    --hash=sha256:b3d9b48ee6e3967b7901c052b670c7dda6deb812c309439adaffdec55c6d7b78 \
-    --hash=sha256:b5bcf60a228acae568e9911f410f9d9e0d43197d030ae5799e20dca8df588287 \
-    --hash=sha256:b8f3307af845803fb0b060ab76cf6dd3a13adc15b6b451f54281d25911eb92df \
-    --hash=sha256:c2af80fb58f0f24b3f3adcb9148e6203fa67dd3f61c4af146ecad033024dde43 \
-    --hash=sha256:c350354efb159b8767a6244c166f66e67506e06c8924ed74669b2c70bc8735b1 \
-    --hash=sha256:c5a74c359b2d47d26cdbbc7845e9662d6b08a1e915eb015d044729e92e7050b7 \
-    --hash=sha256:c71f16da1ed8949774ef79f4a0260d28b83b3a50c6576f8f4f0288d109777989 \
-    --hash=sha256:d47ecf253780c90ee181d4d871cd655a789da937454045b17b5798da9393901a \
-    --hash=sha256:d7eff0f27edc5afa9e405f7165f85a6d782d308f3b6b9d96016c010597958e63 \
-    --hash=sha256:d97d85fa63f315a8bdaba2af9a6a686e0eceab77b3089af45133252618e70884 \
-    --hash=sha256:db756e48f9c5c607b5e33dd36b1d5872d0422e960145b08ab0ec7fd420e9d649 \
-    --hash=sha256:dc45229747b67ffc441b3de2f3ae5e62877a282ea828a5bdb67883c4ee4a8810 \
-    --hash=sha256:e0fc42822278451bc13a2e8626cf2218ba570f27856b536e00cfa53099724828 \
-    --hash=sha256:e39c7eb31e3f5b1f88caff88bcff1b7f8334975b46f6ac6e9fc725d829bc35d4 \
-    --hash=sha256:e46cd37076971c1040fc8c41273a8b3e2c624ce4f2be3f5dfcb7a430c1d3acc2 \
-    --hash=sha256:e5c1502d4ace69a179305abb3f0bb6141cbe4714bc9b31d427329a95acfc8bdd \
-    --hash=sha256:edfe077ab09442d4ef3c52cb1f9dab89bff02f4524afc0acf2d46be17dc479f5 \
-    --hash=sha256:effe5406c9bd748a871dbcaf3ac69167c38d72db8c9baf3ff954c344f31c4cbe \
-    --hash=sha256:f0d1e3732768fecb052d90d62b220af62ead5748ac51ef61e7b32c266cac9293 \
-    --hash=sha256:f5969baeaea61c97efa706b9b107dcba02784b1601c74ac84f2a532ea079403e \
-    --hash=sha256:f8888e31e3a85943743f8fc15e71536bda1c81d5aa36d014a3c0c44481d7db6e \
-    --hash=sha256:fc52b79d83a3fe3a360902d3f5d79073a993597d48114c29485e9431092905d8
+charset-normalizer==3.3.1 \
+    --hash=sha256:06cf46bdff72f58645434d467bf5228080801298fbba19fe268a01b4534467f5 \
+    --hash=sha256:0c8c61fb505c7dad1d251c284e712d4e0372cef3b067f7ddf82a7fa82e1e9a93 \
+    --hash=sha256:10b8dd31e10f32410751b3430996f9807fc4d1587ca69772e2aa940a82ab571a \
+    --hash=sha256:1171ef1fc5ab4693c5d151ae0fdad7f7349920eabbaca6271f95969fa0756c2d \
+    --hash=sha256:17a866d61259c7de1bdadef418a37755050ddb4b922df8b356503234fff7932c \
+    --hash=sha256:1d6bfc32a68bc0933819cfdfe45f9abc3cae3877e1d90aac7259d57e6e0f85b1 \
+    --hash=sha256:1ec937546cad86d0dce5396748bf392bb7b62a9eeb8c66efac60e947697f0e58 \
+    --hash=sha256:223b4d54561c01048f657fa6ce41461d5ad8ff128b9678cfe8b2ecd951e3f8a2 \
+    --hash=sha256:2465aa50c9299d615d757c1c888bc6fef384b7c4aec81c05a0172b4400f98557 \
+    --hash=sha256:28f512b9a33235545fbbdac6a330a510b63be278a50071a336afc1b78781b147 \
+    --hash=sha256:2c092be3885a1b7899cd85ce24acedc1034199d6fca1483fa2c3a35c86e43041 \
+    --hash=sha256:2c4c99f98fc3a1835af8179dcc9013f93594d0670e2fa80c83aa36346ee763d2 \
+    --hash=sha256:31445f38053476a0c4e6d12b047b08ced81e2c7c712e5a1ad97bc913256f91b2 \
+    --hash=sha256:31bbaba7218904d2eabecf4feec0d07469284e952a27400f23b6628439439fa7 \
+    --hash=sha256:34d95638ff3613849f473afc33f65c401a89f3b9528d0d213c7037c398a51296 \
+    --hash=sha256:352a88c3df0d1fa886562384b86f9a9e27563d4704ee0e9d56ec6fcd270ea690 \
+    --hash=sha256:39b70a6f88eebe239fa775190796d55a33cfb6d36b9ffdd37843f7c4c1b5dc67 \
+    --hash=sha256:3c66df3f41abee950d6638adc7eac4730a306b022570f71dd0bd6ba53503ab57 \
+    --hash=sha256:3f70fd716855cd3b855316b226a1ac8bdb3caf4f7ea96edcccc6f484217c9597 \
+    --hash=sha256:3f9bc2ce123637a60ebe819f9fccc614da1bcc05798bbbaf2dd4ec91f3e08846 \
+    --hash=sha256:3fb765362688821404ad6cf86772fc54993ec11577cd5a92ac44b4c2ba52155b \
+    --hash=sha256:45f053a0ece92c734d874861ffe6e3cc92150e32136dd59ab1fb070575189c97 \
+    --hash=sha256:46fb9970aa5eeca547d7aa0de5d4b124a288b42eaefac677bde805013c95725c \
+    --hash=sha256:4cb50a0335382aac15c31b61d8531bc9bb657cfd848b1d7158009472189f3d62 \
+    --hash=sha256:4e12f8ee80aa35e746230a2af83e81bd6b52daa92a8afaef4fea4a2ce9b9f4fa \
+    --hash=sha256:4f3100d86dcd03c03f7e9c3fdb23d92e32abbca07e7c13ebd7ddfbcb06f5991f \
+    --hash=sha256:4f6e2a839f83a6a76854d12dbebde50e4b1afa63e27761549d006fa53e9aa80e \
+    --hash=sha256:4f861d94c2a450b974b86093c6c027888627b8082f1299dfd5a4bae8e2292821 \
+    --hash=sha256:501adc5eb6cd5f40a6f77fbd90e5ab915c8fd6e8c614af2db5561e16c600d6f3 \
+    --hash=sha256:520b7a142d2524f999447b3a0cf95115df81c4f33003c51a6ab637cbda9d0bf4 \
+    --hash=sha256:548eefad783ed787b38cb6f9a574bd8664468cc76d1538215d510a3cd41406cb \
+    --hash=sha256:555fe186da0068d3354cdf4bbcbc609b0ecae4d04c921cc13e209eece7720727 \
+    --hash=sha256:55602981b2dbf8184c098bc10287e8c245e351cd4fdcad050bd7199d5a8bf514 \
+    --hash=sha256:58e875eb7016fd014c0eea46c6fa92b87b62c0cb31b9feae25cbbe62c919f54d \
+    --hash=sha256:5a3580a4fdc4ac05f9e53c57f965e3594b2f99796231380adb2baaab96e22761 \
+    --hash=sha256:5b70bab78accbc672f50e878a5b73ca692f45f5b5e25c8066d748c09405e6a55 \
+    --hash=sha256:5ceca5876032362ae73b83347be8b5dbd2d1faf3358deb38c9c88776779b2e2f \
+    --hash=sha256:61f1e3fb621f5420523abb71f5771a204b33c21d31e7d9d86881b2cffe92c47c \
+    --hash=sha256:633968254f8d421e70f91c6ebe71ed0ab140220469cf87a9857e21c16687c034 \
+    --hash=sha256:63a6f59e2d01310f754c270e4a257426fe5a591dc487f1983b3bbe793cf6bac6 \
+    --hash=sha256:63accd11149c0f9a99e3bc095bbdb5a464862d77a7e309ad5938fbc8721235ae \
+    --hash=sha256:6db3cfb9b4fcecb4390db154e75b49578c87a3b9979b40cdf90d7e4b945656e1 \
+    --hash=sha256:71ef3b9be10070360f289aea4838c784f8b851be3ba58cf796262b57775c2f14 \
+    --hash=sha256:7ae8e5142dcc7a49168f4055255dbcced01dc1714a90a21f87448dc8d90617d1 \
+    --hash=sha256:7b6cefa579e1237ce198619b76eaa148b71894fb0d6bcf9024460f9bf30fd228 \
+    --hash=sha256:800561453acdecedaac137bf09cd719c7a440b6800ec182f077bb8e7025fb708 \
+    --hash=sha256:82ca51ff0fc5b641a2d4e1cc8c5ff108699b7a56d7f3ad6f6da9dbb6f0145b48 \
+    --hash=sha256:851cf693fb3aaef71031237cd68699dded198657ec1e76a76eb8be58c03a5d1f \
+    --hash=sha256:854cc74367180beb327ab9d00f964f6d91da06450b0855cbbb09187bcdb02de5 \
+    --hash=sha256:87071618d3d8ec8b186d53cb6e66955ef2a0e4fa63ccd3709c0c90ac5a43520f \
+    --hash=sha256:871d045d6ccc181fd863a3cd66ee8e395523ebfbc57f85f91f035f50cee8e3d4 \
+    --hash=sha256:8aee051c89e13565c6bd366813c386939f8e928af93c29fda4af86d25b73d8f8 \
+    --hash=sha256:8af5a8917b8af42295e86b64903156b4f110a30dca5f3b5aedea123fbd638bff \
+    --hash=sha256:8ec8ef42c6cd5856a7613dcd1eaf21e5573b2185263d87d27c8edcae33b62a61 \
+    --hash=sha256:91e43805ccafa0a91831f9cd5443aa34528c0c3f2cc48c4cb3d9a7721053874b \
+    --hash=sha256:9505dc359edb6a330efcd2be825fdb73ee3e628d9010597aa1aee5aa63442e97 \
+    --hash=sha256:985c7965f62f6f32bf432e2681173db41336a9c2611693247069288bcb0c7f8b \
+    --hash=sha256:9a74041ba0bfa9bc9b9bb2cd3238a6ab3b7618e759b41bd15b5f6ad958d17605 \
+    --hash=sha256:9edbe6a5bf8b56a4a84533ba2b2f489d0046e755c29616ef8830f9e7d9cf5728 \
+    --hash=sha256:a15c1fe6d26e83fd2e5972425a772cca158eae58b05d4a25a4e474c221053e2d \
+    --hash=sha256:a66bcdf19c1a523e41b8e9d53d0cedbfbac2e93c649a2e9502cb26c014d0980c \
+    --hash=sha256:ae4070f741f8d809075ef697877fd350ecf0b7c5837ed68738607ee0a2c572cf \
+    --hash=sha256:ae55d592b02c4349525b6ed8f74c692509e5adffa842e582c0f861751701a673 \
+    --hash=sha256:b578cbe580e3b41ad17b1c428f382c814b32a6ce90f2d8e39e2e635d49e498d1 \
+    --hash=sha256:b891a2f68e09c5ef989007fac11476ed33c5c9994449a4e2c3386529d703dc8b \
+    --hash=sha256:baec8148d6b8bd5cee1ae138ba658c71f5b03e0d69d5907703e3e1df96db5e41 \
+    --hash=sha256:bb06098d019766ca16fc915ecaa455c1f1cd594204e7f840cd6258237b5079a8 \
+    --hash=sha256:bc791ec3fd0c4309a753f95bb6c749ef0d8ea3aea91f07ee1cf06b7b02118f2f \
+    --hash=sha256:bd28b31730f0e982ace8663d108e01199098432a30a4c410d06fe08fdb9e93f4 \
+    --hash=sha256:be4d9c2770044a59715eb57c1144dedea7c5d5ae80c68fb9959515037cde2008 \
+    --hash=sha256:c0c72d34e7de5604df0fde3644cc079feee5e55464967d10b24b1de268deceb9 \
+    --hash=sha256:c0e842112fe3f1a4ffcf64b06dc4c61a88441c2f02f373367f7b4c1aa9be2ad5 \
+    --hash=sha256:c15070ebf11b8b7fd1bfff7217e9324963c82dbdf6182ff7050519e350e7ad9f \
+    --hash=sha256:c2000c54c395d9e5e44c99dc7c20a64dc371f777faf8bae4919ad3e99ce5253e \
+    --hash=sha256:c30187840d36d0ba2893bc3271a36a517a717f9fd383a98e2697ee890a37c273 \
+    --hash=sha256:cb7cd68814308aade9d0c93c5bd2ade9f9441666f8ba5aa9c2d4b389cb5e2a45 \
+    --hash=sha256:cd805513198304026bd379d1d516afbf6c3c13f4382134a2c526b8b854da1c2e \
+    --hash=sha256:d0bf89afcbcf4d1bb2652f6580e5e55a840fdf87384f6063c4a4f0c95e378656 \
+    --hash=sha256:d9137a876020661972ca6eec0766d81aef8a5627df628b664b234b73396e727e \
+    --hash=sha256:dbd95e300367aa0827496fe75a1766d198d34385a58f97683fe6e07f89ca3e3c \
+    --hash=sha256:dced27917823df984fe0c80a5c4ad75cf58df0fbfae890bc08004cd3888922a2 \
+    --hash=sha256:de0b4caa1c8a21394e8ce971997614a17648f94e1cd0640fbd6b4d14cab13a72 \
+    --hash=sha256:debb633f3f7856f95ad957d9b9c781f8e2c6303ef21724ec94bea2ce2fcbd056 \
+    --hash=sha256:e372d7dfd154009142631de2d316adad3cc1c36c32a38b16a4751ba78da2a397 \
+    --hash=sha256:ecd26be9f112c4f96718290c10f4caea6cc798459a3a76636b817a0ed7874e42 \
+    --hash=sha256:edc0202099ea1d82844316604e17d2b175044f9bcb6b398aab781eba957224bd \
+    --hash=sha256:f194cce575e59ffe442c10a360182a986535fd90b57f7debfaa5c845c409ecc3 \
+    --hash=sha256:f5fb672c396d826ca16a022ac04c9dce74e00a1c344f6ad1a0fdc1ba1f332213 \
+    --hash=sha256:f6a02a3c7950cafaadcd46a226ad9e12fc9744652cc69f9e5534f98b47f3bbcf \
+    --hash=sha256:fe81b35c33772e56f4b6cf62cf4aedc1762ef7162a31e6ac7fe5e40d0149eb67
     # via requests
 dill==0.3.7 \
     --hash=sha256:76b122c08ef4ce2eedcd4d1abd8e641114bfc6c2867f49f3c41facf65bf19f5e \
     --hash=sha256:cc1c8b182eb3013e24bd475ff2e9295af86c1a38eb1aff128dac8962a9ce3c03
     # via -r requirements.in
+dm-tree==0.1.8 \
+    --hash=sha256:054b461f8176f4bce7a21f7b1870f873a1ced3bdbe1282c816c550bb43c71fa6 \
+    --hash=sha256:0d3172394079a86c3a759179c65f64c48d1a42b89495fcf38976d11cc3bb952c \
+    --hash=sha256:0e9620ccf06393eb6b613b5e366469304622d4ea96ae6540b28a33840e6c89cf \
+    --hash=sha256:0fcaabbb14e7980377439e7140bd05552739ca5e515ecb3119f234acee4b9430 \
+    --hash=sha256:1607ce49aa42f010d1e5e616d92ce899d66835d4d8bea49679582435285515de \
+    --hash=sha256:181c35521d480d0365f39300542cb6cd7fd2b77351bb43d7acfda15aef63b317 \
+    --hash=sha256:1d7c26e431fc93cc7e0cba867eb000db6a05f6f2b25af11ac4e9dada88fc5bca \
+    --hash=sha256:1fe962015b2fe1282892b28ebe962faed53c7f98d942da9a4625cbf27baef913 \
+    --hash=sha256:250b692fb75f45f02e2f58fbef9ab338904ef334b90557565621fa251df267cf \
+    --hash=sha256:2869228d9c619074de501a3c10dc7f07c75422f8fab36ecdcb859b6f1b1ec3ef \
+    --hash=sha256:28c52cbf4f8b3dbd0beaedf44f69fa85eec5e9dede612e08035e06ada6ec9426 \
+    --hash=sha256:2f7915660f59c09068e428613c480150180df1060561fd0d1470684ae7007bd1 \
+    --hash=sha256:343a4a4ebaa127451ff971254a4be4084eb4bdc0b2513c32b46f6f728fd03f9e \
+    --hash=sha256:35cc164a79336bfcfafb47e5f297898359123bbd3330c1967f0c4994f9cf9f60 \
+    --hash=sha256:378cc8ad93c5fe3590f405a309980721f021c790ca1bdf9b15bb1d59daec57f5 \
+    --hash=sha256:39070ba268c0491af9fe7a58644d99e8b4f2cde6e5884ba3380bddc84ed43d5f \
+    --hash=sha256:5483dca4d7eb1a0d65fe86d3b6a53ae717face83c1f17e0887b1a4a64ae5c410 \
+    --hash=sha256:694c3654cfd2a81552c08ec66bb5c4a3d48fa292b9a181880fb081c36c5b9134 \
+    --hash=sha256:803bfc53b4659f447ac694dbd04235f94a73ef7c1fd1e0df7c84ac41e0bc963b \
+    --hash=sha256:81fce77f22a302d7a5968aebdf4efafef4def7ce96528719a354e6990dcd49c7 \
+    --hash=sha256:83b7764de0d855338abefc6e3ee9fe40d301668310aa3baea3f778ff051f4393 \
+    --hash=sha256:8c60a7eadab64c2278861f56bca320b2720f163dca9d7558103c3b77f2416571 \
+    --hash=sha256:8ed3564abed97c806db122c2d3e1a2b64c74a63debe9903aad795167cc301368 \
+    --hash=sha256:a5d819c38c03f0bb5b3b3703c60e4b170355a0fc6b5819325bf3d4ceb3ae7e80 \
+    --hash=sha256:ad16ceba90a56ec47cf45b21856d14962ac314787975ef786efb5e6e9ca75ec7 \
+    --hash=sha256:af4b3d372f2477dcd89a6e717e4a575ca35ccc20cc4454a8a4b6f8838a00672d \
+    --hash=sha256:b095ba4f8ca1ba19350fd53cf1f8f3eb0bd406aa28af64a6dfc86707b32a810a \
+    --hash=sha256:b9bd9b9ccb59409d33d51d84b7668010c04c2af7d4a371632874c1ca356cff3d \
+    --hash=sha256:b9f89a454e98806b44fe9d40ec9eee61f848388f7e79ac2371a55679bd5a3ac6 \
+    --hash=sha256:bb2d109f42190225112da899b9f3d46d0d5f26aef501c61e43529fe9322530b5 \
+    --hash=sha256:c5c8c12e3fda754ef6af94161bacdaeda816d941995fac415d6855c6c386af68 \
+    --hash=sha256:d1612fcaecd79023dbc6a6ae48d51a80beb5c385d6f3f6d71688e57bc8d07de8 \
+    --hash=sha256:d16e1f2a073604cfcc09f7131ae8d534674f43c3aef4c25742eae295bc60d04f \
+    --hash=sha256:d20f2faa3672b52e5013f4077117bfb99c4cfc0b445d3bde1584c34032b57436 \
+    --hash=sha256:d40fa4106ca6edc66760246a08f500ec0c85ef55c762fb4a363f6ee739ba02ee \
+    --hash=sha256:de287fabc464b8734be251e46e06aa9aa1001f34198da2b6ce07bd197172b9cb \
+    --hash=sha256:e4d714371bb08839e4e5e29024fc95832d9affe129825ef38836b143028bd144 \
+    --hash=sha256:f7ac31b9aecccb2c6e1ab29706f6ded3eba0c2c69c770322c9c685929c3d6afb \
+    --hash=sha256:fa42a605d099ee7d41ba2b5fb75e21423951fd26e5d50583a00471238fb3021d
+    # via keras-nightly
 gast==0.4.0 \
     --hash=sha256:40feb7b8b8434785585ab224d1568b857edb18297e5a3047f1ba012bc83b42c1 \
     --hash=sha256:b7adcdd5adbebf1adf17378da5ba3f543684dbec47b1cda1f3997e573cd542c4
@@ -128,61 +171,61 @@
     --hash=sha256:089c6e587d36f4803ac7e0720c045c6a8b1fd1790088b8424975b90d0ee61c12 \
     --hash=sha256:83ea8c3b0881e453790baff4448e8a6112ac8778d1de9da0b68010b843937afb
     # via tb-nightly
-grpcio==1.59.0 \
-    --hash=sha256:0ae444221b2c16d8211b55326f8ba173ba8f8c76349bfc1768198ba592b58f74 \
-    --hash=sha256:0b84445fa94d59e6806c10266b977f92fa997db3585f125d6b751af02ff8b9fe \
-    --hash=sha256:14890da86a0c0e9dc1ea8e90101d7a3e0e7b1e71f4487fab36e2bfd2ecadd13c \
-    --hash=sha256:15f03bd714f987d48ae57fe092cf81960ae36da4e520e729392a59a75cda4f29 \
-    --hash=sha256:1a839ba86764cc48226f50b924216000c79779c563a301586a107bda9cbe9dcf \
-    --hash=sha256:225e5fa61c35eeaebb4e7491cd2d768cd8eb6ed00f2664fa83a58f29418b39fd \
-    --hash=sha256:228b91ce454876d7eed74041aff24a8f04c0306b7250a2da99d35dd25e2a1211 \
-    --hash=sha256:2ea95cd6abbe20138b8df965b4a8674ec312aaef3147c0f46a0bac661f09e8d0 \
-    --hash=sha256:2f120d27051e4c59db2f267b71b833796770d3ea36ca712befa8c5fff5da6ebd \
-    --hash=sha256:34341d9e81a4b669a5f5dca3b2a760b6798e95cdda2b173e65d29d0b16692857 \
-    --hash=sha256:3859917de234a0a2a52132489c4425a73669de9c458b01c9a83687f1f31b5b10 \
-    --hash=sha256:38823bd088c69f59966f594d087d3a929d1ef310506bee9e3648317660d65b81 \
-    --hash=sha256:38da5310ef84e16d638ad89550b5b9424df508fd5c7b968b90eb9629ca9be4b9 \
-    --hash=sha256:3b8ff795d35a93d1df6531f31c1502673d1cebeeba93d0f9bd74617381507e3f \
-    --hash=sha256:50eff97397e29eeee5df106ea1afce3ee134d567aa2c8e04fabab05c79d791a7 \
-    --hash=sha256:5711c51e204dc52065f4a3327dca46e69636a0b76d3e98c2c28c4ccef9b04c52 \
-    --hash=sha256:598f3530231cf10ae03f4ab92d48c3be1fee0c52213a1d5958df1a90957e6a88 \
-    --hash=sha256:611d9aa0017fa386809bddcb76653a5ab18c264faf4d9ff35cb904d44745f575 \
-    --hash=sha256:61bc72a00ecc2b79d9695220b4d02e8ba53b702b42411397e831c9b0589f08a3 \
-    --hash=sha256:63982150a7d598281fa1d7ffead6096e543ff8be189d3235dd2b5604f2c553e5 \
-    --hash=sha256:6c4b1cc3a9dc1924d2eb26eec8792fedd4b3fcd10111e26c1d551f2e4eda79ce \
-    --hash=sha256:81d86a096ccd24a57fa5772a544c9e566218bc4de49e8c909882dae9d73392df \
-    --hash=sha256:849c47ef42424c86af069a9c5e691a765e304079755d5c29eff511263fad9c2a \
-    --hash=sha256:871371ce0c0055d3db2a86fdebd1e1d647cf21a8912acc30052660297a5a6901 \
-    --hash=sha256:8cd2d38c2d52f607d75a74143113174c36d8a416d9472415eab834f837580cf7 \
-    --hash=sha256:936b2e04663660c600d5173bc2cc84e15adbad9c8f71946eb833b0afc205b996 \
-    --hash=sha256:93e9cb546e610829e462147ce724a9cb108e61647a3454500438a6deef610be1 \
-    --hash=sha256:956f0b7cb465a65de1bd90d5a7475b4dc55089b25042fe0f6c870707e9aabb1d \
-    --hash=sha256:986de4aa75646e963466b386a8c5055c8b23a26a36a6c99052385d6fe8aaf180 \
-    --hash=sha256:aca8a24fef80bef73f83eb8153f5f5a0134d9539b4c436a716256b311dda90a6 \
-    --hash=sha256:acf70a63cf09dd494000007b798aff88a436e1c03b394995ce450be437b8e54f \
-    --hash=sha256:b34c7a4c31841a2ea27246a05eed8a80c319bfc0d3e644412ec9ce437105ff6c \
-    --hash=sha256:b95ec8ecc4f703f5caaa8d96e93e40c7f589bad299a2617bdb8becbcce525539 \
-    --hash=sha256:ba0ca727a173ee093f49ead932c051af463258b4b493b956a2c099696f38aa66 \
-    --hash=sha256:c041a91712bf23b2a910f61e16565a05869e505dc5a5c025d429ca6de5de842c \
-    --hash=sha256:c0488c2b0528e6072010182075615620071371701733c63ab5be49140ed8f7f0 \
-    --hash=sha256:c173a87d622ea074ce79be33b952f0b424fa92182063c3bda8625c11d3585d09 \
-    --hash=sha256:c251d22de8f9f5cca9ee47e4bade7c5c853e6e40743f47f5cc02288ee7a87252 \
-    --hash=sha256:c4dfdb49f4997dc664f30116af2d34751b91aa031f8c8ee251ce4dcfc11277b0 \
-    --hash=sha256:ca87ee6183421b7cea3544190061f6c1c3dfc959e0b57a5286b108511fd34ff4 \
-    --hash=sha256:ceb1e68135788c3fce2211de86a7597591f0b9a0d2bb80e8401fd1d915991bac \
-    --hash=sha256:d09bd2a4e9f5a44d36bb8684f284835c14d30c22d8ec92ce796655af12163588 \
-    --hash=sha256:d0fcf53df684fcc0154b1e61f6b4a8c4cf5f49d98a63511e3f30966feff39cd0 \
-    --hash=sha256:d74f7d2d7c242a6af9d4d069552ec3669965b74fed6b92946e0e13b4168374f9 \
-    --hash=sha256:de2599985b7c1b4ce7526e15c969d66b93687571aa008ca749d6235d056b7205 \
-    --hash=sha256:e5378785dce2b91eb2e5b857ec7602305a3b5cf78311767146464bfa365fc897 \
-    --hash=sha256:ec78aebb9b6771d6a1de7b6ca2f779a2f6113b9108d486e904bde323d51f5589 \
-    --hash=sha256:f1feb034321ae2f718172d86b8276c03599846dc7bb1792ae370af02718f91c5 \
-    --hash=sha256:f21917aa50b40842b51aff2de6ebf9e2f6af3fe0971c31960ad6a3a2b24988f4 \
-    --hash=sha256:f367e4b524cb319e50acbdea57bb63c3b717c5d561974ace0b065a648bb3bad3 \
-    --hash=sha256:f6cfe44a5d7c7d5f1017a7da1c8160304091ca5dc64a0f85bca0d63008c3137a \
-    --hash=sha256:fa66cac32861500f280bb60fe7d5b3e22d68c51e18e65367e38f8669b78cea3b \
-    --hash=sha256:fc8bf2e7bc725e76c0c11e474634a08c8f24bcf7426c0c6d60c8f9c6e70e4d4a \
-    --hash=sha256:fe976910de34d21057bcb53b2c5e667843588b48bf11339da2a75f5c4c5b4055
+grpcio==1.59.2 \
+    --hash=sha256:023088764012411affe7db183d1ada3ad9daf2e23ddc719ff46d7061de661340 \
+    --hash=sha256:08d77e682f2bf730a4961eea330e56d2f423c6a9b91ca222e5b1eb24a357b19f \
+    --hash=sha256:0a4a3833c0e067f3558538727235cd8a49709bff1003200bbdefa2f09334e4b1 \
+    --hash=sha256:0a754aff9e3af63bdc4c75c234b86b9d14e14a28a30c4e324aed1a9b873d755f \
+    --hash=sha256:11168ef43e4a43ff1b1a65859f3e0ef1a173e277349e7fb16923ff108160a8cd \
+    --hash=sha256:128e20f57c5f27cb0157e73756d1586b83c1b513ebecc83ea0ac37e4b0e4e758 \
+    --hash=sha256:1f9524d1d701e399462d2c90ba7c193e49d1711cf429c0d3d97c966856e03d00 \
+    --hash=sha256:1ff16d68bf453275466a9a46739061a63584d92f18a0f5b33d19fc97eb69867c \
+    --hash=sha256:2067274c88bc6de89c278a672a652b4247d088811ece781a4858b09bdf8448e3 \
+    --hash=sha256:2171c39f355ba5b551c5d5928d65aa6c69807fae195b86ef4a7d125bcdb860a9 \
+    --hash=sha256:242adc47725b9a499ee77c6a2e36688fa6c96484611f33b1be4c57ab075a92dd \
+    --hash=sha256:27f879ae604a7fcf371e59fba6f3ff4635a4c2a64768bd83ff0cac503142fef4 \
+    --hash=sha256:2b230028a008ae1d0f430acb227d323ff8a619017415cf334c38b457f814119f \
+    --hash=sha256:3059668df17627f0e0fa680e9ef8c995c946c792612e9518f5cc1503be14e90b \
+    --hash=sha256:31176aa88f36020055ace9adff2405a33c8bdbfa72a9c4980e25d91b2f196873 \
+    --hash=sha256:36f53c2b3449c015880e7d55a89c992c357f176327b0d2873cdaaf9628a37c69 \
+    --hash=sha256:3b4368b33908f683a363f376dfb747d40af3463a6e5044afee07cf9436addf96 \
+    --hash=sha256:3c61d641d4f409c5ae46bfdd89ea42ce5ea233dcf69e74ce9ba32b503c727e29 \
+    --hash=sha256:4abb717e320e74959517dc8e84a9f48fbe90e9abe19c248541e9418b1ce60acd \
+    --hash=sha256:4c93f4abbb54321ee6471e04a00139c80c754eda51064187963ddf98f5cf36a4 \
+    --hash=sha256:535561990e075fa6bd4b16c4c3c1096b9581b7bb35d96fac4650f1181e428268 \
+    --hash=sha256:53c9aa5ddd6857c0a1cd0287225a2a25873a8e09727c2e95c4aebb1be83a766a \
+    --hash=sha256:5d573e70a6fe77555fb6143c12d3a7d3fa306632a3034b4e7c59ca09721546f8 \
+    --hash=sha256:6009386a2df66159f64ac9f20425ae25229b29b9dd0e1d3dd60043f037e2ad7e \
+    --hash=sha256:686e975a5d16602dc0982c7c703948d17184bd1397e16c8ee03511ecb8c4cdda \
+    --hash=sha256:6959fb07e8351e20501ffb8cc4074c39a0b7ef123e1c850a7f8f3afdc3a3da01 \
+    --hash=sha256:6b25ed37c27e652db01be341af93fbcea03d296c024d8a0e680017a268eb85dd \
+    --hash=sha256:6da6dea3a1bacf99b3c2187e296db9a83029ed9c38fd4c52b7c9b7326d13c828 \
+    --hash=sha256:72ca2399097c0b758198f2ff30f7178d680de8a5cfcf3d9b73a63cf87455532e \
+    --hash=sha256:73abb8584b0cf74d37f5ef61c10722adc7275502ab71789a8fe3cb7ef04cf6e2 \
+    --hash=sha256:74100fecaec8a535e380cf5f2fb556ff84957d481c13e54051c52e5baac70541 \
+    --hash=sha256:75c6ecb70e809cf1504465174343113f51f24bc61e22a80ae1c859f3f7034c6d \
+    --hash=sha256:7cf05053242f61ba94014dd3a986e11a083400a32664058f80bf4cf817c0b3a1 \
+    --hash=sha256:9411e24328a2302e279e70cae6e479f1fddde79629fcb14e03e6d94b3956eabf \
+    --hash=sha256:a213acfbf186b9f35803b52e4ca9addb153fc0b67f82a48f961be7000ecf6721 \
+    --hash=sha256:bb7e0fe6ad73b7f06d7e2b689c19a71cf5cc48f0c2bf8608469e51ffe0bd2867 \
+    --hash=sha256:c2504eed520958a5b77cc99458297cb7906308cb92327f35fb7fbbad4e9b2188 \
+    --hash=sha256:c35aa9657f5d5116d23b934568e0956bd50c615127810fffe3ac356a914c176a \
+    --hash=sha256:c5f09cffa619adfb44799fa4a81c2a1ad77c887187613fb0a8f201ab38d89ba1 \
+    --hash=sha256:c978f864b35f2261e0819f5cd88b9830b04dc51bcf055aac3c601e525a10d2ba \
+    --hash=sha256:cbe946b3e6e60a7b4618f091e62a029cb082b109a9d6b53962dd305087c6e4fd \
+    --hash=sha256:cc3e4cd087f07758b16bef8f31d88dbb1b5da5671d2f03685ab52dece3d7a16e \
+    --hash=sha256:cf0dead5a2c5a3347af2cfec7131d4f2a2e03c934af28989c9078f8241a491fa \
+    --hash=sha256:d2794f0e68b3085d99b4f6ff9c089f6fdd02b32b9d3efdfbb55beac1bf22d516 \
+    --hash=sha256:d2fa68a96a30dd240be80bbad838a0ac81a61770611ff7952b889485970c4c71 \
+    --hash=sha256:d6f70406695e3220f09cd7a2f879333279d91aa4a8a1d34303b56d61a8180137 \
+    --hash=sha256:d8f9cd4ad1be90b0cf350a2f04a38a36e44a026cac1e036ac593dc48efe91d52 \
+    --hash=sha256:da2d94c15f88cd40d7e67f7919d4f60110d2b9d5b1e08cf354c2be773ab13479 \
+    --hash=sha256:e1727c1c0e394096bb9af185c6923e8ea55a5095b8af44f06903bcc0e06800a2 \
+    --hash=sha256:e420ced29b5904cdf9ee5545e23f9406189d8acb6750916c2db4793dada065c6 \
+    --hash=sha256:e82c5cf1495244adf5252f925ac5932e5fd288b3e5ab6b70bec5593074b7236c \
+    --hash=sha256:f1ef0d39bc1feb420caf549b3c657c871cad4ebbcf0580c4d03816b0590de0cf \
+    --hash=sha256:f8753a6c88d1d0ba64302309eecf20f70d2770f65ca02d83c2452279085bfcd3 \
+    --hash=sha256:f93dbf58f03146164048be5426ffde298b237a5e059144847e4940f5b80172c3
     # via
     #   -r requirements.in
     #   tb-nightly
@@ -212,7 +255,9 @@
     --hash=sha256:d93adc48ceeb33347eb24a634fb787efc7ae4644e6ea4ba733d099605045c049 \
     --hash=sha256:f42e6c30698b520f0295d70157c4e202a9e402406f50dc08f5a7bc416b24e52d \
     --hash=sha256:fd6f6d1384a9f491732cee233b99cd4bfd6e838a8815cc86722f9d2ee64032af
-    # via -r requirements.in
+    # via
+    #   -r requirements.in
+    #   keras-nightly
 idna==3.4 \
     --hash=sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4 \
     --hash=sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2
@@ -220,17 +265,21 @@
 jax==0.4.7 \
     --hash=sha256:5e7002d74db25f97c99b979d4ba1233b1ef26e1597e5fc468ad11d1c8a9dc4f8
     # via -r requirements.in
-keras-nightly==2.15.0.dev2023092207 \
-    --hash=sha256:ec9488751d201931a202570ba07788ce211a63ce678c4ea0294efb77cf579d32 \
-    --hash=sha256:f60e0105c18ac451ed52ffef039a4c6682b322c9b760904c12a9be08dcbf7746
+keras-nightly==3.0.0.dev2023103103 \
+    --hash=sha256:25a6a9030d7e067d535ec41ae6700636c90cabada80fbddb3bb8775006807860 \
+    --hash=sha256:c39fccebb3d4cc9d838371981c4d3eef08fc06eadf6cffb39b356cb625bab50f
     # via -r requirements.in
-lit==17.0.2 \
-    --hash=sha256:d6a551eab550f81023c82a260cd484d63970d2be9fd7588111208e7d2ff62212
+lit==17.0.4 \
+    --hash=sha256:ee2e180128e770abc6aed3a02de2daf09d81b7d30225e315205d3599c311d304
     # via -r requirements.in
 markdown==3.5 \
     --hash=sha256:4afb124395ce5fc34e6d9886dab977fd9ae987fc6e85689f08278cf0c69d4bf3 \
     --hash=sha256:a807eb2e4778d9156c8f07876c6e4d50b5494c5665c4834f67b06459dfd877b3
     # via tb-nightly
+markdown-it-py==3.0.0 \
+    --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \
+    --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb
+    # via rich
 markupsafe==2.1.3 \
     --hash=sha256:05fb21170423db021895e1ea1e1f3ab3adb85d1c2333cbc2310f2a26bc77272e \
     --hash=sha256:0a4e4a1aff6c7ac4cd55792abf96c915634c2b97e3cc1c7129578aa68ebd754e \
@@ -293,6 +342,10 @@
     --hash=sha256:fec21693218efe39aa7f8599346e90c705afa52c5b31ae019b2e57e8f6542bb2 \
     --hash=sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11
     # via werkzeug
+mdurl==0.1.2 \
+    --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \
+    --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba
+    # via markdown-it-py
 ml-dtypes==0.3.1 \
     --hash=sha256:3d8ca0acbd377082792d8b97081ba580abdad67c6afb7f827012c675b052f058 \
     --hash=sha256:42a8980afd8b7c8e270e8b5c260237286b5b26acd276fcb758d13cd7cb567e99 \
@@ -312,43 +365,48 @@
     --hash=sha256:f83ff080df8910c0f987f615b03e4f8198638e0c00c6e679ea8892dda909763b \
     --hash=sha256:fcae2c69715410d96906e1dfe8f017d9f78a0d10e0df91aae52e91f51fdfe45e
     # via jax
-numpy==1.26.0 ; python_version >= "3.12" \
-    --hash=sha256:020cdbee66ed46b671429c7265cf00d8ac91c046901c55684954c3958525dab2 \
-    --hash=sha256:0621f7daf973d34d18b4e4bafb210bbaf1ef5e0100b5fa750bd9cde84c7ac292 \
-    --hash=sha256:0792824ce2f7ea0c82ed2e4fecc29bb86bee0567a080dacaf2e0a01fe7654369 \
-    --hash=sha256:09aaee96c2cbdea95de76ecb8a586cb687d281c881f5f17bfc0fb7f5890f6b91 \
-    --hash=sha256:166b36197e9debc4e384e9c652ba60c0bacc216d0fc89e78f973a9760b503388 \
-    --hash=sha256:186ba67fad3c60dbe8a3abff3b67a91351100f2661c8e2a80364ae6279720299 \
-    --hash=sha256:306545e234503a24fe9ae95ebf84d25cba1fdc27db971aa2d9f1ab6bba19a9dd \
-    --hash=sha256:436c8e9a4bdeeee84e3e59614d38c3dbd3235838a877af8c211cfcac8a80b8d3 \
-    --hash=sha256:4a873a8180479bc829313e8d9798d5234dfacfc2e8a7ac188418189bb8eafbd2 \
-    --hash=sha256:4acc65dd65da28060e206c8f27a573455ed724e6179941edb19f97e58161bb69 \
-    --hash=sha256:51be5f8c349fdd1a5568e72713a21f518e7d6707bcf8503b528b88d33b57dc68 \
-    --hash=sha256:546b7dd7e22f3c6861463bebb000646fa730e55df5ee4a0224408b5694cc6148 \
-    --hash=sha256:5671338034b820c8d58c81ad1dafc0ed5a00771a82fccc71d6438df00302094b \
-    --hash=sha256:637c58b468a69869258b8ae26f4a4c6ff8abffd4a8334c830ffb63e0feefe99a \
-    --hash=sha256:767254ad364991ccfc4d81b8152912e53e103ec192d1bb4ea6b1f5a7117040be \
-    --hash=sha256:7d484292eaeb3e84a51432a94f53578689ffdea3f90e10c8b203a99be5af57d8 \
-    --hash=sha256:7f6bad22a791226d0a5c7c27a80a20e11cfe09ad5ef9084d4d3fc4a299cca505 \
-    --hash=sha256:86f737708b366c36b76e953c46ba5827d8c27b7a8c9d0f471810728e5a2fe57c \
-    --hash=sha256:8c6adc33561bd1d46f81131d5352348350fc23df4d742bb246cdfca606ea1208 \
-    --hash=sha256:914b28d3215e0c721dc75db3ad6d62f51f630cb0c277e6b3bcb39519bed10bd8 \
-    --hash=sha256:b44e6a09afc12952a7d2a58ca0a2429ee0d49a4f89d83a0a11052da696440e49 \
-    --hash=sha256:bb0d9a1aaf5f1cb7967320e80690a1d7ff69f1d47ebc5a9bea013e3a21faec95 \
-    --hash=sha256:c0b45c8b65b79337dee5134d038346d30e109e9e2e9d43464a2970e5c0e93229 \
-    --hash=sha256:c2e698cb0c6dda9372ea98a0344245ee65bdc1c9dd939cceed6bb91256837896 \
-    --hash=sha256:c78a22e95182fb2e7874712433eaa610478a3caf86f28c621708d35fa4fd6e7f \
-    --hash=sha256:e062aa24638bb5018b7841977c360d2f5917268d125c833a686b7cbabbec496c \
-    --hash=sha256:e5e18e5b14a7560d8acf1c596688f4dfd19b4f2945b245a71e5af4ddb7422feb \
-    --hash=sha256:eae430ecf5794cb7ae7fa3808740b015aa80747e5266153128ef055975a72b99 \
-    --hash=sha256:ee84ca3c58fe48b8ddafdeb1db87388dce2c3c3f701bf447b05e4cfcc3679112 \
-    --hash=sha256:f042f66d0b4ae6d48e70e28d487376204d3cbf43b84c03bac57e28dac6151581 \
-    --hash=sha256:f8db2f125746e44dce707dd44d4f4efeea8d7e2b43aace3f8d1f235cfa2733dd \
-    --hash=sha256:f93fc78fe8bf15afe2b8d6b6499f1c73953169fad1e9a8dd086cdff3190e7fdf
+namex==0.0.7 \
+    --hash=sha256:84ba65bc4d22bd909e3d26bf2ffb4b9529b608cb3f9a4336f776b04204ced69b \
+    --hash=sha256:8a4f062945f405d77cb66b907f16aa2fd83681945e998be840eb6c4154d40108
+    # via keras-nightly
+numpy==1.26.1 ; python_version >= "3.12" \
+    --hash=sha256:06934e1a22c54636a059215d6da99e23286424f316fddd979f5071093b648668 \
+    --hash=sha256:1c59c046c31a43310ad0199d6299e59f57a289e22f0f36951ced1c9eac3665b9 \
+    --hash=sha256:1d1bd82d539607951cac963388534da3b7ea0e18b149a53cf883d8f699178c0f \
+    --hash=sha256:1e11668d6f756ca5ef534b5be8653d16c5352cbb210a5c2a79ff288e937010d5 \
+    --hash=sha256:3649d566e2fc067597125428db15d60eb42a4e0897fc48d28cb75dc2e0454e53 \
+    --hash=sha256:59227c981d43425ca5e5c01094d59eb14e8772ce6975d4b2fc1e106a833d5ae2 \
+    --hash=sha256:6081aed64714a18c72b168a9276095ef9155dd7888b9e74b5987808f0dd0a974 \
+    --hash=sha256:6965888d65d2848e8768824ca8288db0a81263c1efccec881cb35a0d805fcd2f \
+    --hash=sha256:76ff661a867d9272cd2a99eed002470f46dbe0943a5ffd140f49be84f68ffc42 \
+    --hash=sha256:78ca54b2f9daffa5f323f34cdf21e1d9779a54073f0018a3094ab907938331a2 \
+    --hash=sha256:82e871307a6331b5f09efda3c22e03c095d957f04bf6bc1804f30048d0e5e7af \
+    --hash=sha256:8ab9163ca8aeb7fd32fe93866490654d2f7dda4e61bc6297bf72ce07fdc02f67 \
+    --hash=sha256:9696aa2e35cc41e398a6d42d147cf326f8f9d81befcb399bc1ed7ffea339b64e \
+    --hash=sha256:97e5d6a9f0702c2863aaabf19f0d1b6c2628fbe476438ce0b5ce06e83085064c \
+    --hash=sha256:9f42284ebf91bdf32fafac29d29d4c07e5e9d1af862ea73686581773ef9e73a7 \
+    --hash=sha256:a03fb25610ef560a6201ff06df4f8105292ba56e7cdd196ea350d123fc32e24e \
+    --hash=sha256:a5b411040beead47a228bde3b2241100454a6abde9df139ed087bd73fc0a4908 \
+    --hash=sha256:af22f3d8e228d84d1c0c44c1fbdeb80f97a15a0abe4f080960393a00db733b66 \
+    --hash=sha256:afd5ced4e5a96dac6725daeb5242a35494243f2239244fad10a90ce58b071d24 \
+    --hash=sha256:b9d45d1dbb9de84894cc50efece5b09939752a2d75aab3a8b0cef6f3a35ecd6b \
+    --hash=sha256:bb894accfd16b867d8643fc2ba6c8617c78ba2828051e9a69511644ce86ce83e \
+    --hash=sha256:c8c6c72d4a9f831f328efb1312642a1cafafaa88981d9ab76368d50d07d93cbe \
+    --hash=sha256:cd7837b2b734ca72959a1caf3309457a318c934abef7a43a14bb984e574bbb9a \
+    --hash=sha256:cdd9ec98f0063d93baeb01aad472a1a0840dee302842a2746a7a8e92968f9575 \
+    --hash=sha256:d1cfc92db6af1fd37a7bb58e55c8383b4aa1ba23d012bdbba26b4bcca45ac297 \
+    --hash=sha256:d1d2c6b7dd618c41e202c59c1413ef9b2c8e8a15f5039e344af64195459e3104 \
+    --hash=sha256:d2984cb6caaf05294b8466966627e80bf6c7afd273279077679cb010acb0e5ab \
+    --hash=sha256:d58e8c51a7cf43090d124d5073bc29ab2755822181fcad978b12e144e5e5a4b3 \
+    --hash=sha256:d78f269e0c4fd365fc2992c00353e4530d274ba68f15e968d8bc3c69ce5f5244 \
+    --hash=sha256:dcfaf015b79d1f9f9c9fd0731a907407dc3e45769262d657d754c3a028586124 \
+    --hash=sha256:e44ccb93f30c75dfc0c3aa3ce38f33486a75ec9abadabd4e59f114994a9c4617 \
+    --hash=sha256:e509cbc488c735b43b5ffea175235cec24bbc57b227ef1acc691725beb230d1c
     # via
     #   -r requirements.in
     #   h5py
     #   jax
+    #   keras-nightly
     #   ml-dtypes
     #   opt-einsum
     #   scipy
@@ -386,21 +444,23 @@
     --hash=sha256:effeac51ab79332d44fba74660d40ae79985901ac21bca408f8dc335a81aa597 \
     --hash=sha256:fee88269a090ada09ca63551bf2f573eb2424035bcf2cb1b121895b01a46594a
     # via tb-nightly
-psutil==5.9.5 \
-    --hash=sha256:104a5cc0e31baa2bcf67900be36acde157756b9c44017b86b2c049f11957887d \
-    --hash=sha256:3c6f686f4225553615612f6d9bc21f1c0e305f75d7d8454f9b46e901778e7217 \
-    --hash=sha256:4aef137f3345082a3d3232187aeb4ac4ef959ba3d7c10c33dd73763fbc063da4 \
-    --hash=sha256:5410638e4df39c54d957fc51ce03048acd8e6d60abc0f5107af51e5fb566eb3c \
-    --hash=sha256:5b9b8cb93f507e8dbaf22af6a2fd0ccbe8244bf30b1baad6b3954e935157ae3f \
-    --hash=sha256:7a7dd9997128a0d928ed4fb2c2d57e5102bb6089027939f3b722f3a210f9a8da \
-    --hash=sha256:89518112647f1276b03ca97b65cc7f64ca587b1eb0278383017c2a0dcc26cbe4 \
-    --hash=sha256:8c5f7c5a052d1d567db4ddd231a9d27a74e8e4a9c3f44b1032762bd7b9fdcd42 \
-    --hash=sha256:ab8ed1a1d77c95453db1ae00a3f9c50227ebd955437bcf2a574ba8adbf6a74d5 \
-    --hash=sha256:acf2aef9391710afded549ff602b5887d7a2349831ae4c26be7c807c0a39fac4 \
-    --hash=sha256:b258c0c1c9d145a1d5ceffab1134441c4c5113b2417fafff7315a917a026c3c9 \
-    --hash=sha256:be8929ce4313f9f8146caad4272f6abb8bf99fc6cf59344a3167ecd74f4f203f \
-    --hash=sha256:c607bb3b57dc779d55e1554846352b4e358c10fff3abf3514a7a6601beebdb30 \
-    --hash=sha256:ea8518d152174e1249c4f2a1c89e3e6065941df2fa13a1ab45327716a23c2b48
+psutil==5.9.6 \
+    --hash=sha256:10e8c17b4f898d64b121149afb136c53ea8b68c7531155147867b7b1ac9e7e28 \
+    --hash=sha256:18cd22c5db486f33998f37e2bb054cc62fd06646995285e02a51b1e08da97017 \
+    --hash=sha256:3ebf2158c16cc69db777e3c7decb3c0f43a7af94a60d72e87b2823aebac3d602 \
+    --hash=sha256:51dc3d54607c73148f63732c727856f5febec1c7c336f8f41fcbd6315cce76ac \
+    --hash=sha256:6e5fb8dc711a514da83098bc5234264e551ad980cec5f85dabf4d38ed6f15e9a \
+    --hash=sha256:70cb3beb98bc3fd5ac9ac617a327af7e7f826373ee64c80efd4eb2856e5051e9 \
+    --hash=sha256:748c9dd2583ed86347ed65d0035f45fa8c851e8d90354c122ab72319b5f366f4 \
+    --hash=sha256:91ecd2d9c00db9817a4b4192107cf6954addb5d9d67a969a4f436dbc9200f88c \
+    --hash=sha256:92e0cc43c524834af53e9d3369245e6cc3b130e78e26100d1f63cdb0abeb3d3c \
+    --hash=sha256:a6f01f03bf1843280f4ad16f4bde26b817847b4c1a0db59bf6419807bc5ce05c \
+    --hash=sha256:c69596f9fc2f8acd574a12d5f8b7b1ba3765a641ea5d60fb4736bf3c08a8214a \
+    --hash=sha256:ca2780f5e038379e520281e4c032dddd086906ddff9ef0d1b9dcf00710e5071c \
+    --hash=sha256:daecbcbd29b289aac14ece28eca6a3e60aa361754cf6da3dfb20d4d32b6c7f57 \
+    --hash=sha256:e4b92ddcd7dd4cdd3f900180ea1e104932c7bce234fb88976e2a3b296441225a \
+    --hash=sha256:fb8a697f11b0f5994550555fcfe3e69799e5b060c8ecf9e2f75c69302cc35c0d \
+    --hash=sha256:ff18b8d1a784b810df0b0fff3bcb50ab941c3b8e2c8de5726f9c71c601c611aa
     # via portpicker
 pyasn1==0.5.0 \
     --hash=sha256:87a2121042a1ac9358cabcaf1d07680ff97ee6404333bacca15f76aa8ad01a57 \
@@ -412,6 +472,10 @@
     --hash=sha256:5bd01446b736eb9d31512a30d46c1ac3395d676c6f3cafa4c03eb54b9925631c \
     --hash=sha256:d3ccd6ed470d9ffbc716be08bd90efbd44d0734bc9303818f7336070984a162d
     # via google-auth
+pygments==2.16.1 \
+    --hash=sha256:13fc09fa63bc8d8671a6d247e1eb303c4b343eaee81d861f3404db2935653692 \
+    --hash=sha256:1daff0494820c69bc8941e407aa20f577374ee88364ee10a98fdbe0aece96e29
+    # via rich
 requests==2.31.0 \
     --hash=sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f \
     --hash=sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1
@@ -423,6 +487,10 @@
     --hash=sha256:2577c501a2fb8d05a304c09d090d6e47c306fef15809d102b327cf8364bddab5 \
     --hash=sha256:75beac4a47881eeb94d5ea5d6ad31ef88856affe2332b9aafb52c6452ccf0d7a
     # via google-auth-oauthlib
+rich==13.6.0 \
+    --hash=sha256:2b38e2fe9ca72c9a00170a1a2d20c63c790d0e10ef1fe35eba76e1e7b1d7d245 \
+    --hash=sha256:5c14d22737e6d5084ef4771b62d5d4363165b403455a30a1c8ca39dc7b644bef
+    # via keras-nightly
 rsa==4.9 \
     --hash=sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7 \
     --hash=sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21
@@ -462,40 +530,40 @@
     # via
     #   astunparse
     #   tb-nightly
-tb-nightly==2.15.0a20231013 \
-    --hash=sha256:b116fd7d40d198465a087bd145e85bd306c76e0266ece6d0c9e7fb6e73054987
+tb-nightly==2.15.0a20231023 \
+    --hash=sha256:8990a52985e3296aa18a6825efc017bcded9e2fb2cbcdd2d01f5c62d4fdc9825
     # via -r requirements.in
 tblib==2.0.0 \
     --hash=sha256:9100bfa016b047d5b980d66e7efed952fbd20bd85b56110aaf473cb97d18709a \
     --hash=sha256:a6df30f272c08bf8be66e0775fad862005d950a6b8449b94f7c788731d70ecd7
     # via -r requirements.in
-tensorboard-data-server==0.7.1 \
-    --hash=sha256:255c02b7f5b03dd5c0a88c928e563441ff39e1d4b4a234cdbe09f016e53d9594 \
-    --hash=sha256:9938bd39f5041797b33921066fba0eab03a0dd10d1887a05e62ae58841ad4c3f \
-    --hash=sha256:be8d016a1aa394e6198280d4a3dc37898f56467310c5f5e617cac10a783e055a
+tensorboard-data-server==0.7.2 \
+    --hash=sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb \
+    --hash=sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60 \
+    --hash=sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530
     # via tb-nightly
 termcolor==2.3.0 \
     --hash=sha256:3afb05607b89aed0ffe25202399ee0867ad4d3cb4180d98aaf8eefa6a5f7d475 \
     --hash=sha256:b5b08f68937f138fe92f6c089b99f1e2da0ae56c52b78bf7075fd95420fd9a5a
     # via -r requirements.in
-tf-estimator-nightly==2.15.0.dev2023101308 \
-    --hash=sha256:d33078d257275f368d2b816141d47c8d86d819e97e4970093b4500314c12ce39
+tf-estimator-nightly==2.15.0.dev2023101608 \
+    --hash=sha256:fc045b32fb1a607da93799b3da0642527195a716cac424367f3c5f4edc2ec21e
     # via -r requirements.in
 typing-extensions==4.8.0 \
     --hash=sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0 \
     --hash=sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef
     # via -r requirements.in
-urllib3==2.0.6 \
-    --hash=sha256:7a7c7003b000adf9e7ca2a377c9688bbc54ed41b985789ed576570342a375cd2 \
-    --hash=sha256:b19e1a85d206b56d7df1d5e683df4a7725252a964e3993648dd0fb5a1c157564
+urllib3==2.0.7 \
+    --hash=sha256:c97dfde1f7bd43a71c8d2a58e369e9b2bf692d1334ea9f9cae55add7d0dd0f84 \
+    --hash=sha256:fdb6d215c776278489906c2f8916e6e7d4f5a9b602ccbcfdf7f016fc8da0596e
     # via requests
-werkzeug==3.0.0 \
-    --hash=sha256:3ffff4dcc32db52ef3cc94dff3000a3c2846890f3a5a51800a27b909c5e770f0 \
-    --hash=sha256:cbb2600f7eabe51dbc0502f58be0b3e1b96b893b05695ea2b35b43d4de2d9962
+werkzeug==3.0.1 \
+    --hash=sha256:507e811ecea72b18a404947aded4b3390e1db8f826b494d76550ef45bb3b1dcc \
+    --hash=sha256:90a285dc0e42ad56b34e696398b8122ee4c681833fb35b8334a095d82c56da10
     # via tb-nightly
-wheel==0.41.2 \
-    --hash=sha256:0c5ac5ff2afb79ac23ab82bab027a0be7b5dbcf2e54dc50efe4bf507de1f7985 \
-    --hash=sha256:75909db2664838d015e3d9139004ee16711748a52c8f336b52882266540215d8
+wheel==0.41.3 \
+    --hash=sha256:488609bc63a29322326e05560731bf7bfea8e48ad646e1f5e40d366607de0942 \
+    --hash=sha256:4d4987ce51a49370ea65c0bfd2234e8ce80a12780820d9dc462597a6e60d0841
     # via
     #   -r requirements.in
     #   astunparse
diff --git a/requirements_lock_3_9.txt b/requirements_lock_3_9.txt
index 82031a1..7fe16f9 100644
--- a/requirements_lock_3_9.txt
+++ b/requirements_lock_3_9.txt
@@ -12,105 +12,105 @@
     --hash=sha256:5ad93a8456f0d084c3456d059fd9a92cce667963232cbf763eac3bc5b7940872 \
     --hash=sha256:c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8
     # via -r requirements.in
-cachetools==5.3.1 \
-    --hash=sha256:95ef631eeaea14ba2e36f06437f36463aac3a096799e876ee55e5cdccb102590 \
-    --hash=sha256:dce83f2d9b4e1f732a8cd44af8e8fab2dbe46201467fc98b3ef8f269092bf62b
+cachetools==5.3.2 \
+    --hash=sha256:086ee420196f7b2ab9ca2db2520aca326318b68fe5ba8bc4d49cca91add450f2 \
+    --hash=sha256:861f35a13a451f94e301ce2bec7cac63e881232ccce7ed67fab9b5df4d3beaa1
     # via google-auth
 certifi==2023.7.22 \
     --hash=sha256:539cc1d13202e33ca466e88b2807e29f4c13049d6d87031a3c110744495cb082 \
     --hash=sha256:92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9
     # via requests
-charset-normalizer==3.3.0 \
-    --hash=sha256:02673e456dc5ab13659f85196c534dc596d4ef260e4d86e856c3b2773ce09843 \
-    --hash=sha256:02af06682e3590ab952599fbadac535ede5d60d78848e555aa58d0c0abbde786 \
-    --hash=sha256:03680bb39035fbcffe828eae9c3f8afc0428c91d38e7d61aa992ef7a59fb120e \
-    --hash=sha256:0570d21da019941634a531444364f2482e8db0b3425fcd5ac0c36565a64142c8 \
-    --hash=sha256:09c77f964f351a7369cc343911e0df63e762e42bac24cd7d18525961c81754f4 \
-    --hash=sha256:0d3d5b7db9ed8a2b11a774db2bbea7ba1884430a205dbd54a32d61d7c2a190fa \
-    --hash=sha256:1063da2c85b95f2d1a430f1c33b55c9c17ffaf5e612e10aeaad641c55a9e2b9d \
-    --hash=sha256:12ebea541c44fdc88ccb794a13fe861cc5e35d64ed689513a5c03d05b53b7c82 \
-    --hash=sha256:153e7b6e724761741e0974fc4dcd406d35ba70b92bfe3fedcb497226c93b9da7 \
-    --hash=sha256:15b26ddf78d57f1d143bdf32e820fd8935d36abe8a25eb9ec0b5a71c82eb3895 \
-    --hash=sha256:1872d01ac8c618a8da634e232f24793883d6e456a66593135aeafe3784b0848d \
-    --hash=sha256:187d18082694a29005ba2944c882344b6748d5be69e3a89bf3cc9d878e548d5a \
-    --hash=sha256:1b2919306936ac6efb3aed1fbf81039f7087ddadb3160882a57ee2ff74fd2382 \
-    --hash=sha256:232ac332403e37e4a03d209a3f92ed9071f7d3dbda70e2a5e9cff1c4ba9f0678 \
-    --hash=sha256:23e8565ab7ff33218530bc817922fae827420f143479b753104ab801145b1d5b \
-    --hash=sha256:24817cb02cbef7cd499f7c9a2735286b4782bd47a5b3516a0e84c50eab44b98e \
-    --hash=sha256:249c6470a2b60935bafd1d1d13cd613f8cd8388d53461c67397ee6a0f5dce741 \
-    --hash=sha256:24a91a981f185721542a0b7c92e9054b7ab4fea0508a795846bc5b0abf8118d4 \
-    --hash=sha256:2502dd2a736c879c0f0d3e2161e74d9907231e25d35794584b1ca5284e43f596 \
-    --hash=sha256:250c9eb0f4600361dd80d46112213dff2286231d92d3e52af1e5a6083d10cad9 \
-    --hash=sha256:278c296c6f96fa686d74eb449ea1697f3c03dc28b75f873b65b5201806346a69 \
-    --hash=sha256:2935ffc78db9645cb2086c2f8f4cfd23d9b73cc0dc80334bc30aac6f03f68f8c \
-    --hash=sha256:2f4a0033ce9a76e391542c182f0d48d084855b5fcba5010f707c8e8c34663d77 \
-    --hash=sha256:30a85aed0b864ac88309b7d94be09f6046c834ef60762a8833b660139cfbad13 \
-    --hash=sha256:380c4bde80bce25c6e4f77b19386f5ec9db230df9f2f2ac1e5ad7af2caa70459 \
-    --hash=sha256:3ae38d325b512f63f8da31f826e6cb6c367336f95e418137286ba362925c877e \
-    --hash=sha256:3b447982ad46348c02cb90d230b75ac34e9886273df3a93eec0539308a6296d7 \
-    --hash=sha256:3debd1150027933210c2fc321527c2299118aa929c2f5a0a80ab6953e3bd1908 \
-    --hash=sha256:4162918ef3098851fcd8a628bf9b6a98d10c380725df9e04caf5ca6dd48c847a \
-    --hash=sha256:468d2a840567b13a590e67dd276c570f8de00ed767ecc611994c301d0f8c014f \
-    --hash=sha256:4cc152c5dd831641e995764f9f0b6589519f6f5123258ccaca8c6d34572fefa8 \
-    --hash=sha256:542da1178c1c6af8873e143910e2269add130a299c9106eef2594e15dae5e482 \
-    --hash=sha256:557b21a44ceac6c6b9773bc65aa1b4cc3e248a5ad2f5b914b91579a32e22204d \
-    --hash=sha256:5707a746c6083a3a74b46b3a631d78d129edab06195a92a8ece755aac25a3f3d \
-    --hash=sha256:588245972aca710b5b68802c8cad9edaa98589b1b42ad2b53accd6910dad3545 \
-    --hash=sha256:5adf257bd58c1b8632046bbe43ee38c04e1038e9d37de9c57a94d6bd6ce5da34 \
-    --hash=sha256:619d1c96099be5823db34fe89e2582b336b5b074a7f47f819d6b3a57ff7bdb86 \
-    --hash=sha256:63563193aec44bce707e0c5ca64ff69fa72ed7cf34ce6e11d5127555756fd2f6 \
-    --hash=sha256:67b8cc9574bb518ec76dc8e705d4c39ae78bb96237cb533edac149352c1f39fe \
-    --hash=sha256:6a685067d05e46641d5d1623d7c7fdf15a357546cbb2f71b0ebde91b175ffc3e \
-    --hash=sha256:70f1d09c0d7748b73290b29219e854b3207aea922f839437870d8cc2168e31cc \
-    --hash=sha256:750b446b2ffce1739e8578576092179160f6d26bd5e23eb1789c4d64d5af7dc7 \
-    --hash=sha256:7966951325782121e67c81299a031f4c115615e68046f79b85856b86ebffc4cd \
-    --hash=sha256:7b8b8bf1189b3ba9b8de5c8db4d541b406611a71a955bbbd7385bbc45fcb786c \
-    --hash=sha256:7f5d10bae5d78e4551b7be7a9b29643a95aded9d0f602aa2ba584f0388e7a557 \
-    --hash=sha256:805dfea4ca10411a5296bcc75638017215a93ffb584c9e344731eef0dcfb026a \
-    --hash=sha256:81bf654678e575403736b85ba3a7867e31c2c30a69bc57fe88e3ace52fb17b89 \
-    --hash=sha256:82eb849f085624f6a607538ee7b83a6d8126df6d2f7d3b319cb837b289123078 \
-    --hash=sha256:85a32721ddde63c9df9ebb0d2045b9691d9750cb139c161c80e500d210f5e26e \
-    --hash=sha256:86d1f65ac145e2c9ed71d8ffb1905e9bba3a91ae29ba55b4c46ae6fc31d7c0d4 \
-    --hash=sha256:86f63face3a527284f7bb8a9d4f78988e3c06823f7bea2bd6f0e0e9298ca0403 \
-    --hash=sha256:8eaf82f0eccd1505cf39a45a6bd0a8cf1c70dcfc30dba338207a969d91b965c0 \
-    --hash=sha256:93aa7eef6ee71c629b51ef873991d6911b906d7312c6e8e99790c0f33c576f89 \
-    --hash=sha256:96c2b49eb6a72c0e4991d62406e365d87067ca14c1a729a870d22354e6f68115 \
-    --hash=sha256:9cf3126b85822c4e53aa28c7ec9869b924d6fcfb76e77a45c44b83d91afd74f9 \
-    --hash=sha256:9fe359b2e3a7729010060fbca442ca225280c16e923b37db0e955ac2a2b72a05 \
-    --hash=sha256:a0ac5e7015a5920cfce654c06618ec40c33e12801711da6b4258af59a8eff00a \
-    --hash=sha256:a3f93dab657839dfa61025056606600a11d0b696d79386f974e459a3fbc568ec \
-    --hash=sha256:a4b71f4d1765639372a3b32d2638197f5cd5221b19531f9245fcc9ee62d38f56 \
-    --hash=sha256:aae32c93e0f64469f74ccc730a7cb21c7610af3a775157e50bbd38f816536b38 \
-    --hash=sha256:aaf7b34c5bc56b38c931a54f7952f1ff0ae77a2e82496583b247f7c969eb1479 \
-    --hash=sha256:abecce40dfebbfa6abf8e324e1860092eeca6f7375c8c4e655a8afb61af58f2c \
-    --hash=sha256:abf0d9f45ea5fb95051c8bfe43cb40cda383772f7e5023a83cc481ca2604d74e \
-    --hash=sha256:ac71b2977fb90c35d41c9453116e283fac47bb9096ad917b8819ca8b943abecd \
-    --hash=sha256:ada214c6fa40f8d800e575de6b91a40d0548139e5dc457d2ebb61470abf50186 \
-    --hash=sha256:b09719a17a2301178fac4470d54b1680b18a5048b481cb8890e1ef820cb80455 \
-    --hash=sha256:b1121de0e9d6e6ca08289583d7491e7fcb18a439305b34a30b20d8215922d43c \
-    --hash=sha256:b3b2316b25644b23b54a6f6401074cebcecd1244c0b8e80111c9a3f1c8e83d65 \
-    --hash=sha256:b3d9b48ee6e3967b7901c052b670c7dda6deb812c309439adaffdec55c6d7b78 \
-    --hash=sha256:b5bcf60a228acae568e9911f410f9d9e0d43197d030ae5799e20dca8df588287 \
-    --hash=sha256:b8f3307af845803fb0b060ab76cf6dd3a13adc15b6b451f54281d25911eb92df \
-    --hash=sha256:c2af80fb58f0f24b3f3adcb9148e6203fa67dd3f61c4af146ecad033024dde43 \
-    --hash=sha256:c350354efb159b8767a6244c166f66e67506e06c8924ed74669b2c70bc8735b1 \
-    --hash=sha256:c5a74c359b2d47d26cdbbc7845e9662d6b08a1e915eb015d044729e92e7050b7 \
-    --hash=sha256:c71f16da1ed8949774ef79f4a0260d28b83b3a50c6576f8f4f0288d109777989 \
-    --hash=sha256:d47ecf253780c90ee181d4d871cd655a789da937454045b17b5798da9393901a \
-    --hash=sha256:d7eff0f27edc5afa9e405f7165f85a6d782d308f3b6b9d96016c010597958e63 \
-    --hash=sha256:d97d85fa63f315a8bdaba2af9a6a686e0eceab77b3089af45133252618e70884 \
-    --hash=sha256:db756e48f9c5c607b5e33dd36b1d5872d0422e960145b08ab0ec7fd420e9d649 \
-    --hash=sha256:dc45229747b67ffc441b3de2f3ae5e62877a282ea828a5bdb67883c4ee4a8810 \
-    --hash=sha256:e0fc42822278451bc13a2e8626cf2218ba570f27856b536e00cfa53099724828 \
-    --hash=sha256:e39c7eb31e3f5b1f88caff88bcff1b7f8334975b46f6ac6e9fc725d829bc35d4 \
-    --hash=sha256:e46cd37076971c1040fc8c41273a8b3e2c624ce4f2be3f5dfcb7a430c1d3acc2 \
-    --hash=sha256:e5c1502d4ace69a179305abb3f0bb6141cbe4714bc9b31d427329a95acfc8bdd \
-    --hash=sha256:edfe077ab09442d4ef3c52cb1f9dab89bff02f4524afc0acf2d46be17dc479f5 \
-    --hash=sha256:effe5406c9bd748a871dbcaf3ac69167c38d72db8c9baf3ff954c344f31c4cbe \
-    --hash=sha256:f0d1e3732768fecb052d90d62b220af62ead5748ac51ef61e7b32c266cac9293 \
-    --hash=sha256:f5969baeaea61c97efa706b9b107dcba02784b1601c74ac84f2a532ea079403e \
-    --hash=sha256:f8888e31e3a85943743f8fc15e71536bda1c81d5aa36d014a3c0c44481d7db6e \
-    --hash=sha256:fc52b79d83a3fe3a360902d3f5d79073a993597d48114c29485e9431092905d8
+charset-normalizer==3.3.1 \
+    --hash=sha256:06cf46bdff72f58645434d467bf5228080801298fbba19fe268a01b4534467f5 \
+    --hash=sha256:0c8c61fb505c7dad1d251c284e712d4e0372cef3b067f7ddf82a7fa82e1e9a93 \
+    --hash=sha256:10b8dd31e10f32410751b3430996f9807fc4d1587ca69772e2aa940a82ab571a \
+    --hash=sha256:1171ef1fc5ab4693c5d151ae0fdad7f7349920eabbaca6271f95969fa0756c2d \
+    --hash=sha256:17a866d61259c7de1bdadef418a37755050ddb4b922df8b356503234fff7932c \
+    --hash=sha256:1d6bfc32a68bc0933819cfdfe45f9abc3cae3877e1d90aac7259d57e6e0f85b1 \
+    --hash=sha256:1ec937546cad86d0dce5396748bf392bb7b62a9eeb8c66efac60e947697f0e58 \
+    --hash=sha256:223b4d54561c01048f657fa6ce41461d5ad8ff128b9678cfe8b2ecd951e3f8a2 \
+    --hash=sha256:2465aa50c9299d615d757c1c888bc6fef384b7c4aec81c05a0172b4400f98557 \
+    --hash=sha256:28f512b9a33235545fbbdac6a330a510b63be278a50071a336afc1b78781b147 \
+    --hash=sha256:2c092be3885a1b7899cd85ce24acedc1034199d6fca1483fa2c3a35c86e43041 \
+    --hash=sha256:2c4c99f98fc3a1835af8179dcc9013f93594d0670e2fa80c83aa36346ee763d2 \
+    --hash=sha256:31445f38053476a0c4e6d12b047b08ced81e2c7c712e5a1ad97bc913256f91b2 \
+    --hash=sha256:31bbaba7218904d2eabecf4feec0d07469284e952a27400f23b6628439439fa7 \
+    --hash=sha256:34d95638ff3613849f473afc33f65c401a89f3b9528d0d213c7037c398a51296 \
+    --hash=sha256:352a88c3df0d1fa886562384b86f9a9e27563d4704ee0e9d56ec6fcd270ea690 \
+    --hash=sha256:39b70a6f88eebe239fa775190796d55a33cfb6d36b9ffdd37843f7c4c1b5dc67 \
+    --hash=sha256:3c66df3f41abee950d6638adc7eac4730a306b022570f71dd0bd6ba53503ab57 \
+    --hash=sha256:3f70fd716855cd3b855316b226a1ac8bdb3caf4f7ea96edcccc6f484217c9597 \
+    --hash=sha256:3f9bc2ce123637a60ebe819f9fccc614da1bcc05798bbbaf2dd4ec91f3e08846 \
+    --hash=sha256:3fb765362688821404ad6cf86772fc54993ec11577cd5a92ac44b4c2ba52155b \
+    --hash=sha256:45f053a0ece92c734d874861ffe6e3cc92150e32136dd59ab1fb070575189c97 \
+    --hash=sha256:46fb9970aa5eeca547d7aa0de5d4b124a288b42eaefac677bde805013c95725c \
+    --hash=sha256:4cb50a0335382aac15c31b61d8531bc9bb657cfd848b1d7158009472189f3d62 \
+    --hash=sha256:4e12f8ee80aa35e746230a2af83e81bd6b52daa92a8afaef4fea4a2ce9b9f4fa \
+    --hash=sha256:4f3100d86dcd03c03f7e9c3fdb23d92e32abbca07e7c13ebd7ddfbcb06f5991f \
+    --hash=sha256:4f6e2a839f83a6a76854d12dbebde50e4b1afa63e27761549d006fa53e9aa80e \
+    --hash=sha256:4f861d94c2a450b974b86093c6c027888627b8082f1299dfd5a4bae8e2292821 \
+    --hash=sha256:501adc5eb6cd5f40a6f77fbd90e5ab915c8fd6e8c614af2db5561e16c600d6f3 \
+    --hash=sha256:520b7a142d2524f999447b3a0cf95115df81c4f33003c51a6ab637cbda9d0bf4 \
+    --hash=sha256:548eefad783ed787b38cb6f9a574bd8664468cc76d1538215d510a3cd41406cb \
+    --hash=sha256:555fe186da0068d3354cdf4bbcbc609b0ecae4d04c921cc13e209eece7720727 \
+    --hash=sha256:55602981b2dbf8184c098bc10287e8c245e351cd4fdcad050bd7199d5a8bf514 \
+    --hash=sha256:58e875eb7016fd014c0eea46c6fa92b87b62c0cb31b9feae25cbbe62c919f54d \
+    --hash=sha256:5a3580a4fdc4ac05f9e53c57f965e3594b2f99796231380adb2baaab96e22761 \
+    --hash=sha256:5b70bab78accbc672f50e878a5b73ca692f45f5b5e25c8066d748c09405e6a55 \
+    --hash=sha256:5ceca5876032362ae73b83347be8b5dbd2d1faf3358deb38c9c88776779b2e2f \
+    --hash=sha256:61f1e3fb621f5420523abb71f5771a204b33c21d31e7d9d86881b2cffe92c47c \
+    --hash=sha256:633968254f8d421e70f91c6ebe71ed0ab140220469cf87a9857e21c16687c034 \
+    --hash=sha256:63a6f59e2d01310f754c270e4a257426fe5a591dc487f1983b3bbe793cf6bac6 \
+    --hash=sha256:63accd11149c0f9a99e3bc095bbdb5a464862d77a7e309ad5938fbc8721235ae \
+    --hash=sha256:6db3cfb9b4fcecb4390db154e75b49578c87a3b9979b40cdf90d7e4b945656e1 \
+    --hash=sha256:71ef3b9be10070360f289aea4838c784f8b851be3ba58cf796262b57775c2f14 \
+    --hash=sha256:7ae8e5142dcc7a49168f4055255dbcced01dc1714a90a21f87448dc8d90617d1 \
+    --hash=sha256:7b6cefa579e1237ce198619b76eaa148b71894fb0d6bcf9024460f9bf30fd228 \
+    --hash=sha256:800561453acdecedaac137bf09cd719c7a440b6800ec182f077bb8e7025fb708 \
+    --hash=sha256:82ca51ff0fc5b641a2d4e1cc8c5ff108699b7a56d7f3ad6f6da9dbb6f0145b48 \
+    --hash=sha256:851cf693fb3aaef71031237cd68699dded198657ec1e76a76eb8be58c03a5d1f \
+    --hash=sha256:854cc74367180beb327ab9d00f964f6d91da06450b0855cbbb09187bcdb02de5 \
+    --hash=sha256:87071618d3d8ec8b186d53cb6e66955ef2a0e4fa63ccd3709c0c90ac5a43520f \
+    --hash=sha256:871d045d6ccc181fd863a3cd66ee8e395523ebfbc57f85f91f035f50cee8e3d4 \
+    --hash=sha256:8aee051c89e13565c6bd366813c386939f8e928af93c29fda4af86d25b73d8f8 \
+    --hash=sha256:8af5a8917b8af42295e86b64903156b4f110a30dca5f3b5aedea123fbd638bff \
+    --hash=sha256:8ec8ef42c6cd5856a7613dcd1eaf21e5573b2185263d87d27c8edcae33b62a61 \
+    --hash=sha256:91e43805ccafa0a91831f9cd5443aa34528c0c3f2cc48c4cb3d9a7721053874b \
+    --hash=sha256:9505dc359edb6a330efcd2be825fdb73ee3e628d9010597aa1aee5aa63442e97 \
+    --hash=sha256:985c7965f62f6f32bf432e2681173db41336a9c2611693247069288bcb0c7f8b \
+    --hash=sha256:9a74041ba0bfa9bc9b9bb2cd3238a6ab3b7618e759b41bd15b5f6ad958d17605 \
+    --hash=sha256:9edbe6a5bf8b56a4a84533ba2b2f489d0046e755c29616ef8830f9e7d9cf5728 \
+    --hash=sha256:a15c1fe6d26e83fd2e5972425a772cca158eae58b05d4a25a4e474c221053e2d \
+    --hash=sha256:a66bcdf19c1a523e41b8e9d53d0cedbfbac2e93c649a2e9502cb26c014d0980c \
+    --hash=sha256:ae4070f741f8d809075ef697877fd350ecf0b7c5837ed68738607ee0a2c572cf \
+    --hash=sha256:ae55d592b02c4349525b6ed8f74c692509e5adffa842e582c0f861751701a673 \
+    --hash=sha256:b578cbe580e3b41ad17b1c428f382c814b32a6ce90f2d8e39e2e635d49e498d1 \
+    --hash=sha256:b891a2f68e09c5ef989007fac11476ed33c5c9994449a4e2c3386529d703dc8b \
+    --hash=sha256:baec8148d6b8bd5cee1ae138ba658c71f5b03e0d69d5907703e3e1df96db5e41 \
+    --hash=sha256:bb06098d019766ca16fc915ecaa455c1f1cd594204e7f840cd6258237b5079a8 \
+    --hash=sha256:bc791ec3fd0c4309a753f95bb6c749ef0d8ea3aea91f07ee1cf06b7b02118f2f \
+    --hash=sha256:bd28b31730f0e982ace8663d108e01199098432a30a4c410d06fe08fdb9e93f4 \
+    --hash=sha256:be4d9c2770044a59715eb57c1144dedea7c5d5ae80c68fb9959515037cde2008 \
+    --hash=sha256:c0c72d34e7de5604df0fde3644cc079feee5e55464967d10b24b1de268deceb9 \
+    --hash=sha256:c0e842112fe3f1a4ffcf64b06dc4c61a88441c2f02f373367f7b4c1aa9be2ad5 \
+    --hash=sha256:c15070ebf11b8b7fd1bfff7217e9324963c82dbdf6182ff7050519e350e7ad9f \
+    --hash=sha256:c2000c54c395d9e5e44c99dc7c20a64dc371f777faf8bae4919ad3e99ce5253e \
+    --hash=sha256:c30187840d36d0ba2893bc3271a36a517a717f9fd383a98e2697ee890a37c273 \
+    --hash=sha256:cb7cd68814308aade9d0c93c5bd2ade9f9441666f8ba5aa9c2d4b389cb5e2a45 \
+    --hash=sha256:cd805513198304026bd379d1d516afbf6c3c13f4382134a2c526b8b854da1c2e \
+    --hash=sha256:d0bf89afcbcf4d1bb2652f6580e5e55a840fdf87384f6063c4a4f0c95e378656 \
+    --hash=sha256:d9137a876020661972ca6eec0766d81aef8a5627df628b664b234b73396e727e \
+    --hash=sha256:dbd95e300367aa0827496fe75a1766d198d34385a58f97683fe6e07f89ca3e3c \
+    --hash=sha256:dced27917823df984fe0c80a5c4ad75cf58df0fbfae890bc08004cd3888922a2 \
+    --hash=sha256:de0b4caa1c8a21394e8ce971997614a17648f94e1cd0640fbd6b4d14cab13a72 \
+    --hash=sha256:debb633f3f7856f95ad957d9b9c781f8e2c6303ef21724ec94bea2ce2fcbd056 \
+    --hash=sha256:e372d7dfd154009142631de2d316adad3cc1c36c32a38b16a4751ba78da2a397 \
+    --hash=sha256:ecd26be9f112c4f96718290c10f4caea6cc798459a3a76636b817a0ed7874e42 \
+    --hash=sha256:edc0202099ea1d82844316604e17d2b175044f9bcb6b398aab781eba957224bd \
+    --hash=sha256:f194cce575e59ffe442c10a360182a986535fd90b57f7debfaa5c845c409ecc3 \
+    --hash=sha256:f5fb672c396d826ca16a022ac04c9dce74e00a1c344f6ad1a0fdc1ba1f332213 \
+    --hash=sha256:f6a02a3c7950cafaadcd46a226ad9e12fc9744652cc69f9e5534f98b47f3bbcf \
+    --hash=sha256:fe81b35c33772e56f4b6cf62cf4aedc1762ef7162a31e6ac7fe5e40d0149eb67
     # via requests
 dill==0.3.7 \
     --hash=sha256:76b122c08ef4ce2eedcd4d1abd8e641114bfc6c2867f49f3c41facf65bf19f5e \
@@ -171,61 +171,61 @@
     --hash=sha256:089c6e587d36f4803ac7e0720c045c6a8b1fd1790088b8424975b90d0ee61c12 \
     --hash=sha256:83ea8c3b0881e453790baff4448e8a6112ac8778d1de9da0b68010b843937afb
     # via tb-nightly
-grpcio==1.59.0 \
-    --hash=sha256:0ae444221b2c16d8211b55326f8ba173ba8f8c76349bfc1768198ba592b58f74 \
-    --hash=sha256:0b84445fa94d59e6806c10266b977f92fa997db3585f125d6b751af02ff8b9fe \
-    --hash=sha256:14890da86a0c0e9dc1ea8e90101d7a3e0e7b1e71f4487fab36e2bfd2ecadd13c \
-    --hash=sha256:15f03bd714f987d48ae57fe092cf81960ae36da4e520e729392a59a75cda4f29 \
-    --hash=sha256:1a839ba86764cc48226f50b924216000c79779c563a301586a107bda9cbe9dcf \
-    --hash=sha256:225e5fa61c35eeaebb4e7491cd2d768cd8eb6ed00f2664fa83a58f29418b39fd \
-    --hash=sha256:228b91ce454876d7eed74041aff24a8f04c0306b7250a2da99d35dd25e2a1211 \
-    --hash=sha256:2ea95cd6abbe20138b8df965b4a8674ec312aaef3147c0f46a0bac661f09e8d0 \
-    --hash=sha256:2f120d27051e4c59db2f267b71b833796770d3ea36ca712befa8c5fff5da6ebd \
-    --hash=sha256:34341d9e81a4b669a5f5dca3b2a760b6798e95cdda2b173e65d29d0b16692857 \
-    --hash=sha256:3859917de234a0a2a52132489c4425a73669de9c458b01c9a83687f1f31b5b10 \
-    --hash=sha256:38823bd088c69f59966f594d087d3a929d1ef310506bee9e3648317660d65b81 \
-    --hash=sha256:38da5310ef84e16d638ad89550b5b9424df508fd5c7b968b90eb9629ca9be4b9 \
-    --hash=sha256:3b8ff795d35a93d1df6531f31c1502673d1cebeeba93d0f9bd74617381507e3f \
-    --hash=sha256:50eff97397e29eeee5df106ea1afce3ee134d567aa2c8e04fabab05c79d791a7 \
-    --hash=sha256:5711c51e204dc52065f4a3327dca46e69636a0b76d3e98c2c28c4ccef9b04c52 \
-    --hash=sha256:598f3530231cf10ae03f4ab92d48c3be1fee0c52213a1d5958df1a90957e6a88 \
-    --hash=sha256:611d9aa0017fa386809bddcb76653a5ab18c264faf4d9ff35cb904d44745f575 \
-    --hash=sha256:61bc72a00ecc2b79d9695220b4d02e8ba53b702b42411397e831c9b0589f08a3 \
-    --hash=sha256:63982150a7d598281fa1d7ffead6096e543ff8be189d3235dd2b5604f2c553e5 \
-    --hash=sha256:6c4b1cc3a9dc1924d2eb26eec8792fedd4b3fcd10111e26c1d551f2e4eda79ce \
-    --hash=sha256:81d86a096ccd24a57fa5772a544c9e566218bc4de49e8c909882dae9d73392df \
-    --hash=sha256:849c47ef42424c86af069a9c5e691a765e304079755d5c29eff511263fad9c2a \
-    --hash=sha256:871371ce0c0055d3db2a86fdebd1e1d647cf21a8912acc30052660297a5a6901 \
-    --hash=sha256:8cd2d38c2d52f607d75a74143113174c36d8a416d9472415eab834f837580cf7 \
-    --hash=sha256:936b2e04663660c600d5173bc2cc84e15adbad9c8f71946eb833b0afc205b996 \
-    --hash=sha256:93e9cb546e610829e462147ce724a9cb108e61647a3454500438a6deef610be1 \
-    --hash=sha256:956f0b7cb465a65de1bd90d5a7475b4dc55089b25042fe0f6c870707e9aabb1d \
-    --hash=sha256:986de4aa75646e963466b386a8c5055c8b23a26a36a6c99052385d6fe8aaf180 \
-    --hash=sha256:aca8a24fef80bef73f83eb8153f5f5a0134d9539b4c436a716256b311dda90a6 \
-    --hash=sha256:acf70a63cf09dd494000007b798aff88a436e1c03b394995ce450be437b8e54f \
-    --hash=sha256:b34c7a4c31841a2ea27246a05eed8a80c319bfc0d3e644412ec9ce437105ff6c \
-    --hash=sha256:b95ec8ecc4f703f5caaa8d96e93e40c7f589bad299a2617bdb8becbcce525539 \
-    --hash=sha256:ba0ca727a173ee093f49ead932c051af463258b4b493b956a2c099696f38aa66 \
-    --hash=sha256:c041a91712bf23b2a910f61e16565a05869e505dc5a5c025d429ca6de5de842c \
-    --hash=sha256:c0488c2b0528e6072010182075615620071371701733c63ab5be49140ed8f7f0 \
-    --hash=sha256:c173a87d622ea074ce79be33b952f0b424fa92182063c3bda8625c11d3585d09 \
-    --hash=sha256:c251d22de8f9f5cca9ee47e4bade7c5c853e6e40743f47f5cc02288ee7a87252 \
-    --hash=sha256:c4dfdb49f4997dc664f30116af2d34751b91aa031f8c8ee251ce4dcfc11277b0 \
-    --hash=sha256:ca87ee6183421b7cea3544190061f6c1c3dfc959e0b57a5286b108511fd34ff4 \
-    --hash=sha256:ceb1e68135788c3fce2211de86a7597591f0b9a0d2bb80e8401fd1d915991bac \
-    --hash=sha256:d09bd2a4e9f5a44d36bb8684f284835c14d30c22d8ec92ce796655af12163588 \
-    --hash=sha256:d0fcf53df684fcc0154b1e61f6b4a8c4cf5f49d98a63511e3f30966feff39cd0 \
-    --hash=sha256:d74f7d2d7c242a6af9d4d069552ec3669965b74fed6b92946e0e13b4168374f9 \
-    --hash=sha256:de2599985b7c1b4ce7526e15c969d66b93687571aa008ca749d6235d056b7205 \
-    --hash=sha256:e5378785dce2b91eb2e5b857ec7602305a3b5cf78311767146464bfa365fc897 \
-    --hash=sha256:ec78aebb9b6771d6a1de7b6ca2f779a2f6113b9108d486e904bde323d51f5589 \
-    --hash=sha256:f1feb034321ae2f718172d86b8276c03599846dc7bb1792ae370af02718f91c5 \
-    --hash=sha256:f21917aa50b40842b51aff2de6ebf9e2f6af3fe0971c31960ad6a3a2b24988f4 \
-    --hash=sha256:f367e4b524cb319e50acbdea57bb63c3b717c5d561974ace0b065a648bb3bad3 \
-    --hash=sha256:f6cfe44a5d7c7d5f1017a7da1c8160304091ca5dc64a0f85bca0d63008c3137a \
-    --hash=sha256:fa66cac32861500f280bb60fe7d5b3e22d68c51e18e65367e38f8669b78cea3b \
-    --hash=sha256:fc8bf2e7bc725e76c0c11e474634a08c8f24bcf7426c0c6d60c8f9c6e70e4d4a \
-    --hash=sha256:fe976910de34d21057bcb53b2c5e667843588b48bf11339da2a75f5c4c5b4055
+grpcio==1.59.2 \
+    --hash=sha256:023088764012411affe7db183d1ada3ad9daf2e23ddc719ff46d7061de661340 \
+    --hash=sha256:08d77e682f2bf730a4961eea330e56d2f423c6a9b91ca222e5b1eb24a357b19f \
+    --hash=sha256:0a4a3833c0e067f3558538727235cd8a49709bff1003200bbdefa2f09334e4b1 \
+    --hash=sha256:0a754aff9e3af63bdc4c75c234b86b9d14e14a28a30c4e324aed1a9b873d755f \
+    --hash=sha256:11168ef43e4a43ff1b1a65859f3e0ef1a173e277349e7fb16923ff108160a8cd \
+    --hash=sha256:128e20f57c5f27cb0157e73756d1586b83c1b513ebecc83ea0ac37e4b0e4e758 \
+    --hash=sha256:1f9524d1d701e399462d2c90ba7c193e49d1711cf429c0d3d97c966856e03d00 \
+    --hash=sha256:1ff16d68bf453275466a9a46739061a63584d92f18a0f5b33d19fc97eb69867c \
+    --hash=sha256:2067274c88bc6de89c278a672a652b4247d088811ece781a4858b09bdf8448e3 \
+    --hash=sha256:2171c39f355ba5b551c5d5928d65aa6c69807fae195b86ef4a7d125bcdb860a9 \
+    --hash=sha256:242adc47725b9a499ee77c6a2e36688fa6c96484611f33b1be4c57ab075a92dd \
+    --hash=sha256:27f879ae604a7fcf371e59fba6f3ff4635a4c2a64768bd83ff0cac503142fef4 \
+    --hash=sha256:2b230028a008ae1d0f430acb227d323ff8a619017415cf334c38b457f814119f \
+    --hash=sha256:3059668df17627f0e0fa680e9ef8c995c946c792612e9518f5cc1503be14e90b \
+    --hash=sha256:31176aa88f36020055ace9adff2405a33c8bdbfa72a9c4980e25d91b2f196873 \
+    --hash=sha256:36f53c2b3449c015880e7d55a89c992c357f176327b0d2873cdaaf9628a37c69 \
+    --hash=sha256:3b4368b33908f683a363f376dfb747d40af3463a6e5044afee07cf9436addf96 \
+    --hash=sha256:3c61d641d4f409c5ae46bfdd89ea42ce5ea233dcf69e74ce9ba32b503c727e29 \
+    --hash=sha256:4abb717e320e74959517dc8e84a9f48fbe90e9abe19c248541e9418b1ce60acd \
+    --hash=sha256:4c93f4abbb54321ee6471e04a00139c80c754eda51064187963ddf98f5cf36a4 \
+    --hash=sha256:535561990e075fa6bd4b16c4c3c1096b9581b7bb35d96fac4650f1181e428268 \
+    --hash=sha256:53c9aa5ddd6857c0a1cd0287225a2a25873a8e09727c2e95c4aebb1be83a766a \
+    --hash=sha256:5d573e70a6fe77555fb6143c12d3a7d3fa306632a3034b4e7c59ca09721546f8 \
+    --hash=sha256:6009386a2df66159f64ac9f20425ae25229b29b9dd0e1d3dd60043f037e2ad7e \
+    --hash=sha256:686e975a5d16602dc0982c7c703948d17184bd1397e16c8ee03511ecb8c4cdda \
+    --hash=sha256:6959fb07e8351e20501ffb8cc4074c39a0b7ef123e1c850a7f8f3afdc3a3da01 \
+    --hash=sha256:6b25ed37c27e652db01be341af93fbcea03d296c024d8a0e680017a268eb85dd \
+    --hash=sha256:6da6dea3a1bacf99b3c2187e296db9a83029ed9c38fd4c52b7c9b7326d13c828 \
+    --hash=sha256:72ca2399097c0b758198f2ff30f7178d680de8a5cfcf3d9b73a63cf87455532e \
+    --hash=sha256:73abb8584b0cf74d37f5ef61c10722adc7275502ab71789a8fe3cb7ef04cf6e2 \
+    --hash=sha256:74100fecaec8a535e380cf5f2fb556ff84957d481c13e54051c52e5baac70541 \
+    --hash=sha256:75c6ecb70e809cf1504465174343113f51f24bc61e22a80ae1c859f3f7034c6d \
+    --hash=sha256:7cf05053242f61ba94014dd3a986e11a083400a32664058f80bf4cf817c0b3a1 \
+    --hash=sha256:9411e24328a2302e279e70cae6e479f1fddde79629fcb14e03e6d94b3956eabf \
+    --hash=sha256:a213acfbf186b9f35803b52e4ca9addb153fc0b67f82a48f961be7000ecf6721 \
+    --hash=sha256:bb7e0fe6ad73b7f06d7e2b689c19a71cf5cc48f0c2bf8608469e51ffe0bd2867 \
+    --hash=sha256:c2504eed520958a5b77cc99458297cb7906308cb92327f35fb7fbbad4e9b2188 \
+    --hash=sha256:c35aa9657f5d5116d23b934568e0956bd50c615127810fffe3ac356a914c176a \
+    --hash=sha256:c5f09cffa619adfb44799fa4a81c2a1ad77c887187613fb0a8f201ab38d89ba1 \
+    --hash=sha256:c978f864b35f2261e0819f5cd88b9830b04dc51bcf055aac3c601e525a10d2ba \
+    --hash=sha256:cbe946b3e6e60a7b4618f091e62a029cb082b109a9d6b53962dd305087c6e4fd \
+    --hash=sha256:cc3e4cd087f07758b16bef8f31d88dbb1b5da5671d2f03685ab52dece3d7a16e \
+    --hash=sha256:cf0dead5a2c5a3347af2cfec7131d4f2a2e03c934af28989c9078f8241a491fa \
+    --hash=sha256:d2794f0e68b3085d99b4f6ff9c089f6fdd02b32b9d3efdfbb55beac1bf22d516 \
+    --hash=sha256:d2fa68a96a30dd240be80bbad838a0ac81a61770611ff7952b889485970c4c71 \
+    --hash=sha256:d6f70406695e3220f09cd7a2f879333279d91aa4a8a1d34303b56d61a8180137 \
+    --hash=sha256:d8f9cd4ad1be90b0cf350a2f04a38a36e44a026cac1e036ac593dc48efe91d52 \
+    --hash=sha256:da2d94c15f88cd40d7e67f7919d4f60110d2b9d5b1e08cf354c2be773ab13479 \
+    --hash=sha256:e1727c1c0e394096bb9af185c6923e8ea55a5095b8af44f06903bcc0e06800a2 \
+    --hash=sha256:e420ced29b5904cdf9ee5545e23f9406189d8acb6750916c2db4793dada065c6 \
+    --hash=sha256:e82c5cf1495244adf5252f925ac5932e5fd288b3e5ab6b70bec5593074b7236c \
+    --hash=sha256:f1ef0d39bc1feb420caf549b3c657c871cad4ebbcf0580c4d03816b0590de0cf \
+    --hash=sha256:f8753a6c88d1d0ba64302309eecf20f70d2770f65ca02d83c2452279085bfcd3 \
+    --hash=sha256:f93dbf58f03146164048be5426ffde298b237a5e059144847e4940f5b80172c3
     # via
     #   -r requirements.in
     #   tb-nightly
@@ -269,12 +269,12 @@
 jax==0.4.7 \
     --hash=sha256:5e7002d74db25f97c99b979d4ba1233b1ef26e1597e5fc468ad11d1c8a9dc4f8
     # via -r requirements.in
-keras-nightly==3.0.0.dev2023101703 \
-    --hash=sha256:72674b80300b9672b76e20c3d4c4bd827f6cacd486d9f98e3db864d872e3eaa4 \
-    --hash=sha256:ab5e59bf6a84d048b3241ad9ae368ed8497740804a8e85774fbded34c1322587
+keras-nightly==3.0.0.dev2023103103 \
+    --hash=sha256:25a6a9030d7e067d535ec41ae6700636c90cabada80fbddb3bb8775006807860 \
+    --hash=sha256:c39fccebb3d4cc9d838371981c4d3eef08fc06eadf6cffb39b356cb625bab50f
     # via -r requirements.in
-lit==17.0.3 \
-    --hash=sha256:e6049032462be1e2928686cbd4a6cc5b3c545d83ecd078737fe79412c1f3fcc1
+lit==17.0.4 \
+    --hash=sha256:ee2e180128e770abc6aed3a02de2daf09d81b7d30225e315205d3599c311d304
     # via -r requirements.in
 markdown==3.5 \
     --hash=sha256:4afb124395ce5fc34e6d9886dab977fd9ae987fc6e85689f08278cf0c69d4bf3 \
@@ -530,17 +530,17 @@
     # via
     #   astunparse
     #   tb-nightly
-tb-nightly==2.15.0a20231017 \
-    --hash=sha256:982b8cf32bcab4902eebd2e67b885c127f16106114b34eb4d46ea554af2a713a
+tb-nightly==2.15.0a20231023 \
+    --hash=sha256:8990a52985e3296aa18a6825efc017bcded9e2fb2cbcdd2d01f5c62d4fdc9825
     # via -r requirements.in
 tblib==2.0.0 \
     --hash=sha256:9100bfa016b047d5b980d66e7efed952fbd20bd85b56110aaf473cb97d18709a \
     --hash=sha256:a6df30f272c08bf8be66e0775fad862005d950a6b8449b94f7c788731d70ecd7
     # via -r requirements.in
-tensorboard-data-server==0.7.1 \
-    --hash=sha256:255c02b7f5b03dd5c0a88c928e563441ff39e1d4b4a234cdbe09f016e53d9594 \
-    --hash=sha256:9938bd39f5041797b33921066fba0eab03a0dd10d1887a05e62ae58841ad4c3f \
-    --hash=sha256:be8d016a1aa394e6198280d4a3dc37898f56467310c5f5e617cac10a783e055a
+tensorboard-data-server==0.7.2 \
+    --hash=sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb \
+    --hash=sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60 \
+    --hash=sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530
     # via tb-nightly
 termcolor==2.3.0 \
     --hash=sha256:3afb05607b89aed0ffe25202399ee0867ad4d3cb4180d98aaf8eefa6a5f7d475 \
@@ -557,13 +557,13 @@
     --hash=sha256:c97dfde1f7bd43a71c8d2a58e369e9b2bf692d1334ea9f9cae55add7d0dd0f84 \
     --hash=sha256:fdb6d215c776278489906c2f8916e6e7d4f5a9b602ccbcfdf7f016fc8da0596e
     # via requests
-werkzeug==3.0.0 \
-    --hash=sha256:3ffff4dcc32db52ef3cc94dff3000a3c2846890f3a5a51800a27b909c5e770f0 \
-    --hash=sha256:cbb2600f7eabe51dbc0502f58be0b3e1b96b893b05695ea2b35b43d4de2d9962
+werkzeug==3.0.1 \
+    --hash=sha256:507e811ecea72b18a404947aded4b3390e1db8f826b494d76550ef45bb3b1dcc \
+    --hash=sha256:90a285dc0e42ad56b34e696398b8122ee4c681833fb35b8334a095d82c56da10
     # via tb-nightly
-wheel==0.41.2 \
-    --hash=sha256:0c5ac5ff2afb79ac23ab82bab027a0be7b5dbcf2e54dc50efe4bf507de1f7985 \
-    --hash=sha256:75909db2664838d015e3d9139004ee16711748a52c8f336b52882266540215d8
+wheel==0.41.3 \
+    --hash=sha256:488609bc63a29322326e05560731bf7bfea8e48ad646e1f5e40d366607de0942 \
+    --hash=sha256:4d4987ce51a49370ea65c0bfd2234e8ce80a12780820d9dc462597a6e60d0841
     # via
     #   -r requirements.in
     #   astunparse
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 18ed141..ef01b60 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -1056,6 +1056,7 @@
         "//third_party/py/envlogger/...",
         "//third_party/py/gldm/...",
         "//third_party/py/guesslang/...",
+        "//third_party/py/keras/...",
         "//third_party/py/tf_keras/...",
         "//third_party/yggdrasil_decision_forests/...",
         "//waymo/ml/cn/...",
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index bf598c4..5490149 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -1817,6 +1817,21 @@
   return results;
 }
 
+TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResultsNoSerialization(
+    TF_Graph* graph, const TF_Buffer* graph_def,
+    const TF_ImportGraphDefOptions* options, TF_Status* status) {
+  const GraphDef* graph_def_ptr =
+      reinterpret_cast<const GraphDef*>(graph_def->data);
+  auto results = new TF_ImportGraphDefResults();
+  mutex_lock l(graph->mu);
+  GraphImportGraphDefLocked(graph, *graph_def_ptr, options, results, status);
+  if (!status->status.ok()) {
+    delete results;
+    return nullptr;
+  }
+  return results;
+}
+
 void TF_GraphImportGraphDefWithReturnOutputs(
     TF_Graph* graph, const TF_Buffer* graph_def,
     const TF_ImportGraphDefOptions* options, TF_Output* return_outputs,
diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h
index 2f4cf60..9812b0a 100644
--- a/tensorflow/c/c_api.h
+++ b/tensorflow/c/c_api.h
@@ -834,6 +834,14 @@
                                   const TF_ImportGraphDefOptions* options,
                                   TF_Status* status);
 
+// Has the same behavior as TF_GraphImportGraphDefWithResults, but instead of
+// taking in a serialized tensorflow::GraphDef, it takes in a *pointer* to the
+// C++ *in memory representation* of the GraphDef, stored in `graph_def->data`
+TF_CAPI_EXPORT extern TF_ImportGraphDefResults*
+TF_GraphImportGraphDefWithResultsNoSerialization(
+    TF_Graph* graph, const TF_Buffer* graph_def,
+    const TF_ImportGraphDefOptions* options, TF_Status* status);
+
 // Import the graph serialized in `graph_def` into `graph`.
 // Convenience function for when only return outputs are needed.
 //
diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc
index 008e2d7..e50221a 100644
--- a/tensorflow/c/c_api_test.cc
+++ b/tensorflow/c/c_api_test.cc
@@ -23,6 +23,7 @@
 
 #include "tensorflow/c/c_api_internal.h"
 #include "tensorflow/c/c_test_util.h"
+#include "tensorflow/c/tf_buffer.h"
 #include "tensorflow/c/tf_buffer_internal.h"
 #include "tensorflow/c/tf_status.h"
 #include "tensorflow/cc/saved_model/signature_constants.h"
@@ -764,8 +765,15 @@
   EXPECT_EQ(2, TF_ImportGraphDefOptionsNumReturnOutputs(opts));
   TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar");
   EXPECT_EQ(1, TF_ImportGraphDefOptionsNumReturnOperations(opts));
+  tensorflow::GraphDef graph_def_proto;
+  ASSERT_TRUE(tensorflow::ParseProtoUnlimited(&graph_def_proto, graph_def->data,
+                                              graph_def->length));
+  TF_Buffer graph_def_buffer;
+  graph_def_buffer.data = reinterpret_cast<const void*>(&graph_def_proto);
+  graph_def_buffer.length = sizeof(tensorflow::GraphDef*);
   TF_ImportGraphDefResults* results =
-      TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s);
+      TF_GraphImportGraphDefWithResultsNoSerialization(graph, &graph_def_buffer,
+                                                       opts, s);
   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
 
   TF_Operation* scalar2 = TF_GraphOperationByName(graph, "imported2/scalar");
@@ -956,8 +964,16 @@
   TF_ImportGraphDefOptionsSetPrefix(opts, "imported");
   TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, {scalar, 0});
   TF_ImportGraphDefOptionsAddInputMapping(opts, "fake", 0, {scalar, 0});
+
+  tensorflow::GraphDef graph_def_proto;
+  ASSERT_TRUE(tensorflow::ParseProtoUnlimited(&graph_def_proto, graph_def->data,
+                                              graph_def->length));
+  TF_Buffer graph_def_buffer;
+  graph_def_buffer.data = reinterpret_cast<const void*>(&graph_def_proto);
+  graph_def_buffer.length = sizeof(tensorflow::GraphDef*);
   TF_ImportGraphDefResults* results =
-      TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s);
+      TF_GraphImportGraphDefWithResultsNoSerialization(graph, &graph_def_buffer,
+                                                       opts, s);
   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
 
   // Check unused input mappings
diff --git a/tensorflow/c/experimental/next_pluggable_device/BUILD b/tensorflow/c/experimental/next_pluggable_device/BUILD
index c5cd8d1..5c7bbdd 100644
--- a/tensorflow/c/experimental/next_pluggable_device/BUILD
+++ b/tensorflow/c/experimental/next_pluggable_device/BUILD
@@ -57,6 +57,7 @@
     hdrs = ["tensor_pjrt_buffer_util.h"],
     visibility = ["//visibility:public"],
     deps = [
+        "//tensorflow/compiler/jit:pjrt_tensor_buffer_util",
         "//tensorflow/core:framework",
         "//tensorflow/core/tfrt/common:async_value_tensor",
         "//tensorflow/core/tfrt/common:global_state",
@@ -64,9 +65,11 @@
         "//tensorflow/core/tfrt/common:pjrt_util",
         "@com_google_absl//absl/status",
         "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
         "@local_tsl//tsl/platform:errors",
         "@local_tsl//tsl/platform:statusor",
         "@local_xla//xla/pjrt:pjrt_c_api_client",
+        "@local_xla//xla/pjrt:pjrt_client",
         "@local_xla//xla/pjrt/c:pjrt_c_api_hdrs",
     ],
 )
@@ -80,14 +83,18 @@
         "//tensorflow/core:framework_types_hdr",
         "//tensorflow/core/tfrt/common:async_value_tensor",
         "//tensorflow/core/tfrt/common:pjrt_util",
+        "@com_google_absl//absl/log:check",
         "@com_google_googletest//:gtest_main",
         "@local_tsl//tsl/lib/core:status_test_util",
+        "@local_tsl//tsl/platform:casts",
         "@local_tsl//tsl/platform:status_matchers",
         "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc",
+        "@local_xla//xla:shape_util",
         "@local_xla//xla/pjrt:pjrt_api",
         "@local_xla//xla/pjrt:pjrt_c_api_client",
         "@local_xla//xla/pjrt:tfrt_cpu_pjrt_client",
         "@local_xla//xla/pjrt/c:pjrt_c_api_cpu",
+        "@local_xla//xla/pjrt/c:pjrt_c_api_hdrs",
         "@local_xla//xla/pjrt/c:pjrt_c_api_wrapper_impl",
     ],
 )
diff --git a/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util.cc b/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util.cc
index 02f9388..18a851e 100644
--- a/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util.cc
+++ b/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util.cc
@@ -15,11 +15,15 @@
 #include "tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util.h"
 
 #include <memory>
+#include <utility>
 
 #include "absl/status/status.h"
 #include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "tensorflow/compiler/jit/pjrt_tensor_buffer_util.h"
 #include "xla/pjrt/c/pjrt_c_api.h"
 #include "xla/pjrt/pjrt_c_api_client.h"
+#include "xla/pjrt/pjrt_client.h"
 #include "tensorflow/core/framework/resource_mgr.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/types.h"
@@ -50,14 +54,16 @@
 absl::Status SetPjRtCBufferToTensor(PJRT_Buffer* c_buffer,
                                     xla::PjRtCApiClient* c_api_client,
                                     Tensor* tensor) {
+  auto buffer = std::make_unique<xla::PjRtCApiBuffer>(c_api_client, c_buffer);
   tensorflow::AsyncValueTensor* av_tensor =
       tensorflow::AsyncValueTensor::FromTensor(tensor);
   if (av_tensor == nullptr) {
-    return absl::InternalError(
-        "The tensor to set PjRtBuffer is not an AsyncValueTensor.");
+    TF_ASSIGN_OR_RETURN(
+        *tensor, MakeTensorFromPjRtBuffer(tensor->dtype(), tensor->shape(),
+                                          std::move(buffer)));
+  } else {
+    av_tensor->SetBuffer(std::move(buffer));
   }
-  av_tensor->SetBuffer(
-      std::make_unique<xla::PjRtCApiBuffer>(c_api_client, c_buffer));
   return absl::OkStatus();
 }
 
diff --git a/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util_test.cc b/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util_test.cc
index 06fbd7b..c72f0cf 100644
--- a/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util_test.cc
+++ b/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util_test.cc
@@ -14,20 +14,27 @@
 ==============================================================================*/
 #include "tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util.h"
 
+#include <cstdint>
 #include <memory>
+#include <optional>
 #include <utility>
 #include <vector>
 
 #include <gtest/gtest.h>
+#include "absl/log/check.h"
+#include "xla/pjrt/c/pjrt_c_api.h"
 #include "xla/pjrt/c/pjrt_c_api_cpu.h"
 #include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
 #include "xla/pjrt/pjrt_api.h"
 #include "xla/pjrt/pjrt_c_api_client.h"
 #include "xla/pjrt/tfrt_cpu_pjrt_client.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/tfrt/common/async_value_tensor.h"
 #include "tensorflow/core/tfrt/common/pjrt_util.h"
 #include "tsl/lib/core/status_test_util.h"
+#include "tsl/platform/casts.h"
 #include "tsl/platform/status_matchers.h"
 #include "tsl/protobuf/error_codes.pb.h"
 
@@ -38,6 +45,27 @@
 using ::testing::NotNull;
 using ::tsl::testing::StatusIs;
 
+PJRT_Buffer* CreateCBuffer() {
+  auto status = pjrt::PjrtApi(DEVICE_CPU);
+  if (!status.ok()) {
+    CHECK_OK(pjrt::SetPjrtApi(DEVICE_CPU, GetPjrtApi()));
+  }
+  auto pjrt_client = xla::GetCApiClient(DEVICE_CPU);
+  CHECK_OK(pjrt_client.status());
+  auto c_api_client = down_cast<xla::PjRtCApiClient*>(pjrt_client->get());
+  std::vector<int32_t> data(1, 0);
+  xla::Shape shape = xla::ShapeUtil::MakeShape(xla::S32, {1});
+
+  auto buffer = c_api_client->pjrt_c_client()->client->BufferFromHostBuffer(
+      data.data(), shape.element_type(), shape.dimensions(),
+      /*byte_strides=*/std::nullopt,
+      xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall, nullptr,
+      c_api_client->pjrt_c_client()->client->addressable_devices()[0]);
+  CHECK_OK(buffer.status());
+
+  return new PJRT_Buffer{std::move(*buffer), c_api_client->pjrt_c_client()};
+}
+
 TEST(TensorPjRtBufferUtilTest, GetPjRtCBufferFromTensorNoBuffer) {
   auto allocator = std::make_unique<AsyncValueAllocator>();
   tensorflow::Tensor tensor(allocator.get(), DT_FLOAT, {1});
@@ -103,36 +131,18 @@
 
 TEST(TensorPjRtBufferUtilTest, SetPjRtCBufferToTensorNotAsyncValueTensor) {
   tensorflow::Tensor tensor(DT_FLOAT, {1});
+  TF_ASSERT_OK_AND_ASSIGN(auto pjrt_client, xla::GetCApiClient(DEVICE_CPU));
+  PJRT_Buffer* c_buffer = CreateCBuffer();
 
-  EXPECT_THAT(
-      SetPjRtCBufferToTensor(nullptr, nullptr, &tensor),
-      StatusIs(
-          error::INTERNAL,
-          HasSubstr(absl::StrCat(
-              "The tensor to set PjRtBuffer is not an AsyncValueTensor"))));
+  TF_EXPECT_OK(SetPjRtCBufferToTensor(
+      c_buffer, down_cast<xla::PjRtCApiClient*>(pjrt_client.get()), &tensor));
 }
 
 TEST(TensorPjRtBufferUtilTest, SetPjRtCBufferToTensorSuccess) {
   auto allocator = std::make_unique<AsyncValueAllocator>();
-  auto status = pjrt::PjrtApi(DEVICE_CPU);
-  if (!status.ok()) {
-    TF_ASSERT_OK(pjrt::SetPjrtApi(DEVICE_CPU, GetPjrtApi()));
-  }
-  TF_ASSERT_OK_AND_ASSIGN(auto pjrt_client, xla::GetCApiClient(DEVICE_CPU));
-  auto c_api_client = down_cast<xla::PjRtCApiClient*>(pjrt_client.get());
-  std::vector<int32_t> data(1, 0);
-  xla::Shape shape = xla::ShapeUtil::MakeShape(xla::S32, {1});
-  TF_ASSERT_OK_AND_ASSIGN(
-      auto buffer,
-      c_api_client->pjrt_c_client()->client->BufferFromHostBuffer(
-          data.data(), shape.element_type(), shape.dimensions(),
-          /*byte_strides=*/std::nullopt,
-          xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
-          nullptr,
-          c_api_client->pjrt_c_client()->client->addressable_devices()[0]));
   tensorflow::Tensor tensor(allocator.get(), DT_FLOAT, {1});
-  auto c_buffer =
-      new PJRT_Buffer{std::move(buffer), c_api_client->pjrt_c_client()};
+  TF_ASSERT_OK_AND_ASSIGN(auto pjrt_client, xla::GetCApiClient(DEVICE_CPU));
+  PJRT_Buffer* c_buffer = CreateCBuffer();
 
   TF_EXPECT_OK(SetPjRtCBufferToTensor(
       c_buffer, down_cast<xla::PjRtCApiClient*>(pjrt_client.get()), &tensor));
diff --git a/tensorflow/c/tf_status.h b/tensorflow/c/tf_status.h
index 31e5a4f..438f860 100644
--- a/tensorflow/c/tf_status.h
+++ b/tensorflow/c/tf_status.h
@@ -29,6 +29,7 @@
 // TF_Code holds an error code.  The enum values here are identical to
 // corresponding values in error_codes.proto.
 typedef TSL_Code TF_Code;
+// LINT.IfChange
 #define TF_OK TSL_OK
 #define TF_CANCELLED TSL_CANCELLED
 #define TF_UNKNOWN TSL_UNKNOWN
@@ -46,6 +47,7 @@
 #define TF_INTERNAL TSL_INTERNAL
 #define TF_UNAVAILABLE TSL_UNAVAILABLE
 #define TF_DATA_LOSS TSL_DATA_LOSS
+// LINT.ThenChange(//tensorflow/python/py_exception_registry_wrapper.cc)
 
 // --------------------------------------------------------------------------
 
diff --git a/tensorflow/cc/framework/cc_op_gen_util.cc b/tensorflow/cc/framework/cc_op_gen_util.cc
index d0c65d1..0a64525 100644
--- a/tensorflow/cc/framework/cc_op_gen_util.cc
+++ b/tensorflow/cc/framework/cc_op_gen_util.cc
@@ -15,6 +15,7 @@
 
 #include "tensorflow/cc/framework/cc_op_gen_util.h"
 
+#include <cmath>
 #include <string>
 #include <unordered_map>
 #include <unordered_set>
@@ -29,6 +30,7 @@
 #include "tensorflow/core/framework/tensor_shape.pb.h"
 #include "tensorflow/core/lib/gtl/map_util.h"
 #include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/strcat.h"
 #include "tensorflow/core/platform/types.h"
 #include "tsl/platform/statusor.h"
 
@@ -206,7 +208,12 @@
       return strings::StrCat(attr_value.i());
     case AttrValue::kF: {
       const float f = attr_value.f();
-      return strings::StrCat(attr_value.f(), floorf(f) == f ? ".0" : "", "f");
+      if (std::isinf(f)) {
+        return strings::StrCat(f < 0.0f ? "-" : "+",
+                               "std::numeric_limits<float>::infinity()");
+      } else {
+        return strings::StrCat(attr_value.f(), floorf(f) == f ? ".0" : "", "f");
+      }
     }
     case AttrValue::kB:
       return attr_value.b() ? "true" : "false";
diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD
index b09ebe0..935b37b 100644
--- a/tensorflow/cc/saved_model/BUILD
+++ b/tensorflow/cc/saved_model/BUILD
@@ -522,6 +522,7 @@
     visibility = [
         "//learning/brain/contrib/hub/server/distro:__subpackages__",
         "//learning/brain/contrib/tpu_modeling:__subpackages__",
+        "//learning/metadata/artifactoid/cc:__subpackages__",
         "//learning/tfx/pipeline/util:__subpackages__",
         "//tensorflow/python/saved_model:__subpackages__",
     ],
diff --git a/tensorflow/compiler/aot/benchmark_main.template b/tensorflow/compiler/aot/benchmark_main.template
index a4df6ed..dc5a903 100644
--- a/tensorflow/compiler/aot/benchmark_main.template
+++ b/tensorflow/compiler/aot/benchmark_main.template
@@ -19,7 +19,7 @@
 // clang-format on
 
 #include "tensorflow/compiler/aot/benchmark.h"
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "unsupported/Eigen/CXX11/Tensor"
 
 // Macros that expand to tokens based on the entry point name.
 // clang-format off
diff --git a/tensorflow/compiler/jit/device_compiler_client.cc b/tensorflow/compiler/jit/device_compiler_client.cc
index 747aa7d..5d0042c 100644
--- a/tensorflow/compiler/jit/device_compiler_client.cc
+++ b/tensorflow/compiler/jit/device_compiler_client.cc
@@ -37,8 +37,6 @@
   build_options.set_alias_passthrough_params(options.alias_passthrough_params);
   build_options.mutable_debug_options()->set_xla_detailed_logging(
       options.detailed_logging);
-  build_options.mutable_debug_options()->set_xla_enable_dumping(
-      options.detailed_logging);
   if (tensorflow::OpDeterminismRequired()) {
     build_options.mutable_debug_options()->set_xla_gpu_deterministic_ops(true);
   }
diff --git a/tensorflow/compiler/jit/device_compiler_client_test.cc b/tensorflow/compiler/jit/device_compiler_client_test.cc
index 4ac2e7f..f42ae36 100644
--- a/tensorflow/compiler/jit/device_compiler_client_test.cc
+++ b/tensorflow/compiler/jit/device_compiler_client_test.cc
@@ -60,5 +60,17 @@
   EXPECT_EQ(build_option.device_ordinal(), -1);
 }
 
+TEST(GetExecutableOptionTest, DumpingWithoutDetailedLogging) {
+  XlaCompiler::Options options;
+  options.detailed_logging = false;
+  XlaCompiler::CompilationResult result;
+
+  auto build_option =
+      GetExecutableBuildOptions(options, result, /*default_device_ordinal=*/-1);
+
+  EXPECT_FALSE(build_option.debug_options().xla_detailed_logging());
+  EXPECT_TRUE(build_option.debug_options().xla_enable_dumping());
+}
+
 }  // namespace
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/jit/get_compiler_ir.cc b/tensorflow/compiler/jit/get_compiler_ir.cc
index 37987cb..d147bfb 100644
--- a/tensorflow/compiler/jit/get_compiler_ir.cc
+++ b/tensorflow/compiler/jit/get_compiler_ir.cc
@@ -17,7 +17,6 @@
 
 #include <cstdint>
 #include <deque>
-#include <iterator>
 #include <memory>
 #include <string>
 #include <utility>
@@ -74,8 +73,6 @@
   build_options.set_alias_passthrough_params(options.alias_passthrough_params);
   build_options.mutable_debug_options()->set_xla_detailed_logging(
       options.detailed_logging);
-  build_options.mutable_debug_options()->set_xla_enable_dumping(
-      options.detailed_logging);
   // If the embed_ir_in_executable is set, hlo_proto will be dumped in
   // executable. The hlo_proto contains HLO modules and buffer assignment.
   build_options.mutable_debug_options()->set_xla_embed_ir_in_executable(
@@ -148,10 +145,10 @@
 BuildXlaCompilerArgumentFromTensorSpec(
     const FunctionBody* fbody, absl::Span<int const> must_be_constant_idxs,
     absl::Span<const Tensor* const> inputs,
-    absl::Span<VariableInfo const> variable_args, Device* device,
+    absl::Span<VariableInfo const> variable_args,
     absl::Span<const ArgShapeAndDType> flat_arg_shape_and_dtype) {
   TF_RET_CHECK(fbody != nullptr);
-  auto& input_args = fbody->fdef.signature().input_arg();
+  auto& input_args = fbody->record->fdef().signature().input_arg();
   int input_arg_size = input_args.size();
   std::vector<XlaCompiler::Argument> args;
   args.reserve(input_arg_size);
@@ -261,7 +258,7 @@
       flr, function, &fbody, &constant_arg_indices, &resource_arg_indices));
 
   // `input_args` includes both concrete_fn input args and captured_input here.
-  auto& input_args = fbody->fdef.signature().input_arg();
+  auto& input_args = fbody->record->fdef().signature().input_arg();
   // Here input_arg_size = len(flat_args) + len(captured_input)
   int input_arg_size = input_args.size();
 
@@ -326,7 +323,7 @@
 
   if (compiler_arg_source == CompilerArgSource::TENSOR_SPEC) {
     args = BuildXlaCompilerArgumentFromTensorSpec(fbody, constant_arg_indices,
-                                                  inputs, variable_infos, dev,
+                                                  inputs, variable_infos,
                                                   input_arg_shape_and_dtype);
   } else if (compiler_arg_source == CompilerArgSource::CONCRETE_INPUT) {
     args = XlaComputationLaunchContext::BuildXlaCompilerArguments(
diff --git a/tensorflow/compiler/jit/xla_kernel_creator.cc b/tensorflow/compiler/jit/xla_kernel_creator.cc
index 598d401..337a93a 100644
--- a/tensorflow/compiler/jit/xla_kernel_creator.cc
+++ b/tensorflow/compiler/jit/xla_kernel_creator.cc
@@ -79,7 +79,8 @@
   Device* dev = flr->device();
   Status s;
   auto props = std::make_shared<NodeProperties>(
-      &fbody->fdef.signature(), node_def, fbody->arg_types, fbody->ret_types);
+      &fbody->record->fdef().signature(), node_def, fbody->arg_types,
+      fbody->ret_types);
   OpKernelConstruction construction(DeviceType(dev->device_type()), dev,
                                     dev->GetAllocator(AllocatorAttributes()),
                                     flr, dev->resource_manager(), props,
diff --git a/tensorflow/compiler/jit/xla_platform_info.cc b/tensorflow/compiler/jit/xla_platform_info.cc
index 70a6b76..a1ceac5 100644
--- a/tensorflow/compiler/jit/xla_platform_info.cc
+++ b/tensorflow/compiler/jit/xla_platform_info.cc
@@ -43,6 +43,7 @@
 #include "tensorflow/core/tfrt/common/global_state.h"
 #include "tensorflow/core/tfrt/common/pjrt_util.h"
 #include "tensorflow/core/tpu/tpu_defs.h"
+#include "tsl/framework/device_type.h"
 
 namespace tensorflow {
 namespace {
@@ -277,11 +278,32 @@
   const auto& device_type = platform_info.device_type();
   const std::string& compiler_name =
       GetPjRtDeviceCompilerResourceName(device_type);
+  const std::string& profiler_name =
+      GetPjRtDeviceCompilationProfilerResourceName(device_type);
+  bool deleted_old_device_compiler = false;
 
   // Lookup the DeviceCompiler, create one if not found.
   Status s = rm->Lookup<PjRtDeviceCompiler>(
       rm->default_container(), compiler_name, pjrt_device_compiler);
-  if (!s.ok()) {
+  if (s.ok() && device_type == DEVICE_TPU) {
+    auto* existing_pjrt_client = (*pjrt_device_compiler)->client();
+    TF_ASSIGN_OR_RETURN(auto* latest_pjrt_client, GetPjRtClient(device_type));
+
+    if (existing_pjrt_client != latest_pjrt_client) {
+      // PjRtClient has changed. Delete the PjRtDeviceCompiler (and the cache
+      // within) and create a new one.
+      TF_RETURN_IF_ERROR(rm->Delete<PjRtDeviceCompiler>(rm->default_container(),
+                                                        compiler_name));
+      TF_RETURN_IF_ERROR(rm->Delete<DeviceCompilationProfiler>(
+          rm->default_container(), profiler_name));
+
+      deleted_old_device_compiler = true;
+    }
+  }
+
+  // TODO(b/308698131): Try consolidating all PJRT-related state into one class
+  // instead of directly storing it in the ResourceMgr.
+  if (!s.ok() || deleted_old_device_compiler) {
     DeviceType compilation_device_type("");
     xla::PjRtClient* pjrt_client = nullptr;
     TF_RETURN_IF_ERROR(GetCompilationDeviceTypeAndPjRtClient(
@@ -296,8 +318,6 @@
         }));
   }
 
-  const std::string& profiler_name =
-      GetPjRtDeviceCompilationProfilerResourceName(device_type);
   TF_RETURN_IF_ERROR(rm->LookupOrCreate<DeviceCompilationProfiler>(
       rm->default_container(), profiler_name, profiler,
       [](DeviceCompilationProfiler** profiler) {
@@ -321,20 +341,21 @@
 }
 
 XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device_base) {
-  auto device = static_cast<Device*>(device_base);
   se::Platform::Id platform_id = nullptr;
   const XlaDevice::Metadata* xla_device_metadata = nullptr;
   const PjRtBaseDevice::Metadata* pjrt_device_metadata = nullptr;
   std::shared_ptr<se::DeviceMemoryAllocator> custom_allocator;
 
-  if (device->device_type() == DEVICE_CPU) {
+  const std::string& device_type = device_base->device_type();
+  if (device_type == DEVICE_CPU) {
     platform_id = se::host::kHostPlatformId;
-  } else if (device->device_type() == DEVICE_GPU) {
+  } else if (device_type == DEVICE_GPU) {
+    auto device = static_cast<Device*>(device_base);
     platform_id = device->tensorflow_accelerator_device_info()
                       ->stream->parent()
                       ->platform()
                       ->id();
-  } else if (XlaDevice::GetMetadataFromDevice(device, &xla_device_metadata)
+  } else if (XlaDevice::GetMetadataFromDevice(device_base, &xla_device_metadata)
                  .ok()) {
     // If we are on an XlaDevice, use the underlying XLA platform's allocator
     // directly. We could use the StreamExecutor's allocator which may
@@ -348,12 +369,12 @@
     platform_id = xla_device_metadata->platform()->id();
     custom_allocator =
         xla_device_metadata->client()->backend().shared_memory_allocator();
-  } else if (auto metadata = PjRtBaseDevice::GetMetadataFromDevice(device);
+  } else if (auto metadata = PjRtBaseDevice::GetMetadataFromDevice(device_base);
              metadata.ok()) {
     pjrt_device_metadata = *metadata;
   }
 
-  return XlaPlatformInfo(DeviceType(device->device_type()), platform_id,
+  return XlaPlatformInfo(DeviceType(device_type), platform_id,
                          xla_device_metadata, pjrt_device_metadata,
                          custom_allocator);
 }
diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD
index 28e2753..f3272f8 100644
--- a/tensorflow/compiler/mlir/BUILD
+++ b/tensorflow/compiler/mlir/BUILD
@@ -54,13 +54,13 @@
         "//tensorflow/compiler/mlir/tensorflow",
         "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_pass_registration",
         "//tensorflow/compiler/mlir/tensorflow:mlprogram_util",
-        "//tensorflow/compiler/mlir/tensorflow/transforms:bridge_pass_test_pipeline_registration",
         "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes",
         "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_test_passes",
         "//tensorflow/compiler/mlir/tensorflow/transforms:tf_graph_optimization_pass",
         "//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_passes",  # buildcleaner:keep
         "//tensorflow/compiler/mlir/tensorflow/transforms/host_runtime:lower_cluster_to_runtime_ops",
         "//tensorflow/compiler/mlir/tf2xla:compile_mlir_util",
+        "//tensorflow/compiler/mlir/tf2xla/internal/passes:clustering_passes",
         "//tensorflow/compiler/mlir/tf2xla/transforms:tf_xla_passes",
         "//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_tf",
         "//tensorflow/compiler/mlir/tosa:tf_passes",
diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD
index 87de941..8117705 100644
--- a/tensorflow/compiler/mlir/lite/BUILD
+++ b/tensorflow/compiler/mlir/lite/BUILD
@@ -12,6 +12,7 @@
         "//learning/brain/mobile/programmability:__subpackages__",
         "//tensorflow/lite/experimental/tf_runtime:__subpackages__",
         "//tensorflow/lite/testing:__subpackages__",
+        "//third_party/odml/infra/genai/conversion/per_layer:__subpackages__",
     ],
     licenses = ["notice"],
 )
@@ -873,8 +874,8 @@
         "transforms/lift_tflite_flex_ops.h",
     ],
     deps = [
+        ":tensorflow_lite",
         ":tensorflow_lite_passes_inc_gen",
-        "//tensorflow/compiler/mlir/lite:tensorflow_lite",
         "//tensorflow/compiler/mlir/tensorflow",
         "//tensorflow/compiler/mlir/tensorflow:convert_attr",
         "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes",
@@ -883,6 +884,7 @@
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
+        "@com_google_absl//absl/strings",
         "@flatbuffers",
         "@llvm-project//mlir:FuncDialect",
         "@llvm-project//mlir:IR",
diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/BUILD b/tensorflow/compiler/mlir/lite/experimental/tac/BUILD
index c39e377..be13a15 100644
--- a/tensorflow/compiler/mlir/lite/experimental/tac/BUILD
+++ b/tensorflow/compiler/mlir/lite/experimental/tac/BUILD
@@ -7,6 +7,7 @@
     "@llvm-project//mlir:tblgen.bzl",
     "gentbl_cc_library",
 )
+# copybara:uncomment load("//tools/build_defs/proto/cpp:cc_proto_library.bzl", "cc_proto_library")
 
 package(
     # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/tests/device-transform-gpu.mlir b/tensorflow/compiler/mlir/lite/experimental/tac/tests/device-transform-gpu.mlir
index b53508f..cbd2dd6 100644
--- a/tensorflow/compiler/mlir/lite/experimental/tac/tests/device-transform-gpu.mlir
+++ b/tensorflow/compiler/mlir/lite/experimental/tac/tests/device-transform-gpu.mlir
@@ -153,8 +153,8 @@
 }
 
 // CHECK:       func @padSliceTo4D(%[[VAL_0:.*]]: tensor<4x384x32xf32>) -> tensor<1x384x32xf32> {
-// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() {value = dense<0> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() {value = dense<[1, 1, 384, 32]> : tensor<4xi32>} : () -> tensor<4xi32>
+// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<0> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<[1, 1, 384, 32]> : tensor<4xi32>}> : () -> tensor<4xi32>
 // CHECK-DAG:       %[[VAL_3:.*]] = arith.constant dense<[1, 4, 384, 32]> : tensor<4xi32>
 // CHECK-DAG:       %[[VAL_4:.*]] = "tfl.pseudo_const"() {value = dense<[1, 384, 32]> : tensor<3xi32>
 // CHECK:           %[[VAL_5:.*]] = "tfl.reshape"(%[[VAL_0]], %[[VAL_3]]) : (tensor<4x384x32xf32>, tensor<4xi32>) -> tensor<1x4x384x32xf32>
diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc
index 1a55e51..4f044e1 100644
--- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc
+++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc
@@ -675,6 +675,11 @@
   std::optional<VectorBufferOffset<BufferOffset<tflite::Metadata>>>
   CreateMetadataVector();
 
+  // Encodes the `tfl.metadata_buffer` array attribute of the module to the
+  // metadata_buffer section in the final model. Returns empty if there isn't
+  // such attribute in the mlir module.
+  VectorBufferOffset<int32_t> CreateMetadataBufferVector();
+
   // Builds and returns list of tfl.SignatureDef sections in the model.
   std::optional<VectorBufferOffset<BufferOffset<tflite::SignatureDef>>>
   CreateSignatureDefs(const std::vector<SignatureDefData>& signature_defs);
@@ -727,8 +732,18 @@
   BufferOffset<flatbuffers::Vector<unsigned int>> BuildStablehloPrecisionConfig(
       ::mlir::ArrayAttr precisionConfig);
 
+  std::optional<BufferOffset<tflite::Operator>> BuildStablehloGatherOp(
+      mlir::stablehlo::GatherOp gather_op, const std::vector<int32_t>& operands,
+      const std::vector<int32_t>& results);
+
   std::optional<BufferOffset<tflite::Operator>> BuildStablehloScatterOp(
-      mlir::stablehlo::ScatterOp shlo_op, const std::vector<int32_t>& operands,
+      mlir::stablehlo::ScatterOp scatter_op,
+      const std::vector<int32_t>& operands,
+      const std::vector<int32_t>& results);
+
+  std::optional<BufferOffset<tflite::Operator>> BuildStablehloReduceWindowOp(
+      mlir::stablehlo::ReduceWindowOp reduce_window_op,
+      const std::vector<int32_t>& operands,
       const std::vector<int32_t>& results);
 
   std::optional<BufferOffset<tflite::Operator>> BuildStablehloRngBitGeneratorOp(
@@ -1432,6 +1447,43 @@
 }
 
 std::optional<BufferOffset<tflite::Operator>>
+Translator::BuildStablehloGatherOp(mlir::stablehlo::GatherOp gather_op,
+                                   const std::vector<int32_t>& operands,
+                                   const std::vector<int32_t>& results) {
+  std::string op_name =
+      gather_op.getOperation()->getName().getStringRef().str();
+  uint32_t opcode_index =
+      GetOpcodeIndex(op_name, tflite::BuiltinOperator_STABLEHLO_GATHER);
+
+  std::vector<int64_t> offset_dims_vec(
+      gather_op.getDimensionNumbers().getOffsetDims().begin(),
+      gather_op.getDimensionNumbers().getOffsetDims().end());
+  std::vector<int64_t> collapsed_slice_dims_vec(
+      gather_op.getDimensionNumbers().getCollapsedSliceDims().begin(),
+      gather_op.getDimensionNumbers().getCollapsedSliceDims().end());
+  std::vector<int64_t> start_index_map_vec(
+      gather_op.getDimensionNumbers().getStartIndexMap().begin(),
+      gather_op.getDimensionNumbers().getStartIndexMap().end());
+
+  auto offset_dims = builder_.CreateVector(offset_dims_vec);
+  auto collapsed_slice_dims = builder_.CreateVector(collapsed_slice_dims_vec);
+  auto start_index_map = builder_.CreateVector(start_index_map_vec);
+  auto slice_sizes = builder_.CreateVector(
+      mlir::GetOptionalVector<int64_t>(gather_op.getSliceSizes()));
+
+  auto gather_option = tflite::CreateStablehloGatherOptions(
+      builder_, offset_dims, collapsed_slice_dims, start_index_map,
+      gather_op.getDimensionNumbers().getIndexVectorDim(), slice_sizes,
+      gather_op.getIndicesAreSorted());
+
+  return tflite::CreateOperator(
+      builder_, opcode_index, builder_.CreateVector(operands),
+      builder_.CreateVector(results), tflite::BuiltinOptions_NONE, 0, 0,
+      tflite::CustomOptionsFormat_FLEXBUFFERS, 0, 0, 0, 0,
+      tflite::BuiltinOptions2_StablehloGatherOptions, gather_option.Union());
+}
+
+std::optional<BufferOffset<tflite::Operator>>
 Translator::BuildStablehloScatterOp(mlir::stablehlo::ScatterOp scatter_op,
                                     const std::vector<int32_t>& operands,
                                     const std::vector<int32_t>& results) {
@@ -1484,6 +1536,46 @@
 }
 
 std::optional<BufferOffset<tflite::Operator>>
+Translator::BuildStablehloReduceWindowOp(
+    mlir::stablehlo::ReduceWindowOp reduce_window_op,
+    const std::vector<int32_t>& operands, const std::vector<int32_t>& results) {
+  std::string op_name =
+      reduce_window_op.getOperation()->getName().getStringRef().str();
+  uint32_t opcode_index =
+      GetOpcodeIndex(op_name, tflite::BuiltinOperator_STABLEHLO_REDUCE_WINDOW);
+
+  auto window_dimensions = builder_.CreateVector(
+      mlir::GetVector<int64_t>(reduce_window_op.getWindowDimensions()));
+  auto window_strides = builder_.CreateVector(
+      mlir::GetOptionalVector<int64_t>(reduce_window_op.getWindowStrides()));
+  auto base_dilations = builder_.CreateVector(
+      mlir::GetOptionalVector<int64_t>(reduce_window_op.getBaseDilations()));
+  auto window_dilations = builder_.CreateVector(
+      mlir::GetOptionalVector<int64_t>(reduce_window_op.getWindowDilations()));
+  auto padding = builder_.CreateVector(
+      mlir::GetOptionalVector<int64_t>(reduce_window_op.getPadding()));
+
+  auto& body = reduce_window_op.getBody();
+  int32_t subgraph_index = UnnamedRegionToSubgraph(
+      &body, tflite::BuiltinOperator_STABLEHLO_REDUCE_WINDOW);
+  if (subgraph_index < 0) return std::nullopt;
+
+  auto reduce_window_option = tflite::CreateStablehloReduceWindowOptions(
+      builder_, window_dimensions, window_strides, base_dilations,
+      window_dilations, padding, subgraph_index);
+
+  return tflite::CreateOperator(
+      builder_, opcode_index, /*inputs=*/builder_.CreateVector(operands),
+      /*outputs=*/builder_.CreateVector(results), tflite::BuiltinOptions_NONE,
+      /*builtin_options=*/0, /*custom_options=*/0,
+      tflite::CustomOptionsFormat_FLEXBUFFERS, /*mutating_variable_inputs=*/0,
+      /*intermediates=*/0, /*large_custom_options_offset=*/0,
+      /*large_custom_options_size=*/0,
+      tflite::BuiltinOptions2_StablehloReduceWindowOptions,
+      reduce_window_option.Union());
+}
+
+std::optional<BufferOffset<tflite::Operator>>
 Translator::BuildStablehloRngBitGeneratorOp(
     mlir::stablehlo::RngBitGeneratorOp rng_op,
     const std::vector<int32_t>& operands, const std::vector<int32_t>& results) {
@@ -1590,6 +1682,28 @@
             llvm::dyn_cast<mlir::stablehlo::RngBitGeneratorOp>(inst)) {
       return BuildStablehloRngBitGeneratorOp(shlo_op, operands, results);
     }
+    if (auto shlo_op = llvm::dyn_cast<mlir::stablehlo::GatherOp>(inst)) {
+      return BuildStablehloGatherOp(shlo_op, operands, results);
+    }
+    if (auto shlo_op = llvm::dyn_cast<mlir::stablehlo::AddOp>(inst)) {
+      return BuildStablehloOperatorwithoutOptions(
+          inst, operands, results, tflite::BuiltinOperator_STABLEHLO_ADD);
+    }
+    if (auto shlo_op = llvm::dyn_cast<mlir::stablehlo::MulOp>(inst)) {
+      return BuildStablehloOperatorwithoutOptions(
+          inst, operands, results, tflite::BuiltinOperator_STABLEHLO_MULTIPLY);
+    }
+    if (auto shlo_op = llvm::dyn_cast<mlir::stablehlo::ReduceWindowOp>(inst)) {
+      return BuildStablehloReduceWindowOp(shlo_op, operands, results);
+    }
+    if (auto shlo_op = llvm::dyn_cast<mlir::stablehlo::MaxOp>(inst)) {
+      return BuildStablehloOperatorwithoutOptions(
+          inst, operands, results, tflite::BuiltinOperator_STABLEHLO_MAXIMUM);
+    }
+    if (auto shlo_op = llvm::dyn_cast<mlir::stablehlo::MinOp>(inst)) {
+      return BuildStablehloOperatorwithoutOptions(
+          inst, operands, results, tflite::BuiltinOperator_STABLEHLO_MINIMUM);
+    }
     // for ops don't have kernels, only serialize when conversion is set to true
     if (convert_stablehlo_) {
       if (auto shlo_op = llvm::dyn_cast<mlir::stablehlo::LogisticOp>(inst)) {
@@ -1598,25 +1712,10 @@
             tflite::BuiltinOperator_STABLEHLO_LOGISTIC);
       }
 
-      if (auto shlo_op = llvm::dyn_cast<mlir::stablehlo::AddOp>(inst)) {
-        return BuildStablehloOperatorwithoutOptions(
-            inst, operands, results, tflite::BuiltinOperator_STABLEHLO_ADD);
-      }
-
-      if (auto shlo_op = llvm::dyn_cast<mlir::stablehlo::MulOp>(inst)) {
-        return BuildStablehloOperatorwithoutOptions(
-            inst, operands, results,
-            tflite::BuiltinOperator_STABLEHLO_MULTIPLY);
-      }
-
       if (auto shlo_op = llvm::dyn_cast<mlir::stablehlo::DivOp>(inst)) {
         return BuildStablehloOperatorwithoutOptions(
             inst, operands, results, tflite::BuiltinOperator_STABLEHLO_DIVIDE);
       }
-      if (auto shlo_op = llvm::dyn_cast<mlir::stablehlo::MaxOp>(inst)) {
-        return BuildStablehloOperatorwithoutOptions(
-            inst, operands, results, tflite::BuiltinOperator_STABLEHLO_MAXIMUM);
-      }
       if (auto shlo_op = llvm::dyn_cast<mlir::stablehlo::ReshapeOp>(inst)) {
         return BuildStablehloOperatorwithoutOptions(
             inst, operands, results, tflite::BuiltinOperator_STABLEHLO_RESHAPE);
@@ -1654,10 +1753,6 @@
         return BuildStablehloOperatorwithoutOptions(
             inst, operands, results, tflite::BuiltinOperator_STABLEHLO_LOG);
       }
-      if (auto shlo_op = llvm::dyn_cast<mlir::stablehlo::MinOp>(inst)) {
-        return BuildStablehloOperatorwithoutOptions(
-            inst, operands, results, tflite::BuiltinOperator_STABLEHLO_MINIMUM);
-      }
       if (auto shlo_op = llvm::dyn_cast<mlir::stablehlo::NegOp>(inst)) {
         return BuildStablehloOperatorwithoutOptions(
             inst, operands, results, tflite::BuiltinOperator_STABLEHLO_NEGATE);
@@ -2027,39 +2122,6 @@
             tflite::BuiltinOptions2_StablehloDotGeneralOptions,
             dot_geneoral_option.Union());
       }
-      if (auto shlo_op =
-              llvm::dyn_cast<mlir::stablehlo::ReduceWindowOp>(inst)) {
-        std::string op_name = inst->getName().getStringRef().str();
-        uint32_t opcode_index = GetOpcodeIndex(
-            op_name, tflite::BuiltinOperator_STABLEHLO_REDUCE_WINDOW);
-
-        auto window_dimensions = builder_.CreateVector(
-            mlir::GetOptionalVector<int64_t>(shlo_op.getWindowDimensions()));
-        auto window_strides = builder_.CreateVector(
-            mlir::GetOptionalVector<int64_t>(shlo_op.getWindowStrides()));
-        auto base_dilations = builder_.CreateVector(
-            mlir::GetOptionalVector<int64_t>(shlo_op.getBaseDilations()));
-        auto window_dilations = builder_.CreateVector(
-            mlir::GetOptionalVector<int64_t>(shlo_op.getWindowDilations()));
-        auto padding = builder_.CreateVector(
-            mlir::GetOptionalVector<int64_t>(shlo_op.getPadding()));
-
-        auto& body = shlo_op.getBody();
-        int32_t subgraph_index = UnnamedRegionToSubgraph(
-            &body, tflite::BuiltinOperator_STABLEHLO_REDUCE_WINDOW);
-        if (subgraph_index < 0) return std::nullopt;
-
-        auto reduce_window_option = tflite::CreateStablehloReduceWindowOptions(
-            builder_, window_dimensions, window_strides, base_dilations,
-            window_dilations, padding, subgraph_index);
-
-        return tflite::CreateOperator(
-            builder_, opcode_index, builder_.CreateVector(operands),
-            builder_.CreateVector(results), tflite::BuiltinOptions_NONE, 0, 0,
-            tflite::CustomOptionsFormat_FLEXBUFFERS, 0, 0, 0, 0,
-            tflite::BuiltinOptions2_StablehloReduceWindowOptions,
-            reduce_window_option.Union());
-      }
       if (auto shlo_op = llvm::dyn_cast<mlir::stablehlo::SortOp>(inst)) {
         std::string op_name = inst->getName().getStringRef().str();
         uint32_t opcode_index =
@@ -2104,40 +2166,6 @@
             tflite::BuiltinOptions2_StablehloWhileOptions,
             while_option.Union());
       }
-      if (auto shlo_op = llvm::dyn_cast<mlir::stablehlo::GatherOp>(inst)) {
-        std::string op_name = inst->getName().getStringRef().str();
-        uint32_t opcode_index =
-            GetOpcodeIndex(op_name, tflite::BuiltinOperator_STABLEHLO_GATHER);
-
-        std::vector<int64_t> offset_dims_vec(
-            shlo_op.getDimensionNumbers().getOffsetDims().begin(),
-            shlo_op.getDimensionNumbers().getOffsetDims().end());
-        std::vector<int64_t> collapsed_slice_dims_vec(
-            shlo_op.getDimensionNumbers().getCollapsedSliceDims().begin(),
-            shlo_op.getDimensionNumbers().getCollapsedSliceDims().end());
-        std::vector<int64_t> start_index_map_vec(
-            shlo_op.getDimensionNumbers().getStartIndexMap().begin(),
-            shlo_op.getDimensionNumbers().getStartIndexMap().end());
-
-        auto offset_dims = builder_.CreateVector(offset_dims_vec);
-        auto collapsed_slice_dims =
-            builder_.CreateVector(collapsed_slice_dims_vec);
-        auto start_index_map = builder_.CreateVector(start_index_map_vec);
-        auto slice_sizes = builder_.CreateVector(
-            mlir::GetOptionalVector<int64_t>(shlo_op.getSliceSizes()));
-
-        auto gather_option = tflite::CreateStablehloGatherOptions(
-            builder_, offset_dims, collapsed_slice_dims, start_index_map,
-            shlo_op.getDimensionNumbers().getIndexVectorDim(), slice_sizes,
-            shlo_op.getIndicesAreSorted());
-
-        return tflite::CreateOperator(
-            builder_, opcode_index, builder_.CreateVector(operands),
-            builder_.CreateVector(results), tflite::BuiltinOptions_NONE, 0, 0,
-            tflite::CustomOptionsFormat_FLEXBUFFERS, 0, 0, 0, 0,
-            tflite::BuiltinOptions2_StablehloGatherOptions,
-            gather_option.Union());
-      }
       if (auto shlo_op = llvm::dyn_cast<mlir::stablehlo::TransposeOp>(inst)) {
         std::string op_name = inst->getName().getStringRef().str();
         uint32_t opcode_index = GetOpcodeIndex(
@@ -2597,6 +2625,18 @@
   return builder_.CreateVector(metadata);
 }
 
+VectorBufferOffset<int32_t> Translator::CreateMetadataBufferVector() {
+  auto array_attr =
+      module_->getAttrOfType<mlir::ArrayAttr>("tfl.metadata_buffer");
+  std::vector<int32_t> metadata_buffer;
+  if (!array_attr) return 0;
+  for (auto value : array_attr.getAsValueRange<mlir::IntegerAttr>()) {
+    metadata_buffer.push_back(value.getSExtValue());
+  }
+
+  return builder_.CreateVector(metadata_buffer);
+}
+
 // Helper method that returns list of all strings in a StringAttr identified
 // by 'attr_key' and values are separated by a comma.
 llvm::SmallVector<llvm::StringRef, 2> GetStringsFromAttrWithSeparator(
@@ -2999,7 +3039,8 @@
 
   // Build the model and finish the model building process.
   auto description = builder_.CreateString(model_description.data());
-  VectorBufferOffset<int32_t> metadata_buffer = 0;  // Deprecated
+  VectorBufferOffset<int32_t> metadata_buffer =
+      CreateMetadataBufferVector();  // Deprecated
   auto metadata = CreateMetadataVector();
   if (!metadata) return std::nullopt;
 
diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
index 4094374..6eb2aee 100644
--- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
+++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
@@ -1904,6 +1904,11 @@
                     mlir::UnitAttr::get(builder.getContext()));
   }
 
+  if (!model->metadata_buffer.empty()) {
+    module->setAttr("tfl.metadata_buffer",
+                    builder.getI32ArrayAttr(model->metadata_buffer));
+  }
+
   if (use_stablehlo_constant) {
     module->setAttr("tfl.metadata",
                     builder.getDictionaryAttr(builder.getNamedAttr(
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
index b5e8fb2..4b915af 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
@@ -385,20 +385,21 @@
 //===----------------------------------------------------------------------===//
 // TFL op common constraints.
 //===----------------------------------------------------------------------===//
+def OperandsSameElementTypeConstraintBasePred :
+  Or<[TCopVTEtIsSameAs<0, 1>,
+    // Two operands' values are both quantized and their type have the same
+    // underlying storage type.
+    And<[
+      SubstLeaves<"$_self", "getElementTypeOrSelf($_op.getOperand(0))",
+        quant_QuantizedType.predicate>,
+      CPred<"quant::QuantizedType::castToStorageType("
+                "getElementTypeOrSelf($_op.getOperand(0))) == "
+            "quant::QuantizedType::castToStorageType("
+                "getElementTypeOrSelf($_op.getOperand(1)))">]>]>;
 
 class OperandsSameElementTypeConstraintBase<string op> :
   PredOpTrait<op # " operands have same element type",
-    Or<[
-      TCopVTEtIsSameAs<0, 1>,
-      // Two operands' values are both quantized and their type have the same
-      // underlying storage type.
-      And<[
-        SubstLeaves<"$_self", "getElementTypeOrSelf($_op.getOperand(0))",
-          quant_QuantizedType.predicate>,
-        CPred<"quant::QuantizedType::castToStorageType("
-                  "getElementTypeOrSelf($_op.getOperand(0))) == "
-              "quant::QuantizedType::castToStorageType("
-                  "getElementTypeOrSelf($_op.getOperand(1)))">]>]>>;
+    OperandsSameElementTypeConstraintBasePred>;
 
 // This is a constraint for most of the binary ops, e.g., add, mul, div, etc.
 // Binary ops lhs & rhs should have the same value type, and is capable to
@@ -538,8 +539,9 @@
 }
 
 def TFL_AddOp : TFL_Op<"add", [
-    TFL_RuntimePredOpTrait<"Operands do not have valid shapes",
-      CPred<"TFL::VerifyAddOpShapeConstraints(llvm::cast<AddOp>($_op))">>,
+    TFL_RuntimePredOpTrait<"Operands should have valid shapes and element type needs to match",
+      And<[CPred<"TFL::VerifyAddOpShapeConstraints(llvm::cast<AddOp>($_op))">,
+           OperandsSameElementTypeConstraintBasePred]>>,
     ResultsBroadcastableShape,
     Pure,
     Commutative,
diff --git a/tensorflow/compiler/mlir/lite/metrics/error_collector_inst_test.cc b/tensorflow/compiler/mlir/lite/metrics/error_collector_inst_test.cc
index 83d0a0e..7984739 100644
--- a/tensorflow/compiler/mlir/lite/metrics/error_collector_inst_test.cc
+++ b/tensorflow/compiler/mlir/lite/metrics/error_collector_inst_test.cc
@@ -160,7 +160,7 @@
   EXPECT_EQ(collected_errors.count(NewConverterErrorData(
                 "MockFailurePass",
                 "Failed at tf.Const op\nsee current operation: %0 = "
-                "\"tf.Const\"() {value = dense<1> : tensor<4xi32>} : () -> "
+                "\"tf.Const\"() <{value = dense<1> : tensor<4xi32>}> : () -> "
                 "tensor<4xi32>\nError code: ERROR_NEEDS_FLEX_OPS",
                 ConverterErrorData::ERROR_NEEDS_FLEX_OPS, "tf.Const",
                 mlir::FileLineColLoc::get(input_file_id, 2, 9))),
@@ -168,22 +168,23 @@
   EXPECT_EQ(collected_errors.count(NewConverterErrorData(
                 "MockFailurePass",
                 "Failed at tf.Const op\nsee current operation: %1 = "
-                "\"tf.Const\"() {value = dense<0> : tensor<4xi32>} : () -> "
+                "\"tf.Const\"() <{value = dense<0> : tensor<4xi32>}> : () -> "
                 "tensor<4xi32>\nError code: ERROR_NEEDS_FLEX_OPS",
                 ConverterErrorData::ERROR_NEEDS_FLEX_OPS, "tf.Const",
                 mlir::FileLineColLoc::get(input_file_id, 2, 9))),
             1);
-  EXPECT_EQ(collected_errors.count(NewConverterErrorData(
-                "MockFailurePass",
-                "Failed at tf.StridedSlice op\nsee current operation: %2 = "
-                "\"tf.StridedSlice\"(%arg0, %1, %1, %0) {begin_mask = 11 : "
-                "i64, device = \"\", ellipsis_mask = 0 : i64, end_mask = 11 : "
-                "i64, new_axis_mask = 4 : i64, shrink_axis_mask = 0 : i64} : "
-                "(tensor<*xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) "
-                "-> tensor<*xf32>\nError code: ERROR_NEEDS_FLEX_OPS",
-                ConverterErrorData::ERROR_NEEDS_FLEX_OPS, "tf.StridedSlice",
-                mlir::FileLineColLoc::get(input_file_id, 4, 10))),
-            1);
+  EXPECT_EQ(
+      collected_errors.count(NewConverterErrorData(
+          "MockFailurePass",
+          "Failed at tf.StridedSlice op\nsee current operation: %2 = "
+          "\"tf.StridedSlice\"(%arg0, %1, %1, %0) <{begin_mask = 11 : "
+          "i64, ellipsis_mask = 0 : i64, end_mask = 11 : i64, new_axis_mask = "
+          "4 : i64, shrink_axis_mask = 0 : i64}> {device = \"\"} : "
+          "(tensor<*xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) "
+          "-> tensor<*xf32>\nError code: ERROR_NEEDS_FLEX_OPS",
+          ConverterErrorData::ERROR_NEEDS_FLEX_OPS, "tf.StridedSlice",
+          mlir::FileLineColLoc::get(input_file_id, 4, 10))),
+      1);
 
   // Check the location information.
   std::vector<std::string> locations;
diff --git a/tensorflow/compiler/mlir/lite/quantization/numerical_utils.cc b/tensorflow/compiler/mlir/lite/quantization/numerical_utils.cc
index 92220e3..b9bc0b1 100644
--- a/tensorflow/compiler/mlir/lite/quantization/numerical_utils.cc
+++ b/tensorflow/compiler/mlir/lite/quantization/numerical_utils.cc
@@ -26,8 +26,17 @@
 namespace mlir {
 namespace quant {
 
-// This method is adopted from TFLite:
-// ["tensorflow/lite/kernels/internal/quantization_util.cc"]
+// Converts a double-precision floating-point multiplier to a quantized
+// multiplier.
+//
+// Args:
+//   double_multiplier: The double-precision floating-point multiplier.
+//
+// Returns:
+//   A quantized multiplier, represented as a pair of integers: the quantized
+//   multiplier and the shift amount. The shift amount is the number of bits
+//   that the quantized multiplier should be shifted to the right before being
+//   used.
 QuantizedMultiplier QuantizeMultiplier(double double_multiplier) {
   if (double_multiplier < 1e-6) {
     return {0, 0};
@@ -35,30 +44,36 @@
 
   int32_t shift;
   const double q = frexp(double_multiplier, &shift);
-  auto q_fixed = static_cast<int64_t>(round(q * (1LL << 31)));
-  assert(q_fixed <= (1LL << 31));
-  if (q_fixed == (1LL << 31)) {
-    q_fixed /= 2;
+  int64_t quantized_multiplier = round(q * (1LL << 31));
+  assert(quantized_multiplier <= (1LL << 31));
+  if (quantized_multiplier == (1LL << 31)) {
+    quantized_multiplier /= 2;
     ++shift;
   }
-  assert(q_fixed <= std::numeric_limits<int32_t>::max());
-  // A shift amount smaller than -31 would cause all bits to be shifted out
-  // and thus all results would be zero. We implement that instead with
-  // q_fixed==0, so as to avoid hitting issues with right-shift
-  // operations with shift amounts greater than 31. Note that this happens
-  // roughly when abs(double_multiplier) < 2^-31 and the present handling means
-  // that we're effectively flushing tiny double_multiplier's to zero.
-  // We could conceivably handle values in the range (roughly) [32, 63]
-  // as 'denormals' i.e. (shift==0, q_fixed < 2^30). In that point of view
-  // the present handling is just doing 'flush denormals to zero'. We could
-  // reconsider and actually generate nonzero denormals if a need arises.
-  if (shift < -31) {
-    shift = 0;
-    q_fixed = 0;
+  assert(quantized_multiplier <= std::numeric_limits<int32_t>::max());
+
+  // Check that the shift amount is not greater than 31 or less than -31.
+  if (shift > 31 || shift < -31) {
+    return {0, 0};
   }
-  return {static_cast<int32_t>(q_fixed), shift};
+
+  return {static_cast<int32_t>(quantized_multiplier), shift};
 }
 
+// Calculates the quantized range for a given scale, zero point, minimum and
+// maximum values, and quantization range.
+//
+// Args:
+//   scale: The scale factor for the quantized values.
+//   zero_point: The zero point for the quantized values.
+//   rmin: The minimum value of the quantized values.
+//   rmax: The maximum value of the quantized values.
+//   qmin: The minimum value of the quantization range.
+//   qmax: The maximum value of the quantization range.
+//
+// Returns:
+//   A quantized range, represented as a pair of integers: the minimum and
+//   maximum quantized values.
 QuantizedRange CalculateQuantizedRange(double scale, int32_t zero_point,
                                        std::optional<double> rmin,
                                        std::optional<double> rmax, int32_t qmin,
diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/fallback_to_flex_ops_default.mlir b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/fallback_to_flex_ops_default.mlir
index 589d438..9c6d9b8 100644
--- a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/fallback_to_flex_ops_default.mlir
+++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/fallback_to_flex_ops_default.mlir
@@ -4,7 +4,7 @@
 func.func @bias_add(%arg0: tensor<1x10x10x32xf32>, %arg1: tensor<32xf32>) -> tensor<1x10x10x32xf32> {
   %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x10x10x32xf32>, tensor<32xf32>) -> tensor<1x10x10x32xf32>
   func.return %0 : tensor<1x10x10x32xf32>
-// CHECK: %[[BIASADD_0:.*]] = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x10x10x32xf32>, tensor<32xf32>) -> tensor<1x10x10x32xf32>
+// CHECK: %[[BIASADD_0:.*]] = "tf.BiasAdd"(%arg0, %arg1) <{data_format = "NHWC"}> {T = "tfdtype$DT_FLOAT"} : (tensor<1x10x10x32xf32>, tensor<32xf32>) -> tensor<1x10x10x32xf32>
 // CHECK: return %[[BIASADD_0]] : tensor<1x10x10x32xf32>
 }
 
@@ -30,8 +30,8 @@
   %1 = "tf.Const"() {value = dense<1.000000e+00> : tensor<1xf32>} : () -> tensor<1xf32>
   %2 = "tf.AddV2"(%0, %1): (tensor<15x28x28x1xf32>, tensor<1xf32>) -> tensor<15x28x28x1xf32>
   func.return %2 : tensor<15x28x28x1xf32>
-// CHECK: %[[CONST_0:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<1xf32>} : () -> tensor<1xf32>
-// CHECK: %[[CONV2DBACKPROPINPUT_0:.*]] = "tf.Conv2DBackpropInput"(%arg0, %arg1, %arg2) {dilations = [1, 1, 1, 1], padding = "SAME", strides = [1, 2, 2, 1]} : (tensor<4xi32>, tensor<3x3x1x32xf32>, tensor<15x14x14x32xf32>) -> tensor<15x28x28x1xf32>
+// CHECK: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
+// CHECK: %[[CONV2DBACKPROPINPUT_0:.*]] = "tf.Conv2DBackpropInput"(%arg0, %arg1, %arg2) <{dilations = [1, 1, 1, 1], padding = "SAME", strides = [1, 2, 2, 1]}> : (tensor<4xi32>, tensor<3x3x1x32xf32>, tensor<15x14x14x32xf32>) -> tensor<15x28x28x1xf32>
 // CHECK: %[[ADDV2_0:.*]] = "tf.AddV2"(%[[CONV2DBACKPROPINPUT_0]], %[[CONST_0]]) : (tensor<15x28x28x1xf32>, tensor<1xf32>) -> tensor<15x28x28x1xf32>
 // CHECK: return %[[ADDV2_0]] : tensor<15x28x28x1xf32>
 }
@@ -42,8 +42,8 @@
   %1 = "tf.Const"() {value = dense<1.000000e+00> : tensor<1xf32>} : () -> tensor<1xf32>
   %2 = "tf.Sub"(%0, %1): (tensor<15x28x28x1xf32>, tensor<1xf32>) -> tensor<15x28x28x1xf32>
   func.return %2 : tensor<15x28x28x1xf32>
-// CHECK: %[[CONST_0:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<1xf32>} : () -> tensor<1xf32>
-// CHECK: %[[CONV2DBACKPROPINPUT_0:.*]] = "tf.Conv2DBackpropInput"(%arg0, %arg1, %arg2) {dilations = [1, 1, 1, 1], padding = "SAME", strides = [1, 2, 2, 1]} : (tensor<4xi32>, tensor<3x3x1x32xf32>, tensor<15x14x14x32xf32>) -> tensor<15x28x28x1xf32>
+// CHECK: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
+// CHECK: %[[CONV2DBACKPROPINPUT_0:.*]] = "tf.Conv2DBackpropInput"(%arg0, %arg1, %arg2) <{dilations = [1, 1, 1, 1], padding = "SAME", strides = [1, 2, 2, 1]}> : (tensor<4xi32>, tensor<3x3x1x32xf32>, tensor<15x14x14x32xf32>) -> tensor<15x28x28x1xf32>
 // CHECK: %[[SUB_0:.*]] = "tf.Sub"(%[[CONV2DBACKPROPINPUT_0]], %[[CONST_0]]) : (tensor<15x28x28x1xf32>, tensor<1xf32>) -> tensor<15x28x28x1xf32>
 // CHECK: return %[[SUB_0]] : tensor<15x28x28x1xf32>
 }
@@ -71,7 +71,7 @@
   %0 = "tf.Identity"(%cst) {device = ""} : (tensor<2xf32>) -> tensor<*xf32>
   %1 = "tf.AddV2"(%0, %cst_1) {device = ""} : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
   func.return %1 : tensor<*xf32>
-// CHECK: %[[CONST_0:.*]] = "tf.Const"() {value = dense<[2.177590e-01, 2.89503098]> : tensor<2xf32>} : () -> tensor<*xf32>
+// CHECK: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<[2.177590e-01, 2.89503098]> : tensor<2xf32>}> : () -> tensor<*xf32>
 // CHECK: return %[[CONST_0]] : tensor<*xf32>
 }
 
@@ -80,7 +80,7 @@
   %0 = "tf.Identity"(%arg0) {device = ""} : (tensor<2xf32>) -> tensor<*xf32>
   %1 = "tf.AddV2"(%0, %cst_1) {device = ""} : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
   func.return %1 : tensor<*xf32>
-// CHECK: %[[CONST_0:.*]] = "tf.Const"() {device = "", value = dense<1.000000e-03> : tensor<f32>} : () -> tensor<f32>
+// CHECK: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<1.000000e-03> : tensor<f32>}> {device = ""} : () -> tensor<f32>
 // CHECK: %[[IDENTITY_0:.*]] = "tf.Identity"(%arg0) {device = ""} : (tensor<2xf32>) -> tensor<*xf32>
 // CHECK: %[[ADDV2_0:.*]] = "tfl.custom"(%0, %cst) {custom_code = "FlexAddV2", custom_option = #tfl<const_bytes : "0x0541646456320016120541646456321A001A002A070A015412023001320000021F191414042801">} : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
 // CHECK: return %[[ADDV2_0]] : tensor<*xf32>
@@ -95,7 +95,7 @@
   %2 = "tf.Conv2DBackpropInput"(%arg0, %arg1, %arg2) {strides = [1, 2, 2, 1], padding="SAME", dilations=[1, 1, 1, 1]}: (tensor<4xi32>, tensor<3x3x1x32xf32>, tensor<15x14x14x32xf32>) -> tensor<15x28x28x1xf32>
   %3 = "tf.AddV2"(%2, %1): (tensor<15x28x28x1xf32>, tensor<1xf32>) -> tensor<15x28x28x1xf32>
   func.return %2 : tensor<15x28x28x1xf32>
-// CHECK: %[[CONV2DBACKPROPINPUT_0:.*]] = "tf.Conv2DBackpropInput"(%arg0, %arg1, %arg2) {dilations = [1, 1, 1, 1], padding = "SAME", strides = [1, 2, 2, 1]} : (tensor<4xi32>, tensor<3x3x1x32xf32>, tensor<15x14x14x32xf32>) -> tensor<15x28x28x1xf32>
+// CHECK: %[[CONV2DBACKPROPINPUT_0:.*]] = "tf.Conv2DBackpropInput"(%arg0, %arg1, %arg2) <{dilations = [1, 1, 1, 1], padding = "SAME", strides = [1, 2, 2, 1]}> : (tensor<4xi32>, tensor<3x3x1x32xf32>, tensor<15x14x14x32xf32>) -> tensor<15x28x28x1xf32>
 // CHECK: return %[[CONV2DBACKPROPINPUT_0]] : tensor<15x28x28x1xf32>
 }
 
@@ -108,10 +108,10 @@
   %1 = "tf.Maximum"(%0, %cst_0) : (tensor<1x3x4x2xf32>, tensor<f32>) -> tensor<1x3x4x2xf32>
   %2 = "tf.Minimum"(%1, %cst_1) : (tensor<1x3x4x2xf32>, tensor<f32>) -> tensor<1x3x4x2xf32>
   func.return %2 : tensor<1x3x4x2xf32>
-// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() {value = dense<{{.*}}> : tensor<1x1x3x2xf32>} : () -> tensor<1x1x3x2xf32>
-// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() {value = dense<-1.000000e+00> : tensor<f32>} : () -> tensor<f32>
-// CHECK-DAG: %[[CONST_2:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
-// CHECK: %[[CONV2D_0:.*]] = "tf.Conv2D"(%arg0, %[[CONST_0]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x3x4x3xf32>, tensor<1x1x3x2xf32>) -> tensor<1x3x4x2xf32>
+// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<{{.*}}> : tensor<1x1x3x2xf32>}> : () -> tensor<1x1x3x2xf32>
+// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() <{value = dense<-1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG: %[[CONST_2:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+// CHECK: %[[CONV2D_0:.*]] = "tf.Conv2D"(%arg0, %[[CONST_0]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]}> : (tensor<1x3x4x3xf32>, tensor<1x1x3x2xf32>) -> tensor<1x3x4x2xf32>
 // CHECK: %[[MAXIMUM_0:.*]] = "tf.Maximum"(%[[CONV2D_0]], %[[CONST_1]]) : (tensor<1x3x4x2xf32>, tensor<f32>) -> tensor<1x3x4x2xf32>
 // CHECK: %[[MINIMUM_0:.*]] = "tf.Minimum"(%[[MAXIMUM_0]], %[[CONST_2]]) : (tensor<1x3x4x2xf32>, tensor<f32>) -> tensor<1x3x4x2xf32>
 // CHECK: return %[[MINIMUM_0]] : tensor<1x3x4x2xf32>
@@ -126,10 +126,10 @@
   %1 = "tf.Minimum"(%0, %cst_1) : (tensor<1x3x4x2xf32>, tensor<f32>) -> tensor<1x3x4x2xf32>
   %2 = "tf.Maximum"(%1, %cst_0) : (tensor<1x3x4x2xf32>, tensor<f32>) -> tensor<1x3x4x2xf32>
   func.return %2 : tensor<1x3x4x2xf32>
-// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() {value = dense<{{.*}}> : tensor<1x1x3x2xf32>} : () -> tensor<1x1x3x2xf32>
-// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() {value = dense<-1.000000e+00> : tensor<f32>} : () -> tensor<f32>
-// CHECK-DAG: %[[CONST_2:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
-// CHECK: %[[CONV2D_0:.*]] = "tf.Conv2D"(%arg0, %[[CONST_0]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x3x4x3xf32>, tensor<1x1x3x2xf32>) -> tensor<1x3x4x2xf32>
+// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<{{.*}}> : tensor<1x1x3x2xf32>}> : () -> tensor<1x1x3x2xf32>
+// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() <{value = dense<-1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG: %[[CONST_2:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+// CHECK: %[[CONV2D_0:.*]] = "tf.Conv2D"(%arg0, %[[CONST_0]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]}> : (tensor<1x3x4x3xf32>, tensor<1x1x3x2xf32>) -> tensor<1x3x4x2xf32>
 // CHECK: %[[MINIMUM_0:.*]] = "tf.Minimum"(%[[CONV2D_0]], %[[CONST_2]]) : (tensor<1x3x4x2xf32>, tensor<f32>) -> tensor<1x3x4x2xf32>
 // CHECK: %[[MAXIMUM_0:.*]] = "tf.Maximum"(%[[MINIMUM_0]], %[[CONST_1]]) : (tensor<1x3x4x2xf32>, tensor<f32>) -> tensor<1x3x4x2xf32>
 // CHECK: return %[[MAXIMUM_0]] : tensor<1x3x4x2xf32>
@@ -144,10 +144,10 @@
   %1 = "tf.Minimum"(%0, %cst_1) : (tensor<1x3x4x2xf32>, tensor<2xf32>) -> tensor<1x3x4x2xf32>
   %2 = "tf.Maximum"(%1, %cst_0) : (tensor<1x3x4x2xf32>, tensor<2xf32>) -> tensor<1x3x4x2xf32>
   func.return %2 : tensor<1x3x4x2xf32>
-// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() {value = dense<{{.*}}> : tensor<1x1x3x2xf32>} : () -> tensor<1x1x3x2xf32>
-// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() {value = dense<[-1.000000e+00, -3.000000e+00]> : tensor<2xf32>} : () -> tensor<2xf32>
-// CHECK-DAG: %[[CONST_2:.*]] = "tf.Const"() {value = dense<[1.000000e+00, 3.000000e+00]> : tensor<2xf32>} : () -> tensor<2xf32>
-// CHECK: %[[CONV2D_0:.*]] = "tf.Conv2D"(%arg0, %[[CONST_0]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x3x4x3xf32>, tensor<1x1x3x2xf32>) -> tensor<1x3x4x2xf32>
+// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<{{.*}}> : tensor<1x1x3x2xf32>}> : () -> tensor<1x1x3x2xf32>
+// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() <{value = dense<[-1.000000e+00, -3.000000e+00]> : tensor<2xf32>}> : () -> tensor<2xf32>
+// CHECK-DAG: %[[CONST_2:.*]] = "tf.Const"() <{value = dense<[1.000000e+00, 3.000000e+00]> : tensor<2xf32>}> : () -> tensor<2xf32>
+// CHECK: %[[CONV2D_0:.*]] = "tf.Conv2D"(%arg0, %[[CONST_0]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]}> : (tensor<1x3x4x3xf32>, tensor<1x1x3x2xf32>) -> tensor<1x3x4x2xf32>
 // CHECK: %[[CUSTOM_0:.*]] = "tfl.custom"(%[[CONV2D_0]], %[[CONST_2]]) {custom_code = "FlexMinimum", custom_option = #tfl<const_bytes : "0x074D696E696D756D001812074D696E696D756D1A001A002A070A01541202300132000002231B1414042801">} : (tensor<1x3x4x2xf32>, tensor<2xf32>) -> tensor<1x3x4x2xf32>
 // CHECK: %[[CUSTOM_1:.*]] = "tfl.custom"(%[[CUSTOM_0]], %[[CONST_1]]) {custom_code = "FlexMaximum", custom_option = #tfl<const_bytes : "0x074D6178696D756D001812074D6178696D756D1A001A002A070A01541202300132000002231B1414042801">} : (tensor<1x3x4x2xf32>, tensor<2xf32>) -> tensor<1x3x4x2xf32>
 // CHECK: return %[[CUSTOM_1]] : tensor<1x3x4x2xf32>
diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/fallback_to_flex_ops_legacy.mlir b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/fallback_to_flex_ops_legacy.mlir
index dea1a9c..5835d7d 100644
--- a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/fallback_to_flex_ops_legacy.mlir
+++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/fallback_to_flex_ops_legacy.mlir
@@ -4,7 +4,7 @@
 func.func @bias_add(%arg0: tensor<1x10x10x32xf32>, %arg1: tensor<32xf32>) -> tensor<1x10x10x32xf32> {
   %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x10x10x32xf32>, tensor<32xf32>) -> tensor<1x10x10x32xf32>
   func.return %0 : tensor<1x10x10x32xf32>
-// CHECK: %[[BIASADD_0:.*]] = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x10x10x32xf32>, tensor<32xf32>) -> tensor<1x10x10x32xf32>
+// CHECK: %[[BIASADD_0:.*]] = "tf.BiasAdd"(%arg0, %arg1) <{data_format = "NHWC"}> {T = "tfdtype$DT_FLOAT"} : (tensor<1x10x10x32xf32>, tensor<32xf32>) -> tensor<1x10x10x32xf32>
 // CHECK: return %[[BIASADD_0]] : tensor<1x10x10x32xf32>
 }
 
@@ -30,8 +30,8 @@
   %1 = "tf.Const"() {value = dense<1.000000e+00> : tensor<1xf32>} : () -> tensor<1xf32>
   %2 = "tf.AddV2"(%0, %1): (tensor<15x28x28x1xf32>, tensor<1xf32>) -> tensor<15x28x28x1xf32>
   func.return %2 : tensor<15x28x28x1xf32>
-// CHECK: %[[CONST_0:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<1xf32>} : () -> tensor<1xf32>
-// CHECK: %[[CONV2DBACKPROPINPUT_0:.*]] = "tf.Conv2DBackpropInput"(%arg0, %arg1, %arg2) {dilations = [1, 1, 1, 1], padding = "SAME", strides = [1, 2, 2, 1]} : (tensor<4xi32>, tensor<3x3x1x32xf32>, tensor<15x14x14x32xf32>) -> tensor<15x28x28x1xf32>
+// CHECK: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
+// CHECK: %[[CONV2DBACKPROPINPUT_0:.*]] = "tf.Conv2DBackpropInput"(%arg0, %arg1, %arg2) <{dilations = [1, 1, 1, 1], padding = "SAME", strides = [1, 2, 2, 1]}> : (tensor<4xi32>, tensor<3x3x1x32xf32>, tensor<15x14x14x32xf32>) -> tensor<15x28x28x1xf32>
 // CHECK: %[[ADDV2_0:.*]] = "tf.AddV2"(%[[CONV2DBACKPROPINPUT_0]], %[[CONST_0]]) {no_fallback} : (tensor<15x28x28x1xf32>, tensor<1xf32>) -> tensor<15x28x28x1xf32>
 // CHECK: return %[[ADDV2_0]] : tensor<15x28x28x1xf32>
 }
@@ -42,8 +42,8 @@
   %1 = "tf.Const"() {value = dense<1.000000e+00> : tensor<1xf32>} : () -> tensor<1xf32>
   %2 = "tf.Sub"(%0, %1): (tensor<15x28x28x1xf32>, tensor<1xf32>) -> tensor<15x28x28x1xf32>
   func.return %2 : tensor<15x28x28x1xf32>
-// CHECK: %[[CONST_0:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<1xf32>} : () -> tensor<1xf32>
-// CHECK: %[[CONV2DBACKPROPINPUT_0:.*]] = "tf.Conv2DBackpropInput"(%arg0, %arg1, %arg2) {dilations = [1, 1, 1, 1], padding = "SAME", strides = [1, 2, 2, 1]} : (tensor<4xi32>, tensor<3x3x1x32xf32>, tensor<15x14x14x32xf32>) -> tensor<15x28x28x1xf32>
+// CHECK: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
+// CHECK: %[[CONV2DBACKPROPINPUT_0:.*]] = "tf.Conv2DBackpropInput"(%arg0, %arg1, %arg2) <{dilations = [1, 1, 1, 1], padding = "SAME", strides = [1, 2, 2, 1]}> : (tensor<4xi32>, tensor<3x3x1x32xf32>, tensor<15x14x14x32xf32>) -> tensor<15x28x28x1xf32>
 // CHECK: %[[SUB_0:.*]] = "tf.Sub"(%[[CONV2DBACKPROPINPUT_0]], %[[CONST_0]]) {no_fallback} : (tensor<15x28x28x1xf32>, tensor<1xf32>) -> tensor<15x28x28x1xf32>
 // CHECK: return %[[SUB_0]] : tensor<15x28x28x1xf32>
 }
@@ -73,6 +73,6 @@
   %2 = "tf.Conv2DBackpropInput"(%arg0, %arg1, %arg2) {strides = [1, 2, 2, 1], padding="SAME", dilations=[1, 1, 1, 1]}: (tensor<4xi32>, tensor<3x3x1x32xf32>, tensor<15x14x14x32xf32>) -> tensor<15x28x28x1xf32>
   %3 = "tf.AddV2"(%2, %1): (tensor<15x28x28x1xf32>, tensor<1xf32>) -> tensor<15x28x28x1xf32>
   func.return %2 : tensor<15x28x28x1xf32>
-// CHECK: %[[CONV2DBACKPROPINPUT_0:.*]] = "tf.Conv2DBackpropInput"(%arg0, %arg1, %arg2) {dilations = [1, 1, 1, 1], padding = "SAME", strides = [1, 2, 2, 1]} : (tensor<4xi32>, tensor<3x3x1x32xf32>, tensor<15x14x14x32xf32>) -> tensor<15x28x28x1xf32>
+// CHECK: %[[CONV2DBACKPROPINPUT_0:.*]] = "tf.Conv2DBackpropInput"(%arg0, %arg1, %arg2) <{dilations = [1, 1, 1, 1], padding = "SAME", strides = [1, 2, 2, 1]}> : (tensor<4xi32>, tensor<3x3x1x32xf32>, tensor<15x14x14x32xf32>) -> tensor<15x28x28x1xf32>
 // CHECK: return %[[CONV2DBACKPROPINPUT_0]] : tensor<15x28x28x1xf32>
 }
diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/tf_to_quant.mlir b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/tf_to_quant.mlir
index 83f07de..dd93ae2 100644
--- a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/tf_to_quant.mlir
+++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/tf_to_quant.mlir
@@ -36,7 +36,7 @@
   %1 = "quantfork.qcast"(%0) : (tensor<8xf32>) -> tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>
   func.return %1 : tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>
 
-// CHECK:  %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0) {narrow_range = false, num_bits = 5 : i64}
+// CHECK:  %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0) <{narrow_range = false, num_bits = 5 : i64}>
 // CHECK:  %1 = "quantfork.qcast"(%0) : (tensor<8xf32>) -> tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>
 // CHECK:  return %1
 }
@@ -51,7 +51,7 @@
   %rst = "tf.FakeQuantWithMinMaxVars"(%in, %mini, %maxi) {num_bits = 5, narrow_range = false} : (tensor<8xf32>, tensor<f32>, tensor<f32>) -> tensor<8xf32>
   func.return %rst : tensor<8xf32>
 
-// CHECK: %[[CONSTANT:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<8xf32>}
+// CHECK: %[[CONSTANT:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<8xf32>}>
 // CHECK: %[[QUANTIZE:.*]] = "quantfork.qcast"(%[[CONSTANT]]) : (tensor<8xf32>) -> tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>
 // CHECK: %[[DEQUANTIZE:.*]] = "quantfork.dcast"(%[[QUANTIZE]])
 // CHECK: return %[[DEQUANTIZE]] : tensor<8xf32>
@@ -79,7 +79,7 @@
   %rst = "tf.Conv2D"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x8x7x16xf32>
   func.return %rst : tensor<256x8x7x16xf32>
 
-// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}
+// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}>
 // CHECK: %[[QUANTIZE:.*]] = "quantfork.qcast"(%[[CONSTANT0]]) : (tensor<3x3x3x16xf32>) -> tensor<3x3x3x16x!quant.uniform<i8:f32, 1.000000e+00:-128>>
 // CHECK: %[[DEQUANTIZE:.*]] = "quantfork.dcast"(%[[QUANTIZE]])
 // CHECK: %[[CONV:.*]] = "tf.Conv2D"(%arg0, %[[DEQUANTIZE]])
@@ -98,7 +98,7 @@
   %rst = "tf.Conv2D"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x8x7x16xf32>
   func.return %rst : tensor<256x8x7x16xf32>
 
-// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}
+// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}>
 // CHECK: %[[QUANTIZE:.*]] = "quantfork.qcast"(%[[CONSTANT0]]) : (tensor<3x3x3x16xf32>) -> tensor<3x3x3x16x!quant.uniform<i8:f32:3,
 // CHECK-SAME: {1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,
 // CHECK-SAME: 1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128}>>
@@ -119,7 +119,7 @@
   %rst = "tf.DepthwiseConv2dNative"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
   func.return %rst : tensor<256x30x30x16xf32>
 
-// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}
+// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}>
 // CHECK: %[[QUANTIZE:.*]] = "quantfork.qcast"(%[[CONSTANT0]]) : (tensor<3x3x3x16xf32>) -> tensor<3x3x3x16x!quant.uniform<i8:f32, 1.000000e+00:-128>>
 // CHECK: %[[DEQUANTIZE:.*]] = "quantfork.dcast"(%[[QUANTIZE]])
 // CHECK: %[[CONV:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[DEQUANTIZE]])
@@ -138,7 +138,7 @@
   %rst = "tf.DepthwiseConv2dNative"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
   func.return %rst : tensor<256x30x30x16xf32>
 
-// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}
+// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}>
 // CHECK: %[[QUANTIZE:.*]] = "quantfork.qcast"(%[[CONSTANT0]]) : (tensor<3x3x3x16xf32>) -> tensor<3x3x3x16x!quant.uniform<i8:f32:3,
 // CHECK-SAME: {1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,
 // CHECK-SAME: 1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128}>>
diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/tf_to_quant_4bit.mlir b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/tf_to_quant_4bit.mlir
index 54c3de4..519226a 100644
--- a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/tf_to_quant_4bit.mlir
+++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/tf_to_quant_4bit.mlir
@@ -36,7 +36,7 @@
   %1 = "quantfork.qcast"(%0) : (tensor<8xf32>) -> tensor<8x!quant.uniform<i4:f32, 1.000000e+00:-8>>
   func.return %1 : tensor<8x!quant.uniform<i4:f32, 1.000000e+00:-8>>
 
-// CHECK:  %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0) {narrow_range = false, num_bits = 3 : i64}
+// CHECK:  %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0) <{narrow_range = false, num_bits = 3 : i64}>
 // CHECK:  %1 = "quantfork.qcast"(%0) : (tensor<8xf32>) -> tensor<8x!quant.uniform<i4:f32, 1.000000e+00:-8>>
 // CHECK:  return %1
 }
@@ -51,7 +51,7 @@
   %rst = "tf.FakeQuantWithMinMaxVars"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor<f32>, tensor<f32>) -> tensor<8xf32>
   func.return %rst : tensor<8xf32>
 
-// CHECK: %[[CONSTANT:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<8xf32>}
+// CHECK: %[[CONSTANT:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<8xf32>}>
 // CHECK: %[[QUANTIZE:.*]] = "quantfork.qcast"(%[[CONSTANT]]) : (tensor<8xf32>) -> tensor<8x!quant.uniform<i4:f32, 1.000000e+00:-8>>
 // CHECK: %[[DEQUANTIZE:.*]] = "quantfork.dcast"(%[[QUANTIZE]])
 // CHECK: return %[[DEQUANTIZE]] : tensor<8xf32>
@@ -79,7 +79,7 @@
   %rst = "tf.Conv2D"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x8x7x16xf32>
   func.return %rst : tensor<256x8x7x16xf32>
 
-// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}
+// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}>
 // CHECK: %[[QUANTIZE:.*]] = "quantfork.qcast"(%[[CONSTANT0]]) : (tensor<3x3x3x16xf32>) -> tensor<3x3x3x16x!quant.uniform<i4:f32, 1.000000e+00:-8>>
 // CHECK: %[[DEQUANTIZE:.*]] = "quantfork.dcast"(%[[QUANTIZE]])
 // CHECK: %[[CONV:.*]] = "tf.Conv2D"(%arg0, %[[DEQUANTIZE]])
@@ -98,7 +98,7 @@
   %rst = "tf.Conv2D"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x8x7x16xf32>
   func.return %rst : tensor<256x8x7x16xf32>
 
-// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}
+// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}>
 // CHECK: %[[QUANTIZE:.*]] = "quantfork.qcast"(%[[CONSTANT0]]) : (tensor<3x3x3x16xf32>) -> tensor<3x3x3x16x!quant.uniform<i4:f32:3,
 // CHECK-SAME: {1.000000e+00:-8,1.000000e+00:-8,1.000000e+00:-8,1.000000e+00:-8,1.000000e+00:-8,1.000000e+00:-8,1.000000e+00:-8,1.000000e+00:-8,1.000000e+00:-8,
 // CHECK-SAME: 1.000000e+00:-8,1.000000e+00:-8,1.000000e+00:-8,1.000000e+00:-8,1.000000e+00:-8,1.000000e+00:-8,1.000000e+00:-8}>>
@@ -119,7 +119,7 @@
   %rst = "tf.DepthwiseConv2dNative"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
   func.return %rst : tensor<256x30x30x16xf32>
 
-// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}
+// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}>
 // CHECK: %[[QUANTIZE:.*]] = "quantfork.qcast"(%[[CONSTANT0]]) : (tensor<3x3x3x16xf32>) -> tensor<3x3x3x16x!quant.uniform<i4:f32, 1.000000e+00:-8>>
 // CHECK: %[[DEQUANTIZE:.*]] = "quantfork.dcast"(%[[QUANTIZE]])
 // CHECK: %[[CONV:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[DEQUANTIZE]])
@@ -138,7 +138,7 @@
   %rst = "tf.DepthwiseConv2dNative"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
   func.return %rst : tensor<256x30x30x16xf32>
 
-// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}
+// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}>
 // CHECK: %[[QUANTIZE:.*]] = "quantfork.qcast"(%[[CONSTANT0]]) : (tensor<3x3x3x16xf32>) -> tensor<3x3x3x16x!quant.uniform<i4:f32:3,
 // CHECK-SAME: {1.000000e+00:-8,1.000000e+00:-8,1.000000e+00:-8,1.000000e+00:-8,1.000000e+00:-8,1.000000e+00:-8,1.000000e+00:-8,1.000000e+00:-8,1.000000e+00:-8,
 // CHECK-SAME: 1.000000e+00:-8,1.000000e+00:-8,1.000000e+00:-8,1.000000e+00:-8,1.000000e+00:-8,1.000000e+00:-8,1.000000e+00:-8}>>
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/BUILD
index ce3de9b..2459f3d 100644
--- a/tensorflow/compiler/mlir/lite/stablehlo/BUILD
+++ b/tensorflow/compiler/mlir/lite/stablehlo/BUILD
@@ -193,6 +193,7 @@
         ":drop_savedmodel_semantics",
         ":fold_broadcast_pass",
         ":fuse_convolution_pass",
+        ":legalize_tf_xla_call_module_to_stablehlo_pass",
         ":optimize",
         ":rename_entrypoint_to_main",
         ":smuggle_disallowed_ops",
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/call_xla_module_to_stablehlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/call_xla_module_to_stablehlo.mlir
new file mode 100644
index 0000000..7958402
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/call_xla_module_to_stablehlo.mlir
@@ -0,0 +1,26 @@
+//RUN: tf_tfl_translate --enable-stablehlo-conversion --input-mlir %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
+
+
+module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 1660 : i32}} {
+  func.func @main(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> attributes {tf.entry_function = {control_outputs = "", inputs = "args_tf_0", outputs = "Identity"}} {
+    %0 = tf_executor.graph {
+      %outputs, %control = tf_executor.island wraps "tf.Identity"(%arg0) {device = ""} : (tensor<2x3xi32>) -> tensor<2x3xi32>
+      %outputs_0, %control_1 = tf_executor.island wraps "tf.XlaSharding"(%outputs) {_XlaSharding = "", device = "", sharding = "", unspecified_dims = []} : (tensor<2x3xi32>) -> tensor<2x3xi32>
+      %outputs_2, %control_3 = tf_executor.island wraps "tf.XlaCallModule"(%outputs_0) {Sout = [#tf_type.shape<2x3>], device = "", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "ML\EFR\01StableHLO_v0.9.0\00\01\17\05\01\03\01\03\05\03\07\07\09\0B\03]?\0B\01)\07\0F\0B+\0B\0F\0B\0B\0B3\0B\0B\0B\0B\0F\0B\0F\0B\13\0B\03\17\0F\13\0B\0B\0B\0F\13\0B\0B\0B\0B\01\05\0B\0F\03\07\17\17\07\02\D7\1F\11\03\05\05\0D\03\09\09\0B\0D\03\0F\03\05\11\05\0F\11\01\00\05\11\05\13\05\15\03\0B\15)\171\193\05;\1B=\05\17\05\19\05\1B\05\1D\1D\1F\01\05\1F\1D#%\05!\17'\A9\01\05#\03\03+\0D\03-/\1D%\1D'#\07\03\035\0D\0379\1D)\1D+\1D-\1D/\01\09\01\02\02)\05\09\0D\09\11\03\05\03\05\1B\04C\05\01\11\01\07\07\03\01\05\03\11\01\13\07\03\05\0B\03\05\1D\05\06!\03\05\05\01\01\07\04\01\03\03\06\03\01\05\01\00f\051\0F\0B\03!\1B\1D[;\05\1F\15\1D\15\1D%)9\13\15\19\11\0F\0B\11builtin\00vhlo\00module\00func_v1\00multiply_v1\00return_v1\00sym_name\00jax.uses_shape_polymorphism\00mhlo.num_partitions\00mhlo.num_replicas\00jit_jax_model\00arg_attrs\00function_type\00res_attrs\00sym_visibility\00x\00jit(jax_model)/jit(main)/mul\00experimental/users/ypang/lite/convert_ulm.py\00mhlo.sharding\00{replicated}\00jax.result_info\00\00main\00public\00", platforms = ["CPU"], version = 8 : i64} : (tensor<2x3xi32>) -> tensor<2x3xi32>
+      %control_4 = tf_executor.island(%control_3) wraps "tf.NoOp"() {device = ""} : () -> ()
+      %outputs_5, %control_6 = tf_executor.island wraps "tf.PreventGradient"(%outputs_2) {device = "", message = "The jax2tf-converted function does not support gradients. Use `with_gradient` parameter to enable gradients"} : (tensor<2x3xi32>) -> tensor<2x3xi32>
+      %outputs_7, %control_8 = tf_executor.island wraps "tf.Identity"(%outputs_5) {device = ""} : (tensor<2x3xi32>) -> tensor<2x3xi32>
+      %outputs_9, %control_10 = tf_executor.island(%control_4) wraps "tf.Identity"(%outputs_7) {device = ""} : (tensor<2x3xi32>) -> tensor<2x3xi32>
+      tf_executor.fetch %outputs_9 : tensor<2x3xi32>
+    }
+    return %0 : tensor<2x3xi32>
+  }
+}
+
+// CHECK: module attributes {tfl.description = "MLIR Converted.", tfl.metadata = {keep_stablehlo_constant = "true"}, tfl.schema_version = 3 : i32} {
+// CHECK-NEXT:  func.func @main(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> attributes {tf.entry_function = {inputs = "args_tf_0", outputs = "Identity"}} {
+// CHECK-NEXT:    %0 = stablehlo.custom_call @Sharding(%arg0) {mhlo.sharding = ""} : (tensor<2x3xi32>) -> tensor<2x3xi32>
+// CHECK-NEXT:    %1 = stablehlo.multiply %0, %0 : tensor<2x3xi32>
+// CHECK-NEXT:    return %1 : tensor<2x3xi32>
+// CHECK-NEXT:  }
+// CHECK-NEXT: }
\ No newline at end of file
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir
index 18cb7ce..593cdbf 100644
--- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir
+++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir
@@ -562,15 +562,15 @@
 // CHECK-LABEL:   func @floordiv_broadcast_i32(
 // CHECK-SAME:                                 %[[VAL_0:.*]]: tensor<2x3xi32>,
 // CHECK-SAME:                                 %[[VAL_1:.*]]: tensor<3xi32>) -> tensor<2x3xi32> {
-// CHECK:           %[[VAL_2:.*]] = "tf.Const"() {value = dense<0> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
+// CHECK:           %[[VAL_2:.*]] = "tf.Const"() <{value = dense<0> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
 // CHECK:           %[[VAL_3:.*]] = "tf.Less"(%[[VAL_0]], %[[VAL_2]]) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1>
-// CHECK:           %[[VAL_4:.*]] = "tf.Const"() {value = dense<0> : tensor<3xi32>} : () -> tensor<3xi32>
+// CHECK:           %[[VAL_4:.*]] = "tf.Const"() <{value = dense<0> : tensor<3xi32>}> : () -> tensor<3xi32>
 // CHECK:           %[[VAL_5:.*]] = "tf.Less"(%[[VAL_1]], %[[VAL_4]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
-// CHECK:           %[[VAL_6:.*]] = "tf.Equal"(%[[VAL_3]], %[[VAL_5]]) {incompatible_shape_error = true} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1>
+// CHECK:           %[[VAL_6:.*]] = "tf.Equal"(%[[VAL_3]], %[[VAL_5]]) <{incompatible_shape_error = true}> : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1>
 // CHECK:           %[[VAL_7:.*]] = "tf.Div"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
 // CHECK:           %[[VAL_8:.*]] = "tf.Abs"(%[[VAL_0]]) : (tensor<2x3xi32>) -> tensor<2x3xi32>
 // CHECK:           %[[VAL_9:.*]] = "tf.Abs"(%[[VAL_1]]) : (tensor<3xi32>) -> tensor<3xi32>
-// CHECK:           %[[VAL_10:.*]] = "tf.Const"() {value = dense<1> : tensor<3xi32>} : () -> tensor<3xi32>
+// CHECK:           %[[VAL_10:.*]] = "tf.Const"() <{value = dense<1> : tensor<3xi32>}> : () -> tensor<3xi32>
 // CHECK:           %[[VAL_11:.*]] = "tf.Sub"(%[[VAL_9]], %[[VAL_10]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
 // CHECK:           %[[VAL_12:.*]] = "tf.AddV2"(%[[VAL_8]], %[[VAL_11]]) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
 // CHECK:           %[[VAL_13:.*]] = "tf.Neg"(%[[VAL_12]]) : (tensor<2x3xi32>) -> tensor<2x3xi32>
@@ -601,15 +601,15 @@
 // CHECK-LABEL:   func @floordiv_reverse_broadcast_i32(
 // CHECK-SAME:                                         %[[VAL_0:.*]]: tensor<3xi32>,
 // CHECK-SAME:                                         %[[VAL_1:.*]]: tensor<2x3xi32>) -> tensor<2x3xi32> {
-// CHECK:           %[[VAL_2:.*]] = "tf.Const"() {value = dense<0> : tensor<3xi32>} : () -> tensor<3xi32>
+// CHECK:           %[[VAL_2:.*]] = "tf.Const"() <{value = dense<0> : tensor<3xi32>}> : () -> tensor<3xi32>
 // CHECK:           %[[VAL_3:.*]] = "tf.Less"(%[[VAL_0]], %[[VAL_2]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
-// CHECK:           %[[VAL_4:.*]] = "tf.Const"() {value = dense<0> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
+// CHECK:           %[[VAL_4:.*]] = "tf.Const"() <{value = dense<0> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
 // CHECK:           %[[VAL_5:.*]] = "tf.Less"(%[[VAL_1]], %[[VAL_4]]) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1>
-// CHECK:           %[[VAL_6:.*]] = "tf.Equal"(%[[VAL_3]], %[[VAL_5]]) {incompatible_shape_error = true} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1>
+// CHECK:           %[[VAL_6:.*]] = "tf.Equal"(%[[VAL_3]], %[[VAL_5]]) <{incompatible_shape_error = true}> : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1>
 // CHECK:           %[[VAL_7:.*]] = "tf.Div"(%[[VAL_0]], %[[VAL_1]]) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
 // CHECK:           %[[VAL_8:.*]] = "tf.Abs"(%[[VAL_0]]) : (tensor<3xi32>) -> tensor<3xi32>
 // CHECK:           %[[VAL_9:.*]] = "tf.Abs"(%[[VAL_1]]) : (tensor<2x3xi32>) -> tensor<2x3xi32>
-// CHECK:           %[[VAL_10:.*]] = "tf.Const"() {value = dense<1> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
+// CHECK:           %[[VAL_10:.*]] = "tf.Const"() <{value = dense<1> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
 // CHECK:           %[[VAL_11:.*]] = "tf.Sub"(%[[VAL_9]], %[[VAL_10]]) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
 // CHECK:           %[[VAL_12:.*]] = "tf.AddV2"(%[[VAL_8]], %[[VAL_11]]) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
 // CHECK:           %[[VAL_13:.*]] = "tf.Neg"(%[[VAL_12]]) : (tensor<2x3xi32>) -> tensor<2x3xi32>
@@ -669,7 +669,7 @@
 // CHECK-LABEL:   func @equal(
 // CHECK-SAME:                %[[VAL_0:.*]]: tensor<2xi32>,
 // CHECK-SAME:                %[[VAL_1:.*]]: tensor<2xi32>) -> tensor<2xi1> {
-// CHECK:           %[[VAL_2:.*]] = "tf.Equal"(%[[VAL_0]], %[[VAL_1]]) {incompatible_shape_error = true} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
+// CHECK:           %[[VAL_2:.*]] = "tf.Equal"(%[[VAL_0]], %[[VAL_1]]) <{incompatible_shape_error = true}> : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
 // CHECK:           return %[[VAL_2]] : tensor<2xi1>
 // CHECK:         }
 func.func @equal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
@@ -680,7 +680,7 @@
 // CHECK-LABEL:   func @equal_dynamic(
 // CHECK-SAME:                        %[[VAL_0:.*]]: tensor<?xi32>,
 // CHECK-SAME:                        %[[VAL_1:.*]]: tensor<1xi32>) -> tensor<?xi1> {
-// CHECK:           %[[VAL_2:.*]] = "tf.Equal"(%[[VAL_0]], %[[VAL_1]]) {incompatible_shape_error = true} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
+// CHECK:           %[[VAL_2:.*]] = "tf.Equal"(%[[VAL_0]], %[[VAL_1]]) <{incompatible_shape_error = true}> : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
 // CHECK:           return %[[VAL_2]] : tensor<?xi1>
 // CHECK:         }
 func.func @equal_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1> {
@@ -691,7 +691,7 @@
 // CHECK-LABEL:   func @equal_broadcast(
 // CHECK-SAME:                          %[[VAL_0:.*]]: tensor<1x1xi32>,
 // CHECK-SAME:                          %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> {
-// CHECK:           %[[VAL_2:.*]] = "tf.Equal"(%[[VAL_0]], %[[VAL_1]]) {incompatible_shape_error = true} : (tensor<1x1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
+// CHECK:           %[[VAL_2:.*]] = "tf.Equal"(%[[VAL_0]], %[[VAL_1]]) <{incompatible_shape_error = true}> : (tensor<1x1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
 // CHECK:           return %[[VAL_2]] : tensor<1x2xi1>
 // CHECK:         }
 func.func @equal_broadcast(%arg0: tensor<1x1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
@@ -703,7 +703,7 @@
 // CHECK-LABEL:   func @equal_broadcast_chlo(
 // CHECK-SAME:                               %[[VAL_0:.*]]: tensor<1xi32>,
 // CHECK-SAME:                               %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> {
-// CHECK:           %[[VAL_2:.*]] = "tf.Equal"(%[[VAL_0]], %[[VAL_1]]) {incompatible_shape_error = true} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
+// CHECK:           %[[VAL_2:.*]] = "tf.Equal"(%[[VAL_0]], %[[VAL_1]]) <{incompatible_shape_error = true}> : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
 // CHECK:           return %[[VAL_2]] : tensor<1x2xi1>
 // CHECK:         }
 func.func @equal_broadcast_chlo(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
@@ -714,7 +714,7 @@
 // CHECK-LABEL:   func @equal_broadcast_no_incompatible_shapes_error(
 // CHECK-SAME:                                                       %[[VAL_0:.*]]: tensor<2xi32>,
 // CHECK-SAME:                                                       %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> {
-// CHECK:           %[[VAL_2:.*]] = "tf.Equal"(%[[VAL_0]], %[[VAL_1]]) {incompatible_shape_error = true} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
+// CHECK:           %[[VAL_2:.*]] = "tf.Equal"(%[[VAL_0]], %[[VAL_1]]) <{incompatible_shape_error = true}> : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
 // CHECK:           return %[[VAL_2]] : tensor<1x2xi1>
 // CHECK:         }
 func.func @equal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
@@ -725,7 +725,7 @@
 // CHECK-LABEL:   func @equal_incompatible_shape_broadcastable(
 // CHECK-SAME:                                                 %[[VAL_0:.*]]: tensor<?xi32>,
 // CHECK-SAME:                                                 %[[VAL_1:.*]]: tensor<1xi32>) -> tensor<?xi1> {
-// CHECK:           %[[VAL_2:.*]] = "tf.Equal"(%[[VAL_0]], %[[VAL_1]]) {incompatible_shape_error = true} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
+// CHECK:           %[[VAL_2:.*]] = "tf.Equal"(%[[VAL_0]], %[[VAL_1]]) <{incompatible_shape_error = true}> : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
 // CHECK:           return %[[VAL_2]] : tensor<?xi1>
 // CHECK:         }
 func.func @equal_incompatible_shape_broadcastable(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1> {
@@ -743,7 +743,7 @@
 // CHECK-LABEL:   func @notequal(
 // CHECK-SAME:                   %[[VAL_0:.*]]: tensor<2xi32>,
 // CHECK-SAME:                   %[[VAL_1:.*]]: tensor<2xi32>) -> tensor<2xi1> {
-// CHECK:           %[[VAL_2:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_1]]) {incompatible_shape_error = true} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
+// CHECK:           %[[VAL_2:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_1]]) <{incompatible_shape_error = true}> : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
 // CHECK:           return %[[VAL_2]] : tensor<2xi1>
 // CHECK:         }
 func.func @notequal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
@@ -754,7 +754,7 @@
 // CHECK-LABEL:   func @notequal_broadcast(
 // CHECK-SAME:                             %[[VAL_0:.*]]: tensor<1x1xi32>,
 // CHECK-SAME:                             %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> {
-// CHECK:           %[[VAL_2:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_1]]) {incompatible_shape_error = true} : (tensor<1x1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
+// CHECK:           %[[VAL_2:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_1]]) <{incompatible_shape_error = true}> : (tensor<1x1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
 // CHECK:           return %[[VAL_2]] : tensor<1x2xi1>
 // CHECK:         }
 func.func @notequal_broadcast(%arg0: tensor<1x1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
@@ -766,7 +766,7 @@
 // CHECK-LABEL:   func @notequal_broadcast_chlo(
 // CHECK-SAME:                                  %[[VAL_0:.*]]: tensor<1xi32>,
 // CHECK-SAME:                                  %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> {
-// CHECK:           %[[VAL_2:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_1]]) {incompatible_shape_error = true} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
+// CHECK:           %[[VAL_2:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_1]]) <{incompatible_shape_error = true}> : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
 // CHECK:           return %[[VAL_2]] : tensor<1x2xi1>
 // CHECK:         }
 func.func @notequal_broadcast_chlo(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
@@ -777,7 +777,7 @@
 // CHECK-LABEL:   func @notequal_broadcast_no_incompatible_shapes_error(
 // CHECK-SAME:                                                          %[[VAL_0:.*]]: tensor<2xi32>,
 // CHECK-SAME:                                                          %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> {
-// CHECK:           %[[VAL_2:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_1]]) {incompatible_shape_error = true} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
+// CHECK:           %[[VAL_2:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_1]]) <{incompatible_shape_error = true}> : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
 // CHECK:           return %[[VAL_2]] : tensor<1x2xi1>
 // CHECK:         }
 func.func @notequal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
@@ -788,7 +788,7 @@
 // CHECK-LABEL:   func @notequal_incompatible_shape_broadcastable(
 // CHECK-SAME:                                                    %[[VAL_0:.*]]: tensor<?xi32>,
 // CHECK-SAME:                                                    %[[VAL_1:.*]]: tensor<1xi32>) -> tensor<?xi1> {
-// CHECK:           %[[VAL_2:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_1]]) {incompatible_shape_error = true} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
+// CHECK:           %[[VAL_2:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_1]]) <{incompatible_shape_error = true}> : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
 // CHECK:           return %[[VAL_2]] : tensor<?xi1>
 // CHECK:         }
 func.func @notequal_incompatible_shape_broadcastable(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1> {
@@ -942,7 +942,7 @@
 // CHECK-LABEL:   func @concat_v2(
 // CHECK-SAME:                    %[[VAL_0:.*]]: tensor<3x3xf32>,
 // CHECK-SAME:                    %[[VAL_1:.*]]: tensor<3x3xf32>) -> tensor<6x3xf32> {
-// CHECK:           %[[VAL_2:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
+// CHECK:           %[[VAL_2:.*]] = "tf.Const"() <{value = dense<0> : tensor<i64>}> : () -> tensor<i64>
 // CHECK:           %[[VAL_3:.*]] = "tf.ConcatV2"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<i64>) -> tensor<6x3xf32>
 // CHECK:           return %[[VAL_3]] : tensor<6x3xf32>
 // CHECK:         }
@@ -954,7 +954,7 @@
 // CHECK-LABEL:   func @concat_v2_1d_axis(
 // CHECK-SAME:                            %[[VAL_0:.*]]: tensor<3x3xf32>,
 // CHECK-SAME:                            %[[VAL_1:.*]]: tensor<3x3xf32>) -> tensor<3x6xf32> {
-// CHECK:           %[[VAL_2:.*]] = "tf.Const"() {value = dense<1> : tensor<i64>} : () -> tensor<i64>
+// CHECK:           %[[VAL_2:.*]] = "tf.Const"() <{value = dense<1> : tensor<i64>}> : () -> tensor<i64>
 // CHECK:           %[[VAL_3:.*]] = "tf.ConcatV2"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<i64>) -> tensor<3x6xf32>
 // CHECK:           return %[[VAL_3]] : tensor<3x6xf32>
 // CHECK:         }
@@ -964,7 +964,7 @@
 }
 
 // CHECK-LABEL:   func @const() -> tensor<2xi32> {
-// CHECK:           %[[VAL_0:.*]] = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> tensor<2xi32>
+// CHECK:           %[[VAL_0:.*]] = "tf.Const"() <{value = dense<0> : tensor<2xi32>}> : () -> tensor<2xi32>
 // CHECK:           return %[[VAL_0]] : tensor<2xi32>
 // CHECK:         }
 func.func @const() -> tensor<2xi32> {
@@ -974,7 +974,7 @@
 
 // CHECK-LABEL:   func @relu(
 // CHECK-SAME:               %[[VAL_0:.*]]: tensor<1xi32>) -> tensor<1xi32> {
-// CHECK:           %[[VAL_1:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK:           %[[VAL_1:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           %[[VAL_2:.*]] = "tf.Maximum"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
 // CHECK:           return %[[VAL_2]] : tensor<1xi32>
 // CHECK:         }
@@ -986,7 +986,7 @@
 
 // CHECK-LABEL:   func @relu_unranked(
 // CHECK-SAME:                        %[[VAL_0:.*]]: tensor<?xi32>) -> tensor<?xi32> {
-// CHECK:           %[[VAL_1:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK:           %[[VAL_1:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           %[[VAL_2:.*]] = "tf.Maximum"(%[[VAL_0]], %[[VAL_1]]) : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
 // CHECK:           return %[[VAL_2]] : tensor<?xi32>
 // CHECK:         }
@@ -998,8 +998,8 @@
 
 // CHECK-LABEL:   func @relu6(
 // CHECK-SAME:                %[[VAL_0:.*]]: tensor<1xi32>) -> tensor<1xi32> {
-// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() {value = dense<6> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<6> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           %[[VAL_3:.*]] = "tf.Minimum"(%[[VAL_0]], %[[VAL_2]]) : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
 // CHECK:           %[[VAL_4:.*]] = "tf.Maximum"(%[[VAL_3]], %[[VAL_1]]) : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
 // CHECK:           return %[[VAL_4]] : tensor<1xi32>
@@ -1014,8 +1014,8 @@
 
 // CHECK-LABEL:   func @relu6_unranked(
 // CHECK-SAME:                         %[[VAL_0:.*]]: tensor<?xi32>) -> tensor<?xi32> {
-// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() {value = dense<6> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<6> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           %[[VAL_3:.*]] = "tf.Minimum"(%[[VAL_0]], %[[VAL_2]]) : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
 // CHECK:           %[[VAL_4:.*]] = "tf.Maximum"(%[[VAL_3]], %[[VAL_1]]) : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
 // CHECK:           return %[[VAL_4]] : tensor<?xi32>
@@ -1031,9 +1031,9 @@
 // CHECK-LABEL:   func @relu_grad(
 // CHECK-SAME:                    %[[VAL_0:.*]]: tensor<4x8xf32>,
 // CHECK-SAME:                    %[[VAL_1:.*]]: tensor<?x?xf32>) -> tensor<4x8xf32> {
-// CHECK:           %[[VAL_2:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
+// CHECK:           %[[VAL_2:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
 // CHECK:           %[[VAL_3:.*]] = "tf.Greater"(%[[VAL_1]], %[[VAL_2]]) : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
-// CHECK:           %[[VAL_4:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<4x8xf32>} : () -> tensor<4x8xf32>
+// CHECK:           %[[VAL_4:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<4x8xf32>}> : () -> tensor<4x8xf32>
 // CHECK:           %[[VAL_5:.*]] = "tf.Select"(%[[VAL_3]], %[[VAL_0]], %[[VAL_4]]) : (tensor<?x?xi1>, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
 // CHECK:           return %[[VAL_5]] : tensor<4x8xf32>
 // CHECK:         }
@@ -1133,9 +1133,9 @@
 
 // CHECK-LABEL:   func @transpose_2d(
 // CHECK-SAME:                       %[[VAL_0:.*]]: tensor<2x3xf32>) -> tensor<3x2xf32> {
-// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64>
-// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64>
-// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64>
+// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<[1, 0]> : tensor<2xi64>}> : () -> tensor<2xi64>
+// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<[1, 0]> : tensor<2xi64>}> : () -> tensor<2xi64>
+// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() <{value = dense<[1, 0]> : tensor<2xi64>}> : () -> tensor<2xi64>
 // CHECK:           %[[VAL_4:.*]] = "tf.Transpose"(%[[VAL_0]], %[[VAL_3]]) : (tensor<2x3xf32>, tensor<2xi64>) -> tensor<3x2xf32>
 // CHECK:           return %[[VAL_4]] : tensor<3x2xf32>
 // CHECK:         }
@@ -1148,9 +1148,9 @@
 
 // CHECK-LABEL:   func @transpose_3d_int32(
 // CHECK-SAME:                             %[[VAL_0:.*]]: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> {
-// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi32>} : () -> tensor<3xi32>
-// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> tensor<3xi64>
-// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> tensor<3xi64>
+// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<[2, 1, 0]> : tensor<3xi32>}> : () -> tensor<3xi32>
+// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<[2, 1, 0]> : tensor<3xi64>}> : () -> tensor<3xi64>
+// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() <{value = dense<[2, 1, 0]> : tensor<3xi64>}> : () -> tensor<3xi64>
 // CHECK:           %[[VAL_4:.*]] = "tf.Transpose"(%[[VAL_0]], %[[VAL_3]]) : (tensor<1x2x3xf32>, tensor<3xi64>) -> tensor<3x2x1xf32>
 // CHECK:           return %[[VAL_4]] : tensor<3x2x1xf32>
 // CHECK:         }
@@ -1163,9 +1163,9 @@
 
 // CHECK-LABEL:   func @transpose_3d(
 // CHECK-SAME:                       %[[VAL_0:.*]]: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> {
-// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> tensor<3xi64>
-// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> tensor<3xi64>
-// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> tensor<3xi64>
+// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<[2, 1, 0]> : tensor<3xi64>}> : () -> tensor<3xi64>
+// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<[2, 1, 0]> : tensor<3xi64>}> : () -> tensor<3xi64>
+// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() <{value = dense<[2, 1, 0]> : tensor<3xi64>}> : () -> tensor<3xi64>
 // CHECK:           %[[VAL_4:.*]] = "tf.Transpose"(%[[VAL_0]], %[[VAL_3]]) : (tensor<1x2x3xf32>, tensor<3xi64>) -> tensor<3x2x1xf32>
 // CHECK:           return %[[VAL_4]] : tensor<3x2x1xf32>
 // CHECK:         }
@@ -1178,9 +1178,9 @@
 
 // CHECK-LABEL:   func @transpose_dynamic_2d(
 // CHECK-SAME:                               %[[VAL_0:.*]]: tensor<?x4xf32>) -> tensor<4x?xf32> {
-// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64>
-// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64>
-// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64>
+// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<[1, 0]> : tensor<2xi64>}> : () -> tensor<2xi64>
+// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<[1, 0]> : tensor<2xi64>}> : () -> tensor<2xi64>
+// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() <{value = dense<[1, 0]> : tensor<2xi64>}> : () -> tensor<2xi64>
 // CHECK:           %[[VAL_4:.*]] = "tf.Transpose"(%[[VAL_0]], %[[VAL_3]]) : (tensor<?x4xf32>, tensor<2xi64>) -> tensor<4x?xf32>
 // CHECK:           return %[[VAL_4]] : tensor<4x?xf32>
 // CHECK:         }
@@ -1193,9 +1193,9 @@
 
 // CHECK-LABEL:   func @transpose_unranked_2d(
 // CHECK-SAME:                                %[[VAL_0:.*]]: tensor<*xf32>) -> tensor<*xf32> {
-// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64>
-// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64>
-// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64>
+// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<[1, 0]> : tensor<2xi64>}> : () -> tensor<2xi64>
+// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<[1, 0]> : tensor<2xi64>}> : () -> tensor<2xi64>
+// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() <{value = dense<[1, 0]> : tensor<2xi64>}> : () -> tensor<2xi64>
 // CHECK:           %[[VAL_4:.*]] = "tf.Transpose"(%[[VAL_0]], %[[VAL_3]]) : (tensor<*xf32>, tensor<2xi64>) -> tensor<*xf32>
 // CHECK:           return %[[VAL_4]] : tensor<*xf32>
 // CHECK:         }
@@ -1488,9 +1488,9 @@
 
 // CHECK-LABEL:   func @sigmoid(
 // CHECK-SAME:                  %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> {
-// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() {value = dense<5.000000e-01> : tensor<f32>} : () -> tensor<f32>
-// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() {value = dense<2> : tensor<1xi64>} : () -> tensor<1xi64>
-// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() {value = dense<5.000000e-01> : tensor<2xf32>} : () -> tensor<2xf32>
+// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<5.000000e-01> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<2> : tensor<1xi64>}> : () -> tensor<1xi64>
+// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() <{value = dense<5.000000e-01> : tensor<2xf32>}> : () -> tensor<2xf32>
 // CHECK:           %[[VAL_4:.*]] = "tf.Mul"(%[[VAL_0]], %[[VAL_3]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
 // CHECK:           %[[VAL_5:.*]] = "tf.Tanh"(%[[VAL_4]]) : (tensor<2xf32>) -> tensor<2xf32>
 // CHECK:           %[[VAL_6:.*]] = "tf.Mul"(%[[VAL_5]], %[[VAL_3]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
@@ -1671,10 +1671,10 @@
 // CHECK-LABEL:   func @sign(
 // CHECK-SAME:               %[[VAL_0:.*]]: tensor<1x2x3x4xf32>,
 // CHECK-SAME:               %[[VAL_1:.*]]: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> {
-// CHECK:           %[[VAL_2:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_1]]) {incompatible_shape_error = true} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1>
-// CHECK:           %[[VAL_3:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1x2x3x4xf32>} : () -> tensor<1x2x3x4xf32>
-// CHECK:           %[[VAL_4:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_1]]) {incompatible_shape_error = true} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1>
-// CHECK:           %[[VAL_5:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1x2x3x4xf32>} : () -> tensor<1x2x3x4xf32>
+// CHECK:           %[[VAL_2:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_1]]) <{incompatible_shape_error = true}> : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1>
+// CHECK:           %[[VAL_3:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1x2x3x4xf32>}> : () -> tensor<1x2x3x4xf32>
+// CHECK:           %[[VAL_4:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_1]]) <{incompatible_shape_error = true}> : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1>
+// CHECK:           %[[VAL_5:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1x2x3x4xf32>}> : () -> tensor<1x2x3x4xf32>
 // CHECK:           %[[VAL_6:.*]] = "tf.Sign"(%[[VAL_0]]) : (tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
 // CHECK:           %[[VAL_7:.*]] = "tf.Select"(%[[VAL_4]], %[[VAL_5]], %[[VAL_6]]) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
 // CHECK:           %[[VAL_8:.*]] = "tf.Select"(%[[VAL_2]], %[[VAL_3]], %[[VAL_7]]) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
@@ -1693,7 +1693,7 @@
 
 // CHECK-LABEL:   func @size_rank_one_i32(
 // CHECK-SAME:                            %[[VAL_0:.*]]: tensor<f32>) -> tensor<i32> {
-// CHECK:           %[[VAL_1:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+// CHECK:           %[[VAL_1:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           return %[[VAL_1]] : tensor<i32>
 // CHECK:         }
 func.func @size_rank_one_i32(%arg0: tensor<f32>) -> tensor<i32> {
@@ -1703,7 +1703,7 @@
 
 // CHECK-LABEL:   func @size_rank_one_i64(
 // CHECK-SAME:                            %[[VAL_0:.*]]: tensor<f32>) -> tensor<i64> {
-// CHECK:           %[[VAL_1:.*]] = "tf.Const"() {value = dense<1> : tensor<i64>} : () -> tensor<i64>
+// CHECK:           %[[VAL_1:.*]] = "tf.Const"() <{value = dense<1> : tensor<i64>}> : () -> tensor<i64>
 // CHECK:           return %[[VAL_1]] : tensor<i64>
 // CHECK:         }
 func.func @size_rank_one_i64(%arg0: tensor<f32>) -> tensor<i64> {
@@ -1724,7 +1724,7 @@
 
 // CHECK-LABEL:   func @convert_i32_f32(
 // CHECK-SAME:                          %[[VAL_0:.*]]: tensor<2xi32>) -> tensor<2xf32> {
-// CHECK:           %[[VAL_1:.*]] = "tf.Cast"(%[[VAL_0]]) {Truncate = false} : (tensor<2xi32>) -> tensor<2xf32>
+// CHECK:           %[[VAL_1:.*]] = "tf.Cast"(%[[VAL_0]]) <{Truncate = false}> : (tensor<2xi32>) -> tensor<2xf32>
 // CHECK:           return %[[VAL_1]] : tensor<2xf32>
 // CHECK:         }
 func.func @convert_i32_f32(%arg0: tensor<2xi32>) -> tensor<2xf32> {
@@ -1734,11 +1734,11 @@
 
 // CHECK-LABEL:   func @convert_slice(
 // CHECK-SAME:                        %[[VAL_0:.*]]: tensor<1x4672xf32>) -> tensor<1x519xf32> {
-// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() {value = dense<[0, 4153]> : tensor<2xi64>} : () -> tensor<2xi64>
-// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() {value = dense<[1, 4672]> : tensor<2xi64>} : () -> tensor<2xi64>
-// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() {value = dense<1> : tensor<2xi64>} : () -> tensor<2xi64>
+// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<[0, 4153]> : tensor<2xi64>}> : () -> tensor<2xi64>
+// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<[1, 4672]> : tensor<2xi64>}> : () -> tensor<2xi64>
+// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() <{value = dense<1> : tensor<2xi64>}> : () -> tensor<2xi64>
 // CHECK:           %[[VAL_4:.*]] = "tf.StridedSlice"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_3]])
-// CHECK-SAME:          {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64}
+// CHECK-SAME:          <{begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64}>
 // CHECK-SAME:          (tensor<1x4672xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x519xf32>
 // CHECK:           return %[[VAL_4]] : tensor<1x519xf32>
 // CHECK:         }
@@ -1786,7 +1786,7 @@
 // CHECK-SAME:                            %[[VAL_1:.*]]: tensor<256xf32>) -> tensor<1xf32> {
 // CHECK:           %[[VAL_2:.*]] = arith.constant dense<[256, 1]> : tensor<2xi64>
 // CHECK:           %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_1]], %[[VAL_2]]) : (tensor<256xf32>, tensor<2xi64>) -> tensor<256x1xf32>
-// CHECK:           %[[VAL_4:.*]] = "tf.BatchMatMulV3"(%[[VAL_0]], %[[VAL_3]]) {adj_x = false, adj_y = false} : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32>
+// CHECK:           %[[VAL_4:.*]] = "tf.BatchMatMulV3"(%[[VAL_0]], %[[VAL_3]]) <{adj_x = false, adj_y = false}> : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32>
 // CHECK:           %[[VAL_5:.*]] = arith.constant dense<1> : tensor<1xi64>
 // CHECK:           %[[VAL_6:.*]] = "tf.Reshape"(%[[VAL_4]], %[[VAL_5]]) : (tensor<1x1xf32>, tensor<1xi64>) -> tensor<1xf32>
 // CHECK:           return %[[VAL_6]] : tensor<1xf32>
@@ -1803,7 +1803,7 @@
 // CHECK:           %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_2]]) : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32>
 // CHECK:           %[[VAL_4:.*]] = arith.constant dense<[256, 1]> : tensor<2xi64>
 // CHECK:           %[[VAL_5:.*]] = "tf.Reshape"(%[[VAL_1]], %[[VAL_4]]) : (tensor<256xf32>, tensor<2xi64>) -> tensor<256x1xf32>
-// CHECK:           %[[VAL_6:.*]] = "tf.BatchMatMulV3"(%[[VAL_3]], %[[VAL_5]]) {adj_x = false, adj_y = false} : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32>
+// CHECK:           %[[VAL_6:.*]] = "tf.BatchMatMulV3"(%[[VAL_3]], %[[VAL_5]]) <{adj_x = false, adj_y = false}> : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32>
 // CHECK:           %[[VAL_7:.*]] = arith.constant dense<> : tensor<0xi64>
 // CHECK:           %[[VAL_8:.*]] = "tf.Reshape"(%[[VAL_6]], %[[VAL_7]]) : (tensor<1x1xf32>, tensor<0xi64>) -> tensor<f32>
 // CHECK:           return %[[VAL_8]] : tensor<f32>
@@ -1816,7 +1816,7 @@
 // CHECK-LABEL:   func @convert_dot_2d_2d(
 // CHECK-SAME:                            %[[VAL_0:.*]]: tensor<1x256xf32>,
 // CHECK-SAME:                            %[[VAL_1:.*]]: tensor<256x1xf32>) -> tensor<1x1xf32> {
-// CHECK:           %[[VAL_2:.*]] = "tf.BatchMatMulV3"(%[[VAL_0]], %[[VAL_1]]) {adj_x = false, adj_y = false} : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32>
+// CHECK:           %[[VAL_2:.*]] = "tf.BatchMatMulV3"(%[[VAL_0]], %[[VAL_1]]) <{adj_x = false, adj_y = false}> : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32>
 // CHECK:           return %[[VAL_2]] : tensor<1x1xf32>
 // CHECK:         }
 func.func @convert_dot_2d_2d(%arg0: tensor<1x256xf32>, %arg1: tensor<256x1xf32>) -> tensor<1x1xf32> {
@@ -1861,9 +1861,9 @@
 // CHECK-LABEL: func @dynamic_broadcast_in_dim_general_case_expand_back_dims(
 // CHECK-SAME:                               %[[ARG_0:.*]]: tensor<?x3000xf32>,
 // CHECK-SAME:                               %[[ARG_1:.*]]: tensor<4xi32>) -> tensor<?x3000x2x4xf32> {
-// CHECK          %[[CST_0:.*]] = "tf.Const"() {value = dense<2> : tensor<i64>} : () -> tensor<i64>
+// CHECK          %[[CST_0:.*]] = "tf.Const"() <{value = dense<2> : tensor<i64>}> : () -> tensor<i64>
 // CHECK          %[[VAL_0:.*]] = "tf.ExpandDims"(%[[ARG_0]], %[[CST_0]]) : (tensor<?x3000xf32>, tensor<i64>) -> tensor<?x3000x1xf32>
-// CHECK          %[[CST_1:.*]] = "tf.Const"() {value = dense<3> : tensor<i64>} : () -> tensor<i64>
+// CHECK          %[[CST_1:.*]] = "tf.Const"() <{value = dense<3> : tensor<i64>}> : () -> tensor<i64>
 // CHECK          %[[VAL_1:.*]] = "tf.ExpandDims"(%[[VAL_0]], %[[CST_1]]) : (tensor<?x3000x1xf32>, tensor<i64>) -> tensor<?x3000x1x1xf32>
 // CHECK          %[[VAL_2:.*]] = "tf.BroadcastTo"(%[[VAL_1]], %[[ARG_1]]) : (tensor<?x3000x1x1xf32>, tensor<4xi32>) -> tensor<?x3000x2x4xf32>
 // CHECK          return %[[VAL_2]] : tensor<?x3000x2x4xf32>
@@ -1875,7 +1875,7 @@
 // CHECK-LABEL: func @dynamic_broadcast_in_dim_general_case_expand_middle_dim(
 // CHECK-SAME:                               %[[ARG_0:.*]]: tensor<?x750x768xf32>,
 // CHECK-SAME:                               %[[ARG_1:.*]]: tensor<4xi32>) -> tensor<?x750x1x768xf32> {
-// CHECK          %[[CST_0:.*]] = "tf.Const"() {value = dense<2> : tensor<i64>} : () -> tensor<i64>
+// CHECK          %[[CST_0:.*]] = "tf.Const"() <{value = dense<2> : tensor<i64>}> : () -> tensor<i64>
 // CHECK          %[[VAL_0:.*]] = "tf.ExpandDims"(%[[ARG_0]], %[[CST_0]]) : (tensor<?x750x768xf32>, tensor<i64>) -> tensor<?x750x1x768xf32>
 // CHECK          %[[VAL_1:.*]] = "tf.BroadcastTo"(%[[VAL_0]], %[[ARG_1]]) : (tensor<?x750x1x768xf32>, tensor<4xi32>) -> tensor<?x750x1x768xf32>
 // CHECK          return %[[VAL_1]] : tensor<?x750x1x768xf32>
@@ -1887,15 +1887,15 @@
 // CHECK-LABEL:   func @convert_dot_general(
 // CHECK-SAME:                              %[[VAL_0:.*]]: tensor<3x2x6x5x1xf32>,
 // CHECK-SAME:                              %[[VAL_1:.*]]: tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32> {
-// CHECK:           %[[VAL_2:.*]] = "tf.Const"() {value = dense<[0, 3, 4, 1, 2]> : tensor<5xi64>} : () -> tensor<5xi64>
+// CHECK:           %[[VAL_2:.*]] = "tf.Const"() <{value = dense<[0, 3, 4, 1, 2]> : tensor<5xi64>}> : () -> tensor<5xi64>
 // CHECK:           %[[VAL_3:.*]] = "tf.Transpose"(%[[VAL_0]], %[[VAL_2]]) : (tensor<3x2x6x5x1xf32>, tensor<5xi64>) -> tensor<3x5x1x2x6xf32>
-// CHECK:           %[[VAL_4:.*]] = "tf.Const"() {value = dense<[0, 1, 3, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
+// CHECK:           %[[VAL_4:.*]] = "tf.Const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
 // CHECK:           %[[VAL_5:.*]] = "tf.Transpose"(%[[VAL_1]], %[[VAL_4]]) : (tensor<3x2x4x6xf32>, tensor<4xi64>) -> tensor<3x2x6x4xf32>
 // CHECK:           %[[VAL_6:.*]] = arith.constant dense<[3, 5, 12]> : tensor<3xi64>
 // CHECK:           %[[VAL_7:.*]] = "tf.Reshape"(%[[VAL_3]], %[[VAL_6]]) : (tensor<3x5x1x2x6xf32>, tensor<3xi64>) -> tensor<3x5x12xf32>
 // CHECK:           %[[VAL_8:.*]] = arith.constant dense<[3, 12, 4]> : tensor<3xi64>
 // CHECK:           %[[VAL_9:.*]] = "tf.Reshape"(%[[VAL_5]], %[[VAL_8]]) : (tensor<3x2x6x4xf32>, tensor<3xi64>) -> tensor<3x12x4xf32>
-// CHECK:           %[[VAL_10:.*]] = "tf.BatchMatMulV3"(%[[VAL_7]], %[[VAL_9]]) {adj_x = false, adj_y = false} : (tensor<3x5x12xf32>, tensor<3x12x4xf32>) -> tensor<3x5x4xf32>
+// CHECK:           %[[VAL_10:.*]] = "tf.BatchMatMulV3"(%[[VAL_7]], %[[VAL_9]]) <{adj_x = false, adj_y = false}> : (tensor<3x5x12xf32>, tensor<3x12x4xf32>) -> tensor<3x5x4xf32>
 // CHECK:           %[[VAL_11:.*]] = arith.constant dense<[3, 5, 1, 4]> : tensor<4xi64>
 // CHECK:           %[[VAL_12:.*]] = "tf.Reshape"(%[[VAL_10]], %[[VAL_11]]) : (tensor<3x5x4xf32>, tensor<4xi64>) -> tensor<3x5x1x4xf32>
 // CHECK:           return %[[VAL_12]] : tensor<3x5x1x4xf32>
@@ -1929,7 +1929,7 @@
 // CHECK-SAME:                                       %[[VAL_1:.*]]: tensor<1024x1024xf32>) -> tensor<1x1x1024xf32> {
 // CHECK:           %[[VAL_2:.*]] = arith.constant dense<[1, 1024]> : tensor<2xi64>
 // CHECK:           %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_2]]) : {{.*}} -> tensor<1x1024xf32>
-// CHECK:           %[[VAL_4:.*]] = "tf.BatchMatMulV3"(%[[VAL_3]], %[[VAL_1]]) {adj_x = false, adj_y = false} : {{.*}} -> tensor<1x1024xf32>
+// CHECK:           %[[VAL_4:.*]] = "tf.BatchMatMulV3"(%[[VAL_3]], %[[VAL_1]]) <{adj_x = false, adj_y = false}> : {{.*}} -> tensor<1x1024xf32>
 // CHECK:           %[[VAL_5:.*]] = arith.constant dense<[1, 1, 1024]> : tensor<3xi64>
 // CHECK:           %[[VAL_6:.*]] = "tf.Reshape"(%[[VAL_4]], %[[VAL_5]]) : {{.*}} -> tensor<1x1x1024xf32>
 // CHECK:           return %[[VAL_6]] : tensor<1x1x1024xf32>
@@ -1952,7 +1952,7 @@
 // CHECK-SAME:                              %[[VAL_1:.*]]: tensor<256x8xi8>) -> tensor<8xi32> {
 // CHECK:           %[[VAL_2:.*]] = arith.constant dense<[1, 256]> : tensor<2xi64>
 // CHECK:           %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_2]]) : (tensor<256xi8>, tensor<2xi64>) -> tensor<1x256xi8>
-// CHECK:           %[[VAL_4:.*]] = "tf.BatchMatMulV3"(%[[VAL_3]], %[[VAL_1]]) {adj_x = false, adj_y = false} : (tensor<1x256xi8>, tensor<256x8xi8>) -> tensor<1x8xi32>
+// CHECK:           %[[VAL_4:.*]] = "tf.BatchMatMulV3"(%[[VAL_3]], %[[VAL_1]]) <{adj_x = false, adj_y = false}> : (tensor<1x256xi8>, tensor<256x8xi8>) -> tensor<1x8xi32>
 // CHECK:           %[[VAL_5:.*]] = arith.constant dense<8> : tensor<1xi64>
 // CHECK:           %[[VAL_6:.*]] = "tf.Reshape"(%[[VAL_4]], %[[VAL_5]]) : (tensor<1x8xi32>, tensor<1xi64>) -> tensor<8xi32>
 // CHECK:           return %[[VAL_6]] : tensor<8xi32>
@@ -1970,26 +1970,26 @@
 // CHECK-LABEL:   func @convert_dot_general_dynamic_rhs_out_dim(
 // CHECK-SAME:                              %arg0: tensor<4x4x256xf32>,
 // CHECK-SAME:                              %arg1: tensor<4x?x256xf32>) -> tensor<4x4x?xf32> {
-// CHECK-DAG:       %cst = "tf.Const"() {value = dense<[0, 2, 1]> : tensor<3xi64>} : () -> tensor<3xi64>
+// CHECK-DAG:       %cst = "tf.Const"() <{value = dense<[0, 2, 1]> : tensor<3xi64>}> : () -> tensor<3xi64>
 // CHECK:           %0 = "tf.Transpose"(%arg1, %cst) : (tensor<4x?x256xf32>, tensor<3xi64>) -> tensor<4x256x?xf32>
 // CHECK:           %1 = "tf.Shape"(%arg1) : (tensor<4x?x256xf32>) -> tensor<3xi32>
-// CHECK-DAG:       %cst_0 = "tf.Const"() {value = dense<[-1, 0, -1]> : tensor<3xi32>} : () -> tensor<3xi32>
-// CHECK-DAG:       %cst_1 = "tf.Const"() {value = dense<[-1, -1, 0]> : tensor<3xi32>} : () -> tensor<3xi32>
-// CHECK-DAG:       %cst_2 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       %cst_0 = "tf.Const"() <{value = dense<[-1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32>
+// CHECK-DAG:       %cst_1 = "tf.Const"() <{value = dense<[-1, -1, 0]> : tensor<3xi32>}> : () -> tensor<3xi32>
+// CHECK-DAG:       %cst_2 = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           %2 = "tf.UnsortedSegmentProd"(%1, %cst_0, %cst_2) : (tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
 // CHECK:           %3 = "tf.UnsortedSegmentProd"(%1, %cst_1, %cst_2) : (tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
-// CHECK-DAG:       %cst_3 = "tf.Const"() {value = dense<4> : tensor<1xi32>} : () -> tensor<1xi32>
-// CHECK-DAG:       %cst_4 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       %cst_3 = "tf.Const"() <{value = dense<4> : tensor<1xi32>}> : () -> tensor<1xi32>
+// CHECK-DAG:       %cst_4 = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           %4 = "tf.Concat"(%cst_4, %cst_3, %3, %2) : (tensor<i32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
 // CHECK:           %5 = "tf.Reshape"(%0, %4) : (tensor<4x256x?xf32>, tensor<3xi32>) -> tensor<4x256x?xf32>
-// CHECK:           %6 = "tf.BatchMatMulV3"(%arg0, %5) {adj_x = false, adj_y = false} : (tensor<4x4x256xf32>, tensor<4x256x?xf32>) -> tensor<4x4x?xf32>
+// CHECK:           %6 = "tf.BatchMatMulV3"(%arg0, %5) <{adj_x = false, adj_y = false}> : (tensor<4x4x256xf32>, tensor<4x256x?xf32>) -> tensor<4x4x?xf32>
 // CHECK:           %7 = "tf.Shape"(%arg0) : (tensor<4x4x256xf32>) -> tensor<3xi32>
 // CHECK:           %8 = "tf.Shape"(%arg1) : (tensor<4x?x256xf32>) -> tensor<3xi32>
-// CHECK-DAG:       %cst_5 = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64>
-// CHECK:           %9 = "tf.Gather"(%7, %cst_5) {validate_indices = true} : (tensor<3xi32>, tensor<2xi64>) -> tensor<2xi32>
-// CHECK-DAG:       %cst_6 = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64>
-// CHECK:           %10 = "tf.Gather"(%8, %cst_6) {validate_indices = true} : (tensor<3xi32>, tensor<1xi64>) -> tensor<1xi32>
-// CHECK-DAG:       %cst_7 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       %cst_5 = "tf.Const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64>
+// CHECK:           %9 = "tf.Gather"(%7, %cst_5) <{validate_indices = true}> : (tensor<3xi32>, tensor<2xi64>) -> tensor<2xi32>
+// CHECK-DAG:       %cst_6 = "tf.Const"() <{value = dense<1> : tensor<1xi64>}> : () -> tensor<1xi64>
+// CHECK:           %10 = "tf.Gather"(%8, %cst_6) <{validate_indices = true}> : (tensor<3xi32>, tensor<1xi64>) -> tensor<1xi32>
+// CHECK-DAG:       %cst_7 = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           %11 = "tf.Concat"(%cst_7, %9, %10) : (tensor<i32>, tensor<2xi32>, tensor<1xi32>) -> tensor<3xi32>
 // CHECK:           %12 = "tf.Reshape"(%6, %11) : (tensor<4x4x?xf32>, tensor<3xi32>) -> tensor<4x4x?xf32>
 // CHECK:           return %12 : tensor<4x4x?xf32>
@@ -2008,38 +2008,38 @@
 // CHECK-LABEL:   func @convert_dot_general_dynamic_batch_dim(
 // CHECK-SAME:                              %arg0: tensor<2x?x2x3xf32>,
 // CHECK-SAME:                              %arg1: tensor<2x?x4x3xf32>) -> tensor<2x?x2x4xf32> {
-// CHECK-DAG:       %cst = "tf.Const"() {value = dense<[0, 1, 3, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
+// CHECK-DAG:       %cst = "tf.Const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
 // CHECK:           %0 = "tf.Transpose"(%arg1, %cst) : (tensor<2x?x4x3xf32>, tensor<4xi64>) -> tensor<2x?x3x4xf32>
 // CHECK:           %1 = "tf.Shape"(%arg0) : (tensor<2x?x2x3xf32>) -> tensor<4xi32>
-// CHECK-DAG:       %cst_0 = "tf.Const"() {value = dense<[-1, -1, 0, -1]> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK-DAG:       %cst_1 = "tf.Const"() {value = dense<[-1, -1, -1, 0]> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK-DAG:       %cst_2 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       %cst_0 = "tf.Const"() <{value = dense<[-1, -1, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG:       %cst_1 = "tf.Const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG:       %cst_2 = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           %2 = "tf.UnsortedSegmentProd"(%1, %cst_0, %cst_2) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
 // CHECK:           %3 = "tf.UnsortedSegmentProd"(%1, %cst_1, %cst_2) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
-// CHECK-DAG:       %cst_3 = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64>
-// CHECK:           %4 = "tf.Gather"(%1, %cst_3) {validate_indices = true} : (tensor<4xi32>, tensor<2xi64>) -> tensor<2xi32>
-// CHECK-DAG:       %cst_4 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       %cst_3 = "tf.Const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64>
+// CHECK:           %4 = "tf.Gather"(%1, %cst_3) <{validate_indices = true}> : (tensor<4xi32>, tensor<2xi64>) -> tensor<2xi32>
+// CHECK-DAG:       %cst_4 = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           %5 = "tf.Concat"(%cst_4, %4, %2, %3) : (tensor<i32>, tensor<2xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32>
 // CHECK:           %6 = "tf.Reshape"(%arg0, %5) : (tensor<2x?x2x3xf32>, tensor<4xi32>) -> tensor<2x?x2x3xf32>
 // CHECK:           %7 = "tf.Shape"(%arg1) : (tensor<2x?x4x3xf32>) -> tensor<4xi32>
-// CHECK-DAG:       %cst_5 = "tf.Const"() {value = dense<[-1, -1, 0, -1]> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK-DAG:       %cst_6 = "tf.Const"() {value = dense<[-1, -1, -1, 0]> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK-DAG:       %cst_7 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       %cst_5 = "tf.Const"() <{value = dense<[-1, -1, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG:       %cst_6 = "tf.Const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG:       %cst_7 = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           %8 = "tf.UnsortedSegmentProd"(%7, %cst_5, %cst_7) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
 // CHECK:           %9 = "tf.UnsortedSegmentProd"(%7, %cst_6, %cst_7) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
-// CHECK-DAG:       %cst_8 = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64>
-// CHECK:           %10 = "tf.Gather"(%7, %cst_8) {validate_indices = true} : (tensor<4xi32>, tensor<2xi64>) -> tensor<2xi32>
-// CHECK-DAG:       %cst_9 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       %cst_8 = "tf.Const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64>
+// CHECK:           %10 = "tf.Gather"(%7, %cst_8) <{validate_indices = true}> : (tensor<4xi32>, tensor<2xi64>) -> tensor<2xi32>
+// CHECK-DAG:       %cst_9 = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           %11 = "tf.Concat"(%cst_9, %10, %9, %8) : (tensor<i32>, tensor<2xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32>
 // CHECK:           %12 = "tf.Reshape"(%0, %11) : (tensor<2x?x3x4xf32>, tensor<4xi32>) -> tensor<2x?x3x4xf32>
-// CHECK:           %13 = "tf.BatchMatMulV3"(%6, %12) {adj_x = false, adj_y = false} : (tensor<2x?x2x3xf32>, tensor<2x?x3x4xf32>) -> tensor<2x?x2x4xf32>
+// CHECK:           %13 = "tf.BatchMatMulV3"(%6, %12) <{adj_x = false, adj_y = false}> : (tensor<2x?x2x3xf32>, tensor<2x?x3x4xf32>) -> tensor<2x?x2x4xf32>
 // CHECK:           %14 = "tf.Shape"(%arg0) : (tensor<2x?x2x3xf32>) -> tensor<4xi32>
 // CHECK:           %15 = "tf.Shape"(%arg1) : (tensor<2x?x4x3xf32>) -> tensor<4xi32>
-// CHECK-DAG:       %cst_10 = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi64>} : () -> tensor<3xi64>
-// CHECK:           %16 = "tf.Gather"(%14, %cst_10) {validate_indices = true} : (tensor<4xi32>, tensor<3xi64>) -> tensor<3xi32>
-// CHECK:           %cst_11 = "tf.Const"() {value = dense<2> : tensor<1xi64>} : () -> tensor<1xi64>
-// CHECK:           %17 = "tf.Gather"(%15, %cst_11) {validate_indices = true} : (tensor<4xi32>, tensor<1xi64>) -> tensor<1xi32>
-// CHECK-DAG:       %cst_12 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       %cst_10 = "tf.Const"() <{value = dense<[0, 1, 2]> : tensor<3xi64>}> : () -> tensor<3xi64>
+// CHECK:           %16 = "tf.Gather"(%14, %cst_10) <{validate_indices = true}> : (tensor<4xi32>, tensor<3xi64>) -> tensor<3xi32>
+// CHECK:           %cst_11 = "tf.Const"() <{value = dense<2> : tensor<1xi64>}> : () -> tensor<1xi64>
+// CHECK:           %17 = "tf.Gather"(%15, %cst_11) <{validate_indices = true}> : (tensor<4xi32>, tensor<1xi64>) -> tensor<1xi32>
+// CHECK-DAG:       %cst_12 = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           %18 = "tf.Concat"(%cst_12, %16, %17) : (tensor<i32>, tensor<3xi32>, tensor<1xi32>) -> tensor<4xi32>
 // CHECK:           %19 = "tf.Reshape"(%13, %18) : (tensor<2x?x2x4xf32>, tensor<4xi32>) -> tensor<2x?x2x4xf32>
 // CHECK:           return %19 : tensor<2x?x2x4xf32>
@@ -2058,36 +2058,36 @@
 // CHECK-LABEL:   func @convert_dot_general_dynamic_lhs_rhs_out_dims(
 // CHECK-SAME:                              %arg0: tensor<2x2x?x3xf32>,
 // CHECK-SAME:                              %arg1: tensor<2x4x?x3xf32>) -> tensor<2x2x?x4x?xf32> {
-// CHECK-DAG:       %cst = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
+// CHECK-DAG:       %cst = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
 // CHECK:           %0 = "tf.Transpose"(%arg1, %cst) : (tensor<2x4x?x3xf32>, tensor<4xi64>) -> tensor<2x3x4x?xf32>
 // CHECK:           %1 = "tf.Shape"(%arg0) : (tensor<2x2x?x3xf32>) -> tensor<4xi32>
-// CHECK-DAG:       %cst_0 = "tf.Const"() {value = dense<[-1, 0, 0, -1]> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK-DAG:       %cst_1 = "tf.Const"() {value = dense<[-1, -1, -1, 0]> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK-DAG:       %cst_2 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       %cst_0 = "tf.Const"() <{value = dense<[-1, 0, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG:       %cst_1 = "tf.Const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG:       %cst_2 = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           %2 = "tf.UnsortedSegmentProd"(%1, %cst_0, %cst_2) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
 // CHECK:           %3 = "tf.UnsortedSegmentProd"(%1, %cst_1, %cst_2) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
-// CHECK-DAG:       %cst_3 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
-// CHECK-DAG:       %cst_4 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       %cst_3 = "tf.Const"() <{value = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
+// CHECK-DAG:       %cst_4 = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           %4 = "tf.Concat"(%cst_4, %cst_3, %2, %3) : (tensor<i32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
 // CHECK:           %5 = "tf.Reshape"(%arg0, %4) : (tensor<2x2x?x3xf32>, tensor<3xi32>) -> tensor<2x?x3xf32>
 // CHECK:           %6 = "tf.Shape"(%arg1) : (tensor<2x4x?x3xf32>) -> tensor<4xi32>
-// CHECK-DAG:       %cst_5 = "tf.Const"() {value = dense<[-1, 0, 0, -1]> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK-DAG:       %cst_6 = "tf.Const"() {value = dense<[-1, -1, -1, 0]> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK-DAG:       %cst_7 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       %cst_5 = "tf.Const"() <{value = dense<[-1, 0, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG:       %cst_6 = "tf.Const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG:       %cst_7 = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           %7 = "tf.UnsortedSegmentProd"(%6, %cst_5, %cst_7) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
 // CHECK:           %8 = "tf.UnsortedSegmentProd"(%6, %cst_6, %cst_7) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
-// CHECK-DAG:       %cst_8 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
-// CHECK-DAG:       %cst_9 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       %cst_8 = "tf.Const"() <{value = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
+// CHECK-DAG:       %cst_9 = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           %9 = "tf.Concat"(%cst_9, %cst_8, %8, %7) : (tensor<i32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
 // CHECK:           %10 = "tf.Reshape"(%0, %9) : (tensor<2x3x4x?xf32>, tensor<3xi32>) -> tensor<2x3x?xf32>
-// CHECK:           %11 = "tf.BatchMatMulV3"(%5, %10) {adj_x = false, adj_y = false} : (tensor<2x?x3xf32>, tensor<2x3x?xf32>) -> tensor<2x?x?xf32>
+// CHECK:           %11 = "tf.BatchMatMulV3"(%5, %10) <{adj_x = false, adj_y = false}> : (tensor<2x?x3xf32>, tensor<2x3x?xf32>) -> tensor<2x?x?xf32>
 // CHECK:           %12 = "tf.Shape"(%arg0) : (tensor<2x2x?x3xf32>) -> tensor<4xi32>
 // CHECK:           %13 = "tf.Shape"(%arg1) : (tensor<2x4x?x3xf32>) -> tensor<4xi32>
-// CHECK-DAG:       %cst_10 = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi64>} : () -> tensor<3xi64>
-// CHECK:           %14 = "tf.Gather"(%12, %cst_10) {validate_indices = true} : (tensor<4xi32>, tensor<3xi64>) -> tensor<3xi32>
-// CHECK-DAG:       %cst_11 = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi64>} : () -> tensor<2xi64>
-// CHECK:           %15 = "tf.Gather"(%13, %cst_11) {validate_indices = true} : (tensor<4xi32>, tensor<2xi64>) -> tensor<2xi32>
-// CHECK-DAG:       %cst_12 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       %cst_10 = "tf.Const"() <{value = dense<[0, 1, 2]> : tensor<3xi64>}> : () -> tensor<3xi64>
+// CHECK:           %14 = "tf.Gather"(%12, %cst_10) <{validate_indices = true}> : (tensor<4xi32>, tensor<3xi64>) -> tensor<3xi32>
+// CHECK-DAG:       %cst_11 = "tf.Const"() <{value = dense<[1, 2]> : tensor<2xi64>}> : () -> tensor<2xi64>
+// CHECK:           %15 = "tf.Gather"(%13, %cst_11) <{validate_indices = true}> : (tensor<4xi32>, tensor<2xi64>) -> tensor<2xi32>
+// CHECK-DAG:       %cst_12 = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           %16 = "tf.Concat"(%cst_12, %14, %15) : (tensor<i32>, tensor<3xi32>, tensor<2xi32>) -> tensor<5xi32>
 // CHECK:           %17 = "tf.Reshape"(%11, %16) : (tensor<2x?x?xf32>, tensor<5xi32>) -> tensor<2x2x?x4x?xf32>
 // CHECK:           return %17 : tensor<2x2x?x4x?xf32>
@@ -2107,26 +2107,26 @@
 // CHECK-SAME:                              %arg0: tensor<4x4x?xf32>,
 // CHECK-SAME:                              %arg1: tensor<4x?x256xf32>) -> tensor<4x4x256xf32> {
 // CHECK:           %0 = "tf.Shape"(%arg0) : (tensor<4x4x?xf32>) -> tensor<3xi32>
-// CHECK-DAG:       %cst = "tf.Const"() {value = dense<[-1, 0, -1]> : tensor<3xi32>} : () -> tensor<3xi32>
-// CHECK-DAG:       %cst_0 = "tf.Const"() {value = dense<[-1, -1, 0]> : tensor<3xi32>} : () -> tensor<3xi32>
-// CHECK-DAG:       %cst_1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       %cst = "tf.Const"() <{value = dense<[-1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32>
+// CHECK-DAG:       %cst_0 = "tf.Const"() <{value = dense<[-1, -1, 0]> : tensor<3xi32>}> : () -> tensor<3xi32>
+// CHECK-DAG:       %cst_1 = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           %1 = "tf.UnsortedSegmentProd"(%0, %cst, %cst_1) : (tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
 // CHECK:           %2 = "tf.UnsortedSegmentProd"(%0, %cst_0, %cst_1) : (tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
-// CHECK-DAG:       %cst_2 = "tf.Const"() {value = dense<4> : tensor<1xi32>} : () -> tensor<1xi32>
-// CHECK-DAG:       %cst_3 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       %cst_2 = "tf.Const"() <{value = dense<4> : tensor<1xi32>}> : () -> tensor<1xi32>
+// CHECK-DAG:       %cst_3 = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           %3 = "tf.Concat"(%cst_3, %cst_2, %1, %2) : (tensor<i32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
 // CHECK:           %4 = "tf.Reshape"(%arg0, %3) : (tensor<4x4x?xf32>, tensor<3xi32>) -> tensor<4x4x?xf32>
 // CHECK:           %5 = "tf.Shape"(%arg1) : (tensor<4x?x256xf32>) -> tensor<3xi32>
-// CHECK-DAG:       %cst_4 = "tf.Const"() {value = dense<[-1, -1, 0]> : tensor<3xi32>} : () -> tensor<3xi32>
-// CHECK-DAG:       %cst_5 = "tf.Const"() {value = dense<[-1, 0, -1]> : tensor<3xi32>} : () -> tensor<3xi32>
-// CHECK-DAG:       %cst_6 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       %cst_4 = "tf.Const"() <{value = dense<[-1, -1, 0]> : tensor<3xi32>}> : () -> tensor<3xi32>
+// CHECK-DAG:       %cst_5 = "tf.Const"() <{value = dense<[-1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32>
+// CHECK-DAG:       %cst_6 = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           %6 = "tf.UnsortedSegmentProd"(%5, %cst_4, %cst_6) : (tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
 // CHECK:           %7 = "tf.UnsortedSegmentProd"(%5, %cst_5, %cst_6) : (tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
-// CHECK-DAG:       %cst_7 = "tf.Const"() {value = dense<4> : tensor<1xi32>} : () -> tensor<1xi32>
-// CHECK-DAG:       %cst_8 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       %cst_7 = "tf.Const"() <{value = dense<4> : tensor<1xi32>}> : () -> tensor<1xi32>
+// CHECK-DAG:       %cst_8 = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           %8 = "tf.Concat"(%cst_8, %cst_7, %7, %6) : (tensor<i32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
 // CHECK:           %9 = "tf.Reshape"(%arg1, %8) : (tensor<4x?x256xf32>, tensor<3xi32>) -> tensor<4x?x256xf32>
-// CHECK:           %10 = "tf.BatchMatMulV3"(%4, %9) {adj_x = false, adj_y = false} : (tensor<4x4x?xf32>, tensor<4x?x256xf32>) -> tensor<4x4x256xf32>
+// CHECK:           %10 = "tf.BatchMatMulV3"(%4, %9) <{adj_x = false, adj_y = false}> : (tensor<4x4x?xf32>, tensor<4x?x256xf32>) -> tensor<4x4x256xf32>
 // CHECK:           return %10 : tensor<4x4x256xf32>
 // CHECK:           }
 func.func @convert_dot_general_dynamic_contracting_dim(%arg0: tensor<4x4x?xf32>, %arg1: tensor<4x?x256xf32>) -> tensor<4x4x256xf32> {
@@ -2145,14 +2145,14 @@
 // CHECK-SAME:                              %[[VAL_1:.*]]: tensor<1x256x256xbf16>) -> tensor<16x32x256xbf16> {
 // CHECK-DAG:       %[[VAL_2:.*]] = arith.constant dense<[16, 32, 256, 1]> : tensor<4xi64>
 // CHECK:           %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_2]]) : (tensor<16x32x256xbf16>, tensor<4xi64>) -> tensor<16x32x256x1xbf16>
-// CHECK-DAG:       %[[VAL_4:.*]] = "tf.Const"() {value = dense<[0, 1, 3, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
+// CHECK-DAG:       %[[VAL_4:.*]] = "tf.Const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
 // CHECK:           %[[VAL_5:.*]] = "tf.Transpose"(%[[VAL_3]], %[[VAL_4]]) : (tensor<16x32x256x1xbf16>, tensor<4xi64>) -> tensor<16x32x1x256xbf16>
 // CHECK-DAG:       %[[VAL_6:.*]] = arith.constant dense<[1, 256, 256, 1]> : tensor<4xi64>
 // CHECK:           %[[VAL_7:.*]] = "tf.Reshape"(%[[VAL_1]], %[[VAL_6]]) : (tensor<1x256x256xbf16>, tensor<4xi64>) -> tensor<1x256x256x1xbf16>
-// CHECK-DAG:       %[[VAL_8:.*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
+// CHECK-DAG:       %[[VAL_8:.*]] = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
 // CHECK:           %[[VAL_9:.*]] = "tf.Transpose"(%[[VAL_7]], %[[VAL_8]]) : (tensor<1x256x256x1xbf16>, tensor<4xi64>) -> tensor<1x1x256x256xbf16>
-// CHECK:           %[[VAL_10:.*]] = "tf.Conv2D"(%[[VAL_5]], %[[VAL_9]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<16x32x1x256xbf16>, tensor<1x1x256x256xbf16>) -> tensor<16x32x1x256xbf16>
-// CHECK:           %[[VAL_11:.*]] = "tf.Const"() {value = dense<[0, 1, 3, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
+// CHECK:           %[[VAL_10:.*]] = "tf.Conv2D"(%[[VAL_5]], %[[VAL_9]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<16x32x1x256xbf16>, tensor<1x1x256x256xbf16>) -> tensor<16x32x1x256xbf16>
+// CHECK:           %[[VAL_11:.*]] = "tf.Const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
 // CHECK:           %[[VAL_12:.*]] = "tf.Transpose"(%[[VAL_10]], %[[VAL_11]]) : (tensor<16x32x1x256xbf16>, tensor<4xi64>) -> tensor<16x32x256x1xbf16>
 // CHECK:           %[[VAL_13:.*]] = arith.constant dense<[16, 32, 256]> : tensor<3xi64>
 // CHECK:           %[[VAL_14:.*]] = "tf.Reshape"(%[[VAL_12]], %[[VAL_13]]) : (tensor<16x32x256x1xbf16>, tensor<3xi64>) -> tensor<16x32x256xbf16>
@@ -2177,14 +2177,14 @@
 // CHECK-SAME:                              %[[VAL_1:.*]]: tensor<1x256x256xbf16>) -> tensor<?x32x256xbf16> {
 // CHECK-DAG:       %[[VAL_2:.*]] = arith.constant dense<[-9223372036854775808, 32, 256, 1]> : tensor<4xi64>
 // CHECK:           %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_2]]) : (tensor<?x32x256xbf16>, tensor<4xi64>) -> tensor<?x32x256x1xbf16>
-// CHECK-DAG:       %[[VAL_4:.*]] = "tf.Const"() {value = dense<[0, 1, 3, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
+// CHECK-DAG:       %[[VAL_4:.*]] = "tf.Const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
 // CHECK:           %[[VAL_5:.*]] = "tf.Transpose"(%[[VAL_3]], %[[VAL_4]]) : (tensor<?x32x256x1xbf16>, tensor<4xi64>) -> tensor<?x32x1x256xbf16>
 // CHECK-DAG:       %[[VAL_6:.*]] = arith.constant dense<[1, 256, 256, 1]> : tensor<4xi64>
 // CHECK:           %[[VAL_7:.*]] = "tf.Reshape"(%[[VAL_1]], %[[VAL_6]]) : (tensor<1x256x256xbf16>, tensor<4xi64>) -> tensor<1x256x256x1xbf16>
-// CHECK-DAG:       %[[VAL_8:.*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
+// CHECK-DAG:       %[[VAL_8:.*]] = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
 // CHECK:           %[[VAL_9:.*]] = "tf.Transpose"(%[[VAL_7]], %[[VAL_8]]) : (tensor<1x256x256x1xbf16>, tensor<4xi64>) -> tensor<1x1x256x256xbf16>
-// CHECK:           %[[VAL_10:.*]] = "tf.Conv2D"(%[[VAL_5]], %[[VAL_9]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<?x32x1x256xbf16>, tensor<1x1x256x256xbf16>) -> tensor<?x32x1x256xbf16>
-// CHECK:           %[[VAL_11:.*]] = "tf.Const"() {value = dense<[0, 1, 3, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
+// CHECK:           %[[VAL_10:.*]] = "tf.Conv2D"(%[[VAL_5]], %[[VAL_9]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<?x32x1x256xbf16>, tensor<1x1x256x256xbf16>) -> tensor<?x32x1x256xbf16>
+// CHECK:           %[[VAL_11:.*]] = "tf.Const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
 // CHECK:           %[[VAL_12:.*]] = "tf.Transpose"(%[[VAL_10]], %[[VAL_11]]) : (tensor<?x32x1x256xbf16>, tensor<4xi64>) -> tensor<?x32x256x1xbf16>
 // CHECK:           %[[VAL_13:.*]] = arith.constant dense<[-9223372036854775808, 32, 256]> : tensor<3xi64>
 // CHECK:           %[[VAL_14:.*]] = "tf.Reshape"(%[[VAL_12]], %[[VAL_13]]) : (tensor<?x32x256x1xbf16>, tensor<3xi64>) -> tensor<?x32x256xbf16>
@@ -2211,14 +2211,14 @@
 // CHECK-SAME:                              %[[VAL_1:.*]]: tensor<1x256x256xbf16>) -> tensor<16x32x256xbf16> {
 // CHECK-DAG:       %[[VAL_2:.*]] = arith.constant dense<[16, 32, 256, 1]> : tensor<4xi64>
 // CHECK:           %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_2]]) : (tensor<16x32x256xbf16>, tensor<4xi64>) -> tensor<16x32x256x1xbf16>
-// CHECK-DAG:       %[[VAL_4:.*]] = "tf.Const"() {value = dense<[0, 1, 3, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
+// CHECK-DAG:       %[[VAL_4:.*]] = "tf.Const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
 // CHECK:           %[[VAL_5:.*]] = "tf.Transpose"(%[[VAL_3]], %[[VAL_4]]) : (tensor<16x32x256x1xbf16>, tensor<4xi64>) -> tensor<16x32x1x256xbf16>
 // CHECK-DAG:       %[[VAL_6:.*]] = arith.constant dense<[1, 256, 256, 1]> : tensor<4xi64>
 // CHECK:           %[[VAL_7:.*]] = "tf.Reshape"(%[[VAL_1]], %[[VAL_6]]) : (tensor<1x256x256xbf16>, tensor<4xi64>) -> tensor<1x256x256x1xbf16>
-// CHECK-DAG:       %[[VAL_8:.*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
+// CHECK-DAG:       %[[VAL_8:.*]] = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
 // CHECK:           %[[VAL_9:.*]] = "tf.Transpose"(%[[VAL_7]], %[[VAL_8]]) : (tensor<1x256x256x1xbf16>, tensor<4xi64>) -> tensor<1x1x256x256xbf16>
-// CHECK:           %[[VAL_10:.*]] = "tf.Conv2D"(%[[VAL_5]], %[[VAL_9]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<16x32x1x256xbf16>, tensor<1x1x256x256xbf16>) -> tensor<16x32x1x256xbf16>
-// CHECK:           %[[VAL_11:.*]] = "tf.Const"() {value = dense<[0, 1, 3, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
+// CHECK:           %[[VAL_10:.*]] = "tf.Conv2D"(%[[VAL_5]], %[[VAL_9]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<16x32x1x256xbf16>, tensor<1x1x256x256xbf16>) -> tensor<16x32x1x256xbf16>
+// CHECK:           %[[VAL_11:.*]] = "tf.Const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
 // CHECK:           %[[VAL_12:.*]] = "tf.Transpose"(%[[VAL_10]], %[[VAL_11]]) : (tensor<16x32x1x256xbf16>, tensor<4xi64>) -> tensor<16x32x256x1xbf16>
 // CHECK:           %[[VAL_13:.*]] = arith.constant dense<[16, 32, 256]> : tensor<3xi64>
 // CHECK:           %[[VAL_14:.*]] = "tf.Reshape"(%[[VAL_12]], %[[VAL_13]]) : (tensor<16x32x256x1xbf16>, tensor<3xi64>) -> tensor<16x32x256xbf16>
@@ -2240,14 +2240,14 @@
 // CHECK-SAME:                              %[[VAL_1:.*]]: tensor<1x256x256xbf16>) -> tensor<?x32x256xbf16> {
 // CHECK-DAG:       %[[VAL_2:.*]] = arith.constant dense<[-9223372036854775808, 32, 256, 1]> : tensor<4xi64>
 // CHECK:           %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_2]]) : (tensor<?x32x256xbf16>, tensor<4xi64>) -> tensor<?x32x256x1xbf16>
-// CHECK-DAG:       %[[VAL_4:.*]] = "tf.Const"() {value = dense<[0, 1, 3, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
+// CHECK-DAG:       %[[VAL_4:.*]] = "tf.Const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
 // CHECK:           %[[VAL_5:.*]] = "tf.Transpose"(%[[VAL_3]], %[[VAL_4]]) : (tensor<?x32x256x1xbf16>, tensor<4xi64>) -> tensor<?x32x1x256xbf16>
 // CHECK-DAG:       %[[VAL_6:.*]] = arith.constant dense<[1, 256, 256, 1]> : tensor<4xi64>
 // CHECK:           %[[VAL_7:.*]] = "tf.Reshape"(%[[VAL_1]], %[[VAL_6]]) : (tensor<1x256x256xbf16>, tensor<4xi64>) -> tensor<1x256x256x1xbf16>
-// CHECK-DAG:       %[[VAL_8:.*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
+// CHECK-DAG:       %[[VAL_8:.*]] = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
 // CHECK:           %[[VAL_9:.*]] = "tf.Transpose"(%[[VAL_7]], %[[VAL_8]]) : (tensor<1x256x256x1xbf16>, tensor<4xi64>) -> tensor<1x1x256x256xbf16>
-// CHECK:           %[[VAL_10:.*]] = "tf.Conv2D"(%[[VAL_5]], %[[VAL_9]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<?x32x1x256xbf16>, tensor<1x1x256x256xbf16>) -> tensor<?x32x1x256xbf16>
-// CHECK:           %[[VAL_11:.*]] = "tf.Const"() {value = dense<[0, 1, 3, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
+// CHECK:           %[[VAL_10:.*]] = "tf.Conv2D"(%[[VAL_5]], %[[VAL_9]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<?x32x1x256xbf16>, tensor<1x1x256x256xbf16>) -> tensor<?x32x1x256xbf16>
+// CHECK:           %[[VAL_11:.*]] = "tf.Const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
 // CHECK:           %[[VAL_12:.*]] = "tf.Transpose"(%[[VAL_10]], %[[VAL_11]]) : (tensor<?x32x1x256xbf16>, tensor<4xi64>) -> tensor<?x32x256x1xbf16>
 // CHECK:           %[[VAL_13:.*]] = arith.constant dense<[-9223372036854775808, 32, 256]> : tensor<3xi64>
 // CHECK:           %[[VAL_14:.*]] = "tf.Reshape"(%[[VAL_12]], %[[VAL_13]]) : (tensor<?x32x256x1xbf16>, tensor<3xi64>) -> tensor<?x32x256xbf16>
@@ -2271,14 +2271,14 @@
 // CHECK-SAME:                                                              %[[VAL_1:.*]]: tensor<256x1x256xbf16>) -> tensor<256x16x32xbf16> {
 // CHECK-DAG:       %[[VAL_2:.*]] = arith.constant dense<[32, 16, 256, 1]> : tensor<4xi64>
 // CHECK:           %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_2]]) : (tensor<32x16x256xbf16>, tensor<4xi64>) -> tensor<32x16x256x1xbf16>
-// CHECK-DAG:       %[[VAL_4:.*]] = "tf.Const"() {value = dense<[1, 0, 3, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
+// CHECK-DAG:       %[[VAL_4:.*]] = "tf.Const"() <{value = dense<[1, 0, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
 // CHECK:           %[[VAL_5:.*]] = "tf.Transpose"(%[[VAL_3]], %[[VAL_4]]) : (tensor<32x16x256x1xbf16>, tensor<4xi64>) -> tensor<16x32x1x256xbf16>
 // CHECK:           %[[VAL_6:.*]] = arith.constant dense<[256, 1, 256, 1]> : tensor<4xi64>
 // CHECK:           %[[VAL_7:.*]] = "tf.Reshape"(%[[VAL_1]], %[[VAL_6]]) : (tensor<256x1x256xbf16>, tensor<4xi64>) -> tensor<256x1x256x1xbf16>
-// CHECK-DAG:       %[[VAL_8:.*]] = "tf.Const"() {value = dense<[1, 3, 2, 0]> : tensor<4xi64>} : () -> tensor<4xi64>
+// CHECK-DAG:       %[[VAL_8:.*]] = "tf.Const"() <{value = dense<[1, 3, 2, 0]> : tensor<4xi64>}> : () -> tensor<4xi64>
 // CHECK:           %[[VAL_9:.*]] = "tf.Transpose"(%[[VAL_7]], %[[VAL_8]]) : (tensor<256x1x256x1xbf16>, tensor<4xi64>) -> tensor<1x1x256x256xbf16>
-// CHECK:           %[[VAL_10:.*]] = "tf.Conv2D"(%[[VAL_5]], %[[VAL_9]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<16x32x1x256xbf16>, tensor<1x1x256x256xbf16>) -> tensor<16x32x1x256xbf16>
-// CHECK-DAG:       %[[VAL_11:.*]] = "tf.Const"() {value = dense<[3, 0, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
+// CHECK:           %[[VAL_10:.*]] = "tf.Conv2D"(%[[VAL_5]], %[[VAL_9]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<16x32x1x256xbf16>, tensor<1x1x256x256xbf16>) -> tensor<16x32x1x256xbf16>
+// CHECK-DAG:       %[[VAL_11:.*]] = "tf.Const"() <{value = dense<[3, 0, 1, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
 // CHECK:           %[[VAL_12:.*]] = "tf.Transpose"(%[[VAL_10]], %[[VAL_11]]) : (tensor<16x32x1x256xbf16>, tensor<4xi64>) -> tensor<256x16x32x1xbf16>
 // CHECK:           %[[VAL_13:.*]] = arith.constant dense<[256, 16, 32]> : tensor<3xi64>
 // CHECK:           %[[VAL_14:.*]] = "tf.Reshape"(%[[VAL_12]], %[[VAL_13]]) : (tensor<256x16x32x1xbf16>, tensor<3xi64>) -> tensor<256x16x32xbf16>
@@ -2343,14 +2343,14 @@
 // CHECK-SAME:                              %[[VAL_1:.*]]: tensor<1x256x256xbf16>) -> tensor<16x32x256xbf16> {
 // CHECK-DAG:       %[[VAL_2:.*]] = arith.constant dense<[16, 32, 256, 1]> : tensor<4xi64>
 // CHECK:           %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_2]]) : (tensor<16x32x256xbf16>, tensor<4xi64>) -> tensor<16x32x256x1xbf16>
-// CHECK-DAG:       %[[VAL_4:.*]] = "tf.Const"() {value = dense<[0, 1, 3, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
+// CHECK-DAG:       %[[VAL_4:.*]] = "tf.Const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
 // CHECK:           %[[VAL_5:.*]] = "tf.Transpose"(%[[VAL_3]], %[[VAL_4]]) : (tensor<16x32x256x1xbf16>, tensor<4xi64>) -> tensor<16x32x1x256xbf16>
 // CHECK-DAG:       %[[VAL_6:.*]] = arith.constant dense<[1, 256, 256, 1]> : tensor<4xi64>
 // CHECK:           %[[VAL_7:.*]] = "tf.Reshape"(%[[VAL_1]], %[[VAL_6]]) : (tensor<1x256x256xbf16>, tensor<4xi64>) -> tensor<1x256x256x1xbf16>
-// CHECK-DAG:       %[[VAL_8:.*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
+// CHECK-DAG:       %[[VAL_8:.*]] = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
 // CHECK:           %[[VAL_9:.*]] = "tf.Transpose"(%[[VAL_7]], %[[VAL_8]]) : (tensor<1x256x256x1xbf16>, tensor<4xi64>) -> tensor<1x1x256x256xbf16>
-// CHECK:           %[[VAL_10:.*]] = "tf.Conv2D"(%[[VAL_5]], %[[VAL_9]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<16x32x1x256xbf16>, tensor<1x1x256x256xbf16>) -> tensor<16x32x1x256xbf16>
-// CHECK:           %[[VAL_11:.*]] = "tf.Const"() {value = dense<[0, 1, 3, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
+// CHECK:           %[[VAL_10:.*]] = "tf.Conv2D"(%[[VAL_5]], %[[VAL_9]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<16x32x1x256xbf16>, tensor<1x1x256x256xbf16>) -> tensor<16x32x1x256xbf16>
+// CHECK:           %[[VAL_11:.*]] = "tf.Const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
 // CHECK:           %[[VAL_12:.*]] = "tf.Transpose"(%[[VAL_10]], %[[VAL_11]]) : (tensor<16x32x1x256xbf16>, tensor<4xi64>) -> tensor<16x32x256x1xbf16>
 // CHECK:           %[[VAL_13:.*]] = arith.constant dense<[16, 32, 256]> : tensor<3xi64>
 // CHECK:           %[[VAL_14:.*]] = "tf.Reshape"(%[[VAL_12]], %[[VAL_13]]) : (tensor<16x32x256x1xbf16>, tensor<3xi64>) -> tensor<16x32x256xbf16>
@@ -2372,7 +2372,7 @@
 // CHECK-LABEL:   func.func @convert_conv1d_missing_windows_strides_fallback_2(
 // CHECK-SAME:                              %[[VAL_0:.*]]: tensor<1x64x64x4xbf16>,
 // CHECK-SAME:                              %[[VAL_1:.*]]: tensor<3x3x4x320xbf16>) -> tensor<1x62x62x320xbf16> {
-// CHECK:           %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<1x64x64x4xbf16>, tensor<3x3x4x320xbf16>) -> tensor<1x62x62x320xbf16>
+// CHECK:           %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<1x64x64x4xbf16>, tensor<3x3x4x320xbf16>) -> tensor<1x62x62x320xbf16>
 // CHECK:           return %[[VAL_2]] : tensor<1x62x62x320xbf16>
 // CHECK:         }
 func.func @convert_conv1d_missing_windows_strides_fallback_2(%arg0: tensor<1x64x64x4xbf16>, %arg1: tensor<3x3x4x320xbf16>) -> tensor<1x62x62x320xbf16> {
@@ -2391,7 +2391,7 @@
 // CHECK-LABEL:   func @convert_conv2d(
 // CHECK-SAME:                         %[[VAL_0:.*]]: tensor<1x8x8x207xf32>,
 // CHECK-SAME:                         %[[VAL_1:.*]]: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
-// CHECK:           %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
+// CHECK:           %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
 // CHECK:           return %[[VAL_2]] : tensor<1x8x8x16xf32>
 // CHECK:         }
 func.func @convert_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
@@ -2414,7 +2414,7 @@
 // CHECK-LABEL:   func @convert_group_conv2d(
 // CHECK-SAME:                         %[[VAL_0:.*]]: tensor<1x14x14x2240xf32>,
 // CHECK-SAME:                         %[[VAL_1:.*]]: tensor<3x3x112x2240xf32>) -> tensor<1x7x7x2240xf32> {
-// CHECK:           %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [0, 0, 1, 1, 1, 1, 0, 0], padding = "EXPLICIT", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x14x14x2240xf32>, tensor<3x3x112x2240xf32>) -> tensor<1x7x7x2240xf32>
+// CHECK:           %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [0, 0, 1, 1, 1, 1, 0, 0], padding = "EXPLICIT", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true}> : (tensor<1x14x14x2240xf32>, tensor<3x3x112x2240xf32>) -> tensor<1x7x7x2240xf32>
 // CHECk:           return %[[VAL_2]] : tensor<1x7x7x2240xf32>
 // CHECK:         }
 func.func @convert_group_conv2d(%arg0: tensor<1x14x14x2240xf32>, %arg1: tensor<3x3x112x2240xf32>) -> tensor<1x7x7x2240xf32> {
@@ -2428,13 +2428,13 @@
 // CHECK-LABEL:    func.func @convert_transpose_conv_with_transpose(
 // CHECK-SAME:                         %[[VAL_0:.*]]: tensor<1x256x64x64xf32>,
 // CHECK-SAME:                         %[[VAL_1:.*]]: tensor<2x2x64x256xf32>) -> tensor<1x64x128x128xf32> {
-// CHECK:            %[[VAL_2:.*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64>
+// CHECK:            %[[VAL_2:.*]] = "tf.Const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi64>}> : () -> tensor<4xi64>
 // CHECK:            %[[VAL_3:.*]] = "tf.Transpose"(%[[VAL_0:.*]], %[[VAL_2:.*]]) : (tensor<1x256x64x64xf32>, tensor<4xi64>) -> tensor<1x64x64x256xf32>
-// CHECK:            %[[VAL_4:.*]] = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64>
+// CHECK:            %[[VAL_4:.*]] = "tf.Const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64>
 // CHECK:            %[[VAL_5:.*]] = "tf.ReverseV2"(%[[VAL_1:.*]], %[[VAL_4:.*]]) : (tensor<2x2x64x256xf32>, tensor<2xi64>) -> tensor<2x2x64x256xf32>
-// CHECK:            %[[VAL_6:.*]] = "tf.Const"() {value = dense<[1, 128, 128, 64]> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK:            %[[VAL_7:.*]] = "tf.Conv2DBackpropInput"(%[[VAL_6:.*]], %[[VAL_5:.*]], %[[VAL_3:.*]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true} : (tensor<4xi32>, tensor<2x2x64x256xf32>, tensor<1x64x64x256xf32>) -> tensor<1x128x128x64xf32>
-// CHECK:            %[[VAL_8:.*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
+// CHECK:            %[[VAL_6:.*]] = "tf.Const"() <{value = dense<[1, 128, 128, 64]> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK:            %[[VAL_7:.*]] = "tf.Conv2DBackpropInput"(%[[VAL_6:.*]], %[[VAL_5:.*]], %[[VAL_3:.*]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true}> : (tensor<4xi32>, tensor<2x2x64x256xf32>, tensor<1x64x64x256xf32>) -> tensor<1x128x128x64xf32>
+// CHECK:            %[[VAL_8:.*]] = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
 // CHECK:            %[[VAL_9:.*]] = "tf.Transpose"(%[[VAL_7:.*]], %[[VAL_8:.*]]) : (tensor<1x128x128x64xf32>, tensor<4xi64>) -> tensor<1x64x128x128xf32>
 // CHECK:            return %[[VAL_9:.*]] : tensor<1x64x128x128xf32>
 // CHECK:           }
@@ -2450,13 +2450,13 @@
 // CHECK-LABEL:    func.func @convert_transpose_conv_with_transpose2(
 // CHECK-SAME:                         %[[VAL_0:.*]]: tensor<64x64x1x256xf32>,
 // CHECK-SAME:                         %[[VAL_1:.*]]: tensor<2x2x64x256xf32>) -> tensor<128x128x1x64xf32> {
-// CHECK:            %[[VAL_2:.*]] = "tf.Const"() {value = dense<[2, 0, 1, 3]> : tensor<4xi64>} : () -> tensor<4xi64>
+// CHECK:            %[[VAL_2:.*]] = "tf.Const"() <{value = dense<[2, 0, 1, 3]> : tensor<4xi64>}> : () -> tensor<4xi64>
 // CHECK:            %[[VAL_3:.*]] = "tf.Transpose"(%[[VAL_0:.*]], %[[VAL_2:.*]]) : (tensor<64x64x1x256xf32>, tensor<4xi64>) -> tensor<1x64x64x256xf32>
-// CHECK:            %[[VAL_4:.*]] = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64>
+// CHECK:            %[[VAL_4:.*]] = "tf.Const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64>
 // CHECK:            %[[VAL_5:.*]] = "tf.ReverseV2"(%[[VAL_1:.*]], %[[VAL_4:.*]]) : (tensor<2x2x64x256xf32>, tensor<2xi64>) -> tensor<2x2x64x256xf32>
-// CHECK:            %[[VAL_6:.*]] = "tf.Const"() {value = dense<[1, 128, 128, 64]> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK:            %[[VAL_7:.*]] = "tf.Conv2DBackpropInput"(%[[VAL_6:.*]], %[[VAL_5:.*]], %[[VAL_3:.*]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true} : (tensor<4xi32>, tensor<2x2x64x256xf32>, tensor<1x64x64x256xf32>) -> tensor<1x128x128x64xf32>
-// CHECK:            %[[VAL_8:.*]] = "tf.Const"() {value = dense<[1, 2, 0, 3]> : tensor<4xi64>} : () -> tensor<4xi64>
+// CHECK:            %[[VAL_6:.*]] = "tf.Const"() <{value = dense<[1, 128, 128, 64]> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK:            %[[VAL_7:.*]] = "tf.Conv2DBackpropInput"(%[[VAL_6:.*]], %[[VAL_5:.*]], %[[VAL_3:.*]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true}> : (tensor<4xi32>, tensor<2x2x64x256xf32>, tensor<1x64x64x256xf32>) -> tensor<1x128x128x64xf32>
+// CHECK:            %[[VAL_8:.*]] = "tf.Const"() <{value = dense<[1, 2, 0, 3]> : tensor<4xi64>}> : () -> tensor<4xi64>
 // CHECK:            %[[VAL_9:.*]] = "tf.Transpose"(%[[VAL_7:.*]], %[[VAL_8:.*]]) : (tensor<1x128x128x64xf32>, tensor<4xi64>) -> tensor<128x128x1x64xf32>
 // CHECK:            return %[[VAL_9:.*]] : tensor<128x128x1x64xf32>
 // CHECK:           }
@@ -2473,7 +2473,7 @@
 // CHECK-LABEL:   func @convert_conv2d_dynamic_batch(
 // CHECK-SAME:                         %[[VAL_0:.*]]: tensor<?x8x8x207xf32>,
 // CHECK-SAME:                         %[[VAL_1:.*]]: tensor<3x3x207x16xf32>) -> tensor<?x8x8x16xf32> {
-// CHECK:           %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<?x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<?x8x8x16xf32>
+// CHECK:           %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<?x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<?x8x8x16xf32>
 // CHECK:           return %[[VAL_2]] : tensor<?x8x8x16xf32>
 // CHECK:         }
 func.func @convert_conv2d_dynamic_batch(%arg0: tensor<?x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<?x8x8x16xf32> {
@@ -2496,7 +2496,7 @@
 // CHECK-LABEL:   func @convert_conv2d_no_padding(
 // CHECK-SAME:                         %[[VAL_0:.*]]: tensor<1x6x6x207xf32>,
 // CHECK-SAME:                         %[[VAL_1:.*]]: tensor<3x3x207x16xf32>) -> tensor<1x4x4x16xf32> {
-// CHECK:           %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<1x6x6x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x4x4x16xf32>
+// CHECK:           %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<1x6x6x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x4x4x16xf32>
 // CHECK:           return %[[VAL_2]] : tensor<1x4x4x16xf32>
 // CHECK:         }
 func.func @convert_conv2d_no_padding(%arg0: tensor<1x6x6x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x4x4x16xf32> {
@@ -2519,7 +2519,7 @@
 // CHECK-LABEL:   func @convert_conv2d_no_rhs_dilation(
 // CHECK-SAME:                         %[[VAL_0:.*]]: tensor<1x8x8x207xf32>,
 // CHECK-SAME:                         %[[VAL_1:.*]]: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
-// CHECK:           %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
+// CHECK:           %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
 // CHECK:           return %[[VAL_2]] : tensor<1x8x8x16xf32>
 // CHECK:         }
 func.func @convert_conv2d_no_rhs_dilation(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
@@ -2542,7 +2542,7 @@
 // CHECK-LABEL:   func @convert_conv2d_no_window_strides(
 // CHECK-SAME:                         %[[VAL_0:.*]]: tensor<1x8x8x207xf32>,
 // CHECK-SAME:                         %[[VAL_1:.*]]: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
-// CHECK:           %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
+// CHECK:           %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
 // CHECK:           return %[[VAL_2]] : tensor<1x8x8x16xf32>
 // CHECK:         }
 func.func @convert_conv2d_no_window_strides(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
@@ -2565,7 +2565,7 @@
 // CHECK-LABEL:   func @convert_conv2d_no_lhs_dilation(
 // CHECK-SAME:                         %[[VAL_0:.*]]: tensor<1x8x8x207xf32>,
 // CHECK-SAME:                         %[[VAL_1:.*]]: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
-// CHECK:           %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
+// CHECK:           %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
 // CHECK:           return %[[VAL_2]] : tensor<1x8x8x16xf32>
 // CHECK:         }
 func.func @convert_conv2d_no_lhs_dilation(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
@@ -2588,12 +2588,12 @@
 // CHECK-LABEL:   func @convert_conv2d_with_transpose(
 // CHECK-SAME:                         %[[VAL_0:.*]]: tensor<8x8x1x207xf32>,
 // CHECK-SAME:                         %[[VAL_1:.*]]: tensor<3x3x16x207xf32>) -> tensor<16x8x8x1xf32> {
-// CHECK:           %[[VAL_2:.*]] = "tf.Const"() {value = dense<[2, 0, 1, 3]> : tensor<4xi64>} : () -> tensor<4xi64>
+// CHECK:           %[[VAL_2:.*]] = "tf.Const"() <{value = dense<[2, 0, 1, 3]> : tensor<4xi64>}> : () -> tensor<4xi64>
 // CHECK:           %[[VAL_3:.*]] = "tf.Transpose"(%[[VAL_0]], %[[VAL_2]]) : (tensor<8x8x1x207xf32>, tensor<4xi64>) -> tensor<1x8x8x207xf32>
-// CHECK:           %[[VAL_4:.*]] = "tf.Const"() {value = dense<[0, 1, 3, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
+// CHECK:           %[[VAL_4:.*]] = "tf.Const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
 // CHECK:           %[[VAL_5:.*]] = "tf.Transpose"(%[[VAL_1]], %[[VAL_4]]) : (tensor<3x3x16x207xf32>, tensor<4xi64>) -> tensor<3x3x207x16xf32>
-// CHECK:           %[[VAL_6:.*]] = "tf.Conv2D"(%[[VAL_3]], %[[VAL_5]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
-// CHECK:           %[[VAL_7:.*]] = "tf.Const"() {value = dense<[3, 1, 2, 0]> : tensor<4xi64>} : () -> tensor<4xi64>
+// CHECK:           %[[VAL_6:.*]] = "tf.Conv2D"(%[[VAL_3]], %[[VAL_5]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
+// CHECK:           %[[VAL_7:.*]] = "tf.Const"() <{value = dense<[3, 1, 2, 0]> : tensor<4xi64>}> : () -> tensor<4xi64>
 // CHECK:           %[[VAL_8:.*]] = "tf.Transpose"(%[[VAL_6]], %[[VAL_7]]) : (tensor<1x8x8x16xf32>, tensor<4xi64>) -> tensor<16x8x8x1xf32>
 // CHECK:           return %[[VAL_8]] : tensor<16x8x8x1xf32>
 // CHECK:         }
@@ -2617,12 +2617,12 @@
 // CHECK-LABEL:   func @convert_conv2d_with_transpose_dynamic_batch(
 // CHECK-SAME:                         %[[VAL_0:.*]]: tensor<8x8x?x207xf32>,
 // CHECK-SAME:                         %[[VAL_1:.*]]: tensor<3x3x16x207xf32>) -> tensor<16x8x8x?xf32> {
-// CHECK:           %[[VAL_2:.*]] = "tf.Const"() {value = dense<[2, 0, 1, 3]> : tensor<4xi64>} : () -> tensor<4xi64>
+// CHECK:           %[[VAL_2:.*]] = "tf.Const"() <{value = dense<[2, 0, 1, 3]> : tensor<4xi64>}> : () -> tensor<4xi64>
 // CHECK:           %[[VAL_3:.*]] = "tf.Transpose"(%[[VAL_0]], %[[VAL_2]]) : (tensor<8x8x?x207xf32>, tensor<4xi64>) -> tensor<?x8x8x207xf32>
-// CHECK:           %[[VAL_4:.*]] = "tf.Const"() {value = dense<[0, 1, 3, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
+// CHECK:           %[[VAL_4:.*]] = "tf.Const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
 // CHECK:           %[[VAL_5:.*]] = "tf.Transpose"(%[[VAL_1]], %[[VAL_4]]) : (tensor<3x3x16x207xf32>, tensor<4xi64>) -> tensor<3x3x207x16xf32>
-// CHECK:           %[[VAL_6:.*]] = "tf.Conv2D"(%[[VAL_3]], %[[VAL_5]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<?x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<?x8x8x16xf32>
-// CHECK:           %[[VAL_7:.*]] = "tf.Const"() {value = dense<[3, 1, 2, 0]> : tensor<4xi64>} : () -> tensor<4xi64>
+// CHECK:           %[[VAL_6:.*]] = "tf.Conv2D"(%[[VAL_3]], %[[VAL_5]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<?x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<?x8x8x16xf32>
+// CHECK:           %[[VAL_7:.*]] = "tf.Const"() <{value = dense<[3, 1, 2, 0]> : tensor<4xi64>}> : () -> tensor<4xi64>
 // CHECK:           %[[VAL_8:.*]] = "tf.Transpose"(%[[VAL_6]], %[[VAL_7]]) : (tensor<?x8x8x16xf32>, tensor<4xi64>) -> tensor<16x8x8x?xf32>
 // CHECK:           return %[[VAL_8]] : tensor<16x8x8x?xf32>
 // CHECK:         }
@@ -2646,7 +2646,7 @@
 // CHECK-LABEL:   func @convert_conv2d_explicit_padding(
 // CHECK-SAME:                         %[[VAL_0:.*]]: tensor<64x8x8x8xf32>,
 // CHECK-SAME:                         %[[VAL_1:.*]]: tensor<8x8x8x64xf32>) -> tensor<64x3x3x64xf32> {
-// CHECK:           %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [0, 0, 1, 1, 1, 1, 0, 0], padding = "EXPLICIT", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<64x8x8x8xf32>, tensor<8x8x8x64xf32>) -> tensor<64x3x3x64xf32>
+// CHECK:           %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [0, 0, 1, 1, 1, 1, 0, 0], padding = "EXPLICIT", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<64x8x8x8xf32>, tensor<8x8x8x64xf32>) -> tensor<64x3x3x64xf32>
 // CHECK:           return %[[VAL_2]] : tensor<64x3x3x64xf32>
 // CHECK:         }
 func.func @convert_conv2d_explicit_padding(%arg0: tensor<64x8x8x8xf32>, %arg1: tensor<8x8x8x64xf32>) -> tensor<64x3x3x64xf32> {
@@ -2670,7 +2670,7 @@
 // CHECK-LABEL:   func @convert_conv2d_explicit_padding_dynamic_batch(
 // CHECK-SAME:                         %[[VAL_0:.*]]: tensor<?x8x8x8xf32>,
 // CHECK-SAME:                         %[[VAL_1:.*]]: tensor<8x8x8x64xf32>) -> tensor<?x3x3x64xf32> {
-// CHECK:           %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [0, 0, 1, 1, 1, 1, 0, 0], padding = "EXPLICIT", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<?x8x8x8xf32>, tensor<8x8x8x64xf32>) -> tensor<?x3x3x64xf32>
+// CHECK:           %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [0, 0, 1, 1, 1, 1, 0, 0], padding = "EXPLICIT", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<?x8x8x8xf32>, tensor<8x8x8x64xf32>) -> tensor<?x3x3x64xf32>
 // CHECK:           return %[[VAL_2]] : tensor<?x3x3x64xf32>
 // CHECK:         }
 func.func @convert_conv2d_explicit_padding_dynamic_batch(%arg0: tensor<?x8x8x8xf32>, %arg1: tensor<8x8x8x64xf32>) -> tensor<?x3x3x64xf32> {
@@ -2694,8 +2694,8 @@
 // CHECK-LABEL:   func @convert_conv2d_negative_explicit_padding(
 // CHECK-SAME:                         %[[ARG0:.*]]: tensor<128x7x9x64xf32>,
 // CHECK-SAME:                         %[[ARG1:.*]]: tensor<3x2x64x4xf32>) -> tensor<128x4x3x4xf32> {
-// CHECK-DAG:       %[[START:.*]] = "tf.Const"() {value = dense<[0, 0, 5, 0]> : tensor<4xi64>} : () -> tensor<4xi64>
-// CHECK-DAG:       %[[SIZE:.*]] = "tf.Const"() {value = dense<[128, 5, 4, 64]> : tensor<4xi64>} : () -> tensor<4xi64>
+// CHECK-DAG:       %[[START:.*]] = "tf.Const"() <{value = dense<[0, 0, 5, 0]> : tensor<4xi64>}> : () -> tensor<4xi64>
+// CHECK-DAG:       %[[SIZE:.*]] = "tf.Const"() <{value = dense<[128, 5, 4, 64]> : tensor<4xi64>}> : () -> tensor<4xi64>
 // CHECK:           %[[SLICED_ARG0:.*]] = "tf.Slice"(%[[ARG0]], %[[START]], %[[SIZE]])
 // CHECK-SAME:      (tensor<128x7x9x64xf32>, tensor<4xi64>, tensor<4xi64>) -> tensor<128x5x4x64xf32>
 // CHECK:           %[[CONV:.*]] = "tf.Conv2D"(%[[SLICED_ARG0]], %[[ARG1]])
@@ -2724,8 +2724,8 @@
 // CHECK-LABEL:   func @convert_conv2d_negative_explicit_padding_dynamic_batch(
 // CHECK-SAME:                         %[[ARG0:.*]]: tensor<?x7x9x64xf32>,
 // CHECK-SAME:                         %[[ARG1:.*]]: tensor<3x2x64x4xf32>) -> tensor<?x4x3x4xf32> {
-// CHECK-DAG:       %[[START:.*]] = "tf.Const"() {value = dense<[0, 0, 5, 0]> : tensor<4xi64>} : () -> tensor<4xi64>
-// CHECK-DAG:       %[[SIZE:.*]] = "tf.Const"() {value = dense<[-9223372036854775808, 5, 4, 64]> : tensor<4xi64>} : () -> tensor<4xi64>
+// CHECK-DAG:       %[[START:.*]] = "tf.Const"() <{value = dense<[0, 0, 5, 0]> : tensor<4xi64>}> : () -> tensor<4xi64>
+// CHECK-DAG:       %[[SIZE:.*]] = "tf.Const"() <{value = dense<[-9223372036854775808, 5, 4, 64]> : tensor<4xi64>}> : () -> tensor<4xi64>
 // CHECK:           %[[SLICED_ARG0:.*]] = "tf.Slice"(%[[ARG0]], %[[START]], %[[SIZE]])
 // CHECK-SAME:      (tensor<?x7x9x64xf32>, tensor<4xi64>, tensor<4xi64>) -> tensor<?x5x4x64xf32>
 // CHECK:           %[[CONV:.*]] = "tf.Conv2D"(%[[SLICED_ARG0]], %[[ARG1]])
@@ -2756,7 +2756,7 @@
 // CHECK-SAME:                                   %[[VAL_1:.*]]: tensor<3x3x1x3312xf32>) -> tensor<1x8x8x3312xf32> {
 // CHECK:           %[[CST:.*]] = arith.constant dense<[3, 3, 207, 16]> : tensor<4xi64>
 // CHECK:           %[[VAL_2:.*]] = "tf.Reshape"(%[[VAL_1]], %[[CST]]) : (tensor<3x3x1x3312xf32>, tensor<4xi64>) -> tensor<3x3x207x16xf32>
-// CHECK:           %[[VAL_3:.*]] = "tf.DepthwiseConv2dNative"(%[[VAL_0]], %[[VAL_2]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x3312xf32>
+// CHECK:           %[[VAL_3:.*]] = "tf.DepthwiseConv2dNative"(%[[VAL_0]], %[[VAL_2]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]}> : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x3312xf32>
 // CHECK:           return %[[VAL_3]] : tensor<1x8x8x3312xf32>
 // CHECK:         }
 func.func @convert_depthwise_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x1x3312xf32>) -> tensor<1x8x8x3312xf32> {
@@ -2779,8 +2779,8 @@
 // CHECK-LABEL:   func @convert_conv2d_to_resize(
 // CHECK-SAME:                         %[[VAL_0:.*]]: tensor<1x56x624x16xf32>,
 // CHECK-SAME:                         %[[VAL_1:.*]]: tensor<1x257x16x1xf32>) -> tensor<1x56x904x16xf32> {
-// CHECK-DAG:       %[[SIZE:.*]] = "tf.Const"() {value = dense<[56, 904]> : tensor<2xi32>} : () -> tensor<2xi32>
-// CHECK:           %[[VAL_2:.*]] = "tf.ResizeBilinear"(%[[VAL_0]], %[[SIZE]]) {align_corners = true, half_pixel_centers = false} : (tensor<1x56x624x16xf32>, tensor<2xi32>) -> tensor<1x56x904x16xf32>
+// CHECK-DAG:       %[[SIZE:.*]] = "tf.Const"() <{value = dense<[56, 904]> : tensor<2xi32>}> : () -> tensor<2xi32>
+// CHECK:           %[[VAL_2:.*]] = "tf.ResizeBilinear"(%[[VAL_0]], %[[SIZE]]) <{align_corners = true, half_pixel_centers = false}> : (tensor<1x56x624x16xf32>, tensor<2xi32>) -> tensor<1x56x904x16xf32>
 // CHECK:           return %[[VAL_2]] : tensor<1x56x904x16xf32>
 // CHECK:         }
 func.func @convert_conv2d_to_resize(%arg0: tensor<1x56x624x16xf32>, %arg1: tensor<1x257x16x1xf32>) -> tensor<1x56x904x16xf32> {
@@ -2798,8 +2798,8 @@
 // CHECK-LABEL:   func @convert_conv2d_resize_perferred(
 // CHECK-SAME:                         %[[VAL_0:.*]]: tensor<1x56x1248x16xf32>,
 // CHECK-SAME:                         %[[VAL_1:.*]]: tensor<3x1x16x1xf32>) -> tensor<1x111x1248x16xf32> {
-// CHECK-DAG:       %[[SIZE:.*]] = "tf.Const"() {value = dense<[111, 1248]> : tensor<2xi32>} : () -> tensor<2xi32>
-// CHECK:           %[[VAL_2:.*]] = "tf.ResizeBilinear"(%[[VAL_0]], %[[SIZE]]) {align_corners = true, half_pixel_centers = false} : (tensor<1x56x1248x16xf32>, tensor<2xi32>) -> tensor<1x111x1248x16xf32>
+// CHECK-DAG:       %[[SIZE:.*]] = "tf.Const"() <{value = dense<[111, 1248]> : tensor<2xi32>}> : () -> tensor<2xi32>
+// CHECK:           %[[VAL_2:.*]] = "tf.ResizeBilinear"(%[[VAL_0]], %[[SIZE]]) <{align_corners = true, half_pixel_centers = false}> : (tensor<1x56x1248x16xf32>, tensor<2xi32>) -> tensor<1x111x1248x16xf32>
 // CHECK:           return %[[VAL_2]] : tensor<1x111x1248x16xf32>
 // CHECK:         }
 func.func @convert_conv2d_resize_perferred(%arg0: tensor<1x56x1248x16xf32>, %arg1: tensor<3x1x16x1xf32>) -> tensor<1x111x1248x16xf32> {
@@ -2817,10 +2817,10 @@
 // CHECK-LABEL:   func @convert_conv2d_back_prop_input_same_pad(
 // CHECK-SAME:                         %[[VAL_0:.*]]: tensor<1x256x256x2xf32>,
 // CHECK-SAME:                         %[[VAL_1:.*]]: tensor<4x4x2x2xf32>) -> tensor<1x512x512x2xf32> {
-// CHECK:           %[[VAL_3:.*]] = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64>
+// CHECK:           %[[VAL_3:.*]] = "tf.Const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64>
 // CHECK:           %[[VAL_4:.*]] = "tf.ReverseV2"(%[[VAL_1]], %[[VAL_3]]) : (tensor<4x4x2x2xf32>, tensor<2xi64>) -> tensor<4x4x2x2xf32>
-// CHECK:           %[[VAL_2:.*]] = "tf.Const"() {value = dense<[1, 512, 512, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK:           %[[VAL_5:.*]] = "tf.Conv2DBackpropInput"(%[[VAL_2]], %[[VAL_4]], %[[VAL_0]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true} : (tensor<4xi32>, tensor<4x4x2x2xf32>, tensor<1x256x256x2xf32>) -> tensor<1x512x512x2xf32>
+// CHECK:           %[[VAL_2:.*]] = "tf.Const"() <{value = dense<[1, 512, 512, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK:           %[[VAL_5:.*]] = "tf.Conv2DBackpropInput"(%[[VAL_2]], %[[VAL_4]], %[[VAL_0]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true}> : (tensor<4xi32>, tensor<4x4x2x2xf32>, tensor<1x256x256x2xf32>) -> tensor<1x512x512x2xf32>
 // CHECK:           return %[[VAL_5]] : tensor<1x512x512x2xf32>
 // CHECK:         }
 func.func @convert_conv2d_back_prop_input_same_pad(%arg0: tensor<1x256x256x2xf32>, %arg1: tensor<4x4x2x2xf32>) -> tensor<1x512x512x2xf32> {
@@ -2845,10 +2845,10 @@
 // CHECK-LABEL:   func @convert_conv2d_back_prop_input(
 // CHECK-SAME:                         %[[VAL_0:.*]]: tensor<8x4x4x32xf32>,
 // CHECK-SAME:                         %[[VAL_1:.*]]: tensor<3x3x64x32xf32>) -> tensor<8x8x8x64xf32> {
-// CHECK:           %[[VAL_2:.*]] = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64>
+// CHECK:           %[[VAL_2:.*]] = "tf.Const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64>
 // CHECK:           %[[VAL_3:.*]] = "tf.ReverseV2"(%[[VAL_1]], %[[VAL_2]]) : (tensor<3x3x64x32xf32>, tensor<2xi64>) -> tensor<3x3x64x32xf32>
-// CHECK:           %[[VAL_4:.*]] = "tf.Const"() {value = dense<[8, 8, 8, 64]> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK:           %[[VAL_5:.*]] = "tf.Conv2DBackpropInput"(%[[VAL_4]], %[[VAL_3]], %[[VAL_0]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true} : (tensor<4xi32>, tensor<3x3x64x32xf32>, tensor<8x4x4x32xf32>) -> tensor<8x8x8x64xf32>
+// CHECK:           %[[VAL_4:.*]] = "tf.Const"() <{value = dense<[8, 8, 8, 64]> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK:           %[[VAL_5:.*]] = "tf.Conv2DBackpropInput"(%[[VAL_4]], %[[VAL_3]], %[[VAL_0]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true}> : (tensor<4xi32>, tensor<3x3x64x32xf32>, tensor<8x4x4x32xf32>) -> tensor<8x8x8x64xf32>
 // CHECK:           return %[[VAL_5]] : tensor<8x8x8x64xf32>
 // CHECK:         }
 func.func @convert_conv2d_back_prop_input(%arg0: tensor<8x4x4x32xf32>, %arg1: tensor<3x3x64x32xf32>) -> tensor<8x8x8x64xf32> {
@@ -2872,12 +2872,12 @@
 // CHECK-LABEL:   func @convert_conv2d_back_prop_input_transpose_filter(
 // CHECK-SAME:                         %[[VAL_0:.*]]: tensor<8x4x4x32xf32>,
 // CHECK-SAME:                         %[[VAL_1:.*]]: tensor<3x3x32x64xf32>) -> tensor<8x8x8x64xf32> {
-// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64>
-// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() {value = dense<[0, 1, 3, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
+// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64>
+// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
 // CHECK-DAG:       %[[VAL_4:.*]] = "tf.Transpose"(%[[VAL_1]], %[[VAL_3]]) : (tensor<3x3x32x64xf32>, tensor<4xi64>) -> tensor<3x3x64x32xf32>
 // CHECK:           %[[VAL_5:.*]] = "tf.ReverseV2"(%[[VAL_4]], %[[VAL_2]]) : (tensor<3x3x64x32xf32>, tensor<2xi64>) -> tensor<3x3x64x32xf32>
-// CHECK:           %[[VAL_6:.*]] = "tf.Const"() {value = dense<[8, 8, 8, 64]> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK:           %[[VAL_7:.*]] = "tf.Conv2DBackpropInput"(%[[VAL_6]], %[[VAL_5]], %[[VAL_0]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true} : (tensor<4xi32>, tensor<3x3x64x32xf32>, tensor<8x4x4x32xf32>) -> tensor<8x8x8x64xf32>
+// CHECK:           %[[VAL_6:.*]] = "tf.Const"() <{value = dense<[8, 8, 8, 64]> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK:           %[[VAL_7:.*]] = "tf.Conv2DBackpropInput"(%[[VAL_6]], %[[VAL_5]], %[[VAL_0]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true}> : (tensor<4xi32>, tensor<3x3x64x32xf32>, tensor<8x4x4x32xf32>) -> tensor<8x8x8x64xf32>
 // CHECK:           return %[[VAL_7]] : tensor<8x8x8x64xf32>
 // CHECK:         }
 func.func @convert_conv2d_back_prop_input_transpose_filter(%arg0: tensor<8x4x4x32xf32>, %arg1: tensor<3x3x32x64xf32>) -> tensor<8x8x8x64xf32> {
@@ -2901,7 +2901,7 @@
 // CHECK-LABEL:   func @convert_conv2d_valid_padding(
 // CHECK-SAME:                                       %[[VAL_0:.*]]: tensor<1x8x8x207xf32>,
 // CHECK-SAME:                                       %[[VAL_1:.*]]: tensor<3x3x207x16xf32>) -> tensor<1x6x6x16xf32> {
-// CHECK:           %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x6x6x16xf32>
+// CHECK:           %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x6x6x16xf32>
 // CHECK:           return %[[VAL_2]] : tensor<1x6x6x16xf32>
 // CHECK:         }
 func.func @convert_conv2d_valid_padding(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x6x6x16xf32> {
@@ -2924,7 +2924,7 @@
 // CHECK-LABEL:   func @convert_conv2d_valid_padding_dynamic_batch(
 // CHECK-SAME:                                       %[[VAL_0:.*]]: tensor<?x8x8x207xf32>,
 // CHECK-SAME:                                       %[[VAL_1:.*]]: tensor<3x3x207x16xf32>) -> tensor<?x6x6x16xf32> {
-// CHECK:           %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<?x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<?x6x6x16xf32>
+// CHECK:           %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<?x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<?x6x6x16xf32>
 // CHECK:           return %[[VAL_2]] : tensor<?x6x6x16xf32>
 // CHECK:         }
 func.func @convert_conv2d_valid_padding_dynamic_batch(%arg0: tensor<?x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<?x6x6x16xf32> {
@@ -2946,9 +2946,9 @@
 
 // CHECK-LABEL:   func @convert_reduce_to_prod(
 // CHECK-SAME:                                %[[VAL_0:.*]]: tensor<1x256xf32>) -> tensor<1xf32> {
-// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
-// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64>
-// CHECK:           %[[VAL_3:.*]] = "tf.Prod"(%[[VAL_0]], %[[VAL_2]]) {keep_dims = false} : (tensor<1x256xf32>, tensor<1xi64>) -> tensor<1xf32>
+// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi64>}> : () -> tensor<1xi64>
+// CHECK:           %[[VAL_3:.*]] = "tf.Prod"(%[[VAL_0]], %[[VAL_2]]) <{keep_dims = false}> : (tensor<1x256xf32>, tensor<1xi64>) -> tensor<1xf32>
 // CHECK:           return %[[VAL_3]] : tensor<1xf32>
 // CHECK:         }
 func.func @convert_reduce_to_prod(%arg0: tensor<1x256xf32>) -> tensor<1xf32> {
@@ -2963,9 +2963,9 @@
 
 // CHECK-LABEL:   func @convert_reduce_to_sum(
 // CHECK-SAME:                                %[[VAL_0:.*]]: tensor<1x256xf32>) -> tensor<1xf32> {
-// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
-// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64>
-// CHECK:           %[[VAL_3:.*]] = "tf.Sum"(%[[VAL_0]], %[[VAL_2]]) {keep_dims = false} : (tensor<1x256xf32>, tensor<1xi64>) -> tensor<1xf32>
+// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi64>}> : () -> tensor<1xi64>
+// CHECK:           %[[VAL_3:.*]] = "tf.Sum"(%[[VAL_0]], %[[VAL_2]]) <{keep_dims = false}> : (tensor<1x256xf32>, tensor<1xi64>) -> tensor<1xf32>
 // CHECK:           return %[[VAL_3]] : tensor<1xf32>
 // CHECK:         }
 func.func @convert_reduce_to_sum(%arg0: tensor<1x256xf32>) -> tensor<1xf32> {
@@ -2981,8 +2981,8 @@
 // CHECK-LABEL:   func @convert_reduce_to_prod_non_constant_init(
 // CHECK-SAME:                                %[[ARG_0:.*]]: tensor<1x256xf32>,
 // CHECK-SAME:                                %[[ARG_1:.*]]: tensor<f32>) -> tensor<1xf32> {
-// CHECK-DAG:       %[[VAL_0:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64>
-// CHECK:           %[[VAL_1:.*]] = "tf.Prod"(%[[ARG_0]], %[[VAL_0]]) {keep_dims = false} : (tensor<1x256xf32>, tensor<1xi64>) -> tensor<1xf32>
+// CHECK-DAG:       %[[VAL_0:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi64>}> : () -> tensor<1xi64>
+// CHECK:           %[[VAL_1:.*]] = "tf.Prod"(%[[ARG_0]], %[[VAL_0]]) <{keep_dims = false}> : (tensor<1x256xf32>, tensor<1xi64>) -> tensor<1xf32>
 // CHECK:           %[[VAL_2:.*]] = "tf.Mul"(%[[VAL_1]], %[[ARG_1]]) : (tensor<1xf32>, tensor<f32>) -> tensor<1xf32>
 // CHECK:           return %[[VAL_2]] : tensor<1xf32>
 // CHECK:         }
@@ -2999,8 +2999,8 @@
 // CHECK-LABEL:   func @convert_reduce_to_sum_non_constant_init(
 // CHECK-SAME:                                %[[ARG_0:.*]]: tensor<1x256xf32>,
 // CHECK-SAME:                                %[[ARG_1:.*]]: tensor<f32>) -> tensor<1xf32> {
-// CHECK-DAG:       %[[VAL_0:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64>
-// CHECK:           %[[VAL_1:.*]] = "tf.Sum"(%[[ARG_0]], %[[VAL_0]]) {keep_dims = false} : (tensor<1x256xf32>, tensor<1xi64>) -> tensor<1xf32>
+// CHECK-DAG:       %[[VAL_0:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi64>}> : () -> tensor<1xi64>
+// CHECK:           %[[VAL_1:.*]] = "tf.Sum"(%[[ARG_0]], %[[VAL_0]]) <{keep_dims = false}> : (tensor<1x256xf32>, tensor<1xi64>) -> tensor<1xf32>
 // CHECK:           %[[VAL_2:.*]] = "tf.Add"(%[[VAL_1]], %[[ARG_1]]) : (tensor<1xf32>, tensor<f32>) -> tensor<1xf32>
 // CHECK:           return %[[VAL_2]] : tensor<1xf32>
 // CHECK:         }
@@ -3015,9 +3015,9 @@
 
 // CHECK-LABEL:   func @convert_int_reduce_to_prod(
 // CHECK-SAME:                                %[[VAL_0:.*]]: tensor<1x256xi32>) -> tensor<1xi32> {
-// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
-// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64>
-// CHECK:           %[[VAL_3:.*]] = "tf.Prod"(%[[VAL_0]], %[[VAL_2]]) {keep_dims = false} : (tensor<1x256xi32>, tensor<1xi64>) -> tensor<1xi32>
+// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi64>}> : () -> tensor<1xi64>
+// CHECK:           %[[VAL_3:.*]] = "tf.Prod"(%[[VAL_0]], %[[VAL_2]]) <{keep_dims = false}> : (tensor<1x256xi32>, tensor<1xi64>) -> tensor<1xi32>
 // CHECK:           return %[[VAL_3]] : tensor<1xi32>
 // CHECK:         }
 func.func @convert_int_reduce_to_prod(%arg0: tensor<1x256xi32>) -> tensor<1xi32> {
@@ -3033,9 +3033,9 @@
 
 // CHECK-LABEL:   func @convert_int_reduce_to_sum(
 // CHECK-SAME:                                %[[VAL_0:.*]]: tensor<1x256xi32>) -> tensor<1xi32> {
-// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64>
-// CHECK:           %[[VAL_3:.*]] = "tf.Sum"(%[[VAL_0]], %[[VAL_2]]) {keep_dims = false} : (tensor<1x256xi32>, tensor<1xi64>) -> tensor<1xi32>
+// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi64>}> : () -> tensor<1xi64>
+// CHECK:           %[[VAL_3:.*]] = "tf.Sum"(%[[VAL_0]], %[[VAL_2]]) <{keep_dims = false}> : (tensor<1x256xi32>, tensor<1xi64>) -> tensor<1xi32>
 // CHECK:           return %[[VAL_3]] : tensor<1xi32>
 // CHECK:         }
 func.func @convert_int_reduce_to_sum(%arg0: tensor<1x256xi32>) -> tensor<1xi32> {
@@ -3050,8 +3050,8 @@
 
 // CHECK-LABEL:   func @convert_reduce_to_max(
 // CHECK-SAME:                                %[[VAL_0:.*]]: tensor<1x256xf32>) -> tensor<1xf32> {
-// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64>
-// CHECK:           %[[VAL_3:.*]] = "tf.Max"(%[[VAL_0]], %[[VAL_2]]) {keep_dims = false} : (tensor<1x256xf32>, tensor<1xi64>) -> tensor<1xf32>
+// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi64>}> : () -> tensor<1xi64>
+// CHECK:           %[[VAL_3:.*]] = "tf.Max"(%[[VAL_0]], %[[VAL_2]]) <{keep_dims = false}> : (tensor<1x256xf32>, tensor<1xi64>) -> tensor<1xf32>
 // CHECK:           return %[[VAL_3]] : tensor<1xf32>
 // CHECK:         }
 func.func @convert_reduce_to_max(%arg0: tensor<1x256xf32>) -> tensor<1xf32> {
@@ -3067,8 +3067,8 @@
 
 // CHECK-LABEL:   func @convert_reduce_to_max_int(
 // CHECK-SAME:                                %[[VAL_0:.*]]: tensor<1x4xi32>) -> tensor<1xi32> {
-// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64>
-// CHECK:           %[[VAL_3:.*]] = "tf.Max"(%[[VAL_0]], %[[VAL_2]]) {keep_dims = false} : (tensor<1x4xi32>, tensor<1xi64>) -> tensor<1xi32>
+// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi64>}> : () -> tensor<1xi64>
+// CHECK:           %[[VAL_3:.*]] = "tf.Max"(%[[VAL_0]], %[[VAL_2]]) <{keep_dims = false}> : (tensor<1x4xi32>, tensor<1xi64>) -> tensor<1xi32>
 // CHECK:           return %[[VAL_3]] : tensor<1xi32>
 func.func @convert_reduce_to_max_int(%arg0: tensor<1x4xi32>) -> tensor<1xi32> {
   // -2147483648 is MIN for INT32
@@ -3083,8 +3083,8 @@
 
 // CHECK-LABEL:   func @convert_reduce_to_min(
 // CHECK-SAME:                                %[[VAL_0:.*]]: tensor<1x256xf32>) -> tensor<1xf32> {
-// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64>
-// CHECK:           %[[VAL_3:.*]] = "tf.Min"(%[[VAL_0]], %[[VAL_2]]) {keep_dims = false} : (tensor<1x256xf32>, tensor<1xi64>) -> tensor<1xf32>
+// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi64>}> : () -> tensor<1xi64>
+// CHECK:           %[[VAL_3:.*]] = "tf.Min"(%[[VAL_0]], %[[VAL_2]]) <{keep_dims = false}> : (tensor<1x256xf32>, tensor<1xi64>) -> tensor<1xf32>
 // CHECK:           return %[[VAL_3]] : tensor<1xf32>
 // CHECK:         }
 func.func @convert_reduce_to_min(%arg0: tensor<1x256xf32>) -> tensor<1xf32> {
@@ -3100,8 +3100,8 @@
 
 // CHECK-LABEL:   func @convert_reduce_to_min_int(
 // CHECK-SAME:                                %[[VAL_0:.*]]: tensor<1x4xi32>) -> tensor<1xi32> {
-// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64>
-// CHECK:           %[[VAL_3:.*]] = "tf.Min"(%[[VAL_0]], %[[VAL_2]]) {keep_dims = false} : (tensor<1x4xi32>, tensor<1xi64>) -> tensor<1xi32>
+// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi64>}> : () -> tensor<1xi64>
+// CHECK:           %[[VAL_3:.*]] = "tf.Min"(%[[VAL_0]], %[[VAL_2]]) <{keep_dims = false}> : (tensor<1x4xi32>, tensor<1xi64>) -> tensor<1xi32>
 // CHECK:           return %[[VAL_3]] : tensor<1xi32>
 func.func @convert_reduce_to_min_int(%arg0: tensor<1x4xi32>) -> tensor<1xi32> {
   // 2147483647 is MAX for INT32
@@ -3115,9 +3115,9 @@
 }
 
 // CHECK-LABEL:   func @convert_iota_1d() -> tensor<123xf32> {
-// CHECK-DAG:       %[[VAL_0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
-// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() {value = dense<1.230000e+02> : tensor<f32>} : () -> tensor<f32>
-// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
+// CHECK-DAG:       %[[VAL_0:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<1.230000e+02> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
 // CHECK:           %[[VAL_3:.*]] = "tf.Range"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<123xf32>
 // CHECK:           return %[[VAL_3]] : tensor<123xf32>
 // CHECK:         }
@@ -3127,13 +3127,13 @@
 }
 
 // CHECK-LABEL:   func @convert_iota_3d() -> tensor<5x7x9xi32> {
-// CHECK-DAG:       %[[VAL_0:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() {value = dense<7> : tensor<i32>} : () -> tensor<i32>
-// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       %[[VAL_0:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<7> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           %[[VAL_3:.*]] = "tf.Range"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<7xi32>
-// CHECK:           %[[VAL_4:.*]] = "tf.Const"() {value = dense<[1, 7, 1]> : tensor<3xi64>} : () -> tensor<3xi64>
+// CHECK:           %[[VAL_4:.*]] = "tf.Const"() <{value = dense<[1, 7, 1]> : tensor<3xi64>}> : () -> tensor<3xi64>
 // CHECK:           %[[VAL_5:.*]] = "tf.Reshape"(%[[VAL_3]], %[[VAL_4]]) : (tensor<7xi32>, tensor<3xi64>) -> tensor<1x7x1xi32>
-// CHECK:           %[[VAL_6:.*]] = "tf.Const"() {value = dense<[5, 7, 9]> : tensor<3xi64>} : () -> tensor<3xi64>
+// CHECK:           %[[VAL_6:.*]] = "tf.Const"() <{value = dense<[5, 7, 9]> : tensor<3xi64>}> : () -> tensor<3xi64>
 // CHECK:           %[[VAL_7:.*]] = "tf.BroadcastTo"(%[[VAL_5]], %[[VAL_6]]) : (tensor<1x7x1xi32>, tensor<3xi64>) -> tensor<5x7x9xi32>
 // CHECK:           return %[[VAL_7]] : tensor<5x7x9xi32>
 // CHECK:         }
@@ -3144,7 +3144,7 @@
 
 // CHECK-LABEL:   func @convert_avgpool_valid(
 // CHECK-SAME:                                %[[VAL_0:.*]]: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
-// CHECK:           %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) {data_format = "NHWC", ksize = [1, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 1]} : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32>
+// CHECK:           %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) <{data_format = "NHWC", ksize = [1, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 1]}> : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32>
 // CHECK:           return %[[VAL_1]] : tensor<4x7x7x8xf32>
 // CHECK:         }
 func.func @convert_avgpool_valid(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
@@ -3166,7 +3166,7 @@
 
 // CHECK-LABEL:   func @convert_avgpool_valid_broadcasted_divisor(
 // CHECK-SAME:                                %[[VAL_0:.*]]: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
-// CHECK:           %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) {data_format = "NHWC", ksize = [1, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 1]} : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32>
+// CHECK:           %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) <{data_format = "NHWC", ksize = [1, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 1]}> : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32>
 // CHECK:           return %[[VAL_1]] : tensor<4x7x7x8xf32>
 // CHECK:         }
 func.func @convert_avgpool_valid_broadcasted_divisor(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
@@ -3189,7 +3189,7 @@
 
 // CHECK-LABEL:   func @convert_avgpool_valid_channel_first(
 // CHECK-SAME:                                %[[VAL_0:.*]]: tensor<4x3x16x16xf32>) -> tensor<4x3x7x7xf32> {
-// CHECK:           %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) {data_format = "NCHW", ksize = [1, 1, 3, 3], padding = "VALID", strides = [1, 1, 2, 2]} : (tensor<4x3x16x16xf32>) -> tensor<4x3x7x7xf32>
+// CHECK:           %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) <{data_format = "NCHW", ksize = [1, 1, 3, 3], padding = "VALID", strides = [1, 1, 2, 2]}> : (tensor<4x3x16x16xf32>) -> tensor<4x3x7x7xf32>
 // CHECK:           return %[[VAL_1]] : tensor<4x3x7x7xf32>
 // CHECK:         }
 func.func @convert_avgpool_valid_channel_first(%arg0: tensor<4x3x16x16xf32>) -> tensor<4x3x7x7xf32> {
@@ -3211,7 +3211,7 @@
 
 // CHECK-LABEL:   func @convert_avgpool_valid_rw(
 // CHECK-SAME:                               %[[VAL_0:.*]]: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
-// CHECK:           %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) {data_format = "NHWC", ksize = [1, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 1]} : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32>
+// CHECK:           %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) <{data_format = "NHWC", ksize = [1, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 1]}> : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32>
 // CHECK:           return %[[VAL_1]] : tensor<4x7x7x8xf32>
 // CHECK:         }
 func.func @convert_avgpool_valid_rw(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
@@ -3243,7 +3243,7 @@
 
 // CHECK-LABEL:   func @convert_avgpool_valid_rw_broadcasted_const_lhs(
 // CHECK-SAME:                               %[[VAL_0:.*]]: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
-// CHECK:           %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) {data_format = "NHWC", ksize = [1, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 1]} : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32>
+// CHECK:           %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) <{data_format = "NHWC", ksize = [1, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 1]}> : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32>
 // CHECK:           return %[[VAL_1]] : tensor<4x7x7x8xf32>
 // CHECK:         }
 func.func @convert_avgpool_valid_rw_broadcasted_const_lhs(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
@@ -3276,7 +3276,7 @@
 
 // CHECK-LABEL:   func @convert_avgpool_valid_3d(
 // CHECK-SAME:                                %[[VAL_0:.*]]: tensor<4x16x16x16x8xf32>) -> tensor<4x7x7x7x8xf32> {
-// CHECK:           %[[VAL_1:.*]] = "tf.AvgPool3D"(%[[VAL_0]]) {data_format = "NDHWC", ksize = [1, 3, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 2, 1]} : (tensor<4x16x16x16x8xf32>) -> tensor<4x7x7x7x8xf32>
+// CHECK:           %[[VAL_1:.*]] = "tf.AvgPool3D"(%[[VAL_0]]) <{data_format = "NDHWC", ksize = [1, 3, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 2, 1]}> : (tensor<4x16x16x16x8xf32>) -> tensor<4x7x7x7x8xf32>
 // CHECK:           return %[[VAL_1]] : tensor<4x7x7x7x8xf32>
 // CHECK:         }
 func.func @convert_avgpool_valid_3d(%arg0: tensor<4x16x16x16x8xf32>) -> tensor<4x7x7x7x8xf32> {
@@ -3298,7 +3298,7 @@
 
 // CHECK-LABEL:   func @convert_avgpool_valid_3d_channel_first(
 // CHECK-SAME:                                %[[VAL_0:.*]]: tensor<4x3x16x16x16xf32>) -> tensor<4x3x7x7x7xf32> {
-// CHECK:           %[[VAL_1:.*]] = "tf.AvgPool3D"(%[[VAL_0]]) {data_format = "NCDHW", ksize = [1, 1, 3, 3, 3], padding = "VALID", strides = [1, 1, 2, 2, 2]} : (tensor<4x3x16x16x16xf32>) -> tensor<4x3x7x7x7xf32>
+// CHECK:           %[[VAL_1:.*]] = "tf.AvgPool3D"(%[[VAL_0]]) <{data_format = "NCDHW", ksize = [1, 1, 3, 3, 3], padding = "VALID", strides = [1, 1, 2, 2, 2]}> : (tensor<4x3x16x16x16xf32>) -> tensor<4x3x7x7x7xf32>
 // CHECK:           return %[[VAL_1]] : tensor<4x3x7x7x7xf32>
 // CHECK:         }
 func.func @convert_avgpool_valid_3d_channel_first(%arg0: tensor<4x3x16x16x16xf32>) -> tensor<4x3x7x7x7xf32> {
@@ -3320,7 +3320,7 @@
 
 // CHECK-LABEL:   func @convert_avgpool_same(
 // CHECK-SAME:                               %[[VAL_0:.*]]: tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> {
-// CHECK:           %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) {data_format = "NHWC", ksize = [1, 3, 3, 1], padding = "SAME", strides = [1, 2, 2, 1]} : (tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32>
+// CHECK:           %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) <{data_format = "NHWC", ksize = [1, 3, 3, 1], padding = "SAME", strides = [1, 2, 2, 1]}> : (tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32>
 // CHECK:           return %[[VAL_1]] : tensor<4x8x8x8xf32>
 // CHECK:         }
 func.func @convert_avgpool_same(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> {
@@ -3352,7 +3352,7 @@
 
 // CHECK-LABEL:   func @convert_avgpool_reshape_broadcast(
 // CHECK-SAME:                               %[[VAL_0:.*]]: tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> {
-// CHECK:           %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) {data_format = "NHWC", ksize = [1, 3, 3, 1], padding = "SAME", strides = [1, 2, 2, 1]} : (tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32>
+// CHECK:           %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) <{data_format = "NHWC", ksize = [1, 3, 3, 1], padding = "SAME", strides = [1, 2, 2, 1]}> : (tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32>
 // CHECK:           return %[[VAL_1]] : tensor<4x8x8x8xf32>
 // CHECK:         }
 func.func @convert_avgpool_reshape_broadcast(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> {
@@ -3376,7 +3376,7 @@
 
 // CHECK-LABEL:   func @convert_maxpool_valid(
 // CHECK-SAME:                                %[[VAL_0:.*]]: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
-// CHECK:           %[[VAL_1:.*]] = "tf.MaxPool"(%[[VAL_0]]) {data_format = "NHWC", explicit_paddings = [], ksize = [1, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 1]} : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32>
+// CHECK:           %[[VAL_1:.*]] = "tf.MaxPool"(%[[VAL_0]]) <{data_format = "NHWC", explicit_paddings = [], ksize = [1, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 1]}> : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32>
 // CHECK:           return %[[VAL_1]] : tensor<4x7x7x8xf32>
 // CHECK:         }
 func.func @convert_maxpool_valid(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
@@ -3397,7 +3397,7 @@
 
 // CHECK-LABEL:   func @convert_maxpool_valid_channel_first(
 // CHECK-SAME:                                %[[VAL_0:.*]]: tensor<4x3x16x16xf32>) -> tensor<4x3x7x7xf32> {
-// CHECK:           %[[VAL_1:.*]]  = "tf.MaxPool"(%[[VAL_0]]) {data_format = "NCHW", explicit_paddings = [], ksize = [1, 1, 3, 3], padding = "VALID", strides = [1, 1, 2, 2]} : (tensor<4x3x16x16xf32>) -> tensor<4x3x7x7xf32>
+// CHECK:           %[[VAL_1:.*]]  = "tf.MaxPool"(%[[VAL_0]]) <{data_format = "NCHW", explicit_paddings = [], ksize = [1, 1, 3, 3], padding = "VALID", strides = [1, 1, 2, 2]}> : (tensor<4x3x16x16xf32>) -> tensor<4x3x7x7xf32>
 // CHECK:           return %[[VAL_1]] : tensor<4x3x7x7xf32>
 // CHECK:         }
 func.func @convert_maxpool_valid_channel_first(%arg0: tensor<4x3x16x16xf32>) -> tensor<4x3x7x7xf32> {
@@ -3418,7 +3418,7 @@
 
 // CHECK-LABEL:   func @convert_maxpool_valid_3d(
 // CHECK-SAME:                                %[[VAL_0:.*]]: tensor<4x16x16x16x8xf32>) -> tensor<4x7x7x7x8xf32> {
-// CHECK:           %[[VAL_1:.*]] = "tf.MaxPool3D"(%[[VAL_0]]) {data_format = "NDHWC", ksize = [1, 3, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 2, 1]} : (tensor<4x16x16x16x8xf32>) -> tensor<4x7x7x7x8xf32>
+// CHECK:           %[[VAL_1:.*]] = "tf.MaxPool3D"(%[[VAL_0]]) <{data_format = "NDHWC", ksize = [1, 3, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 2, 1]}> : (tensor<4x16x16x16x8xf32>) -> tensor<4x7x7x7x8xf32>
 // CHECK:           return %[[VAL_1]] : tensor<4x7x7x7x8xf32>
 // CHECK:         }
 func.func @convert_maxpool_valid_3d(%arg0: tensor<4x16x16x16x8xf32>) -> tensor<4x7x7x7x8xf32> {
@@ -3439,7 +3439,7 @@
 
 // CHECK-LABEL:   func @convert_maxpool_valid_3d_channel_first(
 // CHECK-SAME:                                %[[VAL_0:.*]]: tensor<4x3x16x16x16xf32>) -> tensor<4x3x7x7x7xf32> {
-// CHECK:           %[[VAL_1:.*]]  = "tf.MaxPool3D"(%[[VAL_0]]) {data_format = "NCDHW", ksize = [1, 1, 3, 3, 3], padding = "VALID", strides = [1, 1, 2, 2, 2]} : (tensor<4x3x16x16x16xf32>) -> tensor<4x3x7x7x7xf32>
+// CHECK:           %[[VAL_1:.*]]  = "tf.MaxPool3D"(%[[VAL_0]]) <{data_format = "NCDHW", ksize = [1, 1, 3, 3, 3], padding = "VALID", strides = [1, 1, 2, 2, 2]}> : (tensor<4x3x16x16x16xf32>) -> tensor<4x3x7x7x7xf32>
 // CHECK:           return %[[VAL_1]] : tensor<4x3x7x7x7xf32>
 // CHECK:         }
 func.func @convert_maxpool_valid_3d_channel_first(%arg0: tensor<4x3x16x16x16xf32>) -> tensor<4x3x7x7x7xf32> {
@@ -3460,7 +3460,7 @@
 
 // CHECK-LABEL:   func @convert_maxpool_same(
 // CHECK-SAME:                               %[[VAL_0:.*]]: tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> {
-// CHECK:           %[[VAL_1:.*]] = "tf.MaxPool"(%[[VAL_0]]) {data_format = "NHWC", explicit_paddings = [], ksize = [1, 3, 3, 1], padding = "SAME", strides = [1, 2, 2, 1]} : (tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32>
+// CHECK:           %[[VAL_1:.*]] = "tf.MaxPool"(%[[VAL_0]]) <{data_format = "NHWC", explicit_paddings = [], ksize = [1, 3, 3, 1], padding = "SAME", strides = [1, 2, 2, 1]}> : (tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32>
 // CHECK:           return %[[VAL_1]] : tensor<4x8x8x8xf32>
 // CHECK:         }
 func.func @convert_maxpool_same(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> {
@@ -3574,8 +3574,8 @@
 }
 
 // CHECK-LABEL: func @convert_floor_mod_float_cst
-// CHECK-DAG: %[[CST1:.*]] = "tf.Const"() {value = dense<2.000000e+00> : tensor<192x8xbf16>} : () -> tensor<192x8xbf16>
-// CHECK-DAG: %[[CST2:.*]] = "tf.Const"() {value = dense<2.000000e+00> : tensor<192x8xbf16>} : () -> tensor<192x8xbf16>
+// CHECK-DAG: %[[CST1:.*]] = "tf.Const"() <{value = dense<2.000000e+00> : tensor<192x8xbf16>}> : () -> tensor<192x8xbf16>
+// CHECK-DAG: %[[CST2:.*]] = "tf.Const"() <{value = dense<2.000000e+00> : tensor<192x8xbf16>}> : () -> tensor<192x8xbf16>
 // CHECK: %[[RESULT:.*]] = "tf.FloorMod"(%arg0, %[[CST2]]) : (tensor<192x8xbf16>, tensor<192x8xbf16>) -> tensor<192x8xbf16>
 // CHECK: return %[[RESULT]] : tensor<192x8xbf16>
 // CHECK: }
@@ -3592,8 +3592,8 @@
 }
 
 // CHECK-LABEL: func @convert_floor_mod_int_cst
-// CHECK-DAG: %[[CST1:.*]] = "tf.Const"() {value = dense<2> : tensor<192x8xi32>} : () -> tensor<192x8xi32>
-// CHECK-DAG: %[[CST2:.*]] = "tf.Const"() {value = dense<2> : tensor<192x8xi32>} : () -> tensor<192x8xi32>
+// CHECK-DAG: %[[CST1:.*]] = "tf.Const"() <{value = dense<2> : tensor<192x8xi32>}> : () -> tensor<192x8xi32>
+// CHECK-DAG: %[[CST2:.*]] = "tf.Const"() <{value = dense<2> : tensor<192x8xi32>}> : () -> tensor<192x8xi32>
 // CHECK: %[[RESULT:.*]] = "tf.FloorMod"(%arg0, %[[CST2]]) : (tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi32>
 // CHECK: return %[[RESULT]] : tensor<192x8xi32>
 // CHECK: }
@@ -3649,7 +3649,7 @@
 }
 
 // CHECK-LABEL: func @convert_floor_div_cst
-// CHECK: %[[CST2:.*]] = "tf.Const"() {value = dense<2.000000e+00> : tensor<10x10xbf16>} : () -> tensor<10x10xbf16>
+// CHECK: %[[CST2:.*]] = "tf.Const"() <{value = dense<2.000000e+00> : tensor<10x10xbf16>}> : () -> tensor<10x10xbf16>
 // CHECK: %[[RESULT:.*]] = "tf.FloorDiv"(%arg0, %[[CST2]]) : (tensor<10x10xbf16>, tensor<10x10xbf16>) -> tensor<10x10xbf16>
 // CHECK: return %[[RESULT]]
 // CHECK: }
@@ -3674,7 +3674,7 @@
 }
 
 // CHECK-LABEL: func @convert_floor_div_cst2
-// CHECK: %[[CST2:.*]] = "tf.Const"() {value = dense<2.000000e+00> : tensor<10x10xbf16>} : () -> tensor<10x10xbf16>
+// CHECK: %[[CST2:.*]] = "tf.Const"() <{value = dense<2.000000e+00> : tensor<10x10xbf16>}> : () -> tensor<10x10xbf16>
 // CHECK: %[[RESULT:.*]] = "tf.FloorDiv"(%arg0, %[[CST2]]) : (tensor<10x10xbf16>, tensor<10x10xbf16>) -> tensor<10x10xbf16>
 // CHECK: return %[[RESULT]]
 // CHECK: }
@@ -3791,10 +3791,10 @@
 // CHECK-LABEL: func @convert_gather_offset(
 // CHECK-SAME:                                      %[[VAL_0:.*]]: tensor<1x20xi32>,
 // CHECK-SAME:                                      %[[VAL_1:.*]]: tensor<1x1xi32>) -> tensor<1x1xi32> {
-// CHECK:           %[[VAL_2:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64>
+// CHECK:           %[[VAL_2:.*]] = "tf.Const"() <{value = dense<[1, 0]> : tensor<2xi64>}> : () -> tensor<2xi64>
 // CHECK:           %[[VAL_3:.*]] = "tf.Transpose"(%[[VAL_0]], %[[VAL_2]]) : (tensor<1x20xi32>, tensor<2xi64>) -> tensor<20x1xi32>
 // CHECK:           %[[VAL_4:.*]] = "tf.GatherNd"(%[[VAL_3]], %[[VAL_1]]) : (tensor<20x1xi32>, tensor<1x1xi32>) -> tensor<1x1xi32>
-// CHECK:           %[[VAL_5:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64>
+// CHECK:           %[[VAL_5:.*]] = "tf.Const"() <{value = dense<[1, 0]> : tensor<2xi64>}> : () -> tensor<2xi64>
 // CHECK:           %[[VAL_6:.*]] = "tf.Transpose"(%[[VAL_4]], %[[VAL_5]]) : (tensor<1x1xi32>, tensor<2xi64>) -> tensor<1x1xi32>
 // CHECK:           return %[[VAL_6]] : tensor<1x1xi32>
 // CHECK:         }
@@ -3815,12 +3815,12 @@
 // CHECK-LABEL:   func @convert_gather_to_slice_batch_size_1(
 // CHECK-SAME:                         %[[ARG_0:.*]]: tensor<1x2944xi32>,
 // CHECK-SAME:                         %[[ARG_1:.*]]: tensor<1x2xi32>)
-// CHECK-DAG:         %[[CST:.*]] = "tf.Const"() {value = dense<[0, 1440]> : tensor<2xi32>} : () -> tensor<2xi32>
-// CHECK-DAG:         %[[CST_0:.*]] = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> tensor<2xi32>
+// CHECK-DAG:         %[[CST:.*]] = "tf.Const"() <{value = dense<[0, 1440]> : tensor<2xi32>}> : () -> tensor<2xi32>
+// CHECK-DAG:         %[[CST_0:.*]] = "tf.Const"() <{value = dense<0> : tensor<2xi32>}> : () -> tensor<2xi32>
 // CHECK:             %[[VAL_0:.*]] = "tf.Maximum"(%[[ARG_1]], %[[CST_0:.*]]) : (tensor<1x2xi32>, tensor<2xi32>) -> tensor<1x2xi32>
 // CHECK:             %[[VAL_1:.*]] = "tf.Minimum"(%[[VAL_0]], %[[CST]]) : (tensor<1x2xi32>, tensor<2xi32>) -> tensor<1x2xi32>
-// CHECK-DAG:         %[[CST_1:.*]] = "tf.Const"() {value = dense<[1, 1504]> : tensor<2xi32>} : () -> tensor<2xi32>
-// CHECK:             %[[VAL_2:.*]] = "tf.Squeeze"(%[[VAL_1]]) {squeeze_dims = [0]} : (tensor<1x2xi32>) -> tensor<2xi32>
+// CHECK-DAG:         %[[CST_1:.*]] = "tf.Const"() <{value = dense<[1, 1504]> : tensor<2xi32>}> : () -> tensor<2xi32>
+// CHECK:             %[[VAL_2:.*]] = "tf.Squeeze"(%[[VAL_1]]) <{squeeze_dims = [0]}> : (tensor<1x2xi32>) -> tensor<2xi32>
 // CHECK:             %[[VAL_3:.*]] = "tf.Slice"(%[[ARG_0]], %[[VAL_2]], %[[CST_1]]) : (tensor<1x2944xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x1504xi32>
 // CHECK:            return %[[VAL_3]]
 // CHECK:         }
@@ -3878,27 +3878,27 @@
 // CHECK-LABEL:   func @convert_gather_to_slice(
 // CHECK-SAME:                         %[[ARG_0:.*]]: tensor<3x2944xi32>,
 // CHECK-SAME:                         %[[ARG_1:.*]]: tensor<3x2xi32>)
-// CHECK-DAG:        %[[CST:.*]] = "tf.Const"() {value = dense<[2, 1440]> : tensor<2xi32>} : () -> tensor<2xi32>
-// CHECK-DAG:        %[[CST_0:.*]] = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> tensor<2xi32>
+// CHECK-DAG:        %[[CST:.*]] = "tf.Const"() <{value = dense<[2, 1440]> : tensor<2xi32>}> : () -> tensor<2xi32>
+// CHECK-DAG:        %[[CST_0:.*]] = "tf.Const"() <{value = dense<0> : tensor<2xi32>}> : () -> tensor<2xi32>
 // CHECK:            %[[VAL_0:.*]] = "tf.Maximum"(%[[ARG_1]], %[[CST_0]]) : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<3x2xi32>
 // CHECK:            %[[VAL_1:.*]] = "tf.Minimum"(%[[VAL_0]], %[[CST]]) : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<3x2xi32>
-// CHECK-DAG:        %[[CST_1:.*]] = "tf.Const"() {value = dense<[1, 1504]> : tensor<2xi32>} : () -> tensor<2xi32>
-// CHECK-DAG:        %[[CST_2:.*]] = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> tensor<2xi32>
-// CHECK-DAG:        %[[CST_3:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
+// CHECK-DAG:        %[[CST_1:.*]] = "tf.Const"() <{value = dense<[1, 1504]> : tensor<2xi32>}> : () -> tensor<2xi32>
+// CHECK-DAG:        %[[CST_2:.*]] = "tf.Const"() <{value = dense<0> : tensor<2xi32>}> : () -> tensor<2xi32>
+// CHECK-DAG:        %[[CST_3:.*]] = "tf.Const"() <{value = dense<[1, 2]> : tensor<2xi32>}> : () -> tensor<2xi32>
 // CHECK:            %[[VAL_2:.*]] = "tf.Slice"(%[[VAL_1]], %[[CST_2]], %[[CST_3]]) : (tensor<3x2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x2xi32>
-// CHECK:            %[[VAL_3:.*]] = "tf.Squeeze"(%[[VAL_2]]) {squeeze_dims = [0]} : (tensor<1x2xi32>) -> tensor<2xi32>
+// CHECK:            %[[VAL_3:.*]] = "tf.Squeeze"(%[[VAL_2]]) <{squeeze_dims = [0]}> : (tensor<1x2xi32>) -> tensor<2xi32>
 // CHECK:            %[[VAL_4:.*]] = "tf.Slice"(%[[ARG_0]], %[[VAL_3]], %[[CST_1]]) : (tensor<3x2944xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x1504xi32>
-// CHECK-DAG:        %[[CST_4:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
-// CHECK-DAG:        %[[CST_5:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
+// CHECK-DAG:        %[[CST_4:.*]] = "tf.Const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<2xi32>
+// CHECK-DAG:        %[[CST_5:.*]] = "tf.Const"() <{value = dense<[1, 2]> : tensor<2xi32>}> : () -> tensor<2xi32>
 // CHECK:            %[[VAL_5:.*]] = "tf.Slice"(%[[VAL_1]], %[[CST_4]], %[[CST_5]]) : (tensor<3x2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x2xi32>
-// CHECK:            %[[VAL_6:.*]] = "tf.Squeeze"(%[[VAL_5]]) {squeeze_dims = [0]} : (tensor<1x2xi32>) -> tensor<2xi32>
+// CHECK:            %[[VAL_6:.*]] = "tf.Squeeze"(%[[VAL_5]]) <{squeeze_dims = [0]}> : (tensor<1x2xi32>) -> tensor<2xi32>
 // CHECK:            %[[VAL_7:.*]] = "tf.Slice"(%[[ARG_0]], %[[VAL_6]], %[[CST_1]]) : (tensor<3x2944xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x1504xi32>
-// CHECK-DAG:        %[[CST_6:.*]] = "tf.Const"() {value = dense<[2, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
-// CHECK-DAG:        %[[CST_7:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
+// CHECK-DAG:        %[[CST_6:.*]] = "tf.Const"() <{value = dense<[2, 0]> : tensor<2xi32>}> : () -> tensor<2xi32>
+// CHECK-DAG:        %[[CST_7:.*]] = "tf.Const"() <{value = dense<[1, 2]> : tensor<2xi32>}> : () -> tensor<2xi32>
 // CHECK:            %[[VAL_8:.*]] = "tf.Slice"(%[[VAL_1]], %[[CST_6]], %[[CST_7]]) : (tensor<3x2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x2xi32>
-// CHECK:            %[[VAL_9:.*]] = "tf.Squeeze"(%[[VAL_8]]) {squeeze_dims = [0]} : (tensor<1x2xi32>) -> tensor<2xi32>
+// CHECK:            %[[VAL_9:.*]] = "tf.Squeeze"(%[[VAL_8]]) <{squeeze_dims = [0]}> : (tensor<1x2xi32>) -> tensor<2xi32>
 // CHECK:            %[[VAL_10:.*]] = "tf.Slice"(%[[ARG_0]], %[[VAL_9]], %[[CST_1]]) : (tensor<3x2944xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x1504xi32>
-// CHECK-DAG:        %[[CST_8:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:        %[[CST_8:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:            %[[VAL_11:.*]] = "tf.ConcatV2"(%[[VAL_4]], %[[VAL_7]], %[[VAL_10]], %[[CST_8]]) : (tensor<1x1504xi32>, tensor<1x1504xi32>, tensor<1x1504xi32>, tensor<i32>) -> tensor<3x1504xi32>
 // CHECK:            return %[[VAL_11]]
 // CHECK:         }
@@ -3936,17 +3936,17 @@
 // CHECK-SAME:                                      %[[VAL_0:.*]]: tensor<7x3xf32>,
 // CHECK-SAME:                                      %[[VAL_1:.*]]: tensor<i32>,
 // CHECK-SAME:                                      %[[VAL_2:.*]]: tensor<i32>) -> tensor<4x2xf32> {
-// CHECK:           %[[VAL_3:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-// CHECK:           %[[VAL_4:.*]] = "tf.Cast"(%[[VAL_1]]) {Truncate = false} : (tensor<i32>) -> tensor<i32>
-// CHECK:           %[[VAL_5:.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
+// CHECK:           %[[VAL_3:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
+// CHECK:           %[[VAL_4:.*]] = "tf.Cast"(%[[VAL_1]]) <{Truncate = false}> : (tensor<i32>) -> tensor<i32>
+// CHECK:           %[[VAL_5:.*]] = "tf.Const"() <{value = dense<3> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           %[[VAL_6:.*]] = "tf.Minimum"(%[[VAL_4]], %[[VAL_5]]) : (tensor<i32>, tensor<i32>) -> tensor<i32>
 // CHECK:           %[[VAL_7:.*]] = "tf.Maximum"(%[[VAL_6]], %[[VAL_3]]) : (tensor<i32>, tensor<i32>) -> tensor<i32>
-// CHECK:           %[[VAL_8:.*]] = "tf.Cast"(%[[VAL_2]]) {Truncate = false} : (tensor<i32>) -> tensor<i32>
-// CHECK:           %[[VAL_9:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+// CHECK:           %[[VAL_8:.*]] = "tf.Cast"(%[[VAL_2]]) <{Truncate = false}> : (tensor<i32>) -> tensor<i32>
+// CHECK:           %[[VAL_9:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           %[[VAL_10:.*]] = "tf.Minimum"(%[[VAL_8]], %[[VAL_9]]) : (tensor<i32>, tensor<i32>) -> tensor<i32>
 // CHECK:           %[[VAL_11:.*]] = "tf.Maximum"(%[[VAL_10]], %[[VAL_3]]) : (tensor<i32>, tensor<i32>) -> tensor<i32>
-// CHECK:           %[[VAL_12:.*]] = "tf.Pack"(%[[VAL_7]], %[[VAL_11]]) {axis = 0 : i64} : (tensor<i32>, tensor<i32>) -> tensor<2xi32>
-// CHECK:           %[[VAL_13:.*]] = "tf.Const"() {value = dense<[4, 2]> : tensor<2xi64>} : () -> tensor<2xi64>
+// CHECK:           %[[VAL_12:.*]] = "tf.Pack"(%[[VAL_7]], %[[VAL_11]]) <{axis = 0 : i64}> : (tensor<i32>, tensor<i32>) -> tensor<2xi32>
+// CHECK:           %[[VAL_13:.*]] = "tf.Const"() <{value = dense<[4, 2]> : tensor<2xi64>}> : () -> tensor<2xi64>
 // CHECK:           %[[VAL_14:.*]] = "tf.Slice"(%[[VAL_0]], %[[VAL_12]], %[[VAL_13]]) : (tensor<7x3xf32>, tensor<2xi32>, tensor<2xi64>) -> tensor<4x2xf32>
 // CHECK:           return %[[VAL_14]] : tensor<4x2xf32>
 // CHECK:         }
@@ -3959,17 +3959,17 @@
 // CHECK-SAME:                                           %[[VAL_0:.*]]: tensor<7x3xf32>,
 // CHECK-SAME:                                           %[[VAL_1:.*]]: tensor<ui32>,
 // CHECK-SAME:                                           %[[VAL_2:.*]]: tensor<ui32>) -> tensor<4x2xf32> {
-// CHECK:           %[[VAL_3:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-// CHECK:           %[[VAL_4:.*]] = "tf.Cast"(%[[VAL_1]]) {Truncate = false} : (tensor<ui32>) -> tensor<i32>
-// CHECK:           %[[VAL_5:.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
+// CHECK:           %[[VAL_3:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
+// CHECK:           %[[VAL_4:.*]] = "tf.Cast"(%[[VAL_1]]) <{Truncate = false}> : (tensor<ui32>) -> tensor<i32>
+// CHECK:           %[[VAL_5:.*]] = "tf.Const"() <{value = dense<3> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           %[[VAL_6:.*]] = "tf.Minimum"(%[[VAL_4]], %[[VAL_5]]) : (tensor<i32>, tensor<i32>) -> tensor<i32>
 // CHECK:           %[[VAL_7:.*]] = "tf.Maximum"(%[[VAL_6]], %[[VAL_3]]) : (tensor<i32>, tensor<i32>) -> tensor<i32>
-// CHECK:           %[[VAL_8:.*]] = "tf.Cast"(%[[VAL_2]]) {Truncate = false} : (tensor<ui32>) -> tensor<i32>
-// CHECK:           %[[VAL_9:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+// CHECK:           %[[VAL_8:.*]] = "tf.Cast"(%[[VAL_2]]) <{Truncate = false}> : (tensor<ui32>) -> tensor<i32>
+// CHECK:           %[[VAL_9:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           %[[VAL_10:.*]] = "tf.Minimum"(%[[VAL_8]], %[[VAL_9]]) : (tensor<i32>, tensor<i32>) -> tensor<i32>
 // CHECK:           %[[VAL_11:.*]] = "tf.Maximum"(%[[VAL_10]], %[[VAL_3]]) : (tensor<i32>, tensor<i32>) -> tensor<i32>
-// CHECK:           %[[VAL_12:.*]] = "tf.Pack"(%[[VAL_7]], %[[VAL_11]]) {axis = 0 : i64} : (tensor<i32>, tensor<i32>) -> tensor<2xi32>
-// CHECK:           %[[VAL_13:.*]] = "tf.Const"() {value = dense<[4, 2]> : tensor<2xi64>} : () -> tensor<2xi64>
+// CHECK:           %[[VAL_12:.*]] = "tf.Pack"(%[[VAL_7]], %[[VAL_11]]) <{axis = 0 : i64}> : (tensor<i32>, tensor<i32>) -> tensor<2xi32>
+// CHECK:           %[[VAL_13:.*]] = "tf.Const"() <{value = dense<[4, 2]> : tensor<2xi64>}> : () -> tensor<2xi64>
 // CHECK:           %[[VAL_14:.*]] = "tf.Slice"(%[[VAL_0]], %[[VAL_12]], %[[VAL_13]]) : (tensor<7x3xf32>, tensor<2xi32>, tensor<2xi64>) -> tensor<4x2xf32>
 // CHECK:           return %[[VAL_14]] : tensor<4x2xf32>
 // CHECK:         }
@@ -4091,17 +4091,17 @@
   func.return %0 : tensor<16x1504xf32>
 }
 
-// CHECK-LABEL:   func.func @convert_scatter_add(
-// CHECK-SAME:                                   %[[VAL_0:.*]]: tensor<20x6xf32>,
-// CHECK-SAME:                                   %[[VAL_1:.*]]: tensor<4x1xi32>,
-// CHECK-SAME:                                   %[[VAL_2:.*]]: tensor<4x6xf32>) -> tensor<20x6xf32> {
-// CHECK:           %[[VAL_3:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) ({
-// CHECK:           ^bb0(%[[VAL_4:.*]]: tensor<f32>, %[[VAL_5:.*]]: tensor<f32>):
-// CHECK:             %[[VAL_6:.*]] = "tf.AddV2"(%[[VAL_4]], %[[VAL_5]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
-// CHECK:             mhlo.return %[[VAL_6]] : tensor<f32>
-// CHECK:           }) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false} : (tensor<20x6xf32>, tensor<4x1xi32>, tensor<4x6xf32>) -> tensor<20x6xf32>
-// CHECK:           return %[[VAL_3]] : tensor<20x6xf32>
-// CHECK:         }
+// CHECK-LABEL:  func.func @convert_scatter_add(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<20x6xf32>,
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<4x1xi32>,
+// CHECK-SAME:      %[[VAL_2:.*]]: tensor<4x6xf32>) -> tensor<20x6xf32> {
+// CHECK:    %[[VAL_6:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) ({
+// CHECK:    ^bb0(%[[VAL_3:.*]]: tensor<f32>, %[[VAL_4:.*]]: tensor<f32>):
+// CHECK:      %[[VAL_5:.*]] = mhlo.add %[[VAL_3]], %[[VAL_4]] : tensor<f32>
+// CHECK:      mhlo.return %[[VAL_5]] : tensor<f32>
+// CHECK:    }) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false} : (tensor<20x6xf32>, tensor<4x1xi32>, tensor<4x6xf32>) -> tensor<20x6xf32>
+// CHECK:    return %[[VAL_6]] : tensor<20x6xf32>
+// CHECK:  }
 func.func @convert_scatter_add(%arg0: tensor<20x6xf32>, %arg1: tensor<4x1xi32>, %arg2: tensor<4x6xf32>) -> tensor<20x6xf32> {
   %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({
   ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
@@ -4119,17 +4119,17 @@
   func.return %0 : tensor<20x6xf32>
 }
 
-// CHECK-LABEL:   func.func @convert_scatter_max(
-// CHECK-SAME:                                   %[[VAL_0:.*]]: tensor<20x6xf32>,
-// CHECK-SAME:                                   %[[VAL_1:.*]]: tensor<4x1xi32>,
-// CHECK-SAME:                                   %[[VAL_2:.*]]: tensor<4x6xf32>) -> tensor<20x6xf32> {
-// CHECK:           %[[VAL_3:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) ({
-// CHECK:           ^bb0(%[[VAL_4:.*]]: tensor<f32>, %[[VAL_5:.*]]: tensor<f32>):
-// CHECK:             %[[VAL_6:.*]] = "tf.Maximum"(%[[VAL_4]], %[[VAL_5]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
-// CHECK:             mhlo.return %[[VAL_6]] : tensor<f32>
-// CHECK:           }) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false} : (tensor<20x6xf32>, tensor<4x1xi32>, tensor<4x6xf32>) -> tensor<20x6xf32>
-// CHECK:           return %[[VAL_3]] : tensor<20x6xf32>
-// CHECK:         }
+// CHECK-LABEL:  func.func @convert_scatter_max(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<20x6xf32>,
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<4x1xi32>,
+// CHECK-SAME:      %[[VAL_2:.*]]: tensor<4x6xf32>) -> tensor<20x6xf32> {
+// CHECK:    %[[VAL_6:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) ({
+// CHECK:    ^bb0(%[[VAL_3:.*]]: tensor<f32>, %[[VAL_4:.*]]: tensor<f32>):
+// CHECK:      %[[VAL_5:.*]] = mhlo.maximum %[[VAL_3]], %[[VAL_4]] : tensor<f32>
+// CHECK:      mhlo.return %[[VAL_5]] : tensor<f32>
+// CHECK:    }) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false} : (tensor<20x6xf32>, tensor<4x1xi32>, tensor<4x6xf32>) -> tensor<20x6xf32>
+// CHECK:    return %[[VAL_6]] : tensor<20x6xf32>
+// CHECK:  }
 func.func @convert_scatter_max(%arg0: tensor<20x6xf32>, %arg1: tensor<4x1xi32>, %arg2: tensor<4x6xf32>) -> tensor<20x6xf32> {
   %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({
   ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
@@ -4147,17 +4147,17 @@
   func.return %0 : tensor<20x6xf32>
 }
 
-// CHECK-LABEL:   func.func @convert_scatter_min(
-// CHECK-SAME:                                   %[[VAL_0:.*]]: tensor<20x6xf32>,
-// CHECK-SAME:                                   %[[VAL_1:.*]]: tensor<4x1xi32>,
-// CHECK-SAME:                                   %[[VAL_2:.*]]: tensor<4x6xf32>) -> tensor<20x6xf32> {
-// CHECK:           %[[VAL_3:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) ({
-// CHECK:           ^bb0(%[[VAL_4:.*]]: tensor<f32>, %[[VAL_5:.*]]: tensor<f32>):
-// CHECK:             %[[VAL_6:.*]] = "tf.Minimum"(%[[VAL_4]], %[[VAL_5]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
-// CHECK:             mhlo.return %[[VAL_6]] : tensor<f32>
-// CHECK:           }) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false} : (tensor<20x6xf32>, tensor<4x1xi32>, tensor<4x6xf32>) -> tensor<20x6xf32>
-// CHECK:           return %[[VAL_3]] : tensor<20x6xf32>
-// CHECK:         }
+// CHECK-LABEL:  func.func @convert_scatter_min(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<20x6xf32>,
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<4x1xi32>,
+// CHECK-SAME:      %[[VAL_2:.*]]: tensor<4x6xf32>) -> tensor<20x6xf32> {
+// CHECK:    %[[VAL_6:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) ({
+// CHECK:    ^bb0(%[[VAL_3:.*]]: tensor<f32>, %[[VAL_4:.*]]: tensor<f32>):
+// CHECK:      %[[VAL_5:.*]] = mhlo.minimum %[[VAL_3]], %[[VAL_4]] : tensor<f32>
+// CHECK:      mhlo.return %[[VAL_5]] : tensor<f32>
+// CHECK:    }) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false} : (tensor<20x6xf32>, tensor<4x1xi32>, tensor<4x6xf32>) -> tensor<20x6xf32>
+// CHECK:    return %[[VAL_6]] : tensor<20x6xf32>
+// CHECK:  }
 func.func @convert_scatter_min(%arg0: tensor<20x6xf32>, %arg1: tensor<4x1xi32>, %arg2: tensor<4x6xf32>) -> tensor<20x6xf32> {
   %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({
   ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
@@ -4175,17 +4175,17 @@
   func.return %0 : tensor<20x6xf32>
 }
 
-// CHECK-LABEL:   func.func @convert_scatter_sub(
-// CHECK-SAME:                                   %[[VAL_0:.*]]: tensor<20x6xf32>,
-// CHECK-SAME:                                   %[[VAL_1:.*]]: tensor<4x1xi32>,
-// CHECK-SAME:                                   %[[VAL_2:.*]]: tensor<4x6xf32>) -> tensor<20x6xf32> {
-// CHECK:           %[[VAL_3:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) ({
-// CHECK:           ^bb0(%[[VAL_4:.*]]: tensor<f32>, %[[VAL_5:.*]]: tensor<f32>):
-// CHECK:             %[[VAL_6:.*]] = "tf.Sub"(%[[VAL_4]], %[[VAL_5]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
-// CHECK:             mhlo.return %[[VAL_6]] : tensor<f32>
-// CHECK:           }) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false} : (tensor<20x6xf32>, tensor<4x1xi32>, tensor<4x6xf32>) -> tensor<20x6xf32>
-// CHECK:           return %[[VAL_3]] : tensor<20x6xf32>
-// CHECK:         }
+// CHECK-LABEL:  func.func @convert_scatter_sub(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<20x6xf32>,
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<4x1xi32>,
+// CHECK-SAME:      %[[VAL_2:.*]]: tensor<4x6xf32>) -> tensor<20x6xf32> {
+// CHECK:    %[[VAL_6:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) ({
+// CHECK:    ^bb0(%[[VAL_3:.*]]: tensor<f32>, %[[VAL_4:.*]]: tensor<f32>):
+// CHECK:      %[[VAL_5:.*]] = mhlo.subtract %[[VAL_3]], %[[VAL_4]] : tensor<f32>
+// CHECK:      mhlo.return %[[VAL_5]] : tensor<f32>
+// CHECK:    }) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false} : (tensor<20x6xf32>, tensor<4x1xi32>, tensor<4x6xf32>) -> tensor<20x6xf32>
+// CHECK:    return %[[VAL_6]] : tensor<20x6xf32>
+// CHECK:  }
 func.func @convert_scatter_sub(%arg0: tensor<20x6xf32>, %arg1: tensor<4x1xi32>, %arg2: tensor<4x6xf32>) -> tensor<20x6xf32> {
   %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({
   ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
@@ -4206,7 +4206,7 @@
 // CHECK-LABEL:   func @convert_argmax(
 // CHECK-SAME:                         %[[VAL_0:.*]]: tensor<4x32x256xf32>) -> (tensor<4x32xf32>, tensor<4x32xi32>) {
 // CHECK:           %[[VAL_9:.*]] = "tf.Const"{{.*}}value = dense<2> : tensor<1xi64>
-// CHECK:           %[[VAL_10:.*]] = "tf.Max"(%[[VAL_0]], %[[VAL_9]]) {keep_dims = false} : {{.*}} -> tensor<4x32xf32>
+// CHECK:           %[[VAL_10:.*]] = "tf.Max"(%[[VAL_0]], %[[VAL_9]]) <{keep_dims = false}> : {{.*}} -> tensor<4x32xf32>
 // CHECK:           %[[VAL_11:.*]] = "tf.ArgMax"(%[[VAL_0]], %[[VAL_9]]) : {{.*}} -> tensor<4x32xi32>
 // CHECK:           return %[[VAL_10]], %[[VAL_11]]
 // CHECK:         }
@@ -4233,11 +4233,11 @@
 
 // CHECK-LABEL: func @convert_argmax_constant(
 // CHECK-SAME:                                        %[[VAL_0:.*]]: tensor<2x2x4xf32>) -> (tensor<2x2xf32>, tensor<2x2xi32>) {
-// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() {value = dense<0xFF800000> : tensor<f32>} : () -> tensor<f32>
-// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() {value = dense<{{\[\[}}[0, 1, 2, 3], [0, 1, 2, 3]], {{\[\[}}0, 1, 2, 3], [0, 1, 2, 3]]]> : tensor<2x2x4xi32>} : () -> tensor<2x2x4xi32>
-// CHECK-DAG:       %[[VAL_4:.*]] = "tf.Const"() {value = dense<2> : tensor<1xi64>} : () -> tensor<1xi64>
-// CHECK:           %[[VAL_5:.*]] = "tf.Max"(%[[VAL_0]], %[[VAL_4]]) {keep_dims = false} : (tensor<2x2x4xf32>, tensor<1xi64>) -> tensor<2x2xf32>
+// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<0xFF800000> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() <{value = dense<{{\[\[}}[0, 1, 2, 3], [0, 1, 2, 3]], {{\[\[}}0, 1, 2, 3], [0, 1, 2, 3]]]> : tensor<2x2x4xi32>}> : () -> tensor<2x2x4xi32>
+// CHECK-DAG:       %[[VAL_4:.*]] = "tf.Const"() <{value = dense<2> : tensor<1xi64>}> : () -> tensor<1xi64>
+// CHECK:           %[[VAL_5:.*]] = "tf.Max"(%[[VAL_0]], %[[VAL_4]]) <{keep_dims = false}> : (tensor<2x2x4xf32>, tensor<1xi64>) -> tensor<2x2xf32>
 // CHECK:           %[[VAL_6:.*]] = "tf.ArgMax"(%[[VAL_0]], %[[VAL_4]]) : (tensor<2x2x4xf32>, tensor<1xi64>) -> tensor<2x2xi32>
 // CHECK:           return %[[VAL_5]], %[[VAL_6]] : tensor<2x2xf32>, tensor<2x2xi32>
 // CHECK:         }
@@ -4263,11 +4263,11 @@
 
 // CHECK-LABEL:   func @convert_argmax_constant_non_z_axis(
 // CHECK-SAME:      %[[VAL_0:.*]]: tensor<4x4xf32>) -> (tensor<4xf32>, tensor<4xi32>) {
-// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() {value = dense<0xFF800000> : tensor<f32>} : () -> tensor<f32>
-// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() {value = dense<{{\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3]]> : tensor<4x4xi32>} : () -> tensor<4x4xi32>
-// CHECK-DAG:       %[[VAL_4:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi64>} : () -> tensor<1xi64>
-// CHECK:           %[[VAL_5:.*]] = "tf.Max"(%[[VAL_0]], %[[VAL_4]]) {keep_dims = false} : (tensor<4x4xf32>, tensor<1xi64>) -> tensor<4xf32>
+// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<0xFF800000> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() <{value = dense<{{\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3]]> : tensor<4x4xi32>}> : () -> tensor<4x4xi32>
+// CHECK-DAG:       %[[VAL_4:.*]] = "tf.Const"() <{value = dense<0> : tensor<1xi64>}> : () -> tensor<1xi64>
+// CHECK:           %[[VAL_5:.*]] = "tf.Max"(%[[VAL_0]], %[[VAL_4]]) <{keep_dims = false}> : (tensor<4x4xf32>, tensor<1xi64>) -> tensor<4xf32>
 // CHECK:           %[[VAL_6:.*]] = "tf.ArgMax"(%[[VAL_0]], %[[VAL_4]]) : (tensor<4x4xf32>, tensor<1xi64>) -> tensor<4xi32>
 // CHECK:           return %[[VAL_5]], %[[VAL_6]] : tensor<4xf32>, tensor<4xi32>
 // CHECK:         }
@@ -4293,14 +4293,14 @@
 
 // CHECK-LABEL:   func.func @convert_argmax_bool(
 // CHECK-SAME:                                   %[[VAL_0:.*]]: tensor<2xi1>) -> tensor<i32> {
-// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() {value = dense<2> : tensor<i32>} : () -> tensor<i32>
-// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<2> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           %[[VAL_4:.*]] = "tf.Range"(%[[VAL_1]], %[[VAL_2]], %[[VAL_3]]) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<2xi32>
-// CHECK-DAG:       %[[VAL_5:.*]] = "tf.Const"() {value = dense<false> : tensor<i1>} : () -> tensor<i1>
-// CHECK-DAG:       %[[VAL_6:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-// CHECK-DAG:       %[[VAL_7:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi64>} : () -> tensor<1xi64>
-// CHECK:           %[[VAL_8:.*]] = "tf.Any"(%[[VAL_0]], %[[VAL_7]]) {keep_dims = false} : (tensor<2xi1>, tensor<1xi64>) -> tensor<i1>
+// CHECK-DAG:       %[[VAL_5:.*]] = "tf.Const"() <{value = dense<false> : tensor<i1>}> : () -> tensor<i1>
+// CHECK-DAG:       %[[VAL_6:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG:       %[[VAL_7:.*]] = "tf.Const"() <{value = dense<0> : tensor<1xi64>}> : () -> tensor<1xi64>
+// CHECK:           %[[VAL_8:.*]] = "tf.Any"(%[[VAL_0]], %[[VAL_7]]) <{keep_dims = false}> : (tensor<2xi1>, tensor<1xi64>) -> tensor<i1>
 // CHECK:           %[[VAL_9:.*]] = "tf.ArgMax"(%[[VAL_0]], %[[VAL_7]]) : (tensor<2xi1>, tensor<1xi64>) -> tensor<i32>
 // CHECK:           return %[[VAL_9]] : tensor<i32>
 // CHECK:         }
@@ -4326,7 +4326,7 @@
 // CHECK-LABEL:   func @convert_argmin(
 // CHECK-SAME:                         %[[VAL_0:.*]]: tensor<4x32x256xf32>) -> (tensor<4x32xf32>, tensor<4x32xi32>) {
 // CHECK:           %[[VAL_9:.*]] = "tf.Const"{{.*}}value = dense<2> : tensor<1xi64>
-// CHECK:           %[[VAL_10:.*]] = "tf.Min"(%[[VAL_0]], %[[VAL_9]]) {keep_dims = false} : {{.*}} -> tensor<4x32xf32>
+// CHECK:           %[[VAL_10:.*]] = "tf.Min"(%[[VAL_0]], %[[VAL_9]]) <{keep_dims = false}> : {{.*}} -> tensor<4x32xf32>
 // CHECK:           %[[VAL_11:.*]] = "tf.ArgMin"(%[[VAL_0]], %[[VAL_9]]) : {{.*}} -> tensor<4x32xi32>
 // CHECK:           return %[[VAL_10]], %[[VAL_11]]
 // CHECK:         }
@@ -4354,7 +4354,7 @@
 // CHECK-LABEL:   func @convert_argmin_i16(
 // CHECK-SAME:                         %[[VAL_0:.*]]: tensor<2xi16>) -> (tensor<i16>, tensor<i32>) {
 // CHECK:           %[[VAL_9:.*]] = "tf.Const"{{.*}}value = dense<0> : tensor<1xi64>
-// CHECK:           %[[VAL_10:.*]] = "tf.Min"(%[[VAL_0]], %[[VAL_9]]) {keep_dims = false} : {{.*}} -> tensor<i16>
+// CHECK:           %[[VAL_10:.*]] = "tf.Min"(%[[VAL_0]], %[[VAL_9]]) <{keep_dims = false}> : {{.*}} -> tensor<i16>
 // CHECK:           %[[VAL_11:.*]] = "tf.ArgMin"(%[[VAL_0]], %[[VAL_9]]) : {{.*}} -> tensor<i32>
 // CHECK:           return %[[VAL_10]], %[[VAL_11]]
 // CHECK:         }
@@ -4381,11 +4381,11 @@
 
 // CHECK-LABEL: func @convert_argmin_constant(
 // CHECK-SAME:                                        %[[VAL_0:.*]]: tensor<2x2x4xf32>) -> (tensor<2x2xf32>, tensor<2x2xi32>) {
-// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() {value = dense<0x7F800000> : tensor<f32>} : () -> tensor<f32>
-// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() {value = dense<{{\[\[}}[0, 1, 2, 3], [0, 1, 2, 3]], {{\[\[}}0, 1, 2, 3], [0, 1, 2, 3]]]> : tensor<2x2x4xi32>} : () -> tensor<2x2x4xi32>
-// CHECK-DAG:       %[[VAL_4:.*]] = "tf.Const"() {value = dense<2> : tensor<1xi64>} : () -> tensor<1xi64>
-// CHECK:           %[[VAL_5:.*]] = "tf.Min"(%[[VAL_0]], %[[VAL_4]]) {keep_dims = false} : (tensor<2x2x4xf32>, tensor<1xi64>) -> tensor<2x2xf32>
+// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<0x7F800000> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() <{value = dense<{{\[\[}}[0, 1, 2, 3], [0, 1, 2, 3]], {{\[\[}}0, 1, 2, 3], [0, 1, 2, 3]]]> : tensor<2x2x4xi32>}> : () -> tensor<2x2x4xi32>
+// CHECK-DAG:       %[[VAL_4:.*]] = "tf.Const"() <{value = dense<2> : tensor<1xi64>}> : () -> tensor<1xi64>
+// CHECK:           %[[VAL_5:.*]] = "tf.Min"(%[[VAL_0]], %[[VAL_4]]) <{keep_dims = false}> : (tensor<2x2x4xf32>, tensor<1xi64>) -> tensor<2x2xf32>
 // CHECK:           %[[VAL_6:.*]] = "tf.ArgMin"(%[[VAL_0]], %[[VAL_4]]) : (tensor<2x2x4xf32>, tensor<1xi64>) -> tensor<2x2xi32>
 // CHECK:           return %[[VAL_5]], %[[VAL_6]] : tensor<2x2xf32>, tensor<2x2xi32>
 // CHECK:         }
@@ -4411,14 +4411,14 @@
 
 // CHECK-LABEL:   func.func @convert_argmin_bool(
 // CHECK-SAME:                                   %[[VAL_0:.*]]: tensor<2xi1>) -> tensor<i32> {
-// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() {value = dense<2> : tensor<i32>} : () -> tensor<i32>
-// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<2> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           %[[VAL_4:.*]] = "tf.Range"(%[[VAL_1]], %[[VAL_2]], %[[VAL_3]]) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<2xi32>
-// CHECK-DAG:       %[[VAL_5:.*]] = "tf.Const"() {value = dense<false> : tensor<i1>} : () -> tensor<i1>
-// CHECK-DAG:       %[[VAL_6:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-// CHECK-DAG:       %[[VAL_7:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi64>} : () -> tensor<1xi64>
-// CHECK:           %[[VAL_8:.*]] = "tf.All"(%[[VAL_0]], %[[VAL_7]]) {keep_dims = false} : (tensor<2xi1>, tensor<1xi64>) -> tensor<i1>
+// CHECK-DAG:       %[[VAL_5:.*]] = "tf.Const"() <{value = dense<false> : tensor<i1>}> : () -> tensor<i1>
+// CHECK-DAG:       %[[VAL_6:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG:       %[[VAL_7:.*]] = "tf.Const"() <{value = dense<0> : tensor<1xi64>}> : () -> tensor<1xi64>
+// CHECK:           %[[VAL_8:.*]] = "tf.All"(%[[VAL_0]], %[[VAL_7]]) <{keep_dims = false}> : (tensor<2xi1>, tensor<1xi64>) -> tensor<i1>
 // CHECK:           %[[VAL_9:.*]] = "tf.ArgMin"(%[[VAL_0]], %[[VAL_7]]) : (tensor<2xi1>, tensor<1xi64>) -> tensor<i32>
 // CHECK:           return %[[VAL_9]] : tensor<i32>
 // CHECK:         }
@@ -4442,16 +4442,16 @@
 
 // CHECK-LABEL:   func @convert_argmax_with_reshaped_iota(
 // CHECK-SAME:      %[[VAL_0:.*]]: tensor<1x32x1xf32>) -> (tensor<1x1xf32>, tensor<1x1xi32>) {
-// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() {value = dense<0xFF800000> : tensor<f32>} : () -> tensor<f32>
-// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-// CHECK-DAG:       %[[VAL_4:.*]] = "tf.Const"() {value = dense<32> : tensor<i32>} : () -> tensor<i32>
-// CHECK-DAG:       %[[VAL_5:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<0xFF800000> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG:       %[[VAL_4:.*]] = "tf.Const"() <{value = dense<32> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG:       %[[VAL_5:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           %[[VAL_6:.*]] = "tf.Range"(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]]) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<32xi32>
 // CHECK-DAG:       %[[VAL_7:.*]] = arith.constant dense<[1, 32, 1]> : tensor<3xi64>
 // CHECK:           %[[VAL_8:.*]] = "tf.Reshape"(%[[VAL_6]], %[[VAL_7]]) : (tensor<32xi32>, tensor<3xi64>) -> tensor<1x32x1xi32>
-// CHECK-DAG:       %[[VAL_9:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64>
-// CHECK:           %[[VAL_10:.*]] = "tf.Max"(%[[VAL_0]], %[[VAL_9]]) {keep_dims = false} : (tensor<1x32x1xf32>, tensor<1xi64>) -> tensor<1x1xf32>
+// CHECK-DAG:       %[[VAL_9:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi64>}> : () -> tensor<1xi64>
+// CHECK:           %[[VAL_10:.*]] = "tf.Max"(%[[VAL_0]], %[[VAL_9]]) <{keep_dims = false}> : (tensor<1x32x1xf32>, tensor<1xi64>) -> tensor<1x1xf32>
 // CHECK:           %[[VAL_11:.*]] = "tf.ArgMax"(%[[VAL_0]], %[[VAL_9]]) : (tensor<1x32x1xf32>, tensor<1xi64>) -> tensor<1x1xi32>
 // CHECK:           return %[[VAL_10]], %[[VAL_11]] : tensor<1x1xf32>, tensor<1x1xi32>
 // CHECK:         }
@@ -4488,7 +4488,7 @@
 
 // CHECK-LABEL:   func @convert_not_i8(
 // CHECK-SAME:                      %[[ARG:.*]]: tensor<7x9x11xi8>) -> tensor<7x9x11xi8> {
-// CHECK:           %[[CST:.*]] = "tf.Const"() {value = dense<-1> : tensor<i8>} : () -> tensor<i8>
+// CHECK:           %[[CST:.*]] = "tf.Const"() <{value = dense<-1> : tensor<i8>}> : () -> tensor<i8>
 // CHECK:           %[[RES:.*]] = "tf.BitwiseXor"(%[[ARG]], %[[CST]]) : (tensor<7x9x11xi8>, tensor<i8>) -> tensor<7x9x11xi8>
 // CHECK:           return %[[RES]] : tensor<7x9x11xi8>
 // CHECK:         }
@@ -4499,7 +4499,7 @@
 
 // CHECK-LABEL:   func @convert_not_i16(
 // CHECK-SAME:                      %[[ARG:.*]]: tensor<7x9x11xi16>) -> tensor<7x9x11xi16> {
-// CHECK:           %[[CST:.*]] = "tf.Const"() {value = dense<-1> : tensor<i16>} : () -> tensor<i16>
+// CHECK:           %[[CST:.*]] = "tf.Const"() <{value = dense<-1> : tensor<i16>}> : () -> tensor<i16>
 // CHECK:           %[[RES:.*]] = "tf.BitwiseXor"(%[[ARG]], %[[CST]]) : (tensor<7x9x11xi16>, tensor<i16>) -> tensor<7x9x11xi16>
 // CHECK:           return %[[RES]] : tensor<7x9x11xi16>
 // CHECK:         }
@@ -4510,7 +4510,7 @@
 
 // CHECK-LABEL:   func @convert_not_i32(
 // CHECK-SAME:                      %[[ARG:.*]]: tensor<7x9x11xi32>) -> tensor<7x9x11xi32> {
-// CHECK:           %[[CST:.*]] = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
+// CHECK:           %[[CST:.*]] = "tf.Const"() <{value = dense<-1> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           %[[RES:.*]] = "tf.BitwiseXor"(%[[ARG]], %[[CST]]) : (tensor<7x9x11xi32>, tensor<i32>) -> tensor<7x9x11xi32>
 // CHECK:           return %[[RES]] : tensor<7x9x11xi32>
 // CHECK:         }
@@ -4521,7 +4521,7 @@
 
 // CHECK-LABEL:   func @convert_not_i64(
 // CHECK-SAME:                      %[[ARG:.*]]: tensor<7x9x11xi64>) -> tensor<7x9x11xi64> {
-// CHECK:           %[[CST:.*]] = "tf.Const"() {value = dense<-1> : tensor<i64>} : () -> tensor<i64>
+// CHECK:           %[[CST:.*]] = "tf.Const"() <{value = dense<-1> : tensor<i64>}> : () -> tensor<i64>
 // CHECK:           %[[RES:.*]] = "tf.BitwiseXor"(%[[ARG]], %[[CST]]) : (tensor<7x9x11xi64>, tensor<i64>) -> tensor<7x9x11xi64>
 // CHECK:           return %[[RES]] : tensor<7x9x11xi64>
 // CHECK:         }
@@ -4532,7 +4532,7 @@
 
 // CHECK-LABEL:   func @convert_not_ui8(
 // CHECK-SAME:                      %[[ARG:.*]]: tensor<7x9x11xui8>) -> tensor<7x9x11xui8> {
-// CHECK:           %[[CST:.*]] = "tf.Const"() {value = dense<255> : tensor<ui8>} : () -> tensor<ui8>
+// CHECK:           %[[CST:.*]] = "tf.Const"() <{value = dense<255> : tensor<ui8>}> : () -> tensor<ui8>
 // CHECK:           %[[RES:.*]] = "tf.BitwiseXor"(%[[ARG]], %[[CST]]) : (tensor<7x9x11xui8>, tensor<ui8>) -> tensor<7x9x11xui8>
 // CHECK:           return %[[RES]] : tensor<7x9x11xui8>
 // CHECK:         }
@@ -4543,7 +4543,7 @@
 
 // CHECK-LABEL:   func @convert_not_ui16(
 // CHECK-SAME:                      %[[ARG:.*]]: tensor<7x9x11xui16>) -> tensor<7x9x11xui16> {
-// CHECK:           %[[CST:.*]] = "tf.Const"() {value = dense<65535> : tensor<ui16>} : () -> tensor<ui16>
+// CHECK:           %[[CST:.*]] = "tf.Const"() <{value = dense<65535> : tensor<ui16>}> : () -> tensor<ui16>
 // CHECK:           %[[RES:.*]] = "tf.BitwiseXor"(%[[ARG]], %[[CST]]) : (tensor<7x9x11xui16>, tensor<ui16>) -> tensor<7x9x11xui16>
 // CHECK:           return %[[RES]] : tensor<7x9x11xui16>
 // CHECK:         }
@@ -4554,7 +4554,7 @@
 
 // CHECK-LABEL:   func @convert_not_ui32(
 // CHECK-SAME:                      %[[ARG:.*]]: tensor<7x9x11xui32>) -> tensor<7x9x11xui32> {
-// CHECK:           %[[CST:.*]] = "tf.Const"() {value = dense<4294967295> : tensor<ui32>} : () -> tensor<ui32>
+// CHECK:           %[[CST:.*]] = "tf.Const"() <{value = dense<4294967295> : tensor<ui32>}> : () -> tensor<ui32>
 // CHECK:           %[[RES:.*]] = "tf.BitwiseXor"(%[[ARG]], %[[CST]]) : (tensor<7x9x11xui32>, tensor<ui32>) -> tensor<7x9x11xui32>
 // CHECK:           return %[[RES]] : tensor<7x9x11xui32>
 // CHECK:         }
@@ -4565,7 +4565,7 @@
 
 // CHECK-LABEL:   func @convert_not_ui64(
 // CHECK-SAME:                      %[[ARG:.*]]: tensor<7x9x11xui64>) -> tensor<7x9x11xui64> {
-// CHECK:           %[[CST:.*]] = "tf.Const"() {value = dense<18446744073709551615> : tensor<ui64>} : () -> tensor<ui64>
+// CHECK:           %[[CST:.*]] = "tf.Const"() <{value = dense<18446744073709551615> : tensor<ui64>}> : () -> tensor<ui64>
 // CHECK:           %[[RES:.*]] = "tf.BitwiseXor"(%[[ARG]], %[[CST]]) : (tensor<7x9x11xui64>, tensor<ui64>) -> tensor<7x9x11xui64>
 // CHECK:           return %[[RES]] : tensor<7x9x11xui64>
 // CHECK:         }
@@ -4580,7 +4580,7 @@
 // CHECK-DAG:      %[[CST_0:.*]] = arith.constant dense<1> : tensor<i32>
 // CHECK-DAG:      %[[CST_1:.*]] = arith.constant dense<0> : tensor<i32>
 // CHECK-DAG:      %[[CST_2:.*]] = arith.constant dense<1000> : tensor<i32>
-// CHECK:          %[[WHILEREGION_0:.*]]:3 = "tf.WhileRegion"(%[[CST_1]], %[[CST_0]], %[[CST_2]]) ({
+// CHECK:          %[[WHILEREGION_0:.*]]:3 = "tf.WhileRegion"(%[[CST_1]], %[[CST_0]], %[[CST_2]]) <{is_stateless = false, parallel_iterations = 10 : i64}> ({
 // CHECK:          ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>):
 // CHECK:            %[[LESS_0:.*]] = "tf.Less"(%arg0, %arg2) : (tensor<i32>, tensor<i32>) -> tensor<i1>
 // CHECK:            "tf.Yield"(%[[LESS_0]]) : (tensor<i1>) -> ()
@@ -4588,7 +4588,7 @@
 // CHECK:          ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>):
 // CHECK:            %[[ADDV2_0:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
 // CHECK:            "tf.Yield"(%[[ADDV2_0]], %arg1, %arg2) : (tensor<i32>, tensor<i32>, tensor<i32>) -> ()
-// CHECK:          }) {is_stateless = false, parallel_iterations = 10 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>)
+// CHECK:          }) : (tensor<i32>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>)
 // CHECK:          return %[[WHILEREGION_0]]#0, %[[WHILEREGION_0]]#1, %[[WHILEREGION_0]]#2 : tensor<i32>, tensor<i32>, tensor<i32>
 // CHECK:        }
 func.func @while_with_variadic() -> (tensor<i32>, tensor<i32>, tensor<i32>) {
@@ -4613,19 +4613,19 @@
 // CHECK-DAG:      %[[CST_0:.*]] = arith.constant dense<1> : tensor<i32>
 // CHECK-DAG:      %[[CST_1:.*]] = arith.constant dense<0> : tensor<i32>
 // CHECK-DAG:      %[[CST_2:.*]] = arith.constant dense<1000> : tensor<i32>
-// CHECK:          %[[WHILEREGION_0:.*]]:5 = "tf.WhileRegion"(%[[CST_1]], %[[CST_0]], %[[CST_2]], %arg0, %arg1) ({
+// CHECK:          %[[WHILEREGION_0:.*]]:5 = "tf.WhileRegion"(%[[CST_1]], %[[CST_0]], %[[CST_2]], %arg0, %arg1) <{is_stateless = false, parallel_iterations = 10 : i64}> ({
 // CHECK:          ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>, %arg5: tensor<1x256xf32>, %arg6: tensor<1xf32>):
 // CHECK:            %[[LESS_0:.*]] = "tf.Less"(%arg2, %arg4) : (tensor<i32>, tensor<i32>) -> tensor<i1>
 // CHECK:            "tf.Yield"(%[[LESS_0]]) : (tensor<i1>) -> ()
 // CHECK:          },  {
 // CHECK:          ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>, %arg5: tensor<1x256xf32>, %arg6: tensor<1xf32>):
 // CHECK:            %[[ADDV2_0:.*]] = "tf.AddV2"(%arg2, %arg3) : (tensor<i32>, tensor<i32>) -> tensor<i32>
-// CHECK-DAG:        %[[CONST_0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
-// CHECK-DAG:        %[[CONST_1:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64>
-// CHECK:            %[[SUM_0:.*]] = "tf.Sum"(%arg5, %[[CONST_1]]) {keep_dims = false} : (tensor<1x256xf32>, tensor<1xi64>) -> tensor<1xf32>
+// CHECK-DAG:        %[[CONST_0:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG:        %[[CONST_1:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi64>}> : () -> tensor<1xi64>
+// CHECK:            %[[SUM_0:.*]] = "tf.Sum"(%arg5, %[[CONST_1]]) <{keep_dims = false}> : (tensor<1x256xf32>, tensor<1xi64>) -> tensor<1xf32>
 // CHECK:            %[[ADDV2_1:.*]] = "tf.AddV2"(%[[SUM_0]], %arg6) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
 // CHECK:            "tf.Yield"(%[[ADDV2_0]], %arg3, %arg4, %arg5, %[[ADDV2_1]]) : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<1x256xf32>, tensor<1xf32>) -> ()
-// CHECK:          }) {is_stateless = false, parallel_iterations = 10 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<1x256xf32>, tensor<1xf32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<1x256xf32>, tensor<1xf32>)
+// CHECK:          }) : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<1x256xf32>, tensor<1xf32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<1x256xf32>, tensor<1xf32>)
 // CHECK:          return %[[WHILEREGION_0]]#0, %[[WHILEREGION_0]]#1, %[[WHILEREGION_0]]#2, %[[WHILEREGION_0]]#4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<1xf32>
 // CHECK:        }
 func.func @while_with_reduce(%arg0: tensor<1x256xf32>, %arg1: tensor<1xf32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<1xf32>) {
@@ -4656,11 +4656,11 @@
 // CHECK-LABEL:  func @if
 // CHECK-DAG:      %[[CST_0:.*]] = arith.constant dense<0> : tensor<i32>
 // CHECK-DAG:      %[[CST_1:.*]] = arith.constant dense<1000> : tensor<i32>
-// CHECK:          %[[RES:.*]]  = "tf.IfRegion"(%arg0) ({
+// CHECK:          %[[RES:.*]]  = "tf.IfRegion"(%arg0) <{is_stateless = false}> ({
 // CHECK:            "tf.Yield"(%[[CST_0]]) : (tensor<i32>) -> ()
 // CHECK:          }, {
 // CHECK:            "tf.Yield"(%[[CST_1]]) : (tensor<i32>) -> ()
-// CHECK:          }) {is_stateless = false} : (tensor<i1>) -> tensor<i32>
+// CHECK:          }) : (tensor<i1>) -> tensor<i32>
 // CHECK:          return %[[RES]]
 func.func @if(%arg0: tensor<i1>) -> (tensor<i32>) {
   %cst_0 = arith.constant dense<0> : tensor<i32>
@@ -4679,7 +4679,7 @@
 // CHECK-SAME:                                       %[[VAL_2:[a-z0-9]*]]: tensor<i32>,
 // CHECK-SAME:                                       %[[VAL_3:[a-z0-9]*]]: tensor<i32>,
 // CHECK-SAME:                                       %[[VAL_4:[a-z0-9]*]]: tensor<i32>) -> tensor<28x1x100xf32> {
-// CHECK:         %0 = "tf.Pack"(%arg2, %arg3, %arg4) {axis = 0 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
+// CHECK:         %0 = "tf.Pack"(%arg2, %arg3, %arg4) <{axis = 0 : i64}> : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
 // CHECK:         %1 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %0) : (tensor<28x1x100xf32>, tensor<1x1x100xf32>, tensor<3xi32>) -> tensor<28x1x100xf32>
 // CHECK:         return %1 : tensor<28x1x100xf32>
 func.func @convert_dynamic_update_slice(%arg0: tensor<28x1x100xf32>, %arg1: tensor<1x1x100xf32>, %arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>) -> tensor<28x1x100xf32> {
@@ -4692,7 +4692,7 @@
 // CHECK-SAME:                                       %arg1: tensor<?x2xi32>,
 // CHECK-SAME:                                       %arg2: tensor<i32>,
 // CHECK-SAME:                                       %arg3: tensor<i32>) -> tensor<?x4xi32> {
-// CHECK:         %0 = "tf.Pack"(%arg2, %arg3) {axis = 0 : i64} : (tensor<i32>, tensor<i32>) -> tensor<2xi32>
+// CHECK:         %0 = "tf.Pack"(%arg2, %arg3) <{axis = 0 : i64}> : (tensor<i32>, tensor<i32>) -> tensor<2xi32>
 // CHECK:         %1 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %0) : (tensor<?x4xi32>, tensor<?x2xi32>, tensor<2xi32>) -> tensor<?x4xi32>
 // CHECK:         return %1 : tensor<?x4xi32>
 // CHECK:         }
@@ -4707,7 +4707,7 @@
 // CHECK-SAME:                                       %arg2: tensor<i32>,
 // CHECK-SAME:                                       %arg3: tensor<i32>,
 // CHECK-SAME:                                       %arg4: tensor<i32>) -> tensor<1x?x256xf32> {
-// CHECK:         %0 = "tf.Pack"(%arg2, %arg3, %arg4) {axis = 0 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
+// CHECK:         %0 = "tf.Pack"(%arg2, %arg3, %arg4) <{axis = 0 : i64}> : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
 // CHECK:         %1 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %0) : (tensor<1x?x256xf32>, tensor<1x1x256xf32>, tensor<3xi32>) -> tensor<1x?x256xf32>
 // CHECK:         return %1 : tensor<1x?x256xf32>
 // CHECK:         }
@@ -4719,9 +4719,9 @@
 // CHECK-LABEL:   func @convert_reduce_to_all(
 // CHECK-SAME:                                %[[ARG_0:.*]]: tensor<1x2x3x4x5xi1>,
 // CHECK-SAME:                                %[[ARG_1:.*]]: tensor<2xi64>) -> tensor<2x4x5xi1> {
-// CHECK-DAG:       %[[TRUE_CST:.*]] = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
-// CHECK-DAG:       %[[DIMENSIONS:.*]] = "tf.Const"() {value = dense<[0, 2]> : tensor<2xi64>} : () -> tensor<2xi64>
-// CHECK:           %[[VAL_0:.*]] = "tf.All"(%[[ARG_0]], %[[DIMENSIONS]]) {keep_dims = false} : (tensor<1x2x3x4x5xi1>, tensor<2xi64>) -> tensor<2x4x5xi1>
+// CHECK-DAG:       %[[TRUE_CST:.*]] = "tf.Const"() <{value = dense<true> : tensor<i1>}> : () -> tensor<i1>
+// CHECK-DAG:       %[[DIMENSIONS:.*]] = "tf.Const"() <{value = dense<[0, 2]> : tensor<2xi64>}> : () -> tensor<2xi64>
+// CHECK:           %[[VAL_0:.*]] = "tf.All"(%[[ARG_0]], %[[DIMENSIONS]]) <{keep_dims = false}> : (tensor<1x2x3x4x5xi1>, tensor<2xi64>) -> tensor<2x4x5xi1>
 // CHECK:           return %[[VAL_0:.*]] : tensor<2x4x5xi1>
 // CHECK:         }
 func.func @convert_reduce_to_all(%arg0: tensor<1x2x3x4x5xi1>, %arg1: tensor<2xi64>) -> tensor<2x4x5xi1> {
@@ -4738,8 +4738,8 @@
 // CHECK-SAME:                                %[[ARG_0:.*]]: tensor<i1>,
 // CHECK-SAME:                                %[[ARG_1:.*]]: tensor<1x2x3x4x5xi1>,
 // CHECK-SAME:                                %[[ARG_2:.*]]: tensor<2xi64>) -> tensor<2x4x5xi1> {
-// CHECK-DAG:       %[[DIMENSIONS:.*]] = "tf.Const"() {value = dense<[0, 2]> : tensor<2xi64>} : () -> tensor<2xi64>
-// CHECK:           %[[VAL_0:.*]] = "tf.All"(%[[ARG_1]], %[[DIMENSIONS]]) {keep_dims = false} : (tensor<1x2x3x4x5xi1>, tensor<2xi64>) -> tensor<2x4x5xi1>
+// CHECK-DAG:       %[[DIMENSIONS:.*]] = "tf.Const"() <{value = dense<[0, 2]> : tensor<2xi64>}> : () -> tensor<2xi64>
+// CHECK:           %[[VAL_0:.*]] = "tf.All"(%[[ARG_1]], %[[DIMENSIONS]]) <{keep_dims = false}> : (tensor<1x2x3x4x5xi1>, tensor<2xi64>) -> tensor<2x4x5xi1>
 // CHECK:           %[[VAL_1:.*]] = "tf.LogicalAnd"(%[[VAL_0]], %[[ARG_0]]) : (tensor<2x4x5xi1>, tensor<i1>) -> tensor<2x4x5xi1>
 // CHECK:           return %[[VAL_1:.*]] : tensor<2x4x5xi1>
 // CHECK:         }
@@ -4755,9 +4755,9 @@
 // CHECK-LABEL:   func @convert_reduce_to_any(
 // CHECK-SAME:                                %[[ARG_0:.*]]: tensor<1x2x3x4x5xi1>,
 // CHECK-SAME:                                %[[ARG_1:.*]]: tensor<2xi64>) -> tensor<2x4x5xi1> {
-// CHECK-DAG:       %[[FALSE_CST:.*]] = "tf.Const"() {value = dense<false> : tensor<i1>} : () -> tensor<i1>
-// CHECK-DAG:       %[[DIMENSIONS:.*]] = "tf.Const"() {value = dense<[0, 2]> : tensor<2xi64>} : () -> tensor<2xi64>
-// CHECK:           %[[VAL_0:.*]] = "tf.Any"(%[[ARG_0]], %[[DIMENSIONS]]) {keep_dims = false} : (tensor<1x2x3x4x5xi1>, tensor<2xi64>) -> tensor<2x4x5xi1>
+// CHECK-DAG:       %[[FALSE_CST:.*]] = "tf.Const"() <{value = dense<false> : tensor<i1>}> : () -> tensor<i1>
+// CHECK-DAG:       %[[DIMENSIONS:.*]] = "tf.Const"() <{value = dense<[0, 2]> : tensor<2xi64>}> : () -> tensor<2xi64>
+// CHECK:           %[[VAL_0:.*]] = "tf.Any"(%[[ARG_0]], %[[DIMENSIONS]]) <{keep_dims = false}> : (tensor<1x2x3x4x5xi1>, tensor<2xi64>) -> tensor<2x4x5xi1>
 // CHECK:           return %[[VAL_0:.*]] : tensor<2x4x5xi1>
 // CHECK:         }
 func.func @convert_reduce_to_any(%arg0: tensor<1x2x3x4x5xi1>, %arg1: tensor<2xi64>) -> tensor<2x4x5xi1> {
@@ -4774,8 +4774,8 @@
 // CHECK-SAME:                                %[[ARG_0:.*]]: tensor<i1>,
 // CHECK-SAME:                                %[[ARG_1:.*]]: tensor<1x2x3x4x5xi1>,
 // CHECK-SAME:                                %[[ARG_2:.*]]: tensor<2xi64>) -> tensor<2x4x5xi1> {
-// CHECK-DAG:       %[[DIMENSIONS:.*]] = "tf.Const"() {value = dense<[0, 2]> : tensor<2xi64>} : () -> tensor<2xi64>
-// CHECK:           %[[VAL_0:.*]] = "tf.Any"(%[[ARG_1]], %[[DIMENSIONS]]) {keep_dims = false} : (tensor<1x2x3x4x5xi1>, tensor<2xi64>) -> tensor<2x4x5xi1>
+// CHECK-DAG:       %[[DIMENSIONS:.*]] = "tf.Const"() <{value = dense<[0, 2]> : tensor<2xi64>}> : () -> tensor<2xi64>
+// CHECK:           %[[VAL_0:.*]] = "tf.Any"(%[[ARG_1]], %[[DIMENSIONS]]) <{keep_dims = false}> : (tensor<1x2x3x4x5xi1>, tensor<2xi64>) -> tensor<2x4x5xi1>
 // CHECK:           %[[VAL_1:.*]] = "tf.LogicalOr"(%[[VAL_0]], %[[ARG_0]]) : (tensor<2x4x5xi1>, tensor<i1>) -> tensor<2x4x5xi1>
 // CHECK:           return %[[VAL_1:.*]] : tensor<2x4x5xi1>
 // CHECK:         }
@@ -4790,14 +4790,14 @@
 
 // CHECK-LABEL:   func @convert_sort_to_topk_iota_broadcast(
 // CHECK-SAME:                                              %[[ARG_0:.*]]: tensor<3x6xf32>) -> (tensor<3x6xf32>, tensor<3x6xi32>) {
-// CHECK-DAG:       %[[VAL_0:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() {value = dense<6> : tensor<i32>} : () -> tensor<i32>
-// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       %[[VAL_0:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<6> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           %[[VAL_3:.*]] = "tf.Range"(%cst, %cst_0, %cst_1) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<6xi32>
 // CHECK:           %[[VAL_4:.*]] = arith.constant dense<[3, 6]> : tensor<2xi64>
 // CHECK:           %[[VAL_5:.*]] = "tf.BroadcastTo"(%0, %cst_2) : (tensor<6xi32>, tensor<2xi64>) -> tensor<3x6xi32>
-// CHECK:           %[[K:.*]] = "tf.Const"() {value = dense<6> : tensor<i32>} : () -> tensor<i32>
-// CHECK:           %[[VALUES:.*]], %[[INDICES:.*]] = "tf.TopKV2"(%[[ARG_0]], %[[K]]) {sorted = true} : (tensor<3x6xf32>, tensor<i32>) -> (tensor<3x6xf32>, tensor<3x6xi32>)
+// CHECK:           %[[K:.*]] = "tf.Const"() <{value = dense<6> : tensor<i32>}> : () -> tensor<i32>
+// CHECK:           %[[VALUES:.*]], %[[INDICES:.*]] = "tf.TopKV2"(%[[ARG_0]], %[[K]]) <{sorted = true}> : (tensor<3x6xf32>, tensor<i32>) -> (tensor<3x6xf32>, tensor<3x6xi32>)
 // CHECK:           return %[[VALUES]], %[[INDICES]] : tensor<3x6xf32>, tensor<3x6xi32>
 // CHECK:         }
 func.func @convert_sort_to_topk_iota_broadcast(%arg0: tensor<3x6xf32>) -> (tensor<3x6xf32>, tensor<3x6xi32>) {
@@ -4813,11 +4813,11 @@
 
 // CHECK-LABEL:   func @convert_sort_to_topk_iotacst_broadcast(
 // CHECK-SAME:                                                 %[[ARG_0:.*]]: tensor<3x6xf32>) -> (tensor<3x6xf32>, tensor<3x6xi32>) {
-// CHECK-DAG:       %[[VAL_0:.*]] = "tf.Const"() {value = dense<[0, 1, 2, 3, 4, 5]> : tensor<6xi32>} : () -> tensor<6xi32>
+// CHECK-DAG:       %[[VAL_0:.*]] = "tf.Const"() <{value = dense<[0, 1, 2, 3, 4, 5]> : tensor<6xi32>}> : () -> tensor<6xi32>
 // CHECK-DAG:       %[[VAL_1:.*]] = arith.constant dense<[3, 6]> : tensor<2xi64>
 // CHECK:           %[[VAL_2:.*]] = "tf.BroadcastTo"(%cst, %cst_0) : (tensor<6xi32>, tensor<2xi64>) -> tensor<3x6xi32>
-// CHECK:           %[[K:.*]] = "tf.Const"() {value = dense<6> : tensor<i32>} : () -> tensor<i32>
-// CHECK:           %[[VALUES:.*]], %[[INDICES:.*]] = "tf.TopKV2"(%[[ARG_0]], %[[K]]) {sorted = true} : (tensor<3x6xf32>, tensor<i32>) -> (tensor<3x6xf32>, tensor<3x6xi32>)
+// CHECK:           %[[K:.*]] = "tf.Const"() <{value = dense<6> : tensor<i32>}> : () -> tensor<i32>
+// CHECK:           %[[VALUES:.*]], %[[INDICES:.*]] = "tf.TopKV2"(%[[ARG_0]], %[[K]]) <{sorted = true}> : (tensor<3x6xf32>, tensor<i32>) -> (tensor<3x6xf32>, tensor<3x6xi32>)
 // CHECK:           return %[[VALUES]], %[[INDICES]] : tensor<3x6xf32>, tensor<3x6xi32>
 // CHECK:         }
 func.func @convert_sort_to_topk_iotacst_broadcast(%arg0: tensor<3x6xf32>) -> (tensor<3x6xf32>, tensor<3x6xi32>) {
@@ -4833,9 +4833,9 @@
 
 // CHECK-LABEL:   func @convert_sort_to_topk_const(
 // CHECK-SAME:                                     %[[ARG_0:.*]]: tensor<3x6xf32>) -> (tensor<3x6xf32>, tensor<3x6xi32>) {
-// CHECK-DAG:       %[[VAL_0:.*]] = "tf.Const"() {value = dense<{{.*}}> : tensor<3x6xi32>} : () -> tensor<3x6xi32>
-// CHECK-DAG:       %[[K:.*]] = "tf.Const"() {value = dense<6> : tensor<i32>} : () -> tensor<i32>
-// CHECK:           %[[VALUES:.*]], %[[INDICES:.*]] = "tf.TopKV2"(%[[ARG_0]], %[[K]]) {sorted = true} : (tensor<3x6xf32>, tensor<i32>) -> (tensor<3x6xf32>, tensor<3x6xi32>)
+// CHECK-DAG:       %[[VAL_0:.*]] = "tf.Const"() <{value = dense<{{.*}}> : tensor<3x6xi32>}> : () -> tensor<3x6xi32>
+// CHECK-DAG:       %[[K:.*]] = "tf.Const"() <{value = dense<6> : tensor<i32>}> : () -> tensor<i32>
+// CHECK:           %[[VALUES:.*]], %[[INDICES:.*]] = "tf.TopKV2"(%[[ARG_0]], %[[K]]) <{sorted = true}> : (tensor<3x6xf32>, tensor<i32>) -> (tensor<3x6xf32>, tensor<3x6xi32>)
 // CHECK:           return %[[VALUES]], %[[INDICES]] : tensor<3x6xf32>, tensor<3x6xi32>
 // CHECK:         }
 func.func @convert_sort_to_topk_const(%arg0: tensor<3x6xf32>) -> (tensor<3x6xf32>, tensor<3x6xi32>) {
@@ -4913,7 +4913,7 @@
 // CHECK-LABEL:   func @convert_population_count_i32(
 // CHECK-SAME:                                   %[[ARG_0:.*]]: tensor<8xi32>
 // CHECK:       %[[POP_CNT:.*]] = "tf.PopulationCount"(%[[ARG_0]]) : (tensor<8xi32>) -> tensor<8xui8>
-// CHECK:       %[[RES:.*]] = "tf.Cast"(%[[POP_CNT]]) {Truncate = false} : (tensor<8xui8>) -> tensor<8xi32>
+// CHECK:       %[[RES:.*]] = "tf.Cast"(%[[POP_CNT]]) <{Truncate = false}> : (tensor<8xui8>) -> tensor<8xi32>
 // CHECK:       return %[[RES]]
 // CHECK:         }
 func.func @convert_population_count_i32(%arg0: tensor<8xi32>) -> tensor<8xi32> {
@@ -4932,8 +4932,8 @@
 }
 
 // CHECK-LABEL:   func @torch_index_select(
-// CHECK:       %[[AXIS:.+]] = "tf.Const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
-// CHECK:       %[[RES:.+]] = "tf.GatherV2"(%arg0, %arg1, %[[AXIS]]) {batch_dims = 0 : i64}
+// CHECK:       %[[AXIS:.+]] = "tf.Const"() <{value = dense<0> : tensor<i64>}> : () -> tensor<i64>
+// CHECK:       %[[RES:.+]] = "tf.GatherV2"(%arg0, %arg1, %[[AXIS]]) <{batch_dims = 0 : i64}>
 // CHECK:       return %[[RES]]
 
 func.func @torch_index_select(%arg0: tensor<2x1xf32>, %arg1: tensor<2xi32>) -> tensor<2x1xf32> {
@@ -4945,9 +4945,9 @@
 
 // CHECK-LABEL:   func @lowered_cumsum(
 // CHECK-SAME:      %[[VAL_0:.*]]: tensor<4x12xf32>) -> tensor<4x12xf32> {
-// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
-// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
-// CHECK:           %[[VAL_3:.*]] = "tf.Cumsum"(%[[VAL_0]], %[[VAL_2]]) {exclusive = false, reverse = false} : (tensor<4x12xf32>, tensor<i64>) -> tensor<4x12xf32>
+// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<0> : tensor<i64>}> : () -> tensor<i64>
+// CHECK:           %[[VAL_3:.*]] = "tf.Cumsum"(%[[VAL_0]], %[[VAL_2]]) <{exclusive = false, reverse = false}> : (tensor<4x12xf32>, tensor<i64>) -> tensor<4x12xf32>
 // CHECK:           return %[[VAL_3]] : tensor<4x12xf32>
 // CHECK:         }
 func.func @lowered_cumsum(%arg0: tensor<4x12xf32>) -> tensor<4x12xf32> {
@@ -4962,9 +4962,9 @@
 
 // CHECK-LABEL:   func @lowered_cumsum_trivial_attrs(
 // CHECK-SAME:      %[[VAL_0:.*]]: tensor<4x12xf32>) -> tensor<4x12xf32> {
-// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
-// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
-// CHECK:           %[[VAL_3:.*]] = "tf.Cumsum"(%[[VAL_0]], %[[VAL_2]]) {exclusive = false, reverse = false} : (tensor<4x12xf32>, tensor<i64>) -> tensor<4x12xf32>
+// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<0> : tensor<i64>}> : () -> tensor<i64>
+// CHECK:           %[[VAL_3:.*]] = "tf.Cumsum"(%[[VAL_0]], %[[VAL_2]]) <{exclusive = false, reverse = false}> : (tensor<4x12xf32>, tensor<i64>) -> tensor<4x12xf32>
 // CHECK:           return %[[VAL_3]] : tensor<4x12xf32>
 // CHECK:         }
 func.func @lowered_cumsum_trivial_attrs(%arg0: tensor<4x12xf32>) -> tensor<4x12xf32> {
@@ -4979,9 +4979,9 @@
 
 // CHECK-LABEL:   func @lowered_cumprod(
 // CHECK-SAME:      %[[VAL_0:.*]]: tensor<4x12xf32>) -> tensor<4x12xf32> {
-// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
-// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() {value = dense<1> : tensor<i64>} : () -> tensor<i64>
-// CHECK:           %[[VAL_3:.*]] = "tf.Cumprod"(%[[VAL_0]], %[[VAL_2]]) {exclusive = false, reverse = false} : (tensor<4x12xf32>, tensor<i64>) -> tensor<4x12xf32>
+// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<1> : tensor<i64>}> : () -> tensor<i64>
+// CHECK:           %[[VAL_3:.*]] = "tf.Cumprod"(%[[VAL_0]], %[[VAL_2]]) <{exclusive = false, reverse = false}> : (tensor<4x12xf32>, tensor<i64>) -> tensor<4x12xf32>
 // CHECK:           return %[[VAL_3]] : tensor<4x12xf32>
 // CHECK:         }
 func.func @lowered_cumprod(%arg0: tensor<4x12xf32>) -> tensor<4x12xf32> {
@@ -5022,7 +5022,7 @@
 
 // CHECK-LABEL: func @get_dimension_size(
 // CHECK-SAME:              %[[ARG_0:.*]]: tensor<4x256x?xf32>) -> tensor<i32> {
-// CHECK          %[[CST_0:.*]] = "tf.Const"() {value = dense<256> : tensor<i32>} : () -> tensor<i32>
+// CHECK          %[[CST_0:.*]] = "tf.Const"() <{value = dense<256> : tensor<i32>}> : () -> tensor<i32>
 // CHECK          return %[[CST_0]] : tensor<i32>
 func.func @get_dimension_size(%arg0: tensor<4x256x?xf32>) -> tensor<i32> {
   %0 = "mhlo.get_dimension_size"(%arg0) {dimension = 1 : i64} : (tensor<4x256x?xf32>) -> tensor<i32>
@@ -5032,10 +5032,10 @@
 // CHECK-LABEL: func @get_dimension_size_dynamic(
 // CHECK-SAME:              %[[ARG_0:.*]]: tensor<4x256x?xf32>) -> tensor<i32> {
 // CHECK          %[[VAL_0:.*]] = "tf.Shape"(%[[ARG_0]]) : (tensor<4x256x?xf32>) -> tensor<3xi32>
-// CHECK          %[[CST_0:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
-// CHECK          %[[CST_1:.*]] = "tf.Const"() {value = dense<2> : tensor<1xi64>} : () -> tensor<1xi64>
+// CHECK          %[[CST_0:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
+// CHECK          %[[CST_1:.*]] = "tf.Const"() <{value = dense<2> : tensor<1xi64>}> : () -> tensor<1xi64>
 // CHECK          %[[VAL_1:.*]] = "tf.Slice"(%[[VAL_0]], %[[CST_1]], %[[CST_0]]) : (tensor<3xi32>, tensor<1xi64>, tensor<1xi32>) -> tensor<1xi32>
-// CHECK          %[[VAL_2:.*]] = "tf.Squeeze"(%[[VAL_1]]) {squeeze_dims = [0]} : (tensor<1xi32>) -> tensor<i32>
+// CHECK          %[[VAL_2:.*]] = "tf.Squeeze"(%[[VAL_1]]) <{squeeze_dims = [0]}> : (tensor<1xi32>) -> tensor<i32>
 // CHECK          return %[[VAL_2]] : tensor<i32>
 func.func @get_dimension_size_dynamic(%arg0: tensor<4x256x?xf32>) -> tensor<i32> {
   %0 = "mhlo.get_dimension_size"(%arg0) {dimension = 2 : i64} : (tensor<4x256x?xf32>) -> tensor<i32>
@@ -5044,10 +5044,10 @@
 
 // CHECK-LABEL: func @dynamic_iota_i32_1d(
 // CHECK-SAME:                  %[[ARG_0:.*]]: tensor<1xi32>) -> tensor<?xi32> {
-// CHECK-DAG:     %[[CST_0:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
+// CHECK-DAG:     %[[CST_0:.*]] = "tf.Const"() <{value = dense<> : tensor<0xi32>}> : () -> tensor<0xi32>
 // CHECK:         %[[VAL_0:.*]] = "tf.Reshape"(%arg0, %[[CST_0]]) : (tensor<1xi32>, tensor<0xi32>) -> tensor<i32>
-// CHECK-DAG:     %[[CST_1:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-// CHECK-DAG:     %[[CST_2:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:     %[[CST_1:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG:     %[[CST_2:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:         %[[VAL_1:.*]] = "tf.Range"(%[[CST_1]], %[[VAL_0]], %[[CST_2]]) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
 // CHECK:         return %[[VAL_1]] : tensor<?xi32>
 func.func @dynamic_iota_i32_1d(%arg0: tensor<1xi32>) -> tensor<?xi32> {
@@ -5057,11 +5057,11 @@
 
 // CHECK-LABEL: func @dynamic_iota_f32_1d(
 // CHECK-SAME:                  %[[ARG_0:.*]]: tensor<1xi32>) -> tensor<?xf32> {
-// CHECK:         %[[VAL_0:.*]] = "tf.Cast"(%arg0) {Truncate = false} : (tensor<1xi32>) -> tensor<1xf32>
-// CHECK-DAG:     %[[CST_0:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
+// CHECK:         %[[VAL_0:.*]] = "tf.Cast"(%arg0) <{Truncate = false}> : (tensor<1xi32>) -> tensor<1xf32>
+// CHECK-DAG:     %[[CST_0:.*]] = "tf.Const"() <{value = dense<> : tensor<0xi32>}> : () -> tensor<0xi32>
 // CHECK:         %[[VAL_1:.*]] = "tf.Reshape"(%[[VAL_0]], %[[CST_0]]) : (tensor<1xf32>, tensor<0xi32>) -> tensor<f32>
-// CHECK-DAG:     %[[CST_1:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
-// CHECK-DAG:     %[[CST_2:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
+// CHECK-DAG:     %[[CST_1:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG:     %[[CST_2:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
 // CHECK:         %[[VAL_2:.*]] = "tf.Range"(%[[CST_1]], %[[VAL_1]], %[[CST_2]]) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32>
 // CHECK:         return %[[VAL_2]] : tensor<?xf32>
 func.func @dynamic_iota_f32_1d(%arg0: tensor<1xi32>) -> tensor<?xf32> {
@@ -5073,8 +5073,8 @@
 // CHECK-SAME:              %arg0: tensor<1x?x4x256xf32>,
 // CHECK-SAME:              %arg1: tensor<4xi32>,
 // CHECK-SAME:              %arg2: tensor<4xi32>) -> tensor<1x?x4x128xf32> {
-// CHECK:         %cst = "tf.Const"() {value = dense<1> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK:         %0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %cst) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1x?x4x256xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x?x4x128xf32>
+// CHECK:         %cst = "tf.Const"() <{value = dense<1> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK:         %0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %cst) <{begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64}> : (tensor<1x?x4x256xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x?x4x128xf32>
 // CHECK:         return %0 : tensor<1x?x4x128xf32>
 func.func @real_dynamic_slice_strides_equal_to_1_signed(%arg0: tensor<1x?x4x256xf32>, %arg1: tensor<4xi32>, %arg2: tensor<4xi32>) -> tensor<1x?x4x128xf32> {
 %cst = mhlo.constant dense<1> : tensor<4xi32>
@@ -5086,8 +5086,8 @@
 // CHECK-SAME:              %arg0: tensor<1x?x2x4xf32>,
 // CHECK-SAME:              %arg1: tensor<4xi32>,
 // CHECK-SAME:              %arg2: tensor<4xi32>) -> tensor<1x?x1x2xf32> {
-// CHECK          %cst = "tf.Const"() {value = dense<2> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK          %0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %cst) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1x?x2x4xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x?x1x2xf32>
+// CHECK          %cst = "tf.Const"() <{value = dense<2> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK          %0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %cst) <{begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64}> : (tensor<1x?x2x4xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x?x1x2xf32>
 // CHECK          return %0 : tensor<1x?x1x2xf32>
 func.func @real_dynamic_slice_strides_not_equal_to_1(%arg0: tensor<1x?x2x4xf32>, %arg1: tensor<4xi32>, %arg2: tensor<4xi32>) -> tensor<1x?x1x2xf32> {
 %cst = mhlo.constant dense<2> : tensor<4xi32>
@@ -5113,7 +5113,7 @@
 // CHECK-SAME:                                        %[[ARG_1:.*]]: tensor<1x4xi32>,
 // CHECK-SAME:                                        %[[ARG_2:.*]]: tensor<f32>,
 // CHECK-SAME:                                        %[[ARG_3:.*]]: tensor<i32>) -> (tensor<1x4xf32>, tensor<1x4xi32>) {
-// CHECK:          %[[VALUES:.*]], %[[INDICES:.*]] = "tf.ApproxTopK"(%[[ARG_0]]) {aggregate_to_topk = true, is_max_k = true, k = 4 : i64, recall_target = 8.500000e-01 : f32, reduction_dimension = 1 : i64, reduction_input_size_override = -1 : i64} : (tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor<1x4xi32>)
+// CHECK:          %[[VALUES:.*]], %[[INDICES:.*]] = "tf.ApproxTopK"(%[[ARG_0]]) <{aggregate_to_topk = true, is_max_k = true, k = 4 : i64, recall_target = 8.500000e-01 : f32, reduction_dimension = 1 : i64, reduction_input_size_override = -1 : i64}> : (tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor<1x4xi32>)
 // CHECK:          return %[[VALUES]], %[[INDICES]] : tensor<1x4xf32>, tensor<1x4xi32>
 // CHECK:        }
 func.func @convert_approx_top_k_custom_call(%arg0: tensor<1x4xf32>, %arg1: tensor<1x4xi32>, %arg2: tensor<f32>, %arg3: tensor<i32>) -> (tensor<1x4xf32>, tensor<1x4xi32>) {
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/tf-tfl-translate-serialize-stablehlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/tf-tfl-translate-serialize-stablehlo.mlir
index e822963..4e12ffd 100644
--- a/tensorflow/compiler/mlir/lite/stablehlo/tests/tf-tfl-translate-serialize-stablehlo.mlir
+++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/tf-tfl-translate-serialize-stablehlo.mlir
@@ -1,9 +1,22 @@
-//RUN: tf_tfl_translate --enable-stablehlo-conversion --input-mlir %s -o /tmp/temp.stablehlo; [ -f /tmp/temp.stablehlo ]
+//RUN: tf_tfl_translate --enable-stablehlo-conversion --input-mlir %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
 
 
 module {
-func.func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> {
-  %0 = "tf.Add"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
-  func.return %0 : tensor<2xi32>
+func.func @tfInplaceUpdate(%arg0: tensor<2x1x2xf32>) -> tensor<2x1x2xf32> {
+  %1 = arith.constant dense<1> : tensor<1xi32>
+  %2 = arith.constant dense<2.0> : tensor<1x1x2xf32>
+  %3 = "tf.InplaceUpdate"(%arg0, %1, %2) {device = ""}
+    : (tensor<2x1x2xf32>, tensor<1xi32>, tensor<1x1x2xf32>) -> tensor<2x1x2xf32>
+  func.return %3 : tensor<2x1x2xf32>
 }
-}
\ No newline at end of file
+}
+
+//CHECK: module attributes {tfl.description = "MLIR Converted.", tfl.metadata = {keep_stablehlo_constant = "true"}, tfl.schema_version = 3 : i32} {
+//CHECK-NEXT:  func.func @main(%arg0: tensor<2x1x2xf32>) -> tensor<2x1x2xf32> attributes {tf.entry_function = {inputs = "arg0", outputs = "stablehlo.dynamic_update_slice"}} {
+//CHECK-DAG:    %0 = stablehlo.constant dense<2.000000e+00> : tensor<1x1x2xf32>
+//CHECK-DAG:    %1 = stablehlo.constant dense<1> : tensor<i32>
+//CHECK-DAG:    %2 = stablehlo.constant dense<0> : tensor<i32>
+//CHECK-NEXT:   %3 = stablehlo.dynamic_update_slice %arg0, %0, %1, %2, %2 : (tensor<2x1x2xf32>, tensor<1x1x2xf32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<2x1x2xf32>
+//CHECK-NEXT:   return %3 : tensor<2x1x2xf32>
+//CHECK-NEXT:  }
+//CHECK-NEXT:}
\ No newline at end of file
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/unfuse_mhlo_batch_norm.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/unfuse_mhlo_batch_norm.mlir
index bd2f56d..073f31e 100644
--- a/tensorflow/compiler/mlir/lite/stablehlo/tests/unfuse_mhlo_batch_norm.mlir
+++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/unfuse_mhlo_batch_norm.mlir
@@ -132,13 +132,13 @@
 func.func @batchNormTraining_4D_middle_features(
     %x: tensor<3x4x256x6xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>)
     -> (tensor<3x4x256x6xf32>) {
-  // CHECK-DAG: %[[CST_AXIS:.+]] = "tf.Const"() {value = dense<[0, 1, 3]> : tensor<3xi32>} : () -> tensor<3xi32>
+  // CHECK-DAG: %[[CST_AXIS:.+]] = "tf.Const"() <{value = dense<[0, 1, 3]> : tensor<3xi32>}> : () -> tensor<3xi32>
   // CHECK-DAG: %[[X_SHAPE:.+]] = shape.const_shape [3, 4, 256, 6] : tensor<4xindex>
   // CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e+00> : tensor<256xf32>
-  // CHECK-DAG: %[[MEAN:.+]] = "tf.Mean"(%arg0, %[[CST_AXIS]]) {keep_dims = false} : (tensor<3x4x256x6xf32>, tensor<3xi32>) -> tensor<256xf32>
+  // CHECK-DAG: %[[MEAN:.+]] = "tf.Mean"(%arg0, %[[CST_AXIS]]) <{keep_dims = false}> : (tensor<3x4x256x6xf32>, tensor<3xi32>) -> tensor<256xf32>
   // CHECK-DAG: %[[MEAN_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[MEAN]], %[[X_SHAPE]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<256xf32>, tensor<4xindex>) -> tensor<3x4x256x6xf32>
   // CHECK-DAG: %[[SQ_DIFF:.+]] = "tf.SquaredDifference"(%arg0, %[[MEAN_BCAST]]) : (tensor<3x4x256x6xf32>, tensor<3x4x256x6xf32>) -> tensor<3x4x256x6xf32>
-  // CHECK-DAG: %[[VARIANCE:.+]] = "tf.Mean"(%[[SQ_DIFF]], %[[CST_AXIS]]) {keep_dims = false} : (tensor<3x4x256x6xf32>, tensor<3xi32>) -> tensor<256xf32>
+  // CHECK-DAG: %[[VARIANCE:.+]] = "tf.Mean"(%[[SQ_DIFF]], %[[CST_AXIS]]) <{keep_dims = false}> : (tensor<3x4x256x6xf32>, tensor<3xi32>) -> tensor<256xf32>
   // CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS]] : tensor<256xf32>
   // CHECK-DAG: %[[VARIANCE_EPS_RSQRT:.+]] = mhlo.rsqrt %[[VARIANCE_EPS]] : tensor<256xf32>
   // CHECK-DAG: %[[MULTIPLIER:.+]] = mhlo.multiply %[[VARIANCE_EPS_RSQRT]], %[[SCALE]] : tensor<256xf32>
@@ -152,4 +152,4 @@
       {epsilon = 1.0 : f32, feature_index = 2 : i64} :
       (tensor<3x4x256x6xf32>, tensor<256xf32>, tensor<256xf32>) -> (tensor<3x4x256x6xf32>, tensor<256xf32>, tensor<256xf32>)
   func.return %0 : tensor<3x4x256x6xf32>
-}
\ No newline at end of file
+}
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc
index 2a0c673..f161bb3 100644
--- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc
@@ -3795,6 +3795,18 @@
           output_rank - input_rank);
 }
 
+// Returns true if the operation producing the provided result (`op_result`)
+// is within an op region of an operation of type `ParentType`.
+template <typename ParentType>
+bool IsWithinOpRegion(mlir::OpResult op_result) {
+  mlir::Operation* parent_op = op_result.getDefiningOp()->getParentOp();
+
+  if (llvm::dyn_cast<ParentType>(parent_op)) {
+    return true;
+  }
+  return false;
+}
+
 // Returns the intermediate shape that input tensor should be reshaped to during
 // legalization of BroadcastInDimOp.
 arith::ConstantOp ExpandedShape(PatternRewriter& rewriter, Value input,
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_patterns.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_patterns.td
index fbdcfca..473b370 100644
--- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_patterns.td
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_patterns.td
@@ -33,6 +33,9 @@
 def IsTFStyleBroadcast : Constraint<CPred<"IsTFStyleBroadcast($0, $1)">,
     "new dimensions are added as prefix">;
 
+def IsNotWithinScatterRegion : Constraint<CPred<"!IsWithinOpRegion<mhlo::ScatterOp>($0)">,
+    "binary ops within scatter regions are not converted">;
+
 // Check if broadcast dimensions do not match Tensorflow convention.
 def IsNotTFStyleBroadcast : Constraint<Neg<CPred<"IsTFStyleBroadcast($0, $1)">>,
     "new dimensions are inserted in intermediate positions">;
@@ -60,7 +63,9 @@
                          [MHLO_PowOp, CHLO_BroadcastPowOp, TF_PowOp],
                          [MHLO_SubtractOp, CHLO_BroadcastSubOp, TF_SubOp],
                          [MHLO_Atan2Op, CHLO_BroadcastAtan2Op, TF_Atan2Op]] in {
-  def : Pat<(fromToBinPair[0] $l, $r), (fromToBinPair[2] $l, $r)>;
+  def : Pat<(fromToBinPair[0]:$result $l, $r),
+            (fromToBinPair[2] $l, $r),
+            [(IsNotWithinScatterRegion $result)]>;
   def : Pat<(fromToBinPair[1] $l, $r, $broadcast_dimensions),
             (fromToBinPair[2] $l, $r),
             [(IsLegalNumpyRankedBroadcast $l, $r, $broadcast_dimensions)]>;
diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.cc
index 562afd5..df3a5f6 100644
--- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.cc
+++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.cc
@@ -18,6 +18,7 @@
 #include "mlir/Pass/PassManager.h"  // from @llvm-project
 #include "mlir/Transforms/Passes.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/drop_savedmodel_semantics.h"
+#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.h"
 #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h"
 #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.h"
 #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/smuggle_disallowed_ops.h"
@@ -34,6 +35,9 @@
 void AddTFToStablehloPasses(OpPassManager& pm, bool skip_resize,
                             bool smuggle_disallowed_ops) {
   pm.addPass(CreateRenameEntrypointToMainPass());
+
+  // if the input is a call_xla_module, then unwrap the content
+  pm.addPass(mlir::odml::CreateLegalizeTFXlaCallModuleToStablehloPass());
   // TODO(b/230572023): Consider improving shape inference for While op instead
   // of dropping the attribute. This need not be correct for models not trained
   // on TPU.
diff --git a/tensorflow/compiler/mlir/lite/tests/dilated-conv.mlir b/tensorflow/compiler/mlir/lite/tests/dilated-conv.mlir
index ccb5507..378ed7f 100644
--- a/tensorflow/compiler/mlir/lite/tests/dilated-conv.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/dilated-conv.mlir
@@ -10,7 +10,7 @@
 
   // CHECK-LABEL: testDilatedConv
   // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>)
-  // CHECK-NEXT: [[RESULT:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x120x120x8xf32>
+  // CHECK-NEXT: [[RESULT:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) <{dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]}> : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x120x120x8xf32>
   // CHECK-NEXT: return [[RESULT]] : tensor<1x120x120x8xf32>
 }
 
@@ -24,7 +24,7 @@
 
   // CHECK-LABEL: testDilatedConvWithNonConstantPadAndCrops
   // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>)
-  // CHECK-NEXT: [[RESULT:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x120x120x8xf32>
+  // CHECK-NEXT: [[RESULT:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) <{dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]}> : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x120x120x8xf32>
   // CHECK-NEXT: return [[RESULT]] : tensor<1x120x120x8xf32>
 }
 
@@ -39,7 +39,7 @@
 
   // CHECK-LABEL: testDilatedConvWithNonZeroBasePadding
   // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>)
-  // CHECK-NEXT: [[RESULT:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
+  // CHECK-NEXT: [[RESULT:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) <{dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]}> : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
   // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
 }
 
@@ -54,7 +54,7 @@
 
   // CHECK-LABEL: testDilatedConvWithFp16
   // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x20x30x40xf16>, [[FILTER:%.*]]: tensor<5x5x40x32xf16>)
-  // CHECK-NEXT: [[RESULT:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {data_format = "NHWC", dilations = [1, 2, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x20x30x40xf16>, tensor<5x5x40x32xf16>) -> tensor<1x20x30x32xf16>
+  // CHECK-NEXT: [[RESULT:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) <{data_format = "NHWC", dilations = [1, 2, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]}> : (tensor<1x20x30x40xf16>, tensor<5x5x40x32xf16>) -> tensor<1x20x30x32xf16>
   // CHECK-NEXT: return [[RESULT]] : tensor<1x20x30x32xf16>
 }
 
@@ -85,7 +85,7 @@
 
   // CHECK-LABEL: testDilatedDepthWiseConv
   // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>)
-  // CHECK-NEXT: [[RESULT:%.*]] = "tf.DepthwiseConv2dNative"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
+  // CHECK-NEXT: [[RESULT:%.*]] = "tf.DepthwiseConv2dNative"([[INPUT]], [[FILTER]]) <{dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]}> : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
   // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
 }
 
@@ -103,7 +103,7 @@
 
   // CHECK-LABEL: testDilatedConvWithPad
   // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>)
-  // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
+  // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) <{dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]}> : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
   // CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[CONV]], [[BIAS]]) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
   // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
 }
@@ -122,7 +122,7 @@
 
   // CHECK-LABEL: testDilatedDepthWiseConvWithPad
   // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>)
-  // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
+  // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[INPUT]], [[FILTER]]) <{dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]}> : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
   // CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[CONV]], [[BIAS]]) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
   // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
 }
@@ -139,7 +139,7 @@
 
   // CHECK-LABEL: testDilatedConvWithBiasAdd
   // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>)
-  // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
+  // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) <{dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]}> : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
   // CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[CONV]], [[BIAS]]) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
   // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
 }
@@ -156,7 +156,7 @@
 
   // CHECK-LABEL: testDilatedDepthWiseConvWithBiasAdd
   // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>)
-  // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
+  // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[INPUT]], [[FILTER]]) <{dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]}> : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
   // CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[CONV]], [[BIAS]]) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
   // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
 }
@@ -176,10 +176,10 @@
 
   // CHECK-LABEL: testDilatedConvWithExpandSqueeze1
   // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
-  // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
+  // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() <{value = dense<3> : tensor<i32>}> : () -> tensor<i32>
   // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
-  // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
-  // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
+  // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) <{dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]}> : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
+  // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) <{squeeze_dims = [3]}> : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
   // CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
   // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
 }
@@ -199,10 +199,10 @@
 
   // CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze1
   // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
-  // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
+  // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() <{value = dense<3> : tensor<i32>}> : () -> tensor<i32>
   // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
-  // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
-  // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
+  // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) <{dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]}> : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
+  // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) <{squeeze_dims = [3]}> : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
   // CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
   // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
 }
@@ -222,10 +222,10 @@
 
   // CHECK-LABEL: testDilatedConvWithExpandSqueeze2
   // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<?xf32>)
-  // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
+  // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() <{value = dense<3> : tensor<i32>}> : () -> tensor<i32>
   // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
-  // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
-  // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
+  // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) <{dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]}> : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
+  // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) <{squeeze_dims = [3]}> : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
   // CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x128xf32>, tensor<?xf32>) -> tensor<1x128x128xf32>
   // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
 }
@@ -245,10 +245,10 @@
 
   // CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze2
   // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<?xf32>)
-  // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
+  // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() <{value = dense<3> : tensor<i32>}> : () -> tensor<i32>
   // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
-  // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
-  // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
+  // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) <{dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]}> : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
+  // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) <{squeeze_dims = [3]}> : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
   // CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x128xf32>, tensor<?xf32>) -> tensor<1x128x128xf32>
   // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
 }
@@ -270,10 +270,10 @@
 
   // CHECK-LABEL: testDilatedConvWithExpandSqueeze3
   // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
-  // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
+  // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() <{value = dense<3> : tensor<i32>}> : () -> tensor<i32>
   // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
-  // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
-  // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
+  // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) <{dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]}> : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
+  // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) <{squeeze_dims = [3]}> : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
   // CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
   // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
 }
@@ -295,10 +295,10 @@
 
   // CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze3
   // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
-  // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
+  // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() <{value = dense<3> : tensor<i32>}> : () -> tensor<i32>
   // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
-  // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
-  // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
+  // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) <{dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]}> : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
+  // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) <{squeeze_dims = [3]}> : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
   // CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
   // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
 }
@@ -407,10 +407,10 @@
 
   // CHECK-LABEL: testDilatedConv1DExpandH
   // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x3xf32>, [[FILTER:%.*]]: tensor<1x5x3x8xf32>)
-  // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<-3> : tensor<i32>} : () -> tensor<i32>
+  // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() <{value = dense<-3> : tensor<i32>}> : () -> tensor<i32>
   // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x3xf32>, tensor<i32>) -> tensor<1x1x128x3xf32>
-  // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 1, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x1x128x3xf32>, tensor<1x5x3x8xf32>) -> tensor<1x1x128x8xf32>
-  // CHECK-NEXT: [[RESULT:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [-3]} : (tensor<1x1x128x8xf32>) -> tensor<1x128x8xf32>
+  // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) <{dilations = [1, 1, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]}> : (tensor<1x1x128x3xf32>, tensor<1x5x3x8xf32>) -> tensor<1x1x128x8xf32>
+  // CHECK-NEXT: [[RESULT:%.*]] = "tf.Squeeze"([[CONV]]) <{squeeze_dims = [-3]}> : (tensor<1x1x128x8xf32>) -> tensor<1x128x8xf32>
   // CHECK-NEXT: return [[RESULT]] : tensor<1x128x8xf32>
 }
 
@@ -429,10 +429,10 @@
 
   // CHECK-LABEL: testDilatedConv1DExpandHWithBiasAdd
   // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x3xf32>, [[FILTER:%.*]]: tensor<1x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>)
-  // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<-3> : tensor<i32>} : () -> tensor<i32>
+  // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() <{value = dense<-3> : tensor<i32>}> : () -> tensor<i32>
   // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x3xf32>, tensor<i32>) -> tensor<1x1x128x3xf32>
-  // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 1, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x1x128x3xf32>, tensor<1x5x3x8xf32>) -> tensor<1x1x128x8xf32>
-  // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [-3]} : (tensor<1x1x128x8xf32>) -> tensor<1x128x8xf32>
+  // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) <{dilations = [1, 1, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]}> : (tensor<1x1x128x3xf32>, tensor<1x5x3x8xf32>) -> tensor<1x1x128x8xf32>
+  // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) <{squeeze_dims = [-3]}> : (tensor<1x1x128x8xf32>) -> tensor<1x128x8xf32>
   // CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x8xf32>, tensor<8xf32>) -> tensor<1x128x8xf32>
   // CHECK-NEXT: return [[RESULT]] : tensor<1x128x8xf32>
 }
@@ -451,10 +451,10 @@
 
   // CHECK-LABEL: testDilatedConv1DExpandW
   // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x3xf32>, [[FILTER:%.*]]: tensor<5x1x3x8xf32>)
-  // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<-2> : tensor<i32>} : () -> tensor<i32>
+  // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() <{value = dense<-2> : tensor<i32>}> : () -> tensor<i32>
   // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x3xf32>, tensor<i32>) -> tensor<1x128x1x3xf32>
-  // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 1, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x1x3xf32>, tensor<5x1x3x8xf32>) -> tensor<1x128x1x8xf32>
-  // CHECK-NEXT: [[RESULT:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [-2]} : (tensor<1x128x1x8xf32>) -> tensor<1x128x8xf32>
+  // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) <{dilations = [1, 2, 1, 1], padding = "SAME", strides = [1, 1, 1, 1]}> : (tensor<1x128x1x3xf32>, tensor<5x1x3x8xf32>) -> tensor<1x128x1x8xf32>
+  // CHECK-NEXT: [[RESULT:%.*]] = "tf.Squeeze"([[CONV]]) <{squeeze_dims = [-2]}> : (tensor<1x128x1x8xf32>) -> tensor<1x128x8xf32>
   // CHECK-NEXT: return [[RESULT]] : tensor<1x128x8xf32>
 }
 
@@ -473,10 +473,10 @@
 
   // CHECK-LABEL: testDilatedConv1DExpandWWithBiasAdd
   // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x3xf32>, [[FILTER:%.*]]: tensor<5x1x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>)
-  // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<-2> : tensor<i32>} : () -> tensor<i32>
+  // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() <{value = dense<-2> : tensor<i32>}> : () -> tensor<i32>
   // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x3xf32>, tensor<i32>) -> tensor<1x128x1x3xf32>
-  // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 1, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x1x3xf32>, tensor<5x1x3x8xf32>) -> tensor<1x128x1x8xf32>
-  // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [-2]} : (tensor<1x128x1x8xf32>) -> tensor<1x128x8xf32>
+  // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) <{dilations = [1, 2, 1, 1], padding = "SAME", strides = [1, 1, 1, 1]}> : (tensor<1x128x1x3xf32>, tensor<5x1x3x8xf32>) -> tensor<1x128x1x8xf32>
+  // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) <{squeeze_dims = [-2]}> : (tensor<1x128x1x8xf32>) -> tensor<1x128x8xf32>
   // CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x8xf32>, tensor<8xf32>) -> tensor<1x128x8xf32>
   // CHECK-NEXT: return [[RESULT]] : tensor<1x128x8xf32>
 }
@@ -495,10 +495,10 @@
 
   // CHECK-LABEL: testDilatedConv1DWithMixedPostiveAndNegativeAxis
   // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x3xf32>, [[FILTER:%.*]]: tensor<1x5x3x8xf32>)
-  // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+  // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
   // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x3xf32>, tensor<i32>) -> tensor<1x1x128x3xf32>
-  // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 1, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x1x128x3xf32>, tensor<1x5x3x8xf32>) -> tensor<1x1x128x8xf32>
-  // CHECK-NEXT: [[RESULT:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [-3]} : (tensor<1x1x128x8xf32>) -> tensor<1x128x8xf32>
+  // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) <{dilations = [1, 1, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]}> : (tensor<1x1x128x3xf32>, tensor<1x5x3x8xf32>) -> tensor<1x1x128x8xf32>
+  // CHECK-NEXT: [[RESULT:%.*]] = "tf.Squeeze"([[CONV]]) <{squeeze_dims = [-3]}> : (tensor<1x1x128x8xf32>) -> tensor<1x128x8xf32>
   // CHECK-NEXT: return [[RESULT]] : tensor<1x128x8xf32>
 }
 
@@ -518,11 +518,11 @@
 
   // CHECK-LABEL: testPaddedDilatedConv
   // CHECK-SAME: ([[INPUT:%.*]]: tensor<2x1920x64xf32>)
-  // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<2> : tensor<i32>} : () -> tensor<i32>
-  // CHECK-NEXT: [[FILTER:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x1x64x128xf32>} : () -> tensor<3x1x64x128xf32>
+  // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() <{value = dense<2> : tensor<i32>}> : () -> tensor<i32>
+  // CHECK-NEXT: [[FILTER:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<3x1x64x128xf32>}> : () -> tensor<3x1x64x128xf32>
   // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) {device = ""} : (tensor<2x1920x64xf32>, tensor<i32>) -> tensor<2x1920x1x64xf32>
-  // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {data_format = "NHWC", device = "", dilations = [1, 2, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<2x1920x1x64xf32>, tensor<3x1x64x128xf32>) -> tensor<2x1920x1x128xf32>
-  // CHECK-NEXT: [[RESULT:%.*]] = "tf.Squeeze"([[CONV]]) {device = "", squeeze_dims = [2]} : (tensor<2x1920x1x128xf32>) -> tensor<2x1920x128xf32>
+  // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) <{data_format = "NHWC", dilations = [1, 2, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> {device = ""} : (tensor<2x1920x1x64xf32>, tensor<3x1x64x128xf32>) -> tensor<2x1920x1x128xf32>
+  // CHECK-NEXT: [[RESULT:%.*]] = "tf.Squeeze"([[CONV]]) <{squeeze_dims = [2]}> {device = ""} : (tensor<2x1920x1x128xf32>) -> tensor<2x1920x128xf32>
   // CHECK-NEXT: return [[RESULT]] : tensor<2x1920x128xf32>
 }
 
@@ -539,7 +539,7 @@
 
   // CHECK-LABEL: testDilatedConvInterleaved
   // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>)
-  // CHECK-NEXT: [[RESULT0:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x120x120x8xf32>
-  // CHECK-NEXT: [[RESULT1:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x120x120x8xf32>
+  // CHECK-NEXT: [[RESULT0:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) <{dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]}> : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x120x120x8xf32>
+  // CHECK-NEXT: [[RESULT1:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) <{dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]}> : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x120x120x8xf32>
   // CHECK-NEXT: return [[RESULT0]], [[RESULT1]] : tensor<1x120x120x8xf32>, tensor<1x120x120x8xf32>
 }
diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/if_op.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/if_op.mlir
index f29afb3..7ea7e48 100644
--- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/if_op.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/if_op.mlir
@@ -1,7 +1,7 @@
 // RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
 // Confirm function references in if ops are preserved
 func.func @main(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
-// CHECK:   %{{.*}} = "tf.If"(%{{.*}}, %{{.*}}, %{{.*}}) {else_branch = @cond_false, is_stateless = false, then_branch = @cond_true} : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+// CHECK:   %{{.*}} = "tf.If"(%{{.*}}, %{{.*}}, %{{.*}}) <{else_branch = @cond_false, is_stateless = false, then_branch = @cond_true}> : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
   %0 = "tfl.less"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1>
   %1 = "tf.If"(%0, %arg0, %arg1) {else_branch = @cond_false, then_branch = @cond_true, is_stateless = false} : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
   func.return %1 : tensor<1xf32>
diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/metadata_buffer.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/metadata_buffer.mlir
new file mode 100644
index 0000000..6b76b31
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/metadata_buffer.mlir
@@ -0,0 +1,9 @@
+// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
+
+// CHECK: tfl.metadata_buffer = [3 : i32, 7 : i32]
+module attributes {tfl.metadata_buffer = [3 : i32, 7 : i32]} {
+  func.func @main(%arg0: tensor<i32>, %arg1: tensor<3x2xi32>) -> tensor<3x2xi32> {
+    %0 = "tfl.add" (%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<i32>, tensor<3x2xi32>) -> tensor<3x2xi32>
+    func.return %0 : tensor<3x2xi32>
+  }
+}
\ No newline at end of file
diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/stablehlo.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/stablehlo.mlir
index b1b08d5..64567f5 100644
--- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/stablehlo.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/stablehlo.mlir
@@ -355,7 +355,7 @@
     ^bb0(%arg23: tensor<f32>, %arg24: tensor<f32>):
       %1112 = stablehlo.add %arg23, %arg24 : tensor<f32>
       stablehlo.return %1112 : tensor<f32>
-    }) {padding = dense<[[0, 0], [159, 0], [0, 0]]> : tensor<3x2xi64>, window_dimensions = dense<[1, 160, 1]> : tensor<3xi64>, window_strides = dense<1> : tensor<3xi64>} : (tensor<1x160x1xf32>, tensor<f32>) -> tensor<1x160x1xf32>
+    }) {base_dilations = dense<1> : tensor<3xi64>, padding = dense<[[0, 0], [159, 0], [0, 0]]> : tensor<3x2xi64>, window_dilations = dense<1> : tensor<3xi64>, window_dimensions = dense<[1, 160, 1]> : tensor<3xi64>, window_strides = dense<1> : tensor<3xi64>} : (tensor<1x160x1xf32>, tensor<f32>) -> tensor<1x160x1xf32>
   return %0 : tensor<1x160x1xf32>
 }
 
@@ -364,7 +364,7 @@
 //CHECK-NEXT:  ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
 //CHECK-NEXT:   %1 = stablehlo.add %arg2, %arg3 : tensor<f32>
 //CHECK-NEXT:   stablehlo.return %1 : tensor<f32>
-//CHECK-NEXT{LITERAL}:  }) {padding = dense<[[0, 0], [159, 0], [0, 0]]> : tensor<3x2xi64>, window_dimensions = dense<[1, 160, 1]> : tensor<3xi64>, window_strides = dense<1> : tensor<3xi64>} : (tensor<1x160x1xf32>, tensor<f32>) -> tensor<1x160x1xf32>
+//CHECK-NEXT{LITERAL}:  }) {base_dilations = dense<1> : tensor<3xi64>, padding = dense<[[0, 0], [159, 0], [0, 0]]> : tensor<3x2xi64>, window_dilations = dense<1> : tensor<3xi64>, window_dimensions = dense<[1, 160, 1]> : tensor<3xi64>, window_strides = dense<1> : tensor<3xi64>} : (tensor<1x160x1xf32>, tensor<f32>) -> tensor<1x160x1xf32>
 //CHECK-NEXT: return %0 : tensor<1x160x1xf32>
 //CHECK-NEXT:}
 
diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf-variables.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf-variables.mlir
index 39d0aa8..54a10bf 100644
--- a/tensorflow/compiler/mlir/lite/tests/legalize-tf-variables.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf-variables.mlir
@@ -46,7 +46,7 @@
   }
 
   "tf_saved_model.session_initializer"() {initializers = [@init_all_tables]} : () -> ()
-  // CHECK: "tf_saved_model.session_initializer"() {initializers = [@init_all_tables]} : () -> ()
+  // CHECK: "tf_saved_model.session_initializer"() <{initializers = [@init_all_tables]}> : () -> ()
 
   // CHECK-LABEL: serving_default
   func.func @serving_default(%arg0: tensor<1x10xf32> {tf_saved_model.index_path = ["x"]}) ->
diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf-while.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf-while.mlir
index a807cc8..ab9b39b 100644
--- a/tensorflow/compiler/mlir/lite/tests/legalize-tf-while.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf-while.mlir
@@ -66,7 +66,7 @@
 // CANON:           },  {
 // CANON:           ^bb0([[VAL_11:%.*]]: tensor<*xi32>, [[VAL_12:%.*]]: tensor<*xi32>, [[VAL_13:%.*]]: tensor<*xf32>):
 // CANON-DAG:         [[VAL_4:%.*]] = arith.constant dense<1> : tensor<i32>
-// CANON-DAG:         [[VAL_5:%.*]] = "tf.Const"() {value = dense<2.560000e+02> : tensor<256x256xf32>} : () -> tensor<?x?xf32>
+// CANON-DAG:         [[VAL_5:%.*]] = "tf.Const"() <{value = dense<2.560000e+02> : tensor<256x256xf32>}> : () -> tensor<?x?xf32>
 // CANON:             [[VAL_14:%.*]] = "tf.AddV2"([[VAL_12]], [[VAL_4]])
 // CANON:             [[VAL_15:%.*]] = "tf.AddV2"([[VAL_13]], [[VAL_5]])
 // CANON:             [[VAL_16:%.*]] = "tf.AddV2"([[VAL_11]], [[VAL_4]])
diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir
index 444a494..685efd5 100644
--- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir
@@ -198,7 +198,7 @@
   func.return %0 : tensor<8x8x8x8xf32>
 
   // CHECK-LABEL: fakeQuantVarsTrue
-  // CHECK: "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {max = 1.000000e+00 : f32, min = 0.000000e+00 : f32, narrow_range = true, num_bits = 5 : i64}
+  // CHECK: "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) <{narrow_range = true, num_bits = 5 : i64}> {max = 1.000000e+00 : f32, min = 0.000000e+00 : f32}
 }
 
 func.func @fakeQuantArgsFalse4Bits(%arg0: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> {
@@ -235,7 +235,7 @@
   func.return %0 : tensor<8x8x8x8xf32>
 
   // CHECK-LABEL: fakeQuantVarsTrue
-  // CHECK: "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {max = 1.000000e+00 : f32, min = 0.000000e+00 : f32, narrow_range = true, num_bits = 3 : i64}
+  // CHECK: "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) <{narrow_range = true, num_bits = 3 : i64}> {max = 1.000000e+00 : f32, min = 0.000000e+00 : f32}
 }
 
 func.func @const() -> tensor<2xi32> {
@@ -1421,7 +1421,7 @@
   %0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 7 : i64, shrink_axis_mask = 0 : i64, offset = false} : (tensor<5x6x7xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x1x5x6x7xf32>
   func.return %0 : tensor<1x1x5x6x7xf32>
   // CHECK-LABEL: strided_slice_big_dims
-  // CHECK: %0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 7 : i64, offset = false, shrink_axis_mask = 0 : i64} : (tensor<5x6x7xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x1x5x6x7xf32>
+  // CHECK: %0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) <{begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 7 : i64, shrink_axis_mask = 0 : i64}> {offset = false} : (tensor<5x6x7xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x1x5x6x7xf32>
 }
 
 func.func @slice1Tensor(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>, %arg2: tensor<3xi32>) -> tensor<?x3x5xf32> {
@@ -1606,7 +1606,7 @@
   %0 = "tf.SparseToDense"(%arg0, %arg1, %arg2, %arg3) {validate_indices = true}: (tensor<3x5xi32>, tensor<3xi32>, tensor<2xf32>, tensor<f32>) -> tensor<?x?x?xf32>
   func.return %0 : tensor<?x?x?xf32>
   // CHECK-LABEL: sparse_to_dense_with_2d_sparse_indices_and_second_dim_greater_than_4
-  // CHECK: "tf.SparseToDense"(%arg0, %arg1, %arg2, %arg3) {validate_indices = true} : (tensor<3x5xi32>, tensor<3xi32>, tensor<2xf32>, tensor<f32>) -> tensor<?x?x?xf32>
+  // CHECK: "tf.SparseToDense"(%arg0, %arg1, %arg2, %arg3) <{validate_indices = true}> : (tensor<3x5xi32>, tensor<3xi32>, tensor<2xf32>, tensor<f32>) -> tensor<?x?x?xf32>
 }
 
 func.func @where(%arg0: tensor<3x5xi1>) -> tensor<?x2xi64> {
@@ -2311,7 +2311,7 @@
   %0 = "tf.Conv3D"(%arg0, %arg1) {padding = "SAME", strides = [2, 1, 1, 1, 1]} : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
   func.return %0: tensor<?x?x?x?x?xf32>
   // CHECK-LABEL: conv3d_invalid_strides
-  // CHECK:  [[BCT:%.*]] = "tf.Conv3D"(%arg0, %arg1) {padding = "SAME", strides = [2, 1, 1, 1, 1]} : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  // CHECK:  [[BCT:%.*]] = "tf.Conv3D"(%arg0, %arg1) <{padding = "SAME", strides = [2, 1, 1, 1, 1]}> : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
   // CHECK:  return [[BCT]] : tensor<?x?x?x?x?xf32>
 }
 
@@ -2705,7 +2705,7 @@
   func.return %values, %indices: tensor<1x4xf32>, tensor<1x4xi32>
 
   // CHECK-LABEL: approx_top_k_with_min_k
-  // CHECK:  %values, %indices = "tf.ApproxTopK"(%arg0) {aggregate_to_topk = true, is_max_k = false, k = 4 : i64, recall_target = 8.500000e-01 : f32, reduction_dimension = 1 : i64, reduction_input_size_override = -1 : i64} : (tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor<1x4xi32>)
+  // CHECK:  %values, %indices = "tf.ApproxTopK"(%arg0) <{aggregate_to_topk = true, is_max_k = false, k = 4 : i64, recall_target = 8.500000e-01 : f32, reduction_dimension = 1 : i64, reduction_input_size_override = -1 : i64}> : (tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor<1x4xi32>)
   // CHECK:  return %values, %indices : tensor<1x4xf32>, tensor<1x4xi32>
 }
 
@@ -2714,7 +2714,7 @@
   func.return %values, %indices: tensor<1x4xf32>, tensor<1x4xi32>
 
   // CHECK-LABEL: approx_top_k_reduction_dimension_not_last_dim
-  // CHECK:  %values, %indices = "tf.ApproxTopK"(%arg0) {aggregate_to_topk = true, is_max_k = true, k = 4 : i64, recall_target = 8.500000e-01 : f32, reduction_dimension = 0 : i64, reduction_input_size_override = -1 : i64} : (tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor<1x4xi32>)
+  // CHECK:  %values, %indices = "tf.ApproxTopK"(%arg0) <{aggregate_to_topk = true, is_max_k = true, k = 4 : i64, recall_target = 8.500000e-01 : f32, reduction_dimension = 0 : i64, reduction_input_size_override = -1 : i64}> : (tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor<1x4xi32>)
   // CHECK:  return %values, %indices : tensor<1x4xf32>, tensor<1x4xi32>
 }
 
diff --git a/tensorflow/compiler/mlir/lite/tests/lift_tflite_flex_ops.mlir b/tensorflow/compiler/mlir/lite/tests/lift_tflite_flex_ops.mlir
index a03519f..8ed0fe8 100644
--- a/tensorflow/compiler/mlir/lite/tests/lift_tflite_flex_ops.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/lift_tflite_flex_ops.mlir
@@ -20,7 +20,7 @@
     custom_option = #tfl<const_bytes : "0x0D42617463684D61744D756C56320038120D42617463684D61744D756C56321A001A002A070A0154120230012A0B0A0561646A5F78120228002A0B0A0561646A5F791202280032000002493B1414042801">
   } : (tensor<4x128x2xf32>, tensor<2x1xf32>) -> tensor<4x128x1xf32>
 
-// CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false} : (tensor<4x128x2xf32>, tensor<2x1xf32>) -> tensor<4x128x1xf32>
+// CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) <{adj_x = false, adj_y = false}> {T = f32} : (tensor<4x128x2xf32>, tensor<2x1xf32>) -> tensor<4x128x1xf32>
   func.return %0 : tensor<4x128x1xf32>
 }
 
@@ -66,7 +66,7 @@
 
   func.return %0 : tensor<!tf_type.variant>
 // CHECK: "tf.MapDataset"(
-// CHECK-SAME: {Targuments = [], f = @{{.*}}, metadata = "", output_shapes = [#tf_type.shape<>], output_types = [!tf_type.string], preserve_cardinality = true, use_inter_op_parallelism = true}
+// CHECK-SAME: <{f = @{{.*}}, metadata = "", output_shapes = [#tf_type.shape<>], output_types = [!tf_type.string], preserve_cardinality = true, use_inter_op_parallelism = true}> {Targuments = []}
 }
 
 // CHECK-LABEL: TfTakeWhileDataset
@@ -78,7 +78,7 @@
 
   func.return %0 : tensor<!tf_type.variant>
 // CHECK: "tf.TakeWhileDataset"(
-// CHECK-SAME: {Targuments = [!tf_type.resource, !tf_type.resource, i64, !tf_type.resource, !tf_type.resource, !tf_type.resource, !tf_type.resource, i64], metadata = "", output_shapes = [#tf_type.shape<>], output_types = [!tf_type.string], predicate = @{{.*}}}
+// CHECK-SAME: <{metadata = "", output_shapes = [#tf_type.shape<>], output_types = [!tf_type.string], predicate = @{{.*}}}> {Targuments = [!tf_type.resource, !tf_type.resource, i64, !tf_type.resource, !tf_type.resource, !tf_type.resource, !tf_type.resource, i64]}
 }
 
 // CHECK-LABEL: FailureOnInvalidOp
diff --git a/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir b/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir
index 3bfecee..79d9691 100644
--- a/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir
@@ -4,9 +4,9 @@
 
 // CHECK-LABEL: tensorlistConst
 func.func @tensorlistConst(%arg0 : tensor<1xi32>) -> tensor<2x3xi32> {
-  // CHECK-DAG: %[[ELEMENT0:.*]] = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
-  // CHECK-DAG: %[[ELEMENT1:.*]] = "tf.Const"() {value = dense<[3, 4, 5]> : tensor<3xi32>} : () -> tensor<3xi32>
-  // CHECK: %[[LIST:.*]] = "tf.Pack"(%[[ELEMENT0]], %[[ELEMENT1]]) {axis = 0 : i64} : (tensor<3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
+  // CHECK-DAG: %[[ELEMENT0:.*]] = "tf.Const"() <{value = dense<[0, 1, 2]> : tensor<3xi32>}> : () -> tensor<3xi32>
+  // CHECK-DAG: %[[ELEMENT1:.*]] = "tf.Const"() <{value = dense<[3, 4, 5]> : tensor<3xi32>}> : () -> tensor<3xi32>
+  // CHECK: %[[LIST:.*]] = "tf.Pack"(%[[ELEMENT0]], %[[ELEMENT1]]) <{axis = 0 : i64}> : (tensor<3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
   %0 = "tf.Const"() {value = #tf_type<tensor_proto : "0x746674656E736F722464747970653A2044545F56415249414E542074656E736F725F7368617065207B207D2074656E736F725F636F6E74656E743A2022485C6E5C30323674656E736F72666C6F773A3A54656E736F724C6973745C3032325C3032305C3030305C3030335C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3030315C3032325C3030325C3031305C3030335C3033325C725C3031305C3030335C3032325C3030345C3032325C3030325C3031305C3030333A5C3030335C3030305C3030315C3030325C3033325C725C3031305C3030335C3032325C3030345C3032325C3030325C3031305C3030333A5C3030335C3030335C3030345C30303522"> : tensor<!tf_type.variant>} : () -> tensor<!tf_type.variant<tensor<3xi32>>>
 
   // CHECK: return %[[LIST]]
@@ -20,7 +20,7 @@
 func.func @emptyTensorlistConst(%arg0 : tensor<1xi32>) -> tensor<0x3xi32> {
   %0 = "tf.Const"() {value = #tf_type<tensor_proto : "0x746674656E736F722464747970653A2044545F56415249414E542074656E736F725F7368617065207B207D2074656E736F725F636F6E74656E743A20222A5C6E5C30323674656E736F72666C6F773A3A54656E736F724C6973745C3032325C3032305C3030305C3030335C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3030315C3032325C3030325C3031305C30303322"> : tensor<!tf_type.variant>} : () -> tensor<!tf_type.variant<tensor<3xi32>>>
 
-  // CHECK: "tf.Const"() {value = dense<> : tensor<0x3xi32>} : () -> tensor<0x3xi32>
+  // CHECK: "tf.Const"() <{value = dense<> : tensor<0x3xi32>}> : () -> tensor<0x3xi32>
   // CHECK-NOT: tf.TensorListStack
   %1 = "tf.TensorListStack"(%0, %arg0) : (tensor<!tf_type.variant<tensor<3xi32>>>, tensor<1xi32>) -> tensor<0x3xi32>
   func.return %1 : tensor<0x3xi32>
@@ -35,7 +35,7 @@
   %2 = "tf.TensorListStack"(%0, %arg1) : (tensor<!tf_type.variant<tensor<10xf32>>>, tensor<1xi32>) -> tensor<3x10xf32>
   func.return %1, %2 : tensor<10xf32>, tensor<3x10xf32>
 
-// CHECK:  %0 = "tf.Gather"(%arg0, %arg2) {validate_indices = true} : (tensor<3x10xf32>, tensor<i32>) -> tensor<10xf32>
+// CHECK:  %0 = "tf.Gather"(%arg0, %arg2) <{validate_indices = true}> : (tensor<3x10xf32>, tensor<i32>) -> tensor<10xf32>
 // CHECK: return %0, %arg0 : tensor<10xf32>, tensor<3x10xf32>
 }
 
@@ -48,7 +48,7 @@
   %2 = "tf.TensorListStack"(%0, %arg1) : (tensor<!tf_type.variant<tensor<*xf32>>>, tensor<1xi32>) -> tensor<*xf32>
   func.return %1, %2 : tensor<*xf32>, tensor<*xf32>
 
-// CHECK:  %0 = "tf.Gather"(%arg0, %arg2) {validate_indices = true} : (tensor<*xf32>, tensor<i32>) -> tensor<*xf32>
+// CHECK:  %0 = "tf.Gather"(%arg0, %arg2) <{validate_indices = true}> : (tensor<*xf32>, tensor<i32>) -> tensor<*xf32>
 // CHECK: return %0, %arg0 : tensor<*xf32>, tensor<*xf32>
 }
 
@@ -175,7 +175,7 @@
 // CHECK-DAG:  [[SHAPE:%.*]] = "tf.Concat"([[ZERO2]], [[DIM0]], %arg0) : (tensor<i32>, tensor<1xi32>, tensor<3xi32>) -> tensor<4xi32>
 // CHECK-DAG:  [[VALUES:%.*]] = arith.constant dense<0.000000e+00> : tensor<f32>
 // CHECK:      [[LIST:%.*]] = "tf.Fill"([[SHAPE]], [[VALUES]]) : (tensor<4xi32>, tensor<f32>) -> tensor<?x?x?x?xf32>
-// CHECK:      [[RESULT:%.*]] = "tf.Gather"([[LIST]], %arg2) {validate_indices = true} : (tensor<?x?x?x?xf32>, tensor<i32>) -> tensor<?x?x?xf32>
+// CHECK:      [[RESULT:%.*]] = "tf.Gather"([[LIST]], %arg2) <{validate_indices = true}> : (tensor<?x?x?x?xf32>, tensor<i32>) -> tensor<?x?x?xf32>
 // CHECK:      return [[RESULT]] : tensor<?x?x?xf32>
 }
 
@@ -188,7 +188,7 @@
   func.return %1 : tensor<*xf32>
 
 // CHECK:  [[RESULT:%[0-9]+]] = "tf.Fill"{{.*}}(tensor<?xi32>, tensor<f32>) -> tensor<*xf32>
-// CHECK:  [[RESULT2:%[0-9]+]] = "tf.Gather"{{.*}}{validate_indices = true} : (tensor<*xf32>, tensor<i32>) -> tensor<*xf32>
+// CHECK:  [[RESULT2:%[0-9]+]] = "tf.Gather"{{.*}}<{validate_indices = true}> : (tensor<*xf32>, tensor<i32>) -> tensor<*xf32>
 // CHECK:  return [[RESULT2]] : tensor<*xf32>
 }
 
@@ -208,7 +208,7 @@
 // CHECK-DAG:  [[SHAPE:%.*]] = "tf.Concat"([[ZERO2]], [[DIM0]], [[ELEMENT_SHAPE]]) : (tensor<i32>, tensor<1xi32>, tensor<2xi32>) -> tensor<3xi32>
 // CHECK-DAG:  [[VALUES:%.*]] = arith.constant dense<0.000000e+00> : tensor<f32>
 // CHECK:      [[LIST:%.*]] = "tf.Fill"([[SHAPE]], [[VALUES]]) : (tensor<3xi32>, tensor<f32>) -> tensor<?x?x7xf32>
-// CHECK:      [[RESULT:%.*]] = "tf.Gather"([[LIST]], %arg1) {validate_indices = true} : (tensor<?x?x7xf32>, tensor<i32>) -> tensor<?x7xf32>
+// CHECK:      [[RESULT:%.*]] = "tf.Gather"([[LIST]], %arg1) <{validate_indices = true}> : (tensor<?x?x7xf32>, tensor<i32>) -> tensor<?x7xf32>
 // CHECK:      return [[RESULT]] : tensor<?x7xf32>
 }
 
@@ -245,7 +245,7 @@
 // CHECK-DAG:  [[CONCAT:%.*]] = "tf.Concat"([[AXIS_1]], [[EXPAND_DIM]], %arg0) : (tensor<i32>, tensor<1xi32>, tensor<*xi32>) -> tensor<?xi32>
 // CHECK:  [[CST:%.*]] = arith.constant dense<0.000000e+00> : tensor<f32>
 // CHECK:  [[FILL:%.*]] = "tf.Fill"([[CONCAT]], [[CST]]) : (tensor<?xi32>, tensor<f32>) -> tensor<*xf32>
-// CHECK:  [[GATHER:%.*]] = "tf.Gather"([[FILL]], %arg2) {validate_indices = true} : (tensor<*xf32>, tensor<i32>) -> tensor<*xf32>
+// CHECK:  [[GATHER:%.*]] = "tf.Gather"([[FILL]], %arg2) <{validate_indices = true}> : (tensor<*xf32>, tensor<i32>) -> tensor<*xf32>
 // CHECK:  return [[GATHER]] : tensor<*xf32>
 }
 
@@ -263,7 +263,7 @@
 // CHECK-DAG:  [[SHAPE:%.*]] = "tf.Concat"([[ZERO]], [[DIM0]], [[ELEM_SHAPE]]) : (tensor<i32>, tensor<1xi32>, tensor<3xi32>) -> tensor<4xi32>
 // CHECK-DAG:  [[VALUES:%.*]] = arith.constant dense<0.000000e+00> : tensor<f32>
 // CHECK:      [[LIST:%.*]] = "tf.Fill"([[SHAPE]], [[VALUES]]) : (tensor<4xi32>, tensor<f32>) -> tensor<0x?x?x?xf32>
-// CHECK:      [[RESULT:%.*]] = "tf.Gather"([[LIST]], [[IDX]]) {validate_indices = true} : (tensor<0x?x?x?xf32>, tensor<i32>) -> tensor<?x?x?xf32>
+// CHECK:      [[RESULT:%.*]] = "tf.Gather"([[LIST]], [[IDX]]) <{validate_indices = true}> : (tensor<0x?x?x?xf32>, tensor<i32>) -> tensor<?x?x?xf32>
 // CHECK:      return [[RESULT]] : tensor<?x?x?xf32>
 }
 
@@ -294,7 +294,7 @@
 // CHECK-SAME: ([[INPUT:%.*]]: tensor<3x10xf32>, [[ELEM_SHAPE:%.*]]: tensor<1xi32>)
 // CHECK-DAG: [[SHAPE:%.*]] = "tf.Shape"([[INPUT]]) {{.*}} -> tensor<2xi32>
 // CHECK-DAG: [[ZERO:%cst.*]] = arith.constant dense<0> : tensor<i32>
-// CHECK: [[RESULT:%.*]] = "tf.Gather"([[SHAPE]], [[ZERO]]) {validate_indices = true} : (tensor<2xi32>, tensor<i32>) -> tensor<i32>
+// CHECK: [[RESULT:%.*]] = "tf.Gather"([[SHAPE]], [[ZERO]]) <{validate_indices = true}> : (tensor<2xi32>, tensor<i32>) -> tensor<i32>
 // CHECK: return [[RESULT]] : tensor<i32>
 }
 
@@ -352,7 +352,8 @@
   %cst_1 = arith.constant dense<-1> : tensor<i32>
   %0 = "tf.TensorListFromTensor"(%arg0, %cst) : (tensor<2x3xf32>, tensor<1xi32>) -> tensor<!tf_type.variant<tensor<3xf32>>>
   // CHECK: "tf.WhileRegion"
-  %1:2 = "tf.WhileRegion"(%cst_0, %0) ({
+  // CHECK: <{is_stateless = false}>
+  %1:2 = "tf.WhileRegion"(%cst_0, %0) <{is_stateless = false}> ({
       ^bb0(%carg0: tensor<i32>, %carg1: tensor<!tf_type.variant>):
        %cst_2 = arith.constant dense<2> : tensor<i32>
        %1 = "tf.Less"(%carg0, %cst_2) : (tensor<i32>, tensor<i32>) -> tensor<i1>
@@ -376,9 +377,9 @@
       // CHECK-NOT: tensor<!tf_type.variant>
       // CHECK:  "tf.Yield"(%[[LEN]], %[[BARG1]]) : (tensor<i32>, tensor<*xf32>) -> ()
 
-  }) {is_stateless = false} : (tensor<i32>, tensor<!tf_type.variant<tensor<3xf32>>>) -> (tensor<i32>, tensor<!tf_type.variant<tensor<*xf32>>>)
+  }) : (tensor<i32>, tensor<!tf_type.variant<tensor<3xf32>>>) -> (tensor<i32>, tensor<!tf_type.variant<tensor<*xf32>>>)
   // make sure the variant types in input/output have been updated
-  // CHECK: {is_stateless = false} : (tensor<i32>, tensor<2x3xf32>) -> (tensor<i32>, tensor<*xf32>)
+  // : (tensor<i32>, tensor<2x3xf32>) -> (tensor<i32>, tensor<*xf32>)
   %2 = "tf.TensorListStack"(%1#1, %cst_1) : (tensor<!tf_type.variant<tensor<*xf32>>>, tensor<i32>) -> tensor<*xf32>
   // CHECK:  return %0#1 : tensor<*xf32>
   func.return %2 : tensor<*xf32>
@@ -443,11 +444,11 @@
 // CHECK:  [[ZERO:%.*]] = arith.constant dense<0> : tensor<i32>
 // CHECK:  [[SHAPE:%.*]] = "tf.Shape"([[INPUT]]) : (tensor<3x10xf32>) -> tensor<2xi32>
 // CHECK:  [[ZERO_1:%.*]] = arith.constant dense<0> : tensor<i32>
-// CHECK:  [[INPUT_SIZE:%.*]] = "tf.Gather"([[SHAPE]], [[ZERO_1]]) {validate_indices = true} : (tensor<2xi32>, tensor<i32>) -> tensor<i32>
+// CHECK:  [[INPUT_SIZE:%.*]] = "tf.Gather"([[SHAPE]], [[ZERO_1]]) <{validate_indices = true}> : (tensor<2xi32>, tensor<i32>) -> tensor<i32>
 // CHECK:  [[SIZE_DIFF:%.*]] = "tf.Sub"([[SIZE]], [[INPUT_SIZE]]) : (tensor<i32>, tensor<i32>) -> tensor<i32>
 // CHECK:  [[DIFF_RES:%.*]] = "tf.Greater"([[SIZE_DIFF]], [[ZERO]]) : (tensor<i32>, tensor<i32>) -> tensor<i1>
 // CHECK:  [[SHAPE_1:%.*]] = "tf.Shape"([[INPUT]]) : (tensor<3x10xf32>) -> tensor<?xi32>
-// CHECK:  [[RESULT:%.*]] = "tf.If"([[DIFF_RES]], [[INPUT]], [[SHAPE_1]], [[SIZE_DIFF]], [[SIZE]]) {else_branch = @cond_false, is_stateless = true, then_branch = @cond_true} : (tensor<i1>, tensor<3x10xf32>, tensor<?xi32>, tensor<i32>, tensor<i32>) -> tensor<?x10xf32>
+// CHECK:  [[RESULT:%.*]] = "tf.If"([[DIFF_RES]], [[INPUT]], [[SHAPE_1]], [[SIZE_DIFF]], [[SIZE]]) <{else_branch = @cond_false, is_stateless = true, then_branch = @cond_true}> : (tensor<i1>, tensor<3x10xf32>, tensor<?xi32>, tensor<i32>, tensor<i32>) -> tensor<?x10xf32>
 // CHECK:  return [[RESULT]] : tensor<?x10xf32>
 }
 
@@ -510,7 +511,7 @@
   func.return %t#0, %t#1 : tensor<?x2xf32>, tensor<0xi64>
 
 // CHECK: [[ELEMENT_SHAPE:%.*]] = arith.constant dense<2> : tensor<2xi32>
-// CHECK: [[UNPACK:%.*]]:3 = "tf.Unpack"(%arg0) {axis = 0 : i64} : (tensor<3x2x2xf32>) -> (tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2xf32>)
+// CHECK: [[UNPACK:%.*]]:3 = "tf.Unpack"(%arg0) <{axis = 0 : i64}> : (tensor<3x2x2xf32>) -> (tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2xf32>)
 // CHECK: [[SCALAR_ZERO:%.*]] = arith.constant dense<0> : tensor<i32>
 // CHECK: [[CONCAT:%.*]] = "tf.Concat"([[SCALAR_ZERO]], [[UNPACK]]#0, [[UNPACK]]#1, [[UNPACK]]#2) : (tensor<i32>, tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<?x2xf32>
 // CHECK: [[LENGTHS:%.*]] = arith.constant dense<0> : tensor<0xi64>
@@ -567,7 +568,7 @@
 // CHECK: func @tensorListIf
 // CHECK-NEXT:  %cst = arith.constant dense<2> : tensor<i32>
 // CHECK-NEXT:  %0 = "tf.Less"(%arg2, %cst) : (tensor<i32>, tensor<i32>) -> tensor<i1>
-// CHECK-NEXT:  %1 = "tf.If"(%0, %arg0) {else_branch = @tensorListIfCondFalse, is_stateless = true, then_branch = @tensorListIfCondTrue} : (tensor<i1>, tensor<3x10xf32>) -> tensor<3x10xf32>
+// CHECK-NEXT:  %1 = "tf.If"(%0, %arg0) <{else_branch = @tensorListIfCondFalse, is_stateless = true, then_branch = @tensorListIfCondTrue}> : (tensor<i1>, tensor<3x10xf32>) -> tensor<3x10xf32>
 // CHECK-NEXT:  return %1 : tensor<3x10xf32>
 }
 
diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/metadata_buffer.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/metadata_buffer.mlir
new file mode 100644
index 0000000..f53f395
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/metadata_buffer.mlir
@@ -0,0 +1,11 @@
+// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
+
+module attributes {tfl.metadata_buffer = [3 : i32, 7 : i32]} {
+  func.func @main(%arg0: tensor<i32>, %arg1: tensor<3x2xi32>) -> tensor<3x2xi32> {
+    %0 = "tfl.add" (%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<i32>, tensor<3x2xi32>) -> tensor<3x2xi32>
+    func.return %0 : tensor<3x2xi32>
+  }
+}
+
+// CHECK: metadata_buffer: [ 3, 7 ],
+// CHECK-NEXT: metadata:
\ No newline at end of file
diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir
index 52e1599..20c4a03 100644
--- a/tensorflow/compiler/mlir/lite/tests/ops.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir
@@ -410,7 +410,7 @@
 
 func.func @add_with_quantized_i16_broadcasting(tensor<2x2xf32>, tensor<1xf32>) -> tensor<2x2x!quant.any<i16:f32>> {
 ^bb0(%arg0: tensor<2x2xf32>, %arg1: tensor<1xf32>):
-  // expected-error @+1 {{Operands do not have valid shapes}}
+  // expected-error @+1 {{Operands should have valid shapes and element type needs to match}}
   %0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<2x2xf32>, tensor<1xf32>) -> tensor<2x2x!quant.any<i16:f32>>
   func.return %0#0 : tensor<2x2x!quant.any<i16:f32>>
 }
diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir
index 7aac866..27d98c7 100644
--- a/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir
@@ -194,24 +194,24 @@
 // CHECK:           [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi32>) -> tensor<40x8xf32>
 // CHECK:           [[VAL_8:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32>
 // CHECK:           [[VAL_9:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_8]]) : (tensor<10x40xf32>, tensor<2xi32>) -> tensor<40x10xf32>
-// CHECK-DAG:       [[VAL_10:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK-DAG:       [[VAL_11:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       [[VAL_10:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG:       [[VAL_11:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           [[VAL_12:%.*]]:4 = "tf.SplitV"([[VAL_7]], [[VAL_10]], [[VAL_11]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
-// CHECK-DAG:       [[VAL_13:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK-DAG:       [[VAL_14:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       [[VAL_13:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG:       [[VAL_14:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           [[VAL_15:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_13]], [[VAL_14]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>)
-// CHECK-DAG:       [[VAL_16:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK-DAG:       [[VAL_17:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       [[VAL_16:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG:       [[VAL_17:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           [[VAL_18:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_16]], [[VAL_17]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
 // CHECK:           [[VAL_19:%.*]] = "tfl.no_value"() {value} : () -> none
 // CHECK:           [[VAL_20:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) {cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = false, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<?x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<?x8x10xf32>
 // CHECK-DAG:       [[VAL_21:%.*]] = arith.constant dense<[-1, 0, 0]> : tensor<3xi32>
 // CHECK-DAG:       [[VAL_22:%.*]] = arith.constant dense<0> : tensor<3xi32>
 // CHECK-DAG:       [[VAL_23:%.*]] = arith.constant dense<1> : tensor<3xi32>
-// CHECK:           [[VAL_24:%.*]] = "tf.StridedSlice"([[VAL_20]], [[VAL_21]], [[VAL_22]], [[VAL_23]]) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<?x8x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32>
-// CHECK-DAG:       [[VAL_25:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<8x10xf32>
-// CHECK-DAG:       [[VAL_26:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<8x10xf32>
-// CHECK-DAG:       [[VAL_27:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<f32>
+// CHECK:           [[VAL_24:%.*]] = "tf.StridedSlice"([[VAL_20]], [[VAL_21]], [[VAL_22]], [[VAL_23]]) <{begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64}> : (tensor<?x8x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32>
+// CHECK-DAG:       [[VAL_25:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<8x10xf32>
+// CHECK-DAG:       [[VAL_26:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<8x10xf32>
+// CHECK-DAG:       [[VAL_27:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<f32>
 // CHECK:           return [[VAL_24]], [[VAL_20]], [[VAL_25]], [[VAL_26]], [[VAL_27]] : tensor<8x10xf32>, tensor<?x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>
 // CHECK:         }
 }
@@ -240,32 +240,32 @@
 // CHECK:           [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi32>) -> tensor<40x8xf32>
 // CHECK:           [[VAL_8:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32>
 // CHECK:           [[VAL_9:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_8]]) : (tensor<10x4xf32>, tensor<2xi32>) -> tensor<4x10xf32>
-// CHECK-DAG:       [[VAL_10:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK-DAG:       [[VAL_11:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       [[VAL_10:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG:       [[VAL_11:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           [[VAL_12:%.*]]:4 = "tf.SplitV"([[VAL_7]], [[VAL_10]], [[VAL_11]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
-// CHECK-DAG:       [[VAL_13:%.*]] = "tf.Const"() {value = dense<1> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK-DAG:       [[VAL_14:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       [[VAL_13:%.*]] = "tf.Const"() <{value = dense<1> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG:       [[VAL_14:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           [[VAL_15:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_13]], [[VAL_14]]) : (tensor<4x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>)
-// CHECK-DAG:       [[VAL_20:%.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
+// CHECK-DAG:       [[VAL_20:%.*]] = "tf.Const"() <{value = dense<-1> : tensor<1xi32>}> : () -> tensor<1xi32>
 // CHECK:           [[VAL_21:%.*]] = "tf.Reshape"([[VAL_15]]#0, [[VAL_20]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32>
-// CHECK-DAG:       [[VAL_22:%.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
+// CHECK-DAG:       [[VAL_22:%.*]] = "tf.Const"() <{value = dense<-1> : tensor<1xi32>}> : () -> tensor<1xi32>
 // CHECK:           [[VAL_23:%.*]] = "tf.Reshape"([[VAL_15]]#1, [[VAL_22]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32>
-// CHECK-DAG:       [[VAL_24:%.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
+// CHECK-DAG:       [[VAL_24:%.*]] = "tf.Const"() <{value = dense<-1> : tensor<1xi32>}> : () -> tensor<1xi32>
 // CHECK:           [[VAL_25:%.*]] = "tf.Reshape"([[VAL_15]]#2, [[VAL_24]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32>
-// CHECK-DAG:       [[VAL_26:%.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
+// CHECK-DAG:       [[VAL_26:%.*]] = "tf.Const"() <{value = dense<-1> : tensor<1xi32>}> : () -> tensor<1xi32>
 // CHECK:           [[VAL_27:%.*]] = "tf.Reshape"([[VAL_15]]#3, [[VAL_26]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32>
-// CHECK-DAG:       [[VAL_28:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK-DAG:       [[VAL_29:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       [[VAL_28:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG:       [[VAL_29:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           [[VAL_30:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_28]], [[VAL_29]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
 // CHECK:           [[VAL_31:%.*]] = "tfl.no_value"() {value} : () -> none
 // CHECK:           [[VAL_32:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_21]], [[VAL_23]], [[VAL_25]], [[VAL_27]], [[VAL_31]], [[VAL_31]], [[VAL_31]], [[VAL_30]]#0, [[VAL_30]]#1, [[VAL_30]]#2, [[VAL_30]]#3, [[VAL_31]], [[VAL_31]], [[VAL_1]], [[VAL_2]], [[VAL_31]], [[VAL_31]], [[VAL_31]], [[VAL_31]]) {cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = true, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<8x8x10xf32>
 // CHECK-DAG:       [[VAL_33:%.*]] = arith.constant dense<[-1, 0, 0]> : tensor<3xi32>
 // CHECK-DAG:       [[VAL_34:%.*]] = arith.constant dense<0> : tensor<3xi32>
 // CHECK-DAG:       [[VAL_35:%.*]] = arith.constant dense<1> : tensor<3xi32>
-// CHECK:           [[VAL_36:%.*]] = "tf.StridedSlice"([[VAL_32]], [[VAL_33]], [[VAL_34]], [[VAL_35]]) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<8x8x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32>
-// CHECK-DAG:       [[VAL_37:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<8x10xf32>
-// CHECK-DAG:       [[VAL_38:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<8x10xf32>
-// CHECK-DAG:       [[VAL_39:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<f32>
+// CHECK:           [[VAL_36:%.*]] = "tf.StridedSlice"([[VAL_32]], [[VAL_33]], [[VAL_34]], [[VAL_35]]) <{begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64}> : (tensor<8x8x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32>
+// CHECK-DAG:       [[VAL_37:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<8x10xf32>
+// CHECK-DAG:       [[VAL_38:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<8x10xf32>
+// CHECK-DAG:       [[VAL_39:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<f32>
 // CHECK:           return [[VAL_36]], [[VAL_32]], [[VAL_37]], [[VAL_38]], [[VAL_39]] : tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>
 // CHECK:         }
 
@@ -290,24 +290,24 @@
 // CHECK:           [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi32>) -> tensor<40x8xf32>
 // CHECK:           [[VAL_8:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32>
 // CHECK:           [[VAL_9:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_8]]) : (tensor<10x40xf32>, tensor<2xi32>) -> tensor<40x10xf32>
-// CHECK-DAG:       [[VAL_10:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK-DAG:       [[VAL_11:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       [[VAL_10:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG:       [[VAL_11:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           [[VAL_12:%.*]]:4 = "tf.SplitV"([[VAL_7]], [[VAL_10]], [[VAL_11]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
-// CHECK-DAG:       [[VAL_13:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK-DAG:       [[VAL_14:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       [[VAL_13:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG:       [[VAL_14:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           [[VAL_15:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_13]], [[VAL_14]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>)
-// CHECK-DAG:       [[VAL_16:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK-DAG:       [[VAL_17:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       [[VAL_16:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG:       [[VAL_17:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           [[VAL_18:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_16]], [[VAL_17]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
 // CHECK:           [[VAL_19:%.*]] = "tfl.no_value"() {value} : () -> none
 // CHECK:           [[VAL_20:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) {cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = false, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<8x8x10xf32>
 // CHECK-DAG:       [[VAL_21:%.*]] = arith.constant dense<[0, -1, 0]> : tensor<3xi32>
 // CHECK-DAG:       [[VAL_22:%.*]] = arith.constant dense<0> : tensor<3xi32>
 // CHECK-DAG:       [[VAL_23:%.*]] = arith.constant dense<1> : tensor<3xi32>
-// CHECK:           [[VAL_24:%.*]] = "tf.StridedSlice"([[VAL_20]], [[VAL_21]], [[VAL_22]], [[VAL_23]]) {begin_mask = 5 : i64, ellipsis_mask = 0 : i64, end_mask = 5 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 2 : i64} : (tensor<8x8x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32>
-// CHECK-DAG:       [[VAL_25:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<8x10xf32>
-// CHECK-DAG:       [[VAL_26:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<8x10xf32>
-// CHECK-DAG:       [[VAL_27:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<f32>
+// CHECK:           [[VAL_24:%.*]] = "tf.StridedSlice"([[VAL_20]], [[VAL_21]], [[VAL_22]], [[VAL_23]]) <{begin_mask = 5 : i64, ellipsis_mask = 0 : i64, end_mask = 5 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 2 : i64}> : (tensor<8x8x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32>
+// CHECK-DAG:       [[VAL_25:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<8x10xf32>
+// CHECK-DAG:       [[VAL_26:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<8x10xf32>
+// CHECK-DAG:       [[VAL_27:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<f32>
 // CHECK:           return [[VAL_24]], [[VAL_20]], [[VAL_25]], [[VAL_26]], [[VAL_27]] : tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>
 // CHECK:         }
 
@@ -337,32 +337,32 @@
 // CHECK:           [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi32>) -> tensor<40x8xf32>
 // CHECK:           [[VAL_8:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32>
 // CHECK:           [[VAL_9:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_8]]) : (tensor<10x4xf32>, tensor<2xi32>) -> tensor<4x10xf32>
-// CHECK-DAG:       [[VAL_10:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK-DAG:       [[VAL_11:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       [[VAL_10:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG:       [[VAL_11:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           [[VAL_12:%.*]]:4 = "tf.SplitV"([[VAL_7]], [[VAL_10]], [[VAL_11]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
-// CHECK-DAG:       [[VAL_13:%.*]] = "tf.Const"() {value = dense<1> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK-DAG:       [[VAL_14:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       [[VAL_13:%.*]] = "tf.Const"() <{value = dense<1> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG:       [[VAL_14:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           [[VAL_15:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_13]], [[VAL_14]]) : (tensor<4x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>)
-// CHECK-DAG:       [[VAL_20:%.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
+// CHECK-DAG:       [[VAL_20:%.*]] = "tf.Const"() <{value = dense<-1> : tensor<1xi32>}> : () -> tensor<1xi32>
 // CHECK:           [[VAL_21:%.*]] = "tf.Reshape"([[VAL_15]]#0, [[VAL_20]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32>
-// CHECK-DAG:       [[VAL_22:%.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
+// CHECK-DAG:       [[VAL_22:%.*]] = "tf.Const"() <{value = dense<-1> : tensor<1xi32>}> : () -> tensor<1xi32>
 // CHECK:           [[VAL_23:%.*]] = "tf.Reshape"([[VAL_15]]#1, [[VAL_22]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32>
-// CHECK-DAG:       [[VAL_24:%.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
+// CHECK-DAG:       [[VAL_24:%.*]] = "tf.Const"() <{value = dense<-1> : tensor<1xi32>}> : () -> tensor<1xi32>
 // CHECK:           [[VAL_25:%.*]] = "tf.Reshape"([[VAL_15]]#2, [[VAL_24]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32>
-// CHECK-DAG:       [[VAL_26:%.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
+// CHECK-DAG:       [[VAL_26:%.*]] = "tf.Const"() <{value = dense<-1> : tensor<1xi32>}> : () -> tensor<1xi32>
 // CHECK:           [[VAL_27:%.*]] = "tf.Reshape"([[VAL_15]]#3, [[VAL_26]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32>
-// CHECK-DAG:       [[VAL_28:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK-DAG:       [[VAL_29:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       [[VAL_28:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG:       [[VAL_29:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           [[VAL_30:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_28]], [[VAL_29]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
 // CHECK:           [[VAL_31:%.*]] = "tfl.no_value"() {value} : () -> none
 // CHECK:           [[VAL_32:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_21]], [[VAL_23]], [[VAL_25]], [[VAL_27]], [[VAL_31]], [[VAL_31]], [[VAL_31]], [[VAL_30]]#0, [[VAL_30]]#1, [[VAL_30]]#2, [[VAL_30]]#3, [[VAL_31]], [[VAL_31]], [[VAL_1]], [[VAL_2]], [[VAL_31]], [[VAL_31]], [[VAL_31]], [[VAL_31]]) {cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = true, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<8x8x10xf32>
 // CHECK-DAG:       [[VAL_33:%.*]] = arith.constant dense<[0, -1, 0]> : tensor<3xi32>
 // CHECK-DAG:       [[VAL_34:%.*]] = arith.constant dense<0> : tensor<3xi32>
 // CHECK-DAG:       [[VAL_35:%.*]] = arith.constant dense<1> : tensor<3xi32>
-// CHECK:           [[VAL_36:%.*]] = "tf.StridedSlice"([[VAL_32]], [[VAL_33]], [[VAL_34]], [[VAL_35]]) {begin_mask = 5 : i64, ellipsis_mask = 0 : i64, end_mask = 5 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 2 : i64} : (tensor<8x8x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32>
-// CHECK-DAG:       [[VAL_37:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<8x10xf32>
-// CHECK-DAG:       [[VAL_38:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<8x10xf32>
-// CHECK-DAG:       [[VAL_39:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<f32>
+// CHECK:           [[VAL_36:%.*]] = "tf.StridedSlice"([[VAL_32]], [[VAL_33]], [[VAL_34]], [[VAL_35]]) <{begin_mask = 5 : i64, ellipsis_mask = 0 : i64, end_mask = 5 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 2 : i64}> : (tensor<8x8x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32>
+// CHECK-DAG:       [[VAL_37:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<8x10xf32>
+// CHECK-DAG:       [[VAL_38:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<8x10xf32>
+// CHECK-DAG:       [[VAL_39:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<f32>
 // CHECK:           return [[VAL_36]], [[VAL_32]], [[VAL_37]], [[VAL_38]], [[VAL_39]] : tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>
 // CHECK:         }
 
@@ -389,24 +389,24 @@
 // CHECK:           [[VAL_9:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_8]]) : (tensor<8x40xf32>, tensor<2xi32>) -> tensor<40x8xf32>
 // CHECK:           [[VAL_10:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32>
 // CHECK:           [[VAL_11:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_10]]) : (tensor<10x40xf32>, tensor<2xi32>) -> tensor<40x10xf32>
-// CHECK-DAG:       [[VAL_12:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK-DAG:       [[VAL_13:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       [[VAL_12:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG:       [[VAL_13:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           [[VAL_14:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_12]], [[VAL_13]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
-// CHECK-DAG:       [[VAL_15:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK-DAG:       [[VAL_16:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       [[VAL_15:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG:       [[VAL_16:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           [[VAL_17:%.*]]:4 = "tf.SplitV"([[VAL_11]], [[VAL_15]], [[VAL_16]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>)
-// CHECK-DAG:       [[VAL_18:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK-DAG:       [[VAL_19:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       [[VAL_18:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG:       [[VAL_19:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           [[VAL_20:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_18]], [[VAL_19]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
 // CHECK:           [[VAL_21:%.*]] = "tfl.no_value"() {value} : () -> none
 // CHECK:           [[VAL_22:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_7]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) {cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = false, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<?x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<?x8x10xf32>
 // CHECK-DAG:       [[VAL_23:%.*]] = arith.constant dense<[-1, 0, 0]> : tensor<3xi32>
 // CHECK-DAG:       [[VAL_24:%.*]] = arith.constant dense<0> : tensor<3xi32>
 // CHECK-DAG:       [[VAL_25:%.*]] = arith.constant dense<1> : tensor<3xi32>
-// CHECK:           [[VAL_26:%.*]] = "tf.StridedSlice"([[VAL_22]], [[VAL_23]], [[VAL_24]], [[VAL_25]]) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<?x8x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32>
-// CHECK-DAG:       [[VAL_27:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<8x10xf32>
-// CHECK-DAG:       [[VAL_28:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<8x10xf32>
-// CHECK-DAG:       [[VAL_29:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<f32>
+// CHECK:           [[VAL_26:%.*]] = "tf.StridedSlice"([[VAL_22]], [[VAL_23]], [[VAL_24]], [[VAL_25]]) <{begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64}> : (tensor<?x8x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32>
+// CHECK-DAG:       [[VAL_27:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<8x10xf32>
+// CHECK-DAG:       [[VAL_28:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<8x10xf32>
+// CHECK-DAG:       [[VAL_29:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<f32>
 // CHECK:           return [[VAL_26]], [[VAL_22]], [[VAL_27]], [[VAL_28]], [[VAL_29]] : tensor<8x10xf32>, tensor<?x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>
 // CHECK:         }
 
@@ -438,32 +438,32 @@
 // CHECK:           [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi32>) -> tensor<40x8xf32>
 // CHECK:           [[VAL_8:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32>
 // CHECK:           [[VAL_9:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_8]]) : (tensor<10x4xf32>, tensor<2xi32>) -> tensor<4x10xf32>
-// CHECK-DAG:       [[VAL_10:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK-DAG:       [[VAL_11:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       [[VAL_10:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG:       [[VAL_11:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           [[VAL_12:%.*]]:4 = "tf.SplitV"([[VAL_7]], [[VAL_10]], [[VAL_11]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
-// CHECK-DAG:       [[VAL_13:%.*]] = "tf.Const"() {value = dense<1> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK-DAG:       [[VAL_14:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       [[VAL_13:%.*]] = "tf.Const"() <{value = dense<1> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG:       [[VAL_14:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           [[VAL_15:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_13]], [[VAL_14]]) : (tensor<4x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>)
-// CHECK-DAG:       [[VAL_20:%.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
+// CHECK-DAG:       [[VAL_20:%.*]] = "tf.Const"() <{value = dense<-1> : tensor<1xi32>}> : () -> tensor<1xi32>
 // CHECK:           [[VAL_21:%.*]] = "tf.Reshape"([[VAL_15]]#0, [[VAL_20]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32>
-// CHECK-DAG:       [[VAL_22:%.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
+// CHECK-DAG:       [[VAL_22:%.*]] = "tf.Const"() <{value = dense<-1> : tensor<1xi32>}> : () -> tensor<1xi32>
 // CHECK:           [[VAL_23:%.*]] = "tf.Reshape"([[VAL_15]]#1, [[VAL_22]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32>
-// CHECK-DAG:       [[VAL_24:%.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
+// CHECK-DAG:       [[VAL_24:%.*]] = "tf.Const"() <{value = dense<-1> : tensor<1xi32>}> : () -> tensor<1xi32>
 // CHECK:           [[VAL_25:%.*]] = "tf.Reshape"([[VAL_15]]#2, [[VAL_24]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32>
-// CHECK-DAG:       [[VAL_26:%.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
+// CHECK-DAG:       [[VAL_26:%.*]] = "tf.Const"() <{value = dense<-1> : tensor<1xi32>}> : () -> tensor<1xi32>
 // CHECK:           [[VAL_27:%.*]] = "tf.Reshape"([[VAL_15]]#3, [[VAL_26]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32>
-// CHECK-DAG:       [[VAL_28:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK-DAG:       [[VAL_29:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       [[VAL_28:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG:       [[VAL_29:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           [[VAL_30:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_28]], [[VAL_29]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
 // CHECK:           [[VAL_31:%.*]] = "tfl.no_value"() {value} : () -> none
 // CHECK:           [[VAL_32:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_41]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_21]], [[VAL_23]], [[VAL_25]], [[VAL_27]], [[VAL_31]], [[VAL_31]], [[VAL_31]], [[VAL_30]]#0, [[VAL_30]]#1, [[VAL_30]]#2, [[VAL_30]]#3, [[VAL_31]], [[VAL_31]], [[VAL_1]], [[VAL_2]], [[VAL_31]], [[VAL_31]], [[VAL_31]], [[VAL_31]]) {cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = true, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<8x8x10xf32>
 // CHECK-DAG:       [[VAL_33:%.*]] = arith.constant dense<[-1, 0, 0]> : tensor<3xi32>
 // CHECK-DAG:       [[VAL_34:%.*]] = arith.constant dense<0> : tensor<3xi32>
 // CHECK-DAG:       [[VAL_35:%.*]] = arith.constant dense<1> : tensor<3xi32>
-// CHECK:           [[VAL_36:%.*]] = "tf.StridedSlice"([[VAL_32]], [[VAL_33]], [[VAL_34]], [[VAL_35]]) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<8x8x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32>
-// CHECK-DAG:       [[VAL_37:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<8x10xf32>
-// CHECK-DAG:       [[VAL_38:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<8x10xf32>
-// CHECK-DAG:       [[VAL_39:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<f32>
+// CHECK:           [[VAL_36:%.*]] = "tf.StridedSlice"([[VAL_32]], [[VAL_33]], [[VAL_34]], [[VAL_35]]) <{begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64}> : (tensor<8x8x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32>
+// CHECK-DAG:       [[VAL_37:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<8x10xf32>
+// CHECK-DAG:       [[VAL_38:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<8x10xf32>
+// CHECK-DAG:       [[VAL_39:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<f32>
 // CHECK:           return [[VAL_36]], [[VAL_32]], [[VAL_37]], [[VAL_38]], [[VAL_39]] : tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>
 // CHECK:         }
 
@@ -490,24 +490,24 @@
 // CHECK:           [[VAL_9:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_8]]) : (tensor<8x40xf32>, tensor<2xi32>) -> tensor<40x8xf32>
 // CHECK:           [[VAL_10:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32>
 // CHECK:           [[VAL_11:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_10]]) : (tensor<10x40xf32>, tensor<2xi32>) -> tensor<40x10xf32>
-// CHECK-DAG:       [[VAL_12:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK-DAG:       [[VAL_13:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       [[VAL_12:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG:       [[VAL_13:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           [[VAL_14:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_12]], [[VAL_13]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
-// CHECK-DAG:       [[VAL_15:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK-DAG:       [[VAL_16:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       [[VAL_15:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG:       [[VAL_16:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           [[VAL_17:%.*]]:4 = "tf.SplitV"([[VAL_11]], [[VAL_15]], [[VAL_16]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>)
-// CHECK-DAG:       [[VAL_18:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK-DAG:       [[VAL_19:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       [[VAL_18:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG:       [[VAL_19:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           [[VAL_20:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_18]], [[VAL_19]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
 // CHECK:           [[VAL_21:%.*]] = "tfl.no_value"() {value} : () -> none
 // CHECK:           [[VAL_22:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_7]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) {cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = false, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<8x8x10xf32>
 // CHECK-DAG:       [[VAL_23:%.*]] = arith.constant dense<[0, -1, 0]> : tensor<3xi32>
 // CHECK-DAG:       [[VAL_24:%.*]] = arith.constant dense<0> : tensor<3xi32>
 // CHECK-DAG:       [[VAL_25:%.*]] = arith.constant dense<1> : tensor<3xi32>
-// CHECK:           [[VAL_26:%.*]] = "tf.StridedSlice"([[VAL_22]], [[VAL_23]], [[VAL_24]], [[VAL_25]]) {begin_mask = 5 : i64, ellipsis_mask = 0 : i64, end_mask = 5 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 2 : i64} : (tensor<8x8x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32>
-// CHECK-DAG:       [[VAL_27:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<8x10xf32>
-// CHECK-DAG:       [[VAL_28:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<8x10xf32>
-// CHECK-DAG:       [[VAL_29:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<f32>
+// CHECK:           [[VAL_26:%.*]] = "tf.StridedSlice"([[VAL_22]], [[VAL_23]], [[VAL_24]], [[VAL_25]]) <{begin_mask = 5 : i64, ellipsis_mask = 0 : i64, end_mask = 5 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 2 : i64}> : (tensor<8x8x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32>
+// CHECK-DAG:       [[VAL_27:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<8x10xf32>
+// CHECK-DAG:       [[VAL_28:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<8x10xf32>
+// CHECK-DAG:       [[VAL_29:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<f32>
 // CHECK:           return [[VAL_26]], [[VAL_22]], [[VAL_27]], [[VAL_28]], [[VAL_29]] : tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>
 // CHECK:         }
 
@@ -539,32 +539,32 @@
 // CHECK:           [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi32>) -> tensor<40x8xf32>
 // CHECK:           [[VAL_8:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32>
 // CHECK:           [[VAL_9:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_8]]) : (tensor<10x4xf32>, tensor<2xi32>) -> tensor<4x10xf32>
-// CHECK-DAG:       [[VAL_10:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK-DAG:       [[VAL_11:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       [[VAL_10:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG:       [[VAL_11:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           [[VAL_12:%.*]]:4 = "tf.SplitV"([[VAL_7]], [[VAL_10]], [[VAL_11]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
-// CHECK-DAG:       [[VAL_13:%.*]] = "tf.Const"() {value = dense<1> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK-DAG:       [[VAL_14:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       [[VAL_13:%.*]] = "tf.Const"() <{value = dense<1> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG:       [[VAL_14:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           [[VAL_15:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_13]], [[VAL_14]]) : (tensor<4x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>)
-// CHECK-DAG:       [[VAL_20:%.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
+// CHECK-DAG:       [[VAL_20:%.*]] = "tf.Const"() <{value = dense<-1> : tensor<1xi32>}> : () -> tensor<1xi32>
 // CHECK:           [[VAL_21:%.*]] = "tf.Reshape"([[VAL_15]]#0, [[VAL_20]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32>
-// CHECK-DAG:       [[VAL_22:%.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
+// CHECK-DAG:       [[VAL_22:%.*]] = "tf.Const"() <{value = dense<-1> : tensor<1xi32>}> : () -> tensor<1xi32>
 // CHECK:           [[VAL_23:%.*]] = "tf.Reshape"([[VAL_15]]#1, [[VAL_22]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32>
-// CHECK-DAG:       [[VAL_24:%.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
+// CHECK-DAG:       [[VAL_24:%.*]] = "tf.Const"() <{value = dense<-1> : tensor<1xi32>}> : () -> tensor<1xi32>
 // CHECK:           [[VAL_25:%.*]] = "tf.Reshape"([[VAL_15]]#2, [[VAL_24]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32>
-// CHECK-DAG:       [[VAL_26:%.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
+// CHECK-DAG:       [[VAL_26:%.*]] = "tf.Const"() <{value = dense<-1> : tensor<1xi32>}> : () -> tensor<1xi32>
 // CHECK:           [[VAL_27:%.*]] = "tf.Reshape"([[VAL_15]]#3, [[VAL_26]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32>
-// CHECK-DAG:       [[VAL_28:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK-DAG:       [[VAL_29:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       [[VAL_28:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG:       [[VAL_29:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           [[VAL_30:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_28]], [[VAL_29]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
 // CHECK:           [[VAL_31:%.*]] = "tfl.no_value"() {value} : () -> none
 // CHECK:           [[VAL_32:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_41]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_21]], [[VAL_23]], [[VAL_25]], [[VAL_27]], [[VAL_31]], [[VAL_31]], [[VAL_31]], [[VAL_30]]#0, [[VAL_30]]#1, [[VAL_30]]#2, [[VAL_30]]#3, [[VAL_31]], [[VAL_31]], [[VAL_1]], [[VAL_2]], [[VAL_31]], [[VAL_31]], [[VAL_31]], [[VAL_31]]) {cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = true, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<8x8x10xf32>
 // CHECK-DAG:       [[VAL_33:%.*]] = arith.constant dense<[0, -1, 0]> : tensor<3xi32>
 // CHECK-DAG:       [[VAL_34:%.*]] = arith.constant dense<0> : tensor<3xi32>
 // CHECK-DAG:       [[VAL_35:%.*]] = arith.constant dense<1> : tensor<3xi32>
-// CHECK:           [[VAL_36:%.*]] = "tf.StridedSlice"([[VAL_32]], [[VAL_33]], [[VAL_34]], [[VAL_35]]) {begin_mask = 5 : i64, ellipsis_mask = 0 : i64, end_mask = 5 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 2 : i64} : (tensor<8x8x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32>
-// CHECK-DAG:       [[VAL_37:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<8x10xf32>
-// CHECK-DAG:       [[VAL_38:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<8x10xf32>
-// CHECK-DAG:       [[VAL_39:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<f32>
+// CHECK:           [[VAL_36:%.*]] = "tf.StridedSlice"([[VAL_32]], [[VAL_33]], [[VAL_34]], [[VAL_35]]) <{begin_mask = 5 : i64, ellipsis_mask = 0 : i64, end_mask = 5 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 2 : i64}> : (tensor<8x8x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32>
+// CHECK-DAG:       [[VAL_37:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<8x10xf32>
+// CHECK-DAG:       [[VAL_38:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<8x10xf32>
+// CHECK-DAG:       [[VAL_39:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<f32>
 // CHECK:           return [[VAL_36]], [[VAL_32]], [[VAL_37]], [[VAL_38]], [[VAL_39]] : tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>
 // CHECK:         }
 
@@ -596,24 +596,24 @@
 // CHECK:           [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi32>) -> tensor<40x8xf32>
 // CHECK:           [[VAL_8:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32>
 // CHECK:           [[VAL_9:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_8]]) : (tensor<10x40xf32>, tensor<2xi32>) -> tensor<40x10xf32>
-// CHECK-DAG:       [[VAL_10:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK-DAG:       [[VAL_11:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       [[VAL_10:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG:       [[VAL_11:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           [[VAL_12:%.*]]:4 = "tf.SplitV"([[VAL_7]], [[VAL_10]], [[VAL_11]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
-// CHECK-DAG:       [[VAL_13:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK-DAG:       [[VAL_14:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       [[VAL_13:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG:       [[VAL_14:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           [[VAL_15:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_13]], [[VAL_14]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>)
-// CHECK-DAG:       [[VAL_16:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK-DAG:       [[VAL_17:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       [[VAL_16:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG:       [[VAL_17:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           [[VAL_18:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_16]], [[VAL_17]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
 // CHECK:           [[VAL_19:%.*]] = "tfl.no_value"() {value} : () -> none
 // CHECK:           [[VAL_20:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) {cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = false, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<?x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<?x8x10xf32>
 // CHECK-DAG:       [[VAL_21:%.*]] = arith.constant dense<[-1, 0, 0]> : tensor<3xi32>
 // CHECK-DAG:       [[VAL_22:%.*]] = arith.constant dense<0> : tensor<3xi32>
 // CHECK-DAG:       [[VAL_23:%.*]] = arith.constant dense<1> : tensor<3xi32>
-// CHECK:           [[VAL_24:%.*]] = "tf.StridedSlice"([[VAL_20]], [[VAL_21]], [[VAL_22]], [[VAL_23]]) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<?x8x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32>
-// CHECK-DAG:       [[VAL_25:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<8x10xf32>
-// CHECK-DAG:       [[VAL_26:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<8x10xf32>
-// CHECK-DAG:       [[VAL_27:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<f32>
+// CHECK:           [[VAL_24:%.*]] = "tf.StridedSlice"([[VAL_20]], [[VAL_21]], [[VAL_22]], [[VAL_23]]) <{begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64}> : (tensor<?x8x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32>
+// CHECK-DAG:       [[VAL_25:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<8x10xf32>
+// CHECK-DAG:       [[VAL_26:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<8x10xf32>
+// CHECK-DAG:       [[VAL_27:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<f32>
 // CHECK:           return [[VAL_24]], [[VAL_20]], [[VAL_25]], [[VAL_26]], [[VAL_27]] : tensor<8x10xf32>, tensor<?x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>
 // CHECK:         }
 
@@ -646,24 +646,24 @@
 // CHECK:           [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi32>) -> tensor<40x8xf32>
 // CHECK:           [[VAL_8:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32>
 // CHECK:           [[VAL_9:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_8]]) : (tensor<10x40xf32>, tensor<2xi32>) -> tensor<40x10xf32>
-// CHECK-DAG:       [[VAL_10:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK-DAG:       [[VAL_11:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       [[VAL_10:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG:       [[VAL_11:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           [[VAL_12:%.*]]:4 = "tf.SplitV"([[VAL_7]], [[VAL_10]], [[VAL_11]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
-// CHECK-DAG:       [[VAL_13:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK-DAG:       [[VAL_14:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       [[VAL_13:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG:       [[VAL_14:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           [[VAL_15:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_13]], [[VAL_14]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>)
-// CHECK-DAG:       [[VAL_16:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
-// CHECK-DAG:       [[VAL_17:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG:       [[VAL_16:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG:       [[VAL_17:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:           [[VAL_18:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_16]], [[VAL_17]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
 // CHECK:           [[VAL_19:%.*]] = "tfl.no_value"() {value} : () -> none
 // CHECK:           [[VAL_20:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) {cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = false, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<?x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<?x8x10xf32>
 // CHECK-DAG:       [[VAL_21:%.*]] = arith.constant dense<[-1, 0, 0]> : tensor<3xi32>
 // CHECK-DAG:       [[VAL_22:%.*]] = arith.constant dense<0> : tensor<3xi32>
 // CHECK-DAG:       [[VAL_23:%.*]] = arith.constant dense<1> : tensor<3xi32>
-// CHECK:           [[VAL_24:%.*]] = "tf.StridedSlice"([[VAL_20]], [[VAL_21]], [[VAL_22]], [[VAL_23]]) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<?x8x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32>
-// CHECK-DAG:       [[VAL_25:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<8x10xf32>
-// CHECK-DAG:       [[VAL_26:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<8x10xf32>
-// CHECK-DAG:       [[VAL_27:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<f32>
+// CHECK:           [[VAL_24:%.*]] = "tf.StridedSlice"([[VAL_20]], [[VAL_21]], [[VAL_22]], [[VAL_23]]) <{begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64}> : (tensor<?x8x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32>
+// CHECK-DAG:       [[VAL_25:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<8x10xf32>
+// CHECK-DAG:       [[VAL_26:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<8x10xf32>
+// CHECK-DAG:       [[VAL_27:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<f32>
 // CHECK:           return [[VAL_24]], [[VAL_20]], [[VAL_25]], [[VAL_26]], [[VAL_27]] : tensor<8x10xf32>, tensor<?x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>
 // CHECK:         }
 
@@ -684,13 +684,13 @@
 }
 
 // CHECK:       func @inference_standard_lstm_with_mask([[ARG_0:%.*]]: tensor<?x8x8xf32>, [[ARG_1:%.*]]: tensor<8x10xf32>, [[ARG_2:%.*]]: tensor<8x10xf32>, [[ARG_3:%.*]]: tensor<8x40xf32>, [[ARG_4:%.*]]: tensor<10x40xf32>, [[ARG_5:%.*]]: tensor<40xf32>,  [[ARG_6:%.*]]: tensor<?x8xi1>) -> (tensor<8x10xf32>, tensor<?x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false", "tfshape$dim { size: -1 } dim { size: 8 }"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} {
-// CHECK:         [[VAL_0:%.*]] = "tf.BatchMatMulV2"([[ARG_0]], [[ARG_3]]) {adj_x = false, adj_y = false} : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
+// CHECK:         [[VAL_0:%.*]] = "tf.BatchMatMulV2"([[ARG_0]], [[ARG_3]]) <{adj_x = false, adj_y = false}> : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
 // CHECK:         [[VAL_1:%.*]] = "tf.Add"([[VAL_0]], [[ARG_5]]) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32>
-// CHECK:         [[VAL_2:%.*]] = "tf.BatchMatMulV2"([[VAL_1]], [[ARG_4]]) {adj_x = false, adj_y = true} : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
+// CHECK:         [[VAL_2:%.*]] = "tf.BatchMatMulV2"([[VAL_1]], [[ARG_4]]) <{adj_x = false, adj_y = true}> : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
 // CHECK:         [[VAL_3:%.*]] = "tf.Add"([[VAL_2]], [[ARG_1]]) : (tensor<?x8x10xf32>, tensor<8x10xf32>) -> tensor<?x8x10xf32>
 // CHECK:         [[VAL_4:%.*]] = "tf.Add"([[VAL_2]], [[ARG_2]]) : (tensor<?x8x10xf32>, tensor<8x10xf32>) -> tensor<?x8x10xf32>
 // CHECK:         [[VAL_5:%.*]] = "tf.Add"([[ARG_1]], [[ARG_2]]) : (tensor<8x10xf32>, tensor<8x10xf32>) -> tensor<8x10xf32>
-// CHECK:         [[VAL_6:%.*]] = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
+// CHECK:         [[VAL_6:%.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}> {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32} : () -> tensor<f32>
 // CHECK:         return [[VAL_5]], [[VAL_4]], [[VAL_5]], [[VAL_5]], [[VAL_6]] : tensor<8x10xf32>, tensor<?x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>
 // CHECK:       }
 
@@ -718,13 +718,13 @@
 }
 
 // CHECK:        func @inference_standard_lstm_time_major_cannot_fuse([[VAL_0:%.*]]: tensor<?x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} {
-// CHECK:           [[VAL_6:%.*]] = "tf.BatchMatMulV2"([[VAL_0]], [[VAL_3]]) {adj_x = false, adj_y = false} : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
+// CHECK:           [[VAL_6:%.*]] = "tf.BatchMatMulV2"([[VAL_0]], [[VAL_3]]) <{adj_x = false, adj_y = false}> : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
 // CHECK:           [[VAL_7:%.*]] = "tf.Add"([[VAL_6]], [[VAL_5]]) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32>
-// CHECK:           [[VAL_8:%.*]] = "tf.BatchMatMulV2"([[VAL_7]], [[VAL_4]]) {adj_x = false, adj_y = true} : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
+// CHECK:           [[VAL_8:%.*]] = "tf.BatchMatMulV2"([[VAL_7]], [[VAL_4]]) <{adj_x = false, adj_y = true}> : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
 // CHECK:           [[VAL_9:%.*]] = "tf.Add"([[VAL_8]], [[VAL_1]]) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x8x10xf32>
 // CHECK:           [[VAL_10:%.*]] = "tf.Add"([[VAL_8]], [[VAL_2]]) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x8x10xf32>
 // CHECK:           [[VAL_11:%.*]] = "tf.Add"([[VAL_1]], [[VAL_2]]) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<?x10xf32>
-// CHECK:           [[VAL_12:%.*]] = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
+// CHECK:           [[VAL_12:%.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}> {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32} : () -> tensor<f32>
 // CHECK:           return [[VAL_11]], [[VAL_10]], [[VAL_11]], [[VAL_11]], [[VAL_12]] : tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
 // CHECK:         }
 }
@@ -745,13 +745,13 @@
 }
 
 // CHECK: func @dynamic_shape_non_fuse_standard_lstm(%[[VAL_0:.*]]: tensor<?x8x8xf32>, %[[VAL_1:.*]]: tensor<?x10xf32>, %[[VAL_2:.*]]: tensor<?x10xf32>, %[[VAL_3:.*]]: tensor<8x40xf32>, %[[VAL_4:.*]]: tensor<10x40xf32>, %[[VAL_5:.*]]: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} {
-// CHECK:         %[[VAL_6:.*]] = "tf.BatchMatMulV2"(%[[VAL_0]], %[[VAL_3]]) {adj_x = false, adj_y = false} : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
+// CHECK:         %[[VAL_6:.*]] = "tf.BatchMatMulV2"(%[[VAL_0]], %[[VAL_3]]) <{adj_x = false, adj_y = false}> : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
 // CHECK:         %[[VAL_7:.*]] = "tf.Add"(%[[VAL_6]], %[[VAL_5]]) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32>
-// CHECK:         %[[VAL_8:.*]] = "tf.BatchMatMulV2"(%[[VAL_7]], %[[VAL_4]]) {adj_x = false, adj_y = true} : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
+// CHECK:         %[[VAL_8:.*]] = "tf.BatchMatMulV2"(%[[VAL_7]], %[[VAL_4]]) <{adj_x = false, adj_y = true}> : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
 // CHECK:         %[[VAL_9:.*]] = "tf.Add"(%[[VAL_8]], %[[VAL_1]]) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x8x10xf32>
 // CHECK:         %[[VAL_10:.*]] = "tf.Add"(%[[VAL_8]], %[[VAL_2]]) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x8x10xf32>
 // CHECK:         %[[VAL_11:.*]] = "tf.Add"(%[[VAL_1]], %[[VAL_2]]) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<?x10xf32>
-// CHECK:         %[[VAL_12:.*]] = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
+// CHECK:         %[[VAL_12:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}> {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32} : () -> tensor<f32>
 // CHECK:         return %[[VAL_11]], %[[VAL_10]], %[[VAL_11]], %[[VAL_11]], %[[VAL_12]] : tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
 // CHECK:       }
 }
diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf-fake-quant-4bit.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf-fake-quant-4bit.mlir
index 9a865f0..dca4c21 100644
--- a/tensorflow/compiler/mlir/lite/tests/prepare-tf-fake-quant-4bit.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf-fake-quant-4bit.mlir
@@ -40,7 +40,7 @@
   %1 = "tfl.quantize"(%0) {qtype = tensor<8x!quant.uniform<u4:f32, 1.000000e+00>>} : (tensor<8xf32>) -> tensor<8x!quant.uniform<u4:f32, 1.000000e+00>>
   func.return %1 : tensor<8x!quant.uniform<u4:f32, 1.000000e+00>>
 
-// CHECK:  %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0) {narrow_range = false, num_bits = 3 : i64}
+// CHECK:  %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0) <{narrow_range = false, num_bits = 3 : i64}>
 // CHECK:  %1 = "tfl.quantize"(%0) {qtype = tensor<8x!quant.uniform<u4:f32, 1.000000e+00>>}
 // CHECK:  return %1
 }
diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf-fake-quant.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf-fake-quant.mlir
index 0c5bdac..c65cecc 100644
--- a/tensorflow/compiler/mlir/lite/tests/prepare-tf-fake-quant.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf-fake-quant.mlir
@@ -39,7 +39,7 @@
   %1 = "tfl.quantize"(%0) {qtype = tensor<8x!quant.uniform<u8:f32, 1.000000e+00>>} : (tensor<8xf32>) -> tensor<8x!quant.uniform<u8:f32, 1.000000e+00>>
   func.return %1 : tensor<8x!quant.uniform<u8:f32, 1.000000e+00>>
 
-// CHECK:  %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0) {narrow_range = false, num_bits = 5 : i64}
+// CHECK:  %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0) <{narrow_range = false, num_bits = 5 : i64}>
 // CHECK:  %1 = "tfl.quantize"(%0) {qtype = tensor<8x!quant.uniform<u8:f32, 1.000000e+00>>}
 // CHECK:  return %1
 }
diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
index 4f39142..fff0082 100644
--- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
@@ -30,7 +30,7 @@
 // CHECK:  %5 = "tf.Pad"(%arg0, %[[CONSTANT1]]) : (tensor<256x32x32x3xf32>, tensor<4x2xi32>) -> tensor<*xf32>
 // CHECK:  %6 = "tf.Transpose"(%arg1, %[[CONSTANT0]]) : (tensor<3x3x3x16xf32>, tensor<4xi32>) -> tensor<16x3x3x3xf32>
 // CHECK:  %7 = "tfl.conv_2d"(%5, %6, %[[CONSTANT]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<*xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x32x32x16xf32>
-// CHECK:  %8 = "tf.Conv2D"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 1, 1, 1], padding = "SAME", strides = [2, 1, 1, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x32x32x16xf32>
+// CHECK:  %8 = "tf.Conv2D"(%arg0, %arg1) <{data_format = "NHWC", dilations = [1, 1, 1, 1], padding = "SAME", strides = [2, 1, 1, 1]}> {T = "tfdtype$DT_FLOAT"} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x32x32x16xf32>
 }
 
 func.func @depthwiseConv2D(tensor<256x32x32x3xf32>, tensor<3x3x3x4xf32>, tensor<256x3x32x32xf32>) -> (tensor<256x30x30x12xf32>, tensor<256x12x30x30xf32>, tensor<256x30x30x12xf32>, tensor<256x30x30x12xf32>) {
@@ -224,9 +224,9 @@
   func.return %166 : tensor<1x1000xf32>
 
   // CHECK-LABEL: matmulNoTransposeAOrB
-  // CHECK: %[[RES:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<?xi32>
+  // CHECK: %[[RES:.*]] = "tf.Const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<?xi32>
   // CHECK: %[[TRANS:.*]] = "tf.Transpose"(%arg1, %[[RES]]) : (tensor<1280x1000xf32>, tensor<?xi32>) -> tensor<*xf32>
-  // CHECK: %[[MM:.*]] = "tf.MatMul"(%arg0, %[[TRANS]]) {transpose_a = false, transpose_b = true} : (tensor<1x1280xf32>, tensor<*xf32>) -> tensor<1x1000xf32>
+  // CHECK: %[[MM:.*]] = "tf.MatMul"(%arg0, %[[TRANS]]) <{transpose_a = false, transpose_b = true}> : (tensor<1x1280xf32>, tensor<*xf32>) -> tensor<1x1000xf32>
   // CHECK: return %[[MM]] : tensor<1x1000xf32>
  }
 
@@ -235,10 +235,10 @@
   func.return %166 : tensor<1x1000xf32>
 
   // CHECK-LABEL: matmulNoTransposeB
-  // CHECK: %[[RES:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<?xi32>
+  // CHECK: %[[RES:.*]] = "tf.Const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<?xi32>
   // CHECK: %[[TRANS1:.*]] = "tf.Transpose"(%arg0, %[[RES]]) : (tensor<1x1280xf32>, tensor<?xi32>) -> tensor<*xf32>
   // CHECK: %[[TRANS2:.*]] = "tf.Transpose"(%arg1, %[[RES]]) : (tensor<1280x1000xf32>, tensor<?xi32>) -> tensor<*xf32>
-  // CHECK: %[[MM:.*]] = "tf.MatMul"(%[[TRANS1]], %[[TRANS2]]) {transpose_a = false, transpose_b = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<1x1000xf32>
+  // CHECK: %[[MM:.*]] = "tf.MatMul"(%[[TRANS1]], %[[TRANS2]]) <{transpose_a = false, transpose_b = true}> : (tensor<*xf32>, tensor<*xf32>) -> tensor<1x1000xf32>
   // CHECK: return %[[MM]] : tensor<1x1000xf32>
 
 }
@@ -284,7 +284,7 @@
 
   // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : tensor<3xi32>
   // CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<1> : tensor<3xi32>
-  // CHECK: %[[STRIDED_SLICE:.*]] = "tf.StridedSlice"(%arg0, %[[CST]], %[[CST]], %[[CST_0]]) {begin_mask = 3 : i64, ellipsis_mask = 0 : i64, end_mask = 3 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<21x15x7xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<21x15x2xf32>
+  // CHECK: %[[STRIDED_SLICE:.*]] = "tf.StridedSlice"(%arg0, %[[CST]], %[[CST]], %[[CST_0]]) <{begin_mask = 3 : i64, ellipsis_mask = 0 : i64, end_mask = 3 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64}> : (tensor<21x15x7xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<21x15x2xf32>
 }
 
 // CHECK-LABEL: @StridedSliceEllipsisMaskBeforeWithBeginAndEndMask
@@ -298,7 +298,7 @@
   // CHECK-DAG: %[[CST:.*]] = arith.constant dense<[0, 1, 0]> : tensor<3xi32>
   // CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<0> : tensor<3xi32>
   // CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<1> : tensor<3xi32>
-  // CHECK: %[[STRIDED_SLICE:.*]] = "tf.StridedSlice"(%arg0, %[[CST]], %[[CST_0]], %[[CST_1]]) {begin_mask = 7 : i64, ellipsis_mask = 0 : i64, end_mask = 5 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4x5x4xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<4x4x4xf32>
+  // CHECK: %[[STRIDED_SLICE:.*]] = "tf.StridedSlice"(%arg0, %[[CST]], %[[CST_0]], %[[CST_1]]) <{begin_mask = 7 : i64, ellipsis_mask = 0 : i64, end_mask = 5 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64}> : (tensor<4x5x4xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<4x4x4xf32>
 }
 
 // CHECK-LABEL: @StridedSliceEllipsisMaskAfter
@@ -310,7 +310,7 @@
 
   // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : tensor<3xi32>
   // CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<1> : tensor<3xi32>
-  // CHECK: %[[STRIDED_SLICE:.*]] = "tf.StridedSlice"(%arg0, %[[CST]], %[[CST]], %[[CST_0]]) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<21x15x7xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<5x15x7xf32>
+  // CHECK: %[[STRIDED_SLICE:.*]] = "tf.StridedSlice"(%arg0, %[[CST]], %[[CST]], %[[CST_0]]) <{begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64}> : (tensor<21x15x7xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<5x15x7xf32>
 }
 
 // CHECK-LABEL: @NoStridedSliceEllipsisMask
@@ -322,7 +322,7 @@
 
   // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : tensor<2xi32>
   // CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<1> : tensor<2xi32>
-  // CHECK: %[[STRIDED_SLICE:.*]] = "tf.StridedSlice"(%arg0, %[[CST]], %[[CST]], %[[CST_0]]) {begin_mask = 0 : i64, ellipsis_mask = 1 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<*xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<21x15x2xf32>
+  // CHECK: %[[STRIDED_SLICE:.*]] = "tf.StridedSlice"(%arg0, %[[CST]], %[[CST]], %[[CST_0]]) <{begin_mask = 0 : i64, ellipsis_mask = 1 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64}> : (tensor<*xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<21x15x2xf32>
 }
 
 // CHECK-LABEL: @NoPadStridedSliceNonNewAxisMask
@@ -334,7 +334,7 @@
 
   // CHECK-DAG: %cst = arith.constant dense<0> : tensor<4xi32>
   // CHECK-DAG: %cst_0 = arith.constant dense<1> : tensor<4xi32>
-  // CHECK: %0 = "tf.StridedSlice"(%arg0, %cst, %cst, %cst_0) {begin_mask = 15 : i64, ellipsis_mask = 0 : i64, end_mask = 15 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1x2x3x1xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x3x1xf32>
+  // CHECK: %0 = "tf.StridedSlice"(%arg0, %cst, %cst, %cst_0) <{begin_mask = 15 : i64, ellipsis_mask = 0 : i64, end_mask = 15 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64}> : (tensor<1x2x3x1xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x3x1xf32>
 }
 
 // CHECK-LABEL: @PadStridedSliceNewAxisMask1
@@ -348,7 +348,7 @@
   // CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1> : tensor<4xi32>
   // CHECK-DAG: %[[cst_1:.*]] = arith.constant dense<[1, 2, 3, 1]> : tensor<4xi32>
   // CHECK: %0 = "tf.Reshape"(%arg0, %[[cst_1]]) : (tensor<2x3xf32>, tensor<4xi32>) -> tensor<1x2x3x1xf32>
-  // CHECK: %1 = "tf.StridedSlice"(%0, %[[CST0]], %[[CST0]], %[[CST1]]) {begin_mask = 15 : i64, ellipsis_mask = 0 : i64, end_mask = 15 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1x2x3x1xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x3x1xf32>
+  // CHECK: %1 = "tf.StridedSlice"(%0, %[[CST0]], %[[CST0]], %[[CST1]]) <{begin_mask = 15 : i64, ellipsis_mask = 0 : i64, end_mask = 15 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64}> : (tensor<1x2x3x1xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x3x1xf32>
 }
 
 // CHECK-LABEL: @PadStridedSliceNewAxisMask2
@@ -401,7 +401,7 @@
   // CHECK-DAG: [[BEGIN:%cst.*]] = arith.constant dense<[-1, 0, 0]> : tensor<3xi32>
   // CHECK-DAG: [[END:%cst.*]] = arith.constant dense<[0, 10, 10]> : tensor<3xi32>
   // CHECK-DAG: [[STRIDES:%cst.*]] = arith.constant dense<1> : tensor<3xi32>
-  // CHECK-NEXT: "tf.StridedSlice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<10x10x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<10x10xf32>
+  // CHECK-NEXT: "tf.StridedSlice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) <{begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64}> : (tensor<10x10x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<10x10xf32>
 }
 
 // CHECK-LABEL: @StridedSliceEllipsisAndNewAxisMaskBothSet
@@ -419,7 +419,7 @@
   // CHECK-DAG: %[[STEP:.*]] = arith.constant dense<1> : tensor<5xi32>
   // CHECK-DAG: %[[NEW_DIMS:.*]] = arith.constant dense<[6, 1, 7, 8, 1]> : tensor<5xi32>
   // CHECK: %[[RESHAPE:.*]] = "tf.Reshape"(%arg0, %[[NEW_DIMS]]) : (tensor<6x7x8xf32>, tensor<5xi32>) -> tensor<6x1x7x8x1xf32>
-  // CHECK: %[[STRIDED_SLICE:.*]] = "tf.StridedSlice"(%[[RESHAPE]], %[[BEGIN]], %[[END]], %[[STEP]]) {begin_mask = 30 : i64, ellipsis_mask = 0 : i64, end_mask = 30 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<6x1x7x8x1xf32>, tensor<5xi32>, tensor<5xi32>, tensor<5xi32>) -> tensor<2x1x7x8x1xf32>
+  // CHECK: %[[STRIDED_SLICE:.*]] = "tf.StridedSlice"(%[[RESHAPE]], %[[BEGIN]], %[[END]], %[[STEP]]) <{begin_mask = 30 : i64, ellipsis_mask = 0 : i64, end_mask = 30 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64}> : (tensor<6x1x7x8x1xf32>, tensor<5xi32>, tensor<5xi32>, tensor<5xi32>) -> tensor<2x1x7x8x1xf32>
 }
 
 // CHECK-LABEL: @StridedSliceShrinkAxisAndNewAxisMaskBothSet
@@ -437,7 +437,7 @@
   // CHECK-DAG: %[[END:.*]] = arith.constant dense<[2, 3, 4, 5, 8]> : tensor<5xi32>
   // CHECK-DAG: %[[STEP:.*]] = arith.constant dense<1> : tensor<5xi32>
   // CHECK: %[[RESHAPE:.*]] = "tf.Reshape"(%arg0, %[[NEW_DIMS]]) : (tensor<6x7x8xf32>, tensor<5xi32>) -> tensor<6x1x7x1x8xf32>
-  // CHECK: %[[STRIDED_SLICE:.*]] = "tf.StridedSlice"(%[[RESHAPE]], %[[BEGIN]], %[[END]], %[[STEP]]) {begin_mask = 26 : i64, ellipsis_mask = 0 : i64, end_mask = 26 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<6x1x7x1x8xf32>, tensor<5xi32>, tensor<5xi32>, tensor<5xi32>) -> tensor<1x4x1x8xf32>
+  // CHECK: %[[STRIDED_SLICE:.*]] = "tf.StridedSlice"(%[[RESHAPE]], %[[BEGIN]], %[[END]], %[[STEP]]) <{begin_mask = 26 : i64, ellipsis_mask = 0 : i64, end_mask = 26 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64}> : (tensor<6x1x7x1x8xf32>, tensor<5xi32>, tensor<5xi32>, tensor<5xi32>) -> tensor<1x4x1x8xf32>
 }
 
 func.func @broadcast_to_f32_low_dim(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> {
@@ -572,7 +572,7 @@
 // CHECK:  %[[EXP:.*]] = "tf.ExpandDims"(%arg0, %[[CST]]) : (tensor<10x20x30xf32>, tensor<i32>) -> tensor<10x20x1x30xf32>
 // CHECK:  %[[CON:.*]] = "tf.ConcatV2"(%[[CST0]], %arg1, %[[CST1]]) : (tensor<1xi32>, tensor<1xi32>, tensor<i32>) -> tensor<2xi32>
 // CHECK:  %[[RFF:.*]] = "tf.RFFT2D"(%[[EXP]], %[[CON]]) : (tensor<10x20x1x30xf32>, tensor<2xi32>) -> tensor<10x20x1x30xcomplex<f64>>
-// CHECK:  %[[SQE:.*]] = "tf.Squeeze"(%[[RFF]]) {squeeze_dims = [-2]} : (tensor<10x20x1x30xcomplex<f64>>) -> tensor<10x20x30xcomplex<f64>>
+// CHECK:  %[[SQE:.*]] = "tf.Squeeze"(%[[RFF]]) <{squeeze_dims = [-2]}> : (tensor<10x20x1x30xcomplex<f64>>) -> tensor<10x20x30xcomplex<f64>>
 }
 
 // CHECK-LABEL: xla_gather_to_strided_slice
@@ -585,7 +585,7 @@
 // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : tensor<4xi64>
 // CHECK-DAG: %[[CST0:.*]] = arith.constant dense<[1, 9, 23, 768]> : tensor<4xi64>
 // CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1> : tensor<4xi64>
-// CHECK: %[[V0:.*]] = "tf.StridedSlice"(%arg0, %[[CST]], %[[CST0]], %[[CST1]]) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1x9x104x768xf32>, tensor<4xi64>, tensor<4xi64>, tensor<4xi64>) -> tensor<*xf32>
+// CHECK: %[[V0:.*]] = "tf.StridedSlice"(%arg0, %[[CST]], %[[CST0]], %[[CST1]]) <{begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64}> : (tensor<1x9x104x768xf32>, tensor<4xi64>, tensor<4xi64>, tensor<4xi64>) -> tensor<*xf32>
 // CHECK: return %[[V0]] : tensor<*xf32>
 }
 
@@ -660,9 +660,9 @@
   // CHECK-LABEL: fused_batch_norm_v3_training
   // CHECK-DAG: %[[CST:.*]] = arith.constant dense<[0, 1, 2]> : tensor<3xi32>
   // CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1.000000e-03> : tensor<f32>
-  // CHECK:  %[[MEAN:.*]] = "tf.Mean"(%arg0, %[[CST]]) {keep_dims = false} : (tensor<1x1x6x2xf32>, tensor<3xi32>) -> tensor<2xf32>
+  // CHECK:  %[[MEAN:.*]] = "tf.Mean"(%arg0, %[[CST]]) <{keep_dims = false}> : (tensor<1x1x6x2xf32>, tensor<3xi32>) -> tensor<2xf32>
   // CHECK:  %[[SQ:.*]] = "tf.SquaredDifference"(%arg0, %[[MEAN]]) : (tensor<1x1x6x2xf32>, tensor<2xf32>) -> tensor<1x1x6x2xf32>
-  // CHECK:  %[[MEAN0:.*]] = "tf.Mean"(%[[SQ]], %[[CST]]) {keep_dims = false} : (tensor<1x1x6x2xf32>, tensor<3xi32>) -> tensor<2xf32>
+  // CHECK:  %[[MEAN0:.*]] = "tf.Mean"(%[[SQ]], %[[CST]]) <{keep_dims = false}> : (tensor<1x1x6x2xf32>, tensor<3xi32>) -> tensor<2xf32>
   // CHECK:  %[[ADD:.*]] = "tf.Add"(%[[MEAN0]], %[[CST1]]) : (tensor<2xf32>, tensor<f32>) -> tensor<2xf32>
   // CHECK:  %[[RSQRT:.*]] = "tf.Rsqrt"(%[[ADD]]) : (tensor<2xf32>) -> tensor<2xf32>
   // CHECK:  %[[MUL1:.*]] = "tf.Mul"(%arg1, %[[RSQRT]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
@@ -689,10 +689,10 @@
   func.return %0 : tensor<ui32>
 
   // CHECK-LABEL: add_v2_uint32
-  // CHECK:  %[[CAST:.*]] = "tf.Cast"(%arg0) {Truncate = false} : (tensor<ui32>) -> tensor<i32>
-  // CHECK:  %[[CAST1:.*]] = "tf.Cast"(%arg1) {Truncate = false} : (tensor<ui32>) -> tensor<i32>
+  // CHECK:  %[[CAST:.*]] = "tf.Cast"(%arg0) <{Truncate = false}> : (tensor<ui32>) -> tensor<i32>
+  // CHECK:  %[[CAST1:.*]] = "tf.Cast"(%arg1) <{Truncate = false}> : (tensor<ui32>) -> tensor<i32>
   // CHECK:  %[[ADD:.*]] = "tf.AddV2"(%[[CAST]], %[[CAST1]]) : (tensor<i32>, tensor<i32>) -> tensor<i32>
-  // CHECK:  %[[CAST2:.*]] = "tf.Cast"(%[[ADD]]) {Truncate = false} : (tensor<i32>) -> tensor<ui32>
+  // CHECK:  %[[CAST2:.*]] = "tf.Cast"(%[[ADD]]) <{Truncate = false}> : (tensor<i32>) -> tensor<ui32>
   // CHECK:  return %[[CAST2]] : tensor<ui32>
 }
 
@@ -713,12 +713,12 @@
   func.return %6 : tensor<2x4xf32>
 
   // CHECK-LABEL: QuantDequantTranspose
-  // CHECK-DAG: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<?xi32>
+  // CHECK-DAG: %[[CST:.*]] = "tf.Const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<?xi32>
   // CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<1.00392163> : tensor<3x4xf32>
   // CHECK: %[[QUANT:.*]] = "tfl.quantize"(%[[CST_0]]) {qtype = tensor<3x4x!quant.uniform<u8:f32:1, {0.0078431372549019607:128,0.0078431372549019607:128,0.0078431372549019607:128,0.0078431372549019607:128}>>} : (tensor<3x4xf32>) -> tensor<3x4x!quant.uniform<u8:f32:1, {0.0078431372549019607:128,0.0078431372549019607:128,0.0078431372549019607:128,0.0078431372549019607:128}>>
   // CHECK: %[[DEQUANT:.*]] = "tfl.dequantize"(%[[QUANT]]) : (tensor<3x4x!quant.uniform<u8:f32:1, {0.0078431372549019607:128,0.0078431372549019607:128,0.0078431372549019607:128,0.0078431372549019607:128}>>) -> tensor<3x4xf32>
   // CHECK: %[[TRANSPOSE:.*]] = "tf.Transpose"(%[[DEQUANT]], %[[CST]]) : (tensor<3x4xf32>, tensor<?xi32>) -> tensor<*xf32>
-  // CHECK: %[[MATMUL:.*]] = "tf.MatMul"(%arg0, %[[TRANSPOSE]]) {transpose_a = false, transpose_b = true} : (tensor<2x3xf32>, tensor<*xf32>) -> tensor<2x4xf32>
+  // CHECK: %[[MATMUL:.*]] = "tf.MatMul"(%arg0, %[[TRANSPOSE]]) <{transpose_a = false, transpose_b = true}> : (tensor<2x3xf32>, tensor<*xf32>) -> tensor<2x4xf32>
   // CHECK: return %[[MATMUL]] : tensor<2x4xf32>
 }
 
diff --git a/tensorflow/compiler/mlir/lite/tests/raise-custom-ops.mlir b/tensorflow/compiler/mlir/lite/tests/raise-custom-ops.mlir
index e72f421..477315d 100644
--- a/tensorflow/compiler/mlir/lite/tests/raise-custom-ops.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/raise-custom-ops.mlir
@@ -42,11 +42,11 @@
 // CHECK: tf_executor.island wraps "tf.FakeQuantWithMinMaxVarsPerChannel"
 
 // WRAPPED-NEXT: tf_executor.graph {
-// WRAPPED-NEXT:   tf_executor.island wraps "tf.Const"() {device = "", value = dense<1.000000e+00> : tensor<186xf32>} : () -> tensor<186xf32>
-// WRAPPED-NEXT:   tf_executor.island wraps "tf.Const"() {device = "", value = dense<2.000000e+00> : tensor<186xf32>} : () -> tensor<186xf32>
+// WRAPPED-NEXT:   tf_executor.island wraps "tf.Const"() <{value = dense<1.000000e+00> : tensor<186xf32>}> {device = ""} : () -> tensor<186xf32>
+// WRAPPED-NEXT:   tf_executor.island wraps "tf.Const"() <{value = dense<2.000000e+00> : tensor<186xf32>}> {device = ""} : () -> tensor<186xf32>
 // WRAPPED-NEXT:   tf_executor.island wraps "tfl.custom_tf"
 // WRAPPED-NEXT:     ^bb0(%arg1: tensor<*xf32>, %arg2: tensor<186xf32>, %arg3: tensor<186xf32>):
-// WRAPPED-NEXT:   %[[fq:.*]] = "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg1, %arg2, %arg3) {device = "", narrow_range = true, num_bits = 8 : i64} : (tensor<*xf32>, tensor<186xf32>, tensor<186xf32>) -> tensor<*xf32>
+// WRAPPED-NEXT:   %[[fq:.*]] = "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg1, %arg2, %arg3) <{narrow_range = true, num_bits = 8 : i64}> {device = ""} : (tensor<*xf32>, tensor<186xf32>, tensor<186xf32>) -> tensor<*xf32>
 // WRAPPED-NEXT:   "tfl.yield"(%[[fq]]) : (tensor<*xf32>) -> ()
 // WRAPPED-NEXT:   }) {device = "", narrow_range = true, num_bits = 8 : i64} : (tensor<*xf32>, tensor<186xf32>, tensor<186xf32>) -> tensor<*xf32>
 }
diff --git a/tensorflow/compiler/mlir/lite/transforms/lift_tflite_flex_ops.cc b/tensorflow/compiler/mlir/lite/transforms/lift_tflite_flex_ops.cc
index 9563f5d..f8d1de0 100644
--- a/tensorflow/compiler/mlir/lite/transforms/lift_tflite_flex_ops.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/lift_tflite_flex_ops.cc
@@ -19,6 +19,7 @@
 #include <string>
 #include <utility>
 
+#include "absl/strings/match.h"
 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
@@ -111,20 +112,6 @@
     Operation* tf_op = rewriter.create(op_state);
     rewriter.replaceOp(op, tf_op->getResults());
 
-    if (isa<TF::MapDatasetOp, TF::ReduceDatasetOp>(tf_op)) {
-      constexpr StringRef kFuncAttrName = "f";
-      tf_op->setAttr(
-          kFuncAttrName,
-          tf_op->getAttr(kFuncAttrName).cast<TF::FuncAttr>().getName());
-    }
-
-    if (isa<TF::TakeWhileDatasetOp>(tf_op)) {
-      constexpr StringRef kFuncAttrName = "predicate";
-      tf_op->setAttr(
-          kFuncAttrName,
-          tf_op->getAttr(kFuncAttrName).cast<TF::FuncAttr>().getName());
-    }
-
     // Special type fixes for TF Resource Tensors that are casted to
     // Int32 tensor during MLIR->TFLite flatbuffer conversion.
     // TODO(b/146131919): correct handling of resource type
@@ -237,6 +224,10 @@
       if (!mlir_attr.ok()) {
         return emitError(loc, mlir_attr.status().message());
       }
+      if (absl::StrContains(op_name, "Dataset") &&
+          mlir_attr->isa<TF::FuncAttr>()) {
+        mlir_attr = mlir_attr->cast<TF::FuncAttr>().getName();
+      }
       attributes.push_back(builder.getNamedAttr(attr_name, *mlir_attr));
     }
     return success();
diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD
index 85ae6c2..640cd2e 100644
--- a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD
+++ b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD
@@ -32,8 +32,10 @@
         "passes/lift_quantizable_spots_as_functions.cc",
         "passes/lift_quantizable_spots_as_functions_fusion.inc",
         "passes/lift_quantizable_spots_as_functions_simple.inc",
+        "passes/post_quantize.cc",
         "passes/prepare_quantize.cc",
         "passes/quantize.cc",
+        "passes/quantize_composite_functions.cc",
         "passes/quantize_weight.cc",
         "passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc",
         "passes/restore_function_name.cc",
@@ -51,6 +53,7 @@
         ":quantization_options_proto_cc",
         ":stablehlo_passes_inc_gen",
         ":stablehlo_type_utils",
+        ":uniform_quantized_types",
         "//tensorflow/compiler/mlir/lite:tensorflow_lite",
         "//tensorflow/compiler/mlir/lite/quantization:quantization_config",
         "//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
@@ -58,6 +61,7 @@
         "//tensorflow/compiler/mlir/quantization/tensorflow:pass_utils",
         "//tensorflow/compiler/mlir/quantization/tensorflow:passes",
         "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc",
+        "//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes",
         "//tensorflow/compiler/mlir/quantization/tensorflow/ops:tf_op_quant_spec",
         "//tensorflow/compiler/mlir/quantization/tensorflow/utils:lift_as_function_call_utils",
         "//tensorflow/compiler/mlir/tensorflow",
@@ -185,7 +189,6 @@
     ],
     deps = [
         ":bridge_passes_inc_gen",
-        ":math_utils",
         ":tf_type_utils",
         "//tensorflow/compiler/mlir/tensorflow",
         "//tensorflow/compiler/mlir/tensorflow:mangling_util",
@@ -210,7 +213,6 @@
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:QuantOps",
         "@llvm-project//mlir:ShapeDialect",
-        "@llvm-project//mlir:SparseTensorDialect",
         "@llvm-project//mlir:Support",
         "@llvm-project//mlir:Transforms",
         "@local_xla//xla:xla_data_proto_cc",
@@ -232,21 +234,35 @@
     tags = ["nomac"],  # TODO(b/297362678): re-enable mac test.
     deps = [
         ":bridge_passes",
+        "//tensorflow/compiler/mlir/quantization/tensorflow/cc:constant_fold",
+        "//tensorflow/compiler/mlir/tensorflow",
+        "//tensorflow/compiler/mlir/tensorflow:convert_tensor",
         "//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
+        "//tensorflow/compiler/tf2xla:common",
+        "//tensorflow/core:framework",
+        "//tensorflow/core/kernels:math",
+        "//tensorflow/core/kernels:nn",
+        "//tensorflow/core/kernels/uniform_quant_ops:kernels",
+        "//tensorflow/core/ops",
         "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/random",
+        "@com_google_absl//absl/status",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:span",
         "@com_google_googletest//:gtest_main",
+        "@llvm-project//llvm:Support",
         "@llvm-project//mlir:FuncDialect",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Parser",
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:QuantOps",
         "@llvm-project//mlir:Support",
+        "@local_tsl//tsl/platform:errors",
         "@local_tsl//tsl/platform:statusor",
         "@local_xla//xla:error_spec",
         "@local_xla//xla:literal",
         "@local_xla//xla:literal_util",
+        "@local_xla//xla:shape_util",
         "@local_xla//xla/mlir_hlo",
         "@local_xla//xla/pjrt:pjrt_client",
         "@local_xla//xla/pjrt:pjrt_executable",
diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc
index e60c271..16af4b2 100644
--- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc
+++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc
@@ -15,12 +15,10 @@
 
 #include <cstdint>
 #include <cstdlib>
-#include <cstring>
 #include <memory>
 #include <optional>
-#include <string>
-#include <string_view>
 #include <utility>
+#include <variant>
 
 #include "absl/algorithm/container.h"
 #include "llvm/ADT/ArrayRef.h"
@@ -31,12 +29,11 @@
 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"  // from @llvm-project
 #include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
-#include "mlir/Dialect/Shape/IR/Shape.h"  // from @llvm-project
-#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
 #include "mlir/IR/BuiltinTypeInterfaces.h"  // from @llvm-project
 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
+#include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
@@ -49,7 +46,6 @@
 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
 #include "stablehlo/dialect/ChloOps.h"  // from @stablehlo
 #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h"
-#include "tensorflow/compiler/mlir/quantization/stablehlo/utils/math_utils.h"
 #include "tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_targets.h"
 #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
 #include "xla/mlir_hlo/mhlo/transforms/rewriters.h"
@@ -57,88 +53,193 @@
 namespace mlir::quant::stablehlo {
 namespace {
 
-#define GEN_PASS_DEF_CONVERTMHLOQUANTTOINT
-#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h.inc"
-
-// This helper function create ops to requantize `input` tensor and output to
-// `res_int32` tensor. Clamping is omitted because for some ops clamping can be
-// done later to avoid duplicate.
-LogicalResult RequantizeWithoutClamping(
-    mlir::OpState op, Value input, TensorType int32_tensor_type,
-    quant::UniformQuantizedType input_quantized_type,
-    quant::UniformQuantizedType result_quantized_type, Value &res_int32,
-    ConversionPatternRewriter &rewriter) {
+// This helper function create ops to requantize `input` tensor and returns the
+// output tensor. Clamping is done if output integer bit-width < 32.
+//
+// Requantization is essentially dequantize --> quantize.
+//
+// Dequantize: (input - zp) * scale
+// Quantize: input / scale + zp
+//
+// Hence,
+//   output = (input - input_zp) * input_scale / output_scale + output_zp
+//
+// This is simplified as:
+//   output = input * merged_scale + merged_zp
+// where:
+//   merged_zp = output_zp - input_zp * merged_scale.
+//   merged_scale = input_scale / output_scale.
+Value Requantize(mlir::OpState op, Value input,
+                 UniformQuantizedType input_quantized_type,
+                 UniformQuantizedType output_quantized_type,
+                 TensorType output_tensor_type,
+                 ConversionPatternRewriter &rewriter) {
   // Skip requantization when input and result have the same type.
-  if (input_quantized_type == result_quantized_type) {
-    res_int32 = rewriter.create<mhlo::ConvertOp>(op->getLoc(),
-                                                 int32_tensor_type, input);
-    return success();
+  if (input_quantized_type == output_quantized_type) {
+    return rewriter.create<mhlo::ConvertOp>(op->getLoc(), output_tensor_type,
+                                            input);
   }
 
-  // Convert input to int32 tensor.
-  res_int32 =
-      rewriter.create<mhlo::ConvertOp>(op->getLoc(), int32_tensor_type, input);
-  // Undo the input zero point.
-  Value input_zero_point = rewriter.create<mhlo::ConstantOp>(
-      op->getLoc(), rewriter.getI32IntegerAttr(static_cast<int32_t>(
-                        input_quantized_type.getZeroPoint())));
-  res_int32 = rewriter.create<chlo::BroadcastSubOp>(
-      op->getLoc(), int32_tensor_type, res_int32, input_zero_point, nullptr);
-
-  // Adjust the scale.
-  const double effective_scale =
-      input_quantized_type.getScale() / result_quantized_type.getScale();
-  int32_t effective_quantized_fraction;
-  int32_t effective_shift;
-  if (failed(quant::stablehlo::QuantizeMultiplier(
-          effective_scale, effective_quantized_fraction, effective_shift))) {
-    op->emitError("Invalid effective quantization scale.");
-    return failure();
-  }
-  Value multiplier = rewriter.create<mhlo::ConstantOp>(
-      op->getLoc(), rewriter.getI32IntegerAttr(
-                        static_cast<int32_t>(effective_quantized_fraction)));
-  // The effective_quantized_fraction value has been quantized by multiplying
-  // (1 << 15).  So, we have to shift it back by (15 - effective_shift) to get
-  // the desired outcome.
-  Value total_shift = rewriter.create<mhlo::ConstantOp>(
+  double merged_scale_fp =
+      input_quantized_type.getScale() / output_quantized_type.getScale();
+  Value merged_scale = rewriter.create<mhlo::ConstantOp>(
       op->getLoc(),
-      rewriter.getI32IntegerAttr(static_cast<int32_t>(15 - effective_shift)));
+      rewriter.getF32FloatAttr(static_cast<float>(merged_scale_fp)));
 
-  // Apply the effective scale with rounding.
-  Value half = rewriter.create<mhlo::ConstantOp>(
-      op->getLoc(), rewriter.getI32IntegerAttr(
-                        static_cast<int32_t>(1 << (14 - effective_shift))));
-  res_int32 = rewriter.create<chlo::BroadcastMulOp>(
-      op->getLoc(), int32_tensor_type, res_int32, multiplier, nullptr);
-  res_int32 = rewriter.create<chlo::BroadcastAddOp>(
-      op->getLoc(), int32_tensor_type, res_int32, half, nullptr);
-  res_int32 = rewriter.create<chlo::BroadcastShiftRightArithmeticOp>(
-      op->getLoc(), int32_tensor_type, res_int32, total_shift, nullptr);
+  auto float_tensor_type =
+      input.getType().cast<TensorType>().clone(rewriter.getF32Type());
+  Value output_float =
+      rewriter.create<mhlo::ConvertOp>(op->getLoc(), float_tensor_type, input);
 
-  // Apply the output zero point.
-  Value output_zero_point = rewriter.create<mhlo::ConstantOp>(
-      op->getLoc(), rewriter.getI32IntegerAttr(static_cast<int32_t>(
-                        result_quantized_type.getZeroPoint())));
-  res_int32 = rewriter.create<chlo::BroadcastAddOp>(
-      op->getLoc(), int32_tensor_type, res_int32, output_zero_point, nullptr);
+  output_float = rewriter.create<chlo::BroadcastMulOp>(
+      op->getLoc(), float_tensor_type, output_float, merged_scale, nullptr);
 
-  return success();
+  // Add merged_zp only when it is non-zero.
+  double merged_zp_fp = output_quantized_type.getZeroPoint() -
+                        input_quantized_type.getZeroPoint() * merged_scale_fp;
+  if (merged_zp_fp != 0) {
+    Value merged_zp = rewriter.create<mhlo::ConstantOp>(
+        op->getLoc(),
+        rewriter.getF32FloatAttr(static_cast<float>(merged_zp_fp)));
+    output_float = rewriter.create<chlo::BroadcastAddOp>(
+        op->getLoc(), float_tensor_type, output_float, merged_zp, nullptr);
+  }
+
+  // Clamp output if the output integer bit-width <32.
+  if (output_tensor_type.getElementType().cast<IntegerType>().getWidth() < 32) {
+    Value quantization_min = rewriter.create<mhlo::ConstantOp>(
+        op->getLoc(), rewriter.getF32FloatAttr(static_cast<float>(
+                          output_quantized_type.getStorageTypeMin())));
+    Value quantization_max = rewriter.create<mhlo::ConstantOp>(
+        op->getLoc(), rewriter.getF32FloatAttr(static_cast<float>(
+                          output_quantized_type.getStorageTypeMax())));
+    // Clamp results by [quantization_min, quantization_max].
+    output_float = rewriter.create<mhlo::ClampOp>(
+        op->getLoc(), float_tensor_type, quantization_min, output_float,
+        quantization_max);
+  }
+
+  output_float = rewriter.create<mhlo::RoundNearestEvenOp>(
+      op->getLoc(), float_tensor_type, output_float);
+  return rewriter.create<mhlo::ConvertOp>(op->getLoc(), output_tensor_type,
+                                          output_float);
 }
 
-class ConvertMHLOQuantToInt
-    : public impl::ConvertMHLOQuantToIntBase<ConvertMHLOQuantToInt> {
- public:
-  ConvertMHLOQuantToInt() = default;
-  ConvertMHLOQuantToInt(const ConvertMHLOQuantToInt &) {}
+using QuantType =
+    std::variant<UniformQuantizedType, UniformQuantizedPerAxisType>;
+FailureOr<QuantType> GetQuantType(Type type) {
+  if (auto quant_type =
+          getElementTypeOrSelf(type).dyn_cast<UniformQuantizedType>()) {
+    return QuantType(quant_type);
+  } else if (auto quant_type = getElementTypeOrSelf(type)
+                                   .dyn_cast<UniformQuantizedPerAxisType>()) {
+    return QuantType(quant_type);
+  } else {
+    return failure();
+  }
+}
 
-  explicit ConvertMHLOQuantToInt(bool legalize_chlo) {
-    legalize_chlo_ = legalize_chlo;
+// Extract scale and zero point info from input quant type info.
+void GetQuantizationParams(OpBuilder &builder, Location loc,
+                           QuantType quant_type, Value &scales,
+                           Value &zero_points, bool output_zero_point_in_fp,
+                           DenseIntElementsAttr &broadcast_dims) {
+  // Get scales/zero points for per-tensor and per-axis quantization cases.
+  if (auto *quant_per_tensor_type =
+          std::get_if<UniformQuantizedType>(&quant_type)) {
+    scales = builder.create<mhlo::ConstantOp>(
+        loc, builder.getF32FloatAttr(quant_per_tensor_type->getScale()));
+    if (output_zero_point_in_fp) {
+      zero_points = builder.create<mhlo::ConstantOp>(
+          loc, builder.getF32FloatAttr(
+                   static_cast<float>(quant_per_tensor_type->getZeroPoint())));
+    } else {
+      zero_points = builder.create<mhlo::ConstantOp>(
+          loc, builder.getI32IntegerAttr(static_cast<int32_t>(
+                   quant_per_tensor_type->getZeroPoint())));
+    }
+  } else {
+    auto &quant_per_channel_type =
+        std::get<UniformQuantizedPerAxisType>(quant_type);
+    llvm::SmallVector<float> scales_vec;
+    for (auto scale : quant_per_channel_type.getScales())
+      scales_vec.push_back(scale);
+    scales = builder.create<mhlo::ConstantOp>(
+        loc, DenseFPElementsAttr::get(
+                 RankedTensorType::get(
+                     {static_cast<int64_t>(
+                         quant_per_channel_type.getScales().size())},
+                     builder.getF32Type()),
+                 scales_vec));
+    if (output_zero_point_in_fp) {
+      llvm::SmallVector<float> zero_points_vec;
+      for (auto zero_point : quant_per_channel_type.getZeroPoints())
+        zero_points_vec.push_back(zero_point);
+      zero_points = builder.create<mhlo::ConstantOp>(
+          loc, DenseFPElementsAttr::get(
+                   RankedTensorType::get(
+                       {static_cast<int64_t>(
+                           quant_per_channel_type.getZeroPoints().size())},
+                       builder.getF32Type()),
+                   zero_points_vec));
+    } else {
+      llvm::SmallVector<int32_t> zero_points_vec;
+      for (auto zero_point : quant_per_channel_type.getZeroPoints())
+        zero_points_vec.push_back(zero_point);
+      zero_points = builder.create<mhlo::ConstantOp>(
+          loc, DenseIntElementsAttr::get(
+                   RankedTensorType::get(
+                       {static_cast<int64_t>(
+                           quant_per_channel_type.getZeroPoints().size())},
+                       builder.getI32Type()),
+                   zero_points_vec));
+    }
+    broadcast_dims = DenseIntElementsAttr::get(
+        RankedTensorType::get({1}, builder.getI64Type()),
+        {static_cast<int64_t>(quant_per_channel_type.getQuantizedDimension())});
+  }
+}
+
+// Extract storage min/max from input quant type info.
+void GetQuantizationStorageInfo(OpBuilder &builder, Location loc,
+                                QuantType quant_type, Value &storage_min,
+                                Value &storage_max) {
+  if (auto *quant_per_tensor_type =
+          std::get_if<UniformQuantizedType>(&quant_type)) {
+    storage_min = builder.create<mhlo::ConstantOp>(
+        loc, builder.getF32FloatAttr(static_cast<float>(
+                 quant_per_tensor_type->getStorageTypeMin())));
+    storage_max = builder.create<mhlo::ConstantOp>(
+        loc, builder.getF32FloatAttr(static_cast<float>(
+                 quant_per_tensor_type->getStorageTypeMax())));
+  } else {
+    auto &quant_per_channel_type =
+        std::get<UniformQuantizedPerAxisType>(quant_type);
+    storage_min = builder.create<mhlo::ConstantOp>(
+        loc, builder.getF32FloatAttr(static_cast<float>(
+                 quant_per_channel_type.getStorageTypeMin())));
+    storage_max = builder.create<mhlo::ConstantOp>(
+        loc, builder.getF32FloatAttr(static_cast<float>(
+                 quant_per_channel_type.getStorageTypeMax())));
+  }
+}
+
+// Get storage type of a UQ type. Return original type if it is no UQ type.
+Type GetQuantStorageType(Type type) {
+  if (auto shaped = type.dyn_cast<ShapedType>()) {
+    return shaped.clone(GetQuantStorageType(shaped.getElementType()));
   }
 
-  // Performs conversion of MHLO quant ops to primitive ops.
-  void runOnOperation() override;
-};
+  if (auto element_type =
+          getElementTypeOrSelf(type).dyn_cast<UniformQuantizedType>()) {
+    return element_type.getStorageType();
+  } else if (auto element_type = getElementTypeOrSelf(type)
+                                     .dyn_cast<UniformQuantizedPerAxisType>()) {
+    return element_type.getStorageType();
+  } else {
+    return type;
+  }
+}
 
 class ConvertUniformQuantizeOp
     : public OpConversionPattern<mhlo::UniformQuantizeOp> {
@@ -148,124 +249,66 @@
   LogicalResult matchAndRewrite(
       mhlo::UniformQuantizeOp op, mhlo::UniformQuantizeOpAdaptor adaptor,
       ConversionPatternRewriter &rewriter) const override {
-    auto quantized_type = getElementTypeOrSelf(op.getResult().getType())
-                              .dyn_cast<quant::UniformQuantizedType>();
-    // Currently for activation, PTQ supports per-tensor quantization only, and
-    // UniformQuantize op is only for activation.
-    if (!quantized_type) {
-      return rewriter.notifyMatchFailure(
-          op, "Legalization supports only per-tensor quantization.");
-    }
     auto input_element_type = getElementTypeOrSelf(op.getOperand().getType());
     if (input_element_type.isF32()) {
-      return matchAndRewriteQuantize(op, adaptor, rewriter, quantized_type);
-    } else if (input_element_type.isa<quant::UniformQuantizedType>()) {
-      return matchAndRewriteRequantize(op, adaptor, rewriter, quantized_type);
+      auto quant_type = GetQuantType(op.getResult().getType());
+      if (succeeded(quant_type)) {
+        return matchAndRewriteQuantize(op, adaptor, rewriter, *quant_type);
+      }
+    } else if (input_element_type.isa<UniformQuantizedType>()) {
+      return matchAndRewriteRequantize(op, adaptor, rewriter);
     }
     return rewriter.notifyMatchFailure(op, "Unsupported input element type.");
   }
 
-  LogicalResult matchAndRewriteQuantize(
-      mhlo::UniformQuantizeOp op, mhlo::UniformQuantizeOpAdaptor adaptor,
-      ConversionPatternRewriter &rewriter,
-      const quant::UniformQuantizedType &quantized_type) const {
-    Value scale = rewriter.create<mhlo::ConstantOp>(
-        op->getLoc(), rewriter.getF32FloatAttr(quantized_type.getScale()));
-    Value zero_point = rewriter.create<mhlo::ConstantOp>(
-        op->getLoc(), rewriter.getF32FloatAttr(
-                          static_cast<float>(quantized_type.getZeroPoint())));
-    Value quantization_min = rewriter.create<mhlo::ConstantOp>(
-        op->getLoc(), rewriter.getF32FloatAttr(static_cast<float>(
-                          quantized_type.getStorageTypeMin())));
-    Value quantization_max = rewriter.create<mhlo::ConstantOp>(
-        op->getLoc(), rewriter.getF32FloatAttr(static_cast<float>(
-                          quantized_type.getStorageTypeMax())));
+  LogicalResult matchAndRewriteQuantize(mhlo::UniformQuantizeOp op,
+                                        mhlo::UniformQuantizeOpAdaptor adaptor,
+                                        ConversionPatternRewriter &rewriter,
+                                        QuantType quant_type) const {
+    Value scales, zero_points;
+    DenseIntElementsAttr broadcast_dims;
+    GetQuantizationParams(rewriter, op->getLoc(), quant_type, scales,
+                          zero_points, /*output_zero_point_in_fp=*/true,
+                          broadcast_dims);
+
+    Value quantization_min, quantization_max;
+    GetQuantizationStorageInfo(rewriter, op->getLoc(), quant_type,
+                               quantization_min, quantization_max);
 
     auto res_float_tensor_type =
         op.getOperand().getType().clone(rewriter.getF32Type());
     Value res_float = rewriter.create<chlo::BroadcastDivOp>(
-        op->getLoc(), res_float_tensor_type, adaptor.getOperand(), scale,
-        nullptr);
+        op->getLoc(), res_float_tensor_type, adaptor.getOperand(), scales,
+        broadcast_dims);
     res_float = rewriter.create<chlo::BroadcastAddOp>(
-        op->getLoc(), res_float_tensor_type, res_float, zero_point, nullptr);
+        op->getLoc(), res_float_tensor_type, res_float, zero_points,
+        broadcast_dims);
 
     res_float = rewriter.create<mhlo::ClampOp>(
         op->getLoc(), res_float_tensor_type, quantization_min, res_float,
         quantization_max);
     res_float = rewriter.create<mhlo::RoundNearestEvenOp>(
         op->getLoc(), res_float_tensor_type, res_float);
-    auto res_final_tensor_type =
-        res_float_tensor_type.clone(quantized_type.getStorageType());
+    auto res_final_tensor_type = res_float_tensor_type.clone(
+        GetQuantStorageType(op.getResult().getType().getElementType()));
     rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(op, res_final_tensor_type,
                                                  res_float);
     return success();
   }
 
-  // Requantization is essentially dequantize --> quantize.
-  //
-  // Dequantize: (input - zp) * scale
-  // Quantize: input / scale + zp
-  //
-  // Hence,
-  //   result = (input - input_zp) * input_scale / output_scale + output_zp
-  //
-  // This is simplified as:
-  //   result = input * merged_scale + merged_zp
-  // where:
-  //   merged_zp = output_zp - input_zp * merged_scale.
-  //   merged_scale = input_scale / output_scale.
   LogicalResult matchAndRewriteRequantize(
       mhlo::UniformQuantizeOp op, mhlo::UniformQuantizeOpAdaptor adaptor,
-      ConversionPatternRewriter &rewriter,
-      const quant::UniformQuantizedType &output_quantized_type) const {
+      ConversionPatternRewriter &rewriter) const {
     auto input_quantized_type = getElementTypeOrSelf(op.getOperand().getType())
-                                    .cast<quant::UniformQuantizedType>();
-    auto result_quantized_type = getElementTypeOrSelf(op.getResult().getType())
-                                     .cast<quant::UniformQuantizedType>();
-
-    double merged_scale_fp =
-        input_quantized_type.getScale() / result_quantized_type.getScale();
-    Value merged_scale = rewriter.create<mhlo::ConstantOp>(
-        op->getLoc(),
-        rewriter.getF32FloatAttr(static_cast<float>(merged_scale_fp)));
-
-    auto res_float_tensor_type =
-        op.getOperand().getType().clone(rewriter.getF32Type());
-    Value res_float = rewriter.create<mhlo::ConvertOp>(
-        op->getLoc(), res_float_tensor_type, adaptor.getOperand());
-
-    res_float = rewriter.create<chlo::BroadcastMulOp>(
-        op->getLoc(), res_float_tensor_type, res_float, merged_scale, nullptr);
-
-    // Add merged_zp only when it is non-zero.
-    double merged_zp_fp = result_quantized_type.getZeroPoint() -
-                          input_quantized_type.getZeroPoint() * merged_scale_fp;
-    if (merged_zp_fp != 0) {
-      Value merged_zp = rewriter.create<mhlo::ConstantOp>(
-          op->getLoc(),
-          rewriter.getF32FloatAttr(static_cast<float>(merged_zp_fp)));
-      res_float = rewriter.create<chlo::BroadcastAddOp>(
-          op->getLoc(), res_float_tensor_type, res_float, merged_zp, nullptr);
-    }
-
-    Value quantization_min = rewriter.create<mhlo::ConstantOp>(
-        op->getLoc(), rewriter.getF32FloatAttr(static_cast<float>(
-                          output_quantized_type.getStorageTypeMin())));
-    Value quantization_max = rewriter.create<mhlo::ConstantOp>(
-        op->getLoc(), rewriter.getF32FloatAttr(static_cast<float>(
-                          output_quantized_type.getStorageTypeMax())));
-
-    // Clamp results by [quantization_min, quantization_max].
-    res_float = rewriter.create<mhlo::ClampOp>(
-        op->getLoc(), res_float_tensor_type, quantization_min, res_float,
-        quantization_max);
-    res_float = rewriter.create<mhlo::RoundNearestEvenOp>(
-        op->getLoc(), res_float_tensor_type, res_float);
-
-    auto res_final_tensor_type =
-        res_float_tensor_type.clone(output_quantized_type.getStorageType());
-    rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(op, res_final_tensor_type,
-                                                 res_float);
+                                    .cast<UniformQuantizedType>();
+    auto output_quantized_type = getElementTypeOrSelf(op.getResult().getType())
+                                     .cast<UniformQuantizedType>();
+    rewriter.replaceOp(
+        op, Requantize(op, adaptor.getOperand(), input_quantized_type,
+                       output_quantized_type,
+                       op.getResult().getType().cast<TensorType>().clone(
+                           output_quantized_type.getStorageType()),
+                       rewriter));
     return success();
   }
 };
@@ -278,19 +321,15 @@
   LogicalResult matchAndRewrite(
       mhlo::UniformDequantizeOp op, mhlo::UniformDequantizeOpAdaptor adaptor,
       ConversionPatternRewriter &rewriter) const override {
-    auto element_type = getElementTypeOrSelf(op.getOperand().getType())
-                            .dyn_cast<quant::UniformQuantizedType>();
-    // Currently for activation, PTQ supports per-tensor quantization only, and
-    // UniformQuantize op is only for activation.
-    if (!element_type) {
-      return rewriter.notifyMatchFailure(
-          op, "Legalization supports only per-tensor quantization.");
+    auto quant_type = GetQuantType(op.getOperand().getType());
+    if (failed(quant_type)) {
+      return failure();
     }
-    Value scale = rewriter.create<mhlo::ConstantOp>(
-        op->getLoc(), rewriter.getF32FloatAttr(element_type.getScale()));
-    Value zero_point = rewriter.create<mhlo::ConstantOp>(
-        op->getLoc(), rewriter.getI32IntegerAttr(
-                          static_cast<int32_t>(element_type.getZeroPoint())));
+    Value scales, zero_points;
+    DenseIntElementsAttr broadcast_dims;
+    GetQuantizationParams(rewriter, op->getLoc(), *quant_type, scales,
+                          zero_points,
+                          /*output_zero_point_in_fp=*/false, broadcast_dims);
 
     Value input = adaptor.getOperand();
     // TODO: b/260280919 - Consider avoiding conversion to int32.
@@ -299,13 +338,14 @@
     Value res_int32 = rewriter.create<mhlo::ConvertOp>(
         op->getLoc(), res_int32_tensor_type, input);
     res_int32 = rewriter.create<chlo::BroadcastSubOp>(
-        op->getLoc(), res_int32_tensor_type, res_int32, zero_point, nullptr);
+        op->getLoc(), res_int32_tensor_type, res_int32, zero_points,
+        broadcast_dims);
     auto res_float_tensor_type =
         res_int32.getType().cast<TensorType>().clone(rewriter.getF32Type());
     Value res_float = rewriter.create<mhlo::ConvertOp>(
         op->getLoc(), res_float_tensor_type, res_int32);
     res_float = rewriter.replaceOpWithNewOp<chlo::BroadcastMulOp>(
-        op, res_float_tensor_type, res_float, scale, nullptr);
+        op, res_float_tensor_type, res_float, scales, broadcast_dims);
     return success();
   }
 };
@@ -317,18 +357,14 @@
   LogicalResult matchAndRewrite(
       mhlo::AddOp op, mhlo::AddOpAdaptor adaptor,
       ConversionPatternRewriter &rewriter) const override {
-    auto lhs_element_type = op.getLhs()
-                                .getType()
-                                .getElementType()
-                                .dyn_cast<quant::UniformQuantizedType>();
-    auto rhs_element_type = op.getRhs()
-                                .getType()
-                                .getElementType()
-                                .dyn_cast<quant::UniformQuantizedType>();
+    auto lhs_element_type =
+        op.getLhs().getType().getElementType().dyn_cast<UniformQuantizedType>();
+    auto rhs_element_type =
+        op.getRhs().getType().getElementType().dyn_cast<UniformQuantizedType>();
     auto result_element_type = op.getResult()
                                    .getType()
                                    .getElementType()
-                                   .dyn_cast<quant::UniformQuantizedType>();
+                                   .dyn_cast<UniformQuantizedType>();
 
     // We only handle cases where lhs, rhs and results all have quantized
     // element type.
@@ -347,20 +383,14 @@
     // be the same as the result.
     // TODO: b/260280919 - Consider avoiding conversion to int32.
     Value lhs = adaptor.getLhs();
-    Value lhs_int32_tensor;
-    if (failed(RequantizeWithoutClamping(op, lhs, res_int32_tensor_type,
-                                         lhs_element_type, result_element_type,
-                                         lhs_int32_tensor, rewriter))) {
-      return failure();
-    }
+    Value lhs_int32_tensor =
+        Requantize(op, lhs, lhs_element_type, result_element_type,
+                   res_int32_tensor_type, rewriter);
 
     Value rhs = adaptor.getRhs();
-    Value rhs_int32_tensor;
-    if (failed(RequantizeWithoutClamping(op, rhs, res_int32_tensor_type,
-                                         rhs_element_type, result_element_type,
-                                         rhs_int32_tensor, rewriter))) {
-      return failure();
-    }
+    Value rhs_int32_tensor =
+        Requantize(op, rhs, rhs_element_type, result_element_type,
+                   res_int32_tensor_type, rewriter);
 
     Value zero_point = rewriter.create<mhlo::ConstantOp>(
         op->getLoc(), rewriter.getI32IntegerAttr(static_cast<int32_t>(
@@ -437,9 +467,9 @@
   // result = hybridOp(lhs, dequant(rhs))
   Value lhs_float32_tensor = adaptor.getLhs();
   Value rhs = adaptor.getRhs();
-  quant::UniformQuantizedType rhs_element_type =
+  UniformQuantizedType rhs_element_type =
       getElementTypeOrSelf(op.getRhs().getType())
-          .template cast<quant::UniformQuantizedType>();
+          .template cast<UniformQuantizedType>();
   auto res_float32_tensor_type =
       op.getResult().getType().template cast<TensorType>();
   auto rhs_float32_tensor_type =
@@ -481,7 +511,7 @@
 
   // Calculate the output tensor shape. This is input tensor dims minus
   // contracting dims.
-  auto ranked_tensor = tensor.getType().dyn_cast<RankedTensorType>();
+  auto ranked_tensor = tensor.getType().cast<RankedTensorType>();
   llvm::SmallVector<int64_t> output_dims;
   for (int64_t i = 0; i < ranked_tensor.getRank(); ++i) {
     if (absl::c_count(reduction_dims, i) == 0) {
@@ -492,7 +522,7 @@
   // Convert input tensor to output type since mhlo::Reduce only supports same
   // element type for input/output.
   tensor = builder.create<mhlo::ConvertOp>(
-      loc, tensor.getType().dyn_cast<TensorType>().clone(output_element_type),
+      loc, tensor.getType().cast<TensorType>().clone(output_element_type),
       tensor);
   auto reducer_tensor_type = RankedTensorType::get({}, output_element_type);
 
@@ -592,7 +622,7 @@
   // zero-point-offset tensor to the final output tensor, and then do the
   // broadcast.
   auto zp_contribution_rank =
-      zp_contribution.getType().dyn_cast<ShapedType>().getRank();
+      zp_contribution.getType().cast<ShapedType>().getRank();
   llvm::SmallVector<int64_t> broadcast_dims;
   broadcast_dims.resize(zp_contribution_rank, 0);
   // Result tensor will have batching dims first, then LHS result dims, then
@@ -615,7 +645,7 @@
   }
   // Use broadcast_in_dim or dyanmic_broadcast_in_dim based on input shape
   // dynamism.
-  if (zp_contribution.getType().dyn_cast<ShapedType>().hasStaticShape()) {
+  if (zp_contribution.getType().cast<ShapedType>().hasStaticShape()) {
     zp_contribution = builder.create<mhlo::BroadcastInDimOp>(
         loc, output_tensor_type, zp_contribution,
         DenseIntElementsAttr::get(
@@ -742,13 +772,13 @@
         DenseIntElementsAttr::get(
             RankedTensorType::get({}, builder.getI8Type()),
             {static_cast<int8_t>(getElementTypeOrSelf(op.getLhs().getType())
-                                     .dyn_cast<quant::UniformQuantizedType>()
+                                     .cast<UniformQuantizedType>()
                                      .getZeroPoint())}));
     // Convert Padding attributes from mhlo::Convolution to mhlo::Pad. Note that
     // Padding is applied for spatial dimensions [1...rank-1) only for
     // mhlo::Convolution. But mhlo::Pad require those for all dimensions. Hence
     // we add 0 to the beginning and end of the padding vectors.
-    int64_t rank = lhs.getType().dyn_cast<TensorType>().getRank();
+    int64_t rank = lhs.getType().cast<TensorType>().getRank();
     llvm::SmallVector<int64_t> padding_low(rank, 0), padding_high(rank, 0),
         padding_interior(rank, 0);
     for (int64_t i = 1; i < rank - 1; ++i) {
@@ -786,15 +816,12 @@
                                        ConversionPatternRewriter &rewriter) {
   // Lower Dot/DotGeneral UQ ops to DotGeneral int.
   // Assumes that operands and results are uq types.
-  auto lhs_element_quant_type =
-      getElementTypeOrSelf(op.getLhs().getType())
-          .template dyn_cast<quant::UniformQuantizedType>();
-  auto rhs_element_quant_type =
-      getElementTypeOrSelf(op.getRhs().getType())
-          .template dyn_cast<quant::UniformQuantizedType>();
-  auto res_element_quant_type =
-      getElementTypeOrSelf(op.getResult())
-          .template dyn_cast<quant::UniformQuantizedType>();
+  auto lhs_element_quant_type = getElementTypeOrSelf(op.getLhs().getType())
+                                    .template dyn_cast<UniformQuantizedType>();
+  auto rhs_element_quant_type = getElementTypeOrSelf(op.getRhs().getType())
+                                    .template dyn_cast<UniformQuantizedType>();
+  auto res_element_quant_type = getElementTypeOrSelf(op.getResult())
+                                    .template dyn_cast<UniformQuantizedType>();
   Value lhs = adaptor.getLhs();
   Value rhs = adaptor.getRhs();
   auto res_int32_tensor_type =
@@ -820,7 +847,7 @@
   double combined_scale_fp = lhs_element_quant_type.getScale() *
                              rhs_element_quant_type.getScale() /
                              res_element_quant_type.getScale();
-  if (std::abs(combined_scale_fp - 1.0) > 0.005) {
+  if (std::abs(combined_scale_fp - 1.0) > 0.001) {
     Value combined_scale = rewriter.create<mhlo::ConstantOp>(
         op->getLoc(), rewriter.getF32FloatAttr(combined_scale_fp));
 
@@ -837,8 +864,7 @@
     // Skip zp_offset if it is 0.
     if (zp_offset) {
       auto zp_offset_float32_tensor_type =
-          zp_offset.getType().dyn_cast<TensorType>().clone(
-              rewriter.getF32Type());
+          zp_offset.getType().cast<TensorType>().clone(rewriter.getF32Type());
       zp_offset = rewriter.create<mhlo::ConvertOp>(
           op->getLoc(), zp_offset_float32_tensor_type, zp_offset);
       zp_offset = rewriter.create<chlo::BroadcastMulOp>(
@@ -867,15 +893,12 @@
 FailureOr<bool> IsDotLikeOpHybrid(DotLikeOp op) {
   // Checks whether a dot-like op is hybrid by looking at input/output types.
   // Returns failure() when the type is not supported.
-  auto lhs_element_quant_type =
-      getElementTypeOrSelf(op.getLhs().getType())
-          .template dyn_cast<quant::UniformQuantizedType>();
-  auto rhs_element_quant_type =
-      getElementTypeOrSelf(op.getRhs().getType())
-          .template dyn_cast<quant::UniformQuantizedType>();
-  auto res_element_quant_type =
-      getElementTypeOrSelf(op.getResult())
-          .template dyn_cast<quant::UniformQuantizedType>();
+  auto lhs_element_quant_type = getElementTypeOrSelf(op.getLhs().getType())
+                                    .template dyn_cast<UniformQuantizedType>();
+  auto rhs_element_quant_type = getElementTypeOrSelf(op.getRhs().getType())
+                                    .template dyn_cast<UniformQuantizedType>();
+  auto res_element_quant_type = getElementTypeOrSelf(op.getResult())
+                                    .template dyn_cast<UniformQuantizedType>();
   if (lhs_element_quant_type && rhs_element_quant_type &&
       res_element_quant_type) {
     return false;
@@ -996,8 +1019,7 @@
 FailureOr<DotLikeDimensionNumbers> VerifyConvolutionOp(mhlo::ConvolutionOp op) {
   // RHS (weight) must have zero zp.
   auto rhs_element_quant_type =
-      getElementTypeOrSelf(op.getRhs().getType())
-          .template dyn_cast<quant::UniformQuantizedType>();
+      getElementTypeOrSelf(op.getRhs().getType()).cast<UniformQuantizedType>();
   if (rhs_element_quant_type.getZeroPoint() != 0) {
     op->emitError("RHS UQ type must have zero zp.");
     return failure();
@@ -1074,15 +1096,15 @@
     // Check that all operands and result uq types are the same.
     llvm::SmallVector<Type> uq_types;
     for (auto result_type : op->getResultTypes()) {
-      auto type = getElementTypeOrSelf(result_type)
-                      .dyn_cast<quant::UniformQuantizedType>();
+      auto type =
+          getElementTypeOrSelf(result_type).dyn_cast<UniformQuantizedType>();
       if (type) {
         uq_types.push_back(type);
       }
     }
     for (auto operand : op->getOperands()) {
       auto type = getElementTypeOrSelf(operand.getType())
-                      .dyn_cast<quant::UniformQuantizedType>();
+                      .dyn_cast<UniformQuantizedType>();
       if (type) {
         uq_types.push_back(type);
       }
@@ -1097,15 +1119,7 @@
     // type otherwise.
     llvm::SmallVector<Type, 4> new_result_types;
     for (auto result_type : op->getResultTypes()) {
-      if (getElementTypeOrSelf(result_type)
-              .isa<quant::UniformQuantizedType>()) {
-        new_result_types.push_back(result_type.cast<TensorType>().clone(
-            getElementTypeOrSelf(result_type)
-                .cast<quant::UniformQuantizedType>()
-                .getStorageType()));
-      } else {
-        new_result_types.push_back(result_type);
-      }
+      new_result_types.push_back(GetQuantStorageType(result_type));
     }
 
     OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
@@ -1120,73 +1134,78 @@
 class UQTypeConverter : public TypeConverter {
  public:
   UQTypeConverter() {
-    addConversion([](Type type) -> Type {
-      auto to_legal_type = [](Type type) {
-        if (auto uq_type = dyn_cast<quant::UniformQuantizedType>(type)) {
-          return uq_type.getStorageType();
-        }
-        return type;
-      };
-      if (auto shaped = type.dyn_cast<ShapedType>()) {
-        return shaped.clone(to_legal_type(shaped.getElementType()));
-      } else {
-        return to_legal_type(type);
-      }
-    });
+    addConversion([](Type type) -> Type { return GetQuantStorageType(type); });
   }
 };
 
-// Performs conversion of MHLO quant ops to primitive ops.
-void ConvertMHLOQuantToInt::runOnOperation() {
-  Operation *op = getOperation();
-  MLIRContext *context = op->getContext();
-  RewritePatternSet patterns(context);
+#define GEN_PASS_DEF_CONVERTMHLOQUANTTOINT
+#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h.inc"
 
-  // Populate MHLO quant ops conversion patterns.
-  patterns.add<ConvertUniformQuantizeOp, ConvertUniformDequantizeOp,
-               ConvertUniformQuantizedAddOp, ConvertUniformQuantizedDotOp,
-               ConvertUniformQuantizedDotGeneralOp,
-               ConvertUniformQuantizedConvolutionOp, ConvertGenericOp>(context);
+class ConvertMHLOQuantToInt
+    : public impl::ConvertMHLOQuantToIntBase<ConvertMHLOQuantToInt> {
+ public:
+  ConvertMHLOQuantToInt() = default;
+  ConvertMHLOQuantToInt(const ConvertMHLOQuantToInt &) {}
 
-  // uq->int convert patterns for func.func and func.return.
-  UQTypeConverter converter;
-  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
-                                                                 converter);
-  populateReturnOpTypeConversionPattern(patterns, converter);
-
-  ConversionTarget target(*op->getContext());
-  auto is_legal = [&converter](Operation *op) { return converter.isLegal(op); };
-  target.addDynamicallyLegalDialect<mhlo::MhloDialect>(is_legal);
-  target.addDynamicallyLegalDialect<chlo::ChloDialect>(is_legal);
-  target.addDynamicallyLegalDialect<func::FuncDialect>(
-      [&converter](Operation *op) {
-        if (auto func = dyn_cast<func::FuncOp>(op)) {
-          return converter.isSignatureLegal(func.getFunctionType());
-        }
-        return converter.isLegal(op);
-      });
-
-  LogicalResult result =
-      applyPartialConversion(op, target, std::move(patterns));
-  if (failed(result)) {
-    signalPassFailure();
+  explicit ConvertMHLOQuantToInt(bool legalize_chlo) {
+    legalize_chlo_ = legalize_chlo;
   }
 
-  // Legalize CHLO if needed.
-  if (!legalize_chlo_) return;
-  RewritePatternSet patterns_2(context);
+  // Performs conversion of MHLO quant ops to primitive ops.
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    MLIRContext *context = op->getContext();
+    RewritePatternSet patterns(context);
 
-  chlo::populateDecomposeChloPatterns(context, &patterns_2);
-  chlo::populateChloBroadcastingPatterns(context, &patterns_2);
+    // Populate MHLO quant ops conversion patterns.
+    patterns.add<ConvertUniformQuantizeOp, ConvertUniformDequantizeOp,
+                 ConvertUniformQuantizedAddOp, ConvertUniformQuantizedDotOp,
+                 ConvertUniformQuantizedDotGeneralOp,
+                 ConvertUniformQuantizedConvolutionOp, ConvertGenericOp>(
+        context);
 
-  ConversionTarget target_2 =
-      mhlo::GetDefaultLegalConversionTargets(*op->getContext(), legalize_chlo_);
+    // uq->int convert patterns for func.func and func.return.
+    UQTypeConverter converter;
+    populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
+                                                                   converter);
+    populateReturnOpTypeConversionPattern(patterns, converter);
 
-  result = applyPartialConversion(op, target_2, std::move(patterns_2));
-  if (failed(result)) {
-    signalPassFailure();
+    ConversionTarget target(*op->getContext());
+    auto is_legal = [&converter](Operation *op) {
+      return converter.isLegal(op);
+    };
+    target.addDynamicallyLegalDialect<mhlo::MhloDialect>(is_legal);
+    target.addDynamicallyLegalDialect<chlo::ChloDialect>(is_legal);
+    target.addDynamicallyLegalDialect<func::FuncDialect>(
+        [&converter](Operation *op) {
+          if (auto func = dyn_cast<func::FuncOp>(op)) {
+            return converter.isSignatureLegal(func.getFunctionType());
+          }
+          return converter.isLegal(op);
+        });
+
+    LogicalResult result =
+        applyPartialConversion(op, target, std::move(patterns));
+    if (failed(result)) {
+      signalPassFailure();
+    }
+
+    // Legalize CHLO if needed.
+    if (!legalize_chlo_) return;
+    RewritePatternSet patterns_2(context);
+
+    chlo::populateDecomposeChloPatterns(context, &patterns_2);
+    chlo::populateChloBroadcastingPatterns(context, &patterns_2);
+
+    ConversionTarget target_2 = mhlo::GetDefaultLegalConversionTargets(
+        *op->getContext(), legalize_chlo_);
+
+    result = applyPartialConversion(op, target_2, std::move(patterns_2));
+    if (failed(result)) {
+      signalPassFailure();
+    }
   }
-}
+};
 
 }  // namespace
 
diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_to_mhlo_int_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_to_mhlo_int_test.cc
index e611b25..dc686d8 100644
--- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_to_mhlo_int_test.cc
+++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_to_mhlo_int_test.cc
@@ -13,24 +13,37 @@
 limitations under the License.
 ==============================================================================*/
 
+#include <cstdint>
+#include <cstring>
 #include <memory>
+#include <optional>
 #include <utility>
 #include <vector>
 
 #include <gtest/gtest.h>
 #include "absl/log/check.h"
+#include "absl/random/random.h"
+#include "absl/status/status.h"
 #include "absl/strings/string_view.h"
 #include "absl/types/span.h"
+#include "llvm/Support/Casting.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
 #include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
 #include "mlir/IR/DialectRegistry.h"  // from @llvm-project
+#include "mlir/IR/OwningOpRef.h"  // from @llvm-project
+#include "mlir/IR/Value.h"  // from @llvm-project
 #include "mlir/Parser/Parser.h"  // from @llvm-project
 #include "mlir/Pass/PassManager.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
 #include "stablehlo/dialect/ChloOps.h"  // from @stablehlo
 #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h"
+#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/constant_fold.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
+#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
 #include "xla/error_spec.h"
 #include "xla/literal.h"
 #include "xla/literal_util.h"
@@ -38,7 +51,12 @@
 #include "xla/pjrt/pjrt_client.h"
 #include "xla/pjrt/pjrt_executable.h"
 #include "xla/pjrt/tfrt_cpu_pjrt_client.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
 #include "xla/tests/literal_test_util.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tsl/platform/errors.h"
 #include "tsl/platform/statusor.h"
 
 namespace mlir::quant::stablehlo {
@@ -51,6 +69,7 @@
     dialects.insert<TF::TensorFlowDialect, func::FuncDialect, chlo::ChloDialect,
                     mhlo::MhloDialect, quant::QuantizationDialect>();
     ctx_ = std::make_unique<MLIRContext>(dialects);
+    ctx_->loadAllAvailableDialects();
 
     // Create a CPU client with 1 device.
     TF_ASSERT_OK_AND_ASSIGN(
@@ -60,18 +79,117 @@
     CHECK(device_);
   }
 
-  absl::StatusOr<std::unique_ptr<xla::PjRtLoadedExecutable>> CompileProgram(
-      absl::string_view program) {
-    // Parse the program.
+  absl::StatusOr<OwningOpRef<ModuleOp>> ReplaceFuncArgsByConstant(
+      absl::string_view program,
+      absl::Span<const xla::Literal* const> arguments,
+      bool use_mhlo_const = false) {
     auto module_op = parseSourceString<ModuleOp>(program, ctx_.get());
     CHECK(module_op);
+    auto func_op = llvm::dyn_cast<func::FuncOp>(
+        *module_op->getBodyRegion().getOps().begin());
+    if (!func_op) {
+      return absl::InternalError("Input MLIR must have only 1 func");
+    }
+    if (arguments.size() != func_op.getNumArguments()) {
+      return absl::InternalError("Input argument has wrong size");
+    }
+
+    // Convert input xla::Literal arguments to constants, this allows using
+    // constant folding to evaluate function return value.
+    mlir::OpBuilder builder(ctx_.get());
+    for (int i = 0; i < arguments.size(); ++i) {
+      const xla::Literal* const xla_literal = arguments[i];
+      tensorflow::TensorShape shape;
+      TF_ASSIGN_OR_RETURN(auto data_type,
+                          tensorflow::EncodePrimitiveTypeAsDataType(
+                              xla_literal->shape().element_type()));
+      TF_RETURN_IF_ERROR(
+          tensorflow::XLAShapeToTensorShape(xla_literal->shape(), &shape));
+      tensorflow::Tensor tensor(data_type, shape);
+      std::memcpy(static_cast<char*>(tensor.data()),
+                  xla_literal->untyped_data(),
+                  xla::ShapeUtil::ByteSizeOfPrimitiveType(
+                      xla_literal->shape().element_type()) *
+                      xla_literal->element_count());
+      TF_ASSIGN_OR_RETURN(auto attrs,
+                          tensorflow::ConvertTensor(tensor, &builder));
+      builder.setInsertionPoint(
+          &func_op.getFunctionBody().getBlocks().front().front());
+      // Use mhlo.Constant when it is consumed by the lowering passes since they
+      // can't lower tf.Const.
+      Value cst;
+      if (use_mhlo_const) {
+        cst = builder.create<mhlo::ConstantOp>(func_op->getLoc(), attrs);
+      } else {
+        cst = builder.create<TF::ConstOp>(func_op->getLoc(), attrs);
+      }
+      func_op.getArgument(i).replaceAllUsesWith(cst);
+    }
+    return module_op;
+  }
+
+  // Evaluate return value of a function using TF kernel.
+  // This assumes that the module op has only 1 function and it has TF ops only.
+  absl::StatusOr<std::shared_ptr<xla::Literal>> EvaluateTfFunction(
+      absl::string_view program,
+      absl::Span<const xla::Literal* const> arguments) {
+    TF_ASSIGN_OR_RETURN(auto module_op,
+                        ReplaceFuncArgsByConstant(program, arguments));
+    // Constant fold the func.Return op's producer op to evaluate the return
+    // value. The evaluation will use TF kernels.
+    // This assumes that func.Return is the last op in the function and it
+    // returns only 1 value.
+    auto& return_op = llvm::dyn_cast<func::FuncOp>(
+                          *module_op->getBodyRegion().getOps().begin())
+                          .getFunctionBody()
+                          .getBlocks()
+                          .back()
+                          .back();
+    if (!llvm::isa<func::ReturnOp>(return_op) ||
+        return_op.getNumOperands() != 1) {
+      return absl::InternalError(
+          "Func must have ReturnOp as last op and must return 1 value");
+    }
+    auto def_op = return_op.getOperand(0).getDefiningOp();
+    auto fold_results = ConstantFoldOpIfPossible(def_op);
+    if (fold_results.size() != 1 ||
+        !llvm::isa<TF::ConstOp>(fold_results[0].getDefiningOp())) {
+      return absl::InternalError("Failed to evaluate TF ops");
+    }
+
+    // Convert output tensor back to xla::Literal.
+    tensorflow::Tensor tensor;
+    TF_RETURN_IF_ERROR(tensorflow::ConvertToTensor(
+        llvm::dyn_cast<TF::ConstOp>(fold_results[0].getDefiningOp()).getValue(),
+        &tensor));
+    xla::Shape xla_shape;
+    TF_RETURN_IF_ERROR(tensorflow::TensorShapeToXLAShape(
+        tensor.dtype(), tensor.shape(), &xla_shape));
+    xla::PjRtClient::HostBufferSemantics host_buffer_semantics =
+        xla::PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes;
+    TF_ASSIGN_OR_RETURN(
+        auto buffer,
+        pjrt_client_->BufferFromHostBuffer(
+            tensor.data(), xla_shape.element_type(), xla_shape.dimensions(),
+            /*byte_strides=*/std::nullopt, host_buffer_semantics,
+            /*on_done_with_host_buffer=*/nullptr, device_));
+    return buffer->ToLiteralSync();
+  }
+
+  absl::StatusOr<std::unique_ptr<xla::PjRtLoadedExecutable>> CompileProgram(
+      absl::string_view program,
+      absl::Span<const xla::Literal* const> arguments) {
+    // Replace args by mhlo.constant since the lowering passes can't lower
+    // tf.Const.
+    TF_ASSIGN_OR_RETURN(
+        auto module_op,
+        ReplaceFuncArgsByConstant(program, arguments, /*use_mhlo_const=*/true));
+
     // Run the Convert TF Quant Types, TF Quant -> MHLO Quant and MHLO Quant ->
     // MHLO int passes.
     PassManager pm(module_op->getContext());
     pm.addNestedPass<func::FuncOp>(CreateConvertTFQuantTypesPass());
-    pm.addNestedPass<func::FuncOp>(CreateConvertTFQuantOpsToMHLOPass());
-    pm.addNestedPass<func::FuncOp>(
-        stablehlo::createConvertMHLOQuantToIntPass(false));
+    AddQuantizationLoweringPasses(pm);
     CHECK(succeeded(pm.run(module_op.get())));
     // Compile the program.
     return pjrt_client_->Compile(*module_op, xla::CompileOptions{});
@@ -98,230 +216,369 @@
     return result[0][0]->ToLiteralSync();
   }
 
+  void ExecuteAndCompareResultsWithTfKernel(
+      absl::string_view program,
+      absl::Span<const xla::Literal* const> arguments,
+      std::optional<absl::string_view> tf_program = std::nullopt,
+      double error_tolerance = 0.1) {
+    // Expected result is calculated by evaluating using TF kernels. In some
+    // cases, TF kernel behaves differently from lowered graph (e.g. Hybrid
+    // ops). So we optionally use a different graph to calculate the expected
+    // result.
+    TF_ASSERT_OK_AND_ASSIGN(
+        auto expected,
+        this->EvaluateTfFunction(
+            (tf_program.has_value() ? *tf_program : program), arguments));
+
+    TF_ASSERT_OK_AND_ASSIGN(auto executable,
+                            this->CompileProgram(program, arguments));
+    TF_ASSERT_OK_AND_ASSIGN(
+        auto result,
+        this->ExecuteProgramAndReturnSingleResult(executable.get(), arguments));
+
+    // Convert to double for comparison. This is needed for comparing integers
+    // since it LiteralTestUtil asserts different integers even if it is within
+    // error_spec.
+    TF_ASSERT_OK_AND_ASSIGN(auto expected_double, expected->Convert(xla::F64))
+    TF_ASSERT_OK_AND_ASSIGN(auto result_double, result->Convert(xla::F64))
+    EXPECT_TRUE(xla::LiteralTestUtil::Near(expected_double, result_double,
+                                           xla::ErrorSpec(error_tolerance)));
+  }
+
+  absl::StatusOr<xla::Literal> CreateRandomF32Literal(
+      absl::Span<const int64_t> dims, float min = -100, float max = 100) {
+    TF_ASSIGN_OR_RETURN(auto shape,
+                        xla::ShapeUtil::MakeValidatedShape(xla::F32, dims));
+    return xla::LiteralUtil::CreateLiteralWithGenerator<xla::F32, float>(
+        shape, [this, min, max](absl::Span<const int64_t> dims) -> float {
+          return absl::Uniform(bitgen_, min, max);
+        });
+  }
+
+  absl::StatusOr<xla::Literal> CreateRandomI8Literal(
+      absl::Span<const int64_t> dims, int8_t min = -128, int8_t max = 127) {
+    TF_ASSIGN_OR_RETURN(auto shape,
+                        xla::ShapeUtil::MakeValidatedShape(xla::S8, dims));
+    return xla::LiteralUtil::CreateLiteralWithGenerator<xla::S8, int8_t>(
+        shape, [this, min, max](absl::Span<const int64_t> dims) -> int8_t {
+          return absl::Uniform(bitgen_, min, max);
+        });
+  }
+
+  absl::StatusOr<xla::Literal> CreateRandomI32Literal(
+      absl::Span<const int64_t> dims, int32_t min = -128, int32_t max = 127) {
+    TF_ASSIGN_OR_RETURN(auto shape,
+                        xla::ShapeUtil::MakeValidatedShape(xla::S32, dims));
+    return xla::LiteralUtil::CreateLiteralWithGenerator<xla::S32, int32_t>(
+        shape, [this, min, max](absl::Span<const int64_t> dims) -> int32_t {
+          return absl::Uniform(bitgen_, min, max);
+        });
+  }
+
   std::unique_ptr<MLIRContext> ctx_;
   std::unique_ptr<xla::PjRtClient> pjrt_client_;
   xla::PjRtDevice* device_;
+  absl::BitGen bitgen_;
 };
 
 TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizeAndDequantize) {
   constexpr absl::string_view kProgram = R"mlir(
-func.func @main(%arg0: tensor<4xf32>) -> tensor<4xf32> {
-  %scale = "tf.Const"() { value = dense<10.0> : tensor<f32> } : ()
-    -> tensor<f32>
+func.func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
+  %scale = "tf.Const"() { value = dense<0.347> : tensor<f32> } : () -> tensor<f32>
   %zp = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
   %0 = "tf.UniformQuantize"(%arg0, %scale, %zp) {
     quantization_axis = -1 : i64,
     quantization_min_val = -128 : i64,
     quantization_max_val = 127 : i64
-  } : (tensor<4xf32>, tensor<f32>, tensor<i32>) -> tensor<4x!tf_type.qint8>
+  } : (tensor<10xf32>, tensor<f32>, tensor<i32>) -> tensor<10x!tf_type.qint8>
   %1 = "tf.UniformDequantize"(%0, %scale, %zp) {
     quantization_axis = -1 : i64,
     quantization_min_val = -128 : i64,
     quantization_max_val = 127 : i64
-  } : (tensor<4x!tf_type.qint8>, tensor<f32>, tensor<i32>) -> tensor<4xf32>
-  return %1 : tensor<4xf32>
+  } : (tensor<10x!tf_type.qint8>, tensor<f32>, tensor<i32>) -> tensor<10xf32>
+  return %1 : tensor<10xf32>
 })mlir";
-  TF_ASSERT_OK_AND_ASSIGN(auto executable, this->CompileProgram(kProgram));
+  TF_ASSERT_OK_AND_ASSIGN(auto arg0, CreateRandomF32Literal({10}));
+  // error_tolerance is set to be slightly > scale because different rounding
+  // implementations for UniformQuantize in TF kernel and the lowering passes
+  // may cause +/-1 differences.
+  ExecuteAndCompareResultsWithTfKernel(
+      kProgram, {&arg0}, /*tf_program=*/std::nullopt, /*error_tolerance=*/0.35);
+}
 
-  auto arg0 =
-      xla::LiteralUtil::CreateR1<float>({100.0f, 20000.0f, -2409.0f, -25.1f});
+TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizePerChannel) {
+  constexpr absl::string_view kProgram = R"mlir(
+func.func @main(
+    %arg0: tensor<10x10xf32>, %scale: tensor<10xf32>, %zp: tensor<10xi32>
+  ) -> tensor<10x10xi8> {
+  %0 = "tf.UniformQuantize"(%arg0, %scale, %zp) {
+    quantization_axis = 1 : i64,
+    quantization_min_val = -128 : i64,
+    quantization_max_val = 127 : i64
+  } : (tensor<10x10xf32>, tensor<10xf32>, tensor<10xi32>) -> tensor<10x10x!tf_type.qint8>
+  %1 = "tf.Cast"(%0) {} : (tensor<10x10x!tf_type.qint8>) -> tensor<10x10xi8>
+  return %1 : tensor<10x10xi8>
+})mlir";
+  TF_ASSERT_OK_AND_ASSIGN(auto arg0, CreateRandomF32Literal({10, 10}));
   TF_ASSERT_OK_AND_ASSIGN(
-      auto result_literal,
-      this->ExecuteProgramAndReturnSingleResult(executable.get(), {&arg0}));
-  xla::LiteralTestUtil::ExpectR1Near<float>({100.0f, 1240.0f, -1310.0f, -30.0f},
-                                            *result_literal,
-                                            xla::ErrorSpec(0.001f));
+      auto scale, CreateRandomF32Literal({10}, /*min=*/0.0001, /*max=*/2));
+  TF_ASSERT_OK_AND_ASSIGN(auto zp, CreateRandomI32Literal({10}));
+  // Different rounding implementations for UniformQuantize in TF kernel and the
+  // lowering passes may cause +/-1 differences.
+  ExecuteAndCompareResultsWithTfKernel(kProgram, {&arg0, &scale, &zp},
+                                       /*tf_program=*/std::nullopt,
+                                       /*error_tolerance=*/1.0);
+}
+
+TEST_F(ConvertTfQuantToMhloIntTest, UniformDequantizePerChannel) {
+  constexpr absl::string_view kProgram = R"mlir(
+func.func @main(
+    %arg0: tensor<10x10xi8>, %scale: tensor<10xf32>, %zp: tensor<10xi32>
+  ) -> tensor<10x10xf32> {
+  %0 = "tf.Cast"(%arg0) {} : (tensor<10x10xi8>) -> tensor<10x10x!tf_type.qint8>
+  %1 = "tf.UniformDequantize"(%0, %scale, %zp) {
+    quantization_axis = 1 : i64,
+    quantization_min_val = -128 : i64,
+    quantization_max_val = 127 : i64
+  } : (tensor<10x10x!tf_type.qint8>, tensor<10xf32>, tensor<10xi32>) -> tensor<10x10xf32>
+  return %1 : tensor<10x10xf32>
+})mlir";
+  TF_ASSERT_OK_AND_ASSIGN(auto arg0, CreateRandomI8Literal({10, 10}));
+  TF_ASSERT_OK_AND_ASSIGN(
+      auto scale, CreateRandomF32Literal({10}, /*min=*/0.0001, /*max=*/2));
+  TF_ASSERT_OK_AND_ASSIGN(auto zp, CreateRandomI32Literal({10}));
+  ExecuteAndCompareResultsWithTfKernel(kProgram, {&arg0, &scale, &zp});
 }
 
 TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizeConvolution) {
   constexpr absl::string_view kProgram = R"mlir(
-func.func @main(%input: tensor<1x2x2x1xf32>, %filter: tensor<2x1x1x1xf32>) -> tensor<1x2x2x1xf32> {
-    %input_scale = "tf.Const"() { value = dense<7.3> : tensor<f32> } : ()
-    -> tensor<f32>
-    %input_zp = "tf.Const"() { value = dense<-45> : tensor<i32> } : () -> tensor<i32>
-    %filter_scale = "tf.Const"() { value = dense<0.047> : tensor<f32> } : ()
-    -> tensor<f32>
-    %filter_zp = "tf.Const"() { value = dense<0> : tensor<i32> } : () -> tensor<i32>
-    %accum_scale = "tf.Const"() { value = dense<0.3431> : tensor<f32> } : ()
-    -> tensor<f32>
-    %accum_zp = "tf.Const"() { value = dense<0> : tensor<i32> } : () -> tensor<i32>
-    %quant_input = "tf.UniformQuantize"(%input, %input_scale, %input_zp) {
-      Tin = "tfdtype$DT_FLOAT", Tout = "tfdtype$DT_QINT8",
-      attr_map = "", quantization_axis = -1 : i64, quantization_max_val = 127 : i64,
-      quantization_min_val = -128 : i64
-    } : (tensor<1x2x2x1xf32>, tensor<f32>, tensor<i32>) -> tensor<1x2x2x1x!tf_type.qint8>
-    %quant_filter = "tf.UniformQuantize"(%filter, %filter_scale, %filter_zp) {
-      Tin = "tfdtype$DT_FLOAT", Tout = "tfdtype$DT_QINT8",
-      attr_map = "", quantization_axis = -1 : i64,
-      quantization_max_val = 127 : i64, quantization_min_val = -128 : i64
-    } : (tensor<2x1x1x1xf32>, tensor<f32>, tensor<i32>) -> tensor<2x1x1x1x!tf_type.qint8>
-    %0 = "tf.UniformQuantizedConvolution"(
-      %quant_input, %quant_filter, %input_scale, %input_zp,
-      %filter_scale, %filter_zp, %accum_scale, %accum_zp
-    ) {
-      Tin = "tfdtype$DT_QINT8", Tout = "tfdtype$DT_QINT32",
-      attr_map = "", batch_group_count = 1 : i64,
-      dimension_numbers = "\10\03\1A\02\01\02 \02(\032\02\00\01@\03J\02\01\02",
-      explicit_padding = [], feature_group_count = 1 : i64, lhs_dilation = [1, 1],
-      lhs_quantization_axis = -1 : i64, lhs_quantization_max_val = 127 : i64,
-      lhs_quantization_min_val = -128 : i64, output_quantization_axis = -1 : i64,
-      output_quantization_max_val = 2147483647 : i64,
-      output_quantization_min_val = -2147483648 : i64, padding = "SAME",
-      rhs_dilation = [1, 1], rhs_quantization_axis = -1 : i64,
-      rhs_quantization_max_val = 127 : i64, rhs_quantization_min_val = -128 : i64,
-      window_strides = [1, 1]
-    } : (tensor<1x2x2x1x!tf_type.qint8>, tensor<2x1x1x1x!tf_type.qint8>,
-      tensor<f32>, tensor<i32>, tensor<f32>, tensor<i32>, tensor<f32>, tensor<i32>
-    ) -> tensor<1x2x2x1x!tf_type.qint32>
-    %output = "tf.UniformDequantize"(%0, %accum_scale, %accum_zp) {
-      quantization_axis = -1 : i64, quantization_min_val = -128 : i64,
-      quantization_max_val = 127 : i64
-    } : (tensor<1x2x2x1x!tf_type.qint32>, tensor<f32>, tensor<i32>) -> tensor<1x2x2x1xf32>
-    return %output : tensor<1x2x2x1xf32>
+func.func @main(%input: tensor<1x9x9x9xi8>, %filter: tensor<3x3x9x10xi8>) -> tensor<1x9x9x10xi32> {
+  %input_scale = "tf.Const"() { value = dense<2.0> : tensor<f32> } : () -> tensor<f32>
+  %input_zp = "tf.Const"() { value = dense<-10> : tensor<i32> } : () -> tensor<i32>
+  %filter_scale = "tf.Const"() { value = dense<0.5> : tensor<f32> } : () -> tensor<f32>
+  %filter_zp = "tf.Const"() { value = dense<0> : tensor<i32> } : () -> tensor<i32>
+  %accum_scale = "tf.Const"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32>
+  %accum_zp = "tf.Const"() { value = dense<0> : tensor<i32> } : () -> tensor<i32>
+  %quant_input = "tf.Cast"(%input) {} : (tensor<1x9x9x9xi8>) ->
+    tensor<1x9x9x9x!tf_type.qint8>
+  %quant_filter = "tf.Cast"(%filter) {} : (tensor<3x3x9x10xi8>) ->
+    tensor<3x3x9x10x!tf_type.qint8>
+  %0 = "tf.UniformQuantizedConvolution"(
+    %quant_input, %quant_filter, %input_scale, %input_zp,
+    %filter_scale, %filter_zp, %accum_scale, %accum_zp
+  ) {
+    Tin = "tfdtype$DT_QINT8", Tout = "tfdtype$DT_QINT32",
+    attr_map = "", batch_group_count = 1 : i64,
+    dimension_numbers = "\10\03\1A\02\01\02 \02(\032\02\00\01@\03J\02\01\02",
+    explicit_padding = [], feature_group_count = 1 : i64, lhs_dilation = [1, 1],
+    lhs_quantization_axis = -1 : i64, lhs_quantization_max_val = 127 : i64,
+    lhs_quantization_min_val = -128 : i64, output_quantization_axis = -1 : i64,
+    output_quantization_max_val = 2147483647 : i64,
+    output_quantization_min_val = -2147483648 : i64, padding = "SAME",
+    rhs_dilation = [1, 1], rhs_quantization_axis = -1 : i64,
+    rhs_quantization_max_val = 127 : i64, rhs_quantization_min_val = -128 : i64,
+    window_strides = [1, 1]
+  } : (tensor<1x9x9x9x!tf_type.qint8>, tensor<3x3x9x10x!tf_type.qint8>,
+    tensor<f32>, tensor<i32>, tensor<f32>, tensor<i32>, tensor<f32>, tensor<i32>
+  ) -> tensor<1x9x9x10x!tf_type.qint32>
+  %output = "tf.Cast"(%0) {} : (tensor<1x9x9x10x!tf_type.qint32>) -> tensor<1x9x9x10xi32>
+  return %output : tensor<1x9x9x10xi32>
 })mlir";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto executable, this->CompileProgram(kProgram));
-
-  auto input = xla::LiteralUtil::CreateR4<float>(
-      {{{{14.f}, {-100.f}}, {{-200.f}, {350.f}}}});
-  auto filter = xla::LiteralUtil::CreateR4<float>({{{{4.1f}}}, {{{-2.f}}}});
-
-  TF_ASSERT_OK_AND_ASSIGN(auto result_literal,
-                          this->ExecuteProgramAndReturnSingleResult(
-                              executable.get(), {&input, &filter}));
-  xla::LiteralTestUtil::ExpectR4Near<float>(
-      {{{{458.f}, {-1126.f}}, {{-806.f}, {1433.f}}}}, *result_literal,
-      xla::ErrorSpec(1.f));
+  TF_ASSERT_OK_AND_ASSIGN(auto input, CreateRandomI8Literal({1, 9, 9, 9}));
+  TF_ASSERT_OK_AND_ASSIGN(auto filter, CreateRandomI8Literal({3, 3, 9, 10}));
+  ExecuteAndCompareResultsWithTfKernel(kProgram, {&input, &filter});
 }
 
 TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizeConvolutionHybrid) {
-  constexpr absl::string_view kProgram = R"mlir(
-func.func @main(%input: tensor<1x2x2x1xf32>, %filter: tensor<2x1x1x1xf32>) -> tensor<1x2x2x1xf32> {
-    %filter_scale = "tf.Const"() { value = dense<0.047> : tensor<f32> } : ()
-    -> tensor<f32>
-    %filter_zp = "tf.Const"() { value = dense<0> : tensor<i32> } : () -> tensor<i32>
-    %quant_filter = "tf.UniformQuantize"(%filter, %filter_scale, %filter_zp) {
-      Tin = "tfdtype$DT_FLOAT", Tout = "tfdtype$DT_QINT8",
-      attr_map = "", quantization_axis = -1 : i64,
-      quantization_max_val = 127 : i64, quantization_min_val = -128 : i64
-    } : (tensor<2x1x1x1xf32>, tensor<f32>, tensor<i32>) -> tensor<2x1x1x1x!tf_type.qint8>
-    %0 = "tf.UniformQuantizedConvolutionHybrid"(
-      %input, %quant_filter, %filter_scale, %filter_zp
-    ) {
-      Tin = "tfdtype$DT_QINT8", Tout = "tfdtype$DT_FLOAT",
-      attr_map = "", batch_group_count = 1 : i64,
-      dimension_numbers = "\10\03\1A\02\01\02 \02(\032\02\00\01@\03J\02\01\02",
-      explicit_padding = [], feature_group_count = 1 : i64, lhs_dilation = [1, 1],
-      padding = "SAME", rhs_dilation = [1, 1], rhs_quantization_axis = -1 : i64,
-      rhs_quantization_max_val = 127 : i64, rhs_quantization_min_val = -128 : i64,
-      window_strides = [1, 1]
-    } : (tensor<1x2x2x1xf32>, tensor<2x1x1x1x!tf_type.qint8>,
-      tensor<f32>, tensor<i32>) -> tensor<1x2x2x1xf32>
-    return %0 : tensor<1x2x2x1xf32>
+  constexpr absl::string_view kTfProgram = R"mlir(
+func.func @main(%input: tensor<2x10x10x10xf32>, %filter: tensor<3x3x10x20xi8>) -> tensor<2x10x10x20xf32> {
+  %filter_scale = "tf.Const"() { value = dense<0.047> : tensor<f32> } : () -> tensor<f32>
+  %filter_zp = "tf.Const"() { value = dense<0> : tensor<i32> } : () -> tensor<i32>
+  %quant_filter = "tf.Cast"(%filter) {} : (tensor<3x3x10x20xi8>) ->
+    tensor<3x3x10x20x!tf_type.qint8>
+  %filter_new = "tf.UniformDequantize"(%quant_filter, %filter_scale, %filter_zp) {
+    quantization_axis = -1 : i64, quantization_min_val = -128 : i64,
+    quantization_max_val = 127 : i64
+  } : (
+    tensor<3x3x10x20x!tf_type.qint8>, tensor<f32>, tensor<i32>
+  ) -> tensor<3x3x10x20xf32>
+  %0 = "tf.Conv2D"(%input, %filter_new) {
+    Tin = "tfdtype$DT_FLOAT", Tout = "tfdtype$DT_FLOAT",
+    attr_map = "", batch_group_count = 1 : i64,
+    explicit_padding = [], feature_group_count = 1 : i64, lhs_dilation = [1, 1],
+    padding = "SAME", rhs_dilation = [1, 1], strides = [1, 1, 1, 1]
+  } : (tensor<2x10x10x10xf32>, tensor<3x3x10x20xf32>) -> tensor<2x10x10x20xf32>
+  return %0 : tensor<2x10x10x20xf32>
 })mlir";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto executable, this->CompileProgram(kProgram));
-
-  auto input = xla::LiteralUtil::CreateR4<float>(
-      {{{{14.f}, {-100.f}}, {{-200.f}, {350.f}}}});
-  auto filter = xla::LiteralUtil::CreateR4<float>({{{{4.1f}}}, {{{-2.f}}}});
-
-  TF_ASSERT_OK_AND_ASSIGN(auto result_literal,
-                          this->ExecuteProgramAndReturnSingleResult(
-                              executable.get(), {&input, &filter}));
-  xla::LiteralTestUtil::ExpectR4Near<float>(
-      {{{{461}, {-1116.f}}, {{-817.f}, {1431.f}}}}, *result_literal,
-      xla::ErrorSpec(1.f));
+  constexpr absl::string_view kProgram = R"mlir(
+func.func @main(%input: tensor<2x10x10x10xf32>, %filter: tensor<3x3x10x20xi8>) -> tensor<2x10x10x20xf32> {
+  %filter_scale = "tf.Const"() { value = dense<0.047> : tensor<f32> } : () -> tensor<f32>
+  %filter_zp = "tf.Const"() { value = dense<0> : tensor<i32> } : () -> tensor<i32>
+  %quant_filter = "tf.Cast"(%filter) {} : (tensor<3x3x10x20xi8>) -> tensor<3x3x10x20x!tf_type.qint8>
+  %0 = "tf.UniformQuantizedConvolutionHybrid"(
+    %input, %quant_filter, %filter_scale, %filter_zp
+  ) {
+    Tin = "tfdtype$DT_QINT8", Tout = "tfdtype$DT_FLOAT",
+    attr_map = "", batch_group_count = 1 : i64,
+    dimension_numbers = "\10\03\1A\02\01\02 \02(\032\02\00\01@\03J\02\01\02",
+    explicit_padding = [], feature_group_count = 1 : i64, lhs_dilation = [1, 1],
+    padding = "SAME", rhs_dilation = [1, 1], rhs_quantization_axis = -1 : i64,
+    rhs_quantization_max_val = 127 : i64, rhs_quantization_min_val = -128 : i64,
+    window_strides = [1, 1]
+  } : (tensor<2x10x10x10xf32>, tensor<3x3x10x20x!tf_type.qint8>,
+    tensor<f32>, tensor<i32>) -> tensor<2x10x10x20xf32>
+  return %0 : tensor<2x10x10x20xf32>
+})mlir";
+  TF_ASSERT_OK_AND_ASSIGN(auto input, CreateRandomF32Literal({2, 10, 10, 10}));
+  TF_ASSERT_OK_AND_ASSIGN(auto filter, CreateRandomI8Literal({3, 3, 10, 20}));
+  // TF kernels for UniformQuantizedConvolutionHybrid does DRQ. But StableHLO
+  // hybrid ops does weight-only. So we use a different TF graph for evaluating
+  // expected weight-only quantized results.
+  ExecuteAndCompareResultsWithTfKernel(kProgram, {&input, &filter}, kTfProgram);
 }
 
 TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizeDot) {
   constexpr absl::string_view kProgram = R"mlir(
-func.func @main(%input: tensor<1x2xf32>, %filter: tensor<2x3xf32>) -> tensor<1x3xf32> {
-    %input_scale = "tf.Const"() { value = dense<0.588> : tensor<f32> } : ()
-    -> tensor<f32>
-    %input_zp = "tf.Const"() { value = dense<42> : tensor<i32> } : () -> tensor<i32>
-    %filter_scale = "tf.Const"() { value = dense<0.0235> : tensor<f32> } : ()
-    -> tensor<f32>
-    %filter_zp = "tf.Const"() { value = dense<0> : tensor<i32> } : () -> tensor<i32>
-    %accum_scale = "tf.Const"() { value = dense<0.0138> : tensor<f32> } : ()
-    -> tensor<f32>
-    %accum_zp = "tf.Const"() { value = dense<0> : tensor<i32> } : () -> tensor<i32>
-    %quant_input = "tf.UniformQuantize"(%input, %input_scale, %input_zp) {
-      Tin = "tfdtype$DT_FLOAT", Tout = "tfdtype$DT_QINT8", attr_map = "",
-      quantization_axis = -1 : i64, quantization_max_val = 127 : i64,
-      quantization_min_val = -128 : i64
-    } : (tensor<1x2xf32>, tensor<f32>, tensor<i32>) -> tensor<1x2x!tf_type.qint8>
-    %quant_filter = "tf.UniformQuantize"(%filter, %filter_scale, %filter_zp) {
-      Tin = "tfdtype$DT_FLOAT", Tout = "tfdtype$DT_QINT8", attr_map = "",
-      quantization_axis = -1 : i64, quantization_max_val = 127 : i64,
-      quantization_min_val = -128 : i64
-    } : (tensor<2x3xf32>, tensor<f32>, tensor<i32>) -> tensor<2x3x!tf_type.qint8>
-    %0 = "tf.UniformQuantizedDot"(
-      %quant_input, %quant_filter, %input_scale, %input_zp, %filter_scale,
-      %filter_zp, %accum_scale, %accum_zp
-    ) {
-      Tin = "tfdtype$DT_QINT8", Tout = "tfdtype$DT_QINT32", attr_map = "",
-      device = "", lhs_quantization_axis = -1 : i64,
-      lhs_quantization_max_val = 127 : i64, lhs_quantization_min_val = -128 : i64,
-      output_quantization_axis = -1 : i64, output_quantization_max_val = 2147483647 : i64,
-      output_quantization_min_val = -2147483648 : i64, rhs_quantization_axis = -1 : i64,
-      rhs_quantization_max_val = 127 : i64, rhs_quantization_min_val = -128 : i64
-    } : (
-      tensor<1x2x!tf_type.qint8>, tensor<2x3x!tf_type.qint8>, tensor<f32>,
-      tensor<i32>, tensor<f32>, tensor<i32>, tensor<f32>, tensor<i32>
-    ) -> tensor<1x3x!tf_type.qint32>
-    %output = "tf.UniformDequantize"(%0, %accum_scale, %accum_zp) {
-      quantization_axis = -1 : i64, quantization_min_val = -128 : i64,
-      quantization_max_val = 127 : i64
-    } : (tensor<1x3x!tf_type.qint32>, tensor<f32>, tensor<i32>) -> tensor<1x3xf32>
-    return %output : tensor<1x3xf32>
+func.func @main(%input: tensor<8x9xi8>, %filter: tensor<9x10xi8>) -> tensor<8x10xi32> {
+  %input_scale = "tf.Const"() { value = dense<0.588> : tensor<f32> } : () -> tensor<f32>
+  %input_zp = "tf.Const"() { value = dense<42> : tensor<i32> } : () -> tensor<i32>
+  %filter_scale = "tf.Const"() { value = dense<0.0235> : tensor<f32> } : () -> tensor<f32>
+  %filter_zp = "tf.Const"() { value = dense<0> : tensor<i32> } : () -> tensor<i32>
+  %accum_scale = "tf.Const"() { value = dense<0.013818> : tensor<f32> } : () -> tensor<f32>
+  %accum_zp = "tf.Const"() { value = dense<0> : tensor<i32> } : () -> tensor<i32>
+  %quant_input = "tf.Cast"(%input) {} : (tensor<8x9xi8>) -> tensor<8x9x!tf_type.qint8>
+  %quant_filter = "tf.Cast"(%filter) {} : (tensor<9x10xi8>) -> tensor<9x10x!tf_type.qint8>
+  %0 = "tf.UniformQuantizedDot"(
+    %quant_input, %quant_filter, %input_scale, %input_zp, %filter_scale,
+    %filter_zp, %accum_scale, %accum_zp
+  ) {
+    Tin = "tfdtype$DT_QINT8", Tout = "tfdtype$DT_QINT32", attr_map = "",
+    device = "", lhs_quantization_axis = -1 : i64,
+    lhs_quantization_max_val = 127 : i64,
+    lhs_quantization_min_val = -128 : i64,
+    output_quantization_axis = -1 : i64,
+    output_quantization_max_val = 2147483647 : i64,
+    output_quantization_min_val = -2147483648 : i64,
+    rhs_quantization_axis = -1 : i64,
+    rhs_quantization_max_val = 127 : i64,
+    rhs_quantization_min_val = -128 : i64
+  } : (
+    tensor<8x9x!tf_type.qint8>, tensor<9x10x!tf_type.qint8>, tensor<f32>,
+    tensor<i32>, tensor<f32>, tensor<i32>, tensor<f32>, tensor<i32>
+  ) -> tensor<8x10x!tf_type.qint32>
+  %output = "tf.Cast"(%0) {} : (tensor<8x10x!tf_type.qint32>) -> tensor<8x10xi32>
+  return %output : tensor<8x10xi32>
 })mlir";
-
-  TF_ASSERT_OK_AND_ASSIGN(auto executable, this->CompileProgram(kProgram));
-
-  auto input = xla::LiteralUtil::CreateR2<float>({{50.f, -100.f}});
-  auto filter =
-      xla::LiteralUtil::CreateR2<float>({{1.f, 2.f, 3.f}, {-1.f, -3.f, 1.f}});
-
-  TF_ASSERT_OK_AND_ASSIGN(auto result_literal,
-                          this->ExecuteProgramAndReturnSingleResult(
-                              executable.get(), {&input, &filter}));
-  xla::LiteralTestUtil::ExpectR2Near<float>(
-      {{150.f, 400.f, 50.f}}, *result_literal, xla::ErrorSpec(2.f));
+  TF_ASSERT_OK_AND_ASSIGN(auto input, CreateRandomI8Literal({8, 9}));
+  TF_ASSERT_OK_AND_ASSIGN(auto filter, CreateRandomI8Literal({9, 10}));
+  ExecuteAndCompareResultsWithTfKernel(kProgram, {&input, &filter});
 }
 
 TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizeDotHybrid) {
-  constexpr absl::string_view kProgram = R"mlir(
-func.func @main(%input: tensor<1x2xf32>, %filter: tensor<2x3xf32>) -> tensor<1x3xf32> {
-    %filter_scale = "tf.Const"() { value = dense<0.0235> : tensor<f32> } : ()
-    -> tensor<f32>
-    %filter_zp = "tf.Const"() { value = dense<0> : tensor<i32> } : () -> tensor<i32>
-    %quant_filter = "tf.UniformQuantize"(%filter, %filter_scale, %filter_zp) {
-      Tin = "tfdtype$DT_FLOAT", Tout = "tfdtype$DT_QINT8", attr_map = "",
-      quantization_axis = -1 : i64, quantization_max_val = 127 : i64,
-      quantization_min_val = -128 : i64
-    } : (tensor<2x3xf32>, tensor<f32>, tensor<i32>) -> tensor<2x3x!tf_type.qint8>
-    %0 = "tf.UniformQuantizedDotHybrid"(
-      %input, %quant_filter, %filter_scale, %filter_zp
-    ) {
-      Tin = "tfdtype$DT_QINT8", Tout = "tfdtype$DT_FLOAT", attr_map = "",
-      device = "", rhs_quantization_axis = -1 : i64,
-      rhs_quantization_max_val = 127 : i64, rhs_quantization_min_val = -128 : i64
-    } : (tensor<1x2xf32>, tensor<2x3x!tf_type.qint8>, tensor<f32>, tensor<i32>) -> tensor<1x3xf32>
-    return %0 : tensor<1x3xf32>
+  constexpr absl::string_view kTfProgram = R"mlir(
+func.func @main(%input: tensor<8x9xf32>, %filter: tensor<9x10xi8>) -> tensor<8x10xf32> {
+  %filter_scale = "tf.Const"() { value = dense<0.0235> : tensor<f32> } : () -> tensor<f32>
+  %filter_zp = "tf.Const"() { value = dense<0> : tensor<i32> } : () -> tensor<i32>
+  %quant_filter = "tf.Cast"(%filter) {} : (tensor<9x10xi8>) -> tensor<9x10x!tf_type.qint8>
+  %filter_new = "tf.UniformDequantize"(%quant_filter, %filter_scale, %filter_zp) {
+    quantization_axis = -1 : i64, quantization_min_val = -128 : i64,
+    quantization_max_val = 127 : i64
+  } : (tensor<9x10x!tf_type.qint8>, tensor<f32>, tensor<i32>) -> tensor<9x10xf32>
+  %0 = "tf.MatMul"(%input, %filter_new) {
+  } : (tensor<8x9xf32>, tensor<9x10xf32>) -> tensor<8x10xf32>
+  return %0 : tensor<8x10xf32>
 })mlir";
+  constexpr absl::string_view kProgram = R"mlir(
+func.func @main(%input: tensor<8x9xf32>, %filter: tensor<9x10xi8>) -> tensor<8x10xf32> {
+  %filter_scale = "tf.Const"() { value = dense<0.0235> : tensor<f32> } : ()
+  -> tensor<f32>
+  %filter_zp = "tf.Const"() { value = dense<0> : tensor<i32> } : () -> tensor<i32>
+  %quant_filter = "tf.Cast"(%filter) {} : (tensor<9x10xi8>) -> tensor<9x10x!tf_type.qint8>
+  %0 = "tf.UniformQuantizedDotHybrid"(
+    %input, %quant_filter, %filter_scale, %filter_zp
+  ) {
+    Tin = "tfdtype$DT_QINT8", Tout = "tfdtype$DT_FLOAT", attr_map = "",
+    device = "", rhs_quantization_axis = -1 : i64,
+    rhs_quantization_max_val = 127 : i64, rhs_quantization_min_val = -128 : i64
+  } : (tensor<8x9xf32>, tensor<9x10x!tf_type.qint8>, tensor<f32>, tensor<i32>) -> tensor<8x10xf32>
+  return %0 : tensor<8x10xf32>
+})mlir";
+  TF_ASSERT_OK_AND_ASSIGN(auto input, CreateRandomF32Literal({8, 9}));
+  TF_ASSERT_OK_AND_ASSIGN(auto filter, CreateRandomI8Literal({9, 10}));
+  // TF kernels for UniformQuantizedDotHybrid does DRQ. But StableHLO hybrid ops
+  // does weight-only. So we use a different TF graph for evaluating expected
+  // weight-only quantized results.
+  ExecuteAndCompareResultsWithTfKernel(kProgram, {&input, &filter}, kTfProgram);
+}
 
-  TF_ASSERT_OK_AND_ASSIGN(auto executable, this->CompileProgram(kProgram));
+TEST_F(ConvertTfQuantToMhloIntTest, UniformRequantize) {
+  constexpr absl::string_view kProgram = R"mlir(
+func.func @main(%input: tensor<10xi8>) -> tensor<10xi8> {
+  %input_scale = "tf.Const"() { value = dense<0.2235> : tensor<f32> } : () -> tensor<f32>
+  %input_zp = "tf.Const"() { value = dense<-2> : tensor<i32> } : () -> tensor<i32>
+  %output_scale = "tf.Const"() { value = dense<0.11> : tensor<f32> } : () -> tensor<f32>
+  %output_zp = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
+  %0 = "tf.Cast"(%input) {} : (tensor<10xi8>) -> tensor<10x!tf_type.qint8>
+  %1 = "tf.UniformRequantize"(
+    %0, %input_scale, %input_zp, %output_scale, %output_zp
+  ) {
+    Tin = "tfdtype$DT_QINT8", Tout = "tfdtype$DT_QINT8", attr_map = "",
+    device = "", input_quantization_axis = -1,
+    input_quantization_max_val = 127 : i64,
+    input_quantization_min_val = -128 : i64,
+    output_quantization_axis = -1 : i64,
+    output_quantization_max_val = 127 : i64,
+    output_quantization_min_val = -128 : i64
+  } : (
+    tensor<10x!tf_type.qint8>, tensor<f32>, tensor<i32>, tensor<f32>,
+    tensor<i32>
+  ) -> tensor<10x!tf_type.qint8>
+  %2 = "tf.Cast"(%1) {} : (tensor<10x!tf_type.qint8>) -> tensor<10xi8>
+  return %2 : tensor<10xi8>
+})mlir";
+  TF_ASSERT_OK_AND_ASSIGN(auto input, CreateRandomI8Literal({10}));
+  ExecuteAndCompareResultsWithTfKernel(kProgram, {&input});
+}
 
-  auto input = xla::LiteralUtil::CreateR2<float>({{50.f, -100.f}});
-  auto filter =
-      xla::LiteralUtil::CreateR2<float>({{1.f, 2.f, 3.f}, {-1.f, -3.f, 1.f}});
-
-  TF_ASSERT_OK_AND_ASSIGN(auto result_literal,
-                          this->ExecuteProgramAndReturnSingleResult(
-                              executable.get(), {&input, &filter}));
-  xla::LiteralTestUtil::ExpectR2Near<float>(
-      {{150.f, 400.f, 50.f}}, *result_literal, xla::ErrorSpec(2.f));
+TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizeAdd) {
+  constexpr absl::string_view kProgram = R"mlir(
+func.func @main(%lhs: tensor<10x10xi32>, %rhs: tensor<10x10xi32>) -> tensor<10x10xi32> {
+  %lhs_scale = "tf.Const"() { value = dense<0.518> : tensor<f32> } : () -> tensor<f32>
+  %lhs_zp = "tf.Const"() { value = dense<42> : tensor<i32> } : () -> tensor<i32>
+  %rhs_scale = "tf.Const"() { value = dense<0.0239> : tensor<f32> } : () -> tensor<f32>
+  %rhs_zp = "tf.Const"() { value = dense<0> : tensor<i32> } : () -> tensor<i32>
+  %accum_scale = "tf.Const"() { value = dense<0.013> : tensor<f32> } : () -> tensor<f32>
+  %accum_zp = "tf.Const"() { value = dense<0> : tensor<i32> } : () -> tensor<i32>
+  %quant_lhs = "tf.Cast"(%lhs) {} : (tensor<10x10xi32>) -> tensor<10x10x!tf_type.qint32>
+  %quant_rhs = "tf.Cast"(%rhs) {} : (tensor<10x10xi32>) -> tensor<10x10x!tf_type.qint32>
+  %0 = "tf.UniformQuantizedAdd"(
+    %quant_lhs, %quant_rhs, %lhs_scale, %lhs_zp, %rhs_scale,
+    %rhs_zp, %accum_scale, %accum_zp
+  ) {
+    Tin = "tfdtype$DT_QINT32", Tout = "tfdtype$DT_QINT32", attr_map = "",
+    device = "", lhs_quantization_axis = -1 : i64,
+    lhs_quantization_max_val = 2147483647 : i64,
+    lhs_quantization_min_val = -2147483648 : i64,
+    output_quantization_axis = -1 : i64,
+    output_quantization_max_val = 2147483647 : i64,
+    output_quantization_min_val = -2147483648 : i64,
+    rhs_quantization_axis = -1 : i64,
+    rhs_quantization_max_val = 2147483647 : i64,
+    rhs_quantization_min_val = -2147483648 : i64
+  } : (
+    tensor<10x10x!tf_type.qint32>, tensor<10x10x!tf_type.qint32>, tensor<f32>,
+    tensor<i32>, tensor<f32>, tensor<i32>, tensor<f32>, tensor<i32>
+  ) -> tensor<10x10x!tf_type.qint32>
+  %1 = "tf.Cast"(%0) {} : (tensor<10x10x!tf_type.qint32>) ->  tensor<10x10xi32>
+  return %1 : tensor<10x10xi32>
+})mlir";
+  TF_ASSERT_OK_AND_ASSIGN(auto lhs, CreateRandomI32Literal({10, 10}));
+  TF_ASSERT_OK_AND_ASSIGN(auto rhs, CreateRandomI32Literal({10, 10}));
+  // error_tolerance is set to be 1 because different rounding implementations
+  // in TF kernel and the lowering passes may cause +/-1 differences.
+  ExecuteAndCompareResultsWithTfKernel(kProgram, {&lhs, &rhs},
+                                       /*tf_program=*/std::nullopt,
+                                       /*error_tolerance=*/1.0);
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_fusion.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_fusion.td
index a449b4d..116037d 100644
--- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_fusion.td
+++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_fusion.td
@@ -13,9 +13,10 @@
 limitations under the License.
 ==============================================================================*/
 
-include "mlir/IR/OpBase.td"
-include "mlir/Dialect/Func/IR/FuncOps.td"
 include "mlir/Dialect/Arith/IR/ArithOps.td"
+include "mlir/Dialect/Func/IR/FuncOps.td"
+include "mlir/Dialect/Shape/IR/ShapeOps.td"
+include "mlir/IR/OpBase.td"
 include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
 include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
 include "stablehlo/dialect/StablehloOps.td"
@@ -47,7 +48,7 @@
       (NamedAttr<"feature_group_count"> $feature_group_count),
       (NamedAttr<"batch_group_count"> $batch_group_count),
       (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))),
-  [(IsNotInLiftedFunc $res)], [], (addBenefit 5)>;
+  [(IsNotInLiftedFunc $res), (IsStableHLOConstantOp $bias)], [], (addBenefit 5)>;
 
 def LiftDotGeneralWithBias : Pat<
   (StableHLO_AddOp:$res
@@ -60,7 +61,44 @@
     (NamedAttributeList
       (NamedAttr<"dot_dimension_numbers"> $dot_dimension_numbers),
       (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))),
-  [(IsNotInLiftedFunc $res)], [], (addBenefit 5)>;
+  [(IsNotInLiftedFunc $res), (IsStableHLOConstantOp $bias)], [], (addBenefit 5)>;
+
+def LiftConvWithBiasDynamic : Pat<
+  (StableHLO_AddOp:$res
+    (StableHLO_ConvolutionOp $lhs, $rhs, $window_strides, $padding,
+        $lhs_dilation, $rhs_dilation, $window_reversal, $dimension_numbers,
+        $feature_group_count, $batch_group_count, $precision_config),
+    (StableHLO_DynamicBroadcastInDimOp
+      $bias,
+      (Shape_ShapeOfOp $conv), $_, $_, $_)),
+  (LiftAsTFXlaCallModule<"composite_conv_with_bias_dynamic_fn">
+    (ArgumentList $lhs, $rhs, $bias),
+    (ResultList $res),
+    (NamedAttributeList
+      (NamedAttr<"window_strides"> (DefaultOrNullAttr $window_strides)),
+      (NamedAttr<"padding"> (DefaultOrNullAttr $padding)),
+      (NamedAttr<"lhs_dilation"> (DefaultOrNullAttr $lhs_dilation)),
+      (NamedAttr<"rhs_dilation"> (DefaultOrNullAttr $rhs_dilation)),
+      (NamedAttr<"window_reversal"> (DefaultOrNullAttr $window_reversal)),
+      (NamedAttr<"dimension_numbers"> $dimension_numbers),
+      (NamedAttr<"feature_group_count"> $feature_group_count),
+      (NamedAttr<"batch_group_count"> $batch_group_count),
+      (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))),
+  [(IsNotInLiftedFunc $res), (IsStableHLOConstantOp $bias)], [], (addBenefit 10)>;
+
+def LiftDotGeneralWithBiasDynamic : Pat<
+  (StableHLO_AddOp:$res
+    (StableHLO_DotGeneralOp $lhs, $rhs, $dot_dimension_numbers, $precision_config),
+    (StableHLO_DynamicBroadcastInDimOp
+      $bias,
+      (Shape_ShapeOfOp $dot_general), $_, $_, $_)),
+  (LiftAsTFXlaCallModule<"composite_dot_general_with_bias_dynamic_fn">
+    (ArgumentList $lhs, $rhs, $bias),
+    (ResultList $res),
+    (NamedAttributeList
+      (NamedAttr<"dot_dimension_numbers"> $dot_dimension_numbers),
+      (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))),
+  [(IsNotInLiftedFunc $res), (IsStableHLOConstantOp $bias)], [], (addBenefit 10)>;
 
 //===----------------------------------------------------------------------===//
 // Pattern rules for lifting ops with activation as functions
@@ -101,6 +139,45 @@
   [(IsNotInLiftedFunc $res),
    (FloatValueEquals<"0"> $cst)], [], (addBenefit 10)>;
 
+def LiftConvWithReluDynamic : Pat<
+  (StableHLO_MaxOp:$res
+    (StableHLO_ConvolutionOp $lhs, $rhs, $window_strides, $padding,
+        $lhs_dilation, $rhs_dilation, $window_reversal, $dimension_numbers,
+        $feature_group_count, $batch_group_count, $precision_config),
+    (StableHLO_DynamicBroadcastInDimOp
+      (StableHLO_ConstantOp $cst),
+      (Shape_ShapeOfOp $conv), $_, $_, $_)),
+  (LiftAsTFXlaCallModule<"composite_conv_with_relu_dynamic_fn">
+    (ArgumentList $lhs, $rhs),
+    (ResultList $res),
+    (NamedAttributeList
+      (NamedAttr<"window_strides"> (DefaultOrNullAttr $window_strides)),
+      (NamedAttr<"padding"> (DefaultOrNullAttr $padding)),
+      (NamedAttr<"lhs_dilation"> (DefaultOrNullAttr $lhs_dilation)),
+      (NamedAttr<"rhs_dilation"> (DefaultOrNullAttr $rhs_dilation)),
+      (NamedAttr<"window_reversal"> (DefaultOrNullAttr $window_reversal)),
+      (NamedAttr<"dimension_numbers"> $dimension_numbers),
+      (NamedAttr<"feature_group_count"> $feature_group_count),
+      (NamedAttr<"batch_group_count"> $batch_group_count),
+      (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))),
+  [(IsNotInLiftedFunc $res),
+   (FloatValueEquals<"0"> $cst)], [], (addBenefit 15)>;
+
+def LiftDotGeneralWithReluDynamic : Pat<
+  (StableHLO_MaxOp:$res
+    (StableHLO_DotGeneralOp $lhs, $rhs, $dot_dimension_numbers, $precision_config),
+    (StableHLO_DynamicBroadcastInDimOp
+      (StableHLO_ConstantOp $cst),
+      (Shape_ShapeOfOp $dot_general), $_, $_, $_)),
+  (LiftAsTFXlaCallModule<"composite_dot_general_with_relu_dynamic_fn">
+    (ArgumentList $lhs, $rhs),
+    (ResultList $res),
+    (NamedAttributeList
+      (NamedAttr<"dot_dimension_numbers"> $dot_dimension_numbers),
+      (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))),
+  [(IsNotInLiftedFunc $res),
+   (FloatValueEquals<"0"> $cst)], [], (addBenefit 15)>;
+
 def LiftConvWithRelu6 : Pat<
   (StableHLO_ClampOp:$res
     (StableHLO_ConstantOp $cst_0),
@@ -163,7 +240,7 @@
       (NamedAttr<"batch_group_count"> $batch_group_count),
       (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))),
   [(IsNotInLiftedFunc $res),
-   (FloatValueEquals<"0"> $cst)], [], (addBenefit 10)>;
+   (FloatValueEquals<"0"> $cst), (IsStableHLOConstantOp $bias)], [], (addBenefit 10)>;
 
 def LiftDotGeneralWithBiasAndRelu : Pat<
   (StableHLO_MaxOp:$res
@@ -179,7 +256,55 @@
       (NamedAttr<"dot_dimension_numbers"> $dot_dimension_numbers),
       (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))),
   [(IsNotInLiftedFunc $res),
-   (FloatValueEquals<"0"> $cst)], [], (addBenefit 10)>;
+   (FloatValueEquals<"0"> $cst), (IsStableHLOConstantOp $bias)], [], (addBenefit 10)>;
+
+def LiftConvWithBiasAndReluDynamic : Pat<
+  (StableHLO_MaxOp:$res
+    (StableHLO_AddOp
+      (StableHLO_ConvolutionOp $lhs, $rhs, $window_strides, $padding,
+          $lhs_dilation, $rhs_dilation, $window_reversal, $dimension_numbers,
+          $feature_group_count, $batch_group_count, $precision_config),
+      (StableHLO_DynamicBroadcastInDimOp
+        $bias,
+        (Shape_ShapeOfOp $conv), $_, $_, $_)),
+    (StableHLO_DynamicBroadcastInDimOp
+      (StableHLO_ConstantOp $cst),
+      (Shape_ShapeOfOp $add), $_, $_, $_)),
+  (LiftAsTFXlaCallModule<"composite_conv_with_bias_and_relu_dynamic_fn">
+    (ArgumentList $lhs, $rhs, $bias),
+    (ResultList $res),
+    (NamedAttributeList
+      (NamedAttr<"window_strides"> (DefaultOrNullAttr $window_strides)),
+      (NamedAttr<"padding"> (DefaultOrNullAttr $padding)),
+      (NamedAttr<"lhs_dilation"> (DefaultOrNullAttr $lhs_dilation)),
+      (NamedAttr<"rhs_dilation"> (DefaultOrNullAttr $rhs_dilation)),
+      (NamedAttr<"window_reversal"> (DefaultOrNullAttr $window_reversal)),
+      (NamedAttr<"dimension_numbers"> $dimension_numbers),
+      (NamedAttr<"feature_group_count"> $feature_group_count),
+      (NamedAttr<"batch_group_count"> $batch_group_count),
+      (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))),
+  [(IsNotInLiftedFunc $res),
+   (FloatValueEquals<"0"> $cst), (IsStableHLOConstantOp $bias)], [], (addBenefit 15)>;
+
+def LiftDotGeneralWithBiasAndReluDynamic : Pat<
+  (StableHLO_MaxOp:$res
+    (StableHLO_AddOp
+      (StableHLO_DotGeneralOp $lhs, $rhs, $dot_dimension_numbers, $precision_config),
+      (StableHLO_DynamicBroadcastInDimOp
+        $bias,
+        (Shape_ShapeOfOp $dot_general), $_, $_, $_)),
+    (StableHLO_DynamicBroadcastInDimOp
+      (StableHLO_ConstantOp $cst),
+      (Shape_ShapeOfOp $add), $_, $_, $_)),
+  (LiftAsTFXlaCallModule<"composite_dot_general_with_bias_and_relu_dynamic_fn">
+    (ArgumentList $lhs, $rhs, $bias),
+    (ResultList $res),
+    (NamedAttributeList
+      (NamedAttr<"dot_dimension_numbers"> $dot_dimension_numbers),
+      (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))),
+  [(IsNotInLiftedFunc $res),
+   (FloatValueEquals<"0"> $cst), (IsStableHLOConstantOp $bias)], [], (addBenefit 15)>;
+
 
 def LiftConvWithBiasAndRelu6 : Pat<
   (StableHLO_ClampOp:$res
@@ -203,7 +328,7 @@
       (NamedAttr<"feature_group_count"> $feature_group_count),
       (NamedAttr<"batch_group_count"> $batch_group_count),
       (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))),
-  [(IsNotInLiftedFunc $res), (FloatValueEquals<"0"> $cst_0), (FloatValueEquals<"6"> $cst_1)], [], (addBenefit 10)>;
+  [(IsNotInLiftedFunc $res), (IsStableHLOConstantOp $bias), (FloatValueEquals<"0"> $cst_0), (FloatValueEquals<"6"> $cst_1)], [], (addBenefit 10)>;
 
 def LiftDotGeneralWithBiasAndRelu6 : Pat<
   (StableHLO_ClampOp:$res
@@ -219,4 +344,4 @@
     (NamedAttributeList
       (NamedAttr<"dot_dimension_numbers"> $dot_dimension_numbers),
       (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))),
-  [(IsNotInLiftedFunc $res), (FloatValueEquals<"0"> $cst_0), (FloatValueEquals<"6"> $cst_1)], [], (addBenefit 10)>;
+  [(IsNotInLiftedFunc $res), (IsStableHLOConstantOp $bias), (FloatValueEquals<"0"> $cst_0), (FloatValueEquals<"6"> $cst_1)], [], (addBenefit 10)>;
diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h
index 8e9cfb5..0b05069 100644
--- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h
+++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h
@@ -21,10 +21,16 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project  // IWYU pragma: keep
 #include "mlir/Pass/Pass.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
 #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_options.pb.h"
 
 namespace mlir::quant::stablehlo {
 
+// Creates a `QuantizePass` that quantizes ops according to surrounding qcast /
+// dcast ops.
+std::unique_ptr<OperationPass<func::FuncOp>> CreateQuantizePass(
+    const quant::QuantizationSpecs& quantization_specs);
+
 // Creates a pass that quantizes weight component of StableHLO graph.
 std::unique_ptr<OperationPass<func::FuncOp>> CreateQuantizeWeightPass(
     const ::stablehlo::quantization::QuantizationComponentSpec&
diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td
index 0992c04..52dca78 100644
--- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td
+++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td
@@ -76,7 +76,30 @@
   ];
 }
 
-
 def RestoreFunctionNamePass : Pass<"stablehlo-restore-function-name", "ModuleOp"> {
   let summary = "Restores function name from XlaCallModule op.";
 }
+
+def PostQuantizePass : Pass<"stablehlo-post-quantize", "mlir::func::FuncOp"> {
+  let summary = "Apply clean-up after quantization.";
+  let dependentDialects = [
+    "mlir::stablehlo::StablehloDialect",
+    "mlir::quantfork::QuantizationForkDialect",
+  ];
+}
+
+def QuantizeCompositeFunctionsPass : Pass<"stablehlo-quantize-composite-functions", "ModuleOp"> {
+  let summary = "Quantize composite functions with QDQ input / outputs.";
+  let options = [
+    Option<"mlir_dump_file_name_", "mlir-dump-file-name",
+        "std::optional<std::string>", /*default=*/"std::nullopt",
+        "MLIR dump file name.">
+  ];
+  let dependentDialects = [
+    "mlir::arith::ArithDialect",
+    "mlir::stablehlo::StablehloDialect",
+    "mlir::quant::QuantizationDialect",
+    "mlir::quantfork::QuantizationForkDialect",
+    "TF::TensorFlowDialect",
+  ];
+}
diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/post_quantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/post_quantize.cc
new file mode 100644
index 0000000..0416bbd
--- /dev/null
+++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/post_quantize.cc
@@ -0,0 +1,158 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <utility>
+
+#include "llvm/Support/Casting.h"
+#include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinAttributeInterfaces.h"  // from @llvm-project
+#include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
+#include "mlir/IR/MLIRContext.h"  // from @llvm-project
+#include "mlir/IR/Matchers.h"  // from @llvm-project
+#include "mlir/IR/OpDefinition.h"  // from @llvm-project
+#include "mlir/IR/PatternMatch.h"  // from @llvm-project
+#include "mlir/Support/LogicalResult.h"  // from @llvm-project
+#include "mlir/Support/TypeID.h"  // from @llvm-project
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
+#include "stablehlo/dialect/StablehloOps.h"  // from @stablehlo
+#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
+#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
+#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
+
+namespace mlir::quant::stablehlo {
+
+#define GEN_PASS_DEF_POSTQUANTIZEPASS
+#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h.inc"
+
+namespace {
+
+// Applies clean-up patterns after quantization.
+class PostQuantizePass : public impl::PostQuantizePassBase<PostQuantizePass> {
+ public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PostQuantizePass)
+
+  explicit PostQuantizePass() = default;
+
+ private:
+  void runOnOperation() override;
+};
+
+// TODO: b/305815328 - Consider preserving leading and trailing QDQs for
+// ModifyIONodesPass in TFLite use cases.
+// Removes the back-to-back quantize and dequantize ops with volatile attribute.
+class RemoveVolatileQdqPattern
+    : public OpRewritePattern<quantfork::DequantizeCastOp> {
+ public:
+  explicit RemoveVolatileQdqPattern(MLIRContext* context)
+      : OpRewritePattern<quantfork::DequantizeCastOp>(context) {}
+
+  LogicalResult matchAndRewrite(quantfork::DequantizeCastOp op,
+                                PatternRewriter& rewriter) const override {
+    auto input_op = op.getArg().getDefiningOp();
+    if (auto q = llvm::dyn_cast_or_null<quantfork::QuantizeCastOp>(input_op)) {
+      if (!q->getAttr(kVolatileOpAttrName)) return failure();
+
+      // If the quantize op is a requantize op, it is being used in other scale
+      // adjustments and should be kept. Instead, move dequantize op before the
+      // requantize op to remove the unnecessary requantize op.
+      if (auto qtype =
+              QuantizedType::getQuantizedElementType(q.getArg().getType())) {
+        rewriter.setInsertionPoint(op);
+        rewriter.replaceOpWithNewOp<quantfork::DequantizeCastOp>(
+            op, op.getResult().getType(), q.getArg());
+        return success();
+      }
+
+      op.replaceAllUsesWith(q.getArg());
+      return success();
+    }
+    return failure();
+  }
+};
+
+// Replaces constant and uniform_quantize ops with single quantized constant op.
+class QuantizeConstPattern
+    : public OpRewritePattern<mlir::stablehlo::UniformQuantizeOp> {
+ public:
+  explicit QuantizeConstPattern(MLIRContext* context)
+      : OpRewritePattern<mlir::stablehlo::UniformQuantizeOp>(context) {}
+
+  LogicalResult matchAndRewrite(mlir::stablehlo::UniformQuantizeOp op,
+                                PatternRewriter& rewriter) const override {
+    DenseFPElementsAttr attr;
+    if (matchPattern(op.getOperand(), m_Constant(&attr))) {
+      auto qtype = op.getResult().getType();
+      ElementsAttr quantized_attr = Quantize(attr, qtype);
+      if (quantized_attr) {
+        rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(
+            op, qtype, quantized_attr);
+        return success();
+      }
+    }
+    return failure();
+  }
+};
+
+// Replaces quantfork.dcast with stablehlo.uniform_dequantize.
+class ConvertDequantizeCastToUniformDequantizePattern
+    : public OpRewritePattern<quantfork::DequantizeCastOp> {
+ public:
+  explicit ConvertDequantizeCastToUniformDequantizePattern(MLIRContext* context)
+      : OpRewritePattern<quantfork::DequantizeCastOp>(context) {}
+  LogicalResult matchAndRewrite(quantfork::DequantizeCastOp dq_op,
+                                PatternRewriter& rewriter) const override {
+    rewriter.replaceOpWithNewOp<mlir::stablehlo::UniformDequantizeOp>(
+        dq_op, dq_op.getResult().getType(), dq_op.getArg());
+    return success();
+  }
+};
+
+// Replaces quantfork.qcast with stablehlo.uniform_quantize.
+class ConvertQuantizeCastToUniformQuantizePattern
+    : public OpRewritePattern<quantfork::QuantizeCastOp> {
+ public:
+  explicit ConvertQuantizeCastToUniformQuantizePattern(MLIRContext* context)
+      : OpRewritePattern<quantfork::QuantizeCastOp>(context) {}
+  LogicalResult matchAndRewrite(quantfork::QuantizeCastOp q_op,
+                                PatternRewriter& rewriter) const override {
+    rewriter.replaceOpWithNewOp<mlir::stablehlo::UniformQuantizeOp>(
+        q_op, q_op.getResult().getType(), q_op.getArg());
+    return success();
+  }
+};
+
+void PostQuantizePass::runOnOperation() {
+  RewritePatternSet patterns(&getContext());
+  func::FuncOp func = getOperation();
+  MLIRContext* ctx = func.getContext();
+  // TODO: b/307463853 - Consider splitting passes for each pattern set.
+  patterns.add<FoldTrivalRequantizeOp<quantfork::QuantizeCastOp>,
+               RemoveVolatileQdqPattern>(ctx);
+  if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) {
+    signalPassFailure();
+  }
+
+  RewritePatternSet patterns_2(&getContext());
+  patterns_2
+      .add<ConvertDequantizeCastToUniformDequantizePattern,
+           ConvertQuantizeCastToUniformQuantizePattern, QuantizeConstPattern>(
+          ctx);
+  if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns_2)))) {
+    signalPassFailure();
+  }
+}
+
+}  // namespace
+}  // namespace mlir::quant::stablehlo
diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc
index c5b7a3a..16e7ad1 100644
--- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc
+++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc
@@ -13,6 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 
+#include <memory>
 #include <string>
 #include <utility>
 
@@ -137,4 +138,9 @@
 
 }  // namespace
 
+std::unique_ptr<OperationPass<func::FuncOp>> CreateQuantizePass(
+    const QuantizationSpecs& quantization_specs) {
+  return std::make_unique<QuantizePass>(quantization_specs);
+}
+
 }  // namespace mlir::quant::stablehlo
diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc
new file mode 100644
index 0000000..cf0c44f
--- /dev/null
+++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc
@@ -0,0 +1,358 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cstdint>
+#include <string>
+#include <type_traits>
+#include <utility>
+
+#include "absl/algorithm/container.h"
+#include "absl/status/status.h"
+#include "llvm/ADT/STLExtras.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
+#include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project  // IWYU pragma: keep
+#include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
+#include "mlir/IR/Block.h"  // from @llvm-project
+#include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
+#include "mlir/IR/Location.h"  // from @llvm-project
+#include "mlir/IR/MLIRContext.h"  // from @llvm-project
+#include "mlir/IR/OperationSupport.h"  // from @llvm-project
+#include "mlir/IR/PatternMatch.h"  // from @llvm-project
+#include "mlir/IR/SymbolTable.h"  // from @llvm-project
+#include "mlir/IR/TypeUtilities.h"  // from @llvm-project
+#include "mlir/IR/Visitors.h"  // from @llvm-project
+#include "mlir/Pass/Pass.h"  // from @llvm-project  // IWYU pragma: keep
+#include "mlir/Pass/PassRegistry.h"  // from @llvm-project
+#include "mlir/Support/LLVM.h"  // from @llvm-project
+#include "mlir/Support/LogicalResult.h"  // from @llvm-project
+#include "mlir/Support/TypeID.h"  // from @llvm-project
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
+#include "stablehlo/dialect/StablehloOps.h"  // from @stablehlo  // IWYU pragma: keep
+#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
+#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
+#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h"
+#include "tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h"
+#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h"
+#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
+
+namespace mlir::quant::stablehlo {
+
+#define GEN_PASS_DEF_QUANTIZECOMPOSITEFUNCTIONSPASS
+#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h.inc"
+
+namespace {
+
+using QuantMethod = tensorflow::quantization::QuantizationMethod::PresetMethod;
+using ::mlir::stablehlo::DotGeneralOp;
+using ::mlir::stablehlo::UniformQuantizeOp;
+using ::tensorflow::quantization::RunPassesOnModuleOp;
+
+constexpr StringRef kCompositeFuncPrefix = "composite_";
+constexpr StringRef kQuantizedFuncPrefix = "quantized_";
+constexpr StringRef kEntryFuncAttrName = "_entry_function";
+
+class QuantizeCompositeFunctionsPass
+    : public impl::QuantizeCompositeFunctionsPassBase<
+          QuantizeCompositeFunctionsPass> {
+ public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(QuantizeCompositeFunctionsPass)
+
+  using impl::QuantizeCompositeFunctionsPassBase<
+      QuantizeCompositeFunctionsPass>::QuantizeCompositeFunctionsPassBase;
+
+ private:
+  void runOnOperation() override;
+};
+
+// Returns true if `type` is a TensorType with quantized elements.
+bool IsQuantizedTensorType(const Type type) {
+  return type.isa<TensorType>() &&
+         type.cast<TensorType>().getElementType().isa<QuantizedType>();
+}
+
+// Checks if all inputs and outputs are quantized.
+bool HasQuantizedOperandOrOutput(Operation* call_op) {
+  SmallVector<Type> arg_types;
+  for (const Value arg : call_op->getOperands()) {
+    arg_types.push_back(arg.getType());
+  }
+
+  SmallVector<Type> output_types;
+  for (const Value output : call_op->getResults()) {
+    output_types.push_back(output.getType());
+  }
+
+  return absl::c_all_of(arg_types, IsQuantizedTensorType) &&
+         absl::c_all_of(output_types, IsQuantizedTensorType);
+}
+
+// Get the corresponding quantized function name from the given function name.
+// Example: "composite_dot_general_fn_1" => "quantized_dot_general_fn"
+std::string GetQuantizedFunctionName(const StringRef func_name) {
+  return Twine(kQuantizedFuncPrefix)
+      .concat(func_name.rsplit(kCompositeFuncPrefix).second)
+      .str();
+}
+
+// Returns true if `xla_call_module_op` is quantized. To be considered
+// quantized, it should meet three conditions:
+// 1. At least one of the inputs or outputs should be a uniform quantized type.
+// 2. `xla_call_module_op` should have the `kQuantTraitAttrName` attribute.
+// 3. It should also have the `kEntryFuncAttrName` attribute, which points to
+//    the function that `xla_call_module_op` represents.
+bool IsQuantizedXlaCallModuleOp(TF::XlaCallModuleOp xla_call_module_op) {
+  return HasQuantizedOperandOrOutput(xla_call_module_op) &&
+         xla_call_module_op->hasAttr(kQuantTraitAttrName) &&
+         xla_call_module_op->hasAttr(kEntryFuncAttrName);
+}
+
+// Returns the entry function, i.e. the callee of `xla_call_module_op`.
+func::FuncOp GetEntryFuncOp(TF::XlaCallModuleOp xla_call_module_op,
+                            SymbolTable symbol_table) {
+  auto entry_function_symbol_ref =
+      xla_call_module_op->getAttrOfType<FlatSymbolRefAttr>(kEntryFuncAttrName);
+
+  // Don't match if there are no DotGeneralOp.
+  // if (target_func_op.getOps<DotGeneralOp>().empty()) return {};
+  return dyn_cast_or_null<func::FuncOp>(
+      symbol_table.lookup(entry_function_symbol_ref.getValue()));
+}
+
+// Replaces the function type of `entry_func_op` to a quantized one, matching
+// the input and output types of `xla_call_module_op`.
+void SetQuantizedFunctionType(PatternRewriter& rewriter,
+                              func::FuncOp entry_func_op,
+                              TF::XlaCallModuleOp xla_call_module_op) {
+  SmallVector<Type> arg_types;
+  SmallVector<Location> arg_locs;
+  for (const Value arg : xla_call_module_op.getArgs()) {
+    arg_types.push_back(arg.getType());
+    arg_locs.push_back(arg.getLoc());
+  }
+
+  SmallVector<Type> output_types;
+  for (const Value output : xla_call_module_op.getOutput()) {
+    output_types.push_back(output.getType());
+  }
+
+  entry_func_op.setFunctionType(
+      rewriter.getFunctionType(arg_types, output_types));
+
+  // Replace argument types and locs.
+  Block& entry = entry_func_op->getRegion(0).front();
+  for (auto [arg, arg_type, arg_loc] :
+       llvm::zip_equal(entry.getArguments(), arg_types, arg_locs)) {
+    arg.setType(arg_type);
+    arg.setLoc(arg_loc);
+  }
+}
+
+// An interface representing patterns that quantizes an entry function's body.
+// The entry function's signatures should have already been quantized at the
+// point of rewriting.
+class EntryFuncBodyQuantizationPattern {
+ public:
+  virtual ~EntryFuncBodyQuantizationPattern() = default;
+
+  // Returns `success()` if `entry_func_op`'s body is eligible for rewriting. At
+  // this point `entry_func_op`'s signature has not been reset with quantized
+  // types.
+  virtual LogicalResult match(func::FuncOp entry_func_op) const = 0;
+
+  // Rewrites the `entry_func_op`'s body.
+  virtual void rewrite(func::FuncOp entry_func_op,
+                       PatternRewriter& rewriter) const = 0;
+};
+
+// Quantizes the entry function's body containing a `DotGeneralOp`.
+class QuantizeDotGeneralOpPattern : public EntryFuncBodyQuantizationPattern {
+ public:
+  explicit QuantizeDotGeneralOpPattern(MLIRContext& ctx) : ctx_(&ctx) {}
+
+  LogicalResult match(func::FuncOp entry_func_op) const override {
+    auto& operations = entry_func_op.getBody().front().getOperations();
+    return success(operations.size() == 2 &&
+                   isa<DotGeneralOp>(operations.front()));
+  }
+
+  void rewrite(func::FuncOp entry_func_op,
+               PatternRewriter& rewriter) const override {
+    // Update the output type of the dot_general op.
+    auto dot_general_op = *entry_func_op.getOps<DotGeneralOp>().begin();
+
+    const Type input_type = entry_func_op.getArgumentTypes()[0];
+    const Type rhs_type = entry_func_op.getArgumentTypes()[1];
+    const Type func_result_type = entry_func_op.getResultTypes()[0];
+
+    const double input_scale = getElementTypeOrSelf(input_type)
+                                   .cast<UniformQuantizedType>()
+                                   .getScale();
+    const double rhs_scale =
+        getElementTypeOrSelf(rhs_type).cast<UniformQuantizedType>().getScale();
+
+    // Define the intermediate output type, which is an i32 quantized type.
+    // This is intermediate because the final output type of the entry_func_op
+    // should be an i8 quantized type.
+    const UniformQuantizedType output_quantized_element_type =
+        CreateI32F32UniformQuantizedType(dot_general_op->getLoc(), *ctx_,
+                                         input_scale * rhs_scale,
+                                         /*zero_point=*/0);
+
+    Value dot_general_op_result = dot_general_op->getResult(0);
+    const auto dot_general_op_result_type =
+        dot_general_op_result.getType().cast<RankedTensorType>();
+    const ArrayRef<int64_t> shape = dot_general_op_result_type.getShape();
+
+    const TensorType new_dot_general_op_result_type =
+        dot_general_op_result_type.cloneWith(shape,
+                                             output_quantized_element_type);
+    dot_general_op_result.setType(new_dot_general_op_result_type);
+
+    // Add i32 -> i8 requantization.
+    rewriter.setInsertionPointAfter(dot_general_op);
+    auto uniform_quant_op = rewriter.create<UniformQuantizeOp>(
+        dot_general_op->getLoc(), func_result_type,
+        dot_general_op->getResults());
+
+    auto return_op =
+        cast<func::ReturnOp>(entry_func_op.getBody().front().getTerminator());
+    return_op.setOperand(0, uniform_quant_op);
+  }
+
+ private:
+  MLIRContext* ctx_ = nullptr;
+};
+
+// Converts `entry_func_op` to be quantized according to the respective
+// inputs and outputs of `xla_call_module_op` that are possibly quantized. It
+// signature (type) is reset to match that of `xla_call_module_op`.
+// `entry_func_body_quantization_pattern` rewrites the function's body, based on
+// the new signature.
+void QuantizeEntryFuncOp(
+    MLIRContext& ctx, PatternRewriter& rewriter,
+    TF::XlaCallModuleOp xla_call_module_op, func::FuncOp entry_func_op,
+    const EntryFuncBodyQuantizationPattern& body_rewrite_pattern) {
+  SetQuantizedFunctionType(rewriter, entry_func_op, xla_call_module_op);
+
+  body_rewrite_pattern.rewrite(entry_func_op, rewriter);
+
+  // Rename the function to be clear that the function has been quantized.
+  const std::string quantized_function_name =
+      GetQuantizedFunctionName(entry_func_op.getSymName());
+  entry_func_op.setSymName(quantized_function_name);
+}
+
+// Replaces a quantized `xla_call_module_op` with a `func::CallOp`. The callee
+// is expected to remain unquantized (thus having a signature mismatch), and it
+// is also quantized accordingly.
+void ReplaceQuantizedXlaCallModuleOpWithQuantizedCallOp(
+    MLIRContext& ctx, PatternRewriter& rewriter,
+    TF::XlaCallModuleOp xla_call_module_op,
+    const EntryFuncBodyQuantizationPattern& body_rewrite_pattern) {
+  auto module_op = xla_call_module_op->getParentOfType<ModuleOp>();
+  SymbolTable symbol_table(module_op);
+
+  func::FuncOp entry_func_op = GetEntryFuncOp(xla_call_module_op, symbol_table);
+  QuantizeEntryFuncOp(ctx, rewriter, xla_call_module_op, entry_func_op,
+                      body_rewrite_pattern);
+
+  // Replace the XlaCallModuleOp with a new CallOp.
+  rewriter.setInsertionPoint(xla_call_module_op);
+  rewriter.replaceOpWithNewOp<func::CallOp>(xla_call_module_op, entry_func_op,
+                                            xla_call_module_op.getArgs());
+}
+
+// Pattern that mainly does two things:
+//
+//   1. Replaces quantized `TF::XlaCallModuleOp` with a `func::CallOp`.
+//   2. Quantizes the callee function.
+//
+// The inputs of this pattern assumes an invalid IR, where even if a
+// `TF::XlaCallModuleOp` is quantized the callee remains unquantized. Step (2)
+// not only replaces the input and output tensor types into quantized ones, but
+// also rewrites the body with a quantized equivalent.
+//
+// `FuncBodyRewritePatternT` defines how a function body is quantized and
+// rewritten.
+template <typename FuncBodyRewritePatternT,
+          typename = std::enable_if_t<std::is_base_of_v<
+              EntryFuncBodyQuantizationPattern, FuncBodyRewritePatternT>>>
+class XlaCallModuleOpToCallOp : public OpRewritePattern<TF::XlaCallModuleOp> {
+ public:
+  explicit XlaCallModuleOpToCallOp(MLIRContext& ctx)
+      : OpRewritePattern<TF::XlaCallModuleOp>(&ctx) {}
+
+  LogicalResult match(TF::XlaCallModuleOp op) const override {
+    auto module_op = op->getParentOfType<ModuleOp>();
+    SymbolTable symbol_table(module_op);
+
+    // Ignore unquantized ops.
+    if (!IsQuantizedXlaCallModuleOp(op)) return failure();
+
+    func::FuncOp entry_func_op = GetEntryFuncOp(op, symbol_table);
+    if (!entry_func_op) {
+      op->emitError("Failed to find a valid entry function.");
+      return failure();
+    }
+
+    return FuncBodyRewritePatternT(*getContext()).match(entry_func_op);
+  }
+
+  void rewrite(TF::XlaCallModuleOp xla_call_module_op,
+               PatternRewriter& rewriter) const override {
+    ReplaceQuantizedXlaCallModuleOpWithQuantizedCallOp(
+        *rewriter.getContext(), rewriter, xla_call_module_op,
+        FuncBodyRewritePatternT(*getContext()));
+  }
+};
+
+void QuantizeCompositeFunctionsPass::runOnOperation() {
+  MLIRContext& ctx = getContext();
+
+  QuantizationSpecs quant_specs;
+  quant_specs.inference_type = tensorflow::DT_QINT8;
+
+  PassManager pm(&ctx);
+  // Intermediate output from QuantizePass will have quantized ops
+  // (XlaCallModuleOps) with quantized input and output types, which are not
+  // allowed in the TF dialect.
+  pm.enableVerifier(false);
+
+  pm.addNestedPass<func::FuncOp>(CreatePrepareQuantizePass());
+  pm.addNestedPass<func::FuncOp>(CreateQuantizePass(quant_specs));
+  pm.addNestedPass<func::FuncOp>(createPostQuantizePass());
+
+  ModuleOp module_op = getOperation();
+  if (const absl::Status pm_run_status =
+          RunPassesOnModuleOp(mlir_dump_file_name_, pm, module_op);
+      !pm_run_status.ok()) {
+    signalPassFailure();
+  }
+
+  // TODO - b/307839649: Move this as a separate pass.
+  RewritePatternSet patterns(&ctx);
+  patterns.add<XlaCallModuleOpToCallOp<QuantizeDotGeneralOpPattern>>(ctx);
+
+  if (failed(applyPatternsAndFoldGreedily(module_op, std::move(patterns)))) {
+    signalPassFailure();
+  }
+}
+
+}  // namespace
+
+}  // namespace mlir::quant::stablehlo
diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc
index 8cc75fb..5bf8ba7 100644
--- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc
+++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc
@@ -78,7 +78,43 @@
   return Twine("_stablehlo_main_").concat(std::to_string(id)).str();
 }
 
-// Follows the structure of Live-variable analysis.
+// Follows the structure of Live-variable analysis. It is a form of
+// CFG (Control Flow Graph) analysis, often used in compilers.
+//
+// A variable is live if it holds a value that may be used in the future.
+// It is live-in at node n if it is live on any of the node's in-edges.
+// It is live-out at node n if it is live on any of the node's out-edges.
+// def[n] refers to values that are defined at node n.
+// use[n] refers to values that are used at node n.
+//
+// Given a node n, variables' liveliness is defined like the following:
+// live_in[n] = use[n] U (live_out[n] - def[n])
+// live_out[n] = U {live_in[s] | s ε succ[n]}
+//
+// Consider a sequence of op:
+//
+// ```
+// node 1: %0 = stablehlo.constant
+// node 2: %1 = stablehlo.constant
+// node 3: %2 = stablehlo.add %0, %1
+// node 4: %3 = stablehlo.multiply %2, %1
+// node 5: return %3
+// ```
+//
+// In Backward Liveliness analysis, the liveliness for each node above becomes:
+// live_in[5] = use[5]   U (live_out[5] - def[5])
+//            = {%3}     U ({∅} - {∅})            = {%3}
+// live_in[4] = use[4]   U (live_out[4] - def[4])
+//            = {%1, %2} U ({%3} - {%3})          = {%1, %2}
+// live_in[3] = use[3]   U (live_out[3] - def[3])
+//            = {%0, %1} U ({%1, %2} - {%2})      = {%0, %1}
+// live_in[2] = use[2]   U (live_out[2] - def[2])
+//            = {∅}      U ({%0, %1} - {%1})      = {%0}
+// live_in[1] = use[1]   U (live_out[1] - def[1])
+//            = {∅}      U ({%0} - {%0})          = {∅}
+//
+// This analogy is used throughout this pass to ensure only live edges form
+// proper subgraphs.
 class LiveOuts {
  public:
   LiveOuts() = default;
@@ -100,10 +136,10 @@
   void snapshot_previous_state() { prev_liveouts_ = liveouts_; }
 
   // Return the current live values.
-  DenseSet<Value>& get() { return liveouts_; }
+  const DenseSet<Value>& get() const { return liveouts_; }
 
   // Return the previous live values.
-  DenseSet<Value>& get_previous() { return prev_liveouts_; }
+  const DenseSet<Value>& get_previous() const { return prev_liveouts_; }
 
  private:
   DenseSet<Value> liveouts_;
@@ -212,6 +248,38 @@
   }
 }
 
+// Contains the actual logic for updating states and replacing StableHLO ops
+// with tf.XlaCallModuleOps.
+void UpdateStatesAndReplaceStablehloOps(
+    const DenseSet<Value>& operands, const DenseSet<Value>& defined_values,
+    const LiveOuts& liveouts, ModuleOp module_op,
+    ArrayRef<Operation*> reverse_subgraph, const int stablehlo_func_id,
+    func::FuncOp main_func, const bool is_last_subgraph = false) {
+  DenseSet<Value> inputs = operands;
+  for (Value defined_value : defined_values) {
+    inputs.erase(defined_value);
+  }
+
+  DenseSet<Value> outputs = liveouts.get_previous();
+  for (Value live_value : liveouts.get()) {
+    outputs.erase(live_value);
+  }
+
+  if (is_last_subgraph) {
+    // Additionally remove arguments from the outputs, as it provides liveness
+    // throughout (functions as an invisible op above the very first op that
+    // returns the arguments).
+    for (const BlockArgument arg : main_func.getArguments()) {
+      outputs.erase(arg);
+    }
+  }
+
+  ReplaceStablehloOpsWithXlaCallModuleOp(
+      SmallVector<Value>(inputs.begin(), inputs.end()),
+      SmallVector<Value>(outputs.begin(), outputs.end()), reverse_subgraph,
+      stablehlo_func_id, module_op);
+}
+
 // Replaces the StableHLO ops in the main function block with
 // tf.XlaCallModuleOps as separate subgraphs. Wires them back to the main
 // function block to be compatible with SavedModel structure.
@@ -241,20 +309,14 @@
   DenseSet<Value> operands;
   DenseSet<Value> defined_values;
 
-  int stablehlo_func_id = 0;
+  int stablehlo_func_id = -1;
   for (Operation* op : reverse_main_func_block_ops) {
     if (!IsStablehloOp(op)) {
       // Create an XlaCallModuleOp if reverse_subgraph isn't empty.
       if (!reverse_subgraph.empty()) {
-        DenseSet<Value> outputs = liveouts.get_previous();
-        for (Value live_value : liveouts.get()) {
-          outputs.erase(live_value);
-        }
-
-        ReplaceStablehloOpsWithXlaCallModuleOp(
-            SmallVector<Value>(operands.begin(), operands.end()),
-            SmallVector<Value>(outputs.begin(), outputs.end()),
-            reverse_subgraph, stablehlo_func_id++, module_op);
+        UpdateStatesAndReplaceStablehloOps(operands, defined_values, liveouts,
+                                           module_op, reverse_subgraph,
+                                           ++stablehlo_func_id, main_func);
 
         // Reset states and start a new subgraph.
         reverse_subgraph.clear();
@@ -273,25 +335,16 @@
     }
 
     reverse_subgraph.push_back(op);
+
+    defined_values.insert(op->getResults().begin(), op->getResults().end());
+    operands.insert(op->getOperands().begin(), op->getOperands().end());
   }
 
   // Create the last subgraph if it isn't empty.
   if (!reverse_subgraph.empty()) {
-    DenseSet<Value> outputs = liveouts.get_previous();
-    for (Value live_value : liveouts.get()) {
-      outputs.erase(live_value);
-    }
-    // Additionally remove arguments from the outputs, as it provides liveness
-    // throughout (functions as an invisible op above the very first op that
-    // returns the arguments).
-    for (const BlockArgument arg : main_func.getArguments()) {
-      outputs.erase(arg);
-    }
-
-    ReplaceStablehloOpsWithXlaCallModuleOp(
-        SmallVector<Value>(operands.begin(), operands.end()),
-        SmallVector<Value>(outputs.begin(), outputs.end()), reverse_subgraph,
-        stablehlo_func_id++, module_op);
+    UpdateStatesAndReplaceStablehloOps(
+        operands, defined_values, liveouts, module_op, reverse_subgraph,
+        ++stablehlo_func_id, main_func, /*is_last_subgraph=*/true);
   }
 }
 
diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/utils.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/utils.td
index 8c6ab88..744637d 100644
--- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/utils.td
+++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/utils.td
@@ -24,3 +24,7 @@
 
 // Fetches the default or null attribute, used for pattern matching.
 def DefaultOrNullAttr : NativeCodeCall<"DefaultOrNullAttr($_builder, $0)">;
+
+// Returns true if the given op is a StableHLO constant op.
+def IsStableHLOConstantOp : Constraint<CPred<"dyn_cast_or_null<::mlir::stablehlo::ConstantOp>($0.getDefiningOp())">>;
+
diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir
index 9f9c114..cb8ad65 100644
--- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir
+++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir
@@ -106,6 +106,53 @@
 
 // -----
 
+// CHECK-LABEL: func @quantize_per_channel
+func.func @quantize_per_channel(%arg0: tensor<26x26x3x2xf32>
+    ) -> tensor<26x26x3x2x!quant.uniform<i32:f32:3, {1.100000e+00:-10, 1.100000e-01:2}>> {
+  // CHECK-DAG: %[[SCALES:.*]] = mhlo.constant dense<[1.100000e+00, 1.100000e-01]>
+  // CHECK-DAG: %[[ZPS:.*]] = mhlo.constant dense<[-1.000000e+01, 2.000000e+00]>
+  // CHECK-DAG: %[[QMIN:.*]] = mhlo.constant dense<-2.14748365E+9> : tensor<f32>
+  // CHECK-DAG: %[[QMAX:.*]] = mhlo.constant dense<2.14748365E+9> : tensor<f32>
+  // CHECK: %[[DIVIDE:.*]] = chlo.broadcast_divide %arg0, %[[SCALES]]
+  // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>}
+  // CHECK-SAME: (tensor<26x26x3x2xf32>, tensor<2xf32>) -> tensor<26x26x3x2xf32>
+  // CHECK: %[[ADD:.*]] = chlo.broadcast_add %[[DIVIDE]], %[[ZPS]]
+  // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>}
+  // CHECK-SAME: (tensor<26x26x3x2xf32>, tensor<2xf32>) -> tensor<26x26x3x2xf32>
+  // CHECK: %[[CLAMP:.*]] = mhlo.clamp %[[QMIN]], %[[ADD]], %[[QMAX]]
+  // CHECK: %[[ROUND:.*]] = mhlo.round_nearest_even %[[CLAMP]]
+  // CHECK: %[[RESULT:.*]] = mhlo.convert %[[ROUND]]
+  // CHECK-SAME: (tensor<26x26x3x2xf32>) -> tensor<26x26x3x2xi32>
+  %0 = mhlo.uniform_quantize %arg0 : (tensor<26x26x3x2xf32>
+      ) -> tensor<26x26x3x2x!quant.uniform<i32:f32:3, {1.100000e+00:-10, 1.100000e-01:2}>>
+  return %0 : tensor<26x26x3x2x!quant.uniform<i32:f32:3, {1.100000e+00:-10, 1.100000e-01:2}>>
+}
+
+// -----
+
+// CHECK-LABEL: func @dequantize_per_channel
+func.func @dequantize_per_channel(
+    %arg0: tensor<26x26x3x2x!quant.uniform<i32:f32:3, {1.100000e+00:-10, 1.100000e-01:2}>>
+  ) -> tensor<26x26x3x2xf32> {
+  // CHECK-DAG: %[[SCALES:.*]] = mhlo.constant dense<[1.100000e+00, 1.100000e-01]>
+  // CHECK-DAG: %[[ZPS:.*]] = mhlo.constant dense<[-10, 2]> : tensor<2xi32>
+  // CHECK: %[[SUBTRACT:.*]] = chlo.broadcast_subtract
+  // CHECK-SAME: %[[INPUT:.*]], %[[ZPS]]
+  // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>}
+  // CHECK-SAME: (tensor<26x26x3x2xi32>, tensor<2xi32>) -> tensor<26x26x3x2xi32>
+  // CHECK: %[[FLOAT:.*]] = mhlo.convert %[[SUBTRACT]]
+  // CHECK: %[[RESULT:.*]] = chlo.broadcast_multiply
+  // CHECK-SAME: %[[FLOAT]], %[[SCALES]]
+  // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>}
+  // CHECK-SAME: (tensor<26x26x3x2xf32>, tensor<2xf32>) -> tensor<26x26x3x2xf32>
+  %0 = mhlo.uniform_dequantize %arg0 : (
+      tensor<26x26x3x2x!quant.uniform<i32:f32:3, {1.100000e+00:-10, 1.100000e-01:2}>>
+    ) -> tensor<26x26x3x2xf32>
+  return %0 : tensor<26x26x3x2xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @add
 func.func @add(
     %arg0: tensor<?x?x!quant.uniform<i8:f32, 1.000000e+00:3>>,
@@ -173,17 +220,11 @@
     %arg0: tensor<?x?x!quant.uniform<i8:f32, 1.000000e+01:3>>,
     %arg1: tensor<?x?x!quant.uniform<i8:f32, 5.000000e+00:1>>
   ) -> tensor<?x?x!quant.uniform<i8:f32, 5.000000e+00:1>> {
-  // CHECK: %[[VAL1:.*]] = mhlo.convert %[[LHS:.*]] : (tensor<?x?xi8>) -> tensor<?x?xi32>
-  // CHECK-DAG: %[[INPUT_ZPS:.*]] = mhlo.constant dense<3> : tensor<i32>
-  // CHECK: %[[VAL2:.*]] = chlo.broadcast_subtract %[[VAL1]], %[[INPUT_ZPS]] : (tensor<?x?xi32>, tensor<i32>) -> tensor<?x?xi32>
-  // CHECK-DAG: %[[MULTIPLIER:.*]] = mhlo.constant dense<16384> : tensor<i32>
-  // CHECK-DAG: %[[TOTAL_SHIFT:.*]] = mhlo.constant dense<13> : tensor<i32>
-  // CHECK-DAG: %[[HALF:.*]] = mhlo.constant dense<4096> : tensor<i32>
-  // CHECK: %[[VAL3:.*]] = chlo.broadcast_multiply %[[VAL2]], %[[MULTIPLIER]] : (tensor<?x?xi32>, tensor<i32>) -> tensor<?x?xi32>
-  // CHECK: %[[VAL4:.*]] = chlo.broadcast_add %[[VAL3]], %[[HALF]] : (tensor<?x?xi32>, tensor<i32>) -> tensor<?x?xi32>
-  // CHECK: %[[VAL5:.*]] = chlo.broadcast_shift_right_arithmetic %[[VAL4]], %[[TOTAL_SHIFT]] : (tensor<?x?xi32>, tensor<i32>) -> tensor<?x?xi32>
-  // CHECK-DAG: %[[OUTPUT_ZPS:.*]] = mhlo.constant dense<1> : tensor<i32>
-  // CHECK: %[[LHS_32_REQ:.*]] = chlo.broadcast_add %[[VAL5]], %[[OUTPUT_ZPS]] : (tensor<?x?xi32>, tensor<i32>) -> tensor<?x?xi32>
+  // CHECK-DAG: %[[COMBINED_SCALE:.*]] = mhlo.constant dense<2.000000e+00> : tensor<f32>
+  // CHECK-DAG: %[[LHS:.*]] = mhlo.convert %arg0 : (tensor<?x?xi8>) -> tensor<?x?xf32>
+  // CHECK-DAG: %[[MUL:.*]] = chlo.broadcast_multiply %[[LHS]], %[[COMBINED_SCALE]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
+  // CHECK-DAG: %[[COMBINED_ZP:.*]] = mhlo.constant dense<-5.000000e+00>
+  // CHECK: %[[LHS_32:.*]] = chlo.broadcast_add %[[MUL]], %[[COMBINED_ZP]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
 
   // CHECK-DAG: %[[RHS_32:.*]] = mhlo.convert %[[RHS:.*]] : (tensor<?x?xi8>) -> tensor<?x?xi32>
   // CHECK-DAG: %[[RES_ZPS:.*]] = mhlo.constant dense<1> : tensor<i32>
@@ -207,18 +248,11 @@
     %arg0: tensor<?x?x!quant.uniform<i8:f32, 5.000000e+00:1>>,
     %arg1: tensor<?x?x!quant.uniform<i8:f32, 1.000000e+01:3>>
   ) -> tensor<?x?x!quant.uniform<i8:f32, 5.000000e+00:1>> {
-  // CHECK: %[[VAL0:.*]] = mhlo.convert %[[LHS:.*]] : (tensor<?x?xi8>) -> tensor<?x?xi32>
-  // CHECK: %[[VAL1:.*]] = mhlo.convert %[[RHS:.*]] : (tensor<?x?xi8>) -> tensor<?x?xi32>
-  // CHECK-DAG: %[[INPUT_ZPS:.*]] = mhlo.constant dense<3> : tensor<i32>
-  // CHECK: %[[VAL2:.*]] = chlo.broadcast_subtract %[[VAL1]], %[[INPUT_ZPS]] : (tensor<?x?xi32>, tensor<i32>) -> tensor<?x?xi32>
-  // CHECK-DAG: %[[MULTIPLIER:.*]] = mhlo.constant dense<16384> : tensor<i32>
-  // CHECK-DAG: %[[TOTAL_SHIFT:.*]] = mhlo.constant dense<13> : tensor<i32>
-  // CHECK-DAG: %[[HALF:.*]] = mhlo.constant dense<4096> : tensor<i32>
-  // CHECK: %[[VAL3:.*]] = chlo.broadcast_multiply %[[VAL2]], %[[MULTIPLIER]] : (tensor<?x?xi32>, tensor<i32>) -> tensor<?x?xi32>
-  // CHECK: %[[VAL4:.*]] = chlo.broadcast_add %[[VAL3]], %[[HALF]] : (tensor<?x?xi32>, tensor<i32>) -> tensor<?x?xi32>
-  // CHECK: %[[VAL5:.*]] = chlo.broadcast_shift_right_arithmetic %[[VAL4]], %[[TOTAL_SHIFT]] : (tensor<?x?xi32>, tensor<i32>) -> tensor<?x?xi32>
-  // CHECK-DAG: %[[OUTPUT_ZPS:.*]] = mhlo.constant dense<1> : tensor<i32>
-  // CHECK: %[[RHS_32_REQ:.*]] = chlo.broadcast_add %[[VAL5]], %[[OUTPUT_ZPS]] : (tensor<?x?xi32>, tensor<i32>) -> tensor<?x?xi32>
+  // CHECK-DAG: %[[COMBINED_SCALE:.*]] = mhlo.constant dense<2.000000e+00> : tensor<f32>
+  // CHECK-DAG: %[[RHS:.*]] = mhlo.convert %arg1 : (tensor<?x?xi8>) -> tensor<?x?xf32>
+  // CHECK-DAG: %[[MUL:.*]] = chlo.broadcast_multiply %[[RHS]], %[[COMBINED_SCALE]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
+  // CHECK-DAG: %[[COMBINED_ZP:.*]] = mhlo.constant dense<-5.000000e+00>
+  // CHECK: %[[RHS_32:.*]] = chlo.broadcast_add %[[MUL]], %[[COMBINED_ZP]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
 
   // CHECK-DAG: %[[RES_ZPS:.*]] = mhlo.constant dense<1> : tensor<i32>
   // CHECK-DAG: %[[VAL7:.*]] = chlo.broadcast_add %[[LHS_32:.*]], %[[RHS_32_REQ:.*]] : (tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
@@ -239,29 +273,17 @@
     %arg0: tensor<?x?x!quant.uniform<i8:f32, 1.000000e+01:3>>,
     %arg1: tensor<?x?x!quant.uniform<i8:f32, 1.000000e+01:3>>
   ) -> tensor<?x?x!quant.uniform<i8:f32, 5.000000e+00:1>> {
-  // CHECK: %[[VAL1:.*]] = mhlo.convert %[[LHS:.*]] : (tensor<?x?xi8>) -> tensor<?x?xi32>
-  // CHECK-DAG: %[[INPUT_ZPS:.*]] = mhlo.constant dense<3> : tensor<i32>
-  // CHECK: %[[VAL2:.*]] = chlo.broadcast_subtract %[[VAL1]], %[[INPUT_ZPS]] : (tensor<?x?xi32>, tensor<i32>) -> tensor<?x?xi32>
-  // CHECK-DAG: %[[MULTIPLIER:.*]] = mhlo.constant dense<16384> : tensor<i32>
-  // CHECK-DAG: %[[TOTAL_SHIFT:.*]] = mhlo.constant dense<13> : tensor<i32>
-  // CHECK-DAG: %[[HALF:.*]] = mhlo.constant dense<4096> : tensor<i32>
-  // CHECK: %[[VAL3:.*]] = chlo.broadcast_multiply %[[VAL2]], %[[MULTIPLIER]] : (tensor<?x?xi32>, tensor<i32>) -> tensor<?x?xi32>
-  // CHECK: %[[VAL4:.*]] = chlo.broadcast_add %[[VAL3]], %[[HALF]] : (tensor<?x?xi32>, tensor<i32>) -> tensor<?x?xi32>
-  // CHECK: %[[VAL5:.*]] = chlo.broadcast_shift_right_arithmetic %[[VAL4]], %[[TOTAL_SHIFT]] : (tensor<?x?xi32>, tensor<i32>) -> tensor<?x?xi32>
-  // CHECK-DAG: %[[OUTPUT_ZPS:.*]] = mhlo.constant dense<1> : tensor<i32>
-  // CHECK: %[[LHS_32_REQ:.*]] = chlo.broadcast_add %[[VAL5]], %[[OUTPUT_ZPS]] : (tensor<?x?xi32>, tensor<i32>) -> tensor<?x?xi32>
+  // CHECK-DAG: %[[COMBINED_SCALE:.*]] = mhlo.constant dense<2.000000e+00> : tensor<f32>
+  // CHECK-DAG: %[[LHS:.*]] = mhlo.convert %arg0 : (tensor<?x?xi8>) -> tensor<?x?xf32>
+  // CHECK-DAG: %[[MUL:.*]] = chlo.broadcast_multiply %[[LHS]], %[[COMBINED_SCALE]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
+  // CHECK-DAG: %[[COMBINED_ZP:.*]] = mhlo.constant dense<-5.000000e+00>
+  // CHECK: %[[LHS_32_REQ:.*]] = chlo.broadcast_add %[[MUL]], %[[COMBINED_ZP]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
 
-  // CHECK: %[[VAL6:.*]] = mhlo.convert %[[RHS:.*]] : (tensor<?x?xi8>) -> tensor<?x?xi32>
-  // CHECK-DAG: %[[INPUT_ZPS:.*]] = mhlo.constant dense<3> : tensor<i32>
-  // CHECK: %[[VAL7:.*]] = chlo.broadcast_subtract %[[VAL6]], %[[INPUT_ZPS]] : (tensor<?x?xi32>, tensor<i32>) -> tensor<?x?xi32>
-  // CHECK-DAG: %[[MULTIPLIER:.*]] = mhlo.constant dense<16384> : tensor<i32>
-  // CHECK-DAG: %[[TOTAL_SHIFT:.*]] = mhlo.constant dense<13> : tensor<i32>
-  // CHECK-DAG: %[[HALF:.*]] = mhlo.constant dense<4096> : tensor<i32>
-  // CHECK: %[[VAL8:.*]] = chlo.broadcast_multiply %[[VAL7]], %[[MULTIPLIER]] : (tensor<?x?xi32>, tensor<i32>) -> tensor<?x?xi32>
-  // CHECK: %[[VAL9:.*]] = chlo.broadcast_add %[[VAL8]], %[[HALF]] : (tensor<?x?xi32>, tensor<i32>) -> tensor<?x?xi32>
-  // CHECK: %[[VAL10:.*]] = chlo.broadcast_shift_right_arithmetic %[[VAL9]], %[[TOTAL_SHIFT]] : (tensor<?x?xi32>, tensor<i32>) -> tensor<?x?xi32>
-  // CHECK-DAG: %[[OUTPUT_ZPS:.*]] = mhlo.constant dense<1> : tensor<i32>
-  // CHECK: %[[RHS_32_REQ:.*]] = chlo.broadcast_add %[[VAL10]], %[[OUTPUT_ZPS]] : (tensor<?x?xi32>, tensor<i32>) -> tensor<?x?xi32>
+  // CHECK-DAG: %[[COMBINED_SCALE:.*]] = mhlo.constant dense<2.000000e+00> : tensor<f32>
+  // CHECK-DAG: %[[RHS:.*]] = mhlo.convert %arg1 : (tensor<?x?xi8>) -> tensor<?x?xf32>
+  // CHECK-DAG: %[[MUL:.*]] = chlo.broadcast_multiply %[[RHS]], %[[COMBINED_SCALE]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
+  // CHECK-DAG: %[[COMBINED_ZP:.*]] = mhlo.constant dense<-5.000000e+00>
+  // CHECK: %[[RHS_32_REQ:.*]] = chlo.broadcast_add %[[MUL]], %[[COMBINED_ZP]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
 
   // CHECK-DAG: %[[RES_ZPS:.*]] = mhlo.constant dense<1> : tensor<i32>
   // CHECK-DAG: %[[VAL11:.*]] = chlo.broadcast_add %[[LHS_32_REQ:.*]], %[[RHS_32_REQ:.*]] : (tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
@@ -1415,6 +1437,17 @@
 
 // -----
 
+// CHECK-LABEL: func @mhlo_constant_uniform_quantized_per_channel
+func.func @mhlo_constant_uniform_quantized_per_channel() -> () {
+  // CHECK: mhlo.constant dense<[9, 4]> : tensor<2xi8>
+  %0 = mhlo.constant() {value = dense<[9, 4]> : tensor<2xi8>} : ()
+      -> tensor<2x!quant.uniform<i8:f32:0, {1.000000e+00:3, 2.000000e+00:-2}>>
+  return
+}
+
+
+// -----
+
 // CHECK-LABEL: func @mhlo_constant_int
 func.func @mhlo_constant_int() -> tensor<i32> {
   // CHECK: mhlo.constant dense<-128> : tensor<i32>
diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-tf-quant-types.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-tf-quant-types.mlir
index c83f95c..73555e6 100644
--- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-tf-quant-types.mlir
+++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-tf-quant-types.mlir
@@ -11,17 +11,17 @@
 
 // CHECK-LABEL: func @if_qint8(%arg0: tensor<i1>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<1xi8>
 func.func @if_qint8(%arg0: tensor<i1>, %arg1: tensor<1x!tf_type.qint8>, %arg2: tensor<1x!tf_type.qint8>) -> tensor<1x!tf_type.qint8> {
-  // CHECK-NEXT: %0 = "tf.IfRegion"(%arg0) ({
+  // CHECK-NEXT: %0 = "tf.IfRegion"(%arg0) <{is_stateless = false}> ({
   // CHECK-NEXT:   "tf.Yield"(%arg1) : (tensor<1xi8>) -> ()
   // CHECK-NEXT:   }, {
   // CHECK-NEXT:   "tf.Yield"(%arg2) : (tensor<1xi8>) -> ()
-  // CHECK-NEXT:  }) {is_stateless = false} : (tensor<i1>) -> tensor<1xi8>
+  // CHECK-NEXT:  }) : (tensor<i1>) -> tensor<1xi8>
   // CHECK-NEXT: return %0 : tensor<1xi8>
-  %0 = "tf.IfRegion"(%arg0) ({
+  %0 = "tf.IfRegion"(%arg0) <{is_stateless = false}> ({
     "tf.Yield"(%arg1) : (tensor<1x!tf_type.qint8>) -> ()
     }, {
     "tf.Yield"(%arg2) : (tensor<1x!tf_type.qint8>) -> ()
-   }) {is_stateless = false} : (tensor<i1>) -> tensor<1x!tf_type.qint8>
+   }) : (tensor<i1>) -> tensor<1x!tf_type.qint8>
   func.return %0 : tensor<1x!tf_type.qint8>
 }
 
@@ -74,7 +74,7 @@
   %zps = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
 
   // CHECK: %[[qint:.*]] = "tf.UniformQuantize"
-  // CHECK: %[[int:.*]] = "tf.Cast"(%[[qint]]) {Truncate = false} : (tensor<1x!tf_type.qint8>) -> tensor<1xi8>
+  // CHECK: %[[int:.*]] = "tf.Cast"(%[[qint]]) <{Truncate = false}> : (tensor<1x!tf_type.qint8>) -> tensor<1xi8>
   %0 = "tf.UniformQuantize"(%arg0, %scales, %zps) {
     quantization_axis = -1 : i64, quantization_min_val = -128 : i64, quantization_max_val = 127 : i64
   } : (tensor<1xf32>, tensor<f32>, tensor<i32>) -> tensor<1x!tf_type.qint8>
@@ -92,7 +92,7 @@
   %zps = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
 
   // CHECK: %[[qint:.*]] = "tf.UniformQuantize"
-  // CHECK: %[[int:.*]] = "tf.Cast"(%[[qint]]) {Truncate = false} : (tensor<1x!tf_type.qint8>) -> tensor<1xi8>
+  // CHECK: %[[int:.*]] = "tf.Cast"(%[[qint]]) <{Truncate = false}> : (tensor<1x!tf_type.qint8>) -> tensor<1xi8>
   %0 = "tf.UniformQuantize"(%arg0, %scales, %zps) {
     quantization_axis = -1 : i64, quantization_min_val = -128 : i64, quantization_max_val = 127 : i64
   } : (tensor<1xf32>, tensor<f32>, tensor<i32>) -> tensor<1x!tf_type.qint8>
@@ -109,7 +109,7 @@
   %scales = "tf.Const"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32>
   %zps = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
 
-  // CHECK: %[[x:.*]] = "tf.Cast"(%arg0) {Truncate = false} : (tensor<1xi8>) -> tensor<1x!tf_type.qint8>
+  // CHECK: %[[x:.*]] = "tf.Cast"(%arg0) <{Truncate = false}> : (tensor<1xi8>) -> tensor<1x!tf_type.qint8>
   // CHECK: %[[y:.*]] = "tf.UniformDequantize"(%[[x]]
   %0 = "tf.UniformDequantize"(%arg0, %scales, %zps) {
     quantization_axis = -1 : i64, quantization_min_val = -128 : i64, quantization_max_val = 127 : i64
@@ -132,8 +132,8 @@
     quantization_axis = -1 : i64, quantization_min_val = -128 : i64, quantization_max_val = 127 : i64
   } : (tensor<1xf32>, tensor<f32>, tensor<i32>) -> tensor<1x!tf_type.qint8>
 
-  // CHECK: %[[int:.*]] = "tf.Cast"(%[[qint0]]) {Truncate = false} : (tensor<1x!tf_type.qint8>) -> tensor<1xi8>
-  // CHECK: %[[qint1:.*]] = "tf.Cast"(%[[int]]) {Truncate = false} : (tensor<1xi8>) -> tensor<1x!tf_type.qint8>
+  // CHECK: %[[int:.*]] = "tf.Cast"(%[[qint0]]) <{Truncate = false}> : (tensor<1x!tf_type.qint8>) -> tensor<1xi8>
+  // CHECK: %[[qint1:.*]] = "tf.Cast"(%[[int]]) <{Truncate = false}> : (tensor<1xi8>) -> tensor<1x!tf_type.qint8>
   // CHECK: %[[res:.*]] = "tf.UniformDequantize"(%[[qint1]]
   %1 = "tf.UniformDequantize"(%0, %scales, %zps) {
     quantization_axis = -1 : i64, quantization_min_val = -128 : i64, quantization_max_val = 127 : i64
@@ -155,10 +155,10 @@
     %output_scales = "tf.Const"() { value = dense<2.0> : tensor<f32> } : () -> tensor<f32>
     %output_zps = "tf.Const"() { value = dense<4> : tensor<i32> } : () -> tensor<i32>
 
-    // CHECK: %[[lhs:.*]] = "tf.Cast"(%arg0) {Truncate = false} : (tensor<2xi32>) -> tensor<2x!tf_type.qint32>
-    // CHECK: %[[rhs:.*]] = "tf.Cast"(%arg1) {Truncate = false} : (tensor<2xi32>) -> tensor<2x!tf_type.qint32>
+    // CHECK: %[[lhs:.*]] = "tf.Cast"(%arg0) <{Truncate = false}> : (tensor<2xi32>) -> tensor<2x!tf_type.qint32>
+    // CHECK: %[[rhs:.*]] = "tf.Cast"(%arg1) <{Truncate = false}> : (tensor<2xi32>) -> tensor<2x!tf_type.qint32>
     // CHECK: %[[res_qint:.*]] = "tf.UniformQuantizedAdd"(%[[lhs]], %[[rhs]]
-    // CHECK: %[[res_int:.*]] = "tf.Cast"(%[[res_qint]]) {Truncate = false} : (tensor<2x!tf_type.qint32>) -> tensor<2xi32>
+    // CHECK: %[[res_int:.*]] = "tf.Cast"(%[[res_qint]]) <{Truncate = false}> : (tensor<2x!tf_type.qint32>) -> tensor<2xi32>
     // CHECK: return %[[res_int]] : tensor<2xi32>
     %1 = "tf.UniformQuantizedAdd"(
       %arg0, %arg1,
@@ -190,13 +190,13 @@
   %zps4 = "tf.Const"() { value = dense<4> : tensor<i32> } : () -> tensor<i32>
 
   // CHECK: %[[qint_0:.*]] = "tf.UniformQuantize"
-  // CHECK: %[[int_0:.*]] = "tf.Cast"(%[[qint_0]]) {Truncate = false} : (tensor<2x2x!tf_type.qint8>) -> tensor<2x2xi8>
+  // CHECK: %[[int_0:.*]] = "tf.Cast"(%[[qint_0]]) <{Truncate = false}> : (tensor<2x2x!tf_type.qint8>) -> tensor<2x2xi8>
   %0 = "tf.UniformQuantize"(%arg0, %scales, %zps2) {
     quantization_axis = -1 : i64, quantization_min_val = -128 : i64, quantization_max_val = 127 : i64
   } : (tensor<2x2xf32>, tensor<f32>, tensor<i32>) -> tensor<2x2x!tf_type.qint8>
 
   // CHECK: %[[qint_1:.*]] = "tf.UniformQuantize"
-  // CHECK: %[[int_1:.*]] = "tf.Cast"(%[[qint_1]]) {Truncate = false} : (tensor<2x2x!tf_type.qint8>) -> tensor<2x2xi8>
+  // CHECK: %[[int_1:.*]] = "tf.Cast"(%[[qint_1]]) <{Truncate = false}> : (tensor<2x2x!tf_type.qint8>) -> tensor<2x2xi8>
   %1 = "tf.UniformQuantize"(%arg0, %scales, %zps4) {
     quantization_axis = -1 : i64, quantization_min_val = -128 : i64, quantization_max_val = 127 : i64
   } : (tensor<2x2xf32>, tensor<f32>, tensor<i32>) -> tensor<2x2x!tf_type.qint8>
@@ -212,11 +212,11 @@
     "tf.Yield"(%id, %barg1) : (tensor<2x?x!tf_type.qint8>, tensor<?x2x!tf_type.qint8>) -> ()
   }) {is_stateless = false} : (tensor<2x2x!tf_type.qint8>, tensor<2x2x!tf_type.qint8>) -> (tensor<2x?x!tf_type.qint8>, tensor<?x2x!tf_type.qint8>)
 
-  // CHECK: %[[out_qint_0:.*]] = "tf.Cast"(%[[while_result]]#0) {Truncate = false} : (tensor<2x?xi8>) -> tensor<2x?x!tf_type.qint8>
+  // CHECK: %[[out_qint_0:.*]] = "tf.Cast"(%[[while_result]]#0) <{Truncate = false}> : (tensor<2x?xi8>) -> tensor<2x?x!tf_type.qint8>
   // CHECK: %[[out_f_0:.*]] = "tf.UniformDequantize"(%[[out_qint_0]]
   %3 = "tf.UniformDequantize"(%2#0, %scales, %zps2) {quantization_axis = -1 : i64, quantization_min_val = -128 : i64, quantization_max_val = 127 : i64} : (tensor<2x?x!tf_type.qint8>, tensor<f32>, tensor<i32>) -> tensor<2x?xf32>
 
-  // CHECK: %[[out_qint_1:.*]] = "tf.Cast"(%[[while_result]]#1) {Truncate = false} : (tensor<?x2xi8>) -> tensor<?x2x!tf_type.qint8>
+  // CHECK: %[[out_qint_1:.*]] = "tf.Cast"(%[[while_result]]#1) <{Truncate = false}> : (tensor<?x2xi8>) -> tensor<?x2x!tf_type.qint8>
   // CHECK: %[[out_f_1:.*]] = "tf.UniformDequantize"(%[[out_qint_1]]
   %4 = "tf.UniformDequantize"(%2#1, %scales, %zps4) {quantization_axis = -1 : i64, quantization_min_val = -128 : i64, quantization_max_val = 127 : i64} : (tensor<?x2x!tf_type.qint8>, tensor<f32>, tensor<i32>) -> tensor<?x2xf32>
 
@@ -234,7 +234,7 @@
 
   // CHECK: %[[input:.*]] = "tf.ConcatV2"(%arg0, %arg1
   // CHECK: %[[output_qint:.*]] = "tf.UniformQuantize"(%[[input]]
-  // CHECK: %[[output:.*]] = "tf.Cast"(%[[output_qint]]) {Truncate = false} : (tensor<6x3x!tf_type.qint8>) -> tensor<6x3xi8>
+  // CHECK: %[[output:.*]] = "tf.Cast"(%[[output_qint]]) <{Truncate = false}> : (tensor<6x3x!tf_type.qint8>) -> tensor<6x3xi8>
   // CHECK: return %[[output]] : tensor<6x3xi8>
   %0 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<i64>) -> tensor<6x3xf32>
   %1 = "tf.UniformQuantize"(%0, %scales, %zps) {
@@ -252,7 +252,7 @@
   %zps = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
 
   // CHECK: %[[input:.*]] = "tf.ConcatV2"(%arg0, %arg1, %[[VAL:.*]]) : (tensor<3x3xi8>, tensor<3x3xi8>, tensor<i64>) -> tensor<6x3xi8>
-  // CHECK: %[[input_qint:.*]] = "tf.Cast"(%[[input]]) {Truncate = false} : (tensor<6x3xi8>) -> tensor<6x3x!tf_type.qint8>
+  // CHECK: %[[input_qint:.*]] = "tf.Cast"(%[[input]]) <{Truncate = false}> : (tensor<6x3xi8>) -> tensor<6x3x!tf_type.qint8>
   // CHECK: %[[output:.*]] = "tf.UniformDequantize"(%[[input_qint]]
   // CHECK: return %[[output]] : tensor<6x3xf32>
   %0 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3x!tf_type.qint8>, tensor<3x3x!tf_type.qint8>, tensor<i64>) -> tensor<6x3x!tf_type.qint8>
@@ -266,7 +266,7 @@
 
 // CHECK-LABEL: func @tf_const_qint32
 func.func @tf_const_qint32() -> tensor<1x!tf_type.qint32> {
-  // CHECK: %[[result:.*]] = "tf.Const"() {value = dense<127> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK: %[[result:.*]] = "tf.Const"() <{value = dense<127> : tensor<1xi32>}> : () -> tensor<1xi32>
   %0 = "tf.Const"() { value = #tf_type<tensor_proto : "0x746674656E736F722464747970653A2044545F51494E5433322074656E736F725F7368617065207B207D2074656E736F725F636F6E74656E743A20225C3137375C3030305C3030305C30303022"> : tensor<1x!tf_type.qint32> } : () -> tensor<1x!tf_type.qint32>
   // CHECK: return %[[result]] : tensor<1xi32>
   func.return %0 :  tensor<1x!tf_type.qint32>
@@ -276,7 +276,7 @@
 
 // CHECK-LABEL: func @tf_const_qint8
 func.func @tf_const_qint8() -> tensor<2x!tf_type.qint8> {
-  // CHECK: %[[result:.*]] = "tf.Const"() {value = dense<[127, 18]> : tensor<2xi8>} : () -> tensor<2xi8>
+  // CHECK: %[[result:.*]] = "tf.Const"() <{value = dense<[127, 18]> : tensor<2xi8>}> : () -> tensor<2xi8>
   %0 = "tf.Const"() { value = #tf_type<tensor_proto : "0x746674656e736f722464747970653a2044545f51494e54382074656e736f725f7368617065207b2064696d207b2073697a653a2032207d207d2074656e736f725f636f6e74656e743a20225c3137375c30323222"> : tensor<2x!tf_type.qint8> } : () -> tensor<2x!tf_type.qint8>
   // CHECK: return %[[result]] : tensor<2xi8>
   func.return %0 :  tensor<2x!tf_type.qint8>
@@ -295,7 +295,7 @@
 
 // CHECK-LABEL: func @cast_op_qint32_int32
 func.func @cast_op_qint32_int32(%arg0: tensor<1x!tf_type.qint32>) -> tensor<1xi32> {
-  // CHECK: "tf.Cast"(%arg0) {Truncate = false} : (tensor<1xi32>) -> tensor<1xi32>
+  // CHECK: "tf.Cast"(%arg0) <{Truncate = false}> : (tensor<1xi32>) -> tensor<1xi32>
   %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<1x!tf_type.qint32>) -> tensor<1xi32>
   func.return %0: tensor<1xi32>
 }
@@ -304,7 +304,7 @@
 
 // CHECK-LABEL: func @cast_op_int32_qint32
 func.func @cast_op_int32_qint32(%arg0: tensor<1xi32>) -> tensor<1x!tf_type.qint32> {
-  // CHECK: "tf.Cast"(%arg0) {Truncate = false} : (tensor<1xi32>) -> tensor<1xi32>
+  // CHECK: "tf.Cast"(%arg0) <{Truncate = false}> : (tensor<1xi32>) -> tensor<1xi32>
   %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<1xi32>) -> tensor<1x!tf_type.qint32>
   func.return %0: tensor<1x!tf_type.qint32>
 }
@@ -313,7 +313,7 @@
 
 // CHECK-LABEL: func @cast_op_qint8_int8
 func.func @cast_op_qint8_int8(%arg0: tensor<1x!tf_type.qint8>) -> tensor<1xi8> {
-  // CHECK: "tf.Cast"(%arg0) {Truncate = false} : (tensor<1xi8>) -> tensor<1xi8>
+  // CHECK: "tf.Cast"(%arg0) <{Truncate = false}> : (tensor<1xi8>) -> tensor<1xi8>
   %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<1x!tf_type.qint8>) -> tensor<1xi8>
   func.return %0: tensor<1xi8>
 }
@@ -322,7 +322,7 @@
 
 // CHECK-LABEL: func @cast_op_int8_qint8
 func.func @cast_op_int8_qint8(%arg0: tensor<1xi8>) -> tensor<1x!tf_type.qint8> {
-  // CHECK: "tf.Cast"(%arg0) {Truncate = false} : (tensor<1xi8>) -> tensor<1xi8>
+  // CHECK: "tf.Cast"(%arg0) <{Truncate = false}> : (tensor<1xi8>) -> tensor<1xi8>
   %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<1xi8>) -> tensor<1x!tf_type.qint8>
   func.return %0: tensor<1x!tf_type.qint8>
 }
@@ -331,7 +331,7 @@
 
 // CHECK-LABEL: func @cast_op_qint32_int8
 func.func @cast_op_qint32_int8(%arg0: tensor<1x!tf_type.qint32>) -> tensor<1xi8> {
-  // CHECK: "tf.Cast"(%arg0) {Truncate = false} : (tensor<1xi32>) -> tensor<1xi8>
+  // CHECK: "tf.Cast"(%arg0) <{Truncate = false}> : (tensor<1xi32>) -> tensor<1xi8>
   %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<1x!tf_type.qint32>) -> tensor<1xi8>
   func.return %0: tensor<1xi8>
 }
@@ -340,7 +340,7 @@
 
 // CHECK-LABEL: func @cast_op_int8_qint32
 func.func @cast_op_int8_qint32(%arg0: tensor<1xi8>) -> tensor<1x!tf_type.qint32> {
-  // CHECK: "tf.Cast"(%arg0) {Truncate = false} : (tensor<1xi8>) -> tensor<1xi32>
+  // CHECK: "tf.Cast"(%arg0) <{Truncate = false}> : (tensor<1xi8>) -> tensor<1xi32>
   %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<1xi8>) -> tensor<1x!tf_type.qint32>
   func.return %0: tensor<1x!tf_type.qint32>
 }
@@ -353,10 +353,10 @@
   %scales = "tf.Const"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32>
   %zps = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
 
-  // CHECK: %[[x:.*]] = "tf.Cast"(%arg0) {Truncate = false} : (tensor<1xi32>) -> tensor<1xi8>
+  // CHECK: %[[x:.*]] = "tf.Cast"(%arg0) <{Truncate = false}> : (tensor<1xi32>) -> tensor<1xi8>
   %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<1x!tf_type.qint32>) -> tensor<1x!tf_type.qint8>
 
-  // CHECK: %[[y:.*]] = "tf.Cast"(%[[x]]) {Truncate = false} : (tensor<1xi8>) -> tensor<1x!tf_type.qint8>
+  // CHECK: %[[y:.*]] = "tf.Cast"(%[[x]]) <{Truncate = false}> : (tensor<1xi8>) -> tensor<1x!tf_type.qint8>
   // CHECK: %[[z:.*]] = "tf.UniformDequantize"(%[[y]]
   %1 = "tf.UniformDequantize"(%0, %scales, %zps) {
     quantization_axis = -1 : i64, quantization_min_val = -128 : i64, quantization_max_val = 127 : i64
@@ -379,8 +379,8 @@
     quantization_axis = -1 : i64, quantization_min_val = -128 : i64, quantization_max_val = 127 : i64
   } : (tensor<1xf32>, tensor<f32>, tensor<i32>) -> tensor<1x!tf_type.qint8>
 
-  // CHECK: %1 = "tf.Cast"(%0) {Truncate = false} : (tensor<1x!tf_type.qint8>) -> tensor<1xi8>
-  // CHECK: %2 = "tf.Cast"(%1) {Truncate = false} : (tensor<1xi8>) -> tensor<1xi32>
+  // CHECK: %1 = "tf.Cast"(%0) <{Truncate = false}> : (tensor<1x!tf_type.qint8>) -> tensor<1xi8>
+  // CHECK: %2 = "tf.Cast"(%1) <{Truncate = false}> : (tensor<1xi8>) -> tensor<1xi32>
   %1 = "tf.Cast"(%0) {Truncate = false} : (tensor<1x!tf_type.qint8>) -> tensor<1x!tf_type.qint32>
 
   // CHECK: return %2 : tensor<1xi32>
@@ -398,15 +398,15 @@
   %zps1 = "tf.Const"() { value = dense<2> : tensor<i32> } : () -> tensor<i32>
 
   // CHECK: %[[qint_1:.*]] = "tf.UniformQuantize"
-  // CHECK: %[[int_1:.*]] = "tf.Cast"(%[[qint_1]]) {Truncate = false} : (tensor<1x!tf_type.qint8>) -> tensor<1xi8>
+  // CHECK: %[[int_1:.*]] = "tf.Cast"(%[[qint_1]]) <{Truncate = false}> : (tensor<1x!tf_type.qint8>) -> tensor<1xi8>
   %0 = "tf.UniformQuantize"(%arg0, %scales, %zps) {
     quantization_axis = -1 : i64, quantization_min_val = -128 : i64, quantization_max_val = 127 : i64
   } : (tensor<1xf32>, tensor<f32>, tensor<i32>) -> tensor<1x!tf_type.qint8>
 
-  // CHECK: %[[int_2:.*]] = "tf.Cast"(%[[int_1]]) {Truncate = false} : (tensor<1xi8>) -> tensor<1xi32>
+  // CHECK: %[[int_2:.*]] = "tf.Cast"(%[[int_1]]) <{Truncate = false}> : (tensor<1xi8>) -> tensor<1xi32>
   %1 = "tf.Cast"(%0) {Truncate = false} : (tensor<1x!tf_type.qint8>) -> tensor<1x!tf_type.qint32>
 
-  // CHECK: %[[qint_2:.*]] = "tf.Cast"(%[[int_2]]) {Truncate = false} : (tensor<1xi32>) -> tensor<1x!tf_type.qint32>
+  // CHECK: %[[qint_2:.*]] = "tf.Cast"(%[[int_2]]) <{Truncate = false}> : (tensor<1xi32>) -> tensor<1x!tf_type.qint32>
   // CHECK: %[[int_3:.*]] = "tf.UniformDequantize"(%[[qint_2]]
   %2 = "tf.UniformDequantize"(%1, %scales1, %zps1) {
     quantization_axis = -1 : i64, quantization_min_val = -128 : i64, quantization_max_val = 127 : i64
@@ -423,10 +423,10 @@
   %scale = "tf.Const"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32>
   %zp = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
 
-  // CHECK-DAG: %[[MIN_QINT:.*]] = "tf.Cast"(%arg1) {Truncate = false} : (tensor<i32>) -> tensor<!tf_type.qint32>
+  // CHECK-DAG: %[[MIN_QINT:.*]] = "tf.Cast"(%arg1) <{Truncate = false}> : (tensor<i32>) -> tensor<!tf_type.qint32>
   %q_min = "tf.Cast"(%arg1) {Truncate = false} : (tensor<i32>) -> tensor<!tf_type.qint32>
 
-  // CHECK-DAG: %[[INPUT_QINT:.*]] = "tf.Cast"(%arg0) {Truncate = false} : (tensor<1x2x2x1xi32>) -> tensor<1x2x2x1x!tf_type.qint32>
+  // CHECK-DAG: %[[INPUT_QINT:.*]] = "tf.Cast"(%arg0) <{Truncate = false}> : (tensor<1x2x2x1xi32>) -> tensor<1x2x2x1x!tf_type.qint32>
   // CHECK: "tf.UniformQuantizedClipByValue"(%[[INPUT_QINT]], %[[MIN_QINT]], %[[MIN_QINT]]
   %output = "tf.UniformQuantizedClipByValue"(%arg0, %q_min, %q_min, %scale, %zp)
     {quantization_axis = -1 : i64, quantization_max_val = 2147483647 : i64, quantization_min_val = -2147483648 : i64} :
@@ -441,11 +441,11 @@
   %scale = "tf.Const"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32>
   %zp = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
 
-  // CHECK-DAG: %[[INPUT_QINT:.*]] = "tf.Cast"(%arg0) {Truncate = false} : (tensor<1x2x2x1xi32>) -> tensor<1x2x2x1x!tf_type.qint32>
+  // CHECK-DAG: %[[INPUT_QINT:.*]] = "tf.Cast"(%arg0) <{Truncate = false}> : (tensor<1x2x2x1xi32>) -> tensor<1x2x2x1x!tf_type.qint32>
   %q_input = "tf.Cast"(%arg0) {Truncate = false} : (tensor<1x2x2x1xi32>) -> tensor<1x2x2x1x!tf_type.qint32>
 
-  // CHECK-DAG: %[[MIN_QINT:.*]] = "tf.Cast"(%arg1) {Truncate = false} : (tensor<i32>) -> tensor<!tf_type.qint32>
-  // CHECK-DAG: %[[MAX_QINT:.*]] = "tf.Cast"(%arg1) {Truncate = false} : (tensor<i32>) -> tensor<!tf_type.qint32>
+  // CHECK-DAG: %[[MIN_QINT:.*]] = "tf.Cast"(%arg1) <{Truncate = false}> : (tensor<i32>) -> tensor<!tf_type.qint32>
+  // CHECK-DAG: %[[MAX_QINT:.*]] = "tf.Cast"(%arg1) <{Truncate = false}> : (tensor<i32>) -> tensor<!tf_type.qint32>
   // CHECK: "tf.UniformQuantizedClipByValue"(%[[INPUT_QINT]], %[[MIN_QINT]], %[[MAX_QINT]]
   %output = "tf.UniformQuantizedClipByValue"(%q_input, %arg1, %arg1, %scale, %zp)
     {quantization_axis = -1 : i64, quantization_max_val = 2147483647 : i64, quantization_min_val = -2147483648 : i64} :
@@ -460,15 +460,15 @@
   %scale = "tf.Const"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32>
   %zp = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
 
-  // CHECK-DAG: %[[INPUT_QINT:.*]] = "tf.Cast"(%arg0) {Truncate = false} : (tensor<1x2x2x1xi32>) -> tensor<1x2x2x1x!tf_type.qint32>
-  // CHECK-DAG: %[[MIN_QINT:.*]] = "tf.Cast"(%arg1) {Truncate = false} : (tensor<i32>) -> tensor<!tf_type.qint32>
-  // CHECK-DAG: %[[MAX_QINT:.*]] = "tf.Cast"(%arg1) {Truncate = false} : (tensor<i32>) -> tensor<!tf_type.qint32>
+  // CHECK-DAG: %[[INPUT_QINT:.*]] = "tf.Cast"(%arg0) <{Truncate = false}> : (tensor<1x2x2x1xi32>) -> tensor<1x2x2x1x!tf_type.qint32>
+  // CHECK-DAG: %[[MIN_QINT:.*]] = "tf.Cast"(%arg1) <{Truncate = false}> : (tensor<i32>) -> tensor<!tf_type.qint32>
+  // CHECK-DAG: %[[MAX_QINT:.*]] = "tf.Cast"(%arg1) <{Truncate = false}> : (tensor<i32>) -> tensor<!tf_type.qint32>
   // CHECK: %[[OUTPUT_QINT:.*]] = "tf.UniformQuantizedClipByValue"(%[[INPUT_QINT]], %[[MIN_QINT]], %[[MAX_QINT]]
   %q_output = "tf.UniformQuantizedClipByValue"(%arg0, %arg1, %arg1, %scale, %zp)
     {quantization_axis = -1 : i64, quantization_max_val = 2147483647 : i64, quantization_min_val = -2147483648 : i64} :
     (tensor<1x2x2x1x!tf_type.qint32>, tensor<!tf_type.qint32>, tensor<!tf_type.qint32>, tensor<f32>, tensor<i32>) -> tensor<1x2x2x1x!tf_type.qint32>
 
-  // CHECK: %[[OUTPUT:.*]] = "tf.Cast"(%[[OUTPUT_QINT]]) {Truncate = false} : (tensor<1x2x2x1x!tf_type.qint32>) -> tensor<1x2x2x1xi32>
+  // CHECK: %[[OUTPUT:.*]] = "tf.Cast"(%[[OUTPUT_QINT]]) <{Truncate = false}> : (tensor<1x2x2x1x!tf_type.qint32>) -> tensor<1x2x2x1xi32>
   %output = "tf.Cast"(%q_output) {Truncate = false} : (tensor<1x2x2x1x!tf_type.qint32>) -> tensor<1x2x2x1xi32>
 
   return %output : tensor<1x2x2x1xi32>
@@ -481,19 +481,19 @@
   %scale = "tf.Const"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32>
   %zp = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
 
-  // CHECK-DAG: %[[INPUT_QINT:.*]] = "tf.Cast"(%arg0) {Truncate = false} : (tensor<1x2x2x1xi32>) -> tensor<1x2x2x1x!tf_type.qint32>
-  // CHECK-DAG: %[[MIN_QINT:.*]] = "tf.Cast"(%arg1) {Truncate = false} : (tensor<i32>) -> tensor<!tf_type.qint32>
-  // CHECK-DAG: %[[MAX_QINT:.*]] = "tf.Cast"(%arg1) {Truncate = false} : (tensor<i32>) -> tensor<!tf_type.qint32>
+  // CHECK-DAG: %[[INPUT_QINT:.*]] = "tf.Cast"(%arg0) <{Truncate = false}> : (tensor<1x2x2x1xi32>) -> tensor<1x2x2x1x!tf_type.qint32>
+  // CHECK-DAG: %[[MIN_QINT:.*]] = "tf.Cast"(%arg1) <{Truncate = false}> : (tensor<i32>) -> tensor<!tf_type.qint32>
+  // CHECK-DAG: %[[MAX_QINT:.*]] = "tf.Cast"(%arg1) <{Truncate = false}> : (tensor<i32>) -> tensor<!tf_type.qint32>
   // CHECK: %[[OUTPUT_QINT:.*]] = "tf.UniformQuantizedClipByValue"(%[[INPUT_QINT]], %[[MIN_QINT]], %[[MAX_QINT]]
   %q_output = "tf.UniformQuantizedClipByValue"(%arg0, %arg1, %arg1, %scale, %zp)
     {quantization_axis = -1 : i64, quantization_max_val = 2147483647 : i64, quantization_min_val = -2147483648 : i64} :
     (tensor<1x2x2x1x!tf_type.qint32>, tensor<!tf_type.qint32>, tensor<!tf_type.qint32>, tensor<f32>, tensor<i32>) -> tensor<1x2x2x1x!tf_type.qint32>
 
-  // CHECK-DAG: %[[OUTPUT_1:.*]] = "tf.Cast"(%[[OUTPUT_QINT]]) {Truncate = false} : (tensor<1x2x2x1x!tf_type.qint32>) -> tensor<1x2x2x1xi32>
+  // CHECK-DAG: %[[OUTPUT_1:.*]] = "tf.Cast"(%[[OUTPUT_QINT]]) <{Truncate = false}> : (tensor<1x2x2x1x!tf_type.qint32>) -> tensor<1x2x2x1xi32>
   %output = "tf.Cast"(%q_output) {Truncate = false} : (tensor<1x2x2x1x!tf_type.qint32>) -> tensor<1x2x2x1xi32>
 
-  // CHECK-DAG: %[[OUTPUT_2:.*]] = "tf.Cast"(%[[OUTPUT_QINT]]) {Truncate = false} : (tensor<1x2x2x1x!tf_type.qint32>) -> tensor<1x2x2x1xi32>
-  // CHECK-DAG: %[[OUTPUT_QINT_1:.*]] = "tf.Cast"(%[[OUTPUT_1]]) {Truncate = false} : (tensor<1x2x2x1xi32>) -> tensor<1x2x2x1x!tf_type.qint32>
+  // CHECK-DAG: %[[OUTPUT_2:.*]] = "tf.Cast"(%[[OUTPUT_QINT]]) <{Truncate = false}> : (tensor<1x2x2x1x!tf_type.qint32>) -> tensor<1x2x2x1xi32>
+  // CHECK-DAG: %[[OUTPUT_QINT_1:.*]] = "tf.Cast"(%[[OUTPUT_1]]) <{Truncate = false}> : (tensor<1x2x2x1xi32>) -> tensor<1x2x2x1x!tf_type.qint32>
   // CHECK: "tf.UniformDequantize"(%[[OUTPUT_QINT_1:.*]]
   %dq = "tf.UniformDequantize"(%q_output, %scale, %zp) {
     quantization_axis = -1 : i64, quantization_min_val = -128 : i64, quantization_max_val = 127 : i64
diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/lift_quantizable_spots_as_functions.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/lift_quantizable_spots_as_functions.mlir
index 3e588f5..f085a38 100644
--- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/lift_quantizable_spots_as_functions.mlir
+++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/lift_quantizable_spots_as_functions.mlir
@@ -1,13 +1,14 @@
 // RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-lift-quantizable-spots-as-functions | FileCheck %s
 
 // CHECK-LABEL: @conv_fn(
-// CHECK-SAME:          %[[ARG_0:.*]]: tensor<1x3x3x4xf32>,
-// CHECK-SAME:          %[[ARG_1:.*]]: tensor<3x3x4x4xf32>)
-func.func @conv_fn(%arg0: tensor<1x3x3x4xf32>, %arg1: tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> {
-  %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32>
-  func.return %0: tensor<1x3x3x4xf32>
+// CHECK-SAME:          %[[ARG_0:.*]]: tensor<1x3x3x4xf32>
+func.func @conv_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> {
+  %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x4x4xf32>
+  %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32>
+  func.return %1: tensor<1x3x3x4xf32>
 }
-// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %arg1)
+// CHECK: %[[CONST:.*]] = stablehlo.constant dense<2.000000e+00>
+// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST]])
 // CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x3x3x4xf32>
 // CHECK: }
 
@@ -19,13 +20,14 @@
 // -----
 
 // CHECK-LABEL: @dot_general_fn(
-// CHECK-SAME:                 %[[ARG_0:.*]]: tensor<1x1x167xf32>,
-// CHECK-SAME:                 %[[ARG_1:.*]]: tensor<167x64xf32>
-func.func @dot_general_fn(%arg0: tensor<1x1x167xf32>, %arg1: tensor<167x64xf32>) -> tensor<1x1x64xf32> {
-  %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32>
-  return %0 : tensor<1x1x64xf32>
+// CHECK-SAME:                 %[[ARG_0:.*]]: tensor<1x1x167xf32>
+func.func @dot_general_fn(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> {
+  %0 = stablehlo.constant dense<2.000000e+00> : tensor<167x64xf32>
+  %1 = stablehlo.dot_general %arg0, %0, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32>
+  return %1 : tensor<1x1x64xf32>
 }
-// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %arg1)
+// CHECK: %[[CONST:.*]] = stablehlo.constant dense<2.000000e+00>
+// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST]])
 // CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x1x64xf32>
 // CHECK: }
 
@@ -37,15 +39,17 @@
 // -----
 
 // CHECK-LABEL: @conv_with_bias_fn(
-// CHECK-SAME:                    %[[ARG_0:.*]]: tensor<1x3x3x4xf32>,
-// CHECK-SAME:                    %[[ARG_1:.*]]: tensor<3x3x4x4xf32>,
-// CHECK-SAME:                    %[[ARG_2:.*]]: tensor<1x3x3x4xf32>)
-func.func @conv_with_bias_fn(%arg0: tensor<1x3x3x4xf32>, %arg1: tensor<3x3x4x4xf32>, %arg2: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> {
-  %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32>
-  %1 = stablehlo.add %0, %arg2 : tensor<1x3x3x4xf32>
-  func.return %1: tensor<1x3x3x4xf32>
+// CHECK-SAME:                    %[[ARG_0:.*]]: tensor<1x3x3x4xf32>
+func.func @conv_with_bias_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> {
+  %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x4x4xf32>
+  %1 = stablehlo.constant dense<2.000000e+00> : tensor<1x3x3x4xf32>
+  %2 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32>
+  %3 = stablehlo.add %2, %1 : tensor<1x3x3x4xf32>
+  func.return %3: tensor<1x3x3x4xf32>
 }
-// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %arg1, %arg2)
+// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00>
+// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00>
+// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]])
 // CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x3x3x4xf32>
 // CHECK: }
 
@@ -58,15 +62,17 @@
 // -----
 
 // CHECK-LABEL: @dot_general_with_bias_fn(
-// CHECK-SAME:                    %[[ARG_0:.*]]: tensor<1x1x167xf32>,
-// CHECK-SAME:                    %[[ARG_1:.*]]: tensor<167x64xf32>
-// CHECK-SAME:                    %[[ARG_2:.*]]: tensor<1x1x64xf32>)
-func.func @dot_general_with_bias_fn(%arg0: tensor<1x1x167xf32>, %arg1: tensor<167x64xf32>, %arg2: tensor<1x1x64xf32>) -> tensor<1x1x64xf32> {
-  %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32>
-  %1 = stablehlo.add %0, %arg2 : tensor<1x1x64xf32>
-  func.return %1: tensor<1x1x64xf32>
+// CHECK-SAME:                    %[[ARG_0:.*]]: tensor<1x1x167xf32>
+func.func @dot_general_with_bias_fn(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> {
+  %0 = stablehlo.constant dense<2.000000e+00> : tensor<167x64xf32>
+  %1 = stablehlo.constant dense<2.000000e+00> : tensor<1x1x64xf32>
+  %2 = stablehlo.dot_general %arg0, %0, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32>
+  %3 = stablehlo.add %2, %1 : tensor<1x1x64xf32>
+  func.return %3: tensor<1x1x64xf32>
 }
-// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %arg1, %arg2)
+// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00>
+// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00>
+// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]])
 // CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x1x64xf32>
 // CHECK: }
 
@@ -78,16 +84,71 @@
 
 // -----
 
-// CHECK-LABEL: @conv_with_relu_fn(
-// CHECK-SAME:                    %[[ARG_0:.*]]: tensor<1x3x3x4xf32>,
-// CHECK-SAME:                    %[[ARG_1:.*]]: tensor<3x3x4x4xf32>)
-func.func @conv_with_relu_fn(%arg0: tensor<1x3x3x4xf32>, %arg1: tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> {
-  %0 = stablehlo.constant dense<0.000000e+00> : tensor<1x3x3x4xf32>
-  %1 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32>
-  %2 = stablehlo.maximum %1, %0 : tensor<1x3x3x4xf32>
-  func.return %2: tensor<1x3x3x4xf32>
+// CHECK-LABEL: @conv_with_bias_dynamic_fn(
+// CHECK-SAME:                    %[[ARG_0:.*]]: tensor<?x28x28x1xf32>
+func.func @conv_with_bias_dynamic_fn(%arg0: tensor<?x28x28x1xf32>) -> tensor<?x28x28x16xf32> {
+  %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x1x16xf32>
+  %1 = stablehlo.constant dense<2.000000e+00> : tensor<16xf32>
+  %2 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<?x28x28x1xf32>, tensor<3x3x1x16xf32>) -> tensor<?x28x28x16xf32>
+  %3 = shape.shape_of %2 : tensor<?x28x28x16xf32> -> tensor<4xindex>
+  %4 = stablehlo.dynamic_broadcast_in_dim %1, %3, dims = [3] : (tensor<16xf32>, tensor<4xindex>) -> tensor<?x28x28x16xf32>
+  %5 = stablehlo.add %2, %4 : tensor<?x28x28x16xf32>
+  func.return %5: tensor<?x28x28x16xf32>
 }
-// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %arg1)
+// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00>
+// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00>
+// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]])
+// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<?x28x28x16xf32>
+// CHECK: }
+
+// CHECK-LABEL: private @composite_conv_with_bias_dynamic_fn_1
+// CHECK: %[[CONV:.*]] = stablehlo.convolution(%arg0, %arg1)
+// CHECK: %[[SHAPE_OF:.*]] = shape.shape_of %[[CONV]]
+// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM:.*]] = stablehlo.dynamic_broadcast_in_dim %arg2, %[[SHAPE_OF]]
+// CHECK: %[[ADD:.*]] = stablehlo.add %[[CONV]], %[[DYNAMIC_BROADCAST_IN_DIM]]
+// CHECK: return %[[ADD]] : tensor<?x28x28x16xf32>
+// CHECK: }
+
+// -----
+
+// CHECK-LABEL: @dot_general_with_bias_dynamic_fn(
+// CHECK-SAME:                    %[[ARG_0:.*]]: tensor<?x12544xf32>
+func.func @dot_general_with_bias_dynamic_fn(%arg0: tensor<?x12544xf32>) -> tensor<?x10xf32> {
+  %0 = stablehlo.constant dense<2.000000e+00> : tensor<12544x10xf32>
+  %1 = stablehlo.constant dense<2.000000e+00> : tensor<10xf32>
+  %2 = stablehlo.dot_general %arg0, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<?x12544xf32>, tensor<12544x10xf32>) -> tensor<?x10xf32>
+  %3 = shape.shape_of %2 : tensor<?x10xf32> -> tensor<2xindex>
+  %4 = stablehlo.dynamic_broadcast_in_dim %1, %3, dims = [1] : (tensor<10xf32>, tensor<2xindex>) -> tensor<?x10xf32>
+  %5 = stablehlo.add %2, %4 : tensor<?x10xf32>
+  func.return %5: tensor<?x10xf32>
+}
+// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00>
+// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00>
+// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]])
+// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<?x10xf32>
+// CHECK: }
+
+// CHECK-LABEL: private @composite_dot_general_with_bias_dynamic_fn_1
+// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1
+// CHECK: %[[SHAPE_OF_0:.*]] = shape.shape_of %[[DOT_GENERAL]]
+// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_0:.*]] = stablehlo.dynamic_broadcast_in_dim %arg2, %[[SHAPE_OF_0]]
+// CHECK: %[[ADD:.*]] = stablehlo.add %[[DOT_GENERAL]], %[[DYNAMIC_BROADCAST_IN_DIM_0]]
+// CHECK: return %[[ADD]] : tensor<?x10xf32>
+// CHECK: }
+
+// -----
+
+// CHECK-LABEL: @conv_with_relu_fn(
+// CHECK-SAME:                    %[[ARG_0:.*]]: tensor<1x3x3x4xf32>
+func.func @conv_with_relu_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> {
+  %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x4x4xf32>
+  %1 = stablehlo.constant dense<0.000000e+00> : tensor<1x3x3x4xf32>
+  %2 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32>
+  %3 = stablehlo.maximum %2, %1 : tensor<1x3x3x4xf32>
+  func.return %3: tensor<1x3x3x4xf32>
+}
+// CHECK: %[[CONST:.*]] = stablehlo.constant dense<2.000000e+00>
+// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST]])
 // CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x3x3x4xf32>
 // CHECK: }
 
@@ -102,14 +163,15 @@
 
 // CHECK-LABEL: @dot_general_with_relu_fn(
 // CHECK-SAME:                 %[[ARG_0:.*]]: tensor<1x1x167xf32>,
-// CHECK-SAME:                 %[[ARG_1:.*]]: tensor<167x64xf32>
 func.func @dot_general_with_relu_fn(%arg0: tensor<1x1x167xf32>, %arg1: tensor<167x64xf32>) -> tensor<1x1x64xf32> {
-  %0 = stablehlo.constant dense<0.000000e+00> : tensor<1x1x64xf32>
-  %1 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32>
-  %2 = stablehlo.maximum %1, %0 : tensor<1x1x64xf32>
-  return %2 : tensor<1x1x64xf32>
+  %0 = stablehlo.constant dense<2.000000e+00> : tensor<167x64xf32>
+  %1 = stablehlo.constant dense<0.000000e+00> : tensor<1x1x64xf32>
+  %2 = stablehlo.dot_general %arg0, %0, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32>
+  %3 = stablehlo.maximum %2, %1 : tensor<1x1x64xf32>
+  return %3 : tensor<1x1x64xf32>
 }
-// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %arg1)
+// CHECK: %[[CONST:.*]] = stablehlo.constant dense<2.000000e+00>
+// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST]])
 // CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x1x64xf32>
 // CHECK: }
 
@@ -122,36 +184,95 @@
 
 // -----
 
+// CHECK-LABEL: @conv_with_relu_dynamic_fn(
+// CHECK-SAME:                    %[[ARG_0:.*]]: tensor<?x28x28x1xf32>
+func.func @conv_with_relu_dynamic_fn(%arg0: tensor<?x28x28x1xf32>) -> tensor<?x28x28x16xf32> {
+  %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x1x16xf32>
+  %1 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
+  %2 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<?x28x28x1xf32>, tensor<3x3x1x16xf32>) -> tensor<?x28x28x16xf32>
+  %3 = shape.shape_of %2 : tensor<?x28x28x16xf32> -> tensor<4xindex>
+  %4 = stablehlo.dynamic_broadcast_in_dim %1, %3, dims = [] : (tensor<f32>, tensor<4xindex>) -> tensor<?x28x28x16xf32>
+  %5 = stablehlo.maximum %2, %4 : tensor<?x28x28x16xf32>
+  func.return %5: tensor<?x28x28x16xf32>
+}
+// CHECK: %[[CONST:.*]] = stablehlo.constant dense<2.000000e+00>
+// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST]])
+// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<?x28x28x16xf32>
+// CHECK: }
+
+// CHECK-LABEL: private @composite_conv_with_relu_dynamic_fn_1
+// CHECK: %[[CONST:.*]] = stablehlo.constant dense<0.000000e+00>
+// CHECK: %[[CONV:.*]] = stablehlo.convolution(%arg0, %arg1)
+// CHECK: %[[SHAPE_OF:.*]] = shape.shape_of %[[CONV]]
+// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONST]], %[[SHAPE_OF]]
+// CHECK: %[[MAX:.*]] = stablehlo.maximum %[[CONV]], %[[DYNAMIC_BROADCAST_IN_DIM]]
+// CHECK: return %[[MAX]] : tensor<?x28x28x16xf32>
+// CHECK: }
+
+// -----
+
+// CHECK-LABEL: @dot_general_with_relu_dynamic_fn(
+// CHECK-SAME:                    %[[ARG_0:.*]]: tensor<?x12544xf32>
+func.func @dot_general_with_relu_dynamic_fn(%arg0: tensor<?x12544xf32>) -> tensor<?x10xf32> {
+  %0 = stablehlo.constant dense<2.000000e+00> : tensor<12544x10xf32>
+  %1 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
+  %2 = stablehlo.dot_general %arg0, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<?x12544xf32>, tensor<12544x10xf32>) -> tensor<?x10xf32>
+  %3 = shape.shape_of %2 : tensor<?x10xf32> -> tensor<2xindex>
+  %4 = stablehlo.dynamic_broadcast_in_dim %1, %3, dims = [] : (tensor<f32>, tensor<2xindex>) -> tensor<?x10xf32>
+  %5 = stablehlo.maximum %2, %4 : tensor<?x10xf32>
+  func.return %5: tensor<?x10xf32>
+}
+// CHECK: %[[CONST:.*]] = stablehlo.constant dense<2.000000e+00>
+// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST]])
+// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<?x10xf32>
+// CHECK: }
+
+// CHECK-LABEL: private @composite_dot_general_with_relu_dynamic_fn_1
+// CHECK: %[[CONST:.*]] = stablehlo.constant dense<0.000000e+00>
+// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1
+// CHECK: %[[SHAPE_OF:.*]] = shape.shape_of %[[DOT_GENERAL]]
+// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONST]], %[[SHAPE_OF]]
+// CHECK: %[[MAX:.*]] = stablehlo.maximum %[[DOT_GENERAL]], %[[DYNAMIC_BROADCAST_IN_DIM]]
+// CHECK: return %[[MAX]] : tensor<?x10xf32>
+// CHECK: }
+
+// -----
+
 // The pattern should not match when the const value for relu is not 0.
 
 // CHECK-LABEL: @conv_with_relu_wrong_const_fn(
-// CHECK-SAME:                    %[[ARG_0:.*]]: tensor<1x3x3x4xf32>,
-// CHECK-SAME:                    %[[ARG_1:.*]]: tensor<3x3x4x4xf32>)
+// CHECK-SAME:                    %[[ARG_0:.*]]: tensor<1x3x3x4xf32>
 func.func @conv_with_relu_wrong_const_fn(%arg0: tensor<1x3x3x4xf32>, %arg1: tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> {
-  %0 = stablehlo.constant dense<2.000000e+00> : tensor<1x3x3x4xf32>
-  %1 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32>
-  %2 = stablehlo.maximum %1, %0 : tensor<1x3x3x4xf32>
-  func.return %2: tensor<1x3x3x4xf32>
+  %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x4x4xf32>
+  %1 = stablehlo.constant dense<2.000000e+00> : tensor<1x3x3x4xf32>
+  %2 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32>
+  %3 = stablehlo.maximum %2, %1 : tensor<1x3x3x4xf32>
+  func.return %3: tensor<1x3x3x4xf32>
 }
-// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %arg1)
-// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x3x3x4xf32>
+// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00>
+// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00>
+// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]])
+// CHECK: %[[MAX:.*]] = stablehlo.maximum %[[XLA_CALL_MODULE]], %[[CONST_1]]
+// CHECK: return %[[MAX]] : tensor<1x3x3x4xf32>
 // CHECK: }
 
 // CHECK-LABEL: private @composite_conv_fn_1
+// CHECK-NOT: private @composite_conv_with_relu_fn_1
 
 // -----
 
 // CHECK-LABEL: @conv_with_relu6_fn(
-// CHECK-SAME:                    %[[ARG_0:.*]]: tensor<1x3x3x4xf32>,
-// CHECK-SAME:                    %[[ARG_1:.*]]: tensor<3x3x4x4xf32>)
-func.func @conv_with_relu6_fn(%arg0: tensor<1x3x3x4xf32>, %arg1: tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> {
-  %0 = stablehlo.constant dense<0.000000e+00> : tensor<1x3x3x4xf32>
-  %1 = stablehlo.constant dense<6.000000e+00> : tensor<1x3x3x4xf32>
-  %2 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32>
-  %3 = stablehlo.clamp %0, %2, %1 : tensor<1x3x3x4xf32>
-  func.return %3: tensor<1x3x3x4xf32>
+// CHECK-SAME:                    %[[ARG_0:.*]]: tensor<1x3x3x4xf32>
+func.func @conv_with_relu6_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> {
+  %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x4x4xf32>
+  %1 = stablehlo.constant dense<0.000000e+00> : tensor<1x3x3x4xf32>
+  %2 = stablehlo.constant dense<6.000000e+00> : tensor<1x3x3x4xf32>
+  %3 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32>
+  %4 = stablehlo.clamp %1, %3, %2 : tensor<1x3x3x4xf32>
+  func.return %4: tensor<1x3x3x4xf32>
 }
-// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %arg1)
+// CHECK: %[[CONST:.*]] = stablehlo.constant dense<2.000000e+00>
+// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST]])
 // CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x3x3x4xf32>
 // CHECK: }
 
@@ -166,16 +287,17 @@
 // -----
 
 // CHECK-LABEL: @dot_general_with_relu6_fn(
-// CHECK-SAME:                 %[[ARG_0:.*]]: tensor<1x1x167xf32>,
-// CHECK-SAME:                 %[[ARG_1:.*]]: tensor<167x64xf32>
-func.func @dot_general_with_relu6_fn(%arg0: tensor<1x1x167xf32>, %arg1: tensor<167x64xf32>) -> tensor<1x1x64xf32> {
-  %0 = stablehlo.constant dense<0.000000e+00> : tensor<1x1x64xf32>
-  %1 = stablehlo.constant dense<6.000000e+00> : tensor<1x1x64xf32>
-  %2 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32>
-  %3 = stablehlo.clamp %0, %2, %1 : tensor<1x1x64xf32>
-  return %3 : tensor<1x1x64xf32>
+// CHECK-SAME:                 %[[ARG_0:.*]]: tensor<1x1x167xf32>
+func.func @dot_general_with_relu6_fn(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> {
+  %0 = stablehlo.constant dense<2.000000e+00> : tensor<167x64xf32>
+  %1 = stablehlo.constant dense<0.000000e+00> : tensor<1x1x64xf32>
+  %2 = stablehlo.constant dense<6.000000e+00> : tensor<1x1x64xf32>
+  %3 = stablehlo.dot_general %arg0, %0, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32>
+  %4 = stablehlo.clamp %1, %3, %2 : tensor<1x1x64xf32>
+  return %4 : tensor<1x1x64xf32>
 }
-// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %arg1)
+// CHECK: %[[CONST:.*]] = stablehlo.constant dense<2.000000e+00>
+// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST]])
 // CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x1x64xf32>
 // CHECK: }
 
@@ -190,17 +312,19 @@
 // -----
 
 // CHECK-LABEL: @conv_with_bias_and_relu_fn(
-// CHECK-SAME:                    %[[ARG_0:.*]]: tensor<1x3x3x4xf32>,
-// CHECK-SAME:                    %[[ARG_1:.*]]: tensor<3x3x4x4xf32>,
-// CHECK-SAME:                    %[[ARG_2:.*]]: tensor<1x3x3x4xf32>)
-func.func @conv_with_bias_and_relu_fn(%arg0: tensor<1x3x3x4xf32>, %arg1: tensor<3x3x4x4xf32>, %arg2: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> {
-  %0 = stablehlo.constant dense<0.000000e+00> : tensor<1x3x3x4xf32>
-  %1 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32>
-  %2 = stablehlo.add %1, %arg2 : tensor<1x3x3x4xf32>
-  %3 = stablehlo.maximum %2, %0 : tensor<1x3x3x4xf32>
-  func.return %3: tensor<1x3x3x4xf32>
+// CHECK-SAME:                    %[[ARG_0:.*]]: tensor<1x3x3x4xf32>
+func.func @conv_with_bias_and_relu_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> {
+  %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x4x4xf32>
+  %1 = stablehlo.constant dense<2.000000e+00> : tensor<1x3x3x4xf32>
+  %2 = stablehlo.constant dense<0.000000e+00> : tensor<1x3x3x4xf32>
+  %3 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32>
+  %4 = stablehlo.add %3, %1 : tensor<1x3x3x4xf32>
+  %5 = stablehlo.maximum %4, %2 : tensor<1x3x3x4xf32>
+  func.return %5: tensor<1x3x3x4xf32>
 }
-// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %arg1, %arg2)
+// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00>
+// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00>
+// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]])
 // CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x3x3x4xf32>
 // CHECK: }
 
@@ -215,17 +339,19 @@
 // -----
 
 // CHECK-LABEL: @dot_general_with_bias_and_relu_fn(
-// CHECK-SAME:                    %[[ARG_0:.*]]: tensor<1x1x167xf32>,
-// CHECK-SAME:                    %[[ARG_1:.*]]: tensor<167x64xf32>
-// CHECK-SAME:                    %[[ARG_2:.*]]: tensor<1x1x64xf32>)
-func.func @dot_general_with_bias_and_relu_fn(%arg0: tensor<1x1x167xf32>, %arg1: tensor<167x64xf32>, %arg2: tensor<1x1x64xf32>) -> tensor<1x1x64xf32> {
-  %0 = stablehlo.constant dense<0.000000e+00> : tensor<1x1x64xf32>
-  %1 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32>
-  %2 = stablehlo.add %1, %arg2 : tensor<1x1x64xf32>
-  %3 = stablehlo.maximum %2, %0 : tensor<1x1x64xf32>
-  func.return %3: tensor<1x1x64xf32>
+// CHECK-SAME:                    %[[ARG_0:.*]]: tensor<1x1x167xf32>
+func.func @dot_general_with_bias_and_relu_fn(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> {
+  %0 = stablehlo.constant dense<2.000000e+00> : tensor<167x64xf32>
+  %1 = stablehlo.constant dense<2.000000e+00> : tensor<1x1x64xf32>
+  %2 = stablehlo.constant dense<0.000000e+00> : tensor<1x1x64xf32>
+  %3 = stablehlo.dot_general %arg0, %0, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32>
+  %4 = stablehlo.add %3, %1 : tensor<1x1x64xf32>
+  %5 = stablehlo.maximum %4, %2 : tensor<1x1x64xf32>
+  func.return %5: tensor<1x1x64xf32>
 }
-// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %arg1, %arg2)
+// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00>
+// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00>
+// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]])
 // CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x1x64xf32>
 // CHECK: }
 
@@ -239,19 +365,91 @@
 
 // -----
 
-// CHECK-LABEL: @conv_with_bias_and_relu6_fn(
-// CHECK-SAME:                    %[[ARG_0:.*]]: tensor<1x3x3x4xf32>,
-// CHECK-SAME:                    %[[ARG_1:.*]]: tensor<3x3x4x4xf32>,
-// CHECK-SAME:                    %[[ARG_2:.*]]: tensor<1x3x3x4xf32>)
-func.func @conv_with_bias_and_relu6_fn(%arg0: tensor<1x3x3x4xf32>, %arg1: tensor<3x3x4x4xf32>, %arg2: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> {
-  %0 = stablehlo.constant dense<0.000000e+00> : tensor<1x3x3x4xf32>
-  %1 = stablehlo.constant dense<6.000000e+00> : tensor<1x3x3x4xf32>
-  %2 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32>
-  %3 = stablehlo.add %2, %arg2 : tensor<1x3x3x4xf32>
-  %4 = stablehlo.clamp %0, %3, %1 : tensor<1x3x3x4xf32>
-  func.return %4: tensor<1x3x3x4xf32>
+// CHECK-LABEL: @conv_with_bias_and_relu_dynamic_fn(
+// CHECK-SAME:                    %[[ARG_0:.*]]: tensor<?x28x28x1xf32>
+func.func @conv_with_bias_and_relu_dynamic_fn(%arg0: tensor<?x28x28x1xf32>) -> tensor<?x28x28x16xf32> {
+  %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x1x16xf32>
+  %1 = stablehlo.constant dense<2.000000e+00> : tensor<16xf32>
+  %2 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
+  %3 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<?x28x28x1xf32>, tensor<3x3x1x16xf32>) -> tensor<?x28x28x16xf32>
+  %4 = shape.shape_of %3 : tensor<?x28x28x16xf32> -> tensor<4xindex>
+  %5 = stablehlo.dynamic_broadcast_in_dim %1, %4, dims = [3] : (tensor<16xf32>, tensor<4xindex>) -> tensor<?x28x28x16xf32>
+  %6 = stablehlo.add %3, %5 : tensor<?x28x28x16xf32>
+  %7 = shape.shape_of %6 : tensor<?x28x28x16xf32> -> tensor<4xindex>
+  %8 = stablehlo.dynamic_broadcast_in_dim %2, %7, dims = [] : (tensor<f32>, tensor<4xindex>) -> tensor<?x28x28x16xf32>
+  %9 = stablehlo.maximum %6, %8 : tensor<?x28x28x16xf32>
+  func.return %9: tensor<?x28x28x16xf32>
 }
-// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %arg1, %arg2)
+// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00>
+// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00>
+// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]])
+// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<?x28x28x16xf32>
+// CHECK: }
+
+// CHECK-LABEL: private @composite_conv_with_bias_and_relu_dynamic_fn_1
+// CHECK: %[[CONST:.*]] = stablehlo.constant dense<0.000000e+00>
+// CHECK: %[[CONV:.*]] = stablehlo.convolution(%arg0, %arg1)
+// CHECK: %[[SHAPE_OF_0:.*]] = shape.shape_of %[[CONV]]
+// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_0:.*]] = stablehlo.dynamic_broadcast_in_dim %arg2, %[[SHAPE_OF_0]]
+// CHECK: %[[ADD:.*]] = stablehlo.add %[[CONV]], %[[DYNAMIC_BROADCAST_IN_DIM_0]]
+// CHECK: %[[SHAPE_OF_1:.*]] = shape.shape_of %[[ADD]]
+// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_1:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONST]], %[[SHAPE_OF_1]]
+// CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ADD]], %[[DYNAMIC_BROADCAST_IN_DIM_1]]
+// CHECK: return %[[MAX]] : tensor<?x28x28x16xf32>
+// CHECK: }
+
+// -----
+
+// CHECK-LABEL: @dot_general_with_bias_and_relu_dynamic_fn(
+// CHECK-SAME:                    %[[ARG_0:.*]]: tensor<?x12544xf32>
+func.func @dot_general_with_bias_and_relu_dynamic_fn(%arg0: tensor<?x12544xf32>) -> tensor<?x10xf32> {
+  %0 = stablehlo.constant dense<2.000000e+00> : tensor<12544x10xf32>
+  %1 = stablehlo.constant dense<2.000000e+00> : tensor<10xf32>
+  %2 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
+  %3 = stablehlo.dot_general %arg0, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<?x12544xf32>, tensor<12544x10xf32>) -> tensor<?x10xf32>
+  %4 = shape.shape_of %3 : tensor<?x10xf32> -> tensor<2xindex>
+  %5 = stablehlo.dynamic_broadcast_in_dim %1, %4, dims = [1] : (tensor<10xf32>, tensor<2xindex>) -> tensor<?x10xf32>
+  %6 = stablehlo.add %3, %5 : tensor<?x10xf32>
+  %7 = shape.shape_of %6 : tensor<?x10xf32> -> tensor<2xindex>
+  %8 = stablehlo.dynamic_broadcast_in_dim %2, %7, dims = [] : (tensor<f32>, tensor<2xindex>) -> tensor<?x10xf32>
+  %9 = stablehlo.maximum %6, %8 : tensor<?x10xf32>
+  func.return %9: tensor<?x10xf32>
+}
+// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00>
+// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00>
+// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]])
+// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<?x10xf32>
+// CHECK: }
+
+// CHECK-LABEL: private @composite_dot_general_with_bias_and_relu_dynamic_fn_1
+// CHECK: %[[CONST:.*]] = stablehlo.constant dense<0.000000e+00>
+// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1
+// CHECK: %[[SHAPE_OF_0:.*]] = shape.shape_of %[[DOT_GENERAL]]
+// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_0:.*]] = stablehlo.dynamic_broadcast_in_dim %arg2, %[[SHAPE_OF_0]]
+// CHECK: %[[ADD:.*]] = stablehlo.add %[[DOT_GENERAL]], %[[DYNAMIC_BROADCAST_IN_DIM_0]]
+// CHECK: %[[SHAPE_OF_1:.*]] = shape.shape_of %[[ADD]]
+// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_1:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONST]], %[[SHAPE_OF_1]]
+// CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ADD]], %[[DYNAMIC_BROADCAST_IN_DIM_1]]
+// CHECK: return %[[MAX]] : tensor<?x10xf32>
+// CHECK: }
+
+// -----
+
+// CHECK-LABEL: @conv_with_bias_and_relu6_fn(
+// CHECK-SAME:                    %[[ARG_0:.*]]: tensor<1x3x3x4xf32>
+func.func @conv_with_bias_and_relu6_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> {
+  %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x4x4xf32>
+  %1 = stablehlo.constant dense<2.000000e+00> : tensor<1x3x3x4xf32>
+  %2 = stablehlo.constant dense<0.000000e+00> : tensor<1x3x3x4xf32>
+  %3 = stablehlo.constant dense<6.000000e+00> : tensor<1x3x3x4xf32>
+  %4 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32>
+  %5 = stablehlo.add %4, %1 : tensor<1x3x3x4xf32>
+  %6 = stablehlo.clamp %2, %5, %3 : tensor<1x3x3x4xf32>
+  func.return %6: tensor<1x3x3x4xf32>
+}
+// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00>
+// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00>
+// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]])
 // CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x3x3x4xf32>
 // CHECK: }
 
@@ -267,18 +465,20 @@
 // -----
 
 // CHECK-LABEL: @dot_general_with_bias_and_relu6_fn(
-// CHECK-SAME:                    %[[ARG_0:.*]]: tensor<1x1x167xf32>,
-// CHECK-SAME:                    %[[ARG_1:.*]]: tensor<167x64xf32>
-// CHECK-SAME:                    %[[ARG_2:.*]]: tensor<1x1x64xf32>)
-func.func @dot_general_with_bias_and_relu6_fn(%arg0: tensor<1x1x167xf32>, %arg1: tensor<167x64xf32>, %arg2: tensor<1x1x64xf32>) -> tensor<1x1x64xf32> {
-  %0 = stablehlo.constant dense<0.000000e+00> : tensor<1x1x64xf32>
-  %1 = stablehlo.constant dense<6.000000e+00> : tensor<1x1x64xf32>
-  %2 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32>
-  %3 = stablehlo.add %2, %arg2 : tensor<1x1x64xf32>
-  %4 = stablehlo.clamp %0, %3, %1 : tensor<1x1x64xf32>
-  func.return %4: tensor<1x1x64xf32>
+// CHECK-SAME:                    %[[ARG_0:.*]]: tensor<1x1x167xf32>
+func.func @dot_general_with_bias_and_relu6_fn(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> {
+  %0 = stablehlo.constant dense<2.000000e+00> : tensor<167x64xf32>
+  %1 = stablehlo.constant dense<2.000000e+00> : tensor<1x1x64xf32>
+  %2 = stablehlo.constant dense<0.000000e+00> : tensor<1x1x64xf32>
+  %3 = stablehlo.constant dense<6.000000e+00> : tensor<1x1x64xf32>
+  %4 = stablehlo.dot_general %arg0, %0, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32>
+  %5 = stablehlo.add %4, %1 : tensor<1x1x64xf32>
+  %6 = stablehlo.clamp %2, %5, %3 : tensor<1x1x64xf32>
+  func.return %6: tensor<1x1x64xf32>
 }
-// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %arg1, %arg2)
+// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00>
+// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00>
+// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]])
 // CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x1x64xf32>
 // CHECK: }
 
diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/post_quantize.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/post_quantize.mlir
new file mode 100644
index 0000000..ae2f5708
--- /dev/null
+++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/post_quantize.mlir
@@ -0,0 +1,72 @@
+// RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-post-quantize | FileCheck %s
+
+// CHECK-LABEL: @remove_volatile_qdq
+func.func @remove_volatile_qdq() -> tensor<3x2xf32> {
+  // CHECK: %[[CST:.*]] = stablehlo.constant
+  // CHECK-NOT: "quantfork.qcast"
+  // CHECK-NOT: "quantfork.dcast"
+  // CHECK: return %[[CST]]
+  %cst = stablehlo.constant dense<[[-0.960978984, -0.390246302], [-0.790828585, -0.601039409], [-1.0280807, -1.02731466]]> : tensor<3x2xf32>
+  %q = "quantfork.qcast"(%cst) {volatile} : (tensor<3x2xf32>) -> tensor<3x2x!quant.uniform<i8:f32, 0.013075299590241675:-64>>
+  %dq = "quantfork.dcast"(%q) : (tensor<3x2x!quant.uniform<i8:f32, 0.013075299590241675:-64>>) -> tensor<3x2xf32>
+  func.return %dq : tensor<3x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @remove_volatile_qdq_with_requantization
+// CHECK-SAME: %[[ARG0:.*]]: tensor<3x2xf32>
+func.func @remove_volatile_qdq_with_requantization(%arg0: tensor<3x2xf32>) -> tensor<3x2xf32> {
+  // CHECK: %[[Q1:.*]] = stablehlo.uniform_quantize %[[ARG0]]
+  // CHECK: %[[Q2:.*]] = stablehlo.uniform_quantize %[[Q1]]
+  // CHECK: %[[ABS:.*]] = stablehlo.abs %[[Q2]]
+  // CHECK: %[[DQ:.*]] = stablehlo.uniform_dequantize %[[ABS]]
+  // CHECK: %[[ADD:.*]] = stablehlo.add %[[ARG0]], %[[DQ]]
+  // CHECK: return %[[ADD]]
+  %q1 = "quantfork.qcast"(%arg0) {volatile} : (tensor<3x2xf32>) -> tensor<3x2x!quant.uniform<i8:f32, 6.000000e-03:-128>>
+  %q2 = "quantfork.qcast"(%q1) {volatile} : (tensor<3x2x!quant.uniform<i8:f32, 6.000000e-03:-128>>) -> tensor<3x2x!quant.uniform<i8:f32, 0.013075299590241675:-64>>
+  %dq1 = "quantfork.dcast"(%q2) : (tensor<3x2x!quant.uniform<i8:f32, 0.013075299590241675:-64>>) -> tensor<3x2xf32>
+  %abs = stablehlo.abs %q2 : (tensor<3x2x!quant.uniform<i8:f32, 0.013075299590241675:-64>>) -> tensor<3x2x!quant.uniform<i8:f32, 0.013075299590241675:-64>>
+  %dq2 = "quantfork.dcast"(%abs) : (tensor<3x2x!quant.uniform<i8:f32, 0.013075299590241675:-64>>) -> tensor<3x2xf32>
+  %add = stablehlo.add %dq1, %dq2 : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32>
+  func.return %add : tensor<3x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @quantize_constant
+// CHECK-SAME: %[[ARG0:.*]]: tensor<1x3xf32>
+func.func @quantize_constant(%arg0: tensor<1x3xf32>) -> tensor<1x2xf32> {
+  // CHECK-DAG: %[[QCST:.*]] = stablehlo.constant() {value = dense<-78> : tensor<3x2xi8>} : () -> tensor<3x2x!quant.uniform<i8<-127:127>:f32, 5.000000e-03>>
+  // CHECK-DAG: %[[Q1:.*]] = stablehlo.uniform_quantize %[[ARG0]]
+  // CHECK-NOT: "quantfork.qcast"
+  // CHECK: %[[DOT:.*]] = stablehlo.dot %[[Q1]], %[[QCST]]
+  // CHECK: %[[DQ:.*]] = stablehlo.uniform_dequantize %[[DOT]]
+  // CHECK: return %[[DQ]]
+  %cst = stablehlo.constant dense<-0.390246302> : tensor<3x2xf32>
+  %q1 = "quantfork.qcast"(%arg0) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform<i8:f32, 6.000000e-03:-128>>
+  %q2 = "quantfork.qcast"(%cst) {volatile} : (tensor<3x2xf32>) -> tensor<3x2x!quant.uniform<i8<-127:127>:f32, 5.000000e-03>>
+  %dot = stablehlo.dot %q1, %q2 : (tensor<1x3x!quant.uniform<i8:f32, 6.000000e-03:-128>>, tensor<3x2x!quant.uniform<i8<-127:127>:f32, 5.000000e-03>>) -> tensor<1x2x!quant.uniform<i8:f32, 1.000000e-03:-3>>
+  %dq = "quantfork.dcast"(%dot) : (tensor<1x2x!quant.uniform<i8:f32, 1.000000e-03:-3>>) -> tensor<1x2xf32>
+  func.return %dq : tensor<1x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @convert_quantfork_qdq_to_stablehlo_uniform_qdq
+// CHECK-SAME: %[[ARG0:.*]]: tensor<1x3xf32>
+// CHECK-SAME: %[[ARG1:.*]]: tensor<3x2xf32>
+func.func @convert_quantfork_qdq_to_stablehlo_uniform_qdq(%arg0: tensor<1x3xf32>, %arg1: tensor<3x2xf32>) -> tensor<1x2xf32> {
+  // CHECK: %[[Q1:.*]] = stablehlo.uniform_quantize %[[ARG0]]
+  // CHECK-NOT: "quantfork.qcast"
+  // CHECK: %[[Q2:.*]] = stablehlo.uniform_quantize %[[ARG1]]
+  // CHECK-NOT: "quantfork.qcast"
+  // CHECK: %[[DOT:.*]] = stablehlo.dot %[[Q1]], %[[Q2]]
+  // CHECK: %[[DQ:.*]] = stablehlo.uniform_dequantize %[[DOT]]
+  // CHECK: return %[[DQ]]
+  %q1 = "quantfork.qcast"(%arg0) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform<i8:f32, 6.000000e-03:-128>>
+  %q2 = "quantfork.qcast"(%arg1) {volatile} : (tensor<3x2xf32>) -> tensor<3x2x!quant.uniform<i8<-127:127>:f32, 5.000000e-03>>
+  %dot = stablehlo.dot %q1, %q2 : (tensor<1x3x!quant.uniform<i8:f32, 6.000000e-03:-128>>, tensor<3x2x!quant.uniform<i8<-127:127>:f32, 5.000000e-03>>) -> tensor<1x2x!quant.uniform<i8:f32, 1.000000e-03:-3>>
+  %dq = "quantfork.dcast"(%dot) : (tensor<1x2x!quant.uniform<i8:f32, 1.000000e-03:-3>>) -> tensor<1x2xf32>
+  func.return %dq : tensor<1x2xf32>
+}
diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_quantize.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_quantize.mlir
index 612b8d6..8f38f88 100644
--- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_quantize.mlir
+++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_quantize.mlir
@@ -3,6 +3,7 @@
 // -----
 
 // CHECK-LABEL: func @dot
+// CHECK-SAME: (%[[ARG_0:.*]]: tensor<?x3xf32>) -> tensor<?x2xf32>
 func.func @dot(%arg0: tensor<?x3xf32>) -> tensor<?x2xf32> {
   // CHECK: %[[cst:.*]] = stablehlo.constant
   // CHECK: %[[q1:.*]] = "quantfork.qcast"(%[[cst]])
@@ -10,7 +11,7 @@
   // CHECK: %[[dq1:.*]] = "quantfork.dcast"(%[[q1]])
   // CHECK-SAME: quant.uniform<i8:f32, 0.0040316890267764818:127>
   %cst = stablehlo.constant dense<[[-0.960978984, -0.390246302], [-0.790828585, -0.601039409], [-1.0280807, -1.02731466]]> : tensor<3x2xf32>
-  // CHECK: %[[q2:.*]] = "quantfork.qcast"(%arg0)
+  // CHECK: %[[q2:.*]] = "quantfork.qcast"(%[[ARG_0]])
   // CHECK-SAME: quant.uniform<i8:f32, 0.0078408040252386357:-1>
   // CHECK: %[[dq2:.*]] = "quantfork.dcast"(%[[q2]])
   // CHECK-SAME: quant.uniform<i8:f32, 0.0078408040252386357:-1>
@@ -29,8 +30,9 @@
 // -----
 
 // CHECK-LABEL: func @duplicate_stats
+// CHECK-SAME: (%[[ARG_0:.*]]: tensor<2x3xf32>) -> tensor<2x3xf32>
 func.func @duplicate_stats(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
-  // CHECK: %[[q1:.*]] = "quantfork.qcast"(%arg0)
+  // CHECK: %[[q1:.*]] = "quantfork.qcast"(%[[ARG_0]])
   // CHECK: %[[dq1:.*]] = "quantfork.dcast"(%[[q1]])
   // CHECK: %[[q2:.*]] = "quantfork.qcast"(%[[dq1]])
   // CHECK: %[[dq2:.*]] = "quantfork.dcast"(%[[q2]])
@@ -44,6 +46,7 @@
 // -----
 
 // CHECK-LABEL: func @dot_redundant_stats
+// CHECK-SAME: (%[[ARG_0:.*]]: tensor<?x3xf32>) -> tensor<?x2xf32>
 func.func @dot_redundant_stats(%arg0: tensor<?x3xf32>) -> tensor<?x2xf32> {
   // CHECK: %[[cst:.*]] = stablehlo.constant
   // CHECK: %[[q1:.*]] = "quantfork.qcast"(%[[cst]])
@@ -51,7 +54,7 @@
   // CHECK: %[[dq1:.*]] = "quantfork.dcast"(%[[q1]])
   // CHECK-SAME: quant.uniform<i8:f32, 0.0040316890267764818:127>
   %cst = stablehlo.constant dense<[[-0.960978984, -0.390246302], [-0.790828585, -0.601039409], [-1.0280807, -1.02731466]]> : tensor<3x2xf32>
-  // CHECK: %[[q2:.*]] = "quantfork.qcast"(%arg0)
+  // CHECK: %[[q2:.*]] = "quantfork.qcast"(%[[ARG_0]])
   // CHECK-SAME: quant.uniform<i8:f32, 0.0078408040252386357:-1>
   // CHECK: %[[dq2:.*]] = "quantfork.dcast"(%[[q2]])
   // CHECK-SAME: quant.uniform<i8:f32, 0.0078408040252386357:-1>
@@ -87,10 +90,11 @@
 // -----
 
 // CHECK-LABEL: func @merge_consecutive_qcast
+// CHECK-SAME: (%[[ARG_0:.*]]: tensor<*xf32>, %[[ARG_1:.*]]: tensor<*xf32>, %[[ARG_2:.*]]: tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>)
 func.func @merge_consecutive_qcast(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
-  // CHECK: "quantfork.qcast"(%arg1)
+  // CHECK: "quantfork.qcast"(%[[ARG_1]])
   // CHECK-SAME: -> tensor<*x!quant.uniform<i8:f32, 0.02454993117089365:-64>>
-  // CHECK: "quantfork.qcast"(%arg1)
+  // CHECK: "quantfork.qcast"(%[[ARG_1]])
   // CHECK-SAME: -> tensor<*x!quant.uniform<i8:f32, 0.013075299590241675:-64>>
   %0 = "quantfork.stats"(%arg0) {layerStats = dense<[-0.83811146, 2.4960899]> : tensor<2xf32>} : (tensor<*xf32>) -> tensor<*xf32>
   %1 = "quantfork.stats"(%arg1) {layerStats = dense<[-0.835039615, 1.000000e+00]> : tensor<2xf32>} : (tensor<*xf32>) -> tensor<*xf32>
diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_quantize_int4.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_quantize_int4.mlir
index 0ddb01c..ca467d1 100644
--- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_quantize_int4.mlir
+++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_quantize_int4.mlir
@@ -1,6 +1,7 @@
 // RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-prepare-quantize=bit-width=4 -verify-diagnostics | FileCheck %s
 
 // CHECK-LABEL: func @dot_int4
+// CHECK-SAME: (%[[ARG_0:.*]]: tensor<?x3xf32>) -> tensor<?x2xf32>
 func.func @dot_int4(%arg0: tensor<?x3xf32>) -> tensor<?x2xf32> {
   // CHECK: %[[cst:.*]] = stablehlo.constant
   // CHECK: %[[q1:.*]] = "quantfork.qcast"(%[[cst]])
@@ -8,7 +9,7 @@
   // CHECK: %[[dq1:.*]] = "quantfork.dcast"(%[[q1]])
   // CHECK-SAME: quant.uniform<i8:f32, 0.0040316890267764818:127>
   %cst = stablehlo.constant dense<[[-0.960978984, -0.390246302], [-0.790828585, -0.601039409], [-1.0280807, -1.02731466]]> : tensor<3x2xf32>
-  // CHECK: %[[q2:.*]] = "quantfork.qcast"(%arg0)
+  // CHECK: %[[q2:.*]] = "quantfork.qcast"(%[[ARG_0]])
   // CHECK-SAME: quant.uniform<i4:f32, 0.13329366842905679:-1>
   // CHECK: %[[dq2:.*]] = "quantfork.dcast"(%[[q2]])
   // CHECK-SAME: quant.uniform<i4:f32, 0.13329366842905679:-1>
diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize.mlir
index eccd931..d1bfea7 100644
--- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize.mlir
+++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize.mlir
@@ -20,7 +20,7 @@
 // CHECK: %[[CONST_0:.*]] = "stablehlo.constant"() {value = dense<1.000000e+00> : tensor<4x3xf32>} : () -> tensor<4x3xf32>
 // CHECK-DAG: %[[QCAST_0:.*]] = "quantfork.qcast"(%[[CONST_0]]) {volatile} : (tensor<4x3xf32>) -> tensor<4x3x!quant.uniform<i8<-127:127>:f32, 5.000000e-03>>
 // CHECK-DAG: %[[QCAST_1:.*]] = "quantfork.qcast"(%[[ARG_0]]) {volatile} : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform<i8:f32, 6.000000e-03:-128>>
-// CHECK: %[[XLACALLMODULE_0:.*]] = "tf.XlaCallModule"(%[[QCAST_1]], %[[QCAST_0]]) {{{.*}}} : (tensor<1x4x!quant.uniform<i8:f32, 6.000000e-03:-128>>, tensor<4x3x!quant.uniform<i8<-127:127>:f32, 5.000000e-03>>) -> tensor<1x3x!quant.uniform<i8:f32, 1.000000e-03:-3>>
+// CHECK: %[[XLACALLMODULE_0:.*]] = "tf.XlaCallModule"(%[[QCAST_1]], %[[QCAST_0]]) <{{{.*}}}> {{{.*}}} : (tensor<1x4x!quant.uniform<i8:f32, 6.000000e-03:-128>>, tensor<4x3x!quant.uniform<i8<-127:127>:f32, 5.000000e-03>>) -> tensor<1x3x!quant.uniform<i8:f32, 1.000000e-03:-3>>
 // CHECK: %[[DCAST_0:.*]] = "quantfork.dcast"(%[[XLACALLMODULE_0]]) : (tensor<1x3x!quant.uniform<i8:f32, 1.000000e-03:-3>>) -> tensor<1x3xf32>
 // CHECK: "func.return"(%[[DCAST_0]]) : (tensor<1x3xf32>) -> ()
 
@@ -37,6 +37,6 @@
 // Tests that the output of the tf.XlaCallModule op has been replaced by
 // a quantized type, and the corresponding quantfork.qcast ops that turned
 // the float output to a quantized type is removed.
-// CHECK: %[[XLACALLMODULE_0:.*]] = "tf.XlaCallModule"() {{{.*}}} : () -> tensor<1x3x!quant.uniform<i8:f32, 1.000000e-03:-3>>
+// CHECK: %[[XLACALLMODULE_0:.*]] = "tf.XlaCallModule"() <{{{.*}}}> {{{.*}}} : () -> tensor<1x3x!quant.uniform<i8:f32, 1.000000e-03:-3>>
 // CHECK: %[[DCAST_0:.*]] = "quantfork.dcast"(%[[XLACALLMODULE_0]]) : (tensor<1x3x!quant.uniform<i8:f32, 1.000000e-03:-3>>) -> tensor<1x3xf32>
 // CHECK: "func.return"(%[[DCAST_0]]) : (tensor<1x3xf32>) -> ()
diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_composite_functions.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_composite_functions.mlir
new file mode 100644
index 0000000..97ea1f3
--- /dev/null
+++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_composite_functions.mlir
@@ -0,0 +1,114 @@
+// RUN: stablehlo-quant-opt %s -split-input-file -verify-diagnostics \
+// RUN:     -stablehlo-quantize-composite-functions | FileCheck %s
+
+module attributes {tf_saved_model.semantics} {
+// The following pattern does not converge because of a bug in QuantizePass.
+// TODO - b/305469508: Fix the QuantizePass to avoid this warning.
+// expected-warning @+1 {{Failed to converge pattern at QuantizePass.}}
+  func.func private @quantize_dot_general(%arg0: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} {
+    %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<3x3xf32>} : () -> tensor<3x3xf32>
+    %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32>
+    %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable",   device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x3xf32>, tensor<3x3xf32>) -> tensor<1x3xf32>
+    %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32>
+    return %2 : tensor<1x3xf32>
+  }
+// Checks that the quantized XlaCallModule has been replaced by a CallOp, which
+// calls the quantized entry function.
+
+// CHECK-LABEL: func.func private @quantize_dot_general
+// CHECK-SAME: (%[[ARG_1:.*]]: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"}
+// CHECK: %[[CONST_0:.*]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<3x3xi8>} : () -> tensor<3x3x!quant.uniform<i8<-127:127>:f32, {{.*}}>
+// CHECK: %[[UNIFORM_QUANTIZE_0:.*]] = stablehlo.uniform_quantize %[[ARG_1]] : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform<i8:f32, {{.*}}>>
+// CHECK: %[[CALL_0:.*]] = call @quantized_dot_general_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) : (tensor<1x3x!quant.uniform<i8:f32, {{.*}}>>, tensor<3x3x!quant.uniform<i8<-127:127>:f32, {{.*}}>) -> tensor<1x3x!quant.uniform<i8:f32, {{.*}}>>
+// CHECK: %[[UNIFORM_DEQUANTIZE_0:.*]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x!quant.uniform<i8:f32, {{.*}}>) -> tensor<1x3xf32>
+// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3xf32>
+
+  func.func private @composite_dot_general_fn(%arg0: tensor<1x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} {
+    %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x3xf32>, tensor<3x3xf32>) -> tensor<1x3xf32>
+    return %0 : tensor<1x3xf32>
+  }
+// Checks that the entry function is quantized for dot_general. Quantized
+// dot_general outputs an i32 quantized tensor, followed by requantization to
+// i8 quantized tensor.
+
+// CHECK: func.func private @quantized_dot_general_fn(%[[ARG_2:.*]]: tensor<1x3x!quant.uniform<i8:f32, {{.*}}>>, %[[ARG_3:.*]]: tensor<3x3x!quant.uniform<i8<-127:127>:f32, {{.*}}>>) -> tensor<1x3x!quant.uniform<i8:f32, {{.*}}>> attributes {_from_xla_call_module}
+// CHECK: %[[DOT_GENERAL_0:.*]] = stablehlo.dot_general %[[ARG_2]], %[[ARG_3]], contracting_dims = [1] x [0] : (tensor<1x3x!quant.uniform<i8:f32, {{.*}}>>, tensor<3x3x!quant.uniform<i8<-127:127>:f32, {{.*}}>) -> tensor<1x3x!quant.uniform<i32:f32, {{.*}}>>
+// CHECK: %[[UNIFORM_QUANTIZE_1:.*]] = stablehlo.uniform_quantize %[[DOT_GENERAL_0]] : (tensor<1x3x!quant.uniform<i32:f32, {{.*}}>>) -> tensor<1x3x!quant.uniform<i8:f32, {{.*}}>>
+// CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x!quant.uniform<i8:f32, {{.*}}>>
+}
+
+// -----
+
+// Tests error when there are no corresponding entry function to quantize
+// (@composite_dot_general_fn).
+
+module attributes {tf_saved_model.semantics} {
+// The following pattern does not converge because of a bug in QuantizePass.
+// TODO - b/305469508: Fix the QuantizePass to avoid this warning.
+// expected-warning @+1 {{Failed to converge pattern at QuantizePass.}}
+  func.func private @error_when_no_entry_function(%arg0: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} {
+    %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<3x3xf32>} : () -> tensor<3x3xf32>
+    %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32>
+// expected-error @+2 {{Failed to find a valid entry function}}
+// expected-error @+1 {{'tf.XlaCallModule' op operand #0 must be variadic of tensor of tf.dtype values}}
+    %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable",   device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x3xf32>, tensor<3x3xf32>) -> tensor<1x3xf32>
+    %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32>
+    return %2 : tensor<1x3xf32>
+  }
+}
+
+// -----
+
+// Tests that XlaCallModule op is not quantized without the quantfork.stats ops.
+
+module attributes {tf_saved_model.semantics} {
+  func.func private @not_quantized_without_stats(%arg0: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} {
+    %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<3x3xf32>} : () -> tensor<3x3xf32>
+    %1 = "tf.XlaCallModule"(%arg0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable",   device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x3xf32>, tensor<3x3xf32>) -> tensor<1x3xf32>
+    return %1 : tensor<1x3xf32>
+  }
+// Check that "tf.Const" is converted to stablehlo.constant. XlaCallModule is
+// not quantized.
+
+// CHECK-LABEL: func.func private @not_quantized_without_stats
+// CHECK-SAME: (%[[ARG_1:.*]]: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"}
+// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<3.000000e-01> : tensor<3x3xf32>
+// CHECK: %[[XLA_CALL_MODULE_0:.*]] = "tf.XlaCallModule"(%[[ARG_1]], %[[CONST_0]]) <{{{.*}}}> {{{.*_entry_function = @composite_dot_general_fn.*}}} : (tensor<1x3xf32>, tensor<3x3xf32>) -> tensor<1x3xf32>
+// CHECK: return %[[XLA_CALL_MODULE_0]]
+
+  func.func private @composite_dot_general_fn(%arg0: tensor<1x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} {
+    %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x3xf32>, tensor<3x3xf32>) -> tensor<1x3xf32>
+    return %0 : tensor<1x3xf32>
+  }
+// Check that the composite_dot_general_fn is untouched.
+
+// CHECK: func.func private @composite_dot_general_fn(%[[ARG_2:.*]]: tensor<1x3xf32>, %[[ARG_3:.*]]: tensor<3x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module}
+// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %[[ARG_2]], %[[ARG_3]]
+// CHECK: return %[[DOT_GENERAL]]
+}
+
+// -----
+
+// Tests that a fusion pattern for dot_general is not yet supported. Further op
+// coverage will be provided in the future.
+// TODO - b/307620428: Increase op coverage to cover this test case.
+
+module attributes {tf_saved_model.semantics} {
+// The following pattern does not converge because of a bug in QuantizePass.
+// TODO - b/305469508: Fix the QuantizePass to avoid this warning.
+// expected-warning @+1 {{Failed to converge pattern at QuantizePass.}}
+  func.func private @dot_general_fn_fusion_not_quantized(%arg0: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} {
+    %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<3x3xf32>} : () -> tensor<3x3xf32>
+    %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32>
+// expected-error @+1 {{'tf.XlaCallModule' op operand #0 must be variadic of tensor of tf.dtype values}}
+    %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable",   device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x3xf32>, tensor<3x3xf32>) -> tensor<1x3xf32>
+    %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32>
+    return %2 : tensor<1x3xf32>
+  }
+
+  func.func private @composite_dot_general_fn(%arg0: tensor<1x3xf32>, %arg1: tensor<3x3xf32>, %arg2: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} {
+    %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x3xf32>, tensor<3x3xf32>) -> tensor<1x3xf32>
+    %1 = stablehlo.add %0, %arg2 : tensor<1x3xf32>
+    return %1 : tensor<1x3xf32>
+  }
+}
diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir
index 57423f5..3d04c72 100644
--- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir
+++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir
@@ -30,13 +30,13 @@
     return %9 : tensor<1x64xf32>
   }
 
-  // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_0:.*]] = "tf.XlaCallModule"() {Sout = [#tf_type.shape<{{.*}}>, #tf_type.shape<{{.*}}>], _entry_function = @_stablehlo_main_1
-  // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]] = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> tensor<1x1024xf32>
-  // CHECK: %[[XLA_CALL_MODULE_0:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_0]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_0:.*]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_0:.*]]) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1"
+  // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_0:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<{{.*}}>, #tf_type.shape<{{.*}}>], {{.*}}}> {_entry_function = @_stablehlo_main_1
+  // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]] = "tf.CustomAggregator"(%arg0) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> tensor<1x1024xf32>
+  // CHECK: %[[XLA_CALL_MODULE_0:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_0]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_0:.*]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_0:.*]]) <{Sout = [#tf_type.shape<1x3>], {{.*}}}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1"
   // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]] = "tf.CustomAggregator"(%[[XLA_CALL_MODULE_0]])
-  // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_1:.*]] = "tf.XlaCallModule"() {Sout = [#tf_type.shape<{{.*}}>, #tf_type.shape<{{.*}}>], _entry_function = @_stablehlo_main_0
-  // CHECK: %[[CUSTOM_AGGREGATOR_2:.*]] = "tf.CustomAggregator"(%[[CUSTOM_AGGREGATOR_1]]) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x3xf32>) -> tensor<1x3xf32>
-  // CHECK: %[[XLA_CALL_MODULE_1:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_2]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_1:.*]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_1:.*]]) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_with_relu_fn_1, _original_entry_function = "composite_dot_general_with_relu_fn_1"
+  // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_1:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<{{.*}}>, #tf_type.shape<{{.*}}>], {{.*}}}> {_entry_function = @_stablehlo_main_0
+  // CHECK: %[[CUSTOM_AGGREGATOR_2:.*]] = "tf.CustomAggregator"(%[[CUSTOM_AGGREGATOR_1]]) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x3xf32>) -> tensor<1x3xf32>
+  // CHECK: %[[XLA_CALL_MODULE_1:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_2]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_1:.*]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_1:.*]]) <{Sout = [#tf_type.shape<1x3>], {{.*}}}> {_entry_function = @composite_dot_general_with_relu_fn_1, _original_entry_function = "composite_dot_general_with_relu_fn_1"
   // CHECK: %[[CUSTOM_AGGREGATOR_3:.*]] = "tf.CustomAggregator"(%[[XLA_CALL_MODULE_1:.*]])
   // CHECK: return %[[CUSTOM_AGGREGATOR_3]] : tensor<1x64xf32>
   // CHECK: }
@@ -60,6 +60,39 @@
   }
 }
 
+
+// -----
+
+module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1654 : i32}, tf_saved_model.semantics} {
+
+  // CHECK: func private @_stablehlo_main_0
+  // CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<0.134728625> : tensor<1x3xf32>
+  // CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<-1.280000e+02> : tensor<1x1024xf32>
+  // CHECK: %[[CONSTANT_2:.*]] = stablehlo.constant dense<0.003921567> : tensor<1x1024xf32>
+  // CHECK: %[[DIVIDE:.*]] = stablehlo.divide %arg0, %[[CONSTANT_2]]
+  // CHECK: %[[ADD:.*]] = stablehlo.add %[[DIVIDE]], %[[CONSTANT_1]]
+  // CHECK return %[[ADD]]
+  // CHECK: }
+
+  // CHECK: @serving_default
+  func.func @serving_default(%arg0: tensor<1x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<1x1024xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
+    %0 = stablehlo.constant dense<0.134728625> : tensor<1x3xf32>
+    %1 = stablehlo.constant dense<-1.280000e+02> : tensor<1x1024xf32>
+    %2 = stablehlo.constant dense<0.003921567> : tensor<1x1024xf32>
+    %3 = stablehlo.divide %arg0, %2 : tensor<1x1024xf32>
+    %4 = stablehlo.add %3, %1 : tensor<1x1024xf32>
+    %5 = "tf.Identity"(%4) {device = ""} : (tensor<1x1024xf32>) -> tensor<1x1024xf32>
+    return %5 : tensor<1x1024xf32>
+  }
+
+ // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP:.*]] = "tf.XlaCallModule"(%arg0) <{Sout = [#tf_type.shape<1x1024>]
+ // CHECK-SAME: _entry_function = @_stablehlo_main_0
+ // CHECK: %[[IDENTITY:.*]] = "tf.Identity"(%[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP]])
+ // CHECK: return %[[IDENTITY]]
+ // CHECK }
+
+}
+
 // -----
 
 module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1629 : i32}, tf_saved_model.semantics} {
@@ -77,9 +110,9 @@
     return %3 : tensor<1x3xf32>
   }
 
-  // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP:.*]] = "tf.XlaCallModule"() {Sout = [#tf_type.shape<1024x3>], _entry_function = @_stablehlo_main_
-  // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]] = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> tensor<1x1024xf32>
-  // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR:.*]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP:.*]]) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1"
+  // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @_stablehlo_main_
+  // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]] = "tf.CustomAggregator"(%arg0) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> tensor<1x1024xf32>
+  // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR:.*]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP:.*]]) <{Sout = [#tf_type.shape<1x3>], {{.*}}}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1"
   // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]] = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]])
   // CHECK: return %[[CUSTOM_AGGREGATOR_1]]
   // CHECK: }
@@ -109,8 +142,8 @@
   }
 
   // CHECK: %[[CONSTANT:.*]] = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32>
-  // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]] = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> tensor<1x1024xf32>
-  // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR:.*]], %[[XLA_CALL_MODULE_EXTRACTED_FROM_SUBGRAPH:.*]]) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1"
+  // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]] = "tf.CustomAggregator"(%arg0) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> tensor<1x1024xf32>
+  // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR:.*]], %[[XLA_CALL_MODULE_EXTRACTED_FROM_SUBGRAPH:.*]]) <{Sout = [#tf_type.shape<1x3>], {{.*}}}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1"
   // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]] = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]])
   // CHECK: return %[[CUSTOM_AGGREGATOR_1]]
   // CHECK: }
diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.cc b/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.cc
index 197d5dd..bfd9de9 100644
--- a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.cc
+++ b/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.cc
@@ -37,6 +37,17 @@
       /*storageTypeMin=*/llvm::minIntN(8), /*storageTypeMax=*/llvm::maxIntN(8));
 }
 
+UniformQuantizedType CreateI32F32UniformQuantizedType(
+    const Location loc, MLIRContext& context, const float scale,
+    const int32_t zero_point) {
+  return UniformQuantizedType::getChecked(
+      loc, /*flags=*/QuantizationFlags::Signed,
+      /*storageType=*/IntegerType::get(&context, /*width=*/32),
+      /*expressedType=*/FloatType::getF32(&context), scale, zero_point,
+      /*storageTypeMin=*/llvm::minIntN(32),
+      /*storageTypeMax=*/llvm::maxIntN(32));
+}
+
 UniformQuantizedPerAxisType CreateI8F32UniformQuantizedPerAxisType(
     const Location loc, MLIRContext& context, const ArrayRef<float> scales,
     const ArrayRef<int8_t> zero_points, const int quantization_dimension) {
diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h b/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h
index 84f5ae2..68774b2 100644
--- a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h
+++ b/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h
@@ -35,6 +35,16 @@
                                                      float scale,
                                                      int8_t zero_point);
 
+// Creates a `UniformQuantizedType` with the given `scale` and `zero_point`
+// values. The produced type has f32 as its expressed type and i32 as its
+// storage type. The available values use the full range of the storage value.
+// Assumes asymmetric quantization, meaning the zero point values can be
+// non-zero values.
+UniformQuantizedType CreateI32F32UniformQuantizedType(Location loc,
+                                                      MLIRContext& context,
+                                                      float scale,
+                                                      int32_t zero_point);
+
 // Creates a `UniformQuantizedPerAxisType` with the given `scales` and
 // `zero_points` values. The produced type has f32 as its expressed type and
 // i8 as its storage type. The available values use the full range of the
diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types_test.cc
index eb40dcb..0888bfa 100644
--- a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types_test.cc
+++ b/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types_test.cc
@@ -15,6 +15,7 @@
 #include "tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h"
 
 #include <cstdint>
+#include <limits>
 
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
@@ -81,6 +82,60 @@
   EXPECT_EQ(quantized_type.getZeroPoint(), 99);
 }
 
+class CreateI32F32UniformQuantizedTypeTest : public ::testing::Test {
+ protected:
+  CreateI32F32UniformQuantizedTypeTest() : ctx_() {
+    ctx_.loadDialect<quant::QuantizationDialect>();
+  }
+
+  MLIRContext ctx_;
+};
+
+TEST_F(CreateI32F32UniformQuantizedTypeTest, HasI32StorageType) {
+  const UniformQuantizedType quantized_type =
+      CreateI32F32UniformQuantizedType(UnknownLoc::get(&ctx_), ctx_,
+                                       /*scale=*/1.0, /*zero_point=*/0);
+
+  EXPECT_TRUE(quantized_type.getStorageType().isSignlessInteger(32));
+}
+
+TEST_F(CreateI32F32UniformQuantizedTypeTest, HasF32ExpressedType) {
+  const UniformQuantizedType quantized_type =
+      CreateI32F32UniformQuantizedType(UnknownLoc::get(&ctx_), ctx_,
+                                       /*scale=*/1.0, /*zero_point=*/0);
+
+  EXPECT_TRUE(quantized_type.getExpressedType().isF32());
+}
+
+TEST_F(CreateI32F32UniformQuantizedTypeTest, IsSigned) {
+  const UniformQuantizedType quantized_type =
+      CreateI32F32UniformQuantizedType(UnknownLoc::get(&ctx_), ctx_,
+                                       /*scale=*/1.0, /*zero_point=*/0);
+
+  EXPECT_TRUE(quantized_type.isSigned());
+}
+
+TEST_F(CreateI32F32UniformQuantizedTypeTest,
+       SotrageTypeMinMaxEqualToI32MinMax) {
+  const UniformQuantizedType quantized_type =
+      CreateI32F32UniformQuantizedType(UnknownLoc::get(&ctx_), ctx_,
+                                       /*scale=*/1.0, /*zero_point=*/0);
+
+  EXPECT_EQ(quantized_type.getStorageTypeMin(),
+            std::numeric_limits<int32_t>::min());
+  EXPECT_EQ(quantized_type.getStorageTypeMax(),
+            std::numeric_limits<int32_t>::max());
+}
+
+TEST_F(CreateI32F32UniformQuantizedTypeTest, HasScaleAndZeroPointProperlySet) {
+  const UniformQuantizedType quantized_type =
+      CreateI32F32UniformQuantizedType(UnknownLoc::get(&ctx_), ctx_,
+                                       /*scale=*/8.0, /*zero_point=*/1111);
+
+  EXPECT_EQ(quantized_type.getScale(), 8.0);
+  EXPECT_EQ(quantized_type.getZeroPoint(), 1111);
+}
+
 class CreateI8F32UniformQuantizedPerAxisTypeTest : public ::testing::Test {
  protected:
   CreateI8F32UniformQuantizedPerAxisTypeTest() : ctx_() {
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD
index aec612c..286e0e8 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD
@@ -117,6 +117,15 @@
     ],
 )
 
+cc_library(
+    name = "id_assigner",
+    hdrs = ["id_assigner.h"],
+    compatible_with = get_compatible_with_portable(),
+    deps = [
+        "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc",
+    ],
+)
+
 pytype_strict_library(
     name = "calibration_algorithm",
     srcs = ["calibration_algorithm.py"],
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/id_assigner.h b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/id_assigner.h
new file mode 100644
index 0000000..ae75d9f
--- /dev/null
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/id_assigner.h
@@ -0,0 +1,37 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CALIBRATOR_ID_ASSIGNER_H_
+#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CALIBRATOR_ID_ASSIGNER_H_
+
+#include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h"
+
+namespace tensorflow::quantization {
+
+// An interface that assigns UUIDs to CustomAggregator ops.
+class CustomAggregatorIdAssigner {
+ public:
+  virtual ~CustomAggregatorIdAssigner() = default;
+
+  // Assigns UUIDs to each CustomAggregator op found in each GraphDef in
+  // `exported_model`. The UUIDs are set to the `id` attributes. The UUIDs will
+  // be used during calibration step to identify the collected quantization
+  // statistics for each CustsomAggregator op.
+  virtual ExportedModel AssignIds(
+      const ExportedModel& exported_model) const = 0;
+};
+
+}  // namespace tensorflow::quantization
+
+#endif  // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CALIBRATOR_ID_ASSIGNER_H_
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc
index d415e755e..22f95fa 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc
@@ -45,8 +45,8 @@
   // Supported quantizable ops.
   return isa<TF::XlaConvV2Op, TF::XlaDotV2Op, TF::MatMulOp, TF::Conv2DOp,
              TF::GatherOp, TF::GatherV2Op, TF::XlaGatherOp,
-             TF::DepthwiseConv2dNativeOp, TF::Conv3DOp, TF::BatchMatMulV2Op,
-             TF::EinsumOp>(op);
+             TF::ResourceGatherOp, TF::DepthwiseConv2dNativeOp, TF::Conv3DOp,
+             TF::BatchMatMulV2Op, TF::EinsumOp>(op);
 }
 
 bool IsOpWithInt8TypeOperand(Operation* op) {
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc
index d6ad460..1945b69 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc
@@ -144,7 +144,7 @@
                      "drq", "Post-training dynamic-range quantizaiton"),
           clEnumValN(tensorflow::quantization::QuantizationMethod::
                          METHOD_STATIC_RANGE_WEIGHT_ONLY_INT8,
-                     "weight_only", "Post-training weight-only quantizaiton"))};
+                     "weight_only", "Post-training weight-only quantization"))};
 
   Option<OpSet> target_opset_{
       *this, "target-opset", llvm::cl::init(OpSet::TF),
@@ -572,6 +572,30 @@
   return success();
 }
 
+// Transfers the location of the main op in float function to ops with
+// `attr_map` attributes in quantized function.
+LogicalResult TransferLocation(func::FuncOp float_func,
+                               func::FuncOp quantized_func) {
+  Operation* main_op = nullptr;
+  for (Operation& inner_op : float_func.getBody().front().getOperations()) {
+    // Expect only one quantizable op in the composite function.
+    if (IsOpWithQuantizableTrait(&inner_op)) {
+      main_op = &inner_op;
+      break;
+    }
+  }
+  if (!main_op) {
+    float_func.emitError() << "No quantizable ops found in the function.";
+    return failure();
+  }
+
+  for (Operation& inner_op : quantized_func.getBody().front().getOperations()) {
+    if (!inner_op.hasAttr(kAttrMapAttribute)) continue;
+    inner_op.setLoc(main_op->getLoc());
+  }
+  return success();
+}
+
 // Get the corresponding quantized function name from the given function name.
 std::string GetQuantizedFunctionName(StringRef func_name,
                                      const bool merged_with_dequantize,
@@ -807,6 +831,11 @@
       new_quantized_func_arg.setType(partitioned_call_arg.getType());
     }
 
+    // Set the location for ops so the op name is preserved.
+    if (failed(TransferLocation(float_func, new_quantized_func))) {
+      return failure();
+    }
+
     // Set the attributes for ops with the attr_map attribute.
     if (target_opset_ == OpSet::UNIFORM_QUANTIZED) {
       if (failed(TransferTFAttributesToTFUniformAttributes(
@@ -891,6 +920,11 @@
       new_quantized_func_arg.setType(partitioned_call_arg.getType());
     }
 
+    // Set the location for ops so the op name is preserved.
+    if (failed(TransferLocation(float_func, new_quantized_func))) {
+      return failure();
+    }
+
     // Set the attributes for ops with the attr_map attribute.
     if (target_opset_ == OpSet::UNIFORM_QUANTIZED) {
       if (failed(TransferTFAttributesToTFUniformAttributes(
@@ -1283,9 +1317,7 @@
       ctx, target_opset_);
   patterns_2.add<QuantizeConstPattern>(ctx, target_opset_);
 
-  if (target_opset_ == OpSet::XLA && enable_per_channel_quantization_ &&
-      quantization_method_ == tensorflow::quantization::QuantizationMethod::
-                                  METHOD_STATIC_RANGE_WEIGHT_ONLY_INT8) {
+  if (target_opset_ == OpSet::XLA && enable_per_channel_quantization_) {
     patterns_2.add<RestoreWeightShapePattern>(ctx);
   }
 
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD
index b078bb9..ac1b648 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD
@@ -108,14 +108,15 @@
         "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc",
         "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibration_statistics_proto_cc",
         "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibrator_singleton",
+        "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:id_assigner",
         "//tensorflow/python/lib/core:pybind11_lib",
-        "//third_party/python_runtime:headers",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/status:statusor",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/strings:str_format",
         "@pybind11",
         "@pybind11_abseil//pybind11_abseil:absl_casters",
+        "@pybind11_abseil//pybind11_abseil:import_status_module",
         "@pybind11_abseil//pybind11_abseil:status_casters",
         "@pybind11_protobuf//pybind11_protobuf:native_proto_caster",
     ],
@@ -177,7 +178,6 @@
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/framework:tensor_conversion",
         "//tensorflow/python/lib/io:file_io",
-        "//tensorflow/python/platform:tf_logging",
         "//tensorflow/python/saved_model:load",
         "//tensorflow/python/saved_model:loader",
         "//tensorflow/python/saved_model:signature_constants",
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py
index d7156be..406ba58a 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py
@@ -785,6 +785,7 @@
             dilations=[1, 1, 1, 1],
             padding='SAME',
             data_format='NHWC',
+            name='sample/conv2d',
         )
         if has_bias:
           out = nn_ops.bias_add(out, bias, data_format='NHWC')
@@ -870,7 +871,9 @@
     )
     graphdef = loader.get_meta_graph_def_from_tags(tags).graph_def
     if target_opset == quant_opts_pb2.XLA:
-      self.assertTrue(self._contains_op(graphdef, 'XlaConvV2'))
+      self.assertTrue(
+          self._contains_op(graphdef, 'XlaConvV2', node_name='sample/conv2d.*')
+      )
 
     new_outputs = converted_model.signatures[signature_key](
         input=ops.convert_to_tensor(input_data)
@@ -1548,7 +1551,7 @@
           'dilations': [1, 2, 2, 1],
       },
       {
-          'testcase_name': 'with_bias_and_relu6_to_xla',
+          'testcase_name': 'with_bias_and_relu6_to_xla_per_tensor',
           'activation_fn': nn_ops.relu6,
           'has_bias': True,
           'has_batch_norm': False,
@@ -1557,6 +1560,15 @@
           'enable_per_channel_quantization': False,
       },
       {
+          'testcase_name': 'with_bias_and_relu6_to_xla_per_channel',
+          'activation_fn': nn_ops.relu6,
+          'has_bias': True,
+          'has_batch_norm': False,
+          'target_opset': quant_opts_pb2.XLA,
+          'input_shape_dynamic': False,
+          'enable_per_channel_quantization': True,
+      },
+      {
           'testcase_name': 'dilation_with_bias_and_relu6_to_xla',
           'activation_fn': nn_ops.relu6,
           'has_bias': True,
@@ -1567,6 +1579,16 @@
           'dilations': [1, 2, 2, 1],
       },
       {
+          'testcase_name': 'dilation_with_bias_and_relu6_to_xla_per_channel',
+          'activation_fn': nn_ops.relu6,
+          'has_bias': True,
+          'has_batch_norm': False,
+          'target_opset': quant_opts_pb2.XLA,
+          'input_shape_dynamic': False,
+          'enable_per_channel_quantization': True,
+          'dilations': [1, 2, 2, 1],
+      },
+      {
           'testcase_name': 'with_bias_and_bn_and_relu6_to_xla',
           'activation_fn': nn_ops.relu6,
           'has_bias': True,
@@ -1576,6 +1598,15 @@
           'enable_per_channel_quantization': False,
       },
       {
+          'testcase_name': 'with_bias_and_bn_and_relu6_to_xla_per_channel',
+          'activation_fn': nn_ops.relu6,
+          'has_bias': True,
+          'has_batch_norm': True,
+          'target_opset': quant_opts_pb2.XLA,
+          'input_shape_dynamic': False,
+          'enable_per_channel_quantization': True,
+      },
+      {
           'testcase_name': 'dilation_with_bias_and_bn_and_relu6_to_xla',
           'activation_fn': nn_ops.relu6,
           'has_bias': True,
@@ -1586,6 +1617,18 @@
           'dilations': [1, 2, 2, 1],
       },
       {
+          'testcase_name': (
+              'dilation_with_bias_and_bn_and_relu6_to_xla_per_channel'
+          ),
+          'activation_fn': nn_ops.relu6,
+          'has_bias': True,
+          'has_batch_norm': True,
+          'target_opset': quant_opts_pb2.XLA,
+          'input_shape_dynamic': False,
+          'enable_per_channel_quantization': True,
+          'dilations': [1, 2, 2, 1],
+      },
+      {
           'testcase_name': 'with_bias_and_relu6_to_xla_dynamic',
           'activation_fn': nn_ops.relu6,
           'has_bias': True,
@@ -1595,6 +1638,15 @@
           'enable_per_channel_quantization': False,
       },
       {
+          'testcase_name': 'with_bias_and_relu6_to_xla_dynamic_per_channel',
+          'activation_fn': nn_ops.relu6,
+          'has_bias': True,
+          'has_batch_norm': False,
+          'target_opset': quant_opts_pb2.XLA,
+          'input_shape_dynamic': True,
+          'enable_per_channel_quantization': True,
+      },
+      {
           'testcase_name': 'dilation_with_bias_and_relu6_to_xla_dynamic',
           'activation_fn': nn_ops.relu6,
           'has_bias': True,
@@ -1605,6 +1657,18 @@
           'dilations': [1, 2, 2, 1],
       },
       {
+          'testcase_name': (
+              'dilation_with_bias_and_relu6_to_xla_dynamic_per_channel'
+          ),
+          'activation_fn': nn_ops.relu6,
+          'has_bias': True,
+          'has_batch_norm': False,
+          'target_opset': quant_opts_pb2.XLA,
+          'input_shape_dynamic': True,
+          'enable_per_channel_quantization': True,
+          'dilations': [1, 2, 2, 1],
+      },
+      {
           'testcase_name': 'with_bias_and_bn_and_relu6_to_xla_dynamic',
           'activation_fn': nn_ops.relu6,
           'has_bias': True,
@@ -1614,6 +1678,17 @@
           'enable_per_channel_quantization': False,
       },
       {
+          'testcase_name': (
+              'with_bias_and_bn_and_relu6_to_xla_dynamic_per_channel'
+          ),
+          'activation_fn': nn_ops.relu6,
+          'has_bias': True,
+          'has_batch_norm': True,
+          'target_opset': quant_opts_pb2.XLA,
+          'input_shape_dynamic': True,
+          'enable_per_channel_quantization': True,
+      },
+      {
           'testcase_name': 'dilation_with_bias_and_bn_and_relu6_to_xla_dynamic',
           'activation_fn': nn_ops.relu6,
           'has_bias': True,
@@ -1624,6 +1699,18 @@
           'dilations': [1, 2, 2, 1],
       },
       {
+          'testcase_name': (
+              'dilation_with_bias_and_bn_and_relu6_to_xla_dynamic_per_channel'
+          ),
+          'activation_fn': nn_ops.relu6,
+          'has_bias': True,
+          'has_batch_norm': True,
+          'target_opset': quant_opts_pb2.XLA,
+          'input_shape_dynamic': True,
+          'enable_per_channel_quantization': True,
+          'dilations': [1, 2, 2, 1],
+      },
+      {
           'testcase_name': 'with_bias_and_relu6_to_uq',
           'activation_fn': nn_ops.relu6,
           'has_bias': True,
@@ -1787,6 +1874,28 @@
 
     if target_opset == quant_opts_pb2.XLA:
       self.assertTrue(self._contains_op(output_graphdef, 'XlaConvV2'))
+      if enable_per_channel_quantization:
+        per_channel_size_attr = attr_value_pb2.AttrValue(
+            list=attr_value_pb2.AttrValue.ListValue(
+                shape=[
+                    tensor_shape_pb2.TensorShapeProto(
+                        dim=[
+                            tensor_shape_pb2.TensorShapeProto.Dim(
+                                size=filter_shape[-1]
+                            )
+                        ]
+                    )
+                ]
+            )
+        )
+        self.assertTrue(
+            self._contains_op(
+                output_graphdef,
+                'Const',
+                '_output_shapes',
+                per_channel_size_attr,
+            )
+        )
     elif target_opset == quant_opts_pb2.UNIFORM_QUANTIZED:
       self.assertTrue(
           self._contains_op(output_graphdef, 'UniformQuantizedConvolution')
@@ -2051,6 +2160,15 @@
           'enable_per_channel_quantization': False,
       },
       {
+          'testcase_name': 'with_bias_and_relu6_to_xla_per_channel',
+          'activation_fn': nn_ops.relu6,
+          'has_bias': True,
+          'has_batch_norm': False,
+          'target_opset': quant_opts_pb2.XLA,
+          'input_shape_dynamic': False,
+          'enable_per_channel_quantization': True,
+      },
+      {
           'testcase_name': 'with_bias_and_bn_and_relu6_to_xla',
           'activation_fn': nn_ops.relu6,
           'has_bias': True,
@@ -2060,6 +2178,15 @@
           'enable_per_channel_quantization': False,
       },
       {
+          'testcase_name': 'with_bias_and_bn_and_relu6_to_xla_per_channel',
+          'activation_fn': nn_ops.relu6,
+          'has_bias': True,
+          'has_batch_norm': True,
+          'target_opset': quant_opts_pb2.XLA,
+          'input_shape_dynamic': False,
+          'enable_per_channel_quantization': True,
+      },
+      {
           'testcase_name': 'with_bias_and_relu6_to_xla_dynamic',
           'activation_fn': nn_ops.relu6,
           'has_bias': True,
@@ -2069,6 +2196,15 @@
           'enable_per_channel_quantization': False,
       },
       {
+          'testcase_name': 'with_bias_and_relu6_to_xla_dynamic_per_channel',
+          'activation_fn': nn_ops.relu6,
+          'has_bias': True,
+          'has_batch_norm': False,
+          'target_opset': quant_opts_pb2.XLA,
+          'input_shape_dynamic': True,
+          'enable_per_channel_quantization': True,
+      },
+      {
           'testcase_name': 'with_bias_and_bn_and_relu6_to_xla_dynamic',
           'activation_fn': nn_ops.relu6,
           'has_bias': True,
@@ -2078,6 +2214,17 @@
           'enable_per_channel_quantization': False,
       },
       {
+          'testcase_name': (
+              'with_bias_and_bn_and_relu6_to_xla_dynamic_per_channel'
+          ),
+          'activation_fn': nn_ops.relu6,
+          'has_bias': True,
+          'has_batch_norm': True,
+          'target_opset': quant_opts_pb2.XLA,
+          'input_shape_dynamic': True,
+          'enable_per_channel_quantization': True,
+      },
+      {
           'testcase_name': 'with_bias_and_relu6_to_uq',
           'activation_fn': nn_ops.relu6,
           'has_bias': True,
@@ -2172,6 +2319,28 @@
       self.assertTrue(
           self._contains_op(output_graphdef, 'DepthwiseConv2dNative')
       )
+      if enable_per_channel_quantization:
+        per_channel_size_attr = attr_value_pb2.AttrValue(
+            list=attr_value_pb2.AttrValue.ListValue(
+                shape=[
+                    tensor_shape_pb2.TensorShapeProto(
+                        dim=[
+                            tensor_shape_pb2.TensorShapeProto.Dim(
+                                size=filter_shape[-1] * filter_shape[2]
+                            )
+                        ]
+                    )
+                ]
+            )
+        )
+        self.assertTrue(
+            self._contains_op(
+                output_graphdef,
+                'Const',
+                '_output_shapes',
+                per_channel_size_attr,
+            )
+        )
     elif target_opset == quant_opts_pb2.UNIFORM_QUANTIZED:
       self.assertTrue(
           self._contains_op(output_graphdef, 'UniformQuantizedConvolution')
@@ -2320,9 +2489,19 @@
     loader = saved_model_loader.SavedModelLoader(self._output_saved_model_path)
     output_graphdef = loader.get_meta_graph_def_from_tags(tags).graph_def
     if target_opset == quant_opts_pb2.XLA:
-      self.assertTrue(self._contains_op(output_graphdef, 'XlaDotV2'))
+      self.assertTrue(
+          self._contains_op(
+              output_graphdef, 'XlaDotV2', node_name='sample/matmul.*'
+          )
+      )
     elif target_opset == quant_opts_pb2.UNIFORM_QUANTIZED:
-      self.assertTrue(self._contains_op(output_graphdef, 'UniformQuantizedDot'))
+      self.assertTrue(
+          self._contains_op(
+              output_graphdef,
+              'UniformQuantizedDot',
+              node_name='sample/matmul.*',
+          )
+      )
 
     new_outputs = converted_model.signatures['serving_default'](
         input_tensor=ops.convert_to_tensor(input_data)
@@ -2337,11 +2516,6 @@
     else:
       self.assertAllClose(new_outputs, expected_outputs, atol=0.13)
 
-  # NOTE: Isolated the most basic configuration from `test_matmul_ptq_model`
-  # for StableHLO PTQ prototype testing while integrating. Please note this
-  # test is for intermediate testing purposes as the migration is not complete.
-  # TODO: b/298581932 - Add the full test case for STABLEHLO opset once
-  # migration is complete.
   @test_util.run_in_graph_and_eager_modes
   def test_matmul_ptq_model_stablehlo(self):
     activation_fn = None
@@ -2353,7 +2527,7 @@
     input_shape = (*lhs_batch_size, 1, 1024)
     filter_shape = (*rhs_batch_size, 1024, 3)
     static_input_shape = [dim if dim is not None else 2 for dim in input_shape]
-    self._create_matmul_model(
+    model = self._create_matmul_model(
         input_shape,
         filter_shape,
         self._input_saved_model_path,
@@ -2362,38 +2536,47 @@
     )
     rng = np.random.default_rng(seed=1234)
 
+    input_data = ops.convert_to_tensor(
+        rng.uniform(low=0.0, high=1.0, size=static_input_shape).astype(
+            np.float32
+        )
+    )
+    expected_outputs = model.matmul(input_data)
+
     def data_gen() -> repr_dataset.RepresentativeDataset:
-      for _ in range(5):
+      for _ in range(100):
         yield {
             'input_tensor': rng.uniform(
                 low=0.0, high=1.0, size=static_input_shape
             ).astype(np.float32)
         }
 
-    tags = {tag_constants.SERVING}
-
     quantization_options = quant_opts_pb2.QuantizationOptions(
         quantization_method=quant_opts_pb2.QuantizationMethod(
             preset_method=_PresetMethod.METHOD_STATIC_RANGE_INT8
         ),
-        tags=tags,
+        tags={tag_constants.SERVING},
         signature_keys=['serving_default'],
         op_set=target_opset,
     )
-    # TODO: b/299545836 - Remove exception handling below after migrating
-    # StableHLO export passes.
-    with self.assertRaisesRegex(  # pylint: disable=g-error-prone-assert-raises
-        Exception,
-        "Failed to convert MLIR to GraphDef. op node 'quantfork.stats' was not"
-        ' a TF op',
-    ):
-      converted_model = quantize_model.quantize(
-          self._input_saved_model_path,
-          self._output_saved_model_path,
-          quantization_options,
-          representative_dataset=data_gen(),
-      )
-      self.assertIsNotNone(converted_model)
+    converted_model = quantize_model.quantize(
+        self._input_saved_model_path,
+        self._output_saved_model_path,
+        quantization_options,
+        representative_dataset=data_gen(),
+    )
+
+    self.assertIsNotNone(converted_model)
+    self.assertCountEqual(
+        converted_model.signatures._signatures.keys(), {'serving_default'}
+    )
+
+    new_outputs = converted_model.signatures['serving_default'](
+        input_tensor=ops.convert_to_tensor(input_data)
+    )
+    # Tests that the quantized graph outputs similar values. The rtol value is
+    # arbitrary.
+    self.assertAllClose(new_outputs, expected_outputs, rtol=0.02)
 
   @parameterized.named_parameters(
       {
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test_base.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test_base.py
index c8a9221..096b162 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test_base.py
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test_base.py
@@ -14,6 +14,7 @@
 # ==============================================================================
 """Base test class for quantize_model Tests."""
 import os
+import re
 from typing import Collection, Iterable, Mapping, Sequence, Tuple, Optional, Union, List
 
 from absl.testing import parameterized
@@ -169,6 +170,7 @@
       op_name: str,
       attr_name: str,
       attr_val: _AttrValType,
+      node_name: str = '',
   ) -> bool:
     """Determine whether there is a node whose operation name matches `op_name`.
 
@@ -180,15 +182,24 @@
       op_name: Name of the op to match.
       attr_name: Name of the attribute of the op to match.
       attr_val: Value of the attr_name to check.
+      node_name: Name of the node to match. Accepts regex2 format.
 
     Returns:
       True if there exists a node whose name matches `op_name` and 'attr_val' if
       'attr_name' is given.
     """
+
+    def match_node_name(name):
+      if not node_name:
+        return True
+      compiled_regex = re.compile(node_name)
+      match = re.fullmatch(compiled_regex, name)
+      return match is not None
+
     return any(
         node.attr.get(attr_name) == attr_val
         for node in nodes
-        if node.op == op_name
+        if node.op == op_name and match_node_name(node.name)
     )
 
   def _contains_quantized_function_call(
@@ -223,6 +234,7 @@
       op_name: str,
       attr_name: str = '',
       attr_val: _AttrValType = None,
+      node_name: str = '',
   ) -> bool:
     """Determines if the graph def contains the given op.
 
@@ -231,6 +243,7 @@
       op_name: Name of the operation to find within the graph.
       attr_name: Name of the attribute of the op to match.
       attr_val: Value of the attr_name to check.
+      node_name: Name of the node to match. Accepts regex2 format.
 
     Returns:
       True if and only if the graph def contains an op named `op_name`. If
@@ -243,6 +256,7 @@
         op_name=op_name,
         attr_name=attr_name,
         attr_val=attr_val,
+        node_name=node_name,
     ):
       return True
 
@@ -253,6 +267,7 @@
           op_name=op_name,
           attr_name=attr_name,
           attr_val=attr_val,
+          node_name=node_name,
       ):
         return True
     return False
@@ -1303,7 +1318,7 @@
         Returns:
           A map of: output key -> output result.
         """
-        out = math_ops.matmul(input_tensor, self.filters)
+        out = math_ops.matmul(input_tensor, self.filters, name='sample/matmul')
 
         if self.has_reshape():
           input_shape = input_tensor.shape
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc
index 2400cc7..c08aaf7 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc
@@ -12,22 +12,29 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#include <cstring>
 #include <optional>
 #include <string>
 #include <unordered_set>
 #include <utility>
 #include <vector>
 
+#include "absl/container/flat_hash_map.h"
+#include "absl/status/statusor.h"
 #include "absl/strings/str_format.h"
+#include "absl/strings/string_view.h"
+#include "pybind11/cast.h"  // from @pybind11
+#include "pybind11/detail/common.h"  // from @pybind11
+#include "pybind11/detail/descr.h"  // from @pybind11
 #include "pybind11/pybind11.h"  // from @pybind11
 #include "pybind11/pytypes.h"  // from @pybind11
-#include "pybind11/stl.h"  // from @pybind11
-#include "pybind11_abseil/absl_casters.h"  // from @pybind11_abseil
-#include "pybind11_abseil/status_casters.h"  // from @pybind11_abseil
+#include "pybind11/stl.h"  // from @pybind11  // IWYU pragma: keep
+#include "pybind11_abseil/absl_casters.h"  // from @pybind11_abseil   // IWYU pragma: keep
+#include "pybind11_abseil/import_status_module.h"  // from @pybind11_abseil
+#include "pybind11_abseil/status_casters.h"  // from @pybind11_abseil  // IWYU pragma: keep
 #include "pybind11_protobuf/native_proto_caster.h"  // from @pybind11_protobuf
 #include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.pb.h"
 #include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h"
+#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/id_assigner.h"
 #include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h"
 #include "tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h"
 #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h"
@@ -37,6 +44,7 @@
 
 using ::tensorflow::calibrator::CalibrationStatistics;
 using ::tensorflow::calibrator::CalibratorSingleton;
+using ::tensorflow::quantization::CustomAggregatorIdAssigner;
 using ::tensorflow::quantization::ExportedModel;
 using ::tensorflow::quantization::QuantizationOptions;
 using ::tensorflow::quantization::QuantizePtqDynamicRange;
@@ -80,8 +88,8 @@
 namespace pybind11 {
 namespace detail {
 
-// Converts `ExportedModel` (c++) to `bytes` (python). The resulting `bytes`
-// object is a serialization of `ExportedModel`.
+// Handles `ExportedModel` (c++) <-> `bytes` (python) conversion. The `bytes`
+// object in the python layer is a serialization of `ExportedModel`.
 //
 // See https://pybind11.readthedocs.io/en/stable/advanced/cast/custom.html for
 // further details on how custom type conversions work for pybind11.
@@ -90,6 +98,21 @@
  public:
   PYBIND11_TYPE_CASTER(ExportedModel, const_name("ExportedModel"));
 
+  // Loads an `ExportedModel` instance from a python `bytes` object (`src`).
+  bool load(handle src, const bool convert) {
+    auto caster = make_caster<absl::string_view>();
+    // Make sure the user passed a valid python string.
+    if (!caster.load(src, convert)) {
+      return false;
+    }
+
+    const absl::string_view exported_model_serialized =
+        cast_op<absl::string_view>(std::move(caster));
+
+    // NOLINTNEXTLINE: Explicit std::string conversion required for OSS.
+    return value.ParseFromString(std::string(exported_model_serialized));
+  }
+
   // Constructs a `bytes` object after serializing `src`.
   static handle cast(ExportedModel&& src, return_value_policy policy,
                      handle parent) {
@@ -97,6 +120,14 @@
     // destruction of py::bytes and returns a raw python object handle.
     return py::bytes(Serialize(src)).release();
   }
+
+  // Constructs a `bytes` object after serializing `src`.
+  static handle cast(const ExportedModel& src, return_value_policy policy,
+                     handle parent) {
+    // release() prevents the reference count from decreasing upon the
+    // destruction of py::bytes and returns a raw python object handle.
+    return py::bytes(Serialize(src)).release();
+  }
 };
 
 // Python -> cpp conversion for `QuantizationOptions`. Accepts a serialized
@@ -124,10 +155,30 @@
 }  // namespace detail
 }  // namespace pybind11
 
+namespace {
+
+// A "trampoline" class that redirects virtual function calls to the python
+// implementation.
+//
+// Reference:
+// https://pybind11.readthedocs.io/en/stable/advanced/classes.html#overriding-virtual-functions-in-python
+class CustomAggregatorIdAssignerTrampoline : public CustomAggregatorIdAssigner {
+ public:
+  using CustomAggregatorIdAssigner::CustomAggregatorIdAssigner;
+
+  ExportedModel AssignIds(const ExportedModel& exported_model) const override {
+    PYBIND11_OVERRIDE_PURE(ExportedModel, CustomAggregatorIdAssigner,
+                           assign_ids, exported_model);
+  }
+};
+
+}  // namespace
+
 PYBIND11_MODULE(pywrap_quantize_model, m) {
   // Supports absl::StatusOr<T> type conversions.
   pybind11::google::ImportStatusModule();
   pybind11_protobuf::ImportNativeProtoCasters();
+
   // Calibrator related functions.
   m.def(
       "clear_calibrator",
@@ -150,6 +201,14 @@
       Returns the proto CalibrationStatistics given id from calibrator.
     )pbdoc");
 
+  // Exports `CustomAggregatorIdAssigner` class. A pure virtual member function
+  // `AssignIds` is mapped to `assign_ids` in python, which is expected to be
+  // inherited and overridden.
+  py::class_<CustomAggregatorIdAssigner, CustomAggregatorIdAssignerTrampoline>(
+      m, "CustomAggregatorIdAssigner")
+      .def(py::init<>())
+      .def("assign_ids", &CustomAggregatorIdAssigner::AssignIds);
+
   // Quantization functions.
   m.def(
       "quantize_qat_model",
@@ -212,11 +271,17 @@
          const std::vector<std::string>& signature_keys,
          const std::unordered_set<std::string>& tags,
          const QuantizationOptions& quant_opts,
-         const absl::flat_hash_map<std::string, std::string>& function_aliases)
+         const absl::flat_hash_map<std::string, std::string>& function_aliases,
+         const CustomAggregatorIdAssigner& custom_aggregator_id_assigner)
           -> absl::StatusOr<ExportedModel> {
-        return QuantizePtqModelPreCalibration(saved_model_path, signature_keys,
-                                              tags, quant_opts,
-                                              function_aliases);
+        const absl::StatusOr<ExportedModel> exported_model =
+            QuantizePtqModelPreCalibration(saved_model_path, signature_keys,
+                                           tags, quant_opts, function_aliases);
+        if (!exported_model.ok()) {
+          return exported_model.status();
+        }
+
+        return custom_aggregator_id_assigner.AssignIds(*exported_model);
       },
       R"pbdoc(
       Returns serialized ExportedModel that contains the model's GraphDef and
@@ -224,6 +289,10 @@
       user should pass a serialized `QuantizationOptions` for the `quant_opts`
       argument.
 
+      The argument `custom_aggregator_id_assigner` is an instance of
+      `CustomAggregatorIdAssigner` whose virtual function `assign_ids` is
+      implemented in python.
+
       Raises `StatusNotOk` exception if when the run was unsuccessful.
     )pbdoc");
 
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py
index 43dbd28..b83461b 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py
@@ -740,6 +740,38 @@
   return None
 
 
+class CustomAggregatorIdAssigner(
+    pywrap_quantize_model.CustomAggregatorIdAssigner
+):
+  """Python impl. of `pywrap_quantize_model.CustomAggregatorIdAssigner`.
+
+  The interface is defined in the C++ layer, exposing a pure virtual function
+  `assign_ids`.
+  """
+
+  def assign_ids(self, exported_model_serialized: bytes) -> bytes:
+    """Assigns UUIDs to each CustomAggregator op find in the graph def.
+
+    Args:
+      exported_model_serialized: Serialized `ExportedModel` instance.
+
+    Returns:
+      Serialized `ExportedModel` whose CustomAggregator ops are assigned UUIDs
+      to their `id` attributes.
+    """
+    exported_model = exported_model_pb2.ExportedModel.FromString(
+        exported_model_serialized
+    )
+
+    graph_def = exported_model.graph_def
+    for function_def in graph_def.library.function:
+      for node_def in function_def.node_def:
+        if node_def.op == 'CustomAggregator':
+          node_def.attr['id'].s = uuid.uuid4().hex.encode('ascii')
+
+    return exported_model.SerializeToString()
+
+
 def _run_static_range_ptq(
     src_saved_model_path: str,
     dst_saved_model_path: str,
@@ -780,6 +812,7 @@
           set(quant_opts.tags),
           quant_opts.SerializeToString(),
           dict(function_aliases),
+          CustomAggregatorIdAssigner(),
       )
   )
 
@@ -788,11 +821,6 @@
   )
 
   graph_def = exported_model.graph_def
-  for function_def in graph_def.library.function:
-    for node_def in function_def.node_def:
-      if node_def.op == 'CustomAggregator':
-        node_def.attr['id'].s = uuid.uuid4().hex.encode('ascii')
-
   pre_calib_output_model_path = tempfile.mkdtemp()
   save_model.save_model_v1(
       graph_def,
@@ -1376,7 +1404,7 @@
       quantization_options.min_num_elements_for_weights = (
           _DYNAMIC_RANGE_DEFAULT_MIN_NUM_ELEMENTS_FOR_WEIGHTS
       )
-      logging.warn(
+      logging.warning(
           (
               'QuantizationOptions.min_num_elements_for_weights is not set (0).'
               ' Setting to the default value: %d.'
@@ -1384,15 +1412,23 @@
           _DYNAMIC_RANGE_DEFAULT_MIN_NUM_ELEMENTS_FOR_WEIGHTS,
       )
 
-  # TODO(b/281595329): Implement static range quantization per-channel support
+  # TODO: b/307900054 - Set the per-channel quantization by default.
   if quantization_options.enable_per_channel_quantization and not (
-      quantization_options.op_set == quant_opts_pb2.OpSet.UNIFORM_QUANTIZED
-      or quantization_options.quantization_method.preset_method
-      == _PresetMethod.METHOD_STATIC_RANGE_WEIGHT_ONLY_INT8
+      (
+          quantization_options.op_set == quant_opts_pb2.OpSet.UNIFORM_QUANTIZED
+          or quantization_options.quantization_method.preset_method
+          == _PresetMethod.METHOD_STATIC_RANGE_WEIGHT_ONLY_INT8
+      )
+      or (
+          quantization_options.op_set == quant_opts_pb2.OpSet.XLA
+          and quantization_options.quantization_method.preset_method
+          == _PresetMethod.METHOD_STATIC_RANGE_INT8
+      )
   ):
     raise ValueError(
-        'Currently, per-channel quantization is supported for Uniform '
-        'Quantized opset and Weight-only.'
+        'Currently, per-channel quantization is supported for Uniform Quantized'
+        ' opset, weight only quantization, or XLA opset with static range'
+        ' quantization.'
     )
 
   if (
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.proto b/tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.proto
index 01cf340..4384a13 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.proto
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.proto
@@ -289,9 +289,10 @@
   // If not set, it defaults to `true`.
   optional bool freeze_all_variables = 9;
 
-  // Enables chnanel-wise quantizaiton. By default, channel-wise quantization is
+  // Enables channel-wise quantization. By default, channel-wise quantization is
   // not applied regardless of the op support. Currently, it is supported for
-  // Uniform Quantized opset only.
+  // XLA opset for SRQ on weight tensors (not activation),
+  // and Uniform Quantized opset .
   bool enable_per_channel_quantization = 10;
 
   // Enables two inputs of an operation to be both tensors.
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc
index 6fb3c91..0b9cdc0 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc
@@ -52,7 +52,7 @@
 void AddStaticRangeQuantizationPass(
     mlir::PassManager &pm, const QuantizationOptions &quantization_options,
     std::optional<const absl::string_view> mlir_dump_file_prefix) {
-  // TODO: b/299545840 - Include QuantizeCompositeFunctionsPass as in bug.
+  pm.addPass(mlir::quant::stablehlo::createQuantizeCompositeFunctionsPass());
 }
 
 void AddConvertTpuToCpuModelPasses(mlir::PassManager &pm) {
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/add_dump_tensor_op.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/add_dump_tensor_op.mlir
index 4ee9b74..bd35462 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/add_dump_tensor_op.mlir
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/add_dump_tensor_op.mlir
@@ -23,21 +23,21 @@
   }
 
 // WholeModel-LABEL: func @conv
-// WholeModel-DAG: %[[w:.*]] = "tf.Const"() {value = dense<{{\[\[\[\[}}1.600000e-01, 1.000000e-01
-// WholeModel-DAG: %[[b:.*]] = "tf.Const"() {value = dense<[-2.000000e+00, 3.000000e+00
-// WholeModel-DAG: %[[output0:.*]] = "tf.PartitionedCall"(%arg0, %[[w]], %[[b]]) {config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2}
-// WholeModel-DAG: %[[output1:.*]] = "tf.PartitionedCall"(%arg0, %[[w]], %[[b]]) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1}
-// WholeModel-DAG: "tf.DumpTensor"(%[[output1]]) {enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "conv", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"} : (tensor<*xf32>) -> ()
+// WholeModel-DAG: %[[w:.*]] = "tf.Const"() <{value = dense<{{\[\[\[\[}}1.600000e-01, 1.000000e-01
+// WholeModel-DAG: %[[b:.*]] = "tf.Const"() <{value = dense<[-2.000000e+00, 3.000000e+00
+// WholeModel-DAG: %[[output0:.*]] = "tf.PartitionedCall"(%arg0, %[[w]], %[[b]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2}>
+// WholeModel-DAG: %[[output1:.*]] = "tf.PartitionedCall"(%arg0, %[[w]], %[[b]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1}> {_tfl_quant_trait = "fully_quantizable"}
+// WholeModel-DAG: "tf.DumpTensor"(%[[output1]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "conv", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}> : (tensor<*xf32>) -> ()
 // WholeModel-DAG: return %[[output0]], %[[output1]]
 
 // PerLayer-LABEL: func @conv
-// PerLayer-DAG: %[[w:.*]] = "tf.Const"() {value = dense<{{\[\[\[\[}}1.600000e-01, 1.000000e-01
-// PerLayer-DAG: %[[b:.*]] = "tf.Const"() {value = dense<[-2.000000e+00, 3.000000e+00
-// PerLayer-DAG: %[[output0:.*]] = "tf.PartitionedCall"(%arg0, %[[w]], %[[b]]) {config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2}
-// PerLayer-DAG: %[[output1_quantized:.*]] = "tf.PartitionedCall"(%arg0, %[[w]], %[[b]]) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1}
-// PerLayer-DAG: %[[output1_unquantized:.*]] = "tf.PartitionedCall"(%arg0, %cst, %cst_0) {config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1_0}
-// PerLayer-DAG: "tf.DumpTensor"(%[[output1_quantized]]) {enabled = false, file_name = "quantized_tensor_data.pb", func_name = "conv", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"} : (tensor<*xf32>) -> ()
-// PerLayer-DAG: "tf.DumpTensor"(%[[output1_unquantized]]) {enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "conv", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"} : (tensor<*xf32>) -> ()
+// PerLayer-DAG: %[[w:.*]] = "tf.Const"() <{value = dense<{{\[\[\[\[}}1.600000e-01, 1.000000e-01
+// PerLayer-DAG: %[[b:.*]] = "tf.Const"() <{value = dense<[-2.000000e+00, 3.000000e+00
+// PerLayer-DAG: %[[output0:.*]] = "tf.PartitionedCall"(%arg0, %[[w]], %[[b]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2}>
+// PerLayer-DAG: %[[output1_quantized:.*]] = "tf.PartitionedCall"(%arg0, %[[w]], %[[b]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1}> {_tfl_quant_trait = "fully_quantizable"}
+// PerLayer-DAG: %[[output1_unquantized:.*]] = "tf.PartitionedCall"(%arg0, %cst, %cst_0) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1_0}>
+// PerLayer-DAG: "tf.DumpTensor"(%[[output1_quantized]]) <{enabled = false, file_name = "quantized_tensor_data.pb", func_name = "conv", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}> : (tensor<*xf32>) -> ()
+// PerLayer-DAG: "tf.DumpTensor"(%[[output1_unquantized]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "conv", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}> : (tensor<*xf32>) -> ()
 // PerLayer-DAG: return %[[output0]], %[[output1_quantized]]
 }
 
@@ -69,28 +69,28 @@
   }
 
 // WholeModel-LABEL: func @multiple_conv2d
-// WholeModel-DAG: %[[b0:.*]] = "tf.Const"() {value = dense<0.000000e+00>
-// WholeModel-DAG: %[[b1:.*]] = "tf.Const"() {value = dense<1.000000e+00>
-// WholeModel-DAG: %[[w0:.*]] = "tf.Const"() {value = dense<{{\[\[\[\[}}0.193340182, 0.285152316
-// WholeModel-DAG: %[[w1:.*]] = "tf.Const"() {value = dense<{{\[\[\[\[}}-0.174680978, -0.367524445
-// WholeModel-DAG: %[[output0:.*]] = "tf.PartitionedCall"(%arg0, %[[w0]], %[[b0]]) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2}
-// WholeModel-DAG: "tf.DumpTensor"(%[[output0]]) {enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2", node_name = "Conv2D"}
-// WholeModel-DAG: %[[output1:.*]] = "tf.PartitionedCall"(%[[output0]], %[[w1]], %[[b1]]) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1}
-// WholeModel-DAG: "tf.DumpTensor"(%[[output1]]) {enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}
+// WholeModel-DAG: %[[b0:.*]] = "tf.Const"() <{value = dense<0.000000e+00>
+// WholeModel-DAG: %[[b1:.*]] = "tf.Const"() <{value = dense<1.000000e+00>
+// WholeModel-DAG: %[[w0:.*]] = "tf.Const"() <{value = dense<{{\[\[\[\[}}0.193340182, 0.285152316
+// WholeModel-DAG: %[[w1:.*]] = "tf.Const"() <{value = dense<{{\[\[\[\[}}-0.174680978, -0.367524445
+// WholeModel-DAG: %[[output0:.*]] = "tf.PartitionedCall"(%arg0, %[[w0]], %[[b0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2}> {_tfl_quant_trait = "fully_quantizable"}
+// WholeModel-DAG: "tf.DumpTensor"(%[[output0]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2", node_name = "Conv2D"}>
+// WholeModel-DAG: %[[output1:.*]] = "tf.PartitionedCall"(%[[output0]], %[[w1]], %[[b1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1}> {_tfl_quant_trait = "fully_quantizable"}
+// WholeModel-DAG: "tf.DumpTensor"(%[[output1]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}>
 // WholeModel-DAG: return %[[output1]]
 
 // PerLayer-LABEL: func @multiple_conv2d
-// PerLayer-DAG: %[[b0:.*]] = "tf.Const"() {value = dense<0.000000e+00>
-// PerLayer-DAG: %[[b1:.*]] = "tf.Const"() {value = dense<1.000000e+00>
-// PerLayer-DAG: %[[w0:.*]] = "tf.Const"() {value = dense<{{\[\[\[\[}}0.193340182, 0.285152316
-// PerLayer-DAG: %[[w1:.*]] = "tf.Const"() {value = dense<{{\[\[\[\[}}-0.174680978, -0.367524445
-// PerLayer-DAG: %[[output0_quantized:.*]] = "tf.PartitionedCall"(%arg0, %[[w0]], %[[b0]]) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2}
-// PerLayer-DAG: %[[output0_unquantized:.*]] = "tf.PartitionedCall"(%arg0, %[[w0]], %[[b0]]) {config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2_0}
-// PerLayer-DAG: "tf.DumpTensor"(%[[output0_quantized]]) {enabled = false, file_name = "quantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2", node_name = "Conv2D"}
-// PerLayer-DAG: "tf.DumpTensor"(%[[output0_unquantized]]) {enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2", node_name = "Conv2D"}
-// PerLayer-DAG: %[[output1_quantized:.*]] = "tf.PartitionedCall"(%[[output0_quantized]], %[[w1]], %[[b1]]) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1}
-// PerLayer-DAG: %[[output1_unquantized:.*]] = "tf.PartitionedCall"(%[[output0_quantized]], %[[w1]], %[[b1]]) {config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1_0}
-// PerLayer-DAG: "tf.DumpTensor"(%[[output1_quantized]]) {enabled = false, file_name = "quantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}
-// PerLayer-DAG: "tf.DumpTensor"(%[[output1_unquantized]]) {enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}
+// PerLayer-DAG: %[[b0:.*]] = "tf.Const"() <{value = dense<0.000000e+00>
+// PerLayer-DAG: %[[b1:.*]] = "tf.Const"() <{value = dense<1.000000e+00>
+// PerLayer-DAG: %[[w0:.*]] = "tf.Const"() <{value = dense<{{\[\[\[\[}}0.193340182, 0.285152316
+// PerLayer-DAG: %[[w1:.*]] = "tf.Const"() <{value = dense<{{\[\[\[\[}}-0.174680978, -0.367524445
+// PerLayer-DAG: %[[output0_quantized:.*]] = "tf.PartitionedCall"(%arg0, %[[w0]], %[[b0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2}> {_tfl_quant_trait = "fully_quantizable"}
+// PerLayer-DAG: %[[output0_unquantized:.*]] = "tf.PartitionedCall"(%arg0, %[[w0]], %[[b0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2_0}>
+// PerLayer-DAG: "tf.DumpTensor"(%[[output0_quantized]]) <{enabled = false, file_name = "quantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2", node_name = "Conv2D"}>
+// PerLayer-DAG: "tf.DumpTensor"(%[[output0_unquantized]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2", node_name = "Conv2D"}>
+// PerLayer-DAG: %[[output1_quantized:.*]] = "tf.PartitionedCall"(%[[output0_quantized]], %[[w1]], %[[b1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1}> {_tfl_quant_trait = "fully_quantizable"}
+// PerLayer-DAG: %[[output1_unquantized:.*]] = "tf.PartitionedCall"(%[[output0_quantized]], %[[w1]], %[[b1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1_0}>
+// PerLayer-DAG: "tf.DumpTensor"(%[[output1_quantized]]) <{enabled = false, file_name = "quantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}>
+// PerLayer-DAG: "tf.DumpTensor"(%[[output1_unquantized]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}>
 // PerLayer-DAG: return %[[output1_quantized]]
 }
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/cast_bf16_ops_to_f32.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/cast_bf16_ops_to_f32.mlir
index 4fc6cbf..deaafb3 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/cast_bf16_ops_to_f32.mlir
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/cast_bf16_ops_to_f32.mlir
@@ -10,8 +10,8 @@
 }
 
 // CHECK: func @cast_bf16_conv_to_fp32
-// CHECK-DAG: %[[cst:.*]] = "tf.Const"() {device = "", value = dense_resource<__elided__> : tensor<2x3x3x2xbf16>} : () -> tensor<2x3x3x2xbf16>
-// CHECK: %[[cast:.*]] = "tf.Cast"(%[[cst]]) {Truncate = false} : (tensor<2x3x3x2xbf16>) -> tensor<2x3x3x2xf32>
+// CHECK-DAG: %[[cst:.*]] = "tf.Const"() <{value = dense_resource<__elided__> : tensor<2x3x3x2xbf16>}> {device = ""} : () -> tensor<2x3x3x2xbf16>
+// CHECK: %[[cast:.*]] = "tf.Cast"(%[[cst]]) <{Truncate = false}> : (tensor<2x3x3x2xbf16>) -> tensor<2x3x3x2xf32>
 // CHECK: %[[conv:.*]] = "tf.Conv2D"(%arg0, %[[cast]])
 // CHECK: %[[identity:.*]] = "tf.IdentityN"(%[[conv]]) {device = ""} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32>
 // CHECK: return %[[identity]] : tensor<1x3x2x2xf32>
@@ -28,8 +28,8 @@
 }
 
 // CHECK: func @cast_bf16_conv_with_bias_to_fp32
-// CHECK-DAG: %[[cst:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32>
-// CHECK-DAG: %[[cst_0:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32>
+// CHECK-DAG: %[[cst:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32>
+// CHECK-DAG: %[[cst_0:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<2xf32>}> : () -> tensor<2xf32>
 // CHECK: %[[conv:.*]] = "tf.Conv2D"(%arg0, %[[cst]])
 // CHECK: %[[bias_add:.*]] = "tf.BiasAdd"(%[[conv]], %[[cst_0]])
 // CHECK: %[[identity:.*]] = "tf.IdentityN"(%[[bias_add]]) {device = ""} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32>
@@ -46,7 +46,7 @@
 }
 
 // CHECK: func @cast_bf16_avg_pool_to_fp32
-// CHECK-DAG: %[[cst:.*]] = "tf.Const"() {value = dense<{{.*}}> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32>
+// CHECK-DAG: %[[cst:.*]] = "tf.Const"() <{value = dense<{{.*}}> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32>
 // CHECK: %[[conv:.*]] = "tf.Conv2D"(%arg0, %[[cst]])
 // CHECK: %[[avg_pool:.*]] = "tf.AvgPool"(%[[conv]])
 // CHECK: %[[identity:.*]] = "tf.IdentityN"(%[[avg_pool]]) {device = ""} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32>
@@ -62,7 +62,7 @@
 }
 
 // CHECK: func @cast_bf16_matmul_to_fp32
-// CHECK-DAG: %[[cst:.*]] = "tf.Const"() {value = dense<{{.*}}> : tensor<10x2xf32>} : () -> tensor<10x2xf32>
+// CHECK-DAG: %[[cst:.*]] = "tf.Const"() <{value = dense<{{.*}}> : tensor<10x2xf32>}> : () -> tensor<10x2xf32>
 // CHECK: %[[matmul:.*]] = "tf.MatMul"(%arg0, %[[cst]])
 // CHECK: %[[identity:.*]] = "tf.IdentityN"(%[[matmul]])
 // CHECK: return %[[identity]] : tensor<1x2xf32>
@@ -77,7 +77,7 @@
 }
 
 // CHECK: func @cast_bf16_depthwise_conv_to_fp32
-// CHECK-DAG: %[[cst:.*]] = "tf.Const"() {value = dense<{{.*}}> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32>
+// CHECK-DAG: %[[cst:.*]] = "tf.Const"() <{value = dense<{{.*}}> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32>
 // CHECK: %[[depthwise_conv:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[cst]])
 // CHECK: %[[identity:.*]] = "tf.IdentityN"(%[[depthwise_conv]]) {device = ""} : (tensor<1x2x2x6xf32>) -> tensor<1x2x2x6xf32>
 // CHECK: return %[[identity]] : tensor<1x2x2x6xf32>
@@ -92,7 +92,7 @@
 }
 
 // CHECK: func @cast_bf16_batch_matmul_v2_to_fp32
-// CHECK-DAG: %[[cst:.*]] = "tf.Const"() {value = dense<{{.*}}> : tensor<10x2xf32>} : () -> tensor<10x2xf32>
+// CHECK-DAG: %[[cst:.*]] = "tf.Const"() <{value = dense<{{.*}}> : tensor<10x2xf32>}> : () -> tensor<10x2xf32>
 // CHECK: %[[batch_matmul:.*]] = "tf.BatchMatMulV2"(%arg0, %[[cst]])
 // CHECK: %[[identity:.*]] = "tf.IdentityN"(%[[batch_matmul]]) {device = ""} : (tensor<1x1x2xf32>) -> tensor<1x1x2xf32>
 // CHECK: return %[[identity]] : tensor<1x1x2xf32>
@@ -108,7 +108,7 @@
 // CHECK: func @cast_bf16_add_v2_to_fp32(%[[ARG_0:.*]]: tensor<2xbf16>, %[[ARG_1:.*]]: tensor<2xbf16>) -> tensor<2xf32>
 
 // bfloat16 operands are cast to f32 operands.
-// CHECK-DAG: %[[CAST_0:.*]] = "tf.Cast"(%[[ARG_0]]) {Truncate = false} : (tensor<2xbf16>) -> tensor<2xf32>
-// CHECK-DAG: %[[CAST_1:.*]] = "tf.Cast"(%[[ARG_1]]) {Truncate = false} : (tensor<2xbf16>) -> tensor<2xf32>
+// CHECK-DAG: %[[CAST_0:.*]] = "tf.Cast"(%[[ARG_0]]) <{Truncate = false}> : (tensor<2xbf16>) -> tensor<2xf32>
+// CHECK-DAG: %[[CAST_1:.*]] = "tf.Cast"(%[[ARG_1]]) <{Truncate = false}> : (tensor<2xbf16>) -> tensor<2xf32>
 // CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[CAST_0]], %[[CAST_1]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
 // CHECK: return %[[ADD]] : tensor<2xf32>
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/convert_tf_xla_op_to_tf_op.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/convert_tf_xla_op_to_tf_op.mlir
index d30c61f..27a7bb6 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/convert_tf_xla_op_to_tf_op.mlir
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/convert_tf_xla_op_to_tf_op.mlir
@@ -6,7 +6,7 @@
 }
 
 // CHECK: func @xla_dot_v2
-// CHECK: %[[einsum:.*]] = "tf.Einsum"(%arg0, %arg1) {equation = "abc,cde->abde"} : (tensor<?x2x3xf32>, tensor<3x4x5xf32>) -> tensor<?x2x4x5xf32>
+// CHECK: %[[einsum:.*]] = "tf.Einsum"(%arg0, %arg1) <{equation = "abc,cde->abde"}> : (tensor<?x2x3xf32>, tensor<3x4x5xf32>) -> tensor<?x2x4x5xf32>
 // CHECK: return %[[einsum]] : tensor<?x2x4x5xf32>
 
 // -----
@@ -22,12 +22,12 @@
 }
 
 // CHECK: func @xla_gather
-// CHECK-DAG: %[[cst:.*]] = "tf.Const"() {value = dense<0> : tensor<2xi64>} : () -> tensor<2xi64>
-// CHECK-DAG: %[[cst_0:.*]] = "tf.Const"() {value = dense<1> : tensor<1x1xi64>} : () -> tensor<1x1xi64>
-// CHECK-DAG: %[[cst_1:.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi64>} : () -> tensor<1xi64>
-// CHECK: %[[arg1_i64:.*]] = "tf.Cast"(%arg1) {Truncate = false} : (tensor<1xi32>) -> tensor<1xi64>
+// CHECK-DAG: %[[cst:.*]] = "tf.Const"() <{value = dense<0> : tensor<2xi64>}> : () -> tensor<2xi64>
+// CHECK-DAG: %[[cst_0:.*]] = "tf.Const"() <{value = dense<1> : tensor<1x1xi64>}> : () -> tensor<1x1xi64>
+// CHECK-DAG: %[[cst_1:.*]] = "tf.Const"() <{value = dense<-1> : tensor<1xi64>}> : () -> tensor<1xi64>
+// CHECK: %[[arg1_i64:.*]] = "tf.Cast"(%arg1) <{Truncate = false}> : (tensor<1xi32>) -> tensor<1xi64>
 // CHECK: %[[tensor_scatter_update:.*]] = "tf.TensorScatterUpdate"(%[[cst]], %[[cst_0]], %[[arg1_i64]]) : (tensor<2xi64>, tensor<1x1xi64>, tensor<1xi64>) -> tensor<2xi64>
-// CHECK: %[[arg2_i64:.*]] = "tf.Cast"(%arg2) {Truncate = false} : (tensor<2xi32>) -> tensor<2xi64>
+// CHECK: %[[arg2_i64:.*]] = "tf.Cast"(%arg2) <{Truncate = false}> : (tensor<2xi32>) -> tensor<2xi64>
 // CHECK: %[[slice:.*]] = "tf.Slice"(%arg0, %[[tensor_scatter_update]], %[[arg2_i64]]) : (tensor<?x2xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<*xf32>
 // CHECK: %[[reshape:.*]] = "tf.Reshape"(%[[slice]], %[[cst_1]]) : (tensor<*xf32>, tensor<1xi64>) -> tensor<*xf32>
 // CHECK: return %[[reshape]] : tensor<*xf32>
@@ -47,12 +47,12 @@
 }
 
 // CHECK: func @xla_gather_known_output_shape
-// CHECK-DAG: %[[cst:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi64>} : () -> tensor<1xi64>
-// CHECK-DAG: %[[cst_0:.*]] = "tf.Const"() {value = dense<0> : tensor<1x1xi64>} : () -> tensor<1x1xi64>
-// CHECK-DAG: %[[cst_1:.*]] = "tf.Const"() {value = dense<> : tensor<0xi64>} : () -> tensor<0xi64>
-// CHECK: %[[arg1_i64:.*]] = "tf.Cast"(%arg1) {Truncate = false} : (tensor<1xi32>) -> tensor<1xi64>
+// CHECK-DAG: %[[cst:.*]] = "tf.Const"() <{value = dense<0> : tensor<1xi64>}> : () -> tensor<1xi64>
+// CHECK-DAG: %[[cst_0:.*]] = "tf.Const"() <{value = dense<0> : tensor<1x1xi64>}> : () -> tensor<1x1xi64>
+// CHECK-DAG: %[[cst_1:.*]] = "tf.Const"() <{value = dense<> : tensor<0xi64>}> : () -> tensor<0xi64>
+// CHECK: %[[arg1_i64:.*]] = "tf.Cast"(%arg1) <{Truncate = false}> : (tensor<1xi32>) -> tensor<1xi64>
 // CHECK: %[[tensor_scatter_update:.*]] = "tf.TensorScatterUpdate"(%[[cst]], %[[cst_0]], %[[arg1_i64]]) : (tensor<1xi64>, tensor<1x1xi64>, tensor<1xi64>) -> tensor<1xi64>
-// CHECK: %[[arg2_i64:.*]] = "tf.Cast"(%arg2) {Truncate = false} : (tensor<1xi32>) -> tensor<1xi64>
+// CHECK: %[[arg2_i64:.*]] = "tf.Cast"(%arg2) <{Truncate = false}> : (tensor<1xi32>) -> tensor<1xi64>
 // CHECK: %[[slice:.*]] = "tf.Slice"(%arg0, %[[tensor_scatter_update]], %[[arg2_i64]]) : (tensor<5xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<1xi32>
 // CHECK: %[[reshape:.*]] = "tf.Reshape"(%[[slice]], %[[cst_1]]) : (tensor<1xi32>, tensor<0xi64>) -> tensor<i32>
 // CHECK: return %[[reshape]] : tensor<i32>
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/convert_tpu_model_to_cpu.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/convert_tpu_model_to_cpu.mlir
index 26809cd..ad13ebb 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/convert_tpu_model_to_cpu.mlir
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/convert_tpu_model_to_cpu.mlir
@@ -26,8 +26,8 @@
 }
 
 // CHECK: func @tpu_conv(%[[ARG0:.*]]: tensor<1x3x4x3xf32>)
-// CHECK-DAG: %[[cst:.*]] = "tf.Const"() {device = "", value = dense_resource<__elided__> : tensor<2x3x3x2xbf16>} : () -> tensor<2x3x3x2xbf16>
-// CHECK: %[[cast:.*]] = "tf.Cast"(%[[cst]]) {Truncate = false} : (tensor<2x3x3x2xbf16>) -> tensor<2x3x3x2xf32>
+// CHECK-DAG: %[[cst:.*]] = "tf.Const"() <{value = dense_resource<__elided__> : tensor<2x3x3x2xbf16>}> {device = ""} : () -> tensor<2x3x3x2xbf16>
+// CHECK: %[[cast:.*]] = "tf.Cast"(%[[cst]]) <{Truncate = false}> : (tensor<2x3x3x2xbf16>) -> tensor<2x3x3x2xf32>
 // CHECK: %[[conv:.*]] = "tf.Conv2D"(%[[ARG0]], %[[cast]])
 // CHECK: %[[identity:.*]] = "tf.IdentityN"(%[[conv]]) {device = ""} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32>
 // CHECK: return %[[identity]] : tensor<1x3x2x2xf32>
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/fake_quant_e2e_flow.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/fake_quant_e2e_flow.mlir
index b08c49f..aaddb72 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/fake_quant_e2e_flow.mlir
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/fake_quant_e2e_flow.mlir
@@ -13,18 +13,18 @@
 // CHECK-LABEL: @fake_quant_conv
 // CHECK-SAME: %[[ARG0:.*]]: tensor<1x3
 // CHECK-SAME: %[[ARG1:.*]]: tensor<2x3
-// CHECK-DAG: %[[CST:.*]] = "tf.Const"() {value = dense<0.00117647066> : tensor<f32>} : () -> tensor<f32>
-// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() {value = dense<-43> : tensor<i32>} : () -> tensor<i32>
-// CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() {value = dense<0.0117647061> : tensor<f32>} : () -> tensor<f32>
-// CHECK-DAG: %[[CST_2:.*]] = "tf.Const"() {value = dense<1.38408304E-5> : tensor<f32>} : () -> tensor<f32>
-// CHECK-DAG: %[[CST_3:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-// CHECK-DAG: %[[CST_4:.*]] = "tf.Const"() {value = dense<0.0027450982> : tensor<f32>} : () -> tensor<f32>
-// CHECK-DAG: %[[CST_5:.*]] = "tf.Const"() {value = dense<-19> : tensor<i32>} : () -> tensor<i32>
-// CHECK-DAG: %[[CST_6:.*]] = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> tensor<2xi32>
-// CHECK-NEXT: %[[V0:.*]] = "tf.PartitionedCall"(%[[ARG1]], %[[CST_1]], %[[CST_0]]) {config = "", config_proto = "", executor_type = "", f = @quantize_i8} : (tensor<2x3x3x2xf32>, tensor<f32>, tensor<i32>) -> tensor<2x3x3x2xi8>
-// CHECK-NEXT: %[[V1:.*]] = "tf.PartitionedCall"(%[[ARG0]], %[[CST]], %[[CST_0]]) {config = "", config_proto = "", executor_type = "", f = @quantize_i8} : (tensor<1x3x4x3xf32>, tensor<f32>, tensor<i32>) -> tensor<1x3x4x3xi8>
-// CHECK-NEXT: %[[V2:.*]] = "tf.PartitionedCall"(%[[V1]], %[[V0]], %[[CST_6]], %[[CST]], %[[CST_0]], %[[CST_1]], %[[CST_0]], %[[CST_2]], %[[CST_3]], %[[CST_4]], %[[CST_5]]) {config = "", config_proto = "", executor_type = "", f = @quantized_conv2d_with_bias_fn_0} : (tensor<1x3x4x3xi8>, tensor<2x3x3x2xi8>, tensor<2xi32>, tensor<f32>, tensor<i32>, tensor<f32>, tensor<i32>, tensor<f32>, tensor<i32>, tensor<f32>, tensor<i32>) -> tensor<*xi8>
-// CHECK-NEXT: %[[V3:.*]] = "tf.PartitionedCall"(%[[V2]], %[[CST_4]], %[[CST_5]]) {config = "", config_proto = "", executor_type = "", f = @dequantize_i8} : (tensor<*xi8>, tensor<f32>, tensor<i32>) -> tensor<*xf32>
+// CHECK-DAG: %[[CST:.*]] = "tf.Const"() <{value = dense<0.00117647066> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() <{value = dense<-43> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() <{value = dense<0.0117647061> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG: %[[CST_2:.*]] = "tf.Const"() <{value = dense<1.38408304E-5> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG: %[[CST_3:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG: %[[CST_4:.*]] = "tf.Const"() <{value = dense<0.0027450982> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG: %[[CST_5:.*]] = "tf.Const"() <{value = dense<-19> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG: %[[CST_6:.*]] = "tf.Const"() <{value = dense<0> : tensor<2xi32>}> : () -> tensor<2xi32>
+// CHECK-NEXT: %[[V0:.*]] = "tf.PartitionedCall"(%[[ARG1]], %[[CST_1]], %[[CST_0]]) <{config = "", config_proto = "", executor_type = "", f = @quantize_i8}> : (tensor<2x3x3x2xf32>, tensor<f32>, tensor<i32>) -> tensor<2x3x3x2xi8>
+// CHECK-NEXT: %[[V1:.*]] = "tf.PartitionedCall"(%[[ARG0]], %[[CST]], %[[CST_0]]) <{config = "", config_proto = "", executor_type = "", f = @quantize_i8}> : (tensor<1x3x4x3xf32>, tensor<f32>, tensor<i32>) -> tensor<1x3x4x3xi8>
+// CHECK-NEXT: %[[V2:.*]] = "tf.PartitionedCall"(%[[V1]], %[[V0]], %[[CST_6]], %[[CST]], %[[CST_0]], %[[CST_1]], %[[CST_0]], %[[CST_2]], %[[CST_3]], %[[CST_4]], %[[CST_5]]) <{config = "", config_proto = "", executor_type = "", f = @quantized_conv2d_with_bias_fn_0}> : (tensor<1x3x4x3xi8>, tensor<2x3x3x2xi8>, tensor<2xi32>, tensor<f32>, tensor<i32>, tensor<f32>, tensor<i32>, tensor<f32>, tensor<i32>, tensor<f32>, tensor<i32>) -> tensor<*xi8>
+// CHECK-NEXT: %[[V3:.*]] = "tf.PartitionedCall"(%[[V2]], %[[CST_4]], %[[CST_5]]) <{config = "", config_proto = "", executor_type = "", f = @dequantize_i8}> : (tensor<*xi8>, tensor<f32>, tensor<i32>) -> tensor<*xf32>
 // CHECK-NEXT: return %[[V3]] : tensor<*xf32>
 
 // CHECK: func private @quantize_i8(
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/fake_quant_e2e_xla.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/fake_quant_e2e_xla.mlir
index 5d9801e..e5c5d8a 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/fake_quant_e2e_xla.mlir
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/fake_quant_e2e_xla.mlir
@@ -25,7 +25,7 @@
 // CHECK: %[[pad:.*]] = "tf.PadV2"(%[[quant]]
 // CHECK: %[[xlaconv:.*]] = "tf.XlaConvV2"(%[[pad]]
 // CHECK: %[[sub:.*]] = "tf.Sub"(%[[xlaconv]]
-// CHECK: %[[cast:.*]] = "tf.Cast"(%[[sub]]) {Truncate = false} : (tensor<1x3x2x2xi32>) -> tensor<1x3x2x2xf32>
+// CHECK: %[[cast:.*]] = "tf.Cast"(%[[sub]]) <{Truncate = false}> : (tensor<1x3x2x2xi32>) -> tensor<1x3x2x2xf32>
 // CHECK: %[[dequant1:.*]] = "tf.Mul"(%[[cast]]
 // CHECK: %[[relu:.*]] = "tf.Relu"(%[[dequant1]]
 // CHECK: %[[clamped:.*]] = "tf.Minimum"(%[[relu]]
@@ -35,12 +35,12 @@
 // CHECK: %[[maximum2:.*]] = "tf.Maximum"(%[[add2]]
 // CHECK: %[[minimum2:.*]] = "tf.Minimum"(%[[maximum2]]
 // CHECK: %[[round2:.*]] = "tf.Round"(%[[minimum2]]
-// CHECK: %[[quant2:.*]] = "tf.Cast"(%[[round2]]) {Truncate = false} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xi8>
+// CHECK: %[[quant2:.*]] = "tf.Cast"(%[[round2]]) <{Truncate = false}> : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xi8>
 
 // CHECK: %[[pad2:.*]] = "tf.PadV2"(%[[quant2]]
 // CHECK: %[[xlaconv2:.*]] = "tf.XlaConvV2"(%[[pad2]]
 // CHECK: %[[sub2:.*]] = "tf.Sub"(%[[xlaconv2]]
-// CHECK: %[[cast2:.*]] = "tf.Cast"(%[[sub2]]) {Truncate = false} : (tensor<1x3x2x2xi32>) -> tensor<1x3x2x2xf32>
+// CHECK: %[[cast2:.*]] = "tf.Cast"(%[[sub2]]) <{Truncate = false}> : (tensor<1x3x2x2xi32>) -> tensor<1x3x2x2xf32>
 // CHECK: %[[rescale2:.*]] = "tf.Mul"(%[[cast2]]
 // CHECK: %[[rescale2_maxclamped:.*]] = "tf.Maximum"(%[[rescale2]]
 // CHECK: %[[rescale2_minclamped:.*]] = "tf.Minimum"(%[[rescale2_maxclamped]]
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_custom_aggregation_ops.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_custom_aggregation_ops.mlir
index dd1c5a4..fa74735 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_custom_aggregation_ops.mlir
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_custom_aggregation_ops.mlir
@@ -26,10 +26,10 @@
 
 // CalibrationOptions(calibration_method=CALIBRATION_METHOD_MIN_MAX)
 // MIN-MAX-CHECK: func @add_custom_ops
-// MIN-MAX-CHECK-NEXT:  [[rhs:%.*]] = "tf.CustomAggregator"(%arg1) {calibration_method = 1 : i32, id = "", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
-// MIN-MAX-CHECK-NEXT:  [[lhs:%.*]] = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
+// MIN-MAX-CHECK-NEXT:  [[rhs:%.*]] = "tf.CustomAggregator"(%arg1) <{id = ""}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
+// MIN-MAX-CHECK-NEXT:  [[lhs:%.*]] = "tf.CustomAggregator"(%arg0) <{id = ""}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
 // MIN-MAX-CHECK-NEXT:  [[add:%.*]] = "tf.AddV2"([[lhs]], [[rhs]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
-// MIN-MAX-CHECK-NEXT:  [[res:%.*]] = "tf.CustomAggregator"([[add]]) {calibration_method = 1 : i32, id = "", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
+// MIN-MAX-CHECK-NEXT:  [[res:%.*]] = "tf.CustomAggregator"([[add]]) <{id = ""}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
 // MIN-MAX-CHECK-NEXT:  return [[res]] : tensor<*xf32>
 
 // MIN-MAX-CHECK: func @no_custom_ops_on_non_f32_type
@@ -44,10 +44,10 @@
 
 // CalibrationOptions(calibration_method=CALIBRATION_METHOD_AVERAGE_MIN_MAX)
 // AVERAGE-MIN-MAX-CHECK: func @add_custom_ops
-// AVERAGE-MIN-MAX-CHECK-NEXT:  [[rhs:%.*]] = "tf.CustomAggregator"(%arg1) {calibration_method = 2 : i32, id = "", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
-// AVERAGE-MIN-MAX-CHECK-NEXT:  [[lhs:%.*]] = "tf.CustomAggregator"(%arg0) {calibration_method = 2 : i32, id = "", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
+// AVERAGE-MIN-MAX-CHECK-NEXT:  [[rhs:%.*]] = "tf.CustomAggregator"(%arg1) <{id = ""}> {calibration_method = 2 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
+// AVERAGE-MIN-MAX-CHECK-NEXT:  [[lhs:%.*]] = "tf.CustomAggregator"(%arg0) <{id = ""}> {calibration_method = 2 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
 // AVERAGE-MIN-MAX-CHECK-NEXT:  [[add:%.*]] = "tf.AddV2"([[lhs]], [[rhs]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
-// AVERAGE-MIN-MAX-CHECK-NEXT:  [[res:%.*]] = "tf.CustomAggregator"([[add]]) {calibration_method = 2 : i32, id = "", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
+// AVERAGE-MIN-MAX-CHECK-NEXT:  [[res:%.*]] = "tf.CustomAggregator"([[add]]) <{id = ""}> {calibration_method = 2 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
 // AVERAGE-MIN-MAX-CHECK-NEXT:  return [[res]] : tensor<*xf32>
 
 // AVERAGE-MIN-MAX-CHECK: func @no_custom_ops_on_non_f32_type
@@ -65,10 +65,10 @@
 //   calibration_parameters=CalibrationParameters(initial_num_bins=256, min_percentile=0.001, max_percentile=99.999)
 // )
 // HISTOGRAM-PERCENTILE-CHECK: func @add_custom_ops
-// HISTOGRAM-PERCENTILE-CHECK-NEXT:  [[rhs:%.*]] = "tf.CustomAggregator"(%arg1) {calibration_method = 3 : i32, id = "", initial_num_bins = 256 : i32, max_percentile = 9.999900e+01 : f32, min_percentile = 1.000000e-03 : f32} : (tensor<*xf32>) -> tensor<*xf32>
-// HISTOGRAM-PERCENTILE-CHECK-NEXT:  [[lhs:%.*]] = "tf.CustomAggregator"(%arg0) {calibration_method = 3 : i32, id = "", initial_num_bins = 256 : i32, max_percentile = 9.999900e+01 : f32, min_percentile = 1.000000e-03 : f32} : (tensor<*xf32>) -> tensor<*xf32>
+// HISTOGRAM-PERCENTILE-CHECK-NEXT:  [[rhs:%.*]] = "tf.CustomAggregator"(%arg1) <{id = ""}> {calibration_method = 3 : i32, initial_num_bins = 256 : i32, max_percentile = 9.999900e+01 : f32, min_percentile = 1.000000e-03 : f32} : (tensor<*xf32>) -> tensor<*xf32>
+// HISTOGRAM-PERCENTILE-CHECK-NEXT:  [[lhs:%.*]] = "tf.CustomAggregator"(%arg0) <{id = ""}> {calibration_method = 3 : i32, initial_num_bins = 256 : i32, max_percentile = 9.999900e+01 : f32, min_percentile = 1.000000e-03 : f32} : (tensor<*xf32>) -> tensor<*xf32>
 // HISTOGRAM-PERCENTILE-CHECK-NEXT:  [[add:%.*]] = "tf.AddV2"([[lhs]], [[rhs]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
-// HISTOGRAM-PERCENTILE-CHECK-NEXT:  [[res:%.*]] = "tf.CustomAggregator"([[add]]) {calibration_method = 3 : i32, id = "", initial_num_bins = 256 : i32, max_percentile = 9.999900e+01 : f32, min_percentile = 1.000000e-03 : f32} : (tensor<*xf32>) -> tensor<*xf32>
+// HISTOGRAM-PERCENTILE-CHECK-NEXT:  [[res:%.*]] = "tf.CustomAggregator"([[add]]) <{id = ""}> {calibration_method = 3 : i32, initial_num_bins = 256 : i32, max_percentile = 9.999900e+01 : f32, min_percentile = 1.000000e-03 : f32} : (tensor<*xf32>) -> tensor<*xf32>
 // HISTOGRAM-PERCENTILE-CHECK-NEXT:  return [[res]] : tensor<*xf32>
 
 // HISTOGRAM-PERCENTILE-CHECK: func @no_custom_ops_on_non_f32_type
@@ -86,10 +86,10 @@
 //   calibration_parameters=CalibrationParameters(initial_num_bins=256)
 // )
 // HISTOGRAM-MSE-BRUTEFORCE-CHECK: func @add_custom_ops
-// HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT:  [[rhs:%.*]] = "tf.CustomAggregator"(%arg1) {calibration_method = 4 : i32, id = "", initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
-// HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT:  [[lhs:%.*]] = "tf.CustomAggregator"(%arg0) {calibration_method = 4 : i32, id = "", initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
+// HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT:  [[rhs:%.*]] = "tf.CustomAggregator"(%arg1) <{id = ""}> {calibration_method = 4 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
+// HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT:  [[lhs:%.*]] = "tf.CustomAggregator"(%arg0) <{id = ""}> {calibration_method = 4 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
 // HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT:  [[add:%.*]] = "tf.AddV2"([[lhs]], [[rhs]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
-// HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT:  [[res:%.*]] = "tf.CustomAggregator"([[add]]) {calibration_method = 4 : i32, id = "", initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
+// HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT:  [[res:%.*]] = "tf.CustomAggregator"([[add]]) <{id = ""}> {calibration_method = 4 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
 // HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT:  return [[res]] : tensor<*xf32>
 
 // HISTOGRAM-MSE-BRUTEFORCE-CHECK: func @no_custom_ops_on_non_f32_type
@@ -107,10 +107,10 @@
 //   calibration_parameters=CalibrationParameters(initial_num_bins=256)
 // )
 // HISTOGRAM-MSE-MAX-FREQUENCY-CHECK: func @add_custom_ops
-// HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT:  [[rhs:%.*]] = "tf.CustomAggregator"(%arg1) {calibration_method = 5 : i32, id = "", initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
-// HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT:  [[lhs:%.*]] = "tf.CustomAggregator"(%arg0) {calibration_method = 5 : i32, id = "", initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
+// HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT:  [[rhs:%.*]] = "tf.CustomAggregator"(%arg1) <{id = ""}> {calibration_method = 5 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
+// HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT:  [[lhs:%.*]] = "tf.CustomAggregator"(%arg0) <{id = ""}> {calibration_method = 5 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
 // HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT:  [[add:%.*]] = "tf.AddV2"([[lhs]], [[rhs]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
-// HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT:  [[res:%.*]] = "tf.CustomAggregator"([[add]]) {calibration_method = 5 : i32, id = "", initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
+// HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT:  [[res:%.*]] = "tf.CustomAggregator"([[add]]) <{id = ""}> {calibration_method = 5 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
 // HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT:  return [[res]] : tensor<*xf32>
 
 // HISTOGRAM-MSE-MAX-FREQUENCY-CHECK: func @no_custom_ops_on_non_f32_type
@@ -128,10 +128,10 @@
 //   calibration_parameters=CalibrationParameters(initial_num_bins=256)
 // )
 // HISTOGRAM-MSE-SYMMETRIC-CHECK: func @add_custom_ops
-// HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT:  [[rhs:%.*]] = "tf.CustomAggregator"(%arg1) {calibration_method = 6 : i32, id = "", initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
-// HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT:  [[lhs:%.*]] = "tf.CustomAggregator"(%arg0) {calibration_method = 6 : i32, id = "", initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
+// HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT:  [[rhs:%.*]] = "tf.CustomAggregator"(%arg1) <{id = ""}> {calibration_method = 6 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
+// HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT:  [[lhs:%.*]] = "tf.CustomAggregator"(%arg0) <{id = ""}> {calibration_method = 6 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
 // HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT:  [[add:%.*]] = "tf.AddV2"([[lhs]], [[rhs]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
-// HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT:  [[res:%.*]] = "tf.CustomAggregator"([[add]]) {calibration_method = 6 : i32, id = "", initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
+// HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT:  [[res:%.*]] = "tf.CustomAggregator"([[add]]) <{id = ""}> {calibration_method = 6 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
 // HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT:  return [[res]] : tensor<*xf32>
 
 // HISTOGRAM-MSE-SYMMETRIC-CHECK: func @no_custom_ops_on_non_f32_type
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_main_function.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_main_function.mlir
index 5620836..cbd7224 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_main_function.mlir
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_main_function.mlir
@@ -25,7 +25,7 @@
     func.return %1 : tensor<1xf32>
   }
 // CHECK: func private @mul2(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> attributes {tf.entry_function = {inputs = "mul2_y:0,mul2_x:0", outputs = "PartitionedCall_1:0"}} {
-// CHECK:   %[[CONST_0:.*]] = "tf.Const"() {value = dense<2.000000e+00> : tensor<f32>} : () -> tensor<f32>
+// CHECK:   %[[CONST_0:.*]] = "tf.Const"() <{value = dense<2.000000e+00> : tensor<f32>}> : () -> tensor<f32>
 // CHECK:   %[[MUL_1:.*]] = "tf.Mul"(%arg1, %arg0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
 // CHECK:   %[[MUL_2:.*]] = "tf.Mul"(%[[MUL_1]], %[[CONST_0]]) : (tensor<1xf32>, tensor<f32>) -> tensor<1xf32>
 // CHECK:   return %[[MUL_2]] : tensor<1xf32>
@@ -33,8 +33,8 @@
 
 // CHECK: func @main(%arg0: tensor<1xf32> {tf_saved_model.index_path = ["mul1_y:0"]}, %arg1: tensor<1xf32> {tf_saved_model.index_path = ["mul1_x:0"]}, %arg2: tensor<1xf32> {tf_saved_model.index_path = ["mul2_y:0"]}, %arg3: tensor<1xf32> {tf_saved_model.index_path = ["mul2_x:0"]}) -> (tensor<1xf32> {tf_saved_model.index_path = ["PartitionedCall:0"]}, tensor<1xf32> {tf_saved_model.index_path = ["PartitionedCall_1:0"]}) attributes {tf.entry_function = {inputs = "mul1_y:0,mul1_x:0,mul2_y:0,mul2_x:0", outputs = "PartitionedCall:0,PartitionedCall_1:0"}, tf_saved_model.exported_names = ["main"]} {
 // CHECK-NOT: f = @NoOp
-// CHECK:   %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @mul1} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
-// CHECK:   %[[PARTITIONEDCALL_1:.*]] = "tf.PartitionedCall"(%arg2, %arg3) {config = "", config_proto = "", executor_type = "", f = @mul2} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+// CHECK:   %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %arg1) <{config = "", config_proto = "", executor_type = "", f = @mul1}> : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+// CHECK:   %[[PARTITIONEDCALL_1:.*]] = "tf.PartitionedCall"(%arg2, %arg3) <{config = "", config_proto = "", executor_type = "", f = @mul2}> : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
 // CHECK-DAG:   %[[IDENTITY_0:.*]] = "tf.Identity"(%[[PARTITIONEDCALL_0]])
 // CHECK-DAG:   %[[IDENTITY_1:.*]] = "tf.Identity"(%[[PARTITIONEDCALL_1]])
 // CHECK:   return %[[IDENTITY_0]], %[[IDENTITY_1]] : tensor<1xf32>, tensor<1xf32>
@@ -49,8 +49,8 @@
   "tf_saved_model.session_initializer"() {initializers = [@NoOp]} : () -> ()
   "tf_saved_model.asset"() {filename = "assets/mydata.txt", sym_name = "__tf_saved_model_asset0_mydata.txt"} : () -> ()
 // Session initializer ops and asset ops untouched.
-// CHECK: "tf_saved_model.session_initializer"() {initializers = [@NoOp]} : () -> ()
-// CHECK: "tf_saved_model.asset"() {filename = "assets/mydata.txt", sym_name = "__tf_saved_model_asset0_mydata.txt"} : () -> ()
+// CHECK: "tf_saved_model.session_initializer"() <{initializers = [@NoOp]}> : () -> ()
+// CHECK: "tf_saved_model.asset"() <{filename = "assets/mydata.txt", sym_name = "__tf_saved_model_asset0_mydata.txt"}> : () -> ()
 
   func.func @NoOp(%arg0: tensor<!tf_type.string> {tf_saved_model.bound_input = @__tf_saved_model_asset0_mydata.txt}) attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer_NoOp"]} {
     %0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf_type.string, shared_name = "", use_node_name_sharing = false, value_dtype = i64} : () -> tensor<!tf_type.resource>
@@ -82,10 +82,10 @@
 // CHECK-SAME: tf_saved_model.exported_names = ["main"]
 
 // Check that the function call to @add exists and not to @NoOp.
-// CHECK: %[[CALL0:.*]] = "tf.PartitionedCall"(%[[ARG0]], %[[ARG1]]) {
+// CHECK: %[[CALL0:.*]] = "tf.PartitionedCall"(%[[ARG0]], %[[ARG1]]) <{
 // CHECK-NOT: f = @NoOp
 // CHECK-SAME: f = @add
-// CHECK-SAME: }
+// CHECK-SAME: }>
 // CHECK-SAME: : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
 // CHECK: %[[IDENTITY:.*]] = "tf.Identity"(%[[CALL0]])
 // CHECK: return %[[IDENTITY]] : tensor<1xf32>
@@ -111,7 +111,7 @@
 // CHECK: func.func @main(%arg0: tensor<16xf32> {tf_saved_model.index_path = ["input:0"]}, %arg1: tensor<i32> {tf_saved_model.index_path = ["k:0"]})
 // CHECK-SAME: -> (tensor<?xf32> {tf_saved_model.index_path = ["TopK:0"]}, tensor<?xi32> {tf_saved_model.index_path = ["TopK:1"]})
 // CHECK-SAME: attributes {tf.entry_function = {inputs = "input:0,k:0", outputs = "TopK:0,TopK:1"}, tf_saved_model.exported_names = ["main"]}
-// CHECK: %[[CALL0:.*]]:2 = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @topk}
+// CHECK: %[[CALL0:.*]]:2 = "tf.PartitionedCall"(%arg0, %arg1) <{config = "", config_proto = "", executor_type = "", f = @topk}>
 // Expects an IdentityN op to be created.
 // CHECK: %[[IDENTITY:.*]]:2 = "tf.IdentityN"(%[[CALL0]]#0, %[[CALL0]]#1) : (tensor<?xf32>, tensor<?xi32>) -> (tensor<?xf32>, tensor<?xi32>)
 // CHECK: return %[[IDENTITY]]#0, %[[IDENTITY]]#1 : tensor<?xf32>, tensor<?xi32>
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_restore_op.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_restore_op.mlir
index 385052a..7f73eee 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_restore_op.mlir
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_restore_op.mlir
@@ -32,10 +32,10 @@
 
 // Test that RestoreV2 op is created with 1 resulting value.
 // CHECK: %[[RESTORE:.*]] = "tf.RestoreV2"(%[[ARG_0]], %[[CST_1]], %[[CST_2]]) : (tensor<!tf_type.string>, tensor<1x!tf_type.string>, tensor<1x!tf_type.string>) -> tensor<2xf32>
-// CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[RESTORE]]) {validate_shape = false} : (tensor<!tf_type.resource<tensor<2xf32>>>, tensor<2xf32>) -> ()
+// CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[RESTORE]]) <{validate_shape = false}> : (tensor<!tf_type.resource<tensor<2xf32>>>, tensor<2xf32>) -> ()
 
 // Test that the loc is properly set to it's shared_name.
-// CHECK-LOC: "tf.VarHandleOp"() {{{.*shared_name = "var_0".*}}}
+// CHECK-LOC: "tf.VarHandleOp"() <{{{.*shared_name = "var_0".*}}}>
 // CHECK-LOC-SAME: loc("var_0")
 }
 
@@ -66,19 +66,19 @@
 // CHECK-DAG: %[[VAR_HANDLE_0:.*]] = "tf.VarHandleOp"() {{.*shared_name = "var_0".*}} : () -> tensor<!tf_type.resource<tensor<2xf32>>>
 // CHECK-DAG: %[[VAR_HANDLE_1:.*]] = "tf.VarHandleOp"() {{.*shared_name = "var_1".*}} : () -> tensor<!tf_type.resource<tensor<4xi32>>>
 
-// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() {{{.*value = dense<\["var_0", "var_1"\]> : tensor<2x!tf_type.string>.*}}}
-// CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() {{{.*value = dense<""> : tensor<2x!tf_type.string>.*}}}
+// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() <{{{.*value = dense<\["var_0", "var_1"\]> : tensor<2x!tf_type.string>.*}}}>
+// CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() <{{{.*value = dense<""> : tensor<2x!tf_type.string>.*}}}>
 
 // Test that RestoreV2 op is created with 2 resulting values.
 // CHECK: %[[RESTORE:.*]]:2 = "tf.RestoreV2"(%[[ARG_0]], %[[CST_0]], %[[CST_1]]) : (tensor<!tf_type.string>, tensor<2x!tf_type.string>, tensor<2x!tf_type.string>) -> (tensor<2xf32>, tensor<4xi32>)
 
-// CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE_0]], %[[RESTORE]]#0) {validate_shape = false} : (tensor<!tf_type.resource<tensor<2xf32>>>, tensor<2xf32>) -> ()
-// CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE_1]], %[[RESTORE]]#1) {validate_shape = false} : (tensor<!tf_type.resource<tensor<4xi32>>>, tensor<4xi32>) -> ()
+// CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE_0]], %[[RESTORE]]#0) <{validate_shape = false}> : (tensor<!tf_type.resource<tensor<2xf32>>>, tensor<2xf32>) -> ()
+// CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE_1]], %[[RESTORE]]#1) <{validate_shape = false}> : (tensor<!tf_type.resource<tensor<4xi32>>>, tensor<4xi32>) -> ()
 
 // Test that the locs are properly set to their shared_names.
-// CHECK-LOC: "tf.VarHandleOp"() {{{.*shared_name = "var_0".*}}}
+// CHECK-LOC: "tf.VarHandleOp"() <{{{.*shared_name = "var_0".*}}}>
 // CHECK-LOC-SAME: loc("var_0")
-// CHECK-LOC: "tf.VarHandleOp"() {{{.*shared_name = "var_1".*}}}
+// CHECK-LOC: "tf.VarHandleOp"() <{{{.*shared_name = "var_1".*}}}>
 // CHECK-LOC-SAME: loc("var_1")
 }
 
@@ -101,11 +101,11 @@
 // Check that no function argument is created.
 // CHECK: func.func @init_func_init_op()
 
-// CHECK-DAG: %[[VAR_HANDLE:.*]] = "tf.VarHandleOp"() {{{.*shared_name = "var_0".*}}} : () -> tensor<!tf_type.resource<tensor<2xf32>>>
-// CHECK-DAG: %[[CST:.*]] = "tf.Const"() {{{.*value = dense<1.000000e\+00> : tensor<2xf32>.*}}}
+// CHECK-DAG: %[[VAR_HANDLE:.*]] = "tf.VarHandleOp"() <{{{.*shared_name = "var_0".*}}}> : () -> tensor<!tf_type.resource<tensor<2xf32>>>
+// CHECK-DAG: %[[CST:.*]] = "tf.Const"() <{{{.*value = dense<1.000000e\+00> : tensor<2xf32>.*}}}>
 // Make sure that "tf.RestoreV2" is not created.
 // CHECK-NOT: "tf.RestoreV2"
-// CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[CST]]) {validate_shape = false} : (tensor<!tf_type.resource<tensor<2xf32>>>, tensor<2xf32>) -> ()
+// CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[CST]]) <{validate_shape = false}> : (tensor<!tf_type.resource<tensor<2xf32>>>, tensor<2xf32>) -> ()
 
 // CHECK-LOC: @init_func_init_op
 // CHECK-LOC: return
@@ -140,19 +140,19 @@
 // CHECK-DAG: %[[VAR_HANDLE_0:.*]] = "tf.VarHandleOp"() {{.*shared_name = "var_0".*}} : () -> tensor<!tf_type.resource<tensor<2xf32>>>
 // CHECK-DAG: %[[VAR_HANDLE_1:.*]] = "tf.VarHandleOp"() {{.*shared_name = "var_1".*}} : () -> tensor<!tf_type.resource<tensor<2xf32>>>
 
-// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() {{{.*value = dense<\["var_0", "var_1"\]> : tensor<2x!tf_type.string>.*}}}
-// CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() {{{.*value = dense<""> : tensor<2x!tf_type.string>.*}}}
+// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() <{{{.*value = dense<\["var_0", "var_1"\]> : tensor<2x!tf_type.string>.*}}}>
+// CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() <{{{.*value = dense<""> : tensor<2x!tf_type.string>.*}}}>
 
 // Test that RestoreV2 op is created with 2 resulting values.
 // CHECK: %[[RESTORE:.*]]:2 = "tf.RestoreV2"(%[[ARG_0]], %[[CST_0]], %[[CST_1]]) : (tensor<!tf_type.string>, tensor<2x!tf_type.string>, tensor<2x!tf_type.string>) -> (tensor<2xf32>, tensor<2xf32>)
 
-// CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE_0]], %[[RESTORE]]#0) {validate_shape = false} : (tensor<!tf_type.resource<tensor<2xf32>>>, tensor<2xf32>) -> ()
-// CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE_1]], %[[RESTORE]]#1) {validate_shape = false} : (tensor<!tf_type.resource<tensor<2xf32>>>, tensor<2xf32>) -> ()
+// CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE_0]], %[[RESTORE]]#0) <{validate_shape = false}> : (tensor<!tf_type.resource<tensor<2xf32>>>, tensor<2xf32>) -> ()
+// CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE_1]], %[[RESTORE]]#1) <{validate_shape = false}> : (tensor<!tf_type.resource<tensor<2xf32>>>, tensor<2xf32>) -> ()
 
 // Test that the locs are properly set to their shared_names.
-// CHECK-LOC: "tf.VarHandleOp"() {{{.*shared_name = "var_0".*}}}
+// CHECK-LOC: "tf.VarHandleOp"() <{{{.*shared_name = "var_0".*}}}>
 // CHECK-LOC-SAME: loc("var_0")
-// CHECK-LOC: "tf.VarHandleOp"() {{{.*shared_name = "var_1".*}}}
+// CHECK-LOC: "tf.VarHandleOp"() <{{{.*shared_name = "var_1".*}}}>
 // CHECK-LOC-SAME: loc("var_1")
 }
 
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_save_op.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_save_op.mlir
index 9483331..c142247 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_save_op.mlir
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_save_op.mlir
@@ -30,8 +30,8 @@
 // CHECK: %[[VAR_HANDLE:.*]] = "tf.VarHandleOp"()
 // CHECK-SAME: {{.*shared_name = "var_0".*}}
 // CHECK: %[[READ_VARIABLE:.*]] = "tf.ReadVariableOp"(%[[VAR_HANDLE]]) : (tensor<!tf_type.resource<tensor<2xf32>>>) -> tensor<2xf32>
-// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() {{{.*value = dense<"var_0"> : tensor<1x!tf_type.string>.*}}}
-// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() {{{.*value = dense<""> : tensor<1x!tf_type.string>.*}}}
+// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{{{.*value = dense<"var_0"> : tensor<1x!tf_type.string>.*}}}>
+// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() <{{{.*value = dense<""> : tensor<1x!tf_type.string>.*}}}>
 // CHECK: "tf.SaveV2"(%[[ARG]], %[[CONST_0]], %[[CONST_1]], %[[READ_VARIABLE]])
 // CHECK: return
 }
@@ -73,8 +73,8 @@
 // CHECK-DAG: %[[READ_VARIABLE_0:.*]] = "tf.ReadVariableOp"(%[[VAR_HANDLE_0]]) : (tensor<!tf_type.resource<tensor<2xf32>>>) -> tensor<2xf32>
 // CHECK-DAG: %[[READ_VARIABLE_1:.*]] = "tf.ReadVariableOp"(%[[VAR_HANDLE_1]]) : (tensor<!tf_type.resource<tensor<3xf32>>>) -> tensor<3xf32>
 
-// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() {{{.*value = dense<\["var_0", "var_1"\]> : tensor<2x!tf_type.string>.*}}}
-// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() {{{.*value = dense<""> : tensor<2x!tf_type.string>.*}}}
+// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{{{.*value = dense<\["var_0", "var_1"\]> : tensor<2x!tf_type.string>.*}}}>
+// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() <{{{.*value = dense<""> : tensor<2x!tf_type.string>.*}}}>
 // CHECK: "tf.SaveV2"(%[[ARG]], %[[CONST_0]], %[[CONST_1]], %[[READ_VARIABLE_0]], %[[READ_VARIABLE_1]])
 // CHECK: return
 }
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/issue_ids_of_custom_aggregation_ops.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/issue_ids_of_custom_aggregation_ops.mlir
index 03fcbf4..4aa1ae7 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/issue_ids_of_custom_aggregation_ops.mlir
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/issue_ids_of_custom_aggregation_ops.mlir
@@ -10,8 +10,8 @@
 
 
 // CHECK: func @issue_ids
-// CHECK-NEXT:  [[rhs:%.*]] = "tf.CustomAggregator"(%arg1) {id = "0"} : (tensor<*xf32>) -> tensor<*xf32>
-// CHECK-NEXT:  [[lhs:%.*]] = "tf.CustomAggregator"(%arg0) {id = "1"} : (tensor<*xf32>) -> tensor<*xf32>
+// CHECK-NEXT:  [[rhs:%.*]] = "tf.CustomAggregator"(%arg1) <{id = "0"}> : (tensor<*xf32>) -> tensor<*xf32>
+// CHECK-NEXT:  [[lhs:%.*]] = "tf.CustomAggregator"(%arg0) <{id = "1"}> : (tensor<*xf32>) -> tensor<*xf32>
 // CHECK-NEXT:  [[add:%.*]] = "tf.AddV2"([[lhs]], [[rhs]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
-// CHECK-NEXT:  [[res:%.*]] = "tf.CustomAggregator"([[add]]) {id = "2"} : (tensor<*xf32>) -> tensor<*xf32>
+// CHECK-NEXT:  [[res:%.*]] = "tf.CustomAggregator"([[add]]) <{id = "2"}> : (tensor<*xf32>) -> tensor<*xf32>
 // CHECK-NEXT:  return [[res]] : tensor<*xf32>
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_quantizable_spots_as_functions.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_quantizable_spots_as_functions.mlir
index f61a9fb..6a7f9da 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_quantizable_spots_as_functions.mlir
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_quantizable_spots_as_functions.mlir
@@ -26,10 +26,10 @@
   %7 = "tf.BiasAdd"(%6, %cst) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32>
   func.return %2, %5, %7 : tensor<*xf32>, tensor<*xf32>, tensor<*xf32>
 
-// CHECK: %[[CONST_0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32>
+// CHECK: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<2xf32>}> : () -> tensor<2xf32>
 // CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %arg1, %[[CONST_0]])
-// CHECK-SAME: {_tfl_quant_trait = "fully_quantizable",
-// CHECK-SAME: f = @composite_conv2d_with_bias_and_relu6_fn_1}
+// CHECK-SAME: f = @composite_conv2d_with_bias_and_relu6_fn_1}>
+// CHECK-SAME: {_tfl_quant_trait = "fully_quantizable"
 // CHECK: %[[PARTITIONEDCALL_1:.*]] = "tf.PartitionedCall"(%arg0, %arg1, %[[CONST_0]])
 // CHECK-SAME: f = @composite_conv2d_with_bias_and_relu_fn_1}
 // CHECK: %[[PARTITIONEDCALL_2:.*]] = "tf.PartitionedCall"(%arg0, %arg1, %[[CONST_0]])
@@ -39,8 +39,8 @@
 
 // CHECK-LABEL: private @composite_conv2d_with_bias_and_relu6_fn_1
 // CHECK-NEXT: %[[CONV2D_0:.*]] = "tf.Conv2D"(%arg0, %arg1)
+// CHECK-SAME: data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true
 // CHECK-SAME: attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations"
-// CHECK-SAME: data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true
 // CHECK-NEXT: %[[BIASADD_0:.*]] = "tf.BiasAdd"(%[[CONV2D_0]], %arg2)
 // CHECK-NEXT: %[[RELU6_0:.*]] = "tf.Relu6"(%[[BIASADD_0]])
 // CHECK-NEXT: return %[[RELU6_0]]
@@ -70,15 +70,15 @@
 }
 
 // CHECK-LABEL: func @float_conv_strides_equals_to_dilations(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> tensor<*xf32> {
-// CHECK: %[[CONST_0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32>
-// CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %arg1, %[[CONST_0]]) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>, tensor<2xf32>) -> tensor<*xf32>
+// CHECK: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<2xf32>}> : () -> tensor<2xf32>
+// CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %arg1, %[[CONST_0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>, tensor<2xf32>) -> tensor<*xf32>
 // CHECK: return %[[PARTITIONEDCALL_0]] : tensor<*xf32>
 // CHECK: }
 
 // CHECK-LABEL: func private @composite_conv2d_with_bias_and_relu6_fn_1(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>, %arg2: tensor<2xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} {
 // CHECK-NEXT: %[[CONV2D_0:.*]] = "tf.Conv2D"(%arg0, %arg1)
+// CHECK-SAME: data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true
 // CHECK-SAME: attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations"
-// CHECK-SAME: data_format = "NHWC", device = "", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true
 // CHECK-NEXT: %[[BIASADD_0:.*]] = "tf.BiasAdd"(%[[CONV2D_0]], %arg2)
 // CHECK-NEXT: %[[RELU6_0:.*]] = "tf.Relu6"(%[[BIASADD_0]])
 // CHECK-NEXT: return %[[RELU6_0]]
@@ -111,14 +111,14 @@
   %7 = "tf.BiasAdd"(%6, %cst) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32>
   func.return %2, %5, %7 : tensor<*xf32>, tensor<*xf32>, tensor<*xf32>
 
-// CHECK: %[[CONST_0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32>
+// CHECK: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<2xf32>}> : () -> tensor<2xf32>
 // CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %arg1, %[[CONST_0]])
-// CHECK-SAME: _tfl_quant_trait = "fully_quantizable",
-// CHECK-SAME: f = @composite_depthwise_conv2d_with_bias_and_relu6_fn_1}
+// CHECK-SAME: f = @composite_depthwise_conv2d_with_bias_and_relu6_fn_1}>
+// CHECK-SAME: _tfl_quant_trait = "fully_quantizable"
 // CHECK: %[[PARTITIONEDCALL_1:.*]] = "tf.PartitionedCall"(%arg0, %arg1, %[[CONST_0]])
-// CHECK-SAME: f = @composite_depthwise_conv2d_with_bias_and_relu_fn_1}
+// CHECK-SAME: f = @composite_depthwise_conv2d_with_bias_and_relu_fn_1
 // CHECK: %[[PARTITIONEDCALL_2:.*]] = "tf.PartitionedCall"(%arg0, %arg1, %[[CONST_0]])
-// CHECK-SAME: f = @composite_depthwise_conv2d_with_bias_fn_1}
+// CHECK-SAME: f = @composite_depthwise_conv2d_with_bias_fn_1
 // CHECK: return %[[PARTITIONEDCALL_0]], %[[PARTITIONEDCALL_1]], %[[PARTITIONEDCALL_2]]
 // CHECK: }
 
@@ -161,14 +161,14 @@
   %7 = "tf.BiasAdd"(%6, %cst) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<10xf32>) -> tensor<*xf32>
   func.return %2, %5, %7 : tensor<*xf32>, tensor<*xf32>, tensor<*xf32>
 
-// CHECK: %[[CONST_0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<10xf32>}
+// CHECK: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<10xf32>}>
 // CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %arg1, %[[CONST_0]])
-// CHECK-SAME: {_tfl_quant_trait = "fully_quantizable",
-// CHECK-SAME: f = @composite_matmul_with_bias_and_relu6_fn_1}
+// CHECK-SAME: f = @composite_matmul_with_bias_and_relu6_fn_1}>
+// CHECK-SAME: {_tfl_quant_trait = "fully_quantizable"
 // CHECK: %[[PARTITIONEDCALL_1:.*]] = "tf.PartitionedCall"(%arg0, %arg1, %[[CONST_0]])
-// CHECK-SAME: f = @composite_matmul_with_bias_and_relu_fn_1}
+// CHECK-SAME: f = @composite_matmul_with_bias_and_relu_fn_1
 // CHECK: %[[PARTITIONEDCALL_2:.*]] = "tf.PartitionedCall"(%arg0, %arg1, %[[CONST_0]])
-// CHECK-SAME: f = @composite_matmul_with_bias_fn_1}
+// CHECK-SAME: f = @composite_matmul_with_bias_fn_1
 // CHECK: return %[[PARTITIONEDCALL_0]], %[[PARTITIONEDCALL_1]], %[[PARTITIONEDCALL_2]]
 // CHECK: }
 
@@ -207,10 +207,10 @@
   func.return %3 : tensor<*xf32>
 
 
-// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<10xf32>}
-// CHECK-DAG: %[[SHAPE:.*]] = "tf.Const"() {value = dense<[-1, 10]> : tensor<2xi32>}
+// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<10xf32>}>
+// CHECK-DAG: %[[SHAPE:.*]] = "tf.Const"() <{value = dense<[-1, 10]> : tensor<2xi32>}>
 // CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %arg1, %[[CONST_0]], %[[SHAPE]])
-// CHECK-SAME: f = @composite_matmul_with_reshape_and_bias_fn_1}
+// CHECK-SAME: f = @composite_matmul_with_reshape_and_bias_fn_1
 // CHECK: return %[[PARTITIONEDCALL_0]]
 // CHECK: }
 
@@ -247,14 +247,14 @@
   func.return %1, %4, %6 : tensor<*xf32>, tensor<*xf32>, tensor<*xf32>
 
 // CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %arg1)
-// CHECK-SAME: {_tfl_quant_trait = "fully_quantizable",
-// CHECK-SAME: f = @composite_conv2d_with_relu6_fn_1}
+// CHECK-SAME: f = @composite_conv2d_with_relu6_fn_1}>
+// CHECK-SAME: {_tfl_quant_trait = "fully_quantizable"
 
 // CHECK: %[[PARTITIONEDCALL_1:.*]] = "tf.PartitionedCall"(%arg0, %arg1)
-// CHECK-SAME: f = @composite_conv2d_with_relu_fn_1}
+// CHECK-SAME: f = @composite_conv2d_with_relu_fn_1
 
 // CHECK: %[[PARTITIONEDCALL_2:.*]] = "tf.PartitionedCall"(%arg0, %arg1)
-// CHECK-SAME: f = @composite_conv2d_fn_1}
+// CHECK-SAME: f = @composite_conv2d_fn_1
 // CHECK: return %[[PARTITIONEDCALL_0]], %[[PARTITIONEDCALL_1]], %[[PARTITIONEDCALL_2]]
 // CHECK: }
 
@@ -288,12 +288,12 @@
   func.return %1, %4, %6 : tensor<*xf32>, tensor<*xf32>, tensor<*xf32>
 
 // CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %arg1)
-// CHECK-SAME: {_tfl_quant_trait = "fully_quantizable",
-// CHECK-SAME: f = @composite_depthwise_conv2d_with_relu6_fn_1}
+// CHECK-SAME: f = @composite_depthwise_conv2d_with_relu6_fn_1}>
+// CHECK-SAME: {_tfl_quant_trait = "fully_quantizable"
 // CHECK: %[[PARTITIONEDCALL_1:.*]] = "tf.PartitionedCall"(%arg0, %arg1)
-// CHECK-SAME: f = @composite_depthwise_conv2d_with_relu_fn_1}
+// CHECK-SAME: f = @composite_depthwise_conv2d_with_relu_fn_1
 // CHECK: %[[PARTITIONEDCALL_2:.*]] = "tf.PartitionedCall"(%arg0, %arg1)
-// CHECK-SAME: f = @composite_depthwise_conv2d_fn_1}
+// CHECK-SAME: f = @composite_depthwise_conv2d_fn_1
 // CHECK: return %[[PARTITIONEDCALL_0]], %[[PARTITIONEDCALL_1]], %[[PARTITIONEDCALL_2]]
 // CHECK: }
 
@@ -323,12 +323,12 @@
   func.return %1, %4, %6 : tensor<*xf32>, tensor<*xf32>, tensor<*xf32>
 
 // CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %arg1)
-// CHECK-SAME: {_tfl_quant_trait = "fully_quantizable",
-// CHECK-SAME: f = @composite_matmul_with_relu6_fn_1}
+// CHECK-SAME: f = @composite_matmul_with_relu6_fn_1}>
+// CHECK-SAME: {_tfl_quant_trait = "fully_quantizable"
 // CHECK: %[[PARTITIONEDCALL_1:.*]] = "tf.PartitionedCall"(%arg0, %arg1)
-// CHECK-SAME: f = @composite_matmul_with_relu_fn_1}
+// CHECK-SAME: f = @composite_matmul_with_relu_fn_1
 // CHECK: %[[PARTITIONEDCALL_2:.*]] = "tf.PartitionedCall"(%arg0, %arg1)
-// CHECK-SAME: f = @composite_matmul_fn_1}
+// CHECK-SAME: f = @composite_matmul_fn_1
 // CHECK: return %[[PARTITIONEDCALL_0]], %[[PARTITIONEDCALL_1]], %[[PARTITIONEDCALL_2]]
 // CHECK: }
 
@@ -361,14 +361,14 @@
 // CHECK-DAG: %[[CST:.*]] = "tf.Const"() {{.*}} : () -> tensor<2x3x3x3x2xf32>
 
 // CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %[[CST]])
-// CHECK-SAME: {_tfl_quant_trait = "fully_quantizable",
-// CHECK-SAME: f = @composite_conv3d_with_relu_fn_1}
+// CHECK-SAME: f = @composite_conv3d_with_relu_fn_1}>
+// CHECK-SAME: {_tfl_quant_trait = "fully_quantizable"
 
 // CHECK: %[[PARTITIONEDCALL_1:.*]] = "tf.PartitionedCall"(%arg0, %[[CST]])
-// CHECK-SAME: f = @composite_conv3d_with_relu6_fn_1}
+// CHECK-SAME: f = @composite_conv3d_with_relu6_fn_1
 
 // CHECK: %[[PARTITIONEDCALL_2:.*]] = "tf.PartitionedCall"(%arg0, %[[CST]])
-// CHECK-SAME: f = @composite_conv3d_fn_1}
+// CHECK-SAME: f = @composite_conv3d_fn_1
 
 // CHECK: return %[[PARTITIONEDCALL_0]], %[[PARTITIONEDCALL_1]], %[[PARTITIONEDCALL_2]]
 
@@ -406,14 +406,14 @@
 // CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() {{.*}} : () -> tensor<2xf32>
 
 // CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %[[CST]], %[[CST_1]])
-// CHECK-SAME: {_tfl_quant_trait = "fully_quantizable",
-// CHECK-SAME: f = @composite_conv3d_with_bias_and_relu_fn_1}
+// CHECK-SAME: f = @composite_conv3d_with_bias_and_relu_fn_1}>
+// CHECK-SAME: {_tfl_quant_trait = "fully_quantizable"
 
 // CHECK: %[[PARTITIONEDCALL_1:.*]] = "tf.PartitionedCall"(%arg0, %[[CST]], %[[CST_1]])
-// CHECK-SAME: f = @composite_conv3d_with_bias_and_relu6_fn_1}
+// CHECK-SAME: f = @composite_conv3d_with_bias_and_relu6_fn_1
 
 // CHECK: %[[PARTITIONEDCALL_2:.*]] = "tf.PartitionedCall"(%arg0, %[[CST]], %[[CST_1]])
-// CHECK-SAME: f = @composite_conv3d_with_bias_fn_1}
+// CHECK-SAME: f = @composite_conv3d_with_bias_fn_1
 
 // CHECK: return %[[PARTITIONEDCALL_0]], %[[PARTITIONEDCALL_1]], %[[PARTITIONEDCALL_2]]
 
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_quantizable_spots_as_functions_drq.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_quantizable_spots_as_functions_drq.mlir
index 35215ec..305c634 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_quantizable_spots_as_functions_drq.mlir
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_quantizable_spots_as_functions_drq.mlir
@@ -12,10 +12,10 @@
   } : (tensor<1x12x12x512xf32>, tensor<1x12x12x512xf32>) -> tensor<*xf32>
   func.return %out_1, %out_2 : tensor<*xf32>, tensor<*xf32>
 
-// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<512x512xf32>} : () -> tensor<512x512xf32>
+// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<512x512xf32>}> : () -> tensor<512x512xf32>
 // CHECK: %[[PARTITIONEDCALL:.*]] = "tf.PartitionedCall"(%arg0, %[[CONST]])
-// CHECK-SAME: {_tfl_quant_trait = "fully_quantizable",
-// CHECK-SAME: f = @composite_matmul_fn_1}
+// CHECK-SAME: f = @composite_matmul_fn_1}>
+// CHECK-SAME: {_tfl_quant_trait = "fully_quantizable"
 // CHECK: %[[UNQUANTIZED_OUTPUT:.*]] = "tf.MatMul"(%arg0, %arg0)
 // CHECK: }
 
@@ -45,23 +45,23 @@
 
   func.return %2, %4 : tensor<*xf32>, tensor<*xf32>
 
-// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32>
-// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() {value = dense<3.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32>
+// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<2xf32>}> : () -> tensor<2xf32>
+// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() <{value = dense<3.000000e+00> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32>
 // CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %[[CONST_1]])
-// CHECK-SAME: {_tfl_quant_trait = "fully_quantizable",
-// CHECK-SAME: f = @composite_conv2d_fn_2}
+// CHECK-SAME: f = @composite_conv2d_fn_2}>
+// CHECK-SAME: {_tfl_quant_trait = "fully_quantizable"
 // CHECK: %[[BIASADD_0:.*]] = "tf.BiasAdd"(%[[PARTITIONEDCALL_0]], %[[CONST_0]])
 // CHECK: %[[RELU6_0:.*]] = "tf.Relu6"(%[[BIASADD_0]])
 // CHECK: %[[PARTITIONEDCALL_1:.*]] = "tf.PartitionedCall"(%arg0, %[[CONST_1]])
-// CHECK-SAME: f = @composite_conv2d_fn_1}
+// CHECK-SAME: f = @composite_conv2d_fn_1
 // CHECK: %[[BIASADD_1:.*]] = "tf.BiasAdd"(%[[PARTITIONEDCALL_1]], %[[CONST_0]])
 // CHECK: return %[[RELU6_0]], %[[BIASADD_1]]
 // CHECK: }
 
 // CHECK-LABEL: private @composite_conv2d_fn_2
 // CHECK-NEXT: %[[CONV2D_0:.*]] = "tf.Conv2D"(%arg0, %arg1)
+// CHECK-SAME: data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true
 // CHECK-SAME: attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations"
-// CHECK-SAME: data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true
 // CHECK-NEXT: return %[[CONV2D_0]]
 
 // CHECK-LABEL: private @composite_conv2d_fn_1
@@ -90,7 +90,7 @@
 
   func.return %2, %4 : tensor<*xf32>, tensor<*xf32>
 
-// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32>
+// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<2xf32>}> : () -> tensor<2xf32>
 // CHECK-NOT: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %arg1)
 // CHECK: %[[CONV2D_0:.*]] = "tf.Conv2D"(%arg0, %arg1)
 }
@@ -115,15 +115,15 @@
   %4 = "tf.BiasAdd"(%3, %cst) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32>
   func.return %2, %4 : tensor<*xf32>, tensor<*xf32>
 
-// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32>
-// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() {value = dense<3.000000e+00> : tensor<2x3x3x1xf32>} : () -> tensor<2x3x3x1xf32>
+// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<2xf32>}> : () -> tensor<2xf32>
+// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() <{value = dense<3.000000e+00> : tensor<2x3x3x1xf32>}> : () -> tensor<2x3x3x1xf32>
 // CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %[[CONST_1]])
-// CHECK-SAME: _tfl_quant_trait = "fully_quantizable",
-// CHECK-SAME: f = @composite_depthwise_conv2d_fn_2}
+// CHECK-SAME: f = @composite_depthwise_conv2d_fn_2}>
+// CHECK-SAME: {_tfl_quant_trait = "fully_quantizable"
 // CHECK: %[[BIASADD_0:.*]] = "tf.BiasAdd"(%[[PARTITIONEDCALL_0]], %[[CONST_0]])
 // CHECK: %[[RELU6_0:.*]] = "tf.Relu6"(%[[BIASADD_0]])
 // CHECK: %[[PARTITIONEDCALL_1:.*]] = "tf.PartitionedCall"(%arg0, %[[CONST_1]])
-// CHECK-SAME: f = @composite_depthwise_conv2d_fn_1}
+// CHECK-SAME: f = @composite_depthwise_conv2d_fn_1
 // CHECK: %[[BIASADD_0:.*]] = "tf.BiasAdd"(%[[PARTITIONEDCALL_1]], %[[CONST_0]])
 // CHECK: return %[[RELU6_0]], %[[BIASADD_0]]
 // CHECK: }
@@ -153,8 +153,8 @@
 
 // CHECK-DAG: %[[CST:.*]] = "tf.Const"() {{.*}} : () -> tensor<2x3x3x3x2xf32>
 // CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %[[CST]])
-// CHECK-NOT: {_tfl_quant_trait = "fully_quantizable",
-// CHECK-SAME: f = @composite_conv3d_fn_1}
+// CHECK-SAME: f = @composite_conv3d_fn_1}>
+// CHECK-NOT: {_tfl_quant_trait = "fully_quantizable"
 // CHECK: %[[RELU:.*]] = "tf.Relu"(%[[PARTITIONEDCALL_0]])
 // CHECK: return %[[RELU]]
 
@@ -162,8 +162,8 @@
 
 // WEIGHTONLY-DAG: %[[CST:.*]] = "tf.Const"() {{.*}} : () -> tensor<2x3x3x3x2xf32>
 // WEIGHTONLY: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %[[CST]])
-// WEIGHTONLY: {_tfl_quant_trait = "fully_quantizable",
-// WEIGHTONLY-SAME: f = @composite_conv3d_fn_1}
+// WEIGHTONLY-SAME: f = @composite_conv3d_fn_1}>
+// WEIGHTONLY: {_tfl_quant_trait = "fully_quantizable"
 // WEIGHTONLY: %[[RELU:.*]] = "tf.Relu"(%[[PARTITIONEDCALL_0]])
 // WEIGHTONLY: return %[[RELU]]
 
@@ -181,16 +181,16 @@
 
 // CHECK-DAG: %[[CST:.*]] = "tf.Const"() {{.*}} : () -> tensor<4x3x3xf32>
 // CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %[[CST]])
-// CHECK-NOT: {_tfl_quant_trait = "fully_quantizable",
-// CHECK-SAME: f = @composite_batch_matmul_fn_1}
+// CHECK-SAME: f = @composite_batch_matmul_fn_1}>
+// CHECK-NOT: {_tfl_quant_trait = "fully_quantizable"
 // CHECK: return %[[PARTITIONEDCALL_0]]
 
 // CHECK-LABEL: private @composite_batch_matmul_fn_1
 
 // WEIGHTONLY-DAG: %[[CST:.*]] = "tf.Const"() {{.*}} : () -> tensor<4x3x3xf32>
 // WEIGHTONLY: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %[[CST]])
-// WEIGHTONLY-SAME: {_tfl_quant_trait = "fully_quantizable",
-// WEIGHTONLY-SAME: f = @composite_batch_matmul_fn_1}
+// WEIGHTONLY-SAME: f = @composite_batch_matmul_fn_1}>
+// WEIGHTONLY-SAME: {_tfl_quant_trait = "fully_quantizable"
 // WEIGHTONLY: return %[[PARTITIONEDCALL_0]]
 
 // WEIGHTONLY-LABEL: private @composite_batch_matmul_fn_1
@@ -209,15 +209,15 @@
 // CHECK-DAG: %[[CST:.*]] = "tf.Const"() {{.*}} : () -> tensor<i32>
 // CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() {{.*}} : () -> tensor<128x32xf32>
 // CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%[[CST_1]], %arg0, %[[CST]])
-// CHECK-SAME: {_tfl_quant_trait = "fully_quantizable",
-// CHECK-SAME: f = @composite_gather_fn_1}
+// CHECK-SAME: f = @composite_gather_fn_1}>
+// CHECK-SAME: {_tfl_quant_trait = "fully_quantizable"
 // CHECK: return %[[PARTITIONEDCALL_0]]
 
 // WEIGHTONLY-DAG: %[[CST:.*]] = "tf.Const"() {{.*}} : () -> tensor<i32>
 // WEIGHTONLY-DAG: %[[CST_1:.*]] = "tf.Const"() {{.*}} : () -> tensor<128x32xf32>
 // WEIGHTONLY: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%[[CST_1]], %arg0, %[[CST]])
-// WEIGHTONLY-SAME: {_tfl_quant_trait = "fully_quantizable",
-// WEIGHTONLY-SAME: f = @composite_gather_fn_1}
+// WEIGHTONLY-SAME: f = @composite_gather_fn_1}>
+// WEIGHTONLY-SAME: {_tfl_quant_trait = "fully_quantizable"
 // WEIGHTONLY: return %[[PARTITIONEDCALL_0]]
 
 // WEIGHTONLY-LABEL: private @composite_gather_fn_1
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_quantizable_spots_as_functions_drq_min_elements.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_quantizable_spots_as_functions_drq_min_elements.mlir
index b110449..83d6b61 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_quantizable_spots_as_functions_drq_min_elements.mlir
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_quantizable_spots_as_functions_drq_min_elements.mlir
@@ -11,11 +11,11 @@
   } : (tensor<1x12x12x512xf32>, tensor<1x12x12x512xf32>) -> tensor<*xf32>
   func.return %out_1, %out_2 : tensor<*xf32>, tensor<*xf32>
 
-// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<512x512xf32>} : () -> tensor<512x512xf32>
+// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<512x512xf32>}> : () -> tensor<512x512xf32>
 // CHECK: %[[PARTITIONEDCALL:.*]] = "tf.PartitionedCall"(%arg0, %[[CONST]])
-// CHECK-NOT: {_tfl_quant_trait = "fully_quantizable",
-// CHECK-SAME: {config = "",
-// CHECK-SAME: f = @composite_matmul_fn_1}
+// CHECK-SAME: <{config = "",
+// CHECK-SAME: f = @composite_matmul_fn_1}>
+// CHECK-NOT: {_tfl_quant_trait = "fully_quantizable"
 // CHECK: %[[UNQUANTIZED_OUTPUT:.*]] = "tf.MatMul"(%arg0, %arg0)
 // CHECK: }
 
@@ -33,9 +33,9 @@
   } : (tensor<1x3x4x512xf32>, tensor<2x3x512x512xf32>) -> tensor<*xf32>
   func.return %0 : tensor<*xf32>
 
-// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<3.000000e+00> : tensor<2x3x512x512xf32>} : () -> tensor<2x3x512x512xf32>
+// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<3.000000e+00> : tensor<2x3x512x512xf32>}> : () -> tensor<2x3x512x512xf32>
 // CHECK: %[[PARTITIONEDCALL:.*]] = "tf.PartitionedCall"(%arg0, %[[CONST]])
-// CHECK-NOT: {_tfl_quant_trait = "fully_quantizable",
-// CHECK-SAME: {config = "",
-// CHECK-SAME: f = @composite_conv2d_fn_1}
+// CHECK-SAME: <{config = "",
+// CHECK-SAME: f = @composite_conv2d_fn_1}>
+// CHECK-NOT: {_tfl_quant_trait = "fully_quantizable"
 }
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_quantizable_spots_as_functions_xla.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_quantizable_spots_as_functions_xla.mlir
index 1d80a19..38911e2 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_quantizable_spots_as_functions_xla.mlir
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_quantizable_spots_as_functions_xla.mlir
@@ -14,15 +14,15 @@
 
 // CHECK-LABEL: func @depthwise_conv
 // CHECK: "tf.PartitionedCall"
+// CHECK-SAME: f = @composite_depthwise_conv2d_with_bias_and_relu6_fn_1
 // Check that the `_tfl_quant_trait` attribute has been removed.
 // CHECK-NOT: _tfl_quant_trait = "fully_quantizable"
-// CHECK-SAME: f = @composite_depthwise_conv2d_with_bias_and_relu6_fn_1
 
 // CHECK-LABEL: private @composite_depthwise_conv2d_with_bias_and_relu6_fn_1
 // CHECK: %[[DEPTHWISECONV2D_0:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %arg1)
+// CHECK-SAME: <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 2, 2, 1]}>
 // Check that the `attr_map` attribute has been removed.
 // CHECK-NOT: attr_map
-// CHECK-SAME: {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 2, 2, 1]}
 
 // -----
 
@@ -35,17 +35,17 @@
 }
 
 // CHECK-LABEL: func @conv_with_non_constant_filter
-// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32>
+// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<2xf32>}> : () -> tensor<2xf32>
 // CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %arg1, %[[CONST_0]])
+// CHECK-SAME: f = @composite_conv2d_with_bias_and_relu6_fn_1
 // Check that the `_tfl_quant_trait` attribute has been removed.
 // CHECK-NOT: _tfl_quant_trait = "fully_quantizable"
-// CHECK-SAME: f = @composite_conv2d_with_bias_and_relu6_fn_1
 
 // CHECK-LABEL: func private @composite_conv2d_with_bias_and_relu6_fn_1
 // CHECK: %[[CONV2D_0:.*]] = "tf.Conv2D"(%arg0, %arg1)
+// CHECK-SAME: data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1]
 // Check that the `attr_map` attribute has been removed.
 // CHECK-NOT: attr_map
-// CHECK-SAME: data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1]
 
 // -----
 
@@ -59,18 +59,18 @@
 }
 
 // CHECK-LABEL: func @conv_with_dynamic_channel_dim
-// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<1xf32>
+// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
 // CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() {{.*}} : () -> tensor<2x3x3x1xf32>
 // CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %[[CONST_1]], %[[CONST_0]])
+// CHECK-SAME: f = @composite_conv2d_with_bias_and_relu6_fn_1
 // Check that the `_tfl_quant_trait` attribute has been removed.
 // CHECK-NOT: _tfl_quant_trait = "fully_quantizable"
-// CHECK-SAME: f = @composite_conv2d_with_bias_and_relu6_fn_1
 
 // CHECK-LABEL: func private @composite_conv2d_with_bias_and_relu6_fn_1
 // CHECK: %[[CONV2D_0:.*]] = "tf.Conv2D"(%arg0, %arg1)
+// CHECK-SAME: data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1]
 // Check that the `attr_map` attribute has been removed.
 // CHECK-NOT: attr_map
-// CHECK-SAME: data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1]
 
 // -----
 
@@ -93,11 +93,11 @@
 
 // CHECK-LABEL: func @const_filter_with_q_dq
 // CHECK-DAG: %[[WEIGHT:.*]] = "tf.Const"() {{.*}} : () -> tensor<2x3x3x2xf32>
-// CHECK-DAG: %[[BIAS:.*]] = "tf.Const"() {device = "", value = dense<[1.000000e-01, 2.000000e-01]> : tensor<2xf32>}
+// CHECK-DAG: %[[BIAS:.*]] = "tf.Const"() <{value = dense<[1.000000e-01, 2.000000e-01]> : tensor<2xf32>}> {device = ""}
 // CHECK: %[[Q_W:.*]] = "quantfork.qcast"(%[[WEIGHT]])
 // CHECK: %[[DQ_W:.*]] = "quantfork.dcast"(%[[Q_W]])
 // CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"({{.*}}, %[[DQ_W]], %[[BIAS]])
-// CHECK-SAME: _tfl_quant_trait = "fully_quantizable"
 // CHECK-SAME: f = @composite_conv2d_with_bias_and_relu_fn_1
+// CHECK-SAME: _tfl_quant_trait = "fully_quantizable"
 
-// CHECK-LABEL: func private @composite_conv2d_with_bias_and_relu_fn_1
\ No newline at end of file
+// CHECK-LABEL: func private @composite_conv2d_with_bias_and_relu_fn_1
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_quantizable_spots_as_functions_xla_selective_quantization.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_quantizable_spots_as_functions_xla_selective_quantization.mlir
index 8dfd281..a2a86a2 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_quantizable_spots_as_functions_xla_selective_quantization.mlir
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_quantizable_spots_as_functions_xla_selective_quantization.mlir
@@ -16,9 +16,9 @@
 
 // CHECK-LABEL: func @conv2d_unmatching_unit
 // CHECK: "tf.PartitionedCall"
+// CHECK-SAME: f = @composite_conv2d_fn_1
 // Check that the `_tfl_quant_trait` attribute exists since the unit is not in `unit_wise_quantization_specs`.
 // CHECK-SAME: _tfl_quant_trait = "fully_quantizable"
-// CHECK-SAME: f = @composite_conv2d_fn_1
 // CHECK-SAME: loc(callsite("Model/conv2d@conv2d_unmatching_unit"("Conv2D") at "QuantizationUnit({{.*}})"))
 
 // -----
@@ -36,9 +36,9 @@
 
 // CHECK-LABEL: func @conv2d_disable_quantization
 // CHECK: "tf.PartitionedCall"
+// CHECK-SAME: f = @composite_conv2d_fn_1
 // Check that quantization is disabled for this unit.
 // CHECK-NOT: _tfl_quant_trait = "fully_quantizable"
-// CHECK-SAME: f = @composite_conv2d_fn_1
 // CHECK-SAME: loc(callsite("test_opt_out@conv2d_disable_quantization"("Conv2D") at "QuantizationUnit({{.*}})"))
 
 // -----
@@ -58,9 +58,9 @@
 
 // CHECK-LABEL: func @conv2d_with_bias_disable_quantization
 // CHECK: "tf.PartitionedCall"
+// CHECK-SAME: f = @composite_conv2d_with_bias_fn_1
 // Check that quantization is disabled for this unit.
 // CHECK-NOT: _tfl_quant_trait = "fully_quantizable"
-// CHECK-SAME: f = @composite_conv2d_with_bias_fn_1
 // CHECK-SAME: loc(callsite("test_opt_out@conv2d_with_bias_disable_quantization"("Conv2D") at "QuantizationUnit({{.*}})"))
 
 // -----
@@ -80,9 +80,9 @@
 
 // CHECK-LABEL: func @matmul_with_reshape_disable_quantization
 // CHECK: "tf.PartitionedCall"
+// CHECK-SAME: f = @composite_matmul_with_reshape_and_bias_fn_1
 // Check that quantization is disabled for this unit.
 // CHECK-NOT: _tfl_quant_trait = "fully_quantizable"
-// CHECK-SAME: f = @composite_matmul_with_reshape_and_bias_fn_1
 // CHECK-SAME: loc(callsite("test_opt_out@matmul_with_reshape_disable_quantization"("MatMul") at "QuantizationUnit({{.*}})"))
 
 // -----
@@ -105,8 +105,8 @@
 
 // CHECK-LABEL: func @serving_default
 // CHECK: "tf.PartitionedCall"
+// CHECK-SAME: f = @composite_conv2d_fn_1
 // Check that quantization is disabled for this unit.
 // CHECK-NOT: _tfl_quant_trait = "fully_quantizable"
-// CHECK-SAME: f = @composite_conv2d_fn_1
 // CHECK-SAME: loc(callsite("test_opt_out@conv2d_with_inliner"("Conv2D") at "QuantizationUnit({{.*}})"))
 }
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/merge_initializer_function_ops_to_main.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/merge_initializer_function_ops_to_main.mlir
index 0ca4745..9dde4fe 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/merge_initializer_function_ops_to_main.mlir
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/merge_initializer_function_ops_to_main.mlir
@@ -146,14 +146,14 @@
 // CHECK-NEXT: %[[OUT:.*]], %[[CTL:.*]] = tf_executor.island wraps "tf.PartitionedCall"(%[[ARG]])
 // CHECK-SAME: f = @serving_default
 // Checks that the contents of @NoOp are copied here.
-// CHECK-DAG: %[[OUT_0:.*]], %[[CTL_0:.*]] = tf_executor.island wraps "tf.Const"() {{{.*value = dense<"test_1">.*}}}
-// CHECK-DAG: %[[OUT_1:.*]], %[[CTL_1:.*]] = tf_executor.island wraps "tf.Const"() {{{.*value = dense<1>.*}}}
+// CHECK-DAG: %[[OUT_0:.*]], %[[CTL_0:.*]] = tf_executor.island wraps "tf.Const"() <{{{.*value = dense<"test_1">.*}}}>
+// CHECK-DAG: %[[OUT_1:.*]], %[[CTL_1:.*]] = tf_executor.island wraps "tf.Const"() <{{{.*value = dense<1>.*}}}>
 
 // CHECK-NEXT: %[[OUT_2:.*]], %[[CTL_2:.*]] = tf_executor.island wraps "tf.HashTableV2"()
 // CHECK-NEXT: %[[CTL_3:.*]] = tf_executor.island wraps "tf.LookupTableImportV2"(%[[OUT_2]], %[[OUT_0]], %[[OUT_1]])
 
-// CHECK-DAG: %[[OUT_3:.*]], %[[CTL_4:.*]] = tf_executor.island wraps "tf.Const"() {{{.*value = dense<"test_2">.*}}}
-// CHECK-DAG: %[[OUT_4:.*]], %[[CTL_5:.*]] = tf_executor.island wraps "tf.Const"() {{{.*value = dense<2>.*}}}
+// CHECK-DAG: %[[OUT_3:.*]], %[[CTL_4:.*]] = tf_executor.island wraps "tf.Const"() <{{{.*value = dense<"test_2">.*}}}>
+// CHECK-DAG: %[[OUT_4:.*]], %[[CTL_5:.*]] = tf_executor.island wraps "tf.Const"() <{{{.*value = dense<2>.*}}}>
 
 // CHECK-NEXT: %[[OUT_5:.*]], %[[CTL_6:.*]] = tf_executor.island(%[[CTL_3]]) wraps "tf.HashTableV2"()
 // CHECK-NEXT: %[[CTL_7:.*]] = tf_executor.island wraps "tf.LookupTableImportV2"(%[[OUT_5]], %[[OUT_3]], %[[OUT_4]])
@@ -303,7 +303,7 @@
 // CHECK-LABEL: module
 module attributes {tf_saved_model.semantics} {
   "tf_saved_model.session_initializer"() {initializers = [@init_func_restore_op]} : () -> ()
-// CHECK: "tf_saved_model.session_initializer"() {initializers = []}
+// CHECK: "tf_saved_model.session_initializer"() <{initializers = []}>
 
   func.func @init_func_restore_op(%arg: tensor<!tf_type.string> {tf_saved_model.index_path = ["__tf_file_prefix"]})
     attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer_NoOp"], tf_saved_model.initializer_type = "restore_op"} {
@@ -330,9 +330,9 @@
 // CHECK-NEXT: tf_executor.graph
 
 // Checks that the ops from @init_func_restore_op are cloned.
-// CHECK-DAG: %[[CONST_0:.*]], %[[CTL:.*]] = tf_executor.island wraps "tf.Const"() {{{.*value = dense<""> : tensor<1x!tf_type\.string>.*}}}
-// CHECK-DAG: %[[CONST_1:.*]], %[[CTL_0:.*]] = tf_executor.island wraps "tf.Const"() {{{.*value = dense<"var_0"> : tensor<1x!tf_type\.string>.*}}}
-// CHECK: %[[VAR_HANDLE:.*]], %[[CTL_1:.*]] = tf_executor.island wraps "tf.VarHandleOp"() {{{.*shared_name = "var_0".*}}}
+// CHECK-DAG: %[[CONST_0:.*]], %[[CTL:.*]] = tf_executor.island wraps "tf.Const"() <{{{.*value = dense<""> : tensor<1x!tf_type\.string>.*}}}>
+// CHECK-DAG: %[[CONST_1:.*]], %[[CTL_0:.*]] = tf_executor.island wraps "tf.Const"() <{{{.*value = dense<"var_0"> : tensor<1x!tf_type\.string>.*}}}>
+// CHECK: %[[VAR_HANDLE:.*]], %[[CTL_1:.*]] = tf_executor.island wraps "tf.VarHandleOp"() <{{{.*shared_name = "var_0".*}}}>
 // CHECK: %[[RESTORE:.*]], %[[CTL_2:.*]] = tf_executor.island wraps "tf.RestoreV2"(%[[ARG]], %[[CONST_1]], %[[CONST_0]])
 // CHECK: %[[CTL_3:.*]] = tf_executor.island wraps "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[RESTORE]])
 // CHECK: %[[CTL_4:.*]] = tf_executor.island(%[[CTL_3]]) wraps "tf.NoOp"()
@@ -383,7 +383,7 @@
 // CHECK-LABEL: module
 module attributes {tf_saved_model.semantics} {
   "tf_saved_model.session_initializer"() {initializers = [@init_func_restore_op]} : () -> ()
-// CHECK: "tf_saved_model.session_initializer"() {initializers = []}
+// CHECK: "tf_saved_model.session_initializer"() <{initializers = []}>
 
   func.func @init_func_restore_op(%arg: tensor<!tf_type.string> {tf_saved_model.index_path = ["file_prefix"]})
     attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer_NoOp"], tf_saved_model.initializer_type = "restore_op"} {
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/merge_save_function_ops_to_main.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/merge_save_function_ops_to_main.mlir
index 5341e15..bc0b283 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/merge_save_function_ops_to_main.mlir
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/merge_save_function_ops_to_main.mlir
@@ -32,10 +32,10 @@
 // CHECK: func.func @main
 // CHECK-SAME: %[[ARG_0:.*]]: tensor<!tf_type.string> {tf_saved_model.index_path = ["__tf_file_prefix"]}
 // CHECK: tf_executor.graph
-// CHECK: %[[VAR_HANDLE:.*]], {{.*}} = tf_executor.island wraps "tf.VarHandleOp"() {{{.*shared_name = "var_0".*}}}
+// CHECK: %[[VAR_HANDLE:.*]], {{.*}} = tf_executor.island wraps "tf.VarHandleOp"() <{{{.*shared_name = "var_0".*}}}>
 // CHECK: %[[READ_VARIABLE:.*]], {{.*}} = tf_executor.island wraps "tf.ReadVariableOp"(%[[VAR_HANDLE]])
-// CHECK-DAG: %[[CST_0:.*]], {{.*}} = tf_executor.island wraps "tf.Const"() {{{.*value = dense<"var_0"> : tensor<1x!tf_type\.string>.*}}}
-// CHECK-DAG: %[[CST_1:.*]], {{.*}} = tf_executor.island wraps "tf.Const"() {{{.*value = dense<""> : tensor<1x!tf_type\.string>.*}}}
+// CHECK-DAG: %[[CST_0:.*]], {{.*}} = tf_executor.island wraps "tf.Const"() <{{{.*value = dense<"var_0"> : tensor<1x!tf_type\.string>.*}}}>
+// CHECK-DAG: %[[CST_1:.*]], {{.*}} = tf_executor.island wraps "tf.Const"() <{{{.*value = dense<""> : tensor<1x!tf_type\.string>.*}}}>
 // CHECK: %[[CTL_0:.*]] = tf_executor.island wraps "tf.SaveV2"(%[[ARG_0]], %[[CST_0]], %[[CST_1]], %[[READ_VARIABLE]]) : (tensor<!tf_type.string>, tensor<1x!tf_type.string>, tensor<1x!tf_type.string>, tensor<2xf32>) -> ()
 
 // Test that the Identity op has been created to fetch the file prefix
@@ -150,10 +150,10 @@
 // CHECK-SAME: %[[ARG_0:.*]]: tensor<!tf_type.string> {tf_saved_model.index_path = ["__tf_file_prefix"]}
 // CHECK-SAME: tf.entry_function = {inputs = "__tf_file_prefix:0", outputs = ""}
 // CHECK: tf_executor.graph
-// CHECK: %[[VAR_HANDLE:.*]], {{.*}} = tf_executor.island wraps "tf.VarHandleOp"() {{{.*shared_name = "var_0".*}}}
+// CHECK: %[[VAR_HANDLE:.*]], {{.*}} = tf_executor.island wraps "tf.VarHandleOp"() <{{{.*shared_name = "var_0".*}}}>
 // CHECK: %[[READ_VARIABLE:.*]], {{.*}} = tf_executor.island wraps "tf.ReadVariableOp"(%[[VAR_HANDLE]])
-// CHECK-DAG: %[[CST_0:.*]], {{.*}} = tf_executor.island wraps "tf.Const"() {{{.*value = dense<"var_0"> : tensor<1x!tf_type\.string>.*}}}
-// CHECK-DAG: %[[CST_1:.*]], {{.*}} = tf_executor.island wraps "tf.Const"() {{{.*value = dense<""> : tensor<1x!tf_type\.string>.*}}}
+// CHECK-DAG: %[[CST_0:.*]], {{.*}} = tf_executor.island wraps "tf.Const"() <{{{.*value = dense<"var_0"> : tensor<1x!tf_type\.string>.*}}}>
+// CHECK-DAG: %[[CST_1:.*]], {{.*}} = tf_executor.island wraps "tf.Const"() <{{{.*value = dense<""> : tensor<1x!tf_type\.string>.*}}}>
 // CHECK: %[[CTL_0:.*]] = tf_executor.island wraps "tf.SaveV2"(%[[ARG_0]], %[[CST_0]], %[[CST_1]], %[[READ_VARIABLE]]) : (tensor<!tf_type.string>, tensor<1x!tf_type.string>, tensor<1x!tf_type.string>, tensor<2xf32>) -> ()
 
 // Test that the Identity op has been created to fetch the file prefix
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/optimize.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/optimize.mlir
index 2a69627..48a5c43 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/optimize.mlir
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/optimize.mlir
@@ -56,11 +56,11 @@
 
 // CHECK: %[[CLIPBYVALUE_0:.*]] = "tf.ClipByValue"
 // CHECK-SAME: (tensor<1x100x100x1xi32>, tensor<i32>, tensor<i32>) -> tensor<1x100x100x1xi32>
-// CHECK: %[[CAST_1:.*]] = "tf.Cast"(%[[CLIPBYVALUE_0]]) {Truncate = false} : (tensor<1x100x100x1xi32>) -> tensor<1x100x100x1xf32>
+// CHECK: %[[CAST_1:.*]] = "tf.Cast"(%[[CLIPBYVALUE_0]]) <{Truncate = false}> : (tensor<1x100x100x1xi32>) -> tensor<1x100x100x1xf32>
 
 // CHECK: %[[CLIPBYVALUE_1:.*]] = "tf.ClipByValue"
 // CHECK-SAME: (tensor<1x98x98x1xi32>, tensor<i32>, tensor<i32>) -> tensor<1x98x98x1xi32>
-// CHECK: %[[CAST_3:.*]] = "tf.Cast"(%[[CLIPBYVALUE_1]]) {Truncate = false} : (tensor<1x98x98x1xi32>) -> tensor<1x98x98x1xf32>
+// CHECK: %[[CAST_3:.*]] = "tf.Cast"(%[[CLIPBYVALUE_1]]) <{Truncate = false}> : (tensor<1x98x98x1xi32>) -> tensor<1x98x98x1xf32>
 
 // CHECK: %[[CLIPBYVALUE_2:.*]] = "tf.ClipByValue"
 // CHECK-SAME: (tensor<1x96x96x1xi32>, tensor<i32>, tensor<i32>) -> tensor<1x96x96x1xi32>
@@ -76,7 +76,7 @@
 
 // CHECK-LABEL: func.func @consecutive_add_add
 
-// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<-30> : tensor<i32>} : () -> tensor<i32>
+// CHECK: %[[CST:.*]] = "tf.Const"() <{value = dense<-30> : tensor<i32>}> : () -> tensor<i32>
 // CHECK: %[[ADD:.*]] = "tf.AddV2"(%arg0, %[[CST]]) : (tensor<i32>, tensor<i32>) -> tensor<i32>
 // CHECK: return %[[ADD]] : tensor<i32>
 }
@@ -90,7 +90,7 @@
 
 // CHECK-LABEL: func.func @consecutive_add_sub
 
-// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<6> : tensor<i32>} : () -> tensor<i32>
+// CHECK: %[[CST:.*]] = "tf.Const"() <{value = dense<6> : tensor<i32>}> : () -> tensor<i32>
 // CHECK: %[[SUB:.*]] = "tf.Sub"(%arg0, %[[CST]]) : (tensor<i32>, tensor<i32>) -> tensor<i32>
 // CHECK: return %[[SUB]] : tensor<i32>
 }
@@ -104,7 +104,7 @@
 
 // CHECK-LABEL: func.func @consecutive_sub_add
 
-// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<6> : tensor<i32>} : () -> tensor<i32>
+// CHECK: %[[CST:.*]] = "tf.Const"() <{value = dense<6> : tensor<i32>}> : () -> tensor<i32>
 // CHECK: %[[ADD:.*]] = "tf.AddV2"(%arg0, %[[CST]]) : (tensor<i32>, tensor<i32>) -> tensor<i32>
 // CHECK: return %[[ADD]] : tensor<i32>
 }
@@ -118,7 +118,7 @@
 
 // CHECK-LABEL: func.func @consecutive_sub_sub
 
-// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<-30> : tensor<i32>} : () -> tensor<i32>
+// CHECK: %[[CST:.*]] = "tf.Const"() <{value = dense<-30> : tensor<i32>}> : () -> tensor<i32>
 // CHECK: %[[SUB:.*]] = "tf.Sub"(%arg0, %[[CST]]) : (tensor<i32>, tensor<i32>) -> tensor<i32>
 // CHECK: return %[[SUB]] : tensor<i32>
 }
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_lifting.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_lifting.mlir
index c99fed3..1e771e2 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_lifting.mlir
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_lifting.mlir
@@ -7,8 +7,8 @@
   func.return %add : tensor<*xf32>
 }
 // CHECK: func @decompose_batch_norm
-// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<2.49743462E-5> : tensor<2xf32>} : () -> tensor<2xf32>
-// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() {value = dense<0.999950051> : tensor<2xf32>} : () -> tensor<2xf32>
+// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<2.49743462E-5> : tensor<2xf32>}> : () -> tensor<2xf32>
+// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<0.999950051> : tensor<2xf32>}> : () -> tensor<2xf32>
 // CHECK: %[[mul:.*]] = "tf.Mul"(%arg0, %[[CONST_0]]) : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32>
 // CHECK: %[[add:.*]] = "tf.AddV2"(%[[mul]], %[[CONST]]) : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32>
 // CHECK-NEXT: return %[[add]] : tensor<*xf32>
@@ -22,9 +22,9 @@
   func.return %bn : tensor<*xf32>
 }
 // CHECK: func @not_decompose_batch_norm
-// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32>
-// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() {value = dense<5.000000e-01> : tensor<2xf32>} : () -> tensor<2xf32>
-// CHECK: %[[bn:.*]], %batch_mean, %batch_variance, %reserve_space_1, %reserve_space_2, %reserve_space_3 = "tf.FusedBatchNormV3"(%arg0, %[[CONST]], %[[CONST_0]], %[[CONST_0]], %[[CONST]]) {data_format = "NHWC", device = "", epsilon = 9.99999974E-5 : f32, exponential_avg_factor = 1.000000e+00 : f32, is_training = true} : (tensor<*xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>)
+// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<2xf32>}> : () -> tensor<2xf32>
+// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<5.000000e-01> : tensor<2xf32>}> : () -> tensor<2xf32>
+// CHECK: %[[bn:.*]], %batch_mean, %batch_variance, %reserve_space_1, %reserve_space_2, %reserve_space_3 = "tf.FusedBatchNormV3"(%arg0, %[[CONST]], %[[CONST_0]], %[[CONST_0]], %[[CONST]]) <{data_format = "NHWC", epsilon = 9.99999974E-5 : f32, exponential_avg_factor = 1.000000e+00 : f32, is_training = true}> {device = ""} : (tensor<*xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>)
 // CHECK-NEXT: return %[[bn]] : tensor<*xf32>
 
 // -----
@@ -37,10 +37,10 @@
   func.return %1 : tensor<1x3x2x2xf32>
 }
 // CHECK: func @convert_add_to_biasadd
-// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32>
-// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() {value = dense<5.000000e-01> : tensor<2xf32>} : () -> tensor<2xf32>
-// CHECK: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[CONST]]) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32>
-// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[CONV2D]], %[[CONST_0]]) {data_format = "NHWC"} : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32>
+// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32>
+// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<5.000000e-01> : tensor<2xf32>}> : () -> tensor<2xf32>
+// CHECK: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[CONST]]) <{data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true}> : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32>
+// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[CONV2D]], %[[CONST_0]]) <{data_format = "NHWC"}> : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32>
 // CHECK-NEXT: return %[[BIASADD]] : tensor<1x3x2x2xf32>
 
 // -----
@@ -53,9 +53,9 @@
   func.return %1 : tensor<1x3x2x3xf32>
 }
 // CHECK: func @not_convert_add_to_biasadd
-// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<2x3x3x3xf32>} : () -> tensor<2x3x3x3xf32>
-// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() {value = dense<5.000000e-01> : tensor<1x3x2x3xf32>} : () -> tensor<1x3x2x3xf32>
-// CHECK-NEXT: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[CONST]]) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x3xf32>) -> tensor<1x3x2x3xf32>
+// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<2x3x3x3xf32>}> : () -> tensor<2x3x3x3xf32>
+// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<5.000000e-01> : tensor<1x3x2x3xf32>}> : () -> tensor<1x3x2x3xf32>
+// CHECK-NEXT: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[CONST]]) <{data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true}> : (tensor<1x3x4x3xf32>, tensor<2x3x3x3xf32>) -> tensor<1x3x2x3xf32>
 // CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[CONV2D]], %[[CONST_0]]) : (tensor<1x3x2x3xf32>, tensor<1x3x2x3xf32>) -> tensor<1x3x2x3xf32>
 // CHECK-NEXT: return %[[ADD]] : tensor<1x3x2x3xf32>
 
@@ -69,8 +69,8 @@
   func.return %1 : tensor<1x3x2x2xf32>
 }
 // CHECK: func @fuse_conv2d_and_mul
-// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<8.000000e-01> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32>
-// CHECK-NEXT: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[CONST]]) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32>
+// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<8.000000e-01> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32>
+// CHECK-NEXT: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[CONST]]) <{data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true}> : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32>
 // CHECK-NEXT: return %[[CONV2D]] : tensor<1x3x2x2xf32>
 
 // -----
@@ -83,9 +83,9 @@
   func.return %1 : tensor<1x3x2x2xf32>
 }
 // CHECK: func @not_fuse_conv2d_and_mul
-// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32>
-// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() {value = dense<4.000000e-01> : tensor<2x2xf32>} : () -> tensor<2x2xf32>
-// CHECK-NEXT: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[CONST]]) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32>
+// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<2.000000e+00> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32>
+// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<4.000000e-01> : tensor<2x2xf32>}> : () -> tensor<2x2xf32>
+// CHECK-NEXT: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[CONST]]) <{data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true}> : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32>
 // CHECK-NEXT: %[[ADD:.*]] = "tf.Mul"(%[[CONV2D]], %[[CONST_0]]) : (tensor<1x3x2x2xf32>, tensor<2x2xf32>) -> tensor<1x3x2x2xf32>
 // CHECK-NEXT: return %[[ADD]] : tensor<1x3x2x2xf32>
 
@@ -101,10 +101,10 @@
   func.return %2 : tensor<1x3x2x2xf32>
 }
 // CHECK: func @fuse_conv2d_with_bias_and_mul
-// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32>
-// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() {value = dense<2.000000e-01> : tensor<2xf32>} : () -> tensor<2xf32>
-// CHECK-NEXT: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[CONST]]) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32>
-// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[CONV2D]], %[[CONST_0]]) {data_format = "NHWC"} : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32>
+// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32>
+// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<2.000000e-01> : tensor<2xf32>}> : () -> tensor<2xf32>
+// CHECK-NEXT: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[CONST]]) <{data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true}> : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32>
+// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[CONV2D]], %[[CONST_0]]) <{data_format = "NHWC"}> : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32>
 // CHECK-NEXT: return %[[BIASADD]] : tensor<1x3x2x2xf32>
 
 // -----
@@ -119,11 +119,11 @@
   func.return %1, %2 : tensor<1x3x2x2xf32>, tensor<1x3x2x2xf32>
 }
 // CHECK: func @not_fuse_conv2d_with_bias_and_mul
-// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32>
-// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() {value = dense<4.000000e-01> : tensor<2xf32>} : () -> tensor<2xf32>
-// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() {value = dense<8.000000e-01> : tensor<2xf32>} : () -> tensor<2xf32>
-// CHECK-NEXT: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[CONST]]) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32>
-// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[CONV2D]], %[[CONST_0]]) {data_format = "NHWC"} : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32>
+// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<2.000000e+00> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32>
+// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<4.000000e-01> : tensor<2xf32>}> : () -> tensor<2xf32>
+// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() <{value = dense<8.000000e-01> : tensor<2xf32>}> : () -> tensor<2xf32>
+// CHECK-NEXT: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[CONST]]) <{data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true}> : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32>
+// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[CONV2D]], %[[CONST_0]]) <{data_format = "NHWC"}> : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32>
 // CHECK-NEXT: %[[MUL:.*]] = "tf.Mul"(%[[CONV2D]], %[[CONST_1]]) : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32>
 // CHECK-NEXT: return %[[BIASADD]], %[[MUL]] : tensor<1x3x2x2xf32>, tensor<1x3x2x2xf32>
 
@@ -139,10 +139,10 @@
   func.return %2 : tensor<1x3x2x2xf32>
 }
 // CHECK: func @fuse_conv2d_with_bias_and_add
-// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32>
-// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32>
-// CHECK-NEXT: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[CONST]]) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32>
-// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[CONV2D]], %[[CONST_0]]) {data_format = "NHWC"} : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32>
+// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<2.000000e+00> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32>
+// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<2xf32>}> : () -> tensor<2xf32>
+// CHECK-NEXT: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[CONST]]) <{data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true}> : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32>
+// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[CONV2D]], %[[CONST_0]]) <{data_format = "NHWC"}> : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32>
 // CHECK-NEXT: return %[[BIASADD]] : tensor<1x3x2x2xf32>
 
 // -----
@@ -156,10 +156,10 @@
   func.return %2 : tensor<1x3x2x2xf32>
 }
 // CHECK: func @not_fuse_conv2d_with_bias_and_add
-// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32>
-// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() {value = dense<4.000000e-01> : tensor<2xf32>} : () -> tensor<2xf32>
-// CHECK-NEXT: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[CONST]]) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32>
-// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[CONV2D]], %[[CONST_0]]) {data_format = "NHWC"} : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32>
+// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<2.000000e+00> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32>
+// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<4.000000e-01> : tensor<2xf32>}> : () -> tensor<2xf32>
+// CHECK-NEXT: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[CONST]]) <{data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true}> : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32>
+// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[CONV2D]], %[[CONST_0]]) <{data_format = "NHWC"}> : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32>
 // CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[BIASADD]], %arg1) : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32>
 // CHECK-NEXT: return %[[ADD]] : tensor<1x3x2x2xf32>
 
@@ -173,10 +173,10 @@
   func.return %1 : tensor<*xf32>
 }
 // CHECK: func @match_depthwise_conv2d_and_add
-// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x1xf32>} : () -> tensor<2x3x3x1xf32>
-// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() {value = dense<4.000000e-01> : tensor<3xf32>} : () -> tensor<3xf32>
-// CHECK-NEXT: %[[DEPTHWISE_CONV2D:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[CONST]]) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor<?x?x?x3xf32>
-// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[DEPTHWISE_CONV2D]], %[[CONST_0]]) {data_format = "NHWC"} : (tensor<?x?x?x3xf32>, tensor<3xf32>) -> tensor<*xf32>
+// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<2.000000e+00> : tensor<2x3x3x1xf32>}> : () -> tensor<2x3x3x1xf32>
+// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<4.000000e-01> : tensor<3xf32>}> : () -> tensor<3xf32>
+// CHECK-NEXT: %[[DEPTHWISE_CONV2D:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[CONST]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]}> {device = ""} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor<?x?x?x3xf32>
+// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[DEPTHWISE_CONV2D]], %[[CONST_0]]) <{data_format = "NHWC"}> : (tensor<?x?x?x3xf32>, tensor<3xf32>) -> tensor<*xf32>
 // CHECK-NEXT: return %[[BIASADD]] : tensor<*xf32>
 
 // -----
@@ -189,8 +189,8 @@
   func.return %1 : tensor<?x?x?x3xf32>
 }
 // CHECK: func @match_depthwise_conv2d_and_mul
-// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<8.000000e-01> : tensor<2x3x3x1xf32>} : () -> tensor<2x3x3x1xf32>
-// CHECK-NEXT: %[[DEPTHWISE_CONV2D:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[CONST]]) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor<?x?x?x3xf32>
+// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<8.000000e-01> : tensor<2x3x3x1xf32>}> : () -> tensor<2x3x3x1xf32>
+// CHECK-NEXT: %[[DEPTHWISE_CONV2D:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[CONST]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]}> {device = ""} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor<?x?x?x3xf32>
 // CHECK-NEXT: return %[[DEPTHWISE_CONV2D]] : tensor<?x?x?x3xf32>
 
 // -----
@@ -205,10 +205,10 @@
   func.return %2 : tensor<?x?x?x3xf32>
 }
 // CHECK: func @match_depthwise_conv2d_with_bias_and_add
-// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x1xf32>} : () -> tensor<2x3x3x1xf32>
-// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() {value = dense<8.000000e-01> : tensor<3xf32>} : () -> tensor<3xf32>
-// CHECK-NEXT: %[[DEPTHWISE_CONV2D:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[CONST]]) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor<?x?x?x3xf32>
-// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[DEPTHWISE_CONV2D]], %[[CONST_0]]) {data_format = "NHWC"} : (tensor<?x?x?x3xf32>, tensor<3xf32>) -> tensor<?x?x?x3xf32>
+// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<2.000000e+00> : tensor<2x3x3x1xf32>}> : () -> tensor<2x3x3x1xf32>
+// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<8.000000e-01> : tensor<3xf32>}> : () -> tensor<3xf32>
+// CHECK-NEXT: %[[DEPTHWISE_CONV2D:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[CONST]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]}> {device = ""} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor<?x?x?x3xf32>
+// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[DEPTHWISE_CONV2D]], %[[CONST_0]]) <{data_format = "NHWC"}> : (tensor<?x?x?x3xf32>, tensor<3xf32>) -> tensor<?x?x?x3xf32>
 // CHECK-NEXT: return %[[BIASADD]] : tensor<?x?x?x3xf32>
 
 // -----
@@ -223,10 +223,10 @@
   func.return %2 : tensor<?x?x?x3xf32>
 }
 // CHECK: func @match_depthwise_conv2d_with_bias_and_mul
-// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<2x3x3x1xf32>} : () -> tensor<2x3x3x1xf32>
-// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() {value = dense<2.000000e-01> : tensor<3xf32>} : () -> tensor<3xf32>
-// CHECK-NEXT: %[[DEPTHWISE_CONV2D:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[CONST]]) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor<?x?x?x3xf32>
-// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[DEPTHWISE_CONV2D]], %[[CONST_0]]) {data_format = "NHWC"} : (tensor<?x?x?x3xf32>, tensor<3xf32>) -> tensor<?x?x?x3xf32>
+// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<2x3x3x1xf32>}> : () -> tensor<2x3x3x1xf32>
+// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<2.000000e-01> : tensor<3xf32>}> : () -> tensor<3xf32>
+// CHECK-NEXT: %[[DEPTHWISE_CONV2D:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[CONST]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]}> {device = ""} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor<?x?x?x3xf32>
+// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[DEPTHWISE_CONV2D]], %[[CONST_0]]) <{data_format = "NHWC"}> : (tensor<?x?x?x3xf32>, tensor<3xf32>) -> tensor<?x?x?x3xf32>
 // CHECK-NEXT: return %[[BIASADD]] : tensor<?x?x?x3xf32>
 
 // -----
@@ -236,7 +236,7 @@
   func.return %0 : tensor<3x4x6xf32>
 }
 // CHECK-LABEL: lower_einsum
-// CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32>
+// CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) <{adj_x = false, adj_y = false}> : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32>
 
 // -----
 
@@ -251,7 +251,7 @@
   func.return %2 : tensor<*xf32>
 }
 // CHECK: func @removing_identity_after_const
-// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<2x3x3x1xf32>} : () -> tensor<2x3x3x1xf32>
+// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<2x3x3x1xf32>}> : () -> tensor<2x3x3x1xf32>
 // CHECK: %[[DEPTHWISE_CONV2D:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[CONST]])
 
 // -----
@@ -291,14 +291,14 @@
 }
 
 // CHECK: func @batch_norm_with_q_dq
-// CHECK-DAG: %[[cst:.*]] = "tf.Const"() {value = dense<0.707036077> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32>
-// CHECK-DAG: %[[cst_0:.*]] = "tf.Const"() {value = dense<-0.914072155> : tensor<2xf32>} : () -> tensor<2xf32>
+// CHECK-DAG: %[[cst:.*]] = "tf.Const"() <{value = dense<0.707036077> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32>
+// CHECK-DAG: %[[cst_0:.*]] = "tf.Const"() <{value = dense<-0.914072155> : tensor<2xf32>}> : () -> tensor<2xf32>
 // CHECK: %[[q_input:.*]] = "quantfork.qcast"(%arg0) : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform<i8:f32, 0.0011764706057660721:-43>>
 // CHECK: %[[dq_input:.*]] = "quantfork.dcast"(%[[q_input]]) : (tensor<1x3x4x3x!quant.uniform<i8:f32, 0.0011764706057660721:-43>>) -> tensor<1x3x4x3xf32>
 // CHECK: %[[q_weight:.*]] = "quantfork.qcast"(%[[cst]]) : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform<i8<-127:127>:f32:3, {0.005567213212411235,0.005567213212411235}>>
 // CHECK: %[[dq_weight:.*]] = "quantfork.dcast"(%[[q_weight]]) : (tensor<2x3x3x2x!quant.uniform<i8<-127:127>:f32:3, {0.005567213212411235,0.005567213212411235}>>) -> tensor<2x3x3x2xf32>
 // CHECK: %[[conv:.*]] = "tf.Conv2D"(%[[dq_input]], %[[dq_weight]])
-// CHECK: %[[bias:.*]] = "tf.BiasAdd"(%[[conv]], %[[cst_0]]) {data_format = "NHWC"}
+// CHECK: %[[bias:.*]] = "tf.BiasAdd"(%[[conv]], %[[cst_0]]) <{data_format = "NHWC"}>
 // CHECK: %[[relu6:.*]] = "tf.Relu6"(%[[bias]])
 
 // -----
@@ -334,8 +334,8 @@
   func.return %2 : tensor<?x?x?x256xf32>
 }
 // CHECK: func @conv2d_with_large_weight_and_mul
-// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<1.250000e+00> : tensor<48x48x3x256xf32>} : () -> tensor<48x48x3x256xf32>
-// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() {value = dense<2.000000e-01> : tensor<256xf32>} : () -> tensor<256xf32>
+// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<1.250000e+00> : tensor<48x48x3x256xf32>}> : () -> tensor<48x48x3x256xf32>
+// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<2.000000e-01> : tensor<256xf32>}> : () -> tensor<256xf32>
 // CHECK-NEXT: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[CONST]])
 // CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[CONV2D]], %[[CONST_0]])
 // CHECK-NEXT: return %[[BIASADD]]
@@ -354,8 +354,8 @@
   func.return %2 : tensor<?x?x?x3xf32>
 }
 // CHECK: func @depthwise_conv2d_with_large_weight_and_add
-// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<2.500000e+00> : tensor<48x48x3x256xf32>} : () -> tensor<48x48x3x256xf32>
-// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() {value = dense<8.000000e-01> : tensor<3xf32>} : () -> tensor<3xf32>
+// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<2.500000e+00> : tensor<48x48x3x256xf32>}> : () -> tensor<48x48x3x256xf32>
+// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<8.000000e-01> : tensor<3xf32>}> : () -> tensor<3xf32>
 // CHECK-NEXT: %[[DEPTHWISE_CONV2D:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[CONST]])
 // CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[DEPTHWISE_CONV2D]], %[[CONST_0]])
 // CHECK-NEXT: return %[[BIASADD]]
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_quantize_drq.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_quantize_drq.mlir
index 4da50d4..0176867 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_quantize_drq.mlir
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_quantize_drq.mlir
@@ -15,11 +15,11 @@
 // CHECK-DAG: %[[CONST:.*]] = arith.constant dense<0.000000e+00> : tensor<2x1024xf32>
 // CHECK: %0 = "quantfork.qcast"(%[[CONST]]) : (tensor<2x1024xf32>) -> tensor<2x1024x!quant.uniform<i8<-127:127>:f32, 3.9370078740157481E-9>>
 // CHECK: %1 = "quantfork.dcast"(%0) : (tensor<2x1024x!quant.uniform<i8<-127:127>:f32, 3.9370078740157481E-9>>) -> tensor<2x1024xf32>
-// CHECK: %2 = "tf.PartitionedCall"(%arg0, %1) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn} : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32>
+// CHECK: %2 = "tf.PartitionedCall"(%arg0, %1) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32>
 // CHECK: return %2 : tensor<*xf32>
 
 // CHECK-LABEL: func private @composite_matmul_fn
-// CHECK: %0 = "tf.MatMul"(%arg0, %arg1) {attr_map = "0:transpose_a,1:transpose_a", device = "", transpose_a = false, transpose_b = false} : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32>
+// CHECK: %0 = "tf.MatMul"(%arg0, %arg1) <{transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_a", device = ""} : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32>
 // CHECK: return %0 : tensor<*xf32>
 }
 
@@ -43,7 +43,7 @@
 // CHECK-DAG: %[[CONST_1:.*]] = arith.constant dense<3.000000e+00> : tensor<2x3x3x512xf32>
 // CHECK: %0 = "quantfork.qcast"(%[[CONST_1]]) : (tensor<2x3x3x512xf32>) -> tensor<2x3x3x512x!quant.uniform<i8<-127:127>:f32, 0.023622047244094488>>
 // CHECK: %1 = "quantfork.dcast"(%0) : (tensor<2x3x3x512x!quant.uniform<i8<-127:127>:f32, 0.023622047244094488>>) -> tensor<2x3x3x512xf32>
-// CHECK: %2 = "tf.PartitionedCall"(%arg0, %1) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_conv2d_fn_1} : (tensor<1x3x4x3xf32>, tensor<2x3x3x512xf32>) -> tensor<*xf32>
+// CHECK: %2 = "tf.PartitionedCall"(%arg0, %1) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_fn_1}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x3x4x3xf32>, tensor<2x3x3x512xf32>) -> tensor<*xf32>
 // CHECK: %3 = "tf.BiasAdd"(%2, %[[CONST_0]])
 // CHECK: return %3 : tensor<*xf32>
 
@@ -74,7 +74,7 @@
 // CHECK-DAG: %[[CONST_1:.*]] = arith.constant dense<3.000000e+00> : tensor<2x3x1x1536xf32>
 // CHECK: %0 = "quantfork.qcast"(%[[CONST_1]]) : (tensor<2x3x1x1536xf32>) -> tensor<2x3x1x1536x!quant.uniform<i8<-127:127>:f32, 0.023622047244094488>>
 // CHECK: %1 = "quantfork.dcast"(%0) : (tensor<2x3x1x1536x!quant.uniform<i8<-127:127>:f32, 0.023622047244094488>>) -> tensor<2x3x1x1536xf32>
-// CHECK: %2 = "tf.PartitionedCall"(%arg0, %1) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_depthwise_conv2d_fn_0} : (tensor<1x3x4x512xf32>, tensor<2x3x1x1536xf32>) -> tensor<*xf32>
+// CHECK: %2 = "tf.PartitionedCall"(%arg0, %1) <{config = "", config_proto = "", executor_type = "", f = @composite_depthwise_conv2d_fn_0}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x3x4x512xf32>, tensor<2x3x1x1536xf32>) -> tensor<*xf32>
 // CHECK: %3 = "tf.BiasAdd"(%2, %[[CONST_0]])
 // CHECK: return %3 : tensor<*xf32>
 
@@ -85,6 +85,6 @@
 // CHECK-LABEL: func private @composite_depthwise_conv2d_fn_0(
 // CHECK-SAME:                                             %arg0: tensor<1x3x4x512xf32>,
 // CHECK-SAME:                                             %arg1: tensor<2x3x1x1536xf32>)
-// CHECK: %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) {attr_map = "0:strides,1:padding,2:explicit_paddings,3:dilations", data_format = "NHWC", device = "",
+// CHECK: %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1]}> {attr_map = "0:strides,1:padding,2:explicit_paddings,3:dilations", device = ""}
 // CHECK: return %0 : tensor<*xf32>
 }
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_quantize_drq_per_channel.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_quantize_drq_per_channel.mlir
index f2d80c0..927fc34 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_quantize_drq_per_channel.mlir
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_quantize_drq_per_channel.mlir
@@ -15,11 +15,11 @@
 // CHECK-DAG: %[[CONST:.*]] = arith.constant dense<0.000000e+00> : tensor<2x1024xf32>
 // CHECK: %0 = "quantfork.qcast"(%[[CONST]]) : (tensor<2x1024xf32>) -> tensor<2x1024x!quant.uniform<i8<-127:127>:f32, 3.9370078740157481E-9>>
 // CHECK: %1 = "quantfork.dcast"(%0) : (tensor<2x1024x!quant.uniform<i8<-127:127>:f32, 3.9370078740157481E-9>>) -> tensor<2x1024xf32>
-// CHECK: %2 = "tf.PartitionedCall"(%arg0, %1) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn} : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32>
+// CHECK: %2 = "tf.PartitionedCall"(%arg0, %1) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32>
 // CHECK: return %2 : tensor<*xf32>
 
 // CHECK-LABEL: func private @composite_matmul_fn
-// CHECK: %0 = "tf.MatMul"(%arg0, %arg1) {attr_map = "0:transpose_a,1:transpose_a", device = "", transpose_a = false, transpose_b = false} : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32>
+// CHECK: %0 = "tf.MatMul"(%arg0, %arg1) <{transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_a", device = ""} : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32>
 // CHECK: return %0 : tensor<*xf32>
 }
 
@@ -43,7 +43,7 @@
 // CHECK-DAG: %[[CONST_1:.*]] = arith.constant dense<3.000000e+00> : tensor<2x3x512x2xf32>
 // CHECK: %0 = "quantfork.qcast"(%[[CONST_1]]) : (tensor<2x3x512x2xf32>) -> tensor<2x3x512x2x!quant.uniform<i8<-127:127>:f32:3, {0.023622047244094488,0.023622047244094488}>>
 // CHECK: %1 = "quantfork.dcast"(%0) : (tensor<2x3x512x2x!quant.uniform<i8<-127:127>:f32:3, {0.023622047244094488,0.023622047244094488}>>) -> tensor<2x3x512x2xf32>
-// CHECK: %2 = "tf.PartitionedCall"(%arg0, %1) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_conv2d_fn_1} : (tensor<1x3x4x512xf32>, tensor<2x3x512x2xf32>) -> tensor<*xf32>
+// CHECK: %2 = "tf.PartitionedCall"(%arg0, %1) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_fn_1}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x3x4x512xf32>, tensor<2x3x512x2xf32>) -> tensor<*xf32>
 // CHECK: %3 = "tf.BiasAdd"(%2, %[[CONST_0]])
 // CHECK: return %3 : tensor<*xf32>
 
@@ -74,7 +74,7 @@
 // CHECK-DAG: %[[CONST_1:.*]] = arith.constant dense<3.000000e+00> : tensor<2x3x1x1536xf32>
 // CHECK: %0 = "quantfork.qcast"(%[[CONST_1]]) : (tensor<2x3x1x1536xf32>) -> tensor<2x3x1x1536x!quant.uniform<i8<-127:127>:f32:3, {0.023622047244094488,
 // CHECK: %1 = "quantfork.dcast"(%0) : (tensor<2x3x1x1536x!quant.uniform<i8<-127:127>:f32:3, {0.023622047244094488,
-// CHECK: %2 = "tf.PartitionedCall"(%arg0, %1) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_depthwise_conv2d_fn_0} : (tensor<1x3x4x512xf32>, tensor<2x3x1x1536xf32>) -> tensor<*xf32>
+// CHECK: %2 = "tf.PartitionedCall"(%arg0, %1) <{config = "", config_proto = "", executor_type = "", f = @composite_depthwise_conv2d_fn_0}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x3x4x512xf32>, tensor<2x3x1x1536xf32>) -> tensor<*xf32>
 // CHECK: %3 = "tf.BiasAdd"(%2, %[[CONST_0]])
 // CHECK: return %3 : tensor<*xf32>
 
@@ -85,6 +85,6 @@
 // CHECK-LABEL: func private @composite_depthwise_conv2d_fn_0(
 // CHECK-SAME:                                             %arg0: tensor<1x3x4x512xf32>,
 // CHECK-SAME:                                             %arg1: tensor<2x3x1x1536xf32>)
-// CHECK: %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) {attr_map = "0:strides,1:padding,2:explicit_paddings,3:dilations", data_format = "NHWC", device = "",
+// CHECK: %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1]}> {attr_map = "0:strides,1:padding,2:explicit_paddings,3:dilations", device = ""}
 // CHECK: return %0 : tensor<*xf32>
 }
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/preprocess_op.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/preprocess_op.mlir
index 0ef69f6..ae8a20d 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/preprocess_op.mlir
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/preprocess_op.mlir
@@ -22,8 +22,8 @@
 // CHECK: %[[CONST_1:.*]] = arith.constant dense
 // CHECK-NOT: tensor<2x3x3x2xf32>
 // CHECK-SAME: tensor<2x3x1x6xf32>
-// CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %[[CONST_1:.*]]) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_depthwise_conv2d_fn_0} : (tensor<1x3x4x3xf32>, tensor<2x3x1x6xf32>) -> tensor<*xf32>
-// CHECK: %[[BIAS_0:.*]] = "tf.BiasAdd"(%[[PARTITIONEDCALL_0]], %[[CONST_0:.*]]) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<6xf32>) -> tensor<*xf32>
+// CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %[[CONST_1:.*]]) <{config = "", config_proto = "", executor_type = "", f = @composite_depthwise_conv2d_fn_0}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x3x4x3xf32>, tensor<2x3x1x6xf32>) -> tensor<*xf32>
+// CHECK: %[[BIAS_0:.*]] = "tf.BiasAdd"(%[[PARTITIONEDCALL_0]], %[[CONST_0:.*]]) <{data_format = "NHWC"}> {device = ""} : (tensor<*xf32>, tensor<6xf32>) -> tensor<*xf32>
 // CHECK: return %[[BIAS_0:.*]] : tensor<*xf32>
 
 // CHECK-LABEL: func private @composite_depthwise_conv2d_fn(
@@ -33,7 +33,7 @@
 // CHECK-LABEL: func private @composite_depthwise_conv2d_fn_0(
 // CHECK-SAME:                                             %arg0: tensor<1x3x4x3xf32>,
 // CHECK-SAME:                                             %arg1: tensor<2x3x1x6xf32>)
-// CHECK: %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) {attr_map = "0:strides,1:padding,2:explicit_paddings,3:dilations", data_format = "NHWC", device = "",
+// CHECK: %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1]}> {attr_map = "0:strides,1:padding,2:explicit_paddings,3:dilations", device = ""}
 // CHECK: return %0 : tensor<*xf32>
 }
 
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/preprocess_op_weight_only.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/preprocess_op_weight_only.mlir
index 4f36784..e80db7f 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/preprocess_op_weight_only.mlir
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/preprocess_op_weight_only.mlir
@@ -23,14 +23,14 @@
 // PerTensor: %[[CONST_1:.*]] = arith.constant dense
 // PerTensor-NOT: tensor<2x3x1x6xf32>
 // PerTensor-SAME: tensor<2x3x3x2xf32>
-// PerTensor: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %[[CONST_1:.*]]) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_depthwise_conv2d_fn} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32>
-// PerTensor: %[[BIAS_0:.*]] = "tf.BiasAdd"(%[[PARTITIONEDCALL_0]], %[[CONST_0:.*]]) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<6xf32>) -> tensor<*xf32>
+// PerTensor: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %[[CONST_1:.*]]) <{config = "", config_proto = "", executor_type = "", f = @composite_depthwise_conv2d_fn}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32>
+// PerTensor: %[[BIAS_0:.*]] = "tf.BiasAdd"(%[[PARTITIONEDCALL_0]], %[[CONST_0:.*]]) <{data_format = "NHWC"}> {device = ""} : (tensor<*xf32>, tensor<6xf32>) -> tensor<*xf32>
 // PerTensor: return %[[BIAS_0:.*]] : tensor<*xf32>
 
 // PerTensor-LABEL: func private @composite_depthwise_conv2d_fn(
 // PerTensor-SAME:                                             %arg0: tensor<1x3x4x3xf32>,
 // PerTensor-SAME:                                             %arg1: tensor<2x3x3x2xf32>)
-// PerTensor: %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) {attr_map = "0:strides,1:padding,2:explicit_paddings,3:dilations", data_format = "NHWC", device = "",
+// PerTensor: %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1]}> {attr_map = "0:strides,1:padding,2:explicit_paddings,3:dilations", device = ""}
 // PerTensor: return %0 : tensor<*xf32>
 
 // PerChannel-LABEL: func @depthwise_conv
@@ -38,8 +38,8 @@
 // PerChannel: %[[CONST_1:.*]] = arith.constant dense
 // PerChannel-NOT: tensor<2x3x3x2xf32>
 // PerChannel-SAME: tensor<2x3x1x6xf32>
-// PerChannel: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %[[CONST_1:.*]]) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_depthwise_conv2d_fn_0} : (tensor<1x3x4x3xf32>, tensor<2x3x1x6xf32>) -> tensor<*xf32>
-// PerChannel: %[[BIAS_0:.*]] = "tf.BiasAdd"(%[[PARTITIONEDCALL_0]], %[[CONST_0:.*]]) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<6xf32>) -> tensor<*xf32>
+// PerChannel: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %[[CONST_1:.*]]) <{config = "", config_proto = "", executor_type = "", f = @composite_depthwise_conv2d_fn_0}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x3x4x3xf32>, tensor<2x3x1x6xf32>) -> tensor<*xf32>
+// PerChannel: %[[BIAS_0:.*]] = "tf.BiasAdd"(%[[PARTITIONEDCALL_0]], %[[CONST_0:.*]]) <{data_format = "NHWC"}> {device = ""} : (tensor<*xf32>, tensor<6xf32>) -> tensor<*xf32>
 // PerChannel: return %[[BIAS_0:.*]] : tensor<*xf32>
 
 // PerChannel-LABEL: func private @composite_depthwise_conv2d_fn(
@@ -49,7 +49,7 @@
 // PerChannel-LABEL: func private @composite_depthwise_conv2d_fn_0(
 // PerChannel-SAME:                                             %arg0: tensor<1x3x4x3xf32>,
 // PerChannel-SAME:                                             %arg1: tensor<2x3x1x6xf32>)
-// PerChannel: %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) {attr_map = "0:strides,1:padding,2:explicit_paddings,3:dilations", data_format = "NHWC", device = "",
+// PerChannel: %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1]}> {attr_map = "0:strides,1:padding,2:explicit_paddings,3:dilations", device = ""}
 // PerChannel: return %0 : tensor<*xf32>
 }
 
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/propagate_quantize_type.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/propagate_quantize_type.mlir
index 0c69477..6a737b7 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/propagate_quantize_type.mlir
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/propagate_quantize_type.mlir
@@ -12,8 +12,8 @@
   }
 
 // CHECK-LABEL: func @not_propagate_matmul
-// CHECK: %[[CASTED_W:.*]] = "tf.Cast"(%0) {Truncate = false} : (tensor<2x1024xi8>) -> tensor<2x1024xf32>
-// CHECK: %2 = "tf.MatMul"(%arg0, %[[CASTED_W]]) {attr_map = "0:transpose_a,1:transpose_a", device = "", transpose_a = false, transpose_b = false} : (tensor<1x2x2x2xf32>, tensor<2x1024xf32>) -> tensor<*xf32>
+// CHECK: %[[CASTED_W:.*]] = "tf.Cast"(%0) <{Truncate = false}> : (tensor<2x1024xi8>) -> tensor<2x1024xf32>
+// CHECK: %2 = "tf.MatMul"(%arg0, %[[CASTED_W]]) <{transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_a", device = ""} : (tensor<1x2x2x2xf32>, tensor<2x1024xf32>) -> tensor<*xf32>
 }
 
 // -----
@@ -37,8 +37,8 @@
 
 // CHECK-LABEL: func @propagate_xladotv2_bf16
 // CHECK: %[[IDENTITY:.*]] = "tf.Identity"(%cst) : (tensor<2x1024xi8>) -> tensor<2x1024xi8>
-// CHECK: %[[MATMUL:.*]] = "tf.XlaDotV2"(%arg0, %[[IDENTITY]]) {device = "", dimension_numbers = "\12\01\00\0A\01\03", precision_config = ""} : (tensor<1x2x2x2xbf16>, tensor<2x1024xi8>) -> tensor<1x2x2x1024xbf16>
-// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[MATMUL]]) {config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform} : (tensor<1x2x2x1024xbf16>) -> tensor<1x2x2x1024xbf16>
+// CHECK: %[[MATMUL:.*]] = "tf.XlaDotV2"(%arg0, %[[IDENTITY]]) <{dimension_numbers = "\12\01\00\0A\01\03", precision_config = ""}> {device = ""} : (tensor<1x2x2x2xbf16>, tensor<2x1024xi8>) -> tensor<1x2x2x1024xbf16>
+// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[MATMUL]]) <{config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform}> : (tensor<1x2x2x1024xbf16>) -> tensor<1x2x2x1024xbf16>
 }
 
 // -----
@@ -64,8 +64,8 @@
 
 // CHECK-LABEL: func @not_propagate_last_op
 // CHECK: %[[IDENTITY:.*]] = "tf.Identity"(%cst_0) : (tensor<200x100x300xi8>) -> tensor<200x100x300xi8>
-// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[IDENTITY]]) {config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform} : (tensor<200x100x300xi8>) -> tensor<200x100x300xf32>
-// CHECK: %[[GATHER:.*]] = "tf.XlaGather"(%[[DEQUANTIZED]], %arg0, %cst) {dimension_numbers = "\0A\02\00\01\12\01\00\1A\02\00\01 \01", indices_are_sorted = true} : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<3xi64>) -> tensor<1x300x10xf32>
+// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[IDENTITY]]) <{config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform}> : (tensor<200x100x300xi8>) -> tensor<200x100x300xf32>
+// CHECK: %[[GATHER:.*]] = "tf.XlaGather"(%[[DEQUANTIZED]], %arg0, %cst) <{dimension_numbers = "\0A\02\00\01\12\01\00\1A\02\00\01 \01", indices_are_sorted = true}> : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<3xi64>) -> tensor<1x300x10xf32>
 // CHECK: return %[[GATHER]] : tensor<1x300x10xf32>
 
 // -----
@@ -91,7 +91,7 @@
 
 // CHECK-LABEL: func @propagate_xlagather
 // CHECK: %[[IDENTITY:.*]] = "tf.Identity"(%cst_0) : (tensor<200x100x300xi8>) -> tensor<200x100x300xi8>
-// CHECK: %[[GATHER:.*]] = "tf.XlaGather"(%[[IDENTITY]], %arg0, %cst) {dimension_numbers = "\0A\02\00\01\12\01\00\1A\02\00\01 \01", indices_are_sorted = true} : (tensor<200x100x300xi8>, tensor<10x2xi32>, tensor<3xi64>) -> tensor<1x300x10xi8>
-// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[GATHER]]) {config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform} : (tensor<1x300x10xi8>) -> tensor<1x300x10xf32>
+// CHECK: %[[GATHER:.*]] = "tf.XlaGather"(%[[IDENTITY]], %arg0, %cst) <{dimension_numbers = "\0A\02\00\01\12\01\00\1A\02\00\01 \01", indices_are_sorted = true}> : (tensor<200x100x300xi8>, tensor<10x2xi32>, tensor<3xi64>) -> tensor<1x300x10xi8>
+// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[GATHER]]) <{config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform}> : (tensor<1x300x10xi8>) -> tensor<1x300x10xf32>
 // CHECK: %[[ORIGINAL_IDENTITY:.*]] = "tf.Identity"(%[[DEQUANTIZED]]) : (tensor<1x300x10xf32>) -> tensor<1x300x10xf32>
 // CHECK: return %[[ORIGINAL_IDENTITY]] : tensor<1x300x10xf32>
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize.mlir
index d04ec26..0f3c702 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize.mlir
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize.mlir
@@ -23,7 +23,7 @@
 // CHECK-DAG: [[weight:%.+]] = "arith.constant"() <{value = dense_resource<__elided__> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2x!quant.uniform<i8:f32, 0.074855112561992565:-1>>
 // CHECK: [[q_input:%.+]] = "quantfork.qcast"(%arg0) : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform<i8:f32, 0.58810077742034317:-128>>
 // CHECK-NEXT: [[q_bias:%.+]] = "quantfork.qcast"([[bias]]) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i32:f32, 0.044022349891595126>>
-// CHECK-NEXT: [[conv:%.+]] = "tf.PartitionedCall"([[q_input]], [[weight]], [[q_bias]]) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @[[composite_fn:composite_conv2d_with_bias_and_relu6_fn.*]]} : (tensor<1x3x4x3x!quant.uniform<i8:f32, 0.58810077742034317:-128>>, tensor<2x3x3x2x!quant.uniform<i8:f32, 0.074855112561992565:-1>>, tensor<2x!quant.uniform<i32:f32, 0.044022349891595126>>) -> tensor<*x!quant.uniform<i8:f32, 0.023529411764705882:-128>>
+// CHECK-NEXT: [[conv:%.+]] = "tf.PartitionedCall"([[q_input]], [[weight]], [[q_bias]]) <{config = "", config_proto = "", executor_type = "", f = @[[composite_fn:composite_conv2d_with_bias_and_relu6_fn.*]]}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x3x4x3x!quant.uniform<i8:f32, 0.58810077742034317:-128>>, tensor<2x3x3x2x!quant.uniform<i8:f32, 0.074855112561992565:-1>>, tensor<2x!quant.uniform<i32:f32, 0.044022349891595126>>) -> tensor<*x!quant.uniform<i8:f32, 0.023529411764705882:-128>>
 // CHECK-NEXT: [[res:%.+]] = "quantfork.dcast"([[conv]]) : (tensor<*x!quant.uniform<i8:f32, 0.023529411764705882:-128>>) -> tensor<*xf32>
 // CHECK-NEXT: "func.return"([[res]]) : (tensor<*xf32>) -> ()
 
@@ -69,11 +69,11 @@
 
 // CHECK: %[[q:.*]] = "quantfork.qcast"(%arg0)
 // CHECK: %[[sc1:.*]] = "quantfork.scast"(%[[q]]) : (tensor<*x!quant.uniform<i8:f32, 5.000000e-02:-10>>)
-// CHECK: %[[fcast:.*]] = "tf.Cast"(%[[sc1]]) {Truncate = false} : (tensor<*xi8>) -> tensor<*xf32>
+// CHECK: %[[fcast:.*]] = "tf.Cast"(%[[sc1]]) <{Truncate = false}> : (tensor<*xi8>) -> tensor<*xf32>
 // CHECK: %[[avgpool_f32:.*]] = "tf.AvgPool"(%[[fcast]])
 // CHECK-SAME: (tensor<*xf32>) -> tensor<*xf32>
 // CHECK: %[[round:.*]] = "tf.Round"(%[[avgpool_f32]])
-// CHECK: %[[icast:.*]] = "tf.Cast"(%[[round]]) {Truncate = false} : (tensor<*xf32>) -> tensor<*xi8>
+// CHECK: %[[icast:.*]] = "tf.Cast"(%[[round]]) <{Truncate = false}> : (tensor<*xf32>) -> tensor<*xi8>
 // CHECK: %[[sc2:.*]] = "quantfork.scast"(%[[icast]])
 // CHECK: %[[dq:.*]] = "quantfork.dcast"(%[[sc2]]) : (tensor<*x!quant.uniform<i8:f32, 5.000000e-02:-10>>)
 // CHECK: return %[[dq]]
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions.mlir
index eab1244..5b5addd 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions.mlir
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions.mlir
@@ -28,17 +28,17 @@
   }
 
 // CHECK-LABEL: func @conv
-// CHECK-DAG: %[[w_float:.*]] = "tf.Const"() {value = dense<{{\[\[\[\[}}1.600000e-01
-// CHECK-DAG: %[[b_float:.*]] = "tf.Const"() {value = dense<[-2.000000e+00, 3.000000e+00]> : tensor<2xf32>
-// CHECK-DAG: %[[in_scale:.*]] = "tf.Const"() {value = dense<8.000000e-03> : tensor<f32>} : () -> tensor<f32>
-// CHECK-DAG: %[[in_zp:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
-// CHECK-DAG: %[[w_scale:.*]] = "tf.Const"() {value = dense<[4.000000e-03
-// CHECK-DAG: %[[w_zp:.*]] = "tf.Const"() {value = dense<0> : tensor<2xi32>}
-// CHECK-DAG: %[[b_scale:.*]] = "tf.Const"() {value = dense<[3.200000e-05, 4.000000e-05]> : tensor<2xf32>}
-// CHECK-DAG: %[[out_scale:.*]] = "tf.Const"() {value = dense<5.000000e-02> : tensor<f32>}
-// CHECK-DAG: %[[out_zp:.*]] = "tf.Const"() {value = dense<-1> : tensor<i32>}
-// CHECK-DAG: %[[b_quant:.*]] = "tf.Const"() {value = dense<[-62500, 75000]> : tensor<2xi32>}
-// CHECK-DAG: %[[w_quant:.*]] = "tf.Const"() {value = dense<{{\[\[\[\[}}40, 20]
+// CHECK-DAG: %[[w_float:.*]] = "tf.Const"() <{value = dense<{{\[\[\[\[}}1.600000e-01
+// CHECK-DAG: %[[b_float:.*]] = "tf.Const"() <{value = dense<[-2.000000e+00, 3.000000e+00]> : tensor<2xf32>
+// CHECK-DAG: %[[in_scale:.*]] = "tf.Const"() <{value = dense<8.000000e-03> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG: %[[in_zp:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}>
+// CHECK-DAG: %[[w_scale:.*]] = "tf.Const"() <{value = dense<[4.000000e-03
+// CHECK-DAG: %[[w_zp:.*]] = "tf.Const"() <{value = dense<0> : tensor<2xi32>}>
+// CHECK-DAG: %[[b_scale:.*]] = "tf.Const"() <{value = dense<[3.200000e-05, 4.000000e-05]> : tensor<2xf32>}
+// CHECK-DAG: %[[out_scale:.*]] = "tf.Const"() <{value = dense<5.000000e-02> : tensor<f32>}>
+// CHECK-DAG: %[[out_zp:.*]] = "tf.Const"() <{value = dense<-1> : tensor<i32>}>
+// CHECK-DAG: %[[b_quant:.*]] = "tf.Const"() <{value = dense<[-62500, 75000]> : tensor<2xi32>}>
+// CHECK-DAG: %[[w_quant:.*]] = "tf.Const"() <{value = dense<{{\[\[\[\[}}40, 20]
 // CHECK-DAG: {{\[\[\[}}-87, -42]
 
 // CHECK: %[[quantize:.*]] = "tf.PartitionedCall"(%arg0, %[[in_scale]], %[[in_zp]])
@@ -58,7 +58,8 @@
 
 // CHECK-LABEL: func private @composite_conv2d_with_bias_and_relu6_fn_1
 // CHECK:      %[[CONV2D_0:.*]] = "tf.Conv2D"
-// CHECK-SAME: data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true
+// CHECK-SAME: data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true
+// CHECK-SAME: device = ""
 // CHECK:      %[[BIASADD_0:.*]] = "tf.BiasAdd"
 // CHECK:      %[[RELU6_0:.*]] = "tf.Relu6"
 
@@ -158,10 +159,10 @@
 // CHECK: %[[conv_quant:.*]] = "tf.PartitionedCall"(%[[quantize]]
 // CHECK-SAME: f = @quantized_conv2d_with_bias_and_relu6_fn_0
 // CHECK-SAME: (tensor<1x2x2x3xi8>, tensor<2x2x3x2xi8>, tensor<2xi32>, tensor<f32>, tensor<i32>, tensor<2xf32>, tensor<2xi32>, tensor<2xf32>, tensor<2xi32>, tensor<f32>, tensor<i32>) -> tensor<*xi8>
-// CHECK: %[[cast_1:.*]] = "tf.Cast"(%[[conv_quant]]) {Truncate = false} : (tensor<*xi8>) -> tensor<*xf32>
-// CHECK: %[[avgpool:.*]] = "tf.AvgPool"(%[[cast_1]]) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<*xf32>) -> tensor<*xf32>
+// CHECK: %[[cast_1:.*]] = "tf.Cast"(%[[conv_quant]]) <{Truncate = false}> : (tensor<*xi8>) -> tensor<*xf32>
+// CHECK: %[[avgpool:.*]] = "tf.AvgPool"(%[[cast_1]]) <{data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]}> : (tensor<*xf32>) -> tensor<*xf32>
 // CHECK: %[[round:.*]] = "tf.Round"(%[[avgpool]]) : (tensor<*xf32>) -> tensor<*xf32>
-// CHECK: %[[cast_2:.*]] = "tf.Cast"(%[[round]]) {Truncate = false} : (tensor<*xf32>) -> tensor<*xi8>
+// CHECK: %[[cast_2:.*]] = "tf.Cast"(%[[round]]) <{Truncate = false}> : (tensor<*xf32>) -> tensor<*xi8>
 // CHECK: %[[dequantize:.*]] = "tf.PartitionedCall"(%[[cast_2]]
 // CHECK-SAME: f = @dequantize_i8
 // CHECK: return %[[dequantize]]
@@ -252,35 +253,35 @@
   }
 
 // CHECK-LABE: @conv_with_dump
-// CHECK-DAG: %[[w0_float:.*]] = "tf.Const"() {value = dense<{{\[\[\[\[}}-0.282878935, -0.211567819
-// CHECK-DAG: %[[b0_float:.*]] = "tf.Const"() {value = dense<[-0.0192535277, -5.998660e-03]> : tensor<2xf32>} : () -> tensor<2xf32>
-// CHECK-DAG: %[[w1_float:.*]] = "tf.Const"() {value = dense<{{\[\[\[\[}}0.208403707, 0.478067577
-// CHECK-DAG: %[[b1_float:.*]] = "tf.Const"() {value = dense<[-0.0291469581, 0.0106381178]> : tensor<2xf32>} : () -> tensor<2xf32>
-// CHECK-DAG: %[[w0_quantized:.*]] = "tf.Const"() {value = dense<{{\[\[\[\[}}-59, -44
-// CHECK-DAG: %[[b0_quantized:.*]] = "tf.Const"() {value = dense<[-1040, -324]> : tensor<2xi32>} : () -> tensor<2xi32>
-// CHECK-DAG: %[[w1_quantized:.*]] = "tf.Const"() {value = dense<{{\[\[\[\[}}44, 100
-// CHECK-DAG: %[[b1_quantized:.*]] = "tf.Const"() {value = dense<[-4312, 1574]> : tensor<2xi32>} : () -> tensor<2xi32>
-// CHECK-DAG: %[[in_scale:.*]] = "tf.Const"() {value = dense<0.00387597573> : tensor<f32>} : () -> tensor<f32>
-// CHECK-DAG: %[[in_out_zp:.*]] = "tf.Const"() {value = dense<-128> : tensor<i32>} : () -> tensor<i32>
-// CHECK-DAG: %[[w0_scale:.*]] = "tf.Const"() {value = dense<0.00477493973> : tensor<f32>} : () -> tensor<f32>
-// CHECK-DAG: %[[w_b_zp:.*]]  = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-// CHECK-DAG: %[[b0_scale:.*]] = "tf.Const"() {value = dense<1.85075514E-5> : tensor<f32>} : () -> tensor<f32>
-// CHECK-DAG: %[[mid_scale:.*]] = "tf.Const"() {value = dense<0.00141507247> : tensor<f32>} : () -> tensor<f32>
-// CHECK-DAG: %[[w1_scale:.*]] = "tf.Const"() {value = dense<0.00477652298> : tensor<f32>} : () -> tensor<f32>
-// CHECK-DAG: %[[b1_scale:.*]] = "tf.Const"() {value = dense<6.75912588E-6> : tensor<f32>} : () -> tensor<f32>
-// CHECK-DAG: %[[out_scale:.*]] = "tf.Const"() {value = dense<7.24974147E-4> : tensor<f32>} : () -> tensor<f32>
-// CHECK-DAG: %[[arg_quantized:.*]] = "tf.PartitionedCall"(%arg0, %[[in_scale]], %[[in_out_zp]]) {config = "", config_proto = "", executor_type = "", f = @quantize_i8}
-// CHECK-DAG: %[[conv0_quantized:.*]] = "tf.PartitionedCall"(%[[arg_quantized]], %[[w0_quantized]], %[[b0_quantized]], %[[in_scale]], %[[in_out_zp]], %[[w0_scale]], %[[w_b_zp]], %[[b0_scale]], %[[w_b_zp]], %[[mid_scale]], %[[in_out_zp]]) {config = "", config_proto = "", executor_type = "", f = @quantized_conv2d_with_bias_and_relu6_fn_1}
-// CHECK-DAG: %[[conv0_dequantized:.*]] = "tf.PartitionedCall"(%[[conv0_quantized]], %[[mid_scale]], %[[in_out_zp]]) {config = "", config_proto = "", executor_type = "", f = @dequantize_i8}
-// CHECK-DAG: %[[conv1_quantized:.*]] = "tf.PartitionedCall"(%[[conv0_quantized]], %[[w1_quantized]], %[[b1_quantized]], %[[mid_scale]], %[[in_out_zp]], %[[w1_scale]], %[[w_b_zp]], %[[b1_scale]], %[[w_b_zp]], %[[out_scale]], %[[in_out_zp]]) {config = "", config_proto = "", executor_type = "", f = @quantized_conv2d_with_bias_and_relu6_fn_0}
-// CHECK-DAG: %[[conv1_dequantized_0:.*]] = "tf.PartitionedCall"(%[[conv1_quantized]], %[[out_scale]], %[[in_out_zp]]) {config = "", config_proto = "", executor_type = "", f = @dequantize_i8}
-// CHECK-DAG: %[[conv1_dequantized_1:.*]] = "tf.PartitionedCall"(%[[conv1_quantized]], %[[out_scale]], %[[in_out_zp]]) {config = "", config_proto = "", executor_type = "", f = @dequantize_i8}
+// CHECK-DAG: %[[w0_float:.*]] = "tf.Const"() <{value = dense<{{\[\[\[\[}}-0.282878935, -0.211567819
+// CHECK-DAG: %[[b0_float:.*]] = "tf.Const"() <{value = dense<[-0.0192535277, -5.998660e-03]> : tensor<2xf32>}> : () -> tensor<2xf32>
+// CHECK-DAG: %[[w1_float:.*]] = "tf.Const"() <{value = dense<{{\[\[\[\[}}0.208403707, 0.478067577
+// CHECK-DAG: %[[b1_float:.*]] = "tf.Const"() <{value = dense<[-0.0291469581, 0.0106381178]> : tensor<2xf32>}> : () -> tensor<2xf32>
+// CHECK-DAG: %[[w0_quantized:.*]] = "tf.Const"() <{value = dense<{{\[\[\[\[}}-59, -44
+// CHECK-DAG: %[[b0_quantized:.*]] = "tf.Const"() <{value = dense<[-1040, -324]> : tensor<2xi32>}> : () -> tensor<2xi32>
+// CHECK-DAG: %[[w1_quantized:.*]] = "tf.Const"() <{value = dense<{{\[\[\[\[}}44, 100
+// CHECK-DAG: %[[b1_quantized:.*]] = "tf.Const"() <{value = dense<[-4312, 1574]> : tensor<2xi32>}> : () -> tensor<2xi32>
+// CHECK-DAG: %[[in_scale:.*]] = "tf.Const"() <{value = dense<0.00387597573> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG: %[[in_out_zp:.*]] = "tf.Const"() <{value = dense<-128> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG: %[[w0_scale:.*]] = "tf.Const"() <{value = dense<0.00477493973> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG: %[[w_b_zp:.*]]  = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG: %[[b0_scale:.*]] = "tf.Const"() <{value = dense<1.85075514E-5> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG: %[[mid_scale:.*]] = "tf.Const"() <{value = dense<0.00141507247> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG: %[[w1_scale:.*]] = "tf.Const"() <{value = dense<0.00477652298> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG: %[[b1_scale:.*]] = "tf.Const"() <{value = dense<6.75912588E-6> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG: %[[out_scale:.*]] = "tf.Const"() <{value = dense<7.24974147E-4> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG: %[[arg_quantized:.*]] = "tf.PartitionedCall"(%arg0, %[[in_scale]], %[[in_out_zp]]) <{config = "", config_proto = "", executor_type = "", f = @quantize_i8}>
+// CHECK-DAG: %[[conv0_quantized:.*]] = "tf.PartitionedCall"(%[[arg_quantized]], %[[w0_quantized]], %[[b0_quantized]], %[[in_scale]], %[[in_out_zp]], %[[w0_scale]], %[[w_b_zp]], %[[b0_scale]], %[[w_b_zp]], %[[mid_scale]], %[[in_out_zp]]) <{config = "", config_proto = "", executor_type = "", f = @quantized_conv2d_with_bias_and_relu6_fn_1}>
+// CHECK-DAG: %[[conv0_dequantized:.*]] = "tf.PartitionedCall"(%[[conv0_quantized]], %[[mid_scale]], %[[in_out_zp]]) <{config = "", config_proto = "", executor_type = "", f = @dequantize_i8}>
+// CHECK-DAG: %[[conv1_quantized:.*]] = "tf.PartitionedCall"(%[[conv0_quantized]], %[[w1_quantized]], %[[b1_quantized]], %[[mid_scale]], %[[in_out_zp]], %[[w1_scale]], %[[w_b_zp]], %[[b1_scale]], %[[w_b_zp]], %[[out_scale]], %[[in_out_zp]]) <{config = "", config_proto = "", executor_type = "", f = @quantized_conv2d_with_bias_and_relu6_fn_0}>
+// CHECK-DAG: %[[conv1_dequantized_0:.*]] = "tf.PartitionedCall"(%[[conv1_quantized]], %[[out_scale]], %[[in_out_zp]]) <{config = "", config_proto = "", executor_type = "", f = @dequantize_i8}>
+// CHECK-DAG: %[[conv1_dequantized_1:.*]] = "tf.PartitionedCall"(%[[conv1_quantized]], %[[out_scale]], %[[in_out_zp]]) <{config = "", config_proto = "", executor_type = "", f = @dequantize_i8}>
 // CHECK-DAG: %[[identity:.*]] = "tf.Identity"(%[[conv1_dequantized_1]])
-// CHECK-DAG: %[[conv0_float:.*]] = "tf.PartitionedCall"(%arg0, %[[w0_float]], %[[b0_float]]) {config = "", config_proto = "", device = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2_00}
-// CHECK-DAG: %[[conv1_float:.*]] = "tf.PartitionedCall"(%[[conv0_dequantized]], %[[w1_float]], %[[b1_float]]) {config = "", config_proto = "", device = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1_00}
-// CHECK-DAG: "tf.DumpTensor"(%[[conv0_dequantized]]) {device = "", enabled = true, file_name = "quantized_tensor_data.pb", func_name = "conv_with_dump", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2", node_name = "Conv2D"}
-// CHECK-DAG: "tf.DumpTensor"(%[[conv0_float]]) {device = "", enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "conv_with_dump", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2", node_name = "Conv2D"}
-// CHECK-DAG: "tf.DumpTensor"(%[[conv1_dequantized_0]]) {device = "", enabled = true, file_name = "quantized_tensor_data.pb", func_name = "conv_with_dump", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}
-// CHECK-DAG: "tf.DumpTensor"(%[[conv1_float]]) {device = "", enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "conv_with_dump", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}
+// CHECK-DAG: %[[conv0_float:.*]] = "tf.PartitionedCall"(%arg0, %[[w0_float]], %[[b0_float]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2_00}> {device = ""}
+// CHECK-DAG: %[[conv1_float:.*]] = "tf.PartitionedCall"(%[[conv0_dequantized]], %[[w1_float]], %[[b1_float]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1_00}> {device = ""}
+// CHECK-DAG: "tf.DumpTensor"(%[[conv0_dequantized]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "conv_with_dump", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2", node_name = "Conv2D"}> {device = ""}
+// CHECK-DAG: "tf.DumpTensor"(%[[conv0_float]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "conv_with_dump", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2", node_name = "Conv2D"}> {device = ""}
+// CHECK-DAG: "tf.DumpTensor"(%[[conv1_dequantized_0]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "conv_with_dump", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}> {device = ""}
+// CHECK-DAG: "tf.DumpTensor"(%[[conv1_float]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "conv_with_dump", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}> {device = ""}
 // CHECK-DAG: return %[[identity]]
 }
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions_drq.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions_drq.mlir
index 0731e6b..6ec9928 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions_drq.mlir
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions_drq.mlir
@@ -13,11 +13,11 @@
   }
 
 // CHECK-LABEL: func @matmul
-// CHECK-DAG: %[[q_w:.*]]  = "tf.Const"() {value = #tf_type<tensor_proto : "0x746
-// CHECK-DAG: %[[scale:.*]] = "tf.Const"() {value = dense<3.93700805E-9> : tensor<f32>} : () -> tensor<f32>
-// CHECK-DAG: %[[zp:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-// CHECK: %0 = "tf.PartitionedCall"(%arg0, %[[q_w]], %[[scale]], %[[zp]]) {config = "", config_proto = "", executor_type = "",
-// CHECK-SAME: f = @quantized_matmul_fn_0} : (tensor<2x12xf32>, tensor<12x2x!tf_type.qint8>, tensor<f32>, tensor<i32>) -> tensor<*xf32>
+// CHECK-DAG: %[[q_w:.*]]  = "tf.Const"() <{value = #tf_type<tensor_proto : "0x746
+// CHECK-DAG: %[[scale:.*]] = "tf.Const"() <{value = dense<3.93700805E-9> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG: %[[zp:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
+// CHECK: %0 = "tf.PartitionedCall"(%arg0, %[[q_w]], %[[scale]], %[[zp]]) <{config = "", config_proto = "", executor_type = "",
+// CHECK-SAME: f = @quantized_matmul_fn_0}> : (tensor<2x12xf32>, tensor<12x2x!tf_type.qint8>, tensor<f32>, tensor<i32>) -> tensor<*xf32>
 
 // CHECK-LABEL: func private @quantized_matmul_fn_0
 // CHECK:  %0 = "tf.UniformQuantizedDotHybrid"(%arg0, %arg1, %arg2, %arg3)
@@ -48,11 +48,11 @@
   }
 
 // CHECK-LABEL: func @conv
-// CHECK-DAG: %[[q_w:.*]] = "tf.Const"() {value = #tf_type<tensor_proto : "0x746674
-// CHECK-DAG: %[[w_scale:.*]] = "tf.Const"() {value = dense<0.0157480314> : tensor<f32>} : () -> tensor<f32>
-// CHECK-DAG: %[[w_zp:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-// CHECK: %[[quantize_1:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w]], %[[w_scale]], %[[w_zp]]) {config = "", config_proto = "", executor_type = "", f = @quantized_conv2d_fn_1} : (tensor<1x2x2x3xf32>, tensor<2x3x3x2x!tf_type.qint8>, tensor<f32>, tensor<i32>) -> tensor<*xf32>
-// CHECK: %[[quantize_2:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w]], %[[w_scale]], %[[w_zp]]) {config = "", config_proto = "", executor_type = "", f = @quantized_conv2d_fn_0} : (tensor<1x2x2x3xf32>, tensor<2x3x3x2x!tf_type.qint8>, tensor<f32>, tensor<i32>) -> tensor<*xf32>
+// CHECK-DAG: %[[q_w:.*]] = "tf.Const"() <{value = #tf_type<tensor_proto : "0x746674
+// CHECK-DAG: %[[w_scale:.*]] = "tf.Const"() <{value = dense<0.0157480314> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG: %[[w_zp:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
+// CHECK: %[[quantize_1:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w]], %[[w_scale]], %[[w_zp]]) <{config = "", config_proto = "", executor_type = "", f = @quantized_conv2d_fn_1}> : (tensor<1x2x2x3xf32>, tensor<2x3x3x2x!tf_type.qint8>, tensor<f32>, tensor<i32>) -> tensor<*xf32>
+// CHECK: %[[quantize_2:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w]], %[[w_scale]], %[[w_zp]]) <{config = "", config_proto = "", executor_type = "", f = @quantized_conv2d_fn_0}> : (tensor<1x2x2x3xf32>, tensor<2x3x3x2x!tf_type.qint8>, tensor<f32>, tensor<i32>) -> tensor<*xf32>
 // CHECK: return %[[quantize_1]], %[[quantize_2]]
 
 // CHECK-LABEL: func private @quantized_conv2d_fn_0
@@ -102,17 +102,17 @@
   }
 
 // CHECK-LABEL: func @depthwise_conv
-// CHECK-DAG: %[[bias:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3xf32>} : () -> tensor<3xf32>
-// CHECK-DAG: %[[q_w1:.*]] = "tf.Const"() {value = #tf_type<tensor_proto : "0x746674
+// CHECK-DAG: %[[bias:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<3xf32>}> : () -> tensor<3xf32>
+// CHECK-DAG: %[[q_w1:.*]] = "tf.Const"() <{value = #tf_type<tensor_proto : "0x746674
 // CHECK-SAME:                                                                     -> tensor<2x3x1x3x!tf_type.qint8>
-// CHECK-DAG: %[[q_w2:.*]] = "tf.Const"() {value = #tf_type<tensor_proto : "0x746674
+// CHECK-DAG: %[[q_w2:.*]] = "tf.Const"() <{value = #tf_type<tensor_proto : "0x746674
 // CHECK-SAME:                                                                     -> tensor<2x3x1x6x!tf_type.qint8>
-// CHECK-DAG: %[[w_scale:.*]] = "tf.Const"() {value = dense<0.0236220472> : tensor<f32>} : () -> tensor<f32>
-// CHECK-DAG: %[[w_zp:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG: %[[w_scale:.*]] = "tf.Const"() <{value = dense<0.0236220472> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG: %[[w_zp:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 
-// CHECK: %[[quantize_1:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w1]], %[[w_scale]], %[[w_zp]]) {config = "", config_proto = "", executor_type = "", f = @quantized_depthwise_conv2d_fn_1} : (tensor<1x3x4x3xf32>, tensor<2x3x1x3x!tf_type.qint8>, tensor<f32>, tensor<i32>) -> tensor<*xf32>
-// CHECK: %[[quantize_1_add:.*]] = "tf.BiasAdd"(%[[quantize_1]], %[[bias]]) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<3xf32>) -> tensor<*xf32>
-// CHECK: %[[quantize_2:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w2]], %[[w_scale]], %[[w_zp]]) {config = "", config_proto = "", executor_type = "", f = @quantized_depthwise_conv2d_fn_0} : (tensor<1x3x4x3xf32>, tensor<2x3x1x6x!tf_type.qint8>, tensor<f32>, tensor<i32>) -> tensor<*xf32>
+// CHECK: %[[quantize_1:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w1]], %[[w_scale]], %[[w_zp]]) <{config = "", config_proto = "", executor_type = "", f = @quantized_depthwise_conv2d_fn_1}> : (tensor<1x3x4x3xf32>, tensor<2x3x1x3x!tf_type.qint8>, tensor<f32>, tensor<i32>) -> tensor<*xf32>
+// CHECK: %[[quantize_1_add:.*]] = "tf.BiasAdd"(%[[quantize_1]], %[[bias]]) <{data_format = "NHWC"}> {device = ""} : (tensor<*xf32>, tensor<3xf32>) -> tensor<*xf32>
+// CHECK: %[[quantize_2:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w2]], %[[w_scale]], %[[w_zp]]) <{config = "", config_proto = "", executor_type = "", f = @quantized_depthwise_conv2d_fn_0}> : (tensor<1x3x4x3xf32>, tensor<2x3x1x6x!tf_type.qint8>, tensor<f32>, tensor<i32>) -> tensor<*xf32>
 // CHECK: return %[[quantize_1_add]], %[[quantize_2]]
 
 // CHECK-LABEL: func private @quantized_depthwise_conv2d_fn_0
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions_weight_only.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions_weight_only.mlir
index 8c07861..ba8f213 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions_weight_only.mlir
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions_weight_only.mlir
@@ -15,19 +15,19 @@
 }
 
 // PerTensor-LABEL: func @matmul
-// PerTensor-DAG: %[[q_w:.*]] = "tf.Const"() {value = dense<0> : tensor<12x2xi8>} : () -> tensor<12x2xi8>
-// PerTensor-DAG: %[[scale:.*]] = "tf.Const"() {value = dense<3.93700805E-9> : tensor<f32>} : () -> tensor<f32>
-// PerTensor-DAG: %[[zp:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-// PerTensor: %[[out:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w]], %[[scale]], %[[zp]]) {config = "", config_proto = "", executor_type = "",
-// PerTensor-SAME: f = @quantized_matmul_fn_0} : (tensor<2x12xf32>, tensor<12x2xi8>, tensor<f32>, tensor<i32>) -> tensor<*xf32>
+// PerTensor-DAG: %[[q_w:.*]] = "tf.Const"() <{value = dense<0> : tensor<12x2xi8>}> : () -> tensor<12x2xi8>
+// PerTensor-DAG: %[[scale:.*]] = "tf.Const"() <{value = dense<3.93700805E-9> : tensor<f32>}> : () -> tensor<f32>
+// PerTensor-DAG: %[[zp:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
+// PerTensor: %[[out:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w]], %[[scale]], %[[zp]]) <{config = "", config_proto = "", executor_type = "",
+// PerTensor-SAME: f = @quantized_matmul_fn_0}> : (tensor<2x12xf32>, tensor<12x2xi8>, tensor<f32>, tensor<i32>) -> tensor<*xf32>
 // PerTensor: return %[[out]]
 
 // PerChannel-LABEL: func @matmul
-// PerChannel-DAG: %[[q_w:.*]] = "tf.Const"() {value = dense<0> : tensor<12x2xi8>} : () -> tensor<12x2xi8>
-// PerChannel-DAG: %[[scale:.*]] = "tf.Const"() {value = dense<3.93700805E-9> : tensor<f32>} : () -> tensor<f32>
-// PerChannel-DAG: %[[zp:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-// PerChannel: %[[out:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w]], %[[scale]], %[[zp]]) {config = "", config_proto = "", executor_type = "",
-// PerChannel-SAME: f = @quantized_matmul_fn_0} : (tensor<2x12xf32>, tensor<12x2xi8>, tensor<f32>, tensor<i32>) -> tensor<*xf32>
+// PerChannel-DAG: %[[q_w:.*]] = "tf.Const"() <{value = dense<0> : tensor<12x2xi8>}> : () -> tensor<12x2xi8>
+// PerChannel-DAG: %[[scale:.*]] = "tf.Const"() <{value = dense<3.93700805E-9> : tensor<f32>}> : () -> tensor<f32>
+// PerChannel-DAG: %[[zp:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
+// PerChannel: %[[out:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w]], %[[scale]], %[[zp]]) <{config = "", config_proto = "", executor_type = "",
+// PerChannel-SAME: f = @quantized_matmul_fn_0}> : (tensor<2x12xf32>, tensor<12x2xi8>, tensor<f32>, tensor<i32>) -> tensor<*xf32>
 // PerChannel: return %[[out]]
 
 // -----
@@ -51,23 +51,23 @@
   }
 
 // PerTensor-LABEL: func @conv
-// PerTensor-DAG: %[[q_w:.*]] = "tf.Const"() {value = dense<{{[0-9]+}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2xi8>
-// PerTensor-DAG: %[[scale:.*]] = "tf.Const"() {value = dense<{{[0-9\.Ee\+\-]+}}> : tensor<f32>} : () -> tensor<f32>
-// PerTensor-DAG: %[[zp:.*]] = "tf.Const"() {value = dense<{{[0-9]+}}> : tensor<i32>} : () -> tensor<i32>
-// PerTensor: %[[out_1:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w]], %[[scale]], %[[zp]]) {config = "", config_proto = "", executor_type = "",
-// PerTensor-SAME: f = @quantized_conv2d_fn_1} : (tensor<1x2x2x3xf32>, tensor<2x3x3x2xi8>, tensor<f32>, tensor<i32>) -> tensor<*xf32>
-// PerTensor: %[[out_2:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w]], %[[scale]], %[[zp]]) {config = "", config_proto = "", executor_type = "",
-// PerTensor-SAME: f = @quantized_conv2d_fn_0} : (tensor<1x2x2x3xf32>, tensor<2x3x3x2xi8>, tensor<f32>, tensor<i32>) -> tensor<*xf32>
+// PerTensor-DAG: %[[q_w:.*]] = "tf.Const"() <{value = dense<{{[0-9]+}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2xi8>
+// PerTensor-DAG: %[[scale:.*]] = "tf.Const"() <{value = dense<{{[0-9\.Ee\+\-]+}}> : tensor<f32>}> : () -> tensor<f32>
+// PerTensor-DAG: %[[zp:.*]] = "tf.Const"() <{value = dense<{{[0-9]+}}> : tensor<i32>}> : () -> tensor<i32>
+// PerTensor: %[[out_1:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w]], %[[scale]], %[[zp]]) <{config = "", config_proto = "", executor_type = "",
+// PerTensor-SAME: f = @quantized_conv2d_fn_1}> : (tensor<1x2x2x3xf32>, tensor<2x3x3x2xi8>, tensor<f32>, tensor<i32>) -> tensor<*xf32>
+// PerTensor: %[[out_2:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w]], %[[scale]], %[[zp]]) <{config = "", config_proto = "", executor_type = "",
+// PerTensor-SAME: f = @quantized_conv2d_fn_0}> : (tensor<1x2x2x3xf32>, tensor<2x3x3x2xi8>, tensor<f32>, tensor<i32>) -> tensor<*xf32>
 // PerTensor: return %[[out_1]], %[[out_2]]
 
 // PerChannel-LABEL: func @conv
-// PerChannel-DAG: %[[q_w:.*]] = "tf.Const"() {value = dense<{{[0-9]+}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2xi8>
-// PerChannel-DAG: %[[scale:.*]] = "tf.Const"() {value = dense<{{[0-9\.Ee\+\-]+}}> : tensor<2xf32>} : () -> tensor<2xf32>
-// PerChannel-DAG: %[[zp:.*]] = "tf.Const"() {value = dense<{{[0-9]+}}> : tensor<2xi32>} : () -> tensor<2xi32>
-// PerChannel: %[[out_1:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w]], %[[scale]], %[[zp]]) {config = "", config_proto = "", executor_type = "",
-// PerChannel-SAME: f = @quantized_conv2d_fn_1} : (tensor<1x2x2x3xf32>, tensor<2x3x3x2xi8>, tensor<2xf32>, tensor<2xi32>) -> tensor<*xf32>
-// PerChannel: %[[out_2:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w]], %[[scale]], %[[zp]]) {config = "", config_proto = "", executor_type = "",
-// PerChannel-SAME: f = @quantized_conv2d_fn_0} : (tensor<1x2x2x3xf32>, tensor<2x3x3x2xi8>, tensor<2xf32>, tensor<2xi32>) -> tensor<*xf32>
+// PerChannel-DAG: %[[q_w:.*]] = "tf.Const"() <{value = dense<{{[0-9]+}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2xi8>
+// PerChannel-DAG: %[[scale:.*]] = "tf.Const"() <{value = dense<{{[0-9\.Ee\+\-]+}}> : tensor<2xf32>}> : () -> tensor<2xf32>
+// PerChannel-DAG: %[[zp:.*]] = "tf.Const"() <{value = dense<{{[0-9]+}}> : tensor<2xi32>}> : () -> tensor<2xi32>
+// PerChannel: %[[out_1:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w]], %[[scale]], %[[zp]]) <{config = "", config_proto = "", executor_type = "",
+// PerChannel-SAME: f = @quantized_conv2d_fn_1}> : (tensor<1x2x2x3xf32>, tensor<2x3x3x2xi8>, tensor<2xf32>, tensor<2xi32>) -> tensor<*xf32>
+// PerChannel: %[[out_2:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w]], %[[scale]], %[[zp]]) <{config = "", config_proto = "", executor_type = "",
+// PerChannel-SAME: f = @quantized_conv2d_fn_0}> : (tensor<1x2x2x3xf32>, tensor<2x3x3x2xi8>, tensor<2xf32>, tensor<2xi32>) -> tensor<*xf32>
 // PerChannel: return %[[out_1]], %[[out_2]]
 
 }
@@ -98,30 +98,30 @@
   }
 
 // PerTensor-LABEL: func @depthwise_conv
-// PerTensor-DAG: %[[q_w1:.*]] = "tf.Const"() {value = dense<127> : tensor<2x3x3x1xi8>}
-// PerTensor-DAG: %[[q_w2:.*]] = "tf.Const"() {value = dense<127> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2xi8>
-// PerTensor-DAG: %[[scale:.*]] = "tf.Const"() {value = dense<0.0236220472> : tensor<f32>} : () -> tensor<f32>
-// PerTensor-DAG: %[[zp:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-// PerTensor-DAG: %[[bias:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3xf32>}
-// PerTensor: %[[out_1:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w1]], %[[scale]], %[[zp]]) {config = "", config_proto = "", executor_type = "",
-// PerTensor-SAME: f = @quantized_depthwise_conv2d_fn_1} : (tensor<1x3x4x3xf32>, tensor<2x3x3x1xi8>, tensor<f32>, tensor<i32>) -> tensor<*xf32>
+// PerTensor-DAG: %[[q_w1:.*]] = "tf.Const"() <{value = dense<127> : tensor<2x3x3x1xi8>}>
+// PerTensor-DAG: %[[q_w2:.*]] = "tf.Const"() <{value = dense<127> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2xi8>
+// PerTensor-DAG: %[[scale:.*]] = "tf.Const"() <{value = dense<0.0236220472> : tensor<f32>}> : () -> tensor<f32>
+// PerTensor-DAG: %[[zp:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
+// PerTensor-DAG: %[[bias:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<3xf32>}>
+// PerTensor: %[[out_1:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w1]], %[[scale]], %[[zp]]) <{config = "", config_proto = "", executor_type = "",
+// PerTensor-SAME: f = @quantized_depthwise_conv2d_fn_1}> : (tensor<1x3x4x3xf32>, tensor<2x3x3x1xi8>, tensor<f32>, tensor<i32>) -> tensor<*xf32>
 // PerTensor: %[[out_1_add:.*]]  = "tf.BiasAdd"(%[[out_1]], %[[bias]])
-// PerTensor: %[[out_2:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w2]], %[[scale]], %[[zp]]) {config = "", config_proto = "", executor_type = "",
-// PerTensor-SAME: f = @quantized_depthwise_conv2d_fn_0} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xi8>, tensor<f32>, tensor<i32>) -> tensor<*xf32>
+// PerTensor: %[[out_2:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w2]], %[[scale]], %[[zp]]) <{config = "", config_proto = "", executor_type = "",
+// PerTensor-SAME: f = @quantized_depthwise_conv2d_fn_0}> : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xi8>, tensor<f32>, tensor<i32>) -> tensor<*xf32>
 // PerTensor: return %[[out_1_add]], %[[out_2]]
 
 // PerChannel-LABEL: func @depthwise_conv
-// PerChannel-DAG: %[[bias1:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3xf32>} : () -> tensor<3xf32>
-// PerChannel-DAG: %[[q_w1:.*]] = "tf.Const"() {value = dense<{{[0-9]+}}> : tensor<2x3x3x1xi8>} : () -> tensor<2x3x3x1xi8>
-// PerChannel-DAG: %[[q_w2:.*]] = "tf.Const"() {value = dense<{{[0-9]+}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2xi8>
-// PerChannel-DAG: %[[scale1:.*]] = "tf.Const"() {value = dense<{{[0-9\.Ee\+\-]+}}> : tensor<3xf32>} : () -> tensor<3xf32>
-// PerChannel-DAG: %[[scale2:.*]] = "tf.Const"() {value = dense<{{[0-9\.Ee\+\-]+}}> : tensor<6xf32>} : () -> tensor<6xf32>
-// PerChannel-DAG: %[[zp1:.*]] = "tf.Const"() {value = dense<{{[0-9]+}}> : tensor<3xi32>} : () -> tensor<3xi32>
-// PerChannel-DAG: %[[zp2:.*]] = "tf.Const"() {value = dense<{{[0-9]+}}> : tensor<6xi32>} : () -> tensor<6xi32>
-// PerChannel: %[[out_1:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w1]], %[[scale1]], %[[zp1]]) {config = "", config_proto = "", executor_type = "",
-// PerChannel-SAME: f = @quantized_depthwise_conv2d_fn_1} : (tensor<1x3x4x3xf32>, tensor<2x3x3x1xi8>, tensor<3xf32>, tensor<3xi32>) -> tensor<*xf32>
+// PerChannel-DAG: %[[bias1:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<3xf32>}> : () -> tensor<3xf32>
+// PerChannel-DAG: %[[q_w1:.*]] = "tf.Const"() <{value = dense<{{[0-9]+}}> : tensor<2x3x3x1xi8>}> : () -> tensor<2x3x3x1xi8>
+// PerChannel-DAG: %[[q_w2:.*]] = "tf.Const"() <{value = dense<{{[0-9]+}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2xi8>
+// PerChannel-DAG: %[[scale1:.*]] = "tf.Const"() <{value = dense<{{[0-9\.Ee\+\-]+}}> : tensor<3xf32>}> : () -> tensor<3xf32>
+// PerChannel-DAG: %[[scale2:.*]] = "tf.Const"() <{value = dense<{{[0-9\.Ee\+\-]+}}> : tensor<6xf32>}> : () -> tensor<6xf32>
+// PerChannel-DAG: %[[zp1:.*]] = "tf.Const"() <{value = dense<{{[0-9]+}}> : tensor<3xi32>}> : () -> tensor<3xi32>
+// PerChannel-DAG: %[[zp2:.*]] = "tf.Const"() <{value = dense<{{[0-9]+}}> : tensor<6xi32>}> : () -> tensor<6xi32>
+// PerChannel: %[[out_1:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w1]], %[[scale1]], %[[zp1]]) <{config = "", config_proto = "", executor_type = "",
+// PerChannel-SAME: f = @quantized_depthwise_conv2d_fn_1}> : (tensor<1x3x4x3xf32>, tensor<2x3x3x1xi8>, tensor<3xf32>, tensor<3xi32>) -> tensor<*xf32>
 // PerChannel: %[[out_1_add:.*]]  = "tf.BiasAdd"(%[[out_1]], %[[bias1]])
-// PerChannel: %[[out_2:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w2]], %[[scale2]], %[[zp2]]) {config = "", config_proto = "", executor_type = "",
-// PerChannel-SAME: f = @quantized_depthwise_conv2d_fn_0} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xi8>, tensor<6xf32>, tensor<6xi32>) -> tensor<*xf32>
+// PerChannel: %[[out_2:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w2]], %[[scale2]], %[[zp2]]) <{config = "", config_proto = "", executor_type = "",
+// PerChannel-SAME: f = @quantized_depthwise_conv2d_fn_0}> : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xi8>, tensor<6xf32>, tensor<6xi32>) -> tensor<*xf32>
 // PerChannel: return %[[out_1_add]], %[[out_2]]
 }
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions_xla.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions_xla.mlir
index 38f0662..38b4127 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions_xla.mlir
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions_xla.mlir
@@ -1,4 +1,5 @@
 // RUN: tf-quant-opt %s -split-input-file -quant-insert-quantized-functions -quant-quantize-composite-functions='target-opset=XLA' | FileCheck %s
+// RUN: tf-quant-opt %s -split-input-file -quant-insert-quantized-functions -quant-quantize-composite-functions='target-opset=XLA enable-per-channel-quantization=true' | FileCheck --check-prefix=PerChannel %s
 
 module {
   func.func @conv_with_single_layer(%arg0: tensor<1x2x2x3xf32>) -> (tensor<*xf32>) {
@@ -32,7 +33,7 @@
 // CHECK-LABEL: func private @quantized_conv2d_with_bias_and_relu6_float_output_fn_0
 // CHECK-SAME: (%arg0: tensor<1x2x2x3xi8>, %arg1: tensor<2x2x3x2xi8>, %arg2: tensor<2xi32>, %arg3: tensor<f32>, %arg4: tensor<i32>, %arg5: tensor<2xf32>, %arg6: tensor<2xi32>, %arg7: tensor<2xf32>, %arg8: tensor<2xi32>, %arg9: tensor<f32>, %arg10: tensor<i32>) -> tensor<*xf32>
 // CHECK:      %[[CONV2D_0:.*]] = "tf.Conv2D"
-// CHECK-SAME: {dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true}
+// CHECK-SAME: <{dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true}>
 
 // CHECK: -------- Quantization Summary --------
 // CHECK: Number of quantized layers in the model
@@ -123,7 +124,7 @@
 // CHECK: %[[conv_quant:.*]] = "tf.PartitionedCall"(%[[quantize]]
 // CHECK-SAME: f = @quantized_conv2d_with_bias_and_relu6_fn_0
 // CHECK-SAME: (tensor<1x2x2x3xi8>, tensor<2x2x3x2xi8>, tensor<2xi32>, tensor<f32>, tensor<i32>, tensor<2xf32>, tensor<2xi32>, tensor<2xf32>, tensor<2xi32>, tensor<f32>, tensor<i32>) -> tensor<*xi8>
-// CHECK: %[[maxpool:.*]] = "tf.MaxPool"(%[[conv_quant]]) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<*xi8>) -> tensor<*xi8>
+// CHECK: %[[maxpool:.*]] = "tf.MaxPool"(%[[conv_quant]]) <{data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]}> : (tensor<*xi8>) -> tensor<*xi8>
 // CHECK: %[[dequantize:.*]] = "tf.PartitionedCall"(%[[maxpool]]
 // CHECK-SAME: f = @dequantize_i8
 // CHECK: return %[[dequantize]]
@@ -297,34 +298,94 @@
     func.return %2 : tensor<*xf32>
   }
 
-// CHECK-LABE: @conv_with_dump
-// CHECK-DAG: %[[w0_float:.*]] = "tf.Const"() {value = dense<{{\[\[\[\[}}-0.282878935, -0.211567819
-// CHECK-DAG: %[[b0_float:.*]] = "tf.Const"() {value = dense<[-0.0192535277, -5.998660e-03]> : tensor<2xf32>} : () -> tensor<2xf32>
-// CHECK-DAG: %[[w1_float:.*]] = "tf.Const"() {value = dense<{{\[\[\[\[}}0.208403707, 0.478067577
-// CHECK-DAG: %[[b1_float:.*]] = "tf.Const"() {value = dense<[-0.0291469581, 0.0106381178]> : tensor<2xf32>} : () -> tensor<2xf32>
-// CHECK-DAG: %[[w0_quantized:.*]] = "tf.Const"() {value = dense<{{\[\[\[\[}}-59, -44
-// CHECK-DAG: %[[b0_quantized:.*]] = "tf.Const"() {value = dense<[-1040, -324]> : tensor<2xi32>} : () -> tensor<2xi32>
-// CHECK-DAG: %[[w1_quantized:.*]] = "tf.Const"() {value = dense<{{\[\[\[\[}}44, 100
-// CHECK-DAG: %[[b1_quantized:.*]] = "tf.Const"() {value = dense<[-4312, 1574]> : tensor<2xi32>} : () -> tensor<2xi32>
-// CHECK-DAG: %[[in_scale:.*]] = "tf.Const"() {value = dense<0.00387597573> : tensor<f32>} : () -> tensor<f32>
-// CHECK-DAG: %[[in_out_zp:.*]] = "tf.Const"() {value = dense<-128> : tensor<i32>} : () -> tensor<i32>
-// CHECK-DAG: %[[w0_scale:.*]] = "tf.Const"() {value = dense<0.00477493973> : tensor<f32>} : () -> tensor<f32>
-// CHECK-DAG: %[[w_b_zp:.*]]  = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-// CHECK-DAG: %[[b0_scale:.*]] = "tf.Const"() {value = dense<1.85075514E-5> : tensor<f32>} : () -> tensor<f32>
-// CHECK-DAG: %[[mid_scale:.*]] = "tf.Const"() {value = dense<0.00141507247> : tensor<f32>} : () -> tensor<f32>
-// CHECK-DAG: %[[w1_scale:.*]] = "tf.Const"() {value = dense<0.00477652298> : tensor<f32>} : () -> tensor<f32>
-// CHECK-DAG: %[[b1_scale:.*]] = "tf.Const"() {value = dense<6.75912588E-6> : tensor<f32>} : () -> tensor<f32>
-// CHECK-DAG: %[[out_scale:.*]] = "tf.Const"() {value = dense<7.24974147E-4> : tensor<f32>} : () -> tensor<f32>
-// CHECK-DAG: %[[quantized:.*]] = "tf.PartitionedCall"(%arg0, %[[in_scale]], %[[in_out_zp]]) {config = "", config_proto = "", executor_type = "", f = @quantize_i8}
-// CHECK-DAG: %[[conv0_dequantized:.*]] = "tf.PartitionedCall"(%[[quantized]], %[[w0_quantized]], %[[b0_quantized]], %[[in_scale]], %[[in_out_zp]], %[[w0_scale]], %[[w_b_zp]], %[[b0_scale]], %[[w_b_zp]], %[[mid_scale]], %[[in_out_zp]]) {config = "", config_proto = "", executor_type = "", f = @quantized_conv2d_with_bias_and_relu6_float_output_fn_1}
-// CHECK-DAG: %[[conv0_quantized:.*]] = "tf.PartitionedCall"(%[[quantized]], %[[w0_quantized]], %[[b0_quantized]], %[[in_scale]], %[[in_out_zp]], %[[w0_scale]], %[[w_b_zp]], %[[b0_scale]], %[[w_b_zp]], %[[mid_scale]], %[[in_out_zp]]) {config = "", config_proto = "", executor_type = "", f = @quantized_conv2d_with_bias_and_relu6_fn_1}
-// CHECK-DAG: %[[conv1_dequantized:.*]] = "tf.PartitionedCall"(%[[conv0_quantized]], %[[w1_quantized]], %[[b1_quantized]], %[[mid_scale]], %[[in_out_zp]], %[[w1_scale]], %[[w_b_zp]], %[[b1_scale]], %[[w_b_zp]], %[[out_scale]], %[[in_out_zp]]) {config = "", config_proto = "", executor_type = "", f = @quantized_conv2d_with_bias_and_relu6_float_output_fn_0}
+// CHECK-LABEL: func @conv_with_dump
+// CHECK-DAG: %[[w0_float:.*]] = "tf.Const"() <{value = dense<{{\[\[\[\[}}-0.282878935, -0.211567819
+// CHECK-DAG: %[[b0_float:.*]] = "tf.Const"() <{value = dense<[-0.0192535277, -5.998660e-03]> : tensor<2xf32>}> : () -> tensor<2xf32>
+// CHECK-DAG: %[[w1_float:.*]] = "tf.Const"() <{value = dense<{{\[\[\[\[}}0.208403707, 0.478067577
+// CHECK-DAG: %[[b1_float:.*]] = "tf.Const"() <{value = dense<[-0.0291469581, 0.0106381178]> : tensor<2xf32>}> : () -> tensor<2xf32>
+// CHECK-DAG: %[[w0_quantized:.*]] = "tf.Const"() <{value = dense<{{\[\[\[\[}}-59, -44
+// CHECK-DAG: %[[b0_quantized:.*]] = "tf.Const"() <{value = dense<[-1040, -324]> : tensor<2xi32>}> : () -> tensor<2xi32>
+// CHECK-DAG: %[[w1_quantized:.*]] = "tf.Const"() <{value = dense<{{\[\[\[\[}}44, 100
+// CHECK-DAG: %[[b1_quantized:.*]] = "tf.Const"() <{value = dense<[-4312, 1574]> : tensor<2xi32>}> : () -> tensor<2xi32>
+// CHECK-DAG: %[[in_scale:.*]] = "tf.Const"() <{value = dense<0.00387597573> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG: %[[in_out_zp:.*]] = "tf.Const"() <{value = dense<-128> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG: %[[w0_scale:.*]] = "tf.Const"() <{value = dense<0.00477493973> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG: %[[w_b_zp:.*]]  = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG: %[[b0_scale:.*]] = "tf.Const"() <{value = dense<1.85075514E-5> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG: %[[mid_scale:.*]] = "tf.Const"() <{value = dense<0.00141507247> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG: %[[w1_scale:.*]] = "tf.Const"() <{value = dense<0.00477652298> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG: %[[b1_scale:.*]] = "tf.Const"() <{value = dense<6.75912588E-6> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG: %[[out_scale:.*]] = "tf.Const"() <{value = dense<7.24974147E-4> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG: %[[quantized:.*]] = "tf.PartitionedCall"(%arg0, %[[in_scale]], %[[in_out_zp]]) <{config = "", config_proto = "", executor_type = "", f = @quantize_i8}>
+// CHECK-DAG: %[[conv0_dequantized:.*]] = "tf.PartitionedCall"(%[[quantized]], %[[w0_quantized]], %[[b0_quantized]], %[[in_scale]], %[[in_out_zp]], %[[w0_scale]], %[[w_b_zp]], %[[b0_scale]], %[[w_b_zp]], %[[mid_scale]], %[[in_out_zp]]) <{config = "", config_proto = "", executor_type = "", f = @quantized_conv2d_with_bias_and_relu6_float_output_fn_1}>
+// CHECK-DAG: %[[conv0_quantized:.*]] = "tf.PartitionedCall"(%[[quantized]], %[[w0_quantized]], %[[b0_quantized]], %[[in_scale]], %[[in_out_zp]], %[[w0_scale]], %[[w_b_zp]], %[[b0_scale]], %[[w_b_zp]], %[[mid_scale]], %[[in_out_zp]]) <{config = "", config_proto = "", executor_type = "", f = @quantized_conv2d_with_bias_and_relu6_fn_1}>
+// CHECK-DAG: %[[conv1_dequantized:.*]] = "tf.PartitionedCall"(%[[conv0_quantized]], %[[w1_quantized]], %[[b1_quantized]], %[[mid_scale]], %[[in_out_zp]], %[[w1_scale]], %[[w_b_zp]], %[[b1_scale]], %[[w_b_zp]], %[[out_scale]], %[[in_out_zp]]) <{config = "", config_proto = "", executor_type = "", f = @quantized_conv2d_with_bias_and_relu6_float_output_fn_0}>
 // CHECK-DAG: %[[identity:.*]] = "tf.Identity"(%[[conv1_dequantized]])
-// CHECK-DAG: %[[conv0_float:.*]] = "tf.PartitionedCall"(%arg0, %[[w0_float]], %[[b0_float]]) {config = "", config_proto = "", device = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2_00}
-// CHECK-DAG: %[[conv1_float:.*]] = "tf.PartitionedCall"(%[[conv0_dequantized]], %[[w1_float]], %[[b1_float]]) {config = "", config_proto = "", device = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1_00}
-// CHECK-DAG: "tf.DumpTensor"(%[[conv0_dequantized]]) {device = "", enabled = true, file_name = "quantized_tensor_data.pb", func_name = "conv_with_dump", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2", node_name = "Conv2D"}
-// CHECK-DAG: "tf.DumpTensor"(%[[conv0_float]]) {device = "", enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "conv_with_dump", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2", node_name = "Conv2D"}
-// CHECK-DAG: "tf.DumpTensor"(%[[conv1_dequantized]]) {device = "", enabled = true, file_name = "quantized_tensor_data.pb", func_name = "conv_with_dump", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}
-// CHECK-DAG: "tf.DumpTensor"(%[[conv1_float]]) {device = "", enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "conv_with_dump", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}
+// CHECK-DAG: %[[conv0_float:.*]] = "tf.PartitionedCall"(%arg0, %[[w0_float]], %[[b0_float]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2_00}> {device = ""}
+// CHECK-DAG: %[[conv1_float:.*]] = "tf.PartitionedCall"(%[[conv0_dequantized]], %[[w1_float]], %[[b1_float]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1_00}> {device = ""}
+// CHECK-DAG: "tf.DumpTensor"(%[[conv0_dequantized]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "conv_with_dump", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2", node_name = "Conv2D"}> {device = ""}
+// CHECK-DAG: "tf.DumpTensor"(%[[conv0_float]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "conv_with_dump", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2", node_name = "Conv2D"}> {device = ""}
+// CHECK-DAG: "tf.DumpTensor"(%[[conv1_dequantized]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "conv_with_dump", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}> {device = ""}
+// CHECK-DAG: "tf.DumpTensor"(%[[conv1_float]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "conv_with_dump", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}> {device = ""}
 // CHECK-DAG: return %[[identity]]
+
+// PerChannel-LABEL: func @conv_with_dump
+// PerChannel-DAG: %[[PerChannel_w0_float:.*]] = "tf.Const"() <{value = dense<{{\[\[\[\[}}-0.282878935, -0.211567819
+// PerChannel-DAG: %[[b0_float:.*]] = "tf.Const"() <{value = dense<[-0.0192535277, -5.998660e-03]> : tensor<2xf32>}> : () -> tensor<2xf32>
+// PerChannel-DAG: %[[w1_float:.*]] = "tf.Const"() <{value = dense<{{\[\[\[\[}}0.208403707, 0.478067577
+// PerChannel-DAG: %[[b1_float:.*]] = "tf.Const"() <{value = dense<[-0.0291469581, 0.0106381178]> : tensor<2xf32>}> : () -> tensor<2xf32>
+// PerChannel-DAG: %[[w0_quantized:.*]] = "tf.Const"() <{value = dense<{{\[\[\[\[}}-59, -77
+// PerChannel-DAG: %[[b0_quantized:.*]] = "tf.Const"() <{value = dense<[-1040, -561]> : tensor<2xi32>}> : () -> tensor<2xi32>
+// PerChannel-DAG: %[[w1_quantized:.*]] = "tf.Const"() <{value = dense<{{\[\[\[\[}}45, 100
+// PerChannel-DAG: %[[b1_quantized:.*]] = "tf.Const"() <{value = dense<[-4411, 1574]> : tensor<2xi32>}> : () -> tensor<2xi32>
+// PerChannel-DAG: %[[in_scale:.*]] = "tf.Const"() <{value = dense<0.00387597573> : tensor<f32>}> : () -> tensor<f32>
+// PerChannel-DAG: %[[in_out_zp:.*]] = "tf.Const"() <{value = dense<-128> : tensor<i32>}> : () -> tensor<i32>
+// PerChannel-DAG: %[[w0_scale:.*]] = "tf.Const"() <{value = dense<[0.00477493973, 0.00275693159]> : tensor<2xf32>}> : () -> tensor<2xf32>
+// PerChannel-DAG: %[[w_b_zp:.*]]  = "tf.Const"() <{value = dense<0> : tensor<2xi32>}> : () -> tensor<2xi32>
+// PerChannel-DAG: %[[b0_scale:.*]] = "tf.Const"() <{value = dense<[1.85075514E-5, 1.06858006E-5]> : tensor<2xf32>}> : () -> tensor<2xf32>
+// PerChannel-DAG: %[[mid_scale:.*]] = "tf.Const"() <{value = dense<0.00141507247> : tensor<f32>}> : () -> tensor<f32>
+// PerChannel-DAG: %[[w1_scale:.*]] = "tf.Const"() <{value = dense<[0.00467005931, 0.00477652298]> : tensor<2xf32>}> : () -> tensor<2xf32>
+// PerChannel-DAG: %[[b1_scale:.*]] = "tf.Const"() <{value = dense<[6.60847217E-6, 6.75912588E-6]> : tensor<2xf32>}> : () -> tensor<2xf32>
+// PerChannel-DAG: %[[out_scale:.*]] = "tf.Const"() <{value = dense<7.24974147E-4> : tensor<f32>}> : () -> tensor<f32>
+}
+
+// -----
+
+module {
+  func.func @conv_with_per_channel_and_tensor_weight(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> {
+    %cst = "tf.Const"() {device = "", value = dense<[7.11401462, 7.05456924]> : tensor<2xf32>} : () -> tensor<2xf32>
+    %cst_0 = "tf.Const"() {device = "", value = dense<[[[[-0.630731344, 0.277245104], [0.54962182, 0.927732646], [0.180364341, 1.90948534]], [[-0.764542698, -0.287541777], [-0.211145893, -1.59367061], [-0.708605706, 1.79999375]], [[-0.954062759, 0.197947085], [-0.614013135, -0.966769516], [0.612640202, -1.45540595]]], [[[-0.418223292, 0.234433219], [5.057390e-01, 1.86747122], [0.899269938, 0.145780042]], [[0.335351914, 1.02572429], [0.084816426, 1.79729116], [-0.664676845, 0.310017586]], [[-0.795477629, -7.709830e-01], [0.581315517, 0.740075528], [0.921566545, 1.85318887]]]]> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32>
+    %0 = "quantfork.stats"(%arg0) {layerStats = dense<[4.6128589E-5, 0.999998927]> : tensor<2xf32>} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xf32>
+    %1 = "tf.PartitionedCall"(%0, %cst_0, %cst) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", device = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>, tensor<2xf32>) -> tensor<1x3x4x2xf32>
+    %2 = "quantfork.stats"(%1) {layerStats = dense<[3.50919247, 6.000000e+00]> : tensor<2xf32>} : (tensor<1x3x4x2xf32>) -> tensor<1x3x4x2xf32>
+    %3 = "tf.Identity"(%2) {device = ""} : (tensor<1x3x4x2xf32>) -> tensor<1x3x4x2xf32>
+    %4 = "quantfork.stats"(%3) {layerStats = dense<[3.50919247, 6.000000e+00]> : tensor<2xf32>} : (tensor<1x3x4x2xf32>) -> tensor<1x3x4x2xf32>
+    func.return %4 : tensor<1x3x4x2xf32>
+  }
+  func.func private @composite_conv2d_with_bias_and_relu6_fn_1(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>, %arg2: tensor<2xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "composite_conv2d_with_bias_and_relu6_fn_1", tf.tf_quant.composite_function} {
+    %0 = "tf.Conv2D"(%arg0, %arg1) {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32>
+    %1 = "tf.BiasAdd"(%0, %arg2) {data_format = "NHWC", device = ""} : (tensor<1x3x4x2xf32>, tensor<2xf32>) -> tensor<1x3x4x2xf32>
+    %2 = "tf.Relu6"(%1) {device = ""} : (tensor<1x3x4x2xf32>) -> tensor<1x3x4x2xf32>
+    func.return %2 : tensor<1x3x4x2xf32>
+  }
+
+// CHECK-LABEL: func @conv_with_per_channel_and_tensor_weight
+// CHECK-DAG: %[[b0_quantized:.*]] = "tf.Const"() <{value = dense<[120654, 119646]> : tensor<2xi32>}> : () -> tensor<2xi32>
+// CHECK-DAG: %[[w0_quantized:.*]] = "tf.Const"() <{value = dense<{{\[\[\[\[}}-42, 18
+// CHECK-DAG: %[[in_scale:.*]] = "tf.Const"() <{value = dense<0.0039215642> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG: %[[in_out_zp:.*]] = "tf.Const"() <{value = dense<-128> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG: %[[w0_scale:.*]] = "tf.Const"() <{value = dense<0.0150353173> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG: %[[w_b_zp:.*]]  = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG: %[[b0_scale:.*]] = "tf.Const"() <{value = dense<5.89619667E-5> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG: %[[mid_scale:.*]] = "tf.Const"() <{value = dense<0.0235294122> : tensor<f32>}> : () -> tensor<f32>
+
+// PerChannel-LABEL: func @conv_with_per_channel_and_tensor_weight
+// PerChannel-DAG: %[[b0_quantized:.*]] = "tf.Const"() <{value = dense<[241481, 119646]> : tensor<2xi32>}> : () -> tensor<2xi32>
+// PerChannel-DAG: %[[w0_quantized:.*]] = "tf.Const"() <{value = dense<{{\[\[\[\[}}-84, 18
+// PerChannel-DAG: %[[in_scale:.*]] = "tf.Const"() <{value = dense<0.0039215642> : tensor<f32>}> : () -> tensor<f32>
+// PerChannel-DAG: %[[in_out_zp:.*]] = "tf.Const"() <{value = dense<-128> : tensor<i32>}> : () -> tensor<i32>
+// PerChannel-DAG: %[[w0_scale:.*]] = "tf.Const"() <{value = dense<[0.0075123054, 0.0150353173]> : tensor<2xf32>}> : () -> tensor<2xf32>
+// PerChannel-DAG: %[[w_b_zp:.*]]  = "tf.Const"() <{value = dense<0> : tensor<2xi32>}> : () -> tensor<2xi32>
+// PerChannel-DAG: %[[b0_scale:.*]] = "tf.Const"() <{value = dense<[2.94599886E-5, 5.89619667E-5]> : tensor<2xf32>}> : () -> tensor<2xf32>
+// PerChannel-DAG: %[[mid_scale:.*]] = "tf.Const"() <{value = dense<0.0235294122> : tensor<f32>}> : () -> tensor<f32>
 }
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_drq.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_drq.mlir
index c500b3c..e3bda3f 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_drq.mlir
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_drq.mlir
@@ -15,7 +15,7 @@
 
 // CHECK: %[[cst:.*]] = "arith.constant"() <{value = dense<0.000000e+00> : tensor<2x1024xf32>}> : () -> tensor<2x1024xf32>
 // CHECK: %[[q_cst:.*]] = "quantfork.qcast"(%[[cst]]) : (tensor<2x1024xf32>) -> tensor<2x1024x!quant.uniform<i8<-127:127>:f32, 3.9370078740157481E-9>>
-// CHECK: %[[out:.*]] = "tf.PartitionedCall"(%arg0, %[[q_cst]]) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn} : (tensor<1x2x2x3xf32>, tensor<2x1024x!quant.uniform<i8<-127:127>:f32, 3.9370078740157481E-9>>) -> tensor<*xf32>
+// CHECK: %[[out:.*]] = "tf.PartitionedCall"(%arg0, %[[q_cst]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x2x2x3xf32>, tensor<2x1024x!quant.uniform<i8<-127:127>:f32, 3.9370078740157481E-9>>) -> tensor<*xf32>
 // CHECK: "func.return"(%[[out]]) : (tensor<*xf32>) -> ()
 }
 
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_weights.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_weights.mlir
index c41d43d..7f7a509 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_weights.mlir
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_weights.mlir
@@ -8,7 +8,7 @@
   }
 
 // CHECK-LABEL: func @not_quantize_const
-// CHECK-DAG: %[[W:.*]] = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x1024xf32>
+// CHECK-DAG: %[[W:.*]] = "tf.Const"() <{value = dense<2.000000e+00> : tensor<2x1024xf32>
 // CHECK: return %[[W]] : tensor<2x1024xf32>
 }
 
@@ -22,15 +22,15 @@
   }
 
 // CHECK-LABEL: func @matmul
-// CHECK-DAG: %[[W:.*]] = "tf.Const"() {value = dense<127> : tensor<2x1024xi8>
+// CHECK-DAG: %[[W:.*]] = "tf.Const"() <{value = dense<127> : tensor<2x1024xi8>
 // CHECK: %[[PRESERVE_W:.*]] = "tf.Identity"(%[[W]]) : (tensor<2x1024xi8>) -> tensor<2x1024xi8>
-// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[PRESERVE_W]]) {config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform} : (tensor<2x1024xi8>) -> tensor<2x1024xf32>
-// CHECK: %[[MATMUL:.*]] = "tf.MatMul"(%arg0, %[[DEQUANTIZED]]) {attr_map = "0:transpose_a,1:transpose_a", device = "", transpose_a = false, transpose_b = false} : (tensor<1x2x2x2xf32>, tensor<2x1024xf32>) -> tensor<*xf32>
+// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[PRESERVE_W]]) <{config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform}> : (tensor<2x1024xi8>) -> tensor<2x1024xf32>
+// CHECK: %[[MATMUL:.*]] = "tf.MatMul"(%arg0, %[[DEQUANTIZED]]) <{transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_a", device = ""} : (tensor<1x2x2x2xf32>, tensor<2x1024xf32>) -> tensor<*xf32>
 // CHECK: return %[[MATMUL]] : tensor<*xf32>
 
 // CHECK-LABEL: func.func private @composite_dequantize_uniform(%arg0: tensor<*xi8>) -> tensor<*xf32>
-// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() {value = dense<0.0157480314> : tensor<f32>
-// CHECK: %[[CASTED_W:.*]] = "tf.Cast"(%arg0) {Truncate = false} : (tensor<*xi8>) -> tensor<*xf32>
+// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() <{value = dense<0.0157480314> : tensor<f32>
+// CHECK: %[[CASTED_W:.*]] = "tf.Cast"(%arg0) <{Truncate = false}> : (tensor<*xi8>) -> tensor<*xf32>
 // CHECK: %[[DEQUANTIZED:.*]] = "tf.Mul"(%[[CASTED_W]], %[[SCALE]]) : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
 // CHECK: return %[[DEQUANTIZED]] : tensor<*xf32>
 }
@@ -48,7 +48,7 @@
 // CHECK-LABEL: func @not_quantize_matmul_without_const
 // CHECK: %[[ORIGINAL_IDENTITY_1:.*]] = "tf.Identity"(%arg0) {device = ""} : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32>
 // CHECK: %[[ORIGINAL_IDENTITY_2:.*]] = "tf.Identity"(%arg1) {device = ""} : (tensor<2x1024xf32>) -> tensor<2x1024xf32>
-// CHECK: %[[MATMUL:.*]] = "tf.MatMul"(%[[ORIGINAL_IDENTITY_1]], %[[ORIGINAL_IDENTITY_2]]) {attr_map = "0:transpose_a,1:transpose_a", device = "", transpose_a = false, transpose_b = false} : (tensor<1x2x2x2xf32>, tensor<2x1024xf32>) -> tensor<*xf32>
+// CHECK: %[[MATMUL:.*]] = "tf.MatMul"(%[[ORIGINAL_IDENTITY_1]], %[[ORIGINAL_IDENTITY_2]]) <{transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_a", device = ""} : (tensor<1x2x2x2xf32>, tensor<2x1024xf32>) -> tensor<*xf32>
 // CHECK: return %[[MATMUL]] : tensor<*xf32>
 }
 
@@ -63,14 +63,14 @@
   }
 
 // CHECK-LABEL: func @quantize_xladotv2_bf16
-// CHECK-DAG: %[[W:.*]] = "tf.Const"() {value = dense<127> : tensor<2x1024xi8>
+// CHECK-DAG: %[[W:.*]] = "tf.Const"() <{value = dense<127> : tensor<2x1024xi8>
 // CHECK: %[[IDENTITY:.*]] = "tf.Identity"(%[[W]]) : (tensor<2x1024xi8>) -> tensor<2x1024xi8>
-// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[IDENTITY]]) {config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform} : (tensor<2x1024xi8>) -> tensor<2x1024xbf16>
-// CHECK: %[[MATMUL:.*]] = "tf.XlaDotV2"(%arg0, %[[DEQUANTIZED]]) {device = "", dimension_numbers = "\12\01\00\0A\01\03", precision_config = ""} : (tensor<1x2x2x2xbf16>, tensor<2x1024xbf16>) -> tensor<1x2x2x1024xbf16>
+// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[IDENTITY]]) <{config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform}> : (tensor<2x1024xi8>) -> tensor<2x1024xbf16>
+// CHECK: %[[MATMUL:.*]] = "tf.XlaDotV2"(%arg0, %[[DEQUANTIZED]]) <{dimension_numbers = "\12\01\00\0A\01\03", precision_config = ""}> {device = ""} : (tensor<1x2x2x2xbf16>, tensor<2x1024xbf16>) -> tensor<1x2x2x1024xbf16>
 // CHECK: return %[[MATMUL]] : tensor<1x2x2x1024xbf16>
 
 // CHECK-LABEL: func.func private @composite_dequantize_uniform(%arg0: tensor<*xi8>) -> tensor<*xbf16>
-// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() {value = dense<1.574710e-02> : tensor<bf16>
+// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() <{value = dense<1.574710e-02> : tensor<bf16>
 }
 
 // -----
@@ -87,17 +87,17 @@
   }
 
 // CHECK-LABEL: func @matmul_with_identity_and_reshape
-// CHECK-DAG: %[[W:.*]] = "tf.Const"() {value = dense<127> : tensor<1024x2xi8>
-// CHECK-DAG: %[[SHAPE:.*]] = "tf.Const"() {value = dense<[2, 1024]> : tensor<2xi32>
+// CHECK-DAG: %[[W:.*]] = "tf.Const"() <{value = dense<127> : tensor<1024x2xi8>
+// CHECK-DAG: %[[SHAPE:.*]] = "tf.Const"() <{value = dense<[2, 1024]> : tensor<2xi32>
 // CHECK: %[[PRESERVE_W:.*]] = "tf.Identity"(%[[W]]) : (tensor<1024x2xi8>) -> tensor<1024x2xi8>
-// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[PRESERVE_W]]) {config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform} : (tensor<1024x2xi8>) -> tensor<1024x2xf32>
+// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[PRESERVE_W]]) <{config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform}> : (tensor<1024x2xi8>) -> tensor<1024x2xf32>
 // CHECK: %[[ORIGINAL_IDENTITY:.*]] = "tf.Identity"(%[[DEQUANTIZED]]) {device = ""} : (tensor<1024x2xf32>) -> tensor<1024x2xf32>
 // CHECK: %[[RESHAPED_W:.*]] = "tf.Reshape"(%[[ORIGINAL_IDENTITY]], %[[SHAPE]]) : (tensor<1024x2xf32>, tensor<2xi32>) -> tensor<2x1024xf32>
-// CHECK: %[[MATMUL:.*]] = "tf.MatMul"(%arg0, %[[RESHAPED_W]]) {attr_map = "0:transpose_a,1:transpose_a", device = "", transpose_a = false, transpose_b = false} : (tensor<1x2x2x2xf32>, tensor<2x1024xf32>) -> tensor<*xf32>
+// CHECK: %[[MATMUL:.*]] = "tf.MatMul"(%arg0, %[[RESHAPED_W]]) <{transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_a", device = ""} : (tensor<1x2x2x2xf32>, tensor<2x1024xf32>) -> tensor<*xf32>
 // CHECK: return %[[MATMUL]] : tensor<*xf32>
 
 // CHECK-LABEL: func.func private @composite_dequantize_uniform(%arg0: tensor<*xi8>) -> tensor<*xf32>
-// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() {value = dense<0.0157480314> : tensor<f32>
+// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() <{value = dense<0.0157480314> : tensor<f32>
 }
 
 // -----
@@ -113,16 +113,16 @@
   }
 
 // CHECK-LABEL: func @conv2d
-// CHECK-DAG: %[[W:.*]] = "tf.Const"() {value = dense<127> : tensor<2x3x3x512xi8>
-// CHECK-DAG: %[[BIAS:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<2xf32>
+// CHECK-DAG: %[[W:.*]] = "tf.Const"() <{value = dense<127> : tensor<2x3x3x512xi8>
+// CHECK-DAG: %[[BIAS:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<2xf32>
 // CHECK: %[[PRESERVE_W:.*]] = "tf.Identity"(%[[W]]) : (tensor<2x3x3x512xi8>) -> tensor<2x3x3x512xi8>
-// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[PRESERVE_W]]) {config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform} : (tensor<2x3x3x512xi8>) -> tensor<2x3x3x512xf32>
-// CHECK: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[DEQUANTIZED:.*]]) {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x512xf32>) -> tensor<*xf32>
-// CHECK: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[CONV2D]], %[[BIAS]]) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32>
+// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[PRESERVE_W]]) <{config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform}> : (tensor<2x3x3x512xi8>) -> tensor<2x3x3x512xf32>
+// CHECK: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[DEQUANTIZED:.*]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true}> {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x512xf32>) -> tensor<*xf32>
+// CHECK: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[CONV2D]], %[[BIAS]]) <{data_format = "NHWC"}> {device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32>
 // CHECK: return %[[BIASADD]] : tensor<*xf32>
 
 // CHECK-LABEL: func.func private @composite_dequantize_uniform(%arg0: tensor<*xi8>) -> tensor<*xf32>
-// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() {value = dense<0.0236220472> : tensor<f32>
+// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() <{value = dense<0.0236220472> : tensor<f32>
 }
 
 // -----
@@ -138,14 +138,14 @@
   }
 
 // CHECK-LABEL: func @depthwise_conv
-// CHECK-DAG: %[[W:.*]] = "tf.Const"() {value = dense<127> : tensor<2x3x3x512xi8>
+// CHECK-DAG: %[[W:.*]] = "tf.Const"() <{value = dense<127> : tensor<2x3x3x512xi8>
 // CHECK: %[[PRESERVE_W:.*]] = "tf.Identity"(%[[W]]) : (tensor<2x3x3x512xi8>) -> tensor<2x3x3x512xi8>
-// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[PRESERVE_W]]) {config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform} : (tensor<2x3x3x512xi8>) -> tensor<2x3x3x512xf32>
-// CHECK: %[[DEPTHWISE_CONV2D:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[DEQUANTIZED]]) {attr_map = "0:strides,1:padding,2:explicit_paddings,3:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1]} : (tensor<1x3x4x512xf32>, tensor<2x3x3x512xf32>) -> tensor<*xf32>
+// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[PRESERVE_W]]) <{config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform}> : (tensor<2x3x3x512xi8>) -> tensor<2x3x3x512xf32>
+// CHECK: %[[DEPTHWISE_CONV2D:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[DEQUANTIZED]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1]}> {attr_map = "0:strides,1:padding,2:explicit_paddings,3:dilations", device = ""} : (tensor<1x3x4x512xf32>, tensor<2x3x3x512xf32>) -> tensor<*xf32>
 // CHECK: return %[[DEPTHWISE_CONV2D]] : tensor<*xf32>
 
 // CHECK-LABEL: func.func private @composite_dequantize_uniform(%arg0: tensor<*xi8>) -> tensor<*xf32>
-// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() {value = dense<0.00787401571> : tensor<f32>
+// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() <{value = dense<0.00787401571> : tensor<f32>
 }
 
 // -----
@@ -160,16 +160,16 @@
   }
 
 // CHECK-LABEL: func @quantize_sharded_weights_with_xladot
-// CHECK-DAG: %[[W:.*]] = "tf.Const"() {value = dense<127> : tensor<512x512xi8>
+// CHECK-DAG: %[[W:.*]] = "tf.Const"() <{value = dense<127> : tensor<512x512xi8>
 // CHECK: %[[PRESERVE_W:.*]] = "tf.Identity"(%[[W]]) : (tensor<512x512xi8>) -> tensor<512x512xi8>
-// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[PRESERVE_W]]) {config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform} : (tensor<512x512xi8>) -> tensor<512x512xf32>
-// CHECK: %[[SHARDED_W:.*]] = "tf.XlaSharding"(%[[DEQUANTIZED]]) {_XlaSharding = "\08\03\1A\03\01\04\02\22\08\00\04\01\05\02\06\03\070\01", device = "", sharding = "\08\03\1A\03\01\04\02\22\08\00\04\01\05\02\06\03\070\01", unspecified_dims = []} : (tensor<512x512xf32>) -> tensor<512x512xf32>
-// CHECK: %[[XLADOT:.*]] = "tf.XlaDotV2"(%arg0, %[[SHARDED_W]]) {device = "", dimension_numbers = "\12\01\00\0A\01\03", precision_config = ""} : (tensor<?x?x?x?xf32>, tensor<512x512xf32>) -> tensor<?x?x?x?xf32>
-// CHECK: %[[ORIGINAL_CAST:.*]] = "tf.Cast"(%[[XLADOT]]) {Truncate = false} : (tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xbf16>
+// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[PRESERVE_W]]) <{config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform}> : (tensor<512x512xi8>) -> tensor<512x512xf32>
+// CHECK: %[[SHARDED_W:.*]] = "tf.XlaSharding"(%[[DEQUANTIZED]]) <{_XlaSharding = "\08\03\1A\03\01\04\02\22\08\00\04\01\05\02\06\03\070\01", sharding = "\08\03\1A\03\01\04\02\22\08\00\04\01\05\02\06\03\070\01"}> {device = "", unspecified_dims = []} : (tensor<512x512xf32>) -> tensor<512x512xf32>
+// CHECK: %[[XLADOT:.*]] = "tf.XlaDotV2"(%arg0, %[[SHARDED_W]]) <{dimension_numbers = "\12\01\00\0A\01\03", precision_config = ""}> {device = ""} : (tensor<?x?x?x?xf32>, tensor<512x512xf32>) -> tensor<?x?x?x?xf32>
+// CHECK: %[[ORIGINAL_CAST:.*]] = "tf.Cast"(%[[XLADOT]]) <{Truncate = false}> : (tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xbf16>
 // CHECK: return %[[ORIGINAL_CAST]] : tensor<?x?x?x?xbf16>
 
 // CHECK-LABEL: func.func private @composite_dequantize_uniform(%arg0: tensor<*xi8>) -> tensor<*xf32>
-// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() {value = dense<0.0787401571> : tensor<f32>
+// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() <{value = dense<0.0787401571> : tensor<f32>
 }
 
 // -----
@@ -184,16 +184,16 @@
   }
 
 // CHECK-LABEL: func @quantize_sharded_weights_with_xladot_with_identity
-// CHECK-DAG: %[[W:.*]] = "tf.Const"() {value = dense<127> : tensor<512x512xi8>
+// CHECK-DAG: %[[W:.*]] = "tf.Const"() <{value = dense<127> : tensor<512x512xi8>
 // CHECK: %[[PRESERVE_W:.*]] = "tf.Identity"(%[[W]]) : (tensor<512x512xi8>) -> tensor<512x512xi8>
-// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[PRESERVE_W]]) {config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform} : (tensor<512x512xi8>) -> tensor<512x512xf32>
+// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[PRESERVE_W]]) <{config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform}> : (tensor<512x512xi8>) -> tensor<512x512xf32>
 // CHECK: %[[IDENTITY_W:.*]] = "tf.Identity"(%[[DEQUANTIZED]]) {device = ""} : (tensor<512x512xf32>) -> tensor<512x512xf32>
-// CHECK: %[[SHARDED_W:.*]] = "tf.XlaSharding"(%[[IDENTITY_W]]) {_XlaSharding = "\08\03\1A\03\01\04\02\22\08\00\04\01\05\02\06\03\070\01", device = "", sharding = "\08\03\1A\03\01\04\02\22\08\00\04\01\05\02\06\03\070\01", unspecified_dims = []} : (tensor<512x512xf32>) -> tensor<512x512xf32>
-// CHECK: %[[XLADOT:.*]] = "tf.XlaDotV2"(%arg0, %[[SHARDED_W]]) {device = "", dimension_numbers = "\12\01\00\0A\01\03", precision_config = ""} : (tensor<?x?x?x?xf32>, tensor<512x512xf32>) -> tensor<?x?x?x?xf32>
+// CHECK: %[[SHARDED_W:.*]] = "tf.XlaSharding"(%[[IDENTITY_W]]) <{_XlaSharding = "\08\03\1A\03\01\04\02\22\08\00\04\01\05\02\06\03\070\01", sharding = "\08\03\1A\03\01\04\02\22\08\00\04\01\05\02\06\03\070\01"}> {device = "", unspecified_dims = []} : (tensor<512x512xf32>) -> tensor<512x512xf32>
+// CHECK: %[[XLADOT:.*]] = "tf.XlaDotV2"(%arg0, %[[SHARDED_W]]) <{dimension_numbers = "\12\01\00\0A\01\03", precision_config = ""}> {device = ""} : (tensor<?x?x?x?xf32>, tensor<512x512xf32>) -> tensor<?x?x?x?xf32>
 // CHECK: return %[[XLADOT]] : tensor<?x?x?x?xf32>
 
 // CHECK-LABEL: func.func private @composite_dequantize_uniform(%arg0: tensor<*xi8>) -> tensor<*xf32>
-// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() {value = dense<0.0787401571> : tensor<f32>
+// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() <{value = dense<0.0787401571> : tensor<f32>
 }
 
 // -----
@@ -208,16 +208,16 @@
   }
 
 // CHECK-LABEL: func @quantize_xlagather
-// CHECK-DAG: %[[W:.*]] = "tf.Const"() {value = dense<127> : tensor<200x100x300xi8>} : () -> tensor<200x100x300xi8>
-// CHECK-DAG: %[[IDX:.*]] = "tf.Const"() {value = dense<[1, 1, 300]> : tensor<3xi64>
+// CHECK-DAG: %[[W:.*]] = "tf.Const"() <{value = dense<127> : tensor<200x100x300xi8>}> : () -> tensor<200x100x300xi8>
+// CHECK-DAG: %[[IDX:.*]] = "tf.Const"() <{value = dense<[1, 1, 300]> : tensor<3xi64>
 // CHECK: %[[PRESERVE_W:.*]] = "tf.Identity"(%[[W]]) : (tensor<200x100x300xi8>) -> tensor<200x100x300xi8>
-// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[PRESERVE_W]]) {config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform} : (tensor<200x100x300xi8>) -> tensor<200x100x300xf32>
-// CHECK: %[[GATHER:.*]] = "tf.XlaGather"(%[[DEQUANTIZED]], %arg0, %[[IDX]]) {dimension_numbers = "\0A\02\00\01\12\01\00\1A\02\00\01 \01", indices_are_sorted = true} : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<3xi64>) -> tensor<1x300x10xf32>
+// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[PRESERVE_W]]) <{config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform}> : (tensor<200x100x300xi8>) -> tensor<200x100x300xf32>
+// CHECK: %[[GATHER:.*]] = "tf.XlaGather"(%[[DEQUANTIZED]], %arg0, %[[IDX]]) <{dimension_numbers = "\0A\02\00\01\12\01\00\1A\02\00\01 \01", indices_are_sorted = true}> : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<3xi64>) -> tensor<1x300x10xf32>
 // CHECK: %[[IDENTITY:.*]] = "tf.Identity"(%[[GATHER]]) {device = ""} : (tensor<1x300x10xf32>) -> tensor<1x300x10xf32>
 // CHECK: return %[[IDENTITY]] : tensor<1x300x10xf32>
 
 // CHECK-LABEL: func.func private @composite_dequantize_uniform(%arg0: tensor<*xi8>) -> tensor<*xf32>
-// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() {value = dense<0.0787401571> : tensor<f32>} : () -> tensor<f32>
+// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() <{value = dense<0.0787401571> : tensor<f32>}> : () -> tensor<f32>
 }
 
 // -----
@@ -236,17 +236,17 @@
   }
 
 // CHECK-LABEL: func @partitioned_call
-// CHECK-DAG: %[[W:.*]] = "tf.Const"() {value = dense<127> : tensor<2x1024xi8>
+// CHECK-DAG: %[[W:.*]] = "tf.Const"() <{value = dense<127> : tensor<2x1024xi8>
 // CHECK: %[[PRESERVE_W:.*]] = "tf.Identity"(%[[W]]) : (tensor<2x1024xi8>) -> tensor<2x1024xi8>
-// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[PRESERVE_W]]) {config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform} : (tensor<2x1024xi8>) -> tensor<2x1024xf32>
-// CHECK: %[[OUTPUT:.*]] = "tf.PartitionedCall"(%arg0, %[[DEQUANTIZED]]) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn} : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32>
+// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[PRESERVE_W]]) <{config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform}> : (tensor<2x1024xi8>) -> tensor<2x1024xf32>
+// CHECK: %[[OUTPUT:.*]] = "tf.PartitionedCall"(%arg0, %[[DEQUANTIZED]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32>
 // CHECK: return %[[OUTPUT]] : tensor<*xf32>
 
 // CHECK-LABEL: func.func private @composite_dequantize_uniform(%arg0: tensor<*xi8>) -> tensor<*xf32>
-// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() {value = dense<0.0314960629> : tensor<f32>
+// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() <{value = dense<0.0314960629> : tensor<f32>
 
 // CHECK-LABEL: func private @composite_matmul_fn
-// CHECK: %[[MATMUL:.*]] = "tf.MatMul"(%arg0, %arg1) {attr_map = "0:transpose_a,1:transpose_a", device = "", transpose_a = false, transpose_b = false} : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32>
+// CHECK: %[[MATMUL:.*]] = "tf.MatMul"(%arg0, %arg1) <{transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_a", device = ""} : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32>
 // CHECK: return %[[MATMUL]] : tensor<*xf32>
 }
 
@@ -272,21 +272,21 @@
 }
 
 // CHECK-LABEL: func @recursive_partitioned_call(%arg0: tensor<1x2x2x3xf32>) -> tensor<*xf32>
-// CHECK-DAG: %[[W:.*]] = "tf.Const"() {value = dense<127> : tensor<2x1024xi8>
+// CHECK-DAG: %[[W:.*]] = "tf.Const"() <{value = dense<127> : tensor<2x1024xi8>
 // CHECK: %[[PRESERVE_W:.*]] = "tf.Identity"(%[[W]]) : (tensor<2x1024xi8>) -> tensor<2x1024xi8>
-// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[PRESERVE_W]]) {config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform} : (tensor<2x1024xi8>) -> tensor<2x1024xf32>
-// CHECK: %[[OUTPUT:.*]] = "tf.PartitionedCall"(%arg0, %[[DEQUANTIZED]]) {config = "", config_proto = "", executor_type = "", f = @outer_fn} : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32>
+// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[PRESERVE_W]]) <{config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform}> : (tensor<2x1024xi8>) -> tensor<2x1024xf32>
+// CHECK: %[[OUTPUT:.*]] = "tf.PartitionedCall"(%arg0, %[[DEQUANTIZED]]) <{config = "", config_proto = "", executor_type = "", f = @outer_fn}> : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32>
 // CHECK: return %[[OUTPUT]] : tensor<*xf32>
 
 // CHECK-LABEL: func private @composite_dequantize_uniform(%arg0: tensor<*xi8>) -> tensor<*xf32>
-// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() {value = dense<0.0314960629> : tensor<f32>
+// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() <{value = dense<0.0314960629> : tensor<f32>
 
 // CHECK-LABEL: func private @outer_fn
-// CHECK: %[[OUTER_OUTPUT:.*]] = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @inner_fn} : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32>
+// CHECK: %[[OUTER_OUTPUT:.*]] = "tf.PartitionedCall"(%arg0, %arg1) <{config = "", config_proto = "", executor_type = "", f = @inner_fn}> : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32>
 // CHECK: return %[[OUTER_OUTPUT]] : tensor<*xf32>
 
 // CHECK-LABEL: func private @inner_fn
-// CHECK: %[[MATMUL:.*]] = "tf.MatMul"(%arg0, %arg1) {attr_map = "0:transpose_a,1:transpose_a", device = "", transpose_a = false, transpose_b = false} : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32>
+// CHECK: %[[MATMUL:.*]] = "tf.MatMul"(%arg0, %arg1) <{transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_a", device = ""} : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32>
 // CHECK: return %[[MATMUL]] : tensor<*xf32>
 
 // -----
@@ -302,17 +302,17 @@
   }
 
 // CHECK-LABEL: func @matmul_multiuses
-// CHECK-DAG: %[[W:.*]] = "tf.Const"() {value = dense<127> : tensor<2x1024xi8>
+// CHECK-DAG: %[[W:.*]] = "tf.Const"() <{value = dense<127> : tensor<2x1024xi8>
 // CHECK: %[[PRESERVE_W:.*]] = "tf.Identity"(%[[W]]) : (tensor<2x1024xi8>) -> tensor<2x1024xi8>
-// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[PRESERVE_W]]) {config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform} : (tensor<2x1024xi8>) -> tensor<2x1024xf32>
-// CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %[[DEQUANTIZED]]) {attr_map = "0:transpose_a,1:transpose_a", device = "", transpose_a = false, transpose_b = false} : (tensor<1x2x2x2xf32>, tensor<2x1024xf32>) -> tensor<*xf32>
-// CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%arg1, %[[DEQUANTIZED]]) {attr_map = "0:transpose_a,1:transpose_a", device = "", transpose_a = false, transpose_b = false} : (tensor<1x2x2x2xf32>, tensor<2x1024xf32>) -> tensor<*xf32>
+// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[PRESERVE_W]]) <{config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform}> : (tensor<2x1024xi8>) -> tensor<2x1024xf32>
+// CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %[[DEQUANTIZED]]) <{transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_a", device = ""} : (tensor<1x2x2x2xf32>, tensor<2x1024xf32>) -> tensor<*xf32>
+// CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%arg1, %[[DEQUANTIZED]]) <{transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_a", device = ""} : (tensor<1x2x2x2xf32>, tensor<2x1024xf32>) -> tensor<*xf32>
 // CHECK: %[[ORIGINAL_IDENTITY:.*]] = "tf.Identity"(%[[DEQUANTIZED]]) {device = ""} : (tensor<2x1024xf32>) -> tensor<2x1024xf32>
-// CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%arg0, %[[ORIGINAL_IDENTITY]]) {attr_map = "0:transpose_a,1:transpose_a", device = "", transpose_a = false, transpose_b = false} : (tensor<1x2x2x2xf32>, tensor<2x1024xf32>) -> tensor<*xf32>
+// CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%arg0, %[[ORIGINAL_IDENTITY]]) <{transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_a", device = ""} : (tensor<1x2x2x2xf32>, tensor<2x1024xf32>) -> tensor<*xf32>
 // CHECK: return %[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]] : tensor<*xf32>, tensor<*xf32>, tensor<*xf32>
 
 // CHECK-LABEL: func.func private @composite_dequantize_uniform(%arg0: tensor<*xi8>) -> tensor<*xf32>
-// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() {value = dense<0.0157480314> : tensor<f32>
+// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() <{value = dense<0.0157480314> : tensor<f32>
 }
 
 // -----
@@ -327,15 +327,15 @@
   }
 
 // CHECK-LABEL: func @matmul_multiuses
-// CHECK-DAG: %[[W:.*]] = "tf.Const"() {value = dense<127> : tensor<2x1024xi8>
+// CHECK-DAG: %[[W:.*]] = "tf.Const"() <{value = dense<127> : tensor<2x1024xi8>
 // CHECK: %[[PRESERVE_W:.*]] = "tf.Identity"(%[[W]]) : (tensor<2x1024xi8>) -> tensor<2x1024xi8>
-// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[PRESERVE_W]]) {config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform} : (tensor<2x1024xi8>) -> tensor<2x1024xf32>
-// CHECK: %[[MATMUL:.*]] = "tf.MatMul"(%arg0, %[[DEQUANTIZED]]) {attr_map = "0:transpose_a,1:transpose_a", device = "", transpose_a = false, transpose_b = false} : (tensor<1x2x2x2xf32>, tensor<2x1024xf32>) -> tensor<*xf32>
+// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[PRESERVE_W]]) <{config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform}> : (tensor<2x1024xi8>) -> tensor<2x1024xf32>
+// CHECK: %[[MATMUL:.*]] = "tf.MatMul"(%arg0, %[[DEQUANTIZED]]) <{transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_a", device = ""} : (tensor<1x2x2x2xf32>, tensor<2x1024xf32>) -> tensor<*xf32>
 // CHECK: %[[ADD:.*]] = "tf.AddV2"(%arg1, %[[DEQUANTIZED]]) {device = ""} : (tensor<2x1024xf32>, tensor<2x1024xf32>) -> tensor<2x1024xf32>
 // CHECK: return %[[MATMUL]], %[[ADD]] : tensor<*xf32>, tensor<2x1024xf32>
 
 // CHECK-LABEL: func.func private @composite_dequantize_uniform(%arg0: tensor<*xi8>) -> tensor<*xf32>
-// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() {value = dense<0.0157480314> : tensor<f32>
+// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() <{value = dense<0.0157480314> : tensor<f32>
 }
 
 // -----
@@ -375,21 +375,21 @@
 }
 
 // CHECK-LABEL: func @matmul_with_while
-// CHECK-DAG: %[[W:.*]] = "tf.Const"() {value = dense<127> : tensor<1024x1024xi8>
-// CHECK-DAG: %[[CNT:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG: %[[W:.*]] = "tf.Const"() <{value = dense<127> : tensor<1024x1024xi8>
+// CHECK-DAG: %[[CNT:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
 // CHECK: %[[PRESERVE_W:.*]] = "tf.Identity"(%[[W]]) : (tensor<1024x1024xi8>) -> tensor<1024x1024xi8>
-// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[PRESERVE_W]]) {config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform} : (tensor<1024x1024xi8>) -> tensor<1024x1024xf32>
-// CHECK: %[[WHILE:.*]] = "tf.While"(%[[CNT]], %[[CNT]], %[[CNT]], %arg0, %[[DEQUANTIZED]]) {T = [i32, i32, i32, f32, f32], _lower_using_switch_merge = true, _num_original_outputs = 5 : i64, _read_only_resource_inputs = [], body = @while_body, cond = @while_cond, device = "", is_stateless = true, output_shapes = [#tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<1x1024>, #tf_type.shape<1024x1024>], parallel_iterations = 10 : i64, shape_invariant} : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<1x1024xf32>, tensor<1024x1024xf32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<1x1024xf32>, tensor<1024x1024xf32>)
+// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[PRESERVE_W]]) <{config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform}> : (tensor<1024x1024xi8>) -> tensor<1024x1024xf32>
+// CHECK: %[[WHILE:.*]] = "tf.While"(%[[CNT]], %[[CNT]], %[[CNT]], %arg0, %[[DEQUANTIZED]]) <{body = @while_body, cond = @while_cond, is_stateless = true, parallel_iterations = 10 : i64, shape_invariant}> {T = [i32, i32, i32, f32, f32], _lower_using_switch_merge = true, _num_original_outputs = 5 : i64, _read_only_resource_inputs = [], device = "", output_shapes = [#tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<1x1024>, #tf_type.shape<1024x1024>]} : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<1x1024xf32>, tensor<1024x1024xf32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<1x1024xf32>, tensor<1024x1024xf32>)
 // CHECK: %[[IDENTITY:.*]] = "tf.Identity"(%[[WHILE:.*]]) {device = ""} : (tensor<1x1024xf32>) -> tensor<1x1024xf32>
 // CHECK: return %[[IDENTITY]] : tensor<1x1024xf32>
 
 // CHECK-LABEL: func.func private @composite_dequantize_uniform(%arg0: tensor<*xi8>) -> tensor<*xf32>
-// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() {value = dense<0.00787401571> : tensor<f32>
+// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() <{value = dense<0.00787401571> : tensor<f32>
 
 // CHECK-LABEL: func private @while_body(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<1x1024xf32>, %arg4: tensor<1024x1024xf32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<1x1024xf32>, tensor<1024x1024xf32>)
-// CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg3, %arg4) {device = "", transpose_a = false, transpose_b = false} : (tensor<1x1024xf32>, tensor<1024x1024xf32>) -> tensor<1x1024xf32>
+// CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg3, %arg4) <{transpose_a = false, transpose_b = false}> {device = ""} : (tensor<1x1024xf32>, tensor<1024x1024xf32>) -> tensor<1x1024xf32>
 // CHECK: %[[IDENTITY:.*]] = "tf.Identity"(%arg4) {device = ""} : (tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
-// CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%arg3, %[[IDENTITY]]) {device = "", transpose_a = false, transpose_b = false} : (tensor<1x1024xf32>, tensor<1024x1024xf32>) -> tensor<1x1024xf32>
+// CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%arg3, %[[IDENTITY]]) <{transpose_a = false, transpose_b = false}> {device = ""} : (tensor<1x1024xf32>, tensor<1024x1024xf32>) -> tensor<1x1024xf32>
 // CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[MATMUL_1]], %[[MATMUL_2]]) {device = ""} : (tensor<1x1024xf32>, tensor<1x1024xf32>) -> tensor<1x1024xf32>
 
 // CHECK-LABEL: func private @while_cond(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<1x1024xf32>, %arg4: tensor<1024x1024xf32>) -> tensor<i1>
@@ -401,13 +401,13 @@
   func.func @matmul_with_while_bf16(%arg0: tensor<1x1024xbf16>) -> tensor<1x1024xbf16> {
     %cst = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
     %cst_0 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
-    %cst_1 = "tf.Const"(){value = dense<1.0> : tensor<1024x1024xbf16>} : () -> tensor<1024x1024xbf16>
+    %cst_1 = "tf.Const"() {value = dense<1.0> : tensor<1024x1024xbf16>} : () -> tensor<1024x1024xbf16>
     %0:5 = "tf.While"(%cst_0, %cst, %cst_0, %arg0, %cst_1) {T = [i32, i32, i32, bf16, bf16],_lower_using_switch_merge = true, _num_original_outputs = 5 : i64, _read_only_resource_inputs = [], body = @while_body, cond = @while_cond, device = "", is_stateless = true, output_shapes = [#tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<1x1024>, #tf_type.shape<1024x1024>], parallel_iterations = 10 : i64, shape_invariant} : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<1x1024xbf16>, tensor<1024x1024xbf16>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<1x1024xbf16>, tensor<1024x1024xbf16>)
     %1 = "tf.Identity"(%0#3) {device = ""} : (tensor<1x1024xbf16>) -> tensor<1x1024xbf16>
     func.return %1 : tensor<1x1024xbf16>
   }
 
-  func.func private @while_body(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<1x1024xbf16>, %arg4: tensor<1024x1024xbf16>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<1x1024xbf16>, tensor<1024x1024xbf16>) 
+  func.func private @while_body(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<1x1024xbf16>, %arg4: tensor<1024x1024xbf16>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<1x1024xbf16>, tensor<1024x1024xbf16>)
   {
     %cst = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
     %0 = "tf.AddV2"(%arg2, %cst) {device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
@@ -432,20 +432,20 @@
 }
 
 // CHECK-LABEL: func @matmul_with_while_bf16
-// CHECK-DAG: %[[W:.*]] = "tf.Const"() {value = dense<127> : tensor<1024x1024xi8>
-// CHECK-DAG: %[[CNT:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG: %[[W:.*]] = "tf.Const"() <{value = dense<127> : tensor<1024x1024xi8>
+// CHECK-DAG: %[[CNT:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
 // CHECK: %[[IDENTITY:.*]] = "tf.Identity"(%[[W]]) : (tensor<1024x1024xi8>) -> tensor<1024x1024xi8>
-// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[IDENTITY]]) {config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform} : (tensor<1024x1024xi8>) -> tensor<1024x1024xbf16>
-// CHECK: %[[WHILE:.*]] = "tf.While"(%[[CNT]], %[[CNT]], %[[CNT]], %arg0, %[[DEQUANTIZED]]) {T = [i32, i32, i32, bf16, bf16], _lower_using_switch_merge = true, _num_original_outputs = 5 : i64, _read_only_resource_inputs = [], body = @while_body, cond = @while_cond, device = "", is_stateless = true, output_shapes = [#tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<1x1024>, #tf_type.shape<1024x1024>], parallel_iterations = 10 : i64, shape_invariant} : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<1x1024xbf16>, tensor<1024x1024xbf16>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<1x1024xbf16>, tensor<1024x1024xbf16>)
+// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[IDENTITY]]) <{config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform}> : (tensor<1024x1024xi8>) -> tensor<1024x1024xbf16>
+// CHECK: %[[WHILE:.*]] = "tf.While"(%[[CNT]], %[[CNT]], %[[CNT]], %arg0, %[[DEQUANTIZED]]) <{body = @while_body, cond = @while_cond, is_stateless = true, parallel_iterations = 10 : i64, shape_invariant}> {T = [i32, i32, i32, bf16, bf16], _lower_using_switch_merge = true, _num_original_outputs = 5 : i64, _read_only_resource_inputs = [], device = "", output_shapes = [#tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<1x1024>, #tf_type.shape<1024x1024>]} : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<1x1024xbf16>, tensor<1024x1024xbf16>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<1x1024xbf16>, tensor<1024x1024xbf16>)
 // CHECK: %[[ORIGIANL_IDENTITY:.*]] = "tf.Identity"(%[[WHILE:.*]]) {device = ""} : (tensor<1x1024xbf16>) -> tensor<1x1024xbf16>
 
 // CHECK-LABEL: func.func private @composite_dequantize_uniform(%arg0: tensor<*xi8>) -> tensor<*xbf16>
-// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() {value = dense<7.873530e-03> : tensor<bf16>
+// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() <{value = dense<7.873530e-03> : tensor<bf16>
 
 // CHECK-LABEL: func private @while_body(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<1x1024xbf16>, %arg4: tensor<1024x1024xbf16>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<1x1024xbf16>, tensor<1024x1024xbf16>) {
-// CHECK: %[[MATMUL_1:.*]] = "tf.XlaDotV2"(%arg3, %arg4) {device = "", dimension_numbers = "\12\01\00\0A\01\03", precision_config = ""} : (tensor<1x1024xbf16>, tensor<1024x1024xbf16>) -> tensor<1x1024xbf16>
+// CHECK: %[[MATMUL_1:.*]] = "tf.XlaDotV2"(%arg3, %arg4) <{dimension_numbers = "\12\01\00\0A\01\03", precision_config = ""}> {device = ""} : (tensor<1x1024xbf16>, tensor<1024x1024xbf16>) -> tensor<1x1024xbf16>
 // CHECK: %[[IDENTITY_2:.*]] = "tf.Identity"(%arg4) {device = ""} : (tensor<1024x1024xbf16>) -> tensor<1024x1024xbf16>
-// CHECK: %[[MATMUL_2:.*]] = "tf.XlaDotV2"(%arg3, %[[IDENTITY_2]]) {device = "", dimension_numbers = "\12\01\00\0A\01\03", precision_config = ""} : (tensor<1x1024xbf16>, tensor<1024x1024xbf16>) -> tensor<1x1024xbf16>
+// CHECK: %[[MATMUL_2:.*]] = "tf.XlaDotV2"(%arg3, %[[IDENTITY_2]]) <{dimension_numbers = "\12\01\00\0A\01\03", precision_config = ""}> {device = ""} : (tensor<1x1024xbf16>, tensor<1024x1024xbf16>) -> tensor<1x1024xbf16>
 // CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[MATMUL_1]], %[[MATMUL_2]]) {device = ""} : (tensor<1x1024xbf16>, tensor<1x1024xbf16>) -> tensor<1x1024xbf16>
 
 // CHECK-LABEL: func private @while_cond(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<1x1024xbf16>, %arg4: tensor<1024x1024xbf16>) -> tensor<i1> {
@@ -482,7 +482,7 @@
 }
 
 // CHECK-LABEL: func @matmul_with_while_returning_mutated_value
-// CHECK-DAG: %[[W:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<1024x1024xf32>} : () -> tensor<1024x1024xf32>
+// CHECK-DAG: %[[W:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<1024x1024xf32>}> : () -> tensor<1024x1024xf32>
 
 // -----
 module {
@@ -499,27 +499,27 @@
   }
 
 // CHECK-LABEL: func @multiple_quantizable_ops_in_graph
-// CHECK-DAG: %[[W_1:.*]] = "tf.Const"() {value = dense<127> : tensor<2x3x3x1024xi8>} : () -> tensor<2x3x3x1024xi8>
-// CHECK-DAG: %[[W_2:.*]] = "tf.Const"() {value = dense<127> : tensor<3x3x1024x1xi8>} : () -> tensor<3x3x1024x1xi8>
-// CHECK-DAG: %[[W_3:.*]] = "tf.Const"() {value = dense<127> : tensor<1024x3x4x3xi8>} : () -> tensor<1024x3x4x3xi8>
-// CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {device = "", value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG: %[[W_1:.*]] = "tf.Const"() <{value = dense<127> : tensor<2x3x3x1024xi8>}> : () -> tensor<2x3x3x1024xi8>
+// CHECK-DAG: %[[W_2:.*]] = "tf.Const"() <{value = dense<127> : tensor<3x3x1024x1xi8>}> : () -> tensor<3x3x1024x1xi8>
+// CHECK-DAG: %[[W_3:.*]] = "tf.Const"() <{value = dense<127> : tensor<1024x3x4x3xi8>}> : () -> tensor<1024x3x4x3xi8>
+// CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> {device = ""} : () -> tensor<i32>
 // CHECK: %[[IDENTITY_1:.*]] = "tf.Identity"(%[[W_1]]) : (tensor<2x3x3x1024xi8>) -> tensor<2x3x3x1024xi8>
-// CHECK: %[[DEQUANTIZED_1:.*]] = "tf.PartitionedCall"(%[[IDENTITY_1]]) {config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform__} : (tensor<2x3x3x1024xi8>) -> tensor<2x3x3x1024xf32>
+// CHECK: %[[DEQUANTIZED_1:.*]] = "tf.PartitionedCall"(%[[IDENTITY_1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform__}> : (tensor<2x3x3x1024xi8>) -> tensor<2x3x3x1024xf32>
 // CHECK: %[[IDENTITY_2:.*]] = "tf.Identity"(%[[W_2]]) : (tensor<3x3x1024x1xi8>) -> tensor<3x3x1024x1xi8>
-// CHECK: %[[DEQUANTIZED_2:.*]] = "tf.PartitionedCall"(%[[IDENTITY_2]]) {config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform_} : (tensor<3x3x1024x1xi8>) -> tensor<3x3x1024x1xf32>
+// CHECK: %[[DEQUANTIZED_2:.*]] = "tf.PartitionedCall"(%[[IDENTITY_2]]) <{config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform_}> : (tensor<3x3x1024x1xi8>) -> tensor<3x3x1024x1xf32>
 // CHECK: %[[IDENTITY_3:.*]] = "tf.Identity"(%[[W_3]]) : (tensor<1024x3x4x3xi8>) -> tensor<1024x3x4x3xi8>
-// CHECK: %[[DEQUANTIZED_3:.*]] = "tf.PartitionedCall"(%[[IDENTITY_3]]) {config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform} : (tensor<1024x3x4x3xi8>) -> tensor<1024x3x4x3xf32>
-// CHECK: %[[GATHER:.*]] = "tf.GatherV2"(%[[DEQUANTIZED_3]], %arg0, %[[AXIS]]) {batch_dims = 0 : i64, device = ""} : (tensor<1024x3x4x3xf32>, tensor<1xi32>, tensor<i32>) -> tensor<1x3x4x3xf32>
-// CHECK: %[[CONV_1:.*]] = "tf.Conv2D"(%[[GATHER]], %[[DEQUANTIZED_1]]) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x1024xf32>) -> tensor<1x3x2x1024xf32>
-// CHECK: %[[CONV_2:.*]] = "tf.Conv2D"(%[[CONV_1]], %[[DEQUANTIZED_2]]) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x2x1024xf32>, tensor<3x3x1024x1xf32>) -> tensor<1x3x1x1xf32>
+// CHECK: %[[DEQUANTIZED_3:.*]] = "tf.PartitionedCall"(%[[IDENTITY_3]]) <{config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform}> : (tensor<1024x3x4x3xi8>) -> tensor<1024x3x4x3xf32>
+// CHECK: %[[GATHER:.*]] = "tf.GatherV2"(%[[DEQUANTIZED_3]], %arg0, %[[AXIS]]) <{batch_dims = 0 : i64}> {device = ""} : (tensor<1024x3x4x3xf32>, tensor<1xi32>, tensor<i32>) -> tensor<1x3x4x3xf32>
+// CHECK: %[[CONV_1:.*]] = "tf.Conv2D"(%[[GATHER]], %[[DEQUANTIZED_1]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true}> {device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x1024xf32>) -> tensor<1x3x2x1024xf32>
+// CHECK: %[[CONV_2:.*]] = "tf.Conv2D"(%[[CONV_1]], %[[DEQUANTIZED_2]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true}> {device = ""} : (tensor<1x3x2x1024xf32>, tensor<3x3x1024x1xf32>) -> tensor<1x3x1x1xf32>
 
 // CHECK-LABEL: func private @composite_dequantize_uniform__
-// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() {value = dense<0.00866141729> : tensor<f32>} : () -> tensor<f32>
+// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() <{value = dense<0.00866141729> : tensor<f32>}> : () -> tensor<f32>
 
 // CHECK-LABEL: func private @composite_dequantize_uniform_
-// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() {value = dense<0.00866141729> : tensor<f32>} : () -> tensor<f32>
+// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() <{value = dense<0.00866141729> : tensor<f32>}> : () -> tensor<f32>
 
 // CHECK-LABEL: func private @composite_dequantize_uniform
-// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() {value = dense<0.00866141729> : tensor<f32>} : () -> tensor<f32>
+// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() <{value = dense<0.00866141729> : tensor<f32>}> : () -> tensor<f32>
 
-}
\ No newline at end of file
+}
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_xla.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_xla.mlir
index 4356d08..f24b639 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_xla.mlir
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_xla.mlir
@@ -23,7 +23,7 @@
 // CHECK-DAG: [[weight:%.+]] = "arith.constant"() <{value = dense_resource<__elided__> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2x!quant.uniform<i8:f32, 0.074855112561992565:-1>>
 // CHECK: [[q_input:%.+]] = "quantfork.qcast"(%arg0) : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform<i8:f32, 0.58810077742034317:-128>>
 // CHECK-NEXT: [[q_bias:%.+]] = "quantfork.qcast"([[bias]]) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i32:f32, 0.044022349891595126>>
-// CHECK-NEXT: [[conv:%.+]] = "tf.PartitionedCall"([[q_input]], [[weight]], [[q_bias]]) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @[[composite_fn:composite_conv2d_with_bias_and_relu6_fn.*]]} : (tensor<1x3x4x3x!quant.uniform<i8:f32, 0.58810077742034317:-128>>, tensor<2x3x3x2x!quant.uniform<i8:f32, 0.074855112561992565:-1>>, tensor<2x!quant.uniform<i32:f32, 0.044022349891595126>>) -> tensor<*x!quant.uniform<i8:f32, 0.023529411764705882:-128>>
+// CHECK-NEXT: [[conv:%.+]] = "tf.PartitionedCall"([[q_input]], [[weight]], [[q_bias]]) <{config = "", config_proto = "", executor_type = "", f = @[[composite_fn:composite_conv2d_with_bias_and_relu6_fn.*]]}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x3x4x3x!quant.uniform<i8:f32, 0.58810077742034317:-128>>, tensor<2x3x3x2x!quant.uniform<i8:f32, 0.074855112561992565:-1>>, tensor<2x!quant.uniform<i32:f32, 0.044022349891595126>>) -> tensor<*x!quant.uniform<i8:f32, 0.023529411764705882:-128>>
 // CHECK-NEXT: [[res:%.+]] = "quantfork.dcast"([[conv]]) : (tensor<*x!quant.uniform<i8:f32, 0.023529411764705882:-128>>) -> tensor<*xf32>
 // CHECK-NEXT: "func.return"([[res]]) : (tensor<*xf32>) -> ()
 
@@ -127,10 +127,10 @@
 // CHECK-SAME: f = @composite_conv2d_with_bias_and_relu6_fn_1
 // CHECK-SAME: (tensor<1x3x4x3x!quant.uniform<i8:f32, 0.58810077742034317:-128>>, tensor<2x3x3x2x!quant.uniform<i8:f32, 0.074855112561992565:-1>>, tensor<2x!quant.uniform<i32:f32, 0.044022349891595126>>) -> tensor<*x!quant.uniform<i8:f32, 0.023529411764705882:-128>>
 // CHECK: %[[scast:.*]] = "quantfork.scast"(%[[conv]]
-// CHECK: %[[fcast:.*]] = "tf.Cast"(%[[scast]]) {Truncate = false} : (tensor<*xi8>) -> tensor<*xf32>
+// CHECK: %[[fcast:.*]] = "tf.Cast"(%[[scast]]) <{Truncate = false}> : (tensor<*xi8>) -> tensor<*xf32>
 // CHECK: %[[avgpool_f32:.*]] = "tf.AvgPool"(%[[fcast]])
 // CHECK-SAME: (tensor<*xf32>) -> tensor<*xf32>
 // CHECK: %[[round:.*]] = "tf.Round"(%[[avgpool_f32]])
-// CHECK: %[[icast:.*]] = "tf.Cast"(%[[round]]) {Truncate = false} : (tensor<*xf32>) -> tensor<*xi8>
+// CHECK: %[[icast:.*]] = "tf.Cast"(%[[round]]) <{Truncate = false}> : (tensor<*xf32>) -> tensor<*xi8>
 // CHECK: %[[reshape:.*]] = "tf.Reshape"(%[[icast]]
 // CHECK: %[[sc2:.*]] = "quantfork.scast"(%[[reshape]])
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/remove_var_init_by_const.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/remove_var_init_by_const.mlir
index d5e1820..da78ef7 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/remove_var_init_by_const.mlir
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/remove_var_init_by_const.mlir
@@ -79,7 +79,7 @@
   // CHECK-NOT: "tf.AssignVariableOp"
   // CHECK: %[[CST:.*]] = "tf.Const"()
   // CHECK-NEXT: %[[IDENTITY:.*]] = "tf.Identity"(%[[CST]])
-  // CHECK-NEXT: %[[VAR:.*]] = "tf.VarHandleOp"() {{{.*shared_name = "var_1".*}}}
+  // CHECK-NEXT: %[[VAR:.*]] = "tf.VarHandleOp"() <{{{.*shared_name = "var_1".*}}}>
   // CHECK-NEXT: "tf.AssignVariableOp"(%[[VAR]], %[[IDENTITY]])
 }
 
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/replace_cast_hacks_with_tf_xla_ops.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/replace_cast_hacks_with_tf_xla_ops.mlir
index c7fae4c..04677b6 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/replace_cast_hacks_with_tf_xla_ops.mlir
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/replace_cast_hacks_with_tf_xla_ops.mlir
@@ -57,17 +57,17 @@
   }
 
 // CHECK-LABEL: func @conv_with_bias_and_relu
-// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
-// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> tensor<2xi32>
-// CHECK-DAG: %[[CONST_2:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
-// CHECK-DAG: %[[CONST_3:.*]] = "tf.Const"() {value = dense<0> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
-// CHECK-DAG: %[[CONST_4:.*]] = "tf.Const"() {value = dense<{{.*}}> : tensor<4x2xi32>} : () -> tensor<4x2xi32>
+// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<[1, 2]> : tensor<2xi32>}> : () -> tensor<2xi32>
+// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() <{value = dense<1> : tensor<2xi32>}> : () -> tensor<2xi32>
+// CHECK-DAG: %[[CONST_2:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG: %[[CONST_3:.*]] = "tf.Const"() <{value = dense<0> : tensor<2x2xi32>}> : () -> tensor<2x2xi32>
+// CHECK-DAG: %[[CONST_4:.*]] = "tf.Const"() <{value = dense<{{.*}}> : tensor<4x2xi32>}> : () -> tensor<4x2xi32>
 // CHECK-DAG-SAME{LITERAL}: value = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]>
-// CHECK-DAG: %[[CONST_5:.*]] = "tf.Const"() {value = dense<-128> : tensor<i8>} : () -> tensor<i8>
-// CHECK-DAG: %[[CONST_6:.*]] = "tf.Const"() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2xi8>
-// CHECK-DAG: %[[CONST_7:.*]] = "tf.Const"() {value = dense<{{.*}}> : tensor<1x1x1x2xi32>} : () -> tensor<1x1x1x2xi32>
+// CHECK-DAG: %[[CONST_5:.*]] = "tf.Const"() <{value = dense<-128> : tensor<i8>}> : () -> tensor<i8>
+// CHECK-DAG: %[[CONST_6:.*]] = "tf.Const"() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2xi8>
+// CHECK-DAG: %[[CONST_7:.*]] = "tf.Const"() <{value = dense<{{.*}}> : tensor<1x1x1x2xi32>}> : () -> tensor<1x1x1x2xi32>
 // CHECK-DAG-SAME{LITERAL}: value = dense<[[[[-22016, -23680]]]]>
-// CHECK-DAG: %[[CONST_8:.*]] = "tf.Const"() {value = dense<[162, 160]> : tensor<2xi32>} : () -> tensor<2xi32>
+// CHECK-DAG: %[[CONST_8:.*]] = "tf.Const"() <{value = dense<[162, 160]> : tensor<2xi32>}> : () -> tensor<2xi32>
 // CHECK: %[[PADV2_0:.*]] = "tf.PadV2"({{.*}}, %[[CONST_4]], %[[CONST_5]]) : (tensor<1x3x4x3xi8>, tensor<4x2xi32>, tensor<i8>) -> tensor<1x4x5x3xi8>
 // CHECK: %[[XLACONVV2_0:.*]] = "tf.XlaConvV2"(%[[PADV2_0]], %[[CONST_6]], %[[CONST_0]], %[[CONST_3]], %[[CONST_1]], %[[CONST_1]], %[[CONST_2]])
 // CHECK-SAME: (tensor<1x4x5x3xi8>, tensor<2x3x3x2xi8>, tensor<2xi32>, tensor<2x2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<i32>) -> tensor<1x3x2x2xi32>
@@ -144,16 +144,16 @@
   }
 
 // CHECK-LABEL: func @depthwise_conv_with_bias_and_relu6
-// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() {value = dense<{{.*}}> : tensor<4x2xi32>} : () -> tensor<4x2xi32>
-// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() {value = dense<-128> : tensor<i8>} : () -> tensor<i8>
-// CHECK-DAG: %[[CONST_2:.*]] = "tf.Const"() {value = dense<{{.*}}> : tensor<2x3x1x3xi8>} : () -> tensor<2x3x1x3xi8>
-// CHECK-DAG: %[[CONST_3:.*]] = "tf.Const"() {value = dense<2> : tensor<2xi32>} : () -> tensor<2xi32>
-// CHECK-DAG: %[[CONST_4:.*]] = "tf.Const"() {value = dense<0> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
-// CHECK-DAG: %[[CONST_5:.*]] = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> tensor<2xi32>
-// CHECK-DAG: %[[CONST_6:.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
-// CHECK-DAG: %[[CONST_7:.*]] = "tf.Const"() {value = dense<{{.*}}> : tensor<1x1x1x3xi32>} : () -> tensor<1x1x1x3xi32>
+// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<{{.*}}> : tensor<4x2xi32>}> : () -> tensor<4x2xi32>
+// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() <{value = dense<-128> : tensor<i8>}> : () -> tensor<i8>
+// CHECK-DAG: %[[CONST_2:.*]] = "tf.Const"() <{value = dense<{{.*}}> : tensor<2x3x1x3xi8>}> : () -> tensor<2x3x1x3xi8>
+// CHECK-DAG: %[[CONST_3:.*]] = "tf.Const"() <{value = dense<2> : tensor<2xi32>}> : () -> tensor<2xi32>
+// CHECK-DAG: %[[CONST_4:.*]] = "tf.Const"() <{value = dense<0> : tensor<2x2xi32>}> : () -> tensor<2x2xi32>
+// CHECK-DAG: %[[CONST_5:.*]] = "tf.Const"() <{value = dense<1> : tensor<2xi32>}> : () -> tensor<2xi32>
+// CHECK-DAG: %[[CONST_6:.*]] = "tf.Const"() <{value = dense<3> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG: %[[CONST_7:.*]] = "tf.Const"() <{value = dense<{{.*}}> : tensor<1x1x1x3xi32>}> : () -> tensor<1x1x1x3xi32>
 // CHECK-DAG-SAME{LITERAL}: value = dense<[[[[55040, -15104, -21376]]]]>
-// CHECK-DAG: %[[CONST_8:.*]] = "tf.Const"() {value = dense<[129, 166, 221]> : tensor<3xi32>} : () -> tensor<3xi32>
+// CHECK-DAG: %[[CONST_8:.*]] = "tf.Const"() <{value = dense<[129, 166, 221]> : tensor<3xi32>}> : () -> tensor<3xi32>
 // CHECK: %[[PADV2_0:.*]] = "tf.PadV2"({{.*}}, %[[CONST_0]], %[[CONST_1]]) : (tensor<1x3x4x3xi8>, tensor<4x2xi32>, tensor<i8>) -> tensor<1x4x5x3xi8>
 // CHECK: %[[XLACONVV2_0:.*]] = "tf.XlaConvV2"(%[[PADV2_0]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]], %[[CONST_5]], %[[CONST_5]], %[[CONST_6]])
 // CHECK-SAME: (tensor<1x4x5x3xi8>, tensor<2x3x1x3xi8>, tensor<2xi32>, tensor<2x2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<i32>) -> tensor<1x2x2x3xi32>
@@ -198,14 +198,14 @@
   }
 
 // CHECK-LABEL: func @dynamic_shaped_conv2d_with_bias_and_relu6_inlined
-// CHECK-DAG: %[[filter:.*]] = "tf.Const"() {device = "", value = dense<2> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2xi8>
+// CHECK-DAG: %[[filter:.*]] = "tf.Const"() <{value = dense<2> : tensor<2x3x3x2xi8>}> {device = ""} : () -> tensor<2x3x3x2xi8>
 // CHECK-DAG: %[[input_shape:.*]] = "tf.Shape"({{.*}}) : (tensor<?x?x?x3xi8>) -> tensor<4xi32>
-// CHECK-DAG: %[[input_dim_1:.*]] = "tf.StridedSlice"(%[[input_shape]], {{.*}}, {{.*}}, {{.*}}) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
-// CHECK-DAG: %[[input_dim_2:.*]] = "tf.StridedSlice"(%[[input_shape]], {{.*}}, {{.*}}, {{.*}}) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
+// CHECK-DAG: %[[input_dim_1:.*]] = "tf.StridedSlice"(%[[input_shape]], {{.*}}, {{.*}}, {{.*}}) <{begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64}> : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
+// CHECK-DAG: %[[input_dim_2:.*]] = "tf.StridedSlice"(%[[input_shape]], {{.*}}, {{.*}}, {{.*}}) <{begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64}> : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
 // CHECK-DAG: %[[padding_rank_1:.*]] = "tf.Concat"({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}) : (tensor<i32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<8xi32>
 // CHECK-DAG: %[[padding_rank_2:.*]] = "tf.Reshape"(%[[padding_rank_1]], {{.*}}) : (tensor<8xi32>, tensor<2xi64>) -> tensor<4x2xi32>
 // CHECK-DAG: %[[input_padded:.*]] = "tf.PadV2"(%{{.*}}, %[[padding_rank_2]], {{.*}}) : (tensor<?x?x?x3xi8>, tensor<4x2xi32>, tensor<i8>) -> tensor<?x?x?x3xi8>
-// CHECK: %[[conv_output:.*]] = "tf.XlaConvV2"(%[[input_padded]], %[[filter]], {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}) {batch_group_count = 1 : i64, dimension_numbers = "{{.*}}", precision_config = ""} : (tensor<?x?x?x3xi8>, tensor<2x3x3x2xi8>, tensor<2xi32>, tensor<2x2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<i32>) -> tensor<?x?x?x2xi32>
+// CHECK: %[[conv_output:.*]] = "tf.XlaConvV2"(%[[input_padded]], %[[filter]], {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}) <{batch_group_count = 1 : i64, dimension_numbers = "{{.*}}", precision_config = ""}> : (tensor<?x?x?x3xi8>, tensor<2x3x3x2xi8>, tensor<2xi32>, tensor<2x2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<i32>) -> tensor<?x?x?x2xi32>
 // CHECK: %[[conv_output_sub:.*]] = "tf.Sub"(%[[conv_output]], {{.*}}) : (tensor<?x?x?x2xi32>, tensor<1x1x1x2xi32>) -> tensor<?x?x?x2xi32>
 // CHECK: %[[conv_output_add:.*]] = "tf.AddV2"(%[[conv_output_sub]], {{.*}}) {device = ""} : (tensor<?x?x?x2xi32>, tensor<2xi32>) -> tensor<?x?x?x2xi32>
 }
@@ -264,7 +264,7 @@
   }
 
 // CHECK-LABEL: func @conv_with_filter_larger_than_1MB
-// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<-264192> : tensor<1x1x1x512xi32>} : () -> tensor<1x1x1x512xi32>
+// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<-264192> : tensor<1x1x1x512xi32>}> : () -> tensor<1x1x1x512xi32>
 // CHECK: %[[PADV2_0:.*]] = "tf.PadV2"
 // CHECK: %[[XLACONVV2_0:.*]] = "tf.XlaConvV2"(%[[PADV2_0]]
 // CHECK: %[[SUB_0:.*]] = "tf.Sub"(%[[XLACONVV2_0]], %[[CONST]])
@@ -297,8 +297,8 @@
     return %12 : tensor<1x3xf32>
   }
 // CHECK-LABEL: func @matmul_with_relu
-// CHECK-DAG: %[[WEIGHT:.*]] = "tf.Const"() {device = "", value = dense<1> : tensor<1024x3xi8>} : () -> tensor<1024x3xi8>
-// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<-131072> : tensor<1x3xi32>} : () -> tensor<1x3xi32>
+// CHECK-DAG: %[[WEIGHT:.*]] = "tf.Const"() <{value = dense<1> : tensor<1024x3xi8>}> {device = ""} : () -> tensor<1024x3xi8>
+// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<-131072> : tensor<1x3xi32>}> : () -> tensor<1x3xi32>
 // CHECK: %[[MATMUL:.*]] = "tf.XlaDotV2"({{.*}}, %[[WEIGHT]])
 // CHECK-SAME: (tensor<1x1024xi8>, tensor<1024x3xi8>) -> tensor<1x3xi32>
 // CHECK: %[[SUB:.*]] = "tf.Sub"(%[[MATMUL]], %[[CONST]]) : (tensor<1x3xi32>, tensor<1x3xi32>) -> tensor<1x3xi32>
@@ -479,11 +479,11 @@
   }
 
 // CHECK-LABEL: func @conv3d_with_static_shape
-// CHECK-DAG: %[[WEIGHT:.*]] = "tf.Const"() {device = "", value = dense<1> : tensor<2x3x3x3x2xi8>} : () -> tensor<2x3x3x3x2xi8>
+// CHECK-DAG: %[[WEIGHT:.*]] = "tf.Const"() <{value = dense<1> : tensor<2x3x3x3x2xi8>}> {device = ""} : () -> tensor<2x3x3x3x2xi8>
 // CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {{.*}} : () -> tensor<5x2xi32>
 // CHECK-DAG-SAME{LITERAL}: value = dense<[[0, 0], [0, 1], [0, 1], [1, 1], [0, 0]]> : tensor<5x2xi32>
-// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() {value = dense<-43> : tensor<i8>} : () -> tensor<i8>
-// CHECK-DAG: %[[CONST_2:.*]] = "tf.Const"() {value = dense<-2322> : tensor<1x1x1x1x2xi32>} : () -> tensor<1x1x1x1x2xi32>
+// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() <{value = dense<-43> : tensor<i8>}> : () -> tensor<i8>
+// CHECK-DAG: %[[CONST_2:.*]] = "tf.Const"() <{value = dense<-2322> : tensor<1x1x1x1x2xi32>}> : () -> tensor<1x1x1x1x2xi32>
 
 // CHECK: %[[PAD:.*]] = "tf.PadV2"({{.*}}, %[[CONST]], %[[CONST_1]])
 // CHECK: %[[CONV:.*]] = "tf.XlaConvV2"(%[[PAD]], %[[WEIGHT]]
@@ -524,9 +524,9 @@
   }
 
 // CHECK-LABEL: func @conv3d_with_dynamic_shape
-// CHECK-DAG: %[[WEIGHT:.*]] = "tf.Const"() {device = "", value = dense<1> : tensor<2x3x3x3x2xi8>} : () -> tensor<2x3x3x3x2xi8>
-// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() {value = dense<-43> : tensor<i8>} : () -> tensor<i8>
-// CHECK-DAG: %[[CONST_2:.*]] = "tf.Const"() {value = dense<-2322> : tensor<1x1x1x1x2xi32>} : () -> tensor<1x1x1x1x2xi32>
+// CHECK-DAG: %[[WEIGHT:.*]] = "tf.Const"() <{value = dense<1> : tensor<2x3x3x3x2xi8>}> {device = ""} : () -> tensor<2x3x3x3x2xi8>
+// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() <{value = dense<-43> : tensor<i8>}> : () -> tensor<i8>
+// CHECK-DAG: %[[CONST_2:.*]] = "tf.Const"() <{value = dense<-2322> : tensor<1x1x1x1x2xi32>}> : () -> tensor<1x1x1x1x2xi32>
 
 // CHECK: %[[CONCAT:.*]] = "tf.Concat"({{.*}})
 // CHECK: %[[RESHAPE:.*]] = "tf.Reshape"(%[[CONCAT]], {{.*}}) : (tensor<10xi32>, tensor<2xi64>) -> tensor<5x2xi32>
@@ -565,7 +565,7 @@
   }
 
 // CHECK-LABEL: func @batch_matmul
-// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<-131072> : tensor<20x30x1x3xi32>} : () -> tensor<20x30x1x3xi32>
+// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<-131072> : tensor<20x30x1x3xi32>}> : () -> tensor<20x30x1x3xi32>
 // CHECK: %[[CAST:.*]] = "tf.Cast"
 // CHECK: %[[XLADOTV2_0:.*]] = "tf.XlaDotV2"(%[[CAST]]
 // CHECK: %[[SUB_0:.*]] = "tf.Sub"(%[[XLADOTV2_0]], %[[CONST]]) : (tensor<20x30x64x3xi32>, tensor<20x30x1x3xi32>) -> tensor<20x30x64x3xi32>
@@ -602,7 +602,7 @@
   }
 
 // CHECK-LABEL: func @broadcasting_weight_batch_matmul
-// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<[2, 1024, 3]> : tensor<3xi64>} : () -> tensor<3xi64>
+// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<[2, 1024, 3]> : tensor<3xi64>}> : () -> tensor<3xi64>
 // CHECK: %[[CAST:.*]] = "tf.Cast"
 // CHECK: %[[BROADCAST_TO:.*]] = "tf.BroadcastTo"({{.*}}, %[[CONST]]) : (tensor<1024x3xi8>, tensor<3xi64>) -> tensor<2x1024x3xi8>
 // CHECK: %[[XLADOTV2_0:.*]] = "tf.XlaDotV2"(%[[CAST]], %[[BROADCAST_TO]])
@@ -639,8 +639,8 @@
   }
 
 // CHECK-LABEL: func @broadcasting_input_batch_matmul
-// CHECK-DAG: %[[WEIGHT:.*]] = "tf.Const"() {device = "", value = {{.*}} : tensor<2x2x1024x3xi8>} : () -> tensor<2x2x1024x3xi8>
-// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<[2, 2, 1, 1024]> : tensor<4xi64>} : () -> tensor<4xi64>
+// CHECK-DAG: %[[WEIGHT:.*]] = "tf.Const"() <{value = {{.*}} : tensor<2x2x1024x3xi8>}> {device = ""} : () -> tensor<2x2x1024x3xi8>
+// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<[2, 2, 1, 1024]> : tensor<4xi64>}> : () -> tensor<4xi64>
 // CHECK: %[[CAST:.*]] = "tf.Cast"
 // CHECK: %[[BROADCAST_TO:.*]] = "tf.BroadcastTo"(%[[CAST]], %[[CONST]]) : (tensor<2x1x1024xi8>, tensor<4xi64>) -> tensor<2x2x1x1024xi8>
 // CHECK: %[[XLADOTV2_0:.*]] = "tf.XlaDotV2"(%[[BROADCAST_TO]], %[[WEIGHT]])
@@ -677,14 +677,14 @@
   }
 
 // CHECK-LABEL: func @dynamic_shape_batch_matmul
-// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
-// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
-// CHECK-DAG: %[[CONST_2:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
-// CHECK-DAG: %[[CONST_3:.*]] = "tf.Const"() {value = dense<[1024, 3]> : tensor<2xi64>} : () -> tensor<2xi64>
-// CHECK-DAG: %[[CONST_4:.*]] = "tf.Const"() {value = dense<> : tensor<0xi64>} : () -> tensor<0xi64>
-// CHECK-DAG: %[[CONST_5:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-// CHECK-DAG: %[[WEIGHT:.*]] = "tf.Const"() {device = "", value = {{.*}} : tensor<1024x3xi8>} : () -> tensor<1024x3xi8>
-// CHECK: %[[CAST:.*]] = "tf.Cast"({{.*}}) {Truncate = false, device = ""} : (tensor<?x1x1024xf32>) -> tensor<?x1x1024xi8>
+// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() <{value = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
+// CHECK-DAG: %[[CONST_2:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
+// CHECK-DAG: %[[CONST_3:.*]] = "tf.Const"() <{value = dense<[1024, 3]> : tensor<2xi64>}> : () -> tensor<2xi64>
+// CHECK-DAG: %[[CONST_4:.*]] = "tf.Const"() <{value = dense<> : tensor<0xi64>}> : () -> tensor<0xi64>
+// CHECK-DAG: %[[CONST_5:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG: %[[WEIGHT:.*]] = "tf.Const"() <{{{value = .* : tensor<1024x3xi8>}}}> {device = ""} : () -> tensor<1024x3xi8>
+// CHECK: %[[CAST:.*]] = "tf.Cast"({{.*}}) <{Truncate = false}> {device = ""} : (tensor<?x1x1024xf32>) -> tensor<?x1x1024xi8>
 // CHECK: %[[SHAPE:.*]] = "tf.Shape"(%[[CAST]]) : (tensor<?x1x1024xi8>) -> tensor<3xi64>
 // CHECK: %[[SLICE_1:.*]] = "tf.Slice"(%[[SHAPE]], %[[CONST]], %[[CONST_2]]) : (tensor<3xi64>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64>
 // CHECK: %[[SLICE_2:.*]] = "tf.Slice"(%[[SHAPE]], %[[CONST_2]], %[[CONST_1]]) : (tensor<3xi64>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi64>
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/replace_cast_hacks_with_tf_xla_ops_large_constants.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/replace_cast_hacks_with_tf_xla_ops_large_constants.mlir
index 775ab82..3c0c366 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/replace_cast_hacks_with_tf_xla_ops_large_constants.mlir
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/replace_cast_hacks_with_tf_xla_ops_large_constants.mlir
@@ -56,7 +56,7 @@
   }
 
 // CHECK-LABEL: func @conv_with_filter_larger_than_1GB
-// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<-237772800> : tensor<1x1x1x512xi32>} : () -> tensor<1x1x1x512xi32>
+// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<-237772800> : tensor<1x1x1x512xi32>}> : () -> tensor<1x1x1x512xi32>
 // CHECK: %[[PADV2_0:.*]] = "tf.PadV2"
 // CHECK: %[[XLACONVV2_0:.*]] = "tf.XlaConvV2"(%[[PADV2_0]]
 // CHECK: %[[SUB_0:.*]] = "tf.Sub"(%[[XLACONVV2_0]], %[[CONST]])
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/unfreeze_constants.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/unfreeze_constants.mlir
index ddf33e3..b7b4fa1 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/unfreeze_constants.mlir
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/unfreeze_constants.mlir
@@ -15,7 +15,7 @@
 // CHECK-SAME: tf_saved_model.initializer_type = "restore_op"
 
 // Check that variable is initialized by assigning the const value within the initializer function.
-// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<8xf32>}
+// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<8xf32>}>
 // CHECK-DAG: %[[VAR_HANDLE_0:.*]] = "tf.VarHandleOp"() {{.*shared_name = "const_0".*}}
 // CHECK-DAG: "tf.AssignVariableOp"(%[[VAR_HANDLE_0]], %[[CST_0]])
 
@@ -44,11 +44,11 @@
 // CHECK-SAME: tf_saved_model.exported_names = ["tf_saved_model.session_initializer_restore_op"]
 // CHECK-SAME: tf_saved_model.initializer_type = "restore_op"
 
-// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() {{{.*value = dense<1.000000e\+00> : tensor<8xf32>.*}}}
+// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() <{{{.*value = dense<1.000000e\+00> : tensor<8xf32>.*}}}>
 // CHECK-DAG: %[[VAR_HANDLE_0:.*]] = "tf.VarHandleOp"() {{.*shared_name = "const_0".*}}
 // CHECK-DAG: "tf.AssignVariableOp"(%[[VAR_HANDLE_0]], %[[CST_0]])
 
-// CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() {{{.*value = dense<2.000000e\+00> : tensor<8xf32>.*}}}
+// CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() <{{{.*value = dense<2.000000e\+00> : tensor<8xf32>.*}}}>
 // CHECK-DAG: %[[VAR_HANDLE_1:.*]] = "tf.VarHandleOp"()  {{.*shared_name = "const_1".*}}
 // CHECK-DAG: "tf.AssignVariableOp"(%[[VAR_HANDLE_1]], %[[CST_1]])
 
@@ -84,11 +84,11 @@
 // CHECK-SAME: tf_saved_model.exported_names = ["tf_saved_model.session_initializer_init"]
 // CHECK-SAME: tf_saved_model.initializer_type = "restore_op"
 
-// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<8xf32>}
+// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<8xf32>}>
 // CHECK-DAG: %[[VAR_HANDLE_0:.*]] = "tf.VarHandleOp"()
 // CHECK-DAG: "tf.AssignVariableOp"(%[[VAR_HANDLE_0]], %[[CST_0]])
 
-// CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() {value = dense<2.000000e+00> : tensor<8xf32>}
+// CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() <{value = dense<2.000000e+00> : tensor<8xf32>}>
 // CHECK-DAG: %[[VAR_HANDLE_1:.*]] = "tf.VarHandleOp"()
 // CHECK-DAG: "tf.AssignVariableOp"(%[[VAR_HANDLE_1]], %[[CST_1]])
 
@@ -123,7 +123,7 @@
 // CHECK-SAME: tf_saved_model.exported_names = ["tf_saved_model.session_initializer_restore_op"]
 // CHECK-SAME: tf_saved_model.initializer_type = "restore_op"
 
-// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() {value = dense<3.000000e+00> : tensor<8xf32>}
+// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() <{value = dense<3.000000e+00> : tensor<8xf32>}>
 // CHECK-DAG: %[[VAR_HANDLE_0:.*]] = "tf.VarHandleOp"()
 // CHECK-DAG: "tf.AssignVariableOp"(%[[VAR_HANDLE_0]], %[[CST_0]])
 
@@ -185,7 +185,7 @@
 
 // Check that `tf.VarHandleOp` is only created for the constant that is larger
 // than the threshold (16 bytes for this test).
-// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() {{{.*value = dense<5.000000e\+00> : tensor<8xf32>.*}}}
+// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() <{{{.*value = dense<5.000000e\+00> : tensor<8xf32>.*}}}>
 // CHECK-DAG: %[[VAR_HANDLE_0:.*]] = "tf.VarHandleOp"() {{.*shared_name = "const_0".*}}
 // CHECK-DAG: "tf.AssignVariableOp"(%[[VAR_HANDLE_0]], %[[CST_0]])
 
@@ -199,8 +199,8 @@
 // CHECK: @serving_default
 // CHECK-DAG: %[[VAR_HANDLE_2:.*]] = "tf.VarHandleOp"() {{.*shared_name = "const_0".*}} : () -> tensor<!tf_type.resource<tensor<8xf32>>>
 // CHECK-DAG: %[[READ_VAR_0:.*]] = "tf.ReadVariableOp"(%[[VAR_HANDLE_2]]) : (tensor<!tf_type.resource<tensor<8xf32>>>) -> tensor<8xf32>
-// CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() {{{.*value = dense<5.000000e\+00> : tensor<4xf32>.*}}}
-// CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {{{.*value = dense<0> : tensor<i64>.*}}}
+// CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() <{{{.*value = dense<5.000000e\+00> : tensor<4xf32>.*}}}>
+// CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() <{{{.*value = dense<0> : tensor<i64>.*}}}>
 // CHECK-DAG: %[[CONCAT:.*]] = "tf.ConcatV2"(%[[READ_VAR_0]], %[[CST_1]], %[[AXIS]])
 // CHECK: return %[[CONCAT]] : tensor<12xf32>
 }
@@ -214,7 +214,7 @@
 
 module attributes {tf_saved_model.semantics} {
 // CHECK: func.func @init_func_restore_op()
-// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<8xf32>}
+// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<8xf32>}>
 // Check that the variable's shared_name contains the fused loc's items joined
 // by the delimiter "_" and suffixed with a number.
 // CHECK-DAG: %[[VAR_HANDLE_0:.*]] = "tf.VarHandleOp"() {{.*shared_name = "apple_banana_0".*}}
@@ -247,7 +247,7 @@
     %cst_2 = "tf.Const"() {value = dense<1.0> : tensor<1x5x5x1024xf32>} : () -> tensor<1x5x5x1024xf32>
     // Check that these constants are unfrozen.
     // CHECK: func private @__inference_main
-    // CHECK: %[[VAR_HANDLE_0:.*]] = "tf.VarHandleOp"() {container = "", shared_name = "const_0"} : () -> tensor<!tf_type.resource<tensor<1x5x5x1024xf32>>>
+    // CHECK: %[[VAR_HANDLE_0:.*]] = "tf.VarHandleOp"() <{container = "", shared_name = "const_0"}> : () -> tensor<!tf_type.resource<tensor<1x5x5x1024xf32>>>
     // CHECK: %[[READ_VAR_0:.*]] = "tf.ReadVariableOp"(%0) : (tensor<!tf_type.resource<tensor<1x5x5x1024xf32>>>) -> tensor<1x5x5x1024xf32>
     %0:3 = "tf.While"(%cst_0, %cst_1, %arg0) {T = [i32, i32, f32], _lower_using_switch_merge = true, _num_original_outputs = 4 : i64, _read_only_resource_inputs = [], body = @while_body, cond = @while_cond, device = "", is_stateless = true, output_shapes = [#tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<1x5x5x1024>], parallel_iterations = 10 : i64, shape_invariant} : (tensor<i32>, tensor<i32>, tensor<1x5x5x1024xf32>) -> (tensor<i32>, tensor<i32>, tensor<1x5x5x1024xf32>)
     %1 = "tf.AddV2"(%0#2, %cst_2) {device = ""} : (tensor<1x5x5x1024xf32>, tensor<1x5x5x1024xf32>) -> tensor<1x5x5x1024xf32>
@@ -260,7 +260,7 @@
     %cst_0 = "tf.Const"() {value = dense<1.0> : tensor<1x5x5x1024xf32>} : () -> tensor<1x5x5x1024xf32>
     // Check that these constants are remained in constants.
     // CHECK: func private @while_body
-    // CHECK-DAG:  %[[CST_0:.*]]= "tf.Const"() {value = dense<1.000000e+00> : tensor<1x5x5x1024xf32>} : () -> tensor<1x5x5x1024xf32>
+    // CHECK-DAG:  %[[CST_0:.*]]= "tf.Const"() <{value = dense<1.000000e+00> : tensor<1x5x5x1024xf32>}> : () -> tensor<1x5x5x1024xf32>
     %0 = "tf.AddV2"(%arg0, %cst) {device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
     %1 = "tf.Identity"(%0) {device = ""} : (tensor<i32>) -> tensor<i32>
     %2 = "tf.Identity"(%arg1) {device = ""} : (tensor<i32>) -> tensor<i32>
@@ -269,13 +269,13 @@
     return %1, %2, %5 : tensor<i32>, tensor<i32>, tensor<1x5x5x1024xf32>
   }
 
-  func.func private @while_cond(%arg0: tensor<i32> {tf._user_specified_name = "while/loop_counter"}, %arg1: tensor<i32> {tf._user_specified_name = "while/maximum_iterations"}, %arg2: tensor<1x5x5x1024xf32>) -> tensor<i1> 
+  func.func private @while_cond(%arg0: tensor<i32> {tf._user_specified_name = "while/loop_counter"}, %arg1: tensor<i32> {tf._user_specified_name = "while/maximum_iterations"}, %arg2: tensor<1x5x5x1024xf32>) -> tensor<i1>
   attributes {tf._construction_context = "kEagerRuntime", tf._input_shapes = [#tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<1x5x5x1024>], tf._original_func_name = "while_cond_60"} {
     %cst = "tf.Const"() {value = dense<[0, 1, 2, 3]> : tensor<4xi32>} : () -> tensor<4xi32>
     %cst_0 = "tf.Const"() {value = dense<5.0> : tensor<f32>} : () -> tensor<f32>
     // Check that these constants are remained in constants.
     // CHECK: func private @while_cond
-    // CHECK-DAG:  %[[CST:.*]]= "tf.Const"() {value = dense<[0, 1, 2, 3]> : tensor<4xi32>} : () -> tensor<4xi32>
+    // CHECK-DAG:  %[[CST:.*]]= "tf.Const"() <{value = dense<[0, 1, 2, 3]> : tensor<4xi32>}> : () -> tensor<4xi32>
     %0 = "tf.Sum"(%arg2, %cst) {device = "", keep_dims = false} : (tensor<1x5x5x1024xf32>, tensor<4xi32>) -> tensor<f32>
     %1 = "tf.Less"(%0, %cst_0) {device = ""} : (tensor<f32>, tensor<f32>) -> tensor<i1>
     %2 = "tf.Identity"(%1) {device = ""} : (tensor<i1>) -> tensor<i1>
diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD
index 78d87e8..01f225c 100644
--- a/tensorflow/compiler/mlir/tensorflow/BUILD
+++ b/tensorflow/compiler/mlir/tensorflow/BUILD
@@ -353,11 +353,14 @@
         ":attribute_utils",
         ":convert_type",
         ":dynamic_shape_utils",
+        ":tensorflow_all_ops_inc_gen",
         ":tensorflow_attributes",
         ":tensorflow_op_interfaces",
         ":tensorflow_op_interfaces_inc_gen",
+        ":tensorflow_remaining_ops_inc_gen",
         ":tensorflow_side_effects",
         ":tensorflow_structs",
+        ":tensorflow_tfrt_ops_inc_gen",
         ":tensorflow_traits",
         ":tensorflow_types",
         ":tf_arith_ops_folder",
@@ -369,6 +372,8 @@
         "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_canonicalize_inc_gen",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/strings",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:ControlFlowInterfaces",
         "@llvm-project//mlir:DerivedAttributeOpInterface",
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc
index f5a6cf4..46cfa42 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc
@@ -387,7 +387,7 @@
   //   packed_input
   //     %b as %block_arg1: type
   const int32_t n = this->getN();
-  const int32_t num_replicated_inputs = getOperandSegmentSizes()[0];
+  const int32_t num_replicated_inputs = getProperties().operandSegmentSizes[0];
   const int32_t num_replicated_block_args = num_replicated_inputs / n;
 
   if (getNumOperands()) {
@@ -502,7 +502,7 @@
 
   Block& block = op.getBody().front();
 
-  auto operandSegmentSizes = op.getOperandSegmentSizes();
+  auto operandSegmentSizes = op.getProperties().operandSegmentSizes;
   const int32_t num_replicated_inputs = operandSegmentSizes[0];
   const int32_t num_packed_inputs = operandSegmentSizes[1];
 
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td
index c014738..3431273 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td
@@ -39,7 +39,6 @@
 }];
 
   let cppNamespace = "::mlir::tf_device";
-  let usePropertiesForAttributes = 0;
 }
 
 //===----------------------------------------------------------------------===//
@@ -263,7 +262,6 @@
     Variadic<AnyType>:$replicated_inputs,
     Variadic<AnyType>:$packed_inputs,
 
-    DenseI32ArrayAttr:$operandSegmentSizes,
     ConfinedAttr<I32Attr, [IntMinValue<2>]>:$n,
     OptionalAttr<DictionaryAttr>:$devices
   );
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td
index c3fdc7a..a5bb005 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td
@@ -48,7 +48,6 @@
   }];
 
   let cppNamespace = "::mlir::TF";
-  let usePropertiesForAttributes = 0;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h
index f9fa3d1..e81742e 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h
@@ -19,6 +19,7 @@
 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_H_
 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_H_
 
+#include "mlir/Bytecode/BytecodeOpInterface.h"  // from @llvm-project  // IWYU pragma: keep
 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
 #include "mlir/Dialect/Traits.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
index 5e4d546..a763b50 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
@@ -312,7 +312,7 @@
       [Terminator,
        Pure,
        NativeOpTrait<"ReturnLike", [], "", "">,
-       ParentOneOf<["CaseRegionOp", "IfRegionOp", "WhileRegionOp"]>,
+       ParentOneOf<["CaseRegionOp", "IfRegionOp", "WhileRegionOp", "GeneratorDatasetRegionOp"]>,
        DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface,
            ["getMutableSuccessorOperands"]>,
       ]> {
@@ -389,6 +389,57 @@
   let hasCanonicalizer = 1;
 }
 
+def TF_GeneratorDatasetRegionOp : TF_Op<"GeneratorDatasetRegion",
+      [AttrSizedOperandSegments,
+       DeclareOpInterfaceMethods<RegionBranchOpInterface, [
+           "areTypesCompatible",
+           "getEntrySuccessorOperands",
+           "getRegionInvocationBounds",
+           "getSuccessorRegions"
+       ]>,
+       SingleBlockImplicitTerminator<"YieldOp">,
+       TF_GeneratorOpSideEffect,
+      ]> {
+  let summary = "Regional version of GeneratorDataset";
+
+  let description = [{
+Creates a dataset that invokes its 'next' region to generate elements. Conceptually,
+within MLIR, we treat this op as if it fills a buffer with all the results right away,
+and those results are then passed (through the variant tensor result) to
+MakeIterator / IteratorGetNext. Note that the actual TF implementation differs: It
+generates the next element just in time, during IteratorGetNext.
+
+init_extra_args: Additional arguments to pass to 'init'.
+next_extra_args: Additional arguments to pass to 'next'. (Passed after the
+                 normal arguments which are from the return values of 'init'.)
+finalize_extra_args: Additional arguments to pass to 'finalize'. (Passed after
+                 the normal arguments which are from the return values of 'init'.)
+  }];
+
+  let arguments = (ins
+    Variadic<TF_Tensor>:$init_func_other_args,
+    Variadic<TF_Tensor>:$next_func_other_args,
+    Variadic<TF_Tensor>:$finalize_func_other_args,
+
+    ConfinedAttr<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
+    ConfinedAttr<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes,
+    DefaultValuedOptionalAttr<StrAttr, "\"\"">:$metadata
+  );
+
+  let results = (outs
+    TF_VariantTensor:$handle
+  );
+
+  let regions = (region SizedRegion<1>:$init,
+                        SizedRegion<1>:$next,
+                        SizedRegion<1>:$finalize
+                        );
+
+  TF_DerivedOperandTypeListAttr Tinit_func_args = TF_DerivedOperandTypeListAttr<0>;
+  TF_DerivedOperandTypeListAttr Tnext_func_args = TF_DerivedOperandTypeListAttr<1>;
+  TF_DerivedOperandTypeListAttr Tfinalize_func_args = TF_DerivedOperandTypeListAttr<2>;
+}
+
 def TF_LegacyCallOp : TF_Op<"LegacyCall",
       [CallOpInterface,
        DeclareOpInterfaceMethods<SymbolUserOpInterface>, Pure]> {
@@ -455,9 +506,7 @@
     Variadic<TF_StrTensor>:$dense_keys,
     Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$dense_defaults,
 
-    TF_ShapeAttrArray:$dense_shapes,
-    DenseI32ArrayAttr:$resultSegmentSizes,
-    DenseI32ArrayAttr:$operandSegmentSizes
+    TF_ShapeAttrArray:$dense_shapes
   );
 
   let results = (outs
@@ -491,8 +540,7 @@
     Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$dense_defaults,
 
     ConfinedAttr<I64Attr, [IntMinValue<0>]>:$num_sparse,
-    TF_ShapeAttrArray:$dense_shapes,
-    DenseI32ArrayAttr:$resultSegmentSizes
+    TF_ShapeAttrArray:$dense_shapes
   );
 
   let results = (outs
@@ -2169,6 +2217,8 @@
     TF_Float32Tensor:$embedding_table,
     TF_Int32Tensor:$num_minibatches_per_physical_sparse_core,
 
+    F32Attr:$clip_weight_min,
+    F32Attr:$clip_weight_max,
     StrAttr:$table_name
   );
 
@@ -2191,6 +2241,8 @@
     TF_Float32Tensor:$accumulator,
     TF_Int32Tensor:$num_minibatches_per_physical_sparse_core,
 
+    F32Attr:$clip_weight_min,
+    F32Attr:$clip_weight_max,
     StrAttr:$table_name
   );
 
@@ -2220,6 +2272,8 @@
     F32Attr:$beta1,
     F32Attr:$beta2,
     F32Attr:$epsilon,
+    F32Attr:$clip_weight_min,
+    F32Attr:$clip_weight_max,
     StrAttr:$table_name
   );
 
@@ -2249,6 +2303,8 @@
     F32Attr:$beta1,
     F32Attr:$beta2,
     F32Attr:$epsilon,
+    F32Attr:$clip_weight_min,
+    F32Attr:$clip_weight_max,
     StrAttr:$table_name
   );
 
@@ -2279,6 +2335,8 @@
     F32Attr:$learning_rate_power,
     F32Attr:$l1_regularization_strength,
     F32Attr:$l2_regularization_strength,
+    F32Attr:$clip_weight_min,
+    F32Attr:$clip_weight_max,
     StrAttr:$table_name
   );
 
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
index 4f21118..cee0e40 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
@@ -19,55 +19,60 @@
 #include <array>
 #include <cassert>
 #include <complex>
+#include <cstddef>
 #include <cstdint>
-#include <functional>
 #include <iterator>
-#include <limits>
-#include <numeric>
 #include <optional>
 #include <string>
 #include <tuple>
 #include <type_traits>
 
+#include "absl/log/check.h"
+#include "absl/strings/str_cat.h"
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/APInt.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/STLFunctionalExtras.h"
 #include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/SmallVector.h"
-#include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringRef.h"
-#include "llvm/ADT/StringSwitch.h"
 #include "llvm/ADT/iterator_range.h"
-#include "llvm/Support/Casting.h"
+#include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/raw_ostream.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
 #include "mlir/Dialect/Traits.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/BuiltinAttributeInterfaces.h"  // from @llvm-project
 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
-#include "mlir/IR/BuiltinOps.h"  // from @llvm-project
 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
-#include "mlir/IR/DialectImplementation.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/Matchers.h"  // from @llvm-project
 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
+#include "mlir/IR/OperationSupport.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
+#include "mlir/IR/Region.h"  // from @llvm-project
+#include "mlir/IR/SymbolTable.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
 #include "mlir/IR/ValueRange.h"  // from @llvm-project
+#include "mlir/Interfaces/CallInterfaces.h"  // from @llvm-project
 #include "mlir/Interfaces/ControlFlowInterfaces.h"  // from @llvm-project
+#include "mlir/Interfaces/InferTypeOpInterface.h"  // from @llvm-project
 #include "mlir/Parser/Parser.h"  // from @llvm-project
 #include "mlir/Support/LLVM.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
 #include "mlir/Transforms/InliningUtils.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.h"
-#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_canonicalization_helper.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_device_helper.h"
@@ -75,12 +80,14 @@
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_tensor_helper.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
 #include "tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h"
 #include "tensorflow/core/framework/kernel_shape_util.h"
 #include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/util/padding.h"
 #include "tensorflow/core/util/tensor_format.h"
 
@@ -1708,9 +1715,10 @@
 
 LogicalResult ConstOp::inferReturnTypes(
     MLIRContext* context, std::optional<Location> location, ValueRange operands,
-    DictionaryAttr attributes, OpaqueProperties, RegionRange regions,
+    DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
     SmallVectorImpl<Type>& inferredReturnTypes) {
-  auto value = attributes.get("value");
+  ConstOpAdaptor adaptor(operands, attributes, properties, regions);
+  auto value = adaptor.getValue();
   if (!value) return emitOptionalError(location, "missing attribute 'value'");
   if (auto elem_attr = value.dyn_cast<ElementsAttr>()) {
     inferredReturnTypes.assign({elem_attr.getType()});
@@ -1951,13 +1959,13 @@
 
 LogicalResult Conv2DOp::inferReturnTypeComponents(
     MLIRContext* context, std::optional<Location> location,
-    ValueShapeRange operands, DictionaryAttr attributes, OpaqueProperties,
-    RegionRange regions,
+    ValueShapeRange operands, DictionaryAttr attributes,
+    OpaqueProperties properties, RegionRange regions,
     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
-  Conv2DOpAdaptor op(operands.getValues(), attributes);
+  Conv2DOpAdaptor op(operands.getValues(), attributes, properties, regions);
   ArrayRef<Attribute> explicit_padding;
   ArrayAttr explicit_pad =
-      attributes.get("explicit_paddings").dyn_cast_or_null<::mlir::ArrayAttr>();
+      op.getExplicitPaddings().dyn_cast_or_null<::mlir::ArrayAttr>();
   if (!explicit_pad) {
     explicit_pad = ::mlir::Builder(context).getI64ArrayAttr({});
   }
@@ -2150,17 +2158,12 @@
 
 LogicalResult Conv3DOp::inferReturnTypeComponents(
     MLIRContext* context, std::optional<Location> location,
-    ValueShapeRange operands, DictionaryAttr attributes, OpaqueProperties,
-    RegionRange regions,
+    ValueShapeRange operands, DictionaryAttr attributes,
+    OpaqueProperties properties, RegionRange regions,
     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
-  Conv3DOpAdaptor op(operands.getValues(), attributes);
-  ArrayRef<Attribute> explicit_padding;
-  ArrayAttr explicit_pad =
-      attributes.get("explicit_paddings").dyn_cast_or_null<::mlir::ArrayAttr>();
-  if (!explicit_pad) {
-    explicit_pad = ::mlir::Builder(context).getI64ArrayAttr({});
-  }
-  explicit_padding = explicit_pad.getValue();
+  Conv3DOpAdaptor op(operands.getValues(), attributes, properties, regions);
+  ArrayAttr explicit_pad = ::mlir::Builder(context).getI64ArrayAttr({});
+  ArrayRef<Attribute> explicit_padding = explicit_pad.getValue();
 
   return inferConvReturnTypeComponents(location, op, explicit_padding,
                                        inferredReturnShapes);
@@ -2969,6 +2972,70 @@
 }
 
 //===----------------------------------------------------------------------===//
+// GeneratorDatasetRegionOp
+//===----------------------------------------------------------------------===//
+
+bool GeneratorDatasetRegionOp::areTypesCompatible(Type t1, Type t2) {
+  return true;  // Don't enforce type checking across control-flow edges.
+}
+
+void GeneratorDatasetRegionOp::getRegionInvocationBounds(
+    ArrayRef<Attribute> operands,
+    SmallVectorImpl<InvocationBounds>& invocationBounds) {
+  // We invoke `init` once, `finalize` once, and `next` any number of times.
+  invocationBounds.emplace_back(InvocationBounds(1, 1));          // init
+  invocationBounds.emplace_back(InvocationBounds::getUnknown());  // next
+  invocationBounds.emplace_back(InvocationBounds(1, 1));          // finalize
+}
+
+OperandRange GeneratorDatasetRegionOp::getEntrySuccessorOperands(
+    RegionBranchPoint point) {
+  auto end = this->getOperation()->operand_end();
+  if (point.isParent()) {
+    // The op itself doesn't branch back to itself.
+    return ::mlir::OperandRange(end, end);
+  } else if (point.getRegionOrNull() == &getInit()) {
+    return getInitFuncOtherArgs();
+  } else if (point.getRegionOrNull() == &getNext()) {
+    return getNextFuncOtherArgs();
+  } else /* finalize region */ {
+    return getFinalizeFuncOtherArgs();
+  }
+}
+
+void GeneratorDatasetRegionOp::getSuccessorRegions(
+    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor>& regions) {
+  int n;
+  if (point.isParent()) {
+    // The op itself branches to `init` first.
+    regions.push_back(
+        RegionSuccessor(&getInit(), getInit().front().getArguments()));
+  } else if (point.getRegionOrNull() == &getInit()) {
+    // `init` branches to `next`, passing along the arguments given to `init`'s
+    // yield. Said arguments precede the "other args".
+    n = getInitFuncOtherArgs().size();
+    regions.push_back(RegionSuccessor(
+        &getNext(), getNext().front().getArguments().drop_back(n)));
+  } else if (point.getRegionOrNull() == &getNext()) {
+    // `next` branches to itself, or to `finalize`, passing all arguments given
+    // to `next`s yield.
+
+    // The number of values we're passing along.
+    int num = getNext().front().getTerminator()->getNumOperands();
+
+    // The number of extra values from the parent ops that should go to `next`
+    // and `finalize`.
+    regions.push_back(RegionSuccessor(
+        &getNext(), getNext().front().getArguments().slice(0, num)));
+    regions.push_back(RegionSuccessor(
+        &getFinalize(), getFinalize().front().getArguments().slice(0, num)));
+  } else {
+    // `finalize` branches back to the op itself, not passing any arguments.
+    regions.push_back(RegionSuccessor());
+  }
+}
+
+//===----------------------------------------------------------------------===//
 // GatherV2Op
 //===----------------------------------------------------------------------===//
 
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
index 4ac5417..01cbbb9 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
@@ -18,11 +18,12 @@
 
 #include <algorithm>
 #include <array>
+#include <cassert>
+#include <climits>
+#include <cstddef>
 #include <cstdint>
-#include <functional>
 #include <iterator>
 #include <limits>
-#include <numeric>
 #include <optional>
 #include <string>
 #include <tuple>
@@ -34,16 +35,15 @@
 #include "llvm/ADT/BitVector.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/STLFunctionalExtras.h"
 #include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/SmallVector.h"
-#include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringRef.h"
-#include "llvm/ADT/StringSwitch.h"
 #include "llvm/ADT/Twine.h"
 #include "llvm/ADT/iterator_range.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/FormatVariadic.h"
-#include "llvm/Support/raw_ostream.h"
+#include "llvm/Support/MathExtras.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
 #include "mlir/Dialect/Traits.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
@@ -53,26 +53,28 @@
 #include "mlir/IR/BuiltinTypeInterfaces.h"  // from @llvm-project
 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
-#include "mlir/IR/DialectImplementation.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/Matchers.h"  // from @llvm-project
 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
-#include "mlir/IR/OpImplementation.h"  // from @llvm-project
+#include "mlir/IR/OperationSupport.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
+#include "mlir/IR/Region.h"  // from @llvm-project
+#include "mlir/IR/SymbolTable.h"  // from @llvm-project
 #include "mlir/IR/TypeRange.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
+#include "mlir/IR/ValueRange.h"  // from @llvm-project
 #include "mlir/Interfaces/CallInterfaces.h"  // from @llvm-project
 #include "mlir/Interfaces/ControlFlowInterfaces.h"  // from @llvm-project
+#include "mlir/Interfaces/InferTypeOpInterface.h"  // from @llvm-project
 #include "mlir/Interfaces/SideEffectInterfaces.h"  // from @llvm-project
 #include "mlir/Parser/Parser.h"  // from @llvm-project
 #include "mlir/Support/LLVM.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
 #include "mlir/Transforms/InliningUtils.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.h"
-#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_canonicalization_helper.h"
@@ -80,14 +82,12 @@
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_layout_helper.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_tensor_helper.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
-#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
 #include "tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/util/tensor_format.h"
 
 namespace mlir {
 namespace TF {
@@ -4389,6 +4389,12 @@
           this->getOperation(), 1,
           this->getOperation()->getOperands().size() - 1);
     }
+  } else if (auto regionOp = llvm::dyn_cast<GeneratorDatasetRegionOp>(
+                 this->getOperation()->getParentOp())) {
+    if (&regionOp.getFinalize() == this->getOperation()->getParentRegion()) {
+      // `finalize`'s returns get discarded.
+      return MutableOperandRange(this->getOperation(), 0, 0);
+    }
   }
   return MutableOperandRange(this->getOperation());
 }
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_ops.td
index 39748a6..726b829 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_ops.td
@@ -83,7 +83,6 @@
   }];
 
   let cppNamespace = "::mlir::tf_saved_model";
-  let usePropertiesForAttributes = 0;
 }
 
 
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc
index 8cce823..879aa62 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc
@@ -15,14 +15,19 @@
 
 #include "tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.h"
 
+#include <cstdint>
+
+#include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
-#include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/TypeUtilities.h"  // from @llvm-project
+#include "mlir/IR/Value.h"  // from @llvm-project
+#include "mlir/Support/LogicalResult.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
-#include "tensorflow/core/framework/resource_handle.h"
 
 //===----------------------------------------------------------------------===//
 // _TfrtGetResourceOp
@@ -84,6 +89,35 @@
   return mlir::success();
 }
 
+//===----------------------------------------------------------------------===//
+// IfrtProgramCall
+//===----------------------------------------------------------------------===//
+
+mlir::LogicalResult IfrtCallOp::verify() {
+  auto func = getOperation()->getParentOfType<mlir::func::FuncOp>();
+  if (func != nullptr && func->hasAttr("tfrt_ifrt_serving.program_id")) {
+    return emitOpError() << "cannot be nested inside an IFRT program";
+  }
+
+  for (mlir::Value arg : getArgs()) {
+    if (mlir::getElementTypeOrSelf(arg.getType())
+            .isa<mlir::TF::ResourceType>()) {
+      return emitOpError()
+             << "does not support passing '!tf.resource' values as arguments";
+    }
+  }
+
+  for (mlir::Value result : getResults()) {
+    if (mlir::getElementTypeOrSelf(result.getType())
+            .isa<mlir::TF::ResourceType>()) {
+      return emitOpError()
+             << "does not support returning '!tf.resource' values as results";
+    }
+  }
+
+  return mlir::success();
+}
+
 }  // namespace TF
 }  // namespace mlir
 
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.h b/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.h
index e4a41d4..c6c3eb2 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.h
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.h
@@ -16,6 +16,7 @@
 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TFRT_OPS_H_
 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TFRT_OPS_H_
 
+#include "mlir/Bytecode/BytecodeOpInterface.h"  // from @llvm-project  // IWYU pragma: keep
 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/Interfaces/DerivedAttributeOpInterface.h"  // from @llvm-project
 #include "mlir/Interfaces/InferTypeOpInterface.h"  // from @llvm-project
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.td
index a0e2935..3b7d00f 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.td
@@ -63,6 +63,36 @@
   let hasVerifier = 1;
 }
 
+def TF_IfrtCallOp : TF_Op<"IfrtCall", []> {
+  let summary = "Invokes a program via IFRT on a device";
+
+  let description = [{
+    This op calls an IFRT program uniquely identified by the given program id.
+
+    During lowering from a `tf_device.cluster_func` op to a `tf.IfrtCall` op,
+    the region owned by the former will be outlined to a function with a
+    `tfrt_ifrt_serving.program_id` attribute. After that, the runtime ensures
+    that the outlined function is compiled into an executable and is available
+    for lookup from `IfrtCall` TF ops.
+
+    This op also takes `variable_names` attribute to bind the variables (weights)
+    by names.
+  }];
+
+  let arguments = (ins
+    Variadic<TF_Tensor> : $args,
+    I64Attr : $program_id,
+    StrArrayAttr : $variable_names
+  );
+
+  let results = (outs Variadic<TF_Tensor> : $results);
+
+  TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<0>;
+  TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
+
+  let hasVerifier = 1;
+}
+
 // TODO(chky): Consider adding this op to tensorflow core ops.
 def TF_PwStreamResultsOp : TF_Op<"PwStreamResults"> {
   let summary = "Streams results back to the controller";
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/batchmatmul_to_einsum.mlir b/tensorflow/compiler/mlir/tensorflow/tests/batchmatmul_to_einsum.mlir
index 30dd120..4b6a33c 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/batchmatmul_to_einsum.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/batchmatmul_to_einsum.mlir
@@ -2,42 +2,42 @@
 
 func.func @test_batch_matmul_to_einsum(%arg0: tensor<1x2x3xf32>, %arg1: tensor<3x4xf32>) -> tensor<1x2x4xf32> {
   // CHECK-LABEL: test_batch_matmul_to_einsum
-  // CHECK: "tf.Einsum"(%arg0, %arg1) {equation = "...mk,...kn->...mn"} : (tensor<1x2x3xf32>, tensor<3x4xf32>) -> tensor<1x2x4xf32>
+  // CHECK: "tf.Einsum"(%arg0, %arg1) <{equation = "...mk,...kn->...mn"}> : (tensor<1x2x3xf32>, tensor<3x4xf32>) -> tensor<1x2x4xf32>
   %0 = "tf.BatchMatMul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<1x2x3xf32>, tensor<3x4xf32>) -> tensor<1x2x4xf32>
   func.return %0: tensor<1x2x4xf32>
 }
 
 func.func @test_batch_matmul_broadcast_to_einsum(%arg0: tensor<2x2x4xf32>, %arg1: tensor<2x4x2xf32>) -> tensor<2x2x2xf32> {
   // CHECK-LABEL: test_batch_matmul_broadcast_to_einsum
-  // CHECK: "tf.Einsum"(%arg0, %arg1) {equation = "...mk,...kn->...mn"} : (tensor<2x2x4xf32>, tensor<2x4x2xf32>) -> tensor<2x2x2xf32>
+  // CHECK: "tf.Einsum"(%arg0, %arg1) <{equation = "...mk,...kn->...mn"}> : (tensor<2x2x4xf32>, tensor<2x4x2xf32>) -> tensor<2x2x2xf32>
   %0 = "tf.BatchMatMul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<2x2x4xf32>, tensor<2x4x2xf32>) -> tensor<2x2x2xf32>
   func.return %0: tensor<2x2x2xf32>
 }
 
 func.func @test_batch_matmul_dynamic_shape_both_arg_to_einsum(%arg0: tensor<1x2x?xf32>, %arg1: tensor<?x4xf32>) -> tensor<1x2x4xf32> {
   // CHECK-LABEL: test_batch_matmul_dynamic_shape_both_arg_to_einsum
-  // CHECK: "tf.Einsum"(%arg0, %arg1) {equation = "...mk,...kn->...mn"} : (tensor<1x2x?xf32>, tensor<?x4xf32>) -> tensor<1x2x4xf32>
+  // CHECK: "tf.Einsum"(%arg0, %arg1) <{equation = "...mk,...kn->...mn"}> : (tensor<1x2x?xf32>, tensor<?x4xf32>) -> tensor<1x2x4xf32>
   %0 = "tf.BatchMatMul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<1x2x?xf32>, tensor<?x4xf32>) -> tensor<1x2x4xf32>
   func.return %0: tensor<1x2x4xf32>
 }
 
 func.func @test_batch_matmul_dynamic_shape_one_arg_to_einsum(%arg0: tensor<1x2x?xf32>, %arg1: tensor<3x4xf32>) -> tensor<1x2x4xf32> {
   // CHECK-LABEL: test_batch_matmul_dynamic_shape_one_arg_to_einsum
-  // CHECK: "tf.Einsum"(%arg0, %arg1) {equation = "...mk,...kn->...mn"} : (tensor<1x2x?xf32>, tensor<3x4xf32>) -> tensor<1x2x4xf32>
+  // CHECK: "tf.Einsum"(%arg0, %arg1) <{equation = "...mk,...kn->...mn"}> : (tensor<1x2x?xf32>, tensor<3x4xf32>) -> tensor<1x2x4xf32>
   %0 = "tf.BatchMatMul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<1x2x?xf32>, tensor<3x4xf32>) -> tensor<1x2x4xf32>
   func.return %0: tensor<1x2x4xf32>
 }
 
 func.func @test_batch_matmul_adj_to_einsum(%arg0: tensor<1x2x3xf32>, %arg1: tensor<4x3xf32>) -> tensor<1x2x4xf32> {
   // CHECK-LABEL: test_batch_matmul_adj_to_einsum
-  // CHECK: %[[RES_EINSUM:[0-9]*]] = "tf.Einsum"(%arg0, %arg1) {equation = "...mk,...nk->...mn"} : (tensor<1x2x3xf32>, tensor<4x3xf32>) -> tensor<1x2x4xf32>
+  // CHECK: %[[RES_EINSUM:[0-9]*]] = "tf.Einsum"(%arg0, %arg1) <{equation = "...mk,...nk->...mn"}> : (tensor<1x2x3xf32>, tensor<4x3xf32>) -> tensor<1x2x4xf32>
   // CHECK: return %[[RES_EINSUM]] : tensor<1x2x4xf32>
   %0 = "tf.BatchMatMul"(%arg0, %arg1) {adj_x = false, adj_y = true} : (tensor<1x2x3xf32>, tensor<4x3xf32>) -> tensor<1x2x4xf32>
   func.return %0: tensor<1x2x4xf32>
 }
 
 func.func @test_batch_matmulV2_adj_to_einsum(%arg0: tensor<1x3x2xf32>, %arg1: tensor<3x4xf32>) -> tensor<1x2x4xf32> {
-  // CHECK: %[[RES_EINSUM:[0-9]*]] = "tf.Einsum"(%arg0, %arg1) {equation = "...km,...kn->...mn"} : (tensor<1x3x2xf32>, tensor<3x4xf32>) -> tensor<1x2x4xf32>
+  // CHECK: %[[RES_EINSUM:[0-9]*]] = "tf.Einsum"(%arg0, %arg1) <{equation = "...km,...kn->...mn"}> : (tensor<1x3x2xf32>, tensor<3x4xf32>) -> tensor<1x2x4xf32>
   // CHECK: return %[[RES_EINSUM]] : tensor<1x2x4xf32>
   %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = false} : (tensor<1x3x2xf32>, tensor<3x4xf32>) -> tensor<1x2x4xf32>
   func.return %0: tensor<1x2x4xf32>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir b/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir
index 64b0bfb..2704fc3 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir
@@ -62,10 +62,10 @@
 // CHECK:    %[[SUB1:.*]], %[[SUB1_control:.*]] = tf_executor.island(%[[ADD2_control]]) wraps "tf.Sub"(%arg0, %arg1)
 // CHECK:    %[[MUL:.*]], %[[MUL_control:.*]] = tf_executor.island wraps "tf.Mul"(%[[SUB1]], %arg1)
 // CHECK:    %[[SUB2:.*]], %[[SUB2_control:.*]] = tf_executor.island(%[[ADD2_control]], %[[MUL_control]]) wraps "tf.Sub"(%[[ADD1]], %[[SUB1]])
-// CHECK:    %[[PRINT1:.*]], %[[PRINT1_control:.*]] = tf_executor.island wraps "tf.Print"(%[[SUB2]]) {message = "sub result"}
+// CHECK:    %[[PRINT1:.*]], %[[PRINT1_control:.*]] = tf_executor.island wraps "tf.Print"(%[[SUB2]]) <{message = "sub result"}>
 // CHECK:    %[[ISLAND1:.*]] = tf_executor.island(%[[ADD2_control]], %[[MUL_control]]) wraps "tf.NoOp"()
 // CHECK:    %[[ADD3:.*]], %[[ADD3_control:.*]] = tf_executor.island(%[[ISLAND1]], %[[ADD2_control]]) wraps "tf.Add"(%[[ADD2]], %[[ADD2]])
-// CHECK:    %[[PRINT2:.*]], %[[PRINT2_control:.*]] = tf_executor.island wraps "tf.Print"(%[[ADD3]]) {message = "add result"}
+// CHECK:    %[[PRINT2:.*]], %[[PRINT2_control:.*]] = tf_executor.island wraps "tf.Print"(%[[ADD3]]) <{message = "add result"}>
 // CHECK:    tf_executor.fetch %[[ADD2]], %[[MUL]], %[[PRINT1_control]], %[[PRINT2_control:.*]] :
 // CHECK:  }
 // CHECK:  return %[[GRAPH]]#0, %[[GRAPH]]#1
@@ -87,7 +87,7 @@
 // CHECK:  %[[GRAPH:.*]]:2 = tf_executor.graph {
 // CHECK:    %[[ADD1:.*]], %[[ADD1_control:.*]] = tf_executor.island wraps "tf.Add"(%arg0, %arg1)
 // CHECK:    %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ADD1_control:.*]], %arg1)
-// CHECK:    %[[PRINT:.*]], %[[PRINT_control:.*]] = tf_executor.island wraps "tf.Print"(%[[ADD2_control:.*]]) {message = "add result"}
+// CHECK:    %[[PRINT:.*]], %[[PRINT_control:.*]] = tf_executor.island wraps "tf.Print"(%[[ADD2_control:.*]]) <{message = "add result"}>
 // CHECK:    tf_executor.fetch %[[ADD1]], %[[ADD2]], %[[PRINT_control]] :
 // CHECK:  }
 // CHECK:  return %[[GRAPH]]#0, %[[GRAPH]]#1
@@ -116,11 +116,11 @@
 // CHECK: %[[GRAPH:.*]]:2 = tf_executor.graph {
 // CHECK:   %[[ADD1:.*]], %[[ADD1_control:.*]] = tf_executor.island wraps "tf.Add"(%arg0, %arg1)
 // CHECK:   %[[LESS:.*]], %[[LESS_control:.*]] = tf_executor.island wraps "tf.Less"(%arg1, %arg1)
-// CHECK:   %[[PRINT1:.*]], %[[PRINT1_control:.*]] = tf_executor.island wraps "tf.Print"(%[[ADD1]]) {message = "add result 1"}
+// CHECK:   %[[PRINT1:.*]], %[[PRINT1_control:.*]] = tf_executor.island wraps "tf.Print"(%[[ADD1]]) <{message = "add result 1"}>
 // CHECK:   %[[ISLAND1:.*]] = tf_executor.island(%[[LESS_control]], %[[PRINT1_control]]) wraps "tf.NoOp"()
 // CHECK:   %[[SWITCH_false:.*]], %[[SWITCH_true:.*]], {{.*}} = tf_executor.Switch %[[ADD1]], %[[LESS]], %[[ISLAND1]]
 // CHECK:   %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island wraps "tf.Add"(%[[SWITCH_false]], %arg1)
-// CHECK:   %[[PRINT2:.*]], %[[PRINT2_control:.*]] = tf_executor.island wraps "tf.Print"(%[[ADD2]]) {message = "add result 2"}
+// CHECK:   %[[PRINT2:.*]], %[[PRINT2_control:.*]] = tf_executor.island wraps "tf.Print"(%[[ADD2]]) <{message = "add result 2"}>
 // CHECK:   %[[MERGE:.*]], %[[MERGE_index:.*]], %{{.*}} = tf_executor.Merge %[[ADD2]], %[[SWITCH_true]], %[[PRINT2_control]]
 // CHECK:   tf_executor.fetch %[[MERGE]], %[[MERGE_index]]
 // CHECK: }
@@ -141,7 +141,7 @@
 
 // CHECK-LABEL: func @control_flow_plumbing
 // CHECK: %[[GRAPH:.*]] = tf_executor.graph {
-// CHECK:   %[[PRINT:.*]], %[[PRINT_control:.*]] = tf_executor.island wraps "tf.Print"(%arg0) {message = "Random Print"}
+// CHECK:   %[[PRINT:.*]], %[[PRINT_control:.*]] = tf_executor.island wraps "tf.Print"(%arg0) <{message = "Random Print"}>
 // CHECK:   %[[ADD1:.*]], %[[ADD1_control:.*]] = tf_executor.island(%[[PRINT_control]]) wraps "tf.Add"(%arg0, %arg1)
 // CHECK:   %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ADD1]], %arg1)
 // CHECK:   tf_executor.fetch %[[ADD2]] : tensor<*xi32>
@@ -193,7 +193,7 @@
 // CHECK:   %[[READ0:.*]], %[[READ0_CONTROL:.*]] = tf_executor.island wraps "tf.ReadVariableOp"(%arg0)
 // CHECK:   %[[ASSIGN0_CONTROL:.*]] = tf_executor.island(%[[READ0_CONTROL]]) wraps "tf.AssignVariableOp"(%arg0, %arg2)
 // CHECK:   %[[READ1:.*]], %[[READ1_CONTROL:.*]] = tf_executor.island wraps "tf.ReadVariableOp"(%arg1)
-// CHECK:   %[[VH0:.*]], %[[VH0_CONTROL:.*]] = tf_executor.island wraps "tf.VarHandleOp"() {container = "c", shared_name = "v0"}
+// CHECK:   %[[VH0:.*]], %[[VH0_CONTROL:.*]] = tf_executor.island wraps "tf.VarHandleOp"() <{container = "c", shared_name = "v0"}>
 // CHECK:   %[[READ2:.*]], %[[READ2_CONTROL:.*]] = tf_executor.island wraps "tf.ReadVariableOp"(%[[VH0]])
 // CHECK:   %[[ASSIGN1_CONTROL:.*]] = tf_executor.island(%[[READ1_CONTROL]]) wraps "tf.AssignVariableOp"(%arg1, %[[READ0:.*]])
 // CHECK:   %[[ASSIGN2_CONTROL:.*]] = tf_executor.island(%[[ASSIGN0_CONTROL]]) wraps "tf.AssignVariableOp"(%arg0, %[[READ2]])
@@ -222,8 +222,8 @@
 
 // CHECK-LABEL: func @unknown_side_effecting_op
 // CHECK: tf_executor.graph {
-// CHECK:   %[[VH0:.*]], %[[VH0_CONTROL:.*]] = tf_executor.island wraps "tf.VarHandleOp"() {container = "c", shared_name = "v0"}
-// CHECK:   %[[VH1:.*]], %[[VH1_CONTROL:.*]] = tf_executor.island wraps "tf.VarHandleOp"() {container = "c", shared_name = "v1"}
+// CHECK:   %[[VH0:.*]], %[[VH0_CONTROL:.*]] = tf_executor.island wraps "tf.VarHandleOp"() <{container = "c", shared_name = "v0"}>
+// CHECK:   %[[VH1:.*]], %[[VH1_CONTROL:.*]] = tf_executor.island wraps "tf.VarHandleOp"() <{container = "c", shared_name = "v1"}>
 // CHECK:   %[[READ0:.*]], %[[READ0_CONTROL:.*]] = tf_executor.island wraps "tf.ReadVariableOp"(%[[VH0]])
 // CHECK:   %[[ASSIGN0_CONTROL:.*]] = tf_executor.island wraps "tf.AssignVariableOp"(%[[VH1]], %arg0)
 // CHECK:   %[[UNKNOWN_CONTROL:.*]] = tf_executor.island(%[[READ0_CONTROL]], %[[ASSIGN0_CONTROL]]) wraps "tf._UnknownSideEffectingOp_"()
@@ -544,16 +544,17 @@
       tf_executor.yield %0 : tensor<i64>
     }
     // CHECK: "tf_device.launch"()
+    // CHECK-SAME: <{device = "/job:worker/replica:0/task:0/device:CPU:0"}>
     // CHECK:   "tf.OpC"(%[[VAL_0]]) : (tensor<i64>) -> ()
     // CHECK:   "tf.OpD"() : () -> ()
     // CHECK:   tf_device.return
-    // CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> ()
+    // CHECK: }) : () -> ()
     %island2 = tf_executor.island {
-      "tf_device.launch"() ({
+      "tf_device.launch"() <{device = "/job:worker/replica:0/task:0/device:CPU:0"}> ({
         "tf.OpC"(%island1#0) : (tensor<i64>) -> ()
         "tf.OpD"() : () -> ()
         tf_device.return
-      }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> ()
+      }) : () -> ()
       tf_executor.yield
     }
     // CHECK: tf_executor.fetch
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/cannonicalize_ops_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/cannonicalize_ops_outside_compilation.mlir
index d3be639..bc9155a 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/cannonicalize_ops_outside_compilation.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/cannonicalize_ops_outside_compilation.mlir
@@ -1,4 +1,4 @@
-// RUN: tf-opt %s -tf-tpu-bridge 2>&1 | FileCheck %s
+// RUN: tf-opt %s -tf-cluster-tpu-bridge-v2 -tfrt-lower-cluster-to-runtime-ops-tpu 2>&1 | FileCheck %s
 
 // This test verifies that the tail extraction is not terminated prematurely
 // due to the outside compilation attribute could be removed in
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir
index 2adf884..612f01c 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir
@@ -19,29 +19,29 @@
 // CHECK-LABEL: testGatherToV2
 // Ensures that axis param and batch_dims attr use their default values of 0.
 func.func @testGatherToV2(%params: tensor<4x3xf32>, %indices: tensor<1x2xi32>) -> tensor<2x3xf32> {
-  // CHECK: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-  // CHECK: "tf.GatherV2"(%arg0, %arg1, %[[AXIS]]) {batch_dims = 0 : i64, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<4x3xf32>, tensor<1x2xi32>, tensor<i32>) -> tensor<2x3xf32>
+  // CHECK: %[[AXIS:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
+  // CHECK: "tf.GatherV2"(%arg0, %arg1, %[[AXIS]]) <{batch_dims = 0 : i64}> {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<4x3xf32>, tensor<1x2xi32>, tensor<i32>) -> tensor<2x3xf32>
   %0 = "tf.Gather"(%params, %indices) {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<4x3xf32>, tensor<1x2xi32>) -> tensor<2x3xf32>
   func.return %0: tensor<2x3xf32>
 }
 
 // CHECK-LABEL: testBatchMatMulToV2
 func.func @testBatchMatMulToV2(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2x5x7xf32>) -> tensor<2x3x7xf32> {
-  // CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false, device = "/job:localhost/replica:0/task:0/device:GPU:0"}
-  %0 = "tf.BatchMatMul"(%arg0, %arg1) {adj_x = false, adj_y = false, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<2x3x5xf32>, tensor<2x5x7xf32>) -> tensor<2x3x7xf32>
+  // CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) <{adj_x = false, adj_y = false}> {device = "/job:localhost/replica:0/task:0/device:GPU:0"}
+  %0 = "tf.BatchMatMul"(%arg0, %arg1) <{adj_x = false, adj_y = false}> {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<2x3x5xf32>, tensor<2x5x7xf32>) -> tensor<2x3x7xf32>
   func.return %0: tensor<2x3x7xf32>
 }
 
 // CHECK-LABEL: testDynamicBatchMatMulToV2
 func.func @testDynamicBatchMatMulToV2(%arg0: tensor<2x3x5xf32>, %arg1: tensor<?x5x7xf32>) -> tensor<2x3x7xf32> {
-  // CHECK: "tf.BatchMatMul"(%arg0, %arg1) {adj_x = false, adj_y = false, device = "/job:localhost/replica:0/task:0/device:GPU:0"}
+  // CHECK: "tf.BatchMatMul"(%arg0, %arg1) <{adj_x = false, adj_y = false}> {device = "/job:localhost/replica:0/task:0/device:GPU:0"}
   %0 = "tf.BatchMatMul"(%arg0, %arg1) {adj_x = false, adj_y = false, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<2x3x5xf32>, tensor<?x5x7xf32>) -> tensor<2x3x7xf32>
   func.return %0: tensor<2x3x7xf32>
 }
 
 // CHECK-LABEL: testBatchMatMulToMatMul
 func.func @testBatchMatMulToMatMul(%arg0: tensor<2x3xf32>, %arg1: tensor<3x2xf32>) -> tensor<2x2xf32> {
-  // CHECK: %0 = "tf.MatMul"(%arg0, %arg1) {device = "/job:localhost/replica:0/task:0/device:GPU:0", transpose_a = false, transpose_b = false} : (tensor<2x3xf32>, tensor<3x2xf32>) -> tensor<2x2xf32>
+  // CHECK: %0 = "tf.MatMul"(%arg0, %arg1) <{transpose_a = false, transpose_b = false}> {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<2x3xf32>, tensor<3x2xf32>) -> tensor<2x2xf32>
   %0 = "tf.BatchMatMul"(%arg0, %arg1) {adj_x = false, adj_y = false, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<2x3xf32>, tensor<3x2xf32>) -> tensor<2x2xf32>
   // CHECK: return %0
   func.return %0: tensor<2x2xf32>
@@ -49,7 +49,7 @@
 
 // CHECK-LABEL: testBatchMatMulV2ToMatMul
 func.func @testBatchMatMulV2ToMatMul(%arg0: tensor<4x3xf32>, %arg1: tensor<4x5xf32>) -> tensor<3x5xf32> {
-  // CHECK: %0 = "tf.MatMul"(%arg0, %arg1) {device = "/job:localhost/replica:0/task:0/device:GPU:0", transpose_a = true, transpose_b = false} : (tensor<4x3xf32>, tensor<4x5xf32>) -> tensor<3x5xf32>
+  // CHECK: %0 = "tf.MatMul"(%arg0, %arg1) <{transpose_a = true, transpose_b = false}> {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<4x3xf32>, tensor<4x5xf32>) -> tensor<3x5xf32>
   %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = false, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<4x3xf32>, tensor<4x5xf32>) -> tensor<3x5xf32>
   // CHECK: return %0
   func.return %0: tensor<3x5xf32>
@@ -58,7 +58,7 @@
 
 // CHECK-LABEL: testBiasAddV1ToBiasAdd
 func.func @testBiasAddV1ToBiasAdd(%arg0: tensor<*xf32>, %arg1: tensor<128xf32>) -> tensor<*xf32> {
-  // CHECK: "tf.BiasAdd"(%arg0, %arg1) {data_format = "NHWC", device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<*xf32>, tensor<128xf32>) -> tensor<*xf32>
+  // CHECK: "tf.BiasAdd"(%arg0, %arg1) <{data_format = "NHWC"}> {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<*xf32>, tensor<128xf32>) -> tensor<*xf32>
   %0 = "tf.BiasAddV1"(%arg0, %arg1) {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<*xf32>, tensor<128xf32>) -> tensor<*xf32>
   func.return %0: tensor<*xf32>
 }
@@ -124,8 +124,8 @@
   %1 = "tf.Cast"(%arg0) {Truncate = true} : (tensor<8x16x32x64xf32>) -> tensor<8x16x32x64xi32>
   func.return %0, %1: tensor<8x16x32x64xi32>, tensor<8x16x32x64xi32>
 
-  // CHECK: %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<8x16x32x64xf32>) -> tensor<8x16x32x64xi32>
-  // CHECK: %1 = "tf.Cast"(%arg0) {Truncate = true} : (tensor<8x16x32x64xf32>) -> tensor<8x16x32x64xi32>
+  // CHECK: %0 = "tf.Cast"(%arg0) <{Truncate = false}> : (tensor<8x16x32x64xf32>) -> tensor<8x16x32x64xi32>
+  // CHECK: %1 = "tf.Cast"(%arg0) <{Truncate = true}> : (tensor<8x16x32x64xf32>) -> tensor<8x16x32x64xi32>
   // CHECK: return %0, %1
 }
 
@@ -135,8 +135,8 @@
   %1 = "tf.Cast"(%arg0) {Truncate = true} : (tensor<?xf32>) -> tensor<10xf32>
   func.return %0, %1: tensor<10xf32>, tensor<10xf32>
 
-  // CHECK: %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<?xf32>) -> tensor<10xf32>
-  // CHECK: %1 = "tf.Cast"(%arg0) {Truncate = true} : (tensor<?xf32>) -> tensor<10xf32>
+  // CHECK: %0 = "tf.Cast"(%arg0) <{Truncate = false}> : (tensor<?xf32>) -> tensor<10xf32>
+  // CHECK: %1 = "tf.Cast"(%arg0) <{Truncate = true}> : (tensor<?xf32>) -> tensor<10xf32>
   // CHECK: return %0, %1
 }
 
@@ -181,11 +181,11 @@
 func.func @testConcatCwiseBinaryOnInnerDim(%arg0: tensor<?x1xf32>,
   %arg1: tensor<?x1xf32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<?x2xf32> {
 
-  // CHECK-DAG: %[[LHS_AXIS:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>}
+  // CHECK-DAG: %[[LHS_AXIS:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}>
 
-  // CHECK: %[[ADD_LHS_CONCAT:.*]] = "tf.Pack"(%arg2, %arg3) {axis = 0 : i64, device = "/job:localhost/replica:0/task:0/device:GPU:0"}
+  // CHECK: %[[ADD_LHS_CONCAT:.*]] = "tf.Pack"(%arg2, %arg3) <{axis = 0 : i64}> {device = "/job:localhost/replica:0/task:0/device:GPU:0"}
   // CHECK: %[[MUL_LHS_CONCAT:.*]] = "tf.ConcatV2"(%arg0, %arg1, %[[LHS_AXIS]]) {device = "/job:localhost/replica:0/task:0/device:GPU:0"}
-  // CHECK: %[[MUL_RHS_CONCAT:.*]] = "tf.Pack"(%arg2, %arg3) {axis = 0 : i64, device = "/job:localhost/replica:0/task:0/device:GPU:0"}
+  // CHECK: %[[MUL_RHS_CONCAT:.*]] = "tf.Pack"(%arg2, %arg3) <{axis = 0 : i64}> {device = "/job:localhost/replica:0/task:0/device:GPU:0"}
 
   // CHECK: %[[MUL:.*]] = "tf.Mul"(%[[MUL_LHS_CONCAT]], %[[MUL_RHS_CONCAT]]) {device = "/job:localhost/replica:0/task:0/device:GPU:0"}
   // CHECK-SAME: (tensor<?x2xf32>, tensor<2xf32>) -> tensor<?x2xf32>
@@ -209,11 +209,11 @@
 func.func @testConcatCwiseBinaryPreserveAxisType(%arg0: tensor<?x1xf32>,
   %arg1: tensor<?x1xf32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<?x2xf32> {
 
-  // CHECK-DAG: %[[LHS_AXIS:.*]] = "tf.Const"() {value = dense<1> : tensor<i64>}
+  // CHECK-DAG: %[[LHS_AXIS:.*]] = "tf.Const"() <{value = dense<1> : tensor<i64>}>
 
-  // CHECK: %[[ADD_LHS_CONCAT:.*]] = "tf.Pack"(%arg2, %arg3) {axis = 0 : i64, device = "/job:localhost/replica:0/task:0/device:GPU:0"}
+  // CHECK: %[[ADD_LHS_CONCAT:.*]] = "tf.Pack"(%arg2, %arg3) <{axis = 0 : i64}> {device = "/job:localhost/replica:0/task:0/device:GPU:0"}
   // CHECK: %[[MUL_LHS_CONCAT:.*]] = "tf.ConcatV2"(%arg0, %arg1, %[[LHS_AXIS]]) {device = "/job:localhost/replica:0/task:0/device:GPU:0"}
-  // CHECK: %[[MUL_RHS_CONCAT:.*]] = "tf.Pack"(%arg2, %arg3) {axis = 0 : i64, device = "/job:localhost/replica:0/task:0/device:GPU:0"}
+  // CHECK: %[[MUL_RHS_CONCAT:.*]] = "tf.Pack"(%arg2, %arg3) <{axis = 0 : i64}> {device = "/job:localhost/replica:0/task:0/device:GPU:0"}
 
   // CHECK: %[[MUL:.*]] = "tf.Mul"(%[[MUL_LHS_CONCAT]], %[[MUL_RHS_CONCAT]])
   // CHECK-SAME: {device = "/job:localhost/replica:0/task:0/device:GPU:0"}
@@ -287,8 +287,8 @@
 // Synthesize binary ops when 1 of the 3 concat inputs is a non-binary op.
 // CHECK-LABEL: testConcatCwiseBinarySynthMulOp3Inputs
 func.func @testConcatCwiseBinarySynthMulOp3Inputs(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1xf32>, %arg2: tensor<?x1xf32>) -> tensor<?x3xf32> {
-  // CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
-  // CHECK-NEXT: %[[CONST0:.*]] = "tf.Const"() {value = dense<[2.000000e+00, 3.000000e+00, 1.000000e+00]>
+  // CHECK: %[[CONST:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
+  // CHECK-NEXT: %[[CONST0:.*]] = "tf.Const"() <{value = dense<[2.000000e+00, 3.000000e+00, 1.000000e+00]>
   // CHECK: %[[CONCAT:.*]] = "tf.ConcatV2"(%arg0, %arg1, %arg2, %[[CONST]]) {device = "/job:localhost/replica:0/task:0/device:GPU:0"}
   // CHECK: "tf.Mul"(%[[CONCAT]], %[[CONST0]]) {device = "/job:localhost/replica:0/task:0/device:GPU:0"}
   %axis = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
@@ -303,7 +303,7 @@
 
 // Similar to the above, with tf.Sub as the binary op kind.
 func.func @testConcatCwiseBinarySynthSubOp3Inputs(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1xf32>, %arg2: tensor<?x1xf32>) -> tensor<?x3xf32> {
-  // CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<[2.000000e+00, 3.000000e+00, 0.000000e+00]>
+  // CHECK: %[[CONST:.*]] = "tf.Const"() <{value = dense<[2.000000e+00, 3.000000e+00, 0.000000e+00]>
   // CHECK: %[[CONCAT:.*]] = "tf.ConcatV2"(%arg0, %arg1, %arg2,
   // CHECK: "tf.Sub"(%[[CONCAT]], %[[CONST]])
   %axis = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
@@ -681,8 +681,8 @@
 // CHECK-LABEL: func @testStaticAndIdenticalTypeForEqualOp
 func.func @testStaticAndIdenticalTypeForEqualOp(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
   // CHECK:      "tf.Equal"(%arg0, %arg1)
-  // CHECK-SAME:   device = "/job:localhost/replica:0/task:0/device:GPU:0"
   // CHECK-SAME:   incompatible_shape_error = true
+  // CHECK-SAME:   device = "/job:localhost/replica:0/task:0/device:GPU:0"
   %0 = "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
   func.return %0: tensor<2xi1>
 }
@@ -690,8 +690,8 @@
 // CHECK-LABEL: func @testStaticAndIdenticalTypeForNotEqualOp
 func.func @testStaticAndIdenticalTypeForNotEqualOp(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
   // CHECK:      "tf.NotEqual"(%arg0, %arg1)
-  // CHECK-SAME:   device = "/job:localhost/replica:0/task:0/device:GPU:0"
   // CHECK-SAME:   incompatible_shape_error = true
+  // CHECK-SAME:   device = "/job:localhost/replica:0/task:0/device:GPU:0"
   %0 = "tf.NotEqual"(%arg0, %arg1) {incompatible_shape_error = false, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
   func.return %0: tensor<2xi1>
 }
@@ -707,8 +707,8 @@
 // CHECK-LABEL: func @testKnownGoodBroadcastForNotEqualOp
 func.func @testKnownGoodBroadcastForNotEqualOp(%arg0: tensor<1x?xi32>, %arg1: tensor<?x1xi32>) -> tensor<?x?xi1> {
   // CHECK:      "tf.NotEqual"(%arg0, %arg1)
-  // CHECK-SAME:   device = "/job:localhost/replica:0/task:0/device:GPU:0"
   // CHECK-SAME:   incompatible_shape_error = true
+  // CHECK-SAME:   device = "/job:localhost/replica:0/task:0/device:GPU:0"
   %0 = "tf.NotEqual"(%arg0, %arg1) {incompatible_shape_error = false, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<1x?xi32>, tensor<?x1xi32>) -> tensor<?x?xi1>
   func.return %0: tensor<?x?xi1>
 }
@@ -740,8 +740,8 @@
 // CHECK-LABEL: func @testScalarForNotEqualOp
 func.func @testScalarForNotEqualOp(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i1> {
   // CHECK:      "tf.NotEqual"(%arg0, %arg1)
-  // CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:GPU:0"
   // CHECK-SAME: incompatible_shape_error = true
+  // CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:GPU:0"
   %0 = "tf.NotEqual"(%arg0, %arg1) {incompatible_shape_error = false, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
   func.return %0: tensor<i1>
 }
@@ -752,7 +752,7 @@
   %1 = "tf.LogicalNot"(%0) {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<8x16xi1>) -> tensor<8x16xi1>
   func.return %1: tensor<8x16xi1>
 
-  // CHECK: %[[NE:.*]] = "tf.NotEqual"(%arg0, %arg1) {device = "/job:localhost/replica:0/task:0/device:GPU:0", incompatible_shape_error = true}
+  // CHECK: %[[NE:.*]] = "tf.NotEqual"(%arg0, %arg1) <{incompatible_shape_error = true}> {device = "/job:localhost/replica:0/task:0/device:GPU:0"}
   // CHECK: return %[[NE]]
 }
 
@@ -762,7 +762,7 @@
   %1 = "tf.LogicalNot"(%0) {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<8x16xi1>) -> tensor<8x16xi1>
   func.return %1: tensor<8x16xi1>
 
-  // CHECK: %[[NE:.*]] = "tf.Equal"(%arg0, %arg1) {device = "/job:localhost/replica:0/task:0/device:GPU:0", incompatible_shape_error = true}
+  // CHECK: %[[NE:.*]] = "tf.Equal"(%arg0, %arg1) <{incompatible_shape_error = true}> {device = "/job:localhost/replica:0/task:0/device:GPU:0"}
   // CHECK: return %[[NE]]
 }
 
@@ -811,7 +811,7 @@
   %0 = "tf.Size"(%arg0) : (tensor<3x5x7xf32>) -> tensor<i32>
   func.return %0: tensor<i32>
 
-// CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<105> : tensor<i32>} : () -> tensor<i32>
+// CHECK: %[[CONST:.*]] = "tf.Const"() <{value = dense<105> : tensor<i32>}> : () -> tensor<i32>
 // CHECK: return %[[CONST]] : tensor<i32>
 }
 
@@ -873,7 +873,7 @@
 
 // CHECK-LABEL: @identityTranspose
 func.func @identityTranspose(%arg0: tensor<2x3x4x5x6xf32>) -> tensor<2x3x4x5x6xf32> {
-  %0 = "tf.Const"() {value = dense<[0, 1, 2, 3, 4]> : tensor<5xi32>} : () -> tensor<5xi32>
+  %0 = "tf.Const"() <{value = dense<[0, 1, 2, 3, 4]> : tensor<5xi32>}> : () -> tensor<5xi32>
   %1 = "tf.Transpose"(%arg0, %0) : (tensor<2x3x4x5x6xf32>, tensor<5xi32>) -> tensor<2x3x4x5x6xf32>
 
   func.return %1 : tensor<2x3x4x5x6xf32>
@@ -895,7 +895,7 @@
   %1 = "tf.Transpose"(%arg0, %0) : (tensor<2x3x4x5x6xf32>, tensor<5xi32>) -> tensor<2x3x4x6x5xf32>
 
   func.return %1 : tensor<2x3x4x6x5xf32>
-  // CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<[0, 1, 2, 4, 3]> : tensor<5xi32>} : () -> tensor<5xi32>
+  // CHECK: %[[CONST:.*]] = "tf.Const"() <{value = dense<[0, 1, 2, 4, 3]> : tensor<5xi32>}> : () -> tensor<5xi32>
   // CHECK: %[[TRANS:.*]] = "tf.Transpose"(%arg0, %[[CONST]]) : (tensor<2x3x4x5x6xf32>, tensor<5xi32>) -> tensor<2x3x4x6x5xf32>
   // CHECK: return %[[TRANS]]
 }
@@ -924,8 +924,8 @@
 
   func.return %result : tensor<1x4x4x8xf32>
 
-  // CHECK-DAG: %[[CONST1:.*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
-  // CHECK-DAG: %[[CONST2:.*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>}
+  // CHECK-DAG: %[[CONST1:.*]] = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}>
+  // CHECK-DAG: %[[CONST2:.*]] = "tf.Const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}>
   // CHECK: %[[TRANS1:.*]] = "tf.Transpose"(%arg0, %[[CONST1]]) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32>
   // CHECK: %[[TRANS2:.*]] = "tf.Transpose"(%[[TRANS1]], %[[CONST2]]) : (tensor<1x8x4x4xf32>, tensor<4xi32>) -> tensor<1x4x4x8xf32>
   // CHECK: return %[[TRANS2]]
@@ -951,8 +951,8 @@
 
   func.return %3 : tensor<4x1x4x8xf32>
 
-  // CHECK-DAG: %[[CONST1:.*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
-  // CHECK-DAG: %[[CONST2:.*]] = "tf.Const"() {value = dense<[2, 0, 3, 1]> : tensor<4xi32>}
+  // CHECK-DAG: %[[CONST1:.*]] = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}>
+  // CHECK-DAG: %[[CONST2:.*]] = "tf.Const"() <{value = dense<[2, 0, 3, 1]> : tensor<4xi32>}>
   // CHECK: %[[TRANS1:.*]] = "tf.Transpose"(%arg0, %[[CONST1]]) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32>
   // CHECK: %[[TRANS2:.*]] = "tf.Transpose"(%[[TRANS1]], %[[CONST2]]) : (tensor<1x8x4x4xf32>, tensor<4xi32>) -> tensor<4x1x4x8xf32>
   // CHECK: return %[[TRANS2]]
@@ -969,8 +969,8 @@
 func.func @addNWithZerosFloat(%arg0: tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) {
   %0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32>
   %1 = "tf.Const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32>
-  // CHECK-DAG: [[ZERO:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<2xf32>}
-  // CHECK-DAG: [[ONE:%.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<2xf32>}
+  // CHECK-DAG: [[ZERO:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<2xf32>}>
+  // CHECK-DAG: [[ONE:%.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<2xf32>}>
   // CHECK: [[ADD_N:%.*]] = "tf.AddN"(%arg0, [[ZERO]], [[ONE]])
   // CHECK: return %arg0, %arg0, [[ZERO]], [[ADD_N]]
   %2 = "tf.AddN"(%arg0, %1, %1) : (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
@@ -984,8 +984,8 @@
 func.func @addNWithZerosInt(%arg0: tensor<2xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) {
   %0 = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> tensor<2xi32>
   %1 = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> tensor<2xi32>
-  // CHECK-DAG: [[ZERO:%.*]] = "tf.Const"() {value = dense<0> : tensor<2xi32>}
-  // CHECK-DAG: [[ONE:%.*]] = "tf.Const"() {value = dense<1> : tensor<2xi32>}
+  // CHECK-DAG: [[ZERO:%.*]] = "tf.Const"() <{value = dense<0> : tensor<2xi32>}>
+  // CHECK-DAG: [[ONE:%.*]] = "tf.Const"() <{value = dense<1> : tensor<2xi32>}>
   // CHECK: [[ADD_N:%.*]] = "tf.AddN"(%arg0, [[ZERO]], [[ONE]])
   // CHECK: return %arg0, %arg0, [[ZERO]], [[ADD_N]]
   %2 = "tf.AddN"(%arg0, %1, %1) : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
@@ -998,7 +998,7 @@
 // CHECK-LABEL: func @addNSkipFoldingIfBroadcasting
 func.func @addNSkipFoldingIfBroadcasting(%arg0: tensor<1xf32>) -> tensor<10xf32> {
   %0 = "tf.Const"() {value = dense<0.000000e+00> : tensor<10xf32>} : () -> tensor<10xf32>
-  // CHECK: [[ZERO:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<10xf32>}
+  // CHECK: [[ZERO:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<10xf32>}>
   // CHECK: [[ADD_N:%.*]] = "tf.AddN"(%arg0, [[ZERO]])
   // CHECK: return [[ADD_N]]
   %1 = "tf.AddN"(%arg0, %0) : (tensor<1xf32>, tensor<10xf32>) -> tensor<10xf32>
@@ -1014,8 +1014,8 @@
 
 // CHECK-LABEL: func @ToBool_0DScalarInt
 func.func @ToBool_0DScalarInt(%arg0: tensor<i32>) -> tensor<i1> {
-  // CHECK: [[Zero:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
-  // CHECK: [[NE:%.*]] = "tf.NotEqual"(%arg0, [[Zero]]) {device = "/job:localhost/replica:0/task:0/device:GPU:0", incompatible_shape_error = true}
+  // CHECK: [[Zero:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}>
+  // CHECK: [[NE:%.*]] = "tf.NotEqual"(%arg0, [[Zero]]) <{incompatible_shape_error = true}> {device = "/job:localhost/replica:0/task:0/device:GPU:0"}
   // CHECK: return [[NE]]
   %0 = "tf.ToBool"(%arg0) {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<i32>) -> tensor<i1>
   func.return %0 : tensor<i1>
@@ -1023,8 +1023,8 @@
 
 // CHECK-LABEL: func @ToBool_0DScalarFloat
 func.func @ToBool_0DScalarFloat(%arg0: tensor<f32>) -> tensor<i1> {
-  // CHECK: [[Zero:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
-  // CHECK: [[NE:%.*]] = "tf.NotEqual"(%arg0, [[Zero]]) {device = "/job:localhost/replica:0/task:0/device:GPU:0", incompatible_shape_error = true}
+  // CHECK: [[Zero:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+  // CHECK: [[NE:%.*]] = "tf.NotEqual"(%arg0, [[Zero]]) <{incompatible_shape_error = true}> {device = "/job:localhost/replica:0/task:0/device:GPU:0"}
   // CHECK: return [[NE]]
   %0 = "tf.ToBool"(%arg0) {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<f32>) -> tensor<i1>
   func.return %0 : tensor<i1>
@@ -1032,8 +1032,8 @@
 
 // CHECK-LABEL: func @ToBool_0DScalarString
 func.func @ToBool_0DScalarString(%arg0: tensor<!tf_type.string>) -> tensor<i1> {
-  // CHECK: [[EmptyStr:%.*]] = "tf.Const"() {value = dense<""> : tensor<!tf_type.string>} : () -> tensor<!tf_type.string>
-  // CHECK: [[NE:%.*]] = "tf.NotEqual"(%arg0, [[EmptyStr]]) {device = "/job:localhost/replica:0/task:0/device:GPU:0", incompatible_shape_error = true} : (tensor<!tf_type.string>, tensor<!tf_type.string>) -> tensor<i1>
+  // CHECK: [[EmptyStr:%.*]] = "tf.Const"() <{value = dense<""> : tensor<!tf_type.string>}> : () -> tensor<!tf_type.string>
+  // CHECK: [[NE:%.*]] = "tf.NotEqual"(%arg0, [[EmptyStr]]) <{incompatible_shape_error = true}> {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<!tf_type.string>, tensor<!tf_type.string>) -> tensor<i1>
   // CHECK: return [[NE]] : tensor<i1>
   %0 = "tf.ToBool"(%arg0) {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<!tf_type.string>) -> tensor<i1>
   func.return %0 : tensor<i1>
@@ -1041,7 +1041,7 @@
 
 // CHECK-LABEL: func @ToBool_1DTensor
 func.func @ToBool_1DTensor(%arg0: tensor<1xf32>) -> tensor<i1> {
-  // CHECK: [[Const:%.*]] = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
+  // CHECK: [[Const:%.*]] = "tf.Const"() <{value = dense<true> : tensor<i1>}> : () -> tensor<i1>
   // CHECK: return [[Const]]
   %0 = "tf.ToBool"(%arg0) : (tensor<1xf32>) -> tensor<i1>
   func.return %0 : tensor<i1>
@@ -1049,7 +1049,7 @@
 
 // CHECK-LABEL: func @ToBool_1DTensorZeroDim
 func.func @ToBool_1DTensorZeroDim(%arg0: tensor<0xf32>) -> tensor<i1> {
-  // CHECK: [[Const:%.*]] = "tf.Const"() {value = dense<false> : tensor<i1>} : () -> tensor<i1>
+  // CHECK: [[Const:%.*]] = "tf.Const"() <{value = dense<false> : tensor<i1>}> : () -> tensor<i1>
   // CHECK: return [[Const]]
   %0 = "tf.ToBool"(%arg0) : (tensor<0xf32>) -> tensor<i1>
   func.return %0 : tensor<i1>
@@ -1057,7 +1057,7 @@
 
 // CHECK-LABEL: func @ToBool_2DTensor
 func.func @ToBool_2DTensor(%arg0: tensor<1x5xf32>) -> tensor<i1> {
-  // CHECK: [[Const:%.*]] = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
+  // CHECK: [[Const:%.*]] = "tf.Const"() <{value = dense<true> : tensor<i1>}> : () -> tensor<i1>
   // CHECK: return [[Const]]
   %0 = "tf.ToBool"(%arg0) : (tensor<1x5xf32>) -> tensor<i1>
   func.return %0 : tensor<i1>
@@ -1065,7 +1065,7 @@
 
 // CHECK-LABEL: func @ToBool_2DTensorZeroDim
 func.func @ToBool_2DTensorZeroDim(%arg0: tensor<1x0xf32>) -> tensor<i1> {
-  // CHECK: [[Const:%.*]] = "tf.Const"() {value = dense<false> : tensor<i1>} : () -> tensor<i1>
+  // CHECK: [[Const:%.*]] = "tf.Const"() <{value = dense<false> : tensor<i1>}> : () -> tensor<i1>
   // CHECK: return [[Const]]
   %0 = "tf.ToBool"(%arg0) : (tensor<1x0xf32>) -> tensor<i1>
   func.return %0 : tensor<i1>
@@ -1098,7 +1098,7 @@
   "tf.AssignVariableOp"(%0, %1) : (tensor<*x!tf_type.resource>, tensor<f32>) -> ()
   func.return %1: tensor<f32>
 
- // CHECK: %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<!tf_type.resource<tensor<f32>>>) -> tensor<*x!tf_type.resource>
+ // CHECK: %0 = "tf.Cast"(%arg0) <{Truncate = false}> : (tensor<!tf_type.resource<tensor<f32>>>) -> tensor<*x!tf_type.resource>
  // CHECK: %1 = "tf.ReadVariableOp"(%0) : (tensor<*x!tf_type.resource>) -> tensor<f32>
  // CHECK: "tf.AssignVariableOp"(%0, %1) : (tensor<*x!tf_type.resource>, tensor<f32>) -> ()
  // CHECK: return %1
@@ -1118,7 +1118,7 @@
 
 // CHECK-LABEL: testRankOfRankedTensor
 func.func @testRankOfRankedTensor(%arg0 : tensor<4x3x2xf32>) -> tensor<i32> {
-  // CHECK:[[VAL0:%.+]] = "tf.Const"() {value = dense<3> : tensor<i32>}
+  // CHECK:[[VAL0:%.+]] = "tf.Const"() <{value = dense<3> : tensor<i32>}>
   %0 = "tf.Rank"(%arg0) : (tensor<4x3x2xf32>) -> tensor<i32>
 
   // CHECK: return [[VAL0]]
@@ -1143,16 +1143,16 @@
 func.func @foldFill() -> (tensor<3x2x1xf32>, tensor<*xf32>, tensor<*xcomplex<f32>>) {
   %0 = "tf.Const"() {value = dense<[3, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
   %1 = "tf.Const"() {value = dense<23.0> : tensor<f32>} : () -> tensor<f32>
-  // CHECK-DAG: "tf.Const"() {value = dense<2.300000e+01> : tensor<3x2x1xf32>}
+  // CHECK-DAG: "tf.Const"() <{value = dense<2.300000e+01> : tensor<3x2x1xf32>}>
   %2 = "tf.Fill"(%0, %1) : (tensor<3xi32>, tensor<f32>) -> tensor<3x2x1xf32>
-  // CHECK-DAG: "tf.Const"() {value = dense<2.300000e+01> : tensor<3x2x1xf32>}
+  // CHECK-DAG: "tf.Const"() <{value = dense<2.300000e+01> : tensor<3x2x1xf32>}>
   %3 = "tf.Fill"(%0, %1) : (tensor<3xi32>, tensor<f32>) -> tensor<*xf32>
 
   %complex_cst = "tf.Const"() {value = dense<(0.000000e+00,1.000000e+00)> : tensor<complex<f32>>} : () -> tensor<complex<f32>>
   // Here, custom folder doesn't handle complex dtypes and it is folded through
   // the constant folding hook.
   // TODO(hinsu): Handle complex dtypes in the custom folder for FillOp.
-  // CHECK-DAG: "tf.Const"() {value = dense<(0.000000e+00,1.000000e+00)> : tensor<3x2x1xcomplex<f32>>} : () -> tensor<*xcomplex<f32>>
+  // CHECK-DAG: "tf.Const"() <{value = dense<(0.000000e+00,1.000000e+00)> : tensor<3x2x1xcomplex<f32>>}> : () -> tensor<*xcomplex<f32>>
   %4 = "tf.Fill"(%0, %complex_cst) : (tensor<3xi32>, tensor<complex<f32>>) -> tensor<*xcomplex<f32>>
 
   func.return %2, %3, %4 : tensor<3x2x1xf32>, tensor<*xf32>, tensor<*xcomplex<f32>>
@@ -1164,13 +1164,13 @@
   %1 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
 
   // CHECK: %0 = "tf.PartitionedCall"(%arg0, %arg1)
-  // CHECK-SAME: device = "noodle"
   // CHECK-SAME: f = @sub
+  // CHECK-SAME: device = "noodle"
   %2 = "tf.If"(%0, %arg0, %arg1) {then_branch = @add, else_branch = @sub, output_shapes = [#tf_type.shape<>], device = "noodle", is_stateless = true} : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
   // CHECK: %1 = "tf.StatefulPartitionedCall"(%0, %arg1)
+  // CHECK-SAME: f = @add
   // CHECK-SAME: _underscore_attr = "something"
   // CHECK-SAME: device = "noodle"
-  // CHECK-SAME: f = @add
   %3 = "tf.If"(%1, %2, %arg1) {then_branch = @add, else_branch = @sub, output_shapes = [#tf_type.shape<>], device = "noodle", _underscore_attr = "something", is_stateless = false} : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
 
   // CHECK: %2 = "tf.If"
@@ -1233,13 +1233,13 @@
 func.func @eliminatePassThroughIfRegion(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<!tf_type.resource>) -> (tensor<f32>) {
   // CHECK: %[[PRED:.*]] = "tf._SomeOp"() : () -> tensor<i1>
   %pred = "tf._SomeOp"() : () -> tensor<i1>
-  // CHECK: %[[IF_OUTPUT:.*]] = "tf.IfRegion"(%[[PRED]]) ({
+  // CHECK: %[[IF_OUTPUT:.*]] = "tf.IfRegion"(%[[PRED]]) <{is_stateless = true}> ({
   // CHECK:   %[[MUL:.*]] = "tf.Mul"(%[[ARG0]], %[[ARG1]])
   // CHECK:   "tf.Yield"(%[[MUL]]) : (tensor<f32>)
   // CHECK:  },  {
   // CHECK:    %[[SUB:.*]] = "tf.Sub"(%[[ARG0]], %[[ARG1]])
   // CHECK:    "tf.Yield"(%[[SUB]]) : (tensor<f32>)
-  // CHECK:  }) {device = "/job:localhost/replica:0/task:0/device:GPU:0", is_stateless = true} : (tensor<i1>) -> tensor<f32>
+  // CHECK:  }) {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<i1>) -> tensor<f32>
   %0:4 = "tf.IfRegion"(%pred) ({
       %true_value = "tf.Mul"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
       "tf.Yield"(%arg1, %arg2, %true_value, %arg2) : (tensor<f32>, tensor<!tf_type.resource>, tensor<f32>, tensor<!tf_type.resource>) -> ()
@@ -1260,7 +1260,7 @@
 func.func @eliminatePassThroughCaseRegion(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<!tf_type.resource>) -> (tensor<f32>) {
   // CHECK: %[[INDEX:.*]] = "tf._SomeOp"() : () -> tensor<i32>
   %index = "tf._SomeOp"() : () -> tensor<i32>
-  // CHECK: %[[CASE_OUTPUT:.*]] = "tf.CaseRegion"(%[[INDEX]]) ({
+  // CHECK: %[[CASE_OUTPUT:.*]] = "tf.CaseRegion"(%[[INDEX]]) <{is_stateless = true}> ({
   // CHECK:   %[[MUL:.*]] = "tf.Mul"(%[[ARG0]], %[[ARG1]])
   // CHECK:   "tf.Yield"(%[[MUL]]) : (tensor<f32>)
   // CHECK:  },  {
@@ -1269,7 +1269,7 @@
   // CHECK:  },  {
   // CHECK:    %[[ADD:.*]] = "tf.AddV2"(%[[ARG0]], %[[ARG1]])
   // CHECK:    "tf.Yield"(%[[ADD]]) : (tensor<f32>)
-  // CHECK:  }) {device = "/job:localhost/replica:0/task:0/device:GPU:0", is_stateless = true} : (tensor<i32>) -> tensor<f32>
+  // CHECK:  }) {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<i32>) -> tensor<f32>
   %0:3 = "tf.CaseRegion"(%index) ({
       %mul = "tf.Mul"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
       "tf.Yield"(%arg1, %mul, %arg2) : (tensor<f32>, tensor<f32>, tensor<!tf_type.resource>) -> ()
@@ -1293,13 +1293,13 @@
   %3 = arith.constant dense<0> : tensor<i32>
 
   // CHECK: PartitionedCall
-  // CHECK-SAME: device = "noodle"
   // CHECK-SAME: f = @add
+  // CHECK-SAME: device = "noodle"
   %4 = "tf.Case"(%2, %arg0, %arg1) {branches = [@sub, @add], output_shapes = [#tf_type.shape<>], device = "noodle", is_stateless = false} : (tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<f32>
   // CHECK: PartitionedCall
+  // CHECK-SAME: f = @sub
   // CHECK-SAME: _cluster_launch = "not_ready"
   // CHECK-SAME: device = "noodle"
-  // CHECK-SAME: f = @sub
   %5 = "tf.Case"(%3, %4, %arg1) {branches = [@sub, @add], output_shapes = [#tf_type.shape<>], device= "noodle", _cluster_launch = "not_ready", is_stateless = false} : (tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<f32>
   func.return %5 : tensor<f32>
 }
@@ -1317,7 +1317,7 @@
 // CHECK-LABEL: testBatchToSpaceToBatchToSpaceND
 // CHECK-SAME: ([[INPUT:%.*]]: tensor<?x?x?x?xf32>, [[CROPS:%.*]]: tensor<?x?xi32>)
 func.func @testBatchToSpaceToBatchToSpaceND(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?xi32>) -> tensor<*xf32> {
-  // CHECK: [[BLOCK_SHAPE:%.*]] = "tf.Const"() {value = dense<8> : tensor<2xi64>}
+  // CHECK: [[BLOCK_SHAPE:%.*]] = "tf.Const"() <{value = dense<8> : tensor<2xi64>}>
   // CHECK: [[BATCH_TO_SHAPE_ND:%.*]] = "tf.BatchToSpaceND"([[INPUT]], [[BLOCK_SHAPE]], [[CROPS]]) {device = "/job:localhost/replica:0/task:0/device:GPU:0"}
   %0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 8 : i64, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<?x?x?x?xf32>, tensor<?x?xi32>) -> tensor<*xf32>
   // CHECK: return [[BATCH_TO_SHAPE_ND]]
@@ -1612,8 +1612,8 @@
 func.func @testIfDropOutputShapes(tensor<i1>, tensor<2xf32>) -> tensor<2xf32> {
 ^bb0(%arg0: tensor<i1>, %arg1: tensor<2xf32>):
   // CHECK: "tf.If"
-  // CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:GPU:0"
   // CHECK-NOT: output_shapes
+  // CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:GPU:0"
   %1 = "tf.If"(%arg0, %arg1) {
     then_branch = @testIfThen, else_branch = @testIfElse, is_stateless = false, output_shapes = [#tf_type.shape<>], device = "/job:localhost/replica:0/task:0/device:GPU:0"
   } : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
@@ -1647,10 +1647,10 @@
 
 // CHECK-LABEL: @testMatrixDiag
 func.func @testMatrixDiag(%diag: tensor<2x4xf32>) -> tensor<2x4x4xf32> {
-  // CHECK-DAG: %[[MINUS1:.*]] = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
-  // CHECK-DAG: %[[ZEROI:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-  // CHECK-DAG: %[[ZEROF:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
-  // CHECK-DAG: "tf.MatrixDiagV3"(%arg0, %[[ZEROI]], %[[MINUS1]], %[[MINUS1]], %[[ZEROF]]) {align = "RIGHT_LEFT", device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<2x4xf32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<f32>) -> tensor<2x4x4xf32>
+  // CHECK-DAG: %[[MINUS1:.*]] = "tf.Const"() <{value = dense<-1> : tensor<i32>}> : () -> tensor<i32>
+  // CHECK-DAG: %[[ZEROI:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
+  // CHECK-DAG: %[[ZEROF:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+  // CHECK-DAG: "tf.MatrixDiagV3"(%arg0, %[[ZEROI]], %[[MINUS1]], %[[MINUS1]], %[[ZEROF]]) <{align = "RIGHT_LEFT"}> {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<2x4xf32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<f32>) -> tensor<2x4x4xf32>
   %0 = "tf.MatrixDiag"(%diag) {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<2x4xf32>) -> tensor<2x4x4xf32>
   func.return %0 : tensor<2x4x4xf32>
 }
@@ -1660,9 +1660,9 @@
   %0 = "tf.MatrixSetDiag"(%arg0, %arg1) {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<3x3xi64>, tensor<3xi64>) -> tensor<3x3xi64>
   func.return %0 : tensor<3x3xi64>
 
-  // CHECK: %[[ZERO:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
+  // CHECK: %[[ZERO:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}>
   // CHECK: %[[RES:.*]] = "tf.MatrixSetDiagV3"(%arg0, %arg1, %[[ZERO]])
-  // CHECK-SAME: {align = "RIGHT_LEFT", device = "/job:localhost/replica:0/task:0/device:GPU:0"}
+  // CHECK-SAME: <{align = "RIGHT_LEFT"}> {device = "/job:localhost/replica:0/task:0/device:GPU:0"}
   // CHECK-SAME: (tensor<3x3xi64>, tensor<3xi64>, tensor<i32>) -> tensor<3x3xi64>
 }
 
@@ -1672,7 +1672,7 @@
   func.return %0 : tensor<3x3xi64>
 
   // CHECK: %[[RES:.*]] = "tf.MatrixSetDiagV3"(%arg0, %arg1, %arg2)
-  // CHECK-SAME: {align = "LEFT_LEFT", device = "/job:localhost/replica:0/task:0/device:GPU:0"}
+  // CHECK-SAME: <{align = "LEFT_LEFT"}> {device = "/job:localhost/replica:0/task:0/device:GPU:0"}
 }
 
 // CHECK-LABEL: @testVariableToVariableV2
@@ -1680,7 +1680,7 @@
   // CHECK-NOT: "tf.Variable"
 
   %0 = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
-  // CHECK: "tf.VariableV2"() {container = "", device = "/job:localhost/replica:0/task:0/device:GPU:0", shape = #tf_type.shape<>, shared_name = "var"}
+  // CHECK: "tf.VariableV2"() <{container = "", shape = #tf_type.shape<>, shared_name = "var"}> {device = "/job:localhost/replica:0/task:0/device:GPU:0"}
   %1 = "tf.Variable"() {container = "", dtype = i32, shared_name = "var", shape = #tf_type.shape<>, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : () -> tensor<!tf_type.int32ref>
   %2 = "tf.Assign"(%1, %0) : (tensor<!tf_type.int32ref>, tensor<i32>) -> (tensor<!tf_type.int32ref>)
 
@@ -1691,7 +1691,7 @@
 func.func @testUnpackAndCwiseUnary(%arg0: tensor<?x2xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
 
   // CHECK: %[[NEG:.*]] = "tf.Neg"(%arg0) {device = ""}
-  // CHECK: %[[UNPACK:.*]]:2 = "tf.Unpack"(%[[NEG]]) {axis = 1 : i64, device = ""}
+  // CHECK: %[[UNPACK:.*]]:2 = "tf.Unpack"(%[[NEG]]) <{axis = 1 : i64}> {device = ""}
   %unpacked:2 = "tf.Unpack"(%arg0) {axis = 1 : i64, device = ""}
                 : (tensor<?x2xf32>) -> (tensor<?xf32>, tensor<?xf32>)
   %0 = "tf.Neg"(%unpacked#0): (tensor<?xf32>) -> tensor<?xf32>
@@ -1708,7 +1708,7 @@
   %2 = "tf.Shape"(%arg0) : (tensor<?x1x2x?xf32>) -> tensor<4xi32>
   %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
   func.return %3 : tensor<2xi32>
-  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
+  // CHECK: %[[CST:.*]] = "tf.Const"() <{value = dense<[1, 2]> : tensor<2xi32>}> : () -> tensor<2xi32>
   // CHECK: return %[[CST]]
 }
 
@@ -1719,7 +1719,7 @@
   %2 = "tf.Shape"(%arg0) : (tensor<?x1x2x?xf32>) -> tensor<4xi64>
   %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi64>
   func.return %3 : tensor<2xi64>
-  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi64>} : () -> tensor<2xi64>
+  // CHECK: %[[CST:.*]] = "tf.Const"() <{value = dense<[1, 2]> : tensor<2xi64>}> : () -> tensor<2xi64>
   // CHECK: return %[[CST]]
 }
 
@@ -1730,7 +1730,7 @@
   %2 = "tf.Shape"(%arg0) : (tensor<?x1x2x?xf32>) -> tensor<4xi32>
   %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xi32>
   func.return %3 : tensor<?xi32>
-  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<?xi32>
+  // CHECK: %[[CST:.*]] = "tf.Const"() <{value = dense<[1, 2]> : tensor<2xi32>}> : () -> tensor<?xi32>
   // CHECK: return %[[CST]]
 }
 
@@ -1741,7 +1741,7 @@
   %2 = "tf.Shape"(%arg0) : (tensor<?x1x2x?xf32>) -> tensor<4xi32>
   %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
   func.return %3 : tensor<i32>
-  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: %[[CST:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
   // CHECK: return %[[CST]]
 }
 
@@ -1752,7 +1752,7 @@
   %2 = "tf.Shape"(%arg0) : (tensor<?x1x2x?xf32>) -> tensor<4xi64>
   %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i64>
   func.return %3 : tensor<i64>
-  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<1> : tensor<i64>} : () -> tensor<i64>
+  // CHECK: %[[CST:.*]] = "tf.Const"() <{value = dense<1> : tensor<i64>}> : () -> tensor<i64>
   // CHECK: return %[[CST]]
 }
 
@@ -1763,7 +1763,7 @@
   %2 = "tf.Shape"(%arg0) : (tensor<?x1x2x?xf32>) -> tensor<4xi32>
   %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32>
   func.return %3 : tensor<*xi32>
-  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<*xi32>
+  // CHECK: %[[CST:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<*xi32>
   // CHECK: return %[[CST]]
 }
 
@@ -1775,7 +1775,7 @@
   %3 = "tf.Shape"(%arg0) : (tensor<?x1x2x3xf32>) -> tensor<4xi32>
   %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
   func.return %4 : tensor<i32>
-  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: %[[CST:.*]] = "tf.Const"() <{value = dense<3> : tensor<i32>}> : () -> tensor<i32>
   // CHECK: return %[[CST]]
 }
 
@@ -1787,7 +1787,7 @@
   %3 = "tf.Shape"(%arg0) : (tensor<?x1x2x3xf32>) -> tensor<4xi32>
   %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
   func.return %4 : tensor<i32>
-  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<2> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: %[[CST:.*]] = "tf.Const"() <{value = dense<2> : tensor<i32>}> : () -> tensor<i32>
   // CHECK: return %[[CST]]
 }
 
@@ -1811,7 +1811,7 @@
   %3 = "tf.Shape"(%arg0) : (tensor<1x2x3x?xf32>) -> tensor<4xi32>
   %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 1 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
   func.return %4 : tensor<2xi32>
-  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
+  // CHECK: %[[CST:.*]] = "tf.Const"() <{value = dense<[1, 2]> : tensor<2xi32>}> : () -> tensor<2xi32>
   // CHECK: return %[[CST]]
 }
 
@@ -1822,7 +1822,7 @@
   %2 = "tf.Shape"(%arg0) : (tensor<?x1x2x3xf32>) -> tensor<4xi32>
   %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
   func.return %3 : tensor<3xi32>
-  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32>
+  // CHECK: %[[CST:.*]] = "tf.Const"() <{value = dense<[1, 2, 3]> : tensor<3xi32>}> : () -> tensor<3xi32>
   // CHECK: return %[[CST]]
 }
 
@@ -1834,7 +1834,7 @@
   %3 = "tf.Shape"(%arg0) : (tensor<1x2x3x4x?xf32>) -> tensor<5xi32>
   %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<5xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
   func.return %4 : tensor<2xi32>
-  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[2, 4]> : tensor<2xi32>} : () -> tensor<2xi32>
+  // CHECK: %[[CST:.*]] = "tf.Const"() <{value = dense<[2, 4]> : tensor<2xi32>}> : () -> tensor<2xi32>
   // CHECK: return %[[CST]]
 }
 
@@ -1845,7 +1845,7 @@
   %2 = "tf.Shape"(%arg0) : (tensor<?x1x2x3xf32>) -> tensor<4xi32>
   %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
   func.return %3 : tensor<3xi32>
-  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32>
+  // CHECK: %[[CST:.*]] = "tf.Const"() <{value = dense<[1, 2, 3]> : tensor<3xi32>}> : () -> tensor<3xi32>
   // CHECK: return %[[CST]]
 }
 
@@ -1857,7 +1857,7 @@
   %3 = "tf.Shape"(%arg0) : (tensor<1x2x3x?xf32>) -> tensor<4xi32>
   %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
   func.return %4 : tensor<1xi32>
-  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK: %[[CST:.*]] = "tf.Const"() <{value = dense<3> : tensor<1xi32>}> : () -> tensor<1xi32>
   // CHECK: return %[[CST]]
 }
 
@@ -1869,7 +1869,7 @@
   %3 = "tf.Shape"(%arg0) : (tensor<?x1x2x3xf32>) -> tensor<4xi32>
   %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
   func.return %4 : tensor<2xi32>
-  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[3, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
+  // CHECK: %[[CST:.*]] = "tf.Const"() <{value = dense<[3, 2]> : tensor<2xi32>}> : () -> tensor<2xi32>
   // CHECK: return %[[CST]]
 }
 
@@ -1881,7 +1881,7 @@
   %3 = "tf.Shape"(%arg0) : (tensor<?x1x2x3xf32>) -> tensor<4xi32>
   %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 1 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
   func.return %4 : tensor<2xi32>
-  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[3, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
+  // CHECK: %[[CST:.*]] = "tf.Const"() <{value = dense<[3, 2]> : tensor<2xi32>}> : () -> tensor<2xi32>
   // CHECK: return %[[CST]]
 }
 
@@ -1893,7 +1893,7 @@
   %3 = "tf.Shape"(%arg0) : (tensor<1x2x3x?xf32>) -> tensor<4xi32>
   %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
   func.return %4 : tensor<3xi32>
-  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[3, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
+  // CHECK: %[[CST:.*]] = "tf.Const"() <{value = dense<[3, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
   // CHECK: return %[[CST]]
 }
 
@@ -1905,7 +1905,7 @@
   %3 = "tf.Shape"(%arg0) : (tensor<?x1x2x3xf32>) -> tensor<4xi32>
   %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32>
   func.return %4 : tensor<0xi32>
-  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
+  // CHECK: %[[CST:.*]] = "tf.Const"() <{value = dense<> : tensor<0xi32>}> : () -> tensor<0xi32>
   // CHECK: return %[[CST]]
 }
 
@@ -1914,7 +1914,7 @@
   %0 = "tf.EnsureShape"(%arg0) {shape = #tf_type.shape<10x20>} : (tensor<10x20xf32>) -> tensor<10x20xf32>
   %1 = "tf.EnsureShape"(%arg0) {shape = #tf_type.shape<?x20>} : (tensor<10x20xf32>) -> tensor<10x20xf32>
   // Failing case which should not be folded.
-  // CHECK: %[[NF:.*]] = "tf.EnsureShape"(%arg0) {shape = #tf_type.shape<20x10>}
+  // CHECK: %[[NF:.*]] = "tf.EnsureShape"(%arg0) <{shape = #tf_type.shape<20x10>}>
   %2 = "tf.EnsureShape"(%arg0) {shape = #tf_type.shape<20x10>} : (tensor<10x20xf32>) -> tensor<20x10xf32>
   // CHECK: return %arg0, %arg0, %[[NF]]
   func.return %0, %1, %2: tensor<10x20xf32>, tensor<10x20xf32>, tensor<20x10xf32>
@@ -1924,7 +1924,7 @@
 func.func @testConvertPackToReshapeAxis0(%arg0: tensor<2x3xf32>) -> tensor<1x2x3xf32> {
   %0 = "tf.Pack"(%arg0) {axis = 0 : i64, _xla_outside_compilation = "1", device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<2x3xf32>) -> tensor<1x2x3xf32>
   func.return %0 : tensor<1x2x3xf32>
-  // CHECK: %[[SHAPE:.*]] = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi64>} : () -> tensor<3xi64>
+  // CHECK: %[[SHAPE:.*]] = "tf.Const"() <{value = dense<[1, 2, 3]> : tensor<3xi64>}> : () -> tensor<3xi64>
   // CHECK: %[[RESHAPE:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) {_xla_outside_compilation = "1", device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<2x3xf32>, tensor<3xi64>) -> tensor<1x2x3xf32>
   // CHECK: return %[[RESHAPE]] : tensor<1x2x3xf32>
 }
@@ -1933,7 +1933,7 @@
 func.func @testConvertPackToReshapeAxis1(%arg0: tensor<2x3xf32>) -> tensor<2x1x3xf32> {
   %0 = "tf.Pack"(%arg0) {axis = 1 : i64, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<2x3xf32>) -> tensor<2x1x3xf32>
   func.return %0 : tensor<2x1x3xf32>
-  // CHECK: %[[SHAPE:.*]] = "tf.Const"() {value = dense<[2, 1, 3]> : tensor<3xi64>} : () -> tensor<3xi64>
+  // CHECK: %[[SHAPE:.*]] = "tf.Const"() <{value = dense<[2, 1, 3]> : tensor<3xi64>}> : () -> tensor<3xi64>
   // CHECK: %[[RESHAPE:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<2x3xf32>, tensor<3xi64>) -> tensor<2x1x3xf32>
   // CHECK: return %[[RESHAPE]] : tensor<2x1x3xf32>
 }
@@ -1942,7 +1942,7 @@
 func.func @testDontConvertPackToReshapeDynamicShape(%arg0: tensor<2x?xf32>) -> tensor<1x2x?xf32> {
   %0 = "tf.Pack"(%arg0) {axis = 0 : i64} : (tensor<2x?xf32>) -> tensor<1x2x?xf32>
   func.return %0 : tensor<1x2x?xf32>
-  // CHECK: %[[PACK:.*]] = "tf.Pack"(%arg0) {axis = 0 : i64} : (tensor<2x?xf32>) -> tensor<1x2x?xf32>
+  // CHECK: %[[PACK:.*]] = "tf.Pack"(%arg0) <{axis = 0 : i64}> : (tensor<2x?xf32>) -> tensor<1x2x?xf32>
   // CHECK: return %[[PACK]] : tensor<1x2x?xf32>
 }
 
@@ -1950,7 +1950,7 @@
 func.func @while_with_id_passthrough(%arg0: tensor<7xf32> {tf._user_specified_name = "x"}) -> tensor<?xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "x", outputs = "identity_RetVal"}} {
   %0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
   %1 = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
-  // CHECK: %[[SHAPE:.*]] = "tf.Const"() {value = dense<7> : tensor<1xi32>}
+  // CHECK: %[[SHAPE:.*]] = "tf.Const"() <{value = dense<7> : tensor<1xi32>}>
   %2 = "tf.Const"() {value = dense<7> : tensor<1xi32>} : () -> tensor<1xi32>
   %3 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
   %4 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
@@ -1983,7 +1983,7 @@
 func.func @testConvertQuantizeAndDequantizeV2ToQuantizeAndDequantizeV4(%arg0 : tensor<?x?xf32>, %arg1 : tensor<f32>, %arg2 : tensor<f32>) -> tensor<?x?xf32> {
   %0 = "tf.QuantizeAndDequantizeV2"(%arg0, %arg1, %arg2) {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<?x?xf32>, tensor<f32>, tensor<f32>) -> tensor<?x?xf32>
   func.return %0 : tensor<?x?xf32>
-  // CHECK: %[[QUANT:.*]] = "tf.QuantizeAndDequantizeV4"(%arg0, %arg1, %arg2) {axis = -1 : i64, device = "/job:localhost/replica:0/task:0/device:GPU:0", narrow_range = false, num_bits = 8 : i64, range_given = false, round_mode = "HALF_TO_EVEN", signed_input = true} : (tensor<?x?xf32>, tensor<f32>, tensor<f32>) -> tensor<?x?xf32>
+  // CHECK: %[[QUANT:.*]] = "tf.QuantizeAndDequantizeV4"(%arg0, %arg1, %arg2) <{axis = -1 : i64, narrow_range = false, num_bits = 8 : i64, range_given = false, round_mode = "HALF_TO_EVEN", signed_input = true}> {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<?x?xf32>, tensor<f32>, tensor<f32>) -> tensor<?x?xf32>
   // CHECK: return %[[QUANT]] : tensor<?x?xf32>
 }
 
@@ -1991,10 +1991,10 @@
 func.func @testHashTableAndInitializeTableToV2(%arg0: tensor<!tf_type.string>) {
   // CHECK: [[handle:%.*]] = "tf.HashTableV2"()
   // CHECK-SAME: container = ""
-  // CHECK-SAME: device = ""
   // CHECK-SAME: key_dtype = !tf_type.string
   // CHECK-SAME: shared_name = "table"
   // CHECK-SAME: value_dtype = i32
+  // CHECK-SAME: device = ""
   // CHECK-SAME: () -> tensor<!tf_type.resource>
   %handle = "tf.HashTable"() {container = "", device = "", shared_name = "table", key_dtype = !tf_type.string, value_dtype = i32} : () -> tensor<*x!tf_type.stringref>
 
@@ -2008,10 +2008,10 @@
 func.func @testHashTableAndLookupTableSizeToV2() -> tensor<i64> {
   // CHECK: [[handle:%.*]] = "tf.HashTableV2"()
   // CHECK-SAME: container = ""
-  // CHECK-SAME: device = ""
   // CHECK-SAME: key_dtype = !tf_type.string
   // CHECK-SAME: shared_name = "table"
   // CHECK-SAME: value_dtype = i32
+  // CHECK-SAME: device = ""
   // CHECK-SAME: () -> tensor<!tf_type.resource>
   %handle = "tf.HashTable"() {container = "", device = "", shared_name = "table", key_dtype = !tf_type.string, value_dtype = i32} : () -> tensor<*x!tf_type.stringref>
 
@@ -2025,10 +2025,10 @@
 func.func @testHashTableAndLookupTableFindToV2(%arg0: tensor<!tf_type.string>, %arg1: tensor<i32>) -> tensor<i32> {
   // CHECK: [[handle:%.*]] = "tf.HashTableV2"()
   // CHECK-SAME: container = ""
-  // CHECK-SAME: device = ""
   // CHECK-SAME: key_dtype = !tf_type.string
   // CHECK-SAME: shared_name = "table"
   // CHECK-SAME: value_dtype = i32
+  // CHECK-SAME: device = ""
   // CHECK-SAME: () -> tensor<!tf_type.resource>
   %handle = "tf.HashTable"() {container = "", device = "", shared_name = "table", key_dtype = !tf_type.string, value_dtype = i32} : () -> tensor<*x!tf_type.stringref>
 
@@ -2041,9 +2041,9 @@
 // CHECK-LABEL: testDivNoNanAndMulNoNanWithConstantY
 // CHECK-SAME: (%[[ARG0:.*]]: tensor<2xf32>)
 func.func @testDivNoNanAndMulNoNanWithConstantY(%arg0: tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) {
-  // CHECK: %[[CON1:.*]] = "tf.Const"() {value = dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32>} : () -> tensor<2xf32>
-  // CHECK-NEXT: %[[CON2:.*]] = "tf.Const"() {value = dense<[1.000000e+01, 0.000000e+00]> : tensor<2xf32>} : () -> tensor<2xf32>
-  // CHECK-NEXT: %[[CON3:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32>
+  // CHECK: %[[CON1:.*]] = "tf.Const"() <{value = dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32>}> : () -> tensor<2xf32>
+  // CHECK-NEXT: %[[CON2:.*]] = "tf.Const"() <{value = dense<[1.000000e+01, 0.000000e+00]> : tensor<2xf32>}> : () -> tensor<2xf32>
+  // CHECK-NEXT: %[[CON3:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<2xf32>}> : () -> tensor<2xf32>
   // CHECK-NEXT: %[[RES1:.*]] = "tf.Div"(%[[ARG0]], %[[CON1]]) {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
   // CHECK-NEXT: %[[RES2:.*]] = "tf.MulNoNan"(%[[ARG0]], %[[CON2]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
   // CHECK-NEXT: return %[[RES1]], %[[RES2]], %[[CON3]] : tensor<2xf32>, tensor<2xf32>, tensor<2xf32>
@@ -2060,9 +2060,9 @@
 // CHECK-LABEL: testComplexDivNoNanAndMulNoNanWithConstantY
 // CHECK-SAME: (%[[ARG0:.*]]: tensor<2xcomplex<f32>>)
 func.func @testComplexDivNoNanAndMulNoNanWithConstantY(%arg0: tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) {
-  // CHECK-NEXT: %[[COMP2:.*]] = "tf.Const"() {value = dense<[(0.000000e+00,0.000000e+00), (2.000000e+00,0.000000e+00)]> : tensor<2xcomplex<f32>>} : () -> tensor<2xcomplex<f32>>
-  // CHECK-NEXT: %[[COMP1:.*]] = "tf.Const"() {value = dense<[(1.000000e+00,3.000000e+00), (2.000000e+00,4.000000e+00)]> : tensor<2xcomplex<f32>>} : () -> tensor<2xcomplex<f32>>
-  // CHECK-NEXT: %[[COMP3:.*]] = "tf.Const"() {value = dense<(0.000000e+00,0.000000e+00)> : tensor<2xcomplex<f32>>} : () -> tensor<2xcomplex<f32>>
+  // CHECK-NEXT: %[[COMP2:.*]] = "tf.Const"() <{value = dense<[(0.000000e+00,0.000000e+00), (2.000000e+00,0.000000e+00)]> : tensor<2xcomplex<f32>>}> : () -> tensor<2xcomplex<f32>>
+  // CHECK-NEXT: %[[COMP1:.*]] = "tf.Const"() <{value = dense<[(1.000000e+00,3.000000e+00), (2.000000e+00,4.000000e+00)]> : tensor<2xcomplex<f32>>}> : () -> tensor<2xcomplex<f32>>
+  // CHECK-NEXT: %[[COMP3:.*]] = "tf.Const"() <{value = dense<(0.000000e+00,0.000000e+00)> : tensor<2xcomplex<f32>>}> : () -> tensor<2xcomplex<f32>>
   // CHECK-NEXT: %[[RES1:.*]] = "tf.Mul"(%[[ARG0]], %[[COMP1]]) {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>>
   // CHECK-NEXT: %[[RES2:.*]] = "tf.DivNoNan"(%[[ARG0]], %[[COMP2]]) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>>
   // CHECK-NEXT: return %[[RES1]], %[[RES2]], %[[COMP3]] : tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>
@@ -2104,7 +2104,7 @@
 // CHECK-LABEL: testComplexDivNoNanOpWithNonConstantY
 // CHECK-SAME: (%[[ARG0:.*]]: tensor<2xcomplex<f32>>, %[[ARG1:.*]]: tensor<2xcomplex<f32>>, %[[ARG2:.*]]: tensor<2xf32>)
 func.func @testComplexDivNoNanOpWithNonConstantY(%arg0: tensor<2xcomplex<f32>>, %arg1: tensor<2xcomplex<f32>>, %arg2: tensor<2xf32>) -> (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) {
-  // CHECK: %[[CON1:.*]] = "tf.Const"() {value = dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32>} : () -> tensor<2xf32>
+  // CHECK: %[[CON1:.*]] = "tf.Const"() <{value = dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32>}> : () -> tensor<2xf32>
   // CHECK-NEXT: %[[NONCON2:.*]] = "tf.Sub"(%[[ARG0]], %[[ARG1]]) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>>
   // CHECK-NEXT: %[[NONCON3:.*]] = "tf.Complex"(%[[CON1]], %[[ARG2]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xcomplex<f32>>
   // CHECK-NEXT: %[[RES1:.*]] = "tf.MulNoNan"(%[[ARG0]], %[[ARG1]]) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>
@@ -2128,7 +2128,7 @@
   %rhs_dilation = "tf.Const"() {value = dense<1> : tensor<3xi32>} : () -> tensor<3xi32>
   %padding = "tf.Const"() {value = dense<0> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
   %strides = "tf.Const"() {value = dense<[3, 1, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
-  // CHECK: "tf.XlaConvV2"(%arg0, %arg1, %cst_3, %cst_2, %cst_0, %cst_1, %cst) {batch_group_count = 1 : i64, device = "/job:localhost/replica:0/task:0/device:GPU:0", dimension_numbers = "\18\03 \042\03\00\01\02@\04P\04Z\03\01\02\03b\03\01\02\03", precision_config = ""} : (tensor<8x4x16x16x16xf32>, tensor<4x3x3x16x16xf32>, tensor<3xi32>, tensor<3x2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<8x4x14x14x16xf32>
+  // CHECK: "tf.XlaConvV2"(%arg0, %arg1, %cst_3, %cst_2, %cst_0, %cst_1, %cst) <{batch_group_count = 1 : i64, dimension_numbers = "\18\03 \042\03\00\01\02@\04P\04Z\03\01\02\03b\03\01\02\03", precision_config = ""}> {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<8x4x16x16x16xf32>, tensor<4x3x3x16x16xf32>, tensor<3xi32>, tensor<3x2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<8x4x14x14x16xf32>
   %0 = "tf.XlaConv"(%lhs, %rhs, %strides, %padding, %lhs_dilation, %rhs_dilation, %feature_group_count) {dimension_numbers = "\18\03 \042\03\00\01\02@\04P\04Z\03\01\02\03b\03\01\02\03", precision_config = "", device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<8x4x16x16x16xf32>, tensor<4x3x3x16x16xf32>, tensor<3xi32>, tensor<3x2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<8x4x14x14x16xf32>
   func.return %0 : tensor<8x4x14x14x16xf32>
 }
@@ -2136,7 +2136,7 @@
 
 // CHECK-LABEL: testXlaReduceToXlaVariadicReduceV2
 func.func @testXlaReduceToXlaVariadicReduceV2(%arg0: tensor<*xbf16>, %arg1: tensor<*xbf16>) -> tensor<*xbf16> {
-  // CHECK: "tf.XlaVariadicReduceV2"(%arg0, %arg1) {device = "/job:localhost/replica:0/task:0/device:GPU:0", dimensions_to_reduce = [], operandSegmentSizes = array<i32: 1, 1>, reducer = @sum1} : (tensor<*xbf16>, tensor<*xbf16>) -> tensor<*xbf16>
+  // CHECK: "tf.XlaVariadicReduceV2"(%arg0, %arg1) <{dimensions_to_reduce = [], operandSegmentSizes = array<i32: 1, 1>, reducer = @sum1}> {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<*xbf16>, tensor<*xbf16>) -> tensor<*xbf16>
   %0 = "tf.XlaReduce"(%arg0, %arg1) {dimensions_to_reduce = [], reducer = @sum1, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<*xbf16>, tensor<*xbf16>) -> tensor<*xbf16>
   func.return %0 : tensor<*xbf16>
 }
@@ -2148,7 +2148,7 @@
 
 // CHECK-LABEL: testXlaVariadicReduceToV2
 func.func @testXlaVariadicReduceToV2(%arg0: tensor<3x4xf32>, %arg1: tensor<f32>) -> tensor<?x?xf32> {
-  // CHECK:  "tf.XlaVariadicReduceV2"(%arg0, %arg1) {device = "/job:localhost/replica:0/task:0/device:GPU:0", dimensions_to_reduce = [], operandSegmentSizes = array<i32: 1, 1>, reducer = @sum2} : (tensor<3x4xf32>, tensor<f32>) -> tensor<?x?xf32>
+  // CHECK:  "tf.XlaVariadicReduceV2"(%arg0, %arg1) <{dimensions_to_reduce = [], operandSegmentSizes = array<i32: 1, 1>, reducer = @sum2}> {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<3x4xf32>, tensor<f32>) -> tensor<?x?xf32>
   %0 = "tf.XlaVariadicReduce"(%arg0, %arg1) {dimensions_to_reduce = [], reducer = @sum2, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<3x4xf32>, tensor<f32>) -> tensor<?x?xf32>
   func.return %0 : tensor<?x?xf32>
 }
@@ -2229,8 +2229,8 @@
   %1 = "tf.TensorListGetItem"(%0, %arg2, %arg1) {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<!tf_type.variant<tensor<1x32xf32>>>, tensor<i32>, tensor<2xi32>) -> tensor<1x32xf32>
   func.return %1 : tensor<1x32xf32>
 
-  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-  // CHECK: %[[RES:.*]] = "tf.GatherV2"(%arg0, %arg2, %cst) {batch_dims = 0 : i64, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<1600x1x32xf32>, tensor<i32>, tensor<i32>) -> tensor<1x32xf32>
+  // CHECK: %[[CST:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
+  // CHECK: %[[RES:.*]] = "tf.GatherV2"(%arg0, %arg2, %cst) <{batch_dims = 0 : i64}> {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<1600x1x32xf32>, tensor<i32>, tensor<i32>) -> tensor<1x32xf32>
 }
 
 // CHECK-LABEL: testTensorListGetItemMultipleUsers
@@ -2240,9 +2240,9 @@
   %2 = "tf.TensorListGetItem"(%0, %arg3, %arg1) {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<!tf_type.variant<tensor<1x32xf32>>>, tensor<i32>, tensor<2xi32>) -> tensor<1x32xf32>
   func.return %1, %2 : tensor<1x32xf32>, tensor<1x32xf32>
 
-  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-  // CHECK: %[[RES0:.*]] = "tf.GatherV2"(%arg0, %arg2, %cst) {batch_dims = 0 : i64, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<1600x1x32xf32>, tensor<i32>, tensor<i32>) -> tensor<1x32xf32>
-  // CHECK: %[[RES1:.*]] = "tf.GatherV2"(%arg0, %arg3, %cst) {batch_dims = 0 : i64, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<1600x1x32xf32>, tensor<i32>, tensor<i32>) -> tensor<1x32xf32>
+  // CHECK: %[[CST:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
+  // CHECK: %[[RES0:.*]] = "tf.GatherV2"(%arg0, %arg2, %cst) <{batch_dims = 0 : i64}> {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<1600x1x32xf32>, tensor<i32>, tensor<i32>) -> tensor<1x32xf32>
+  // CHECK: %[[RES1:.*]] = "tf.GatherV2"(%arg0, %arg3, %cst) <{batch_dims = 0 : i64}> {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<1600x1x32xf32>, tensor<i32>, tensor<i32>) -> tensor<1x32xf32>
 }
 
 // CHECK-LABEL: testUnaryIdempotent
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize_compile_and_replicate_attributes.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize_compile_and_replicate_attributes.mlir
index 9f6dc7b..2243958 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize_compile_and_replicate_attributes.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize_compile_and_replicate_attributes.mlir
@@ -3,7 +3,7 @@
 // CHECK-LABEL: func.func @convert_tpu_replicate
 func.func @convert_tpu_replicate() {
   tf_executor.graph {
-    // CHECK: tf_executor.island wraps "tf.TPUReplicateMetadata"() {_replication_info = "cluster", _xla_compile_device_type = "TPU", allow_soft_placement = false, computation_shape = [], device = "", device_assignment = [], host_compute_core = [], name = "TPUReplicateMetadata", num_cores_per_replica = 1 : i64, num_replicas = 1 : i64, step_marker_location = "STEP_MARK_AT_ENTRY", topology = "", use_spmd_for_xla_partitioning = false, use_tpu = true} : () -> ()
+    // CHECK: tf_executor.island wraps "tf.TPUReplicateMetadata"() <{allow_soft_placement = false, computation_shape = [], device_assignment = [], host_compute_core = [], num_cores_per_replica = 1 : i64, num_replicas = 1 : i64, step_marker_location = "STEP_MARK_AT_ENTRY", topology = "", use_spmd_for_xla_partitioning = false, use_tpu = true}> {_replication_info = "cluster", _xla_compile_device_type = "TPU", device = "", name = "TPUReplicateMetadata"} : () -> ()
     %control = tf_executor.island wraps "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", allow_soft_placement = false, computation_shape = [], device = "", device_assignment = [], host_compute_core = [], name = "TPUReplicateMetadata", num_cores_per_replica = 1 : i64, num_replicas = 1 : i64, step_marker_location = "STEP_MARK_AT_ENTRY", topology = "", use_tpu = true, use_spmd_for_xla_partitioning = false} : () -> ()
     %outputs_0, %control_0 = tf_executor.island wraps "tf.Placeholder"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "y", shape = "tfshape$dim { }"} : () -> tensor<0xf32>
     %outputs_1, %control_1 = tf_executor.island wraps "tf.TPUReplicatedInput"(%outputs_0) {N = 1 : i64, T = "tfdtype$DT_FLOAT", device = "", name = "input1"} : (tensor<0xf32>) -> tensor<0xf32>
@@ -21,7 +21,7 @@
 
 // CHECK-LABEL: func.func @convert_xla_must_compile
 func.func @convert_xla_must_compile(%arg0: tensor<i32>) -> tensor<i32> {
-  // CHECK: "tf.StatefulPartitionedCall"(%arg0) {_xla_compile_device_type = "CPU", config = "", config_proto = "", device = "/device:CPU:0", executor_type = "", f = @stateful_pcall_func} : (tensor<i32>) -> tensor<i32>
+  // CHECK: "tf.StatefulPartitionedCall"(%arg0) <{config = "", config_proto = "", executor_type = "", f = @stateful_pcall_func}> {_xla_compile_device_type = "CPU", device = "/device:CPU:0"} : (tensor<i32>) -> tensor<i32>
   %0 = "tf.StatefulPartitionedCall"(%arg0) {_XlaMustCompile = true, config = "", config_proto = "", device = "/device:CPU:0", executor_type = "", f = @stateful_pcall_func} : (tensor<i32>) -> (tensor<i32>)
   func.return %0 : tensor<i32>
 }
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/cluster_formation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/cluster_formation.mlir
index 16acdce..5f04582 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/cluster_formation.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/cluster_formation.mlir
@@ -10,6 +10,7 @@
     %2 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
 
     // CHECK: %[[TPU0_OUTPUT:[0-9]*]] = "tf_device.launch"
+    // CHECK-SAME: <{device = "tpu0"}>
     // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[A_OUTPUT]]) : (tensor<?xi32>) -> tensor<?xi32>
     %3 = "tf.B"(%2) {device = "tpu0"} : (tensor<?xi32>) -> tensor<?xi32>
 
@@ -17,7 +18,7 @@
     %4 = "tf.C"(%2, %3) {device = "tpu0"} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
 
     // CHECK: tf_device.return %[[C_OUTPUT]]
-    // CHECK: {device = "tpu0"} : () -> tensor<?xi32>
+    // CHECK: : () -> tensor<?xi32>
 
     // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[TPU0_OUTPUT]])
     %5 = "tf.D"(%4) : (tensor<?xi32>) -> tensor<?xi32>
@@ -40,6 +41,7 @@
         %2 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
 
         // CHECK: %[[TPU0_OUTPUT:[0-9]*]] = "tf_device.launch"
+        // CHECK-SAME: <{device = "tpu0"}>
         // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[A_OUTPUT]]) : (tensor<?xi32>) -> tensor<?xi32>
         %3 = "tf.B"(%2) {device = "tpu0"} : (tensor<?xi32>) -> tensor<?xi32>
 
@@ -47,7 +49,7 @@
         %4 = "tf.C"(%2, %3) {device = "tpu0"} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
 
         // CHECK: tf_device.return %[[C_OUTPUT]]
-        // CHECK: {device = "tpu0"} : () -> tensor<?xi32>
+        // CHECK: : () -> tensor<?xi32>
 
         // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[TPU0_OUTPUT]])
         %5 = "tf.D"(%4) : (tensor<?xi32>) -> tensor<?xi32>
@@ -71,6 +73,7 @@
       %1:2 = tf_executor.island {
 
         // CHECK: %[[TPU0_OUTPUT:[0-9]*]] = "tf_device.launch"
+        // CHECK-SAME: <{device = "tpu0"}>
         // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"(%[[ARG_0]]) : (tensor<?xi32>) -> tensor<?xi32>
         %3 = "tf.A"(%arg0) {device = "tpu0"} : (tensor<?xi32>) -> tensor<?xi32>
 
@@ -78,7 +81,7 @@
         %4 = "tf.B"(%3, %arg0) {device = "tpu0"} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
 
         // CHECK: tf_device.return %[[B_OUTPUT]]
-        // CHECK: {device = "tpu0"} : () -> tensor<?xi32>
+        // CHECK: : () -> tensor<?xi32>
 
         // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[TPU0_OUTPUT]])
         %5 = "tf.C"(%4) : (tensor<?xi32>) -> tensor<?xi32>
@@ -104,6 +107,7 @@
 
       %2:2 = tf_executor.island {
         // CHECK: %[[TPU0_OUTPUT:[0-9]*]] = "tf_device.launch"
+        // CHECK: <{device = "tpu0"}>
         // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"(%[[ARG_0]]) : (tensor<?xi32>) -> tensor<?xi32>
         %3 = "tf.A"(%arg0) {device = "tpu0"} : (tensor<?xi32>) -> tensor<?xi32>
 
@@ -111,7 +115,7 @@
         %4 = "tf.B"(%3, %1#0) {device = "tpu0"} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
 
         // CHECK: tf_device.return %[[B_OUTPUT]]
-        // CHECK: {device = "tpu0"} : () -> tensor<?xi32>
+        // CHECK: : () -> tensor<?xi32>
 
         // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[TPU0_OUTPUT]])
         %5 = "tf.C"(%4) : (tensor<?xi32>) -> tensor<?xi32>
@@ -135,11 +139,12 @@
       %1:2 = tf_executor.island {
 
         // CHECK: %[[TPU0_OUTPUT:[0-9]*]] = "tf_device.launch"
+        // CHECK: <{device = "tpu0"}>
         // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"() : () -> tensor<?xi32>
         %3 = "tf.A"() {device = "tpu0"} : () -> tensor<?xi32>
 
         // CHECK: tf_device.return %[[A_OUTPUT]]
-        // CHECK: {device = "tpu0"} : () -> tensor<?xi32>
+        // CHECK: : () -> tensor<?xi32>
 
         // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[TPU0_OUTPUT]])
         %4 = "tf.B"(%3) : (tensor<?xi32>) -> tensor<?xi32>
@@ -166,6 +171,7 @@
         %2 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
 
         // CHECK: %[[TPU0_OUTPUT:[0-9]*]] = "tf_device.launch"
+        // CHECK: <{device = "tpu0"}>
         // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[A_OUTPUT]]) : (tensor<?xi32>) -> tensor<?xi32>
         %3 = "tf.B"(%2) {device = "tpu0"} : (tensor<?xi32>) -> tensor<?xi32>
 
@@ -173,7 +179,7 @@
         %4 = "tf.C"(%2, %3) {device = "tpu0"} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
 
         // CHECK: tf_device.return %[[C_OUTPUT]]
-        // CHECK: {device = "tpu0"} : () -> tensor<?xi32>
+        // CHECK: : () -> tensor<?xi32>
 
         // CHECK: %[[GPU0_OUTPUT:[0-9]*]] = "tf_device.launch"
         // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[TPU0_OUTPUT]]) : (tensor<?xi32>) -> tensor<?xi32>
@@ -204,6 +210,7 @@
         %2 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
 
         // CHECK: %[[TPU0_OUTPUT:[0-9]*]] = "tf_device.launch"
+        // CHECK: <{device = "tpu0"}>
         // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[A_OUTPUT]]) : (tensor<?xi32>) -> tensor<?xi32>
         %3 = "tf.B"(%2) {device = "tpu0"} : (tensor<?xi32>) -> tensor<?xi32>
 
@@ -211,7 +218,7 @@
         %4 = "tf.C"(%2, %3) {device = "tpu0"} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
 
         // CHECK: tf_device.return %[[C_OUTPUT]]
-        // CHECK: {device = "tpu0"} : () -> tensor<?xi32>
+        // CHECK: : () -> tensor<?xi32>
 
         // CHECK: %[[GPU0_OUTPUT:[0-9]*]] = "tf_device.launch"
         // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[A_OUTPUT]]) : (tensor<?xi32>) -> tensor<?xi32>
@@ -248,6 +255,7 @@
         // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[ARG_0]])
 
         // CHECK: %[[TPU0_OUTPUT:[0-9]*]] = "tf_device.launch"
+        // CHECK: <{device = "tpu0"}>
         // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[A_OUTPUT]]) : (tensor<?xi32>) -> tensor<?xi32>
         %3 = "tf.B"(%2) {device = "tpu0"} : (tensor<?xi32>) -> tensor<?xi32>
 
@@ -257,7 +265,6 @@
         %5 = "tf.D"(%2, %3) {device = "tpu0"} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
 
         // CHECK: tf_device.return %[[D_OUTPUT]]
-        // CHECK: {device = "tpu0"} : () -> tensor<?xi32>
 
         // CHECK: %[[E_OUTPUT:[0-9]*]] = "tf.E"(%[[C_OUTPUT]], %[[TPU0_OUTPUT]]) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
         %6 = "tf.E"(%4, %5) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
@@ -296,12 +303,11 @@
         %4 = "tf.C"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
 
         // CHECK: %[[TPU0_OUTPUT1:[0-9]*]] = "tf_device.launch"
+        // CHECK: <{device = "tpu0"}>
         // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[A_OUTPUT]], %[[TPU0_OUTPUT0]]) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
         // CHECK: tf_device.return %[[D_OUTPUT]]
         %5 = "tf.D"(%2, %3) {device = "tpu0"} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
 
-        // CHECK: {device = "tpu0"} : () -> tensor<?xi32>
-
         // CHECK: %[[E_OUTPUT:[0-9]*]] = "tf.E"(%[[C_OUTPUT]], %[[TPU0_OUTPUT1]]) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
         %6 = "tf.E"(%4, %5) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
 
@@ -358,11 +364,12 @@
         %2 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
 
         // CHECK: %[[GPU0_OUTPUT:[0-9]*]] = "tf_device.launch"
+        // CHECK: <{device = "gpu0"}>
         // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[ARG_0]])
         // CHECK: tf_device.return %[[C_OUTPUT]]
-        // CHECK: {device = "gpu0"} : () -> tensor<?xi32>
 
         // CHECK: %[[TPU0_OUTPUT:[0-9]*]] = "tf_device.launch"
+        // CHECK: <{device = "tpu0"}>
         // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[A_OUTPUT]]) : (tensor<?xi32>) -> tensor<?xi32>
         %3 = "tf.B"(%2) {device = "tpu0"} : (tensor<?xi32>) -> tensor<?xi32>
 
@@ -372,7 +379,6 @@
         %5 = "tf.D"(%2, %3) {device = "tpu0"} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
 
         // CHECK: tf_device.return %[[D_OUTPUT]]
-        // CHECK: {device = "tpu0"} : () -> tensor<?xi32>
 
         // CHECK: %[[E_OUTPUT:[0-9]*]] = "tf.E"(%[[GPU0_OUTPUT]], %[[TPU0_OUTPUT]]) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
         %6 = "tf.E"(%4, %5) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/cluster_outlining.mlir b/tensorflow/compiler/mlir/tensorflow/tests/cluster_outlining.mlir
index a77e449..90f1cfc 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/cluster_outlining.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/cluster_outlining.mlir
@@ -10,7 +10,7 @@
       // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"(%[[ARG_0]])
       %2 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
 
-      // CHECK: %[[CLUSTER_OUTPUT:[0-9]*]] = "tf_device.cluster_func"(%[[A_OUTPUT]]) {func = @[[CLUSTER:.*]]}
+      // CHECK: %[[CLUSTER_OUTPUT:[0-9]*]] = "tf_device.cluster_func"(%[[A_OUTPUT]]) <{func = @[[CLUSTER:.*]]}>
       %3 = "tf_device.cluster"() ({
         %4 = "tf.B"(%2) : (tensor<?xi32>) -> tensor<?xi32>
         tf_device.return %4 : tensor<?xi32>
@@ -42,7 +42,7 @@
       // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"(%[[ARG_0]])
       %2 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
 
-      // CHECK: %[[CLUSTER_0_OUTPUT:[0-9]*]] = "tf_device.cluster_func"(%[[A_OUTPUT]]) {func = @[[CLUSTER_0:.*]]}
+      // CHECK: %[[CLUSTER_0_OUTPUT:[0-9]*]] = "tf_device.cluster_func"(%[[A_OUTPUT]]) <{func = @[[CLUSTER_0:.*]]}>
       %3 = "tf_device.cluster"() ({
         %6 = "tf.B"(%2) : (tensor<?xi32>) -> tensor<?xi32>
         tf_device.return %6 : tensor<?xi32>
@@ -51,7 +51,7 @@
       // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[CLUSTER_0_OUTPUT]])
       %4 = "tf.D"(%3) : (tensor<?xi32>) -> tensor<?xi32>
 
-      // CHECK: %[[CLUSTER_1_OUTPUT:[0-9]*]] = "tf_device.cluster_func"(%[[CLUSTER_0_OUTPUT]], %[[D_OUTPUT]]) {func = @[[CLUSTER_1:.*]]}
+      // CHECK: %[[CLUSTER_1_OUTPUT:[0-9]*]] = "tf_device.cluster_func"(%[[CLUSTER_0_OUTPUT]], %[[D_OUTPUT]]) <{func = @[[CLUSTER_1:.*]]}>
       %5 = "tf_device.cluster"() ({
         %6 = "tf.E"(%3) : (tensor<?xi32>) -> tensor<?xi32>
         %7 = "tf.F"(%4, %6) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
@@ -86,7 +86,7 @@
 func.func @cluster_operands(%arg0: tensor<?xi32>) -> tensor<?xi32> {
   %0 = tf_executor.graph {
     %1:2 = tf_executor.island wraps
-      // CHECK: %[[CLUSTER_OUTPUT:[a-z0-9]*]], %{{.*}} = {{.*}} "tf_device.cluster_func"() {func = @[[CLUSTER:.*]]}
+      // CHECK: %[[CLUSTER_OUTPUT:[a-z0-9]*]], %{{.*}} = {{.*}} "tf_device.cluster_func"() <{func = @[[CLUSTER:.*]]}>
       "tf_device.cluster"() ({
         %3 = "tf.A"() : () -> tensor<?xi32>
         tf_device.return %3 : tensor<?xi32>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/cluster_tf_ops_pass.mlir b/tensorflow/compiler/mlir/tensorflow/tests/cluster_tf_ops_pass.mlir
index 8d94a07..85f9c47 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/cluster_tf_ops_pass.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/cluster_tf_ops_pass.mlir
@@ -43,7 +43,7 @@
 // CHECK: func @while_body(%[[ARG_0:.*]]: tensor<i32> {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"})
 // CHECK-NEXT:   %[[RESULT_0:.*]] = "tf.Const"()
 // CHECK-NEXT:   %[[RESULT_1:.*]] = "tf.AddV2"(%[[ARG_0]], %[[RESULT_0]])
-// CHECK-NEXT:   %[[RESULT_2:.*]] = "tf.Const"() {value = dense<16> : tensor<i32>} : () -> tensor<i32>
+// CHECK-NEXT:   %[[RESULT_2:.*]] = "tf.Const"() <{value = dense<16> : tensor<i32>}> : () -> tensor<i32>
 // CHECK-NEXT:   tf_device.send %[[RESULT_2]] "key-0" "/job:worker/replica:0/task:1/device:CPU:0"
 // CHECK-SAME:  device = "/job:localhost/replica:0/task:0/device:CPU:0"
 // CHECK-NEXT:   tf_device.remote_run "/job:worker/replica:0/task:1" @[[BODY_PARTITION_0:.*]]() : () -> ()
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir
index 3231f05..be0ab85 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir
@@ -9,7 +9,7 @@
 
   // Result shape need not be static. Folding harness uses TensorFlow constant
   // in that case.
-  // CHECK-DAG: "tf.Const"() {value = dense<[1, 32, 32, 16]> : tensor<4xi32>} : () -> tensor<?xi32>
+  // CHECK-DAG: "tf.Const"() <{value = dense<[1, 32, 32, 16]> : tensor<4xi32>}> : () -> tensor<?xi32>
   %1 = "tf.Shape"(%arg1) {T = "tfdtype$DT_FLOAT", output = "tfdtype$DT_INT32"} : (tensor<1x32x32x16xf32>) -> tensor<?xi32>
 
   // CHECK: "tf.Shape"(%arg2) {T = "tfdtype$DT_FLOAT", output = "tfdtype$DT_INT32"} : (tensor<*xf32>) -> tensor<?xi32>
@@ -28,7 +28,7 @@
   // CHECK-DAG: %[[RES_NO_FOLD:.*]] = "tf.Pow"(%arg0, %arg1)
   %0 = "tf.Pow"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
 
-  // CHECK-DAG: %[[POW_ZERO:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32>
+  // CHECK-DAG: %[[POW_ZERO:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<4xf32>}> : () -> tensor<4xf32>
   %1 = "tf.Pow"(%arg0, %cst_zero) : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
 
   // CHECK-NOT: "tf.Pow"
@@ -42,7 +42,7 @@
 func.func @testEmpty32() -> (tensor<5xi32>) {
   %0 = "tf.Const"() { value = dense<5> : tensor<i32> } : () -> tensor<i32>
 
-  // CHECK: [[VAL:%.+]] = "tf.Const"() {value = dense<0> : tensor<5xi32>}
+  // CHECK: [[VAL:%.+]] = "tf.Const"() <{value = dense<0> : tensor<5xi32>}>
   // CHECK: return [[VAL]]
   %1 = "tf.Empty"(%0) : (tensor<i32>) -> (tensor<5xi32>)
   func.return %1 : tensor<5xi32>
@@ -52,7 +52,7 @@
 func.func @testEmpty64() -> (tensor<5xi64>) {
   %0 = "tf.Const"() { value = dense<5> : tensor<i32> } : () -> tensor<i32>
 
-  // CHECK: [[VAL:%.+]] = "tf.Const"() {value = dense<0> : tensor<5xi64>}
+  // CHECK: [[VAL:%.+]] = "tf.Const"() <{value = dense<0> : tensor<5xi64>}>
   // CHECK: return [[VAL]] : tensor<5xi64>
   %1 = "tf.Empty"(%0) : (tensor<i32>) -> (tensor<5xi64>)
   func.return %1 : tensor<5xi64>
@@ -62,7 +62,7 @@
 func.func @testEmptyFloat() -> (tensor<5xf64>) {
   %0 = "tf.Const"() { value = dense<5> : tensor<i32> } : () -> tensor<i32>
 
-  // CHECK: [[VAL:%.+]] = "tf.Const"() {value =  dense<0.000000e+00> : tensor<5xf64>}
+  // CHECK: [[VAL:%.+]] = "tf.Const"() <{value =  dense<0.000000e+00> : tensor<5xf64>}>
   // CHECK: return [[VAL]]
   %1 = "tf.Empty"(%0) : (tensor<i32>) -> (tensor<5xf64>)
   func.return %1 : tensor<5xf64>
@@ -72,7 +72,7 @@
 func.func @testEmptyf16() -> (tensor<5xf16>) {
   %0 = "tf.Const"() { value = dense<5> : tensor<i32> } : () -> tensor<i32>
 
-  // CHECK: [[VAL:%.+]] = "tf.Const"() {value =  dense<0.000000e+00> : tensor<5xf16>}
+  // CHECK: [[VAL:%.+]] = "tf.Const"() <{value =  dense<0.000000e+00> : tensor<5xf16>}>
   // CHECK: return [[VAL]]
   %1 = "tf.Empty"(%0) : (tensor<i32>) -> (tensor<5xf16>)
   func.return %1 : tensor<5xf16>
@@ -82,7 +82,7 @@
 func.func @testEmptybf16() -> (tensor<5xbf16>) {
   %0 = "tf.Const"() { value = dense<5> : tensor<i32> } : () -> tensor<i32>
 
-  // CHECK: [[VAL:%.+]] = "tf.Const"() {value =  dense<0.000000e+00> : tensor<5xbf16>}
+  // CHECK: [[VAL:%.+]] = "tf.Const"() <{value =  dense<0.000000e+00> : tensor<5xbf16>}>
   // CHECK: return [[VAL]]
   %1 = "tf.Empty"(%0) : (tensor<i32>) -> (tensor<5xbf16>)
   func.return %1 : tensor<5xbf16>
@@ -91,8 +91,8 @@
 // CHECK-LABEL: func @testShapeN
 func.func @testShapeN(%arg0: tensor<f32>, %arg1: tensor<1x32x32x16xf32>) -> (tensor<0xi64>, tensor<4xi64>) {
 
-  // CHECK-DAG: %[[SHAPE0:.*]] = "tf.Const"() {value = dense<> : tensor<0xi64>}
-  // CHECK-DAG: %[[SHAPE1:.*]] = "tf.Const"() {value = dense<[1, 32, 32, 16]> : tensor<4xi64>}
+  // CHECK-DAG: %[[SHAPE0:.*]] = "tf.Const"() <{value = dense<> : tensor<0xi64>}>
+  // CHECK-DAG: %[[SHAPE1:.*]] = "tf.Const"() <{value = dense<[1, 32, 32, 16]> : tensor<4xi64>}>
   %0:2 = "tf.ShapeN"(%arg0, %arg1) : (tensor<f32>, tensor<1x32x32x16xf32>) -> (tensor<0xi64>, tensor<4xi64>)
 
   // CHECK: return %[[SHAPE0]], %[[SHAPE1]]
@@ -101,8 +101,8 @@
 
 // CHECK-LABEL: func @testShapeNPartialStatic
 func.func @testShapeNPartialStatic(%arg0: tensor<f32>, %arg1: tensor<2x?x3xf32>, %arg2: tensor<1x32x32x16xf32>, %arg3: tensor<*xf32>) -> (tensor<0xi64>, tensor<3xi64>, tensor<4xi64>, tensor<?xi64>) {
-  // CHECK-DAG: %[[SHAPE0:.*]] = "tf.Const"() {value = dense<> : tensor<0xi64>}
-  // CHECK-DAG: %[[SHAPE2:.*]] = "tf.Const"() {value = dense<[1, 32, 32, 16]> : tensor<4xi64>}
+  // CHECK-DAG: %[[SHAPE0:.*]] = "tf.Const"() <{value = dense<> : tensor<0xi64>}>
+  // CHECK-DAG: %[[SHAPE2:.*]] = "tf.Const"() <{value = dense<[1, 32, 32, 16]> : tensor<4xi64>}>
   // CHECK: %[[SHAPE13:.*]]:2 = "tf.ShapeN"(%arg1, %arg3) : (tensor<2x?x3xf32>, tensor<*xf32>) -> (tensor<3xi64>, tensor<?xi64>)
   %0:4 = "tf.ShapeN"(%arg0, %arg1, %arg2, %arg3) : (tensor<f32>, tensor<2x?x3xf32>, tensor<1x32x32x16xf32>, tensor<*xf32>) -> (tensor<0xi64>, tensor<3xi64>, tensor<4xi64>, tensor<?xi64>)
 
@@ -112,8 +112,8 @@
 
 // CHECK-LABEL: func @testShapeNOneDynamic
 func.func @testShapeNOneDynamic(%arg0: tensor<f32>, %arg1: tensor<1x32x32x16xf32>, %arg2: tensor<*xf32>) -> (tensor<0xi64>, tensor<4xi64>, tensor<?xi64>) {
-  // CHECK-DAG: %[[SHAPE0:.*]] = "tf.Const"() {value = dense<> : tensor<0xi64>}
-  // CHECK-DAG: %[[SHAPE1:.*]] = "tf.Const"() {value = dense<[1, 32, 32, 16]> : tensor<4xi64>}
+  // CHECK-DAG: %[[SHAPE0:.*]] = "tf.Const"() <{value = dense<> : tensor<0xi64>}>
+  // CHECK-DAG: %[[SHAPE1:.*]] = "tf.Const"() <{value = dense<[1, 32, 32, 16]> : tensor<4xi64>}>
   // CHECK: %[[SHAPE2:.*]] = "tf.Shape"(%arg2) : (tensor<*xf32>) -> tensor<?xi64>
   %0:3 = "tf.ShapeN"(%arg0, %arg1, %arg2) : (tensor<f32>, tensor<1x32x32x16xf32>, tensor<*xf32>) -> (tensor<0xi64>, tensor<4xi64>, tensor<?xi64>)
 
@@ -140,8 +140,8 @@
   %2 = "tf.LeakyRelu"(%arg0) {alpha = 3.0 : f32} : (tensor<16xf32>) -> tensor<16xf32>
   // CHECK-DAG: [[POS:%.*]] = "tf.Const{{.*}} dense<5.000000e+00> : tensor<f32>
   // CHECK-DAG: [[NEG:%.*]] = "tf.Const{{.*}} dense<-1.000000e+00> : tensor<f32>
-  // CHECK: [[NC1:%.*]] = "tf.LeakyRelu"(%arg0) {alpha = 2.000000e-01 : f32} : (tensor<16xf32>) -> tensor<16xf32>
-  // CHECK: [[NC2:%.*]] = "tf.LeakyRelu"(%arg0) {alpha = 3.000000e+00 : f32} : (tensor<16xf32>) -> tensor<16xf32>
+  // CHECK: [[NC1:%.*]] = "tf.LeakyRelu"(%arg0) <{alpha = 2.000000e-01 : f32}> : (tensor<16xf32>) -> tensor<16xf32>
+  // CHECK: [[NC2:%.*]] = "tf.LeakyRelu"(%arg0) <{alpha = 3.000000e+00 : f32}> : (tensor<16xf32>) -> tensor<16xf32>
   // CHECK: return [[NC1]], [[POS]], [[NEG]], [[NC2]]
   func.return %no, %0, %1, %2 : tensor<16xf32>, tensor<f32>, tensor<f32>, tensor<16xf32>
 }
@@ -295,8 +295,8 @@
   %3 = "tf.Minimum"(%0, %1) {random_attr = "hello"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
   func.return %2, %3: tensor<i32>, tensor<i32>
 
-// CHECK-DAG: %[[CST:.*]] = "tf.Const"() {value = dense<2> : tensor<i32>} : () -> tensor<i32>
-// CHECK-DAG: %[[CST1:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG: %[[CST:.*]] = "tf.Const"() <{value = dense<2> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG: %[[CST1:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
 // CHECK-NEXT: return %[[CST]], %[[CST1]]
 }
 
@@ -454,7 +454,7 @@
   %1 = "tf.Mul"(%arg0, %0) : (tensor<1x6x8x1xf32>, tensor<f32>) -> tensor<1x6x8x1xf32>
   func.return %1 : tensor<1x6x8x1xf32>
   // CHECK-LABEL: DontRemoveTrivialMul
-  // CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<2.000000e+00> : tensor<f32>} : () -> tensor<f32>
+  // CHECK: %[[CONST:.*]] = "tf.Const"() <{value = dense<2.000000e+00> : tensor<f32>}> : () -> tensor<f32>
   // CHECK: %[[RESULT:.*]] = "tf.Mul"(%arg0, %[[CONST]]) : (tensor<1x6x8x1xf32>, tensor<f32>) -> tensor<1x6x8x1xf32>
   // CHECK: return %[[RESULT]] : tensor<1x6x8x1xf32>
 }
@@ -517,7 +517,7 @@
   %s1 = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
   %r0, %r1 = "tf.BroadcastGradientArgs"(%s0, %s1) {} : (tensor<2xi32>, tensor<2xi32>) -> (tensor<0xi32>, tensor<0xi32>)
 
-  // CHECK-DAG: %[[R:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
+  // CHECK-DAG: %[[R:.*]] = "tf.Const"() <{value = dense<> : tensor<0xi32>}> : () -> tensor<0xi32>
   // CHECK-NOT: tf.BroadcastGradientArgs
   // CHECK: return %[[R]], %[[R]]
 
@@ -529,8 +529,8 @@
   %s0 = "tf.Const"() {value = dense<[4]> : tensor<1xi32>} : () -> tensor<1xi32>
   %s1 = "tf.Const"() {value = dense<[2, 4]> : tensor<2xi32>} : () -> tensor<2xi32>
   %r0, %r1 = "tf.BroadcastGradientArgs"(%s0, %s1) {} : (tensor<1xi32>, tensor<2xi32>) -> (tensor<1xi32>, tensor<0xi32>)
-  // CHECK-DAG: %[[R0:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
-  // CHECK-DAG: %[[R1:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
+  // CHECK-DAG: %[[R0:.*]] = "tf.Const"() <{value = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+  // CHECK-DAG: %[[R1:.*]] = "tf.Const"() <{value = dense<> : tensor<0xi32>}> : () -> tensor<0xi32>
   // CHECK-NOT: tf.BroadcastGradientArgs
   // CHECK: return %[[R0]], %[[R1]]
 
@@ -542,8 +542,8 @@
   %s2 = "tf.Const"() {value = dense<[501, 1, 32, 1280]> : tensor<4xi32>} : () -> tensor<4xi32>
   %s3 = "tf.Const"() {value = dense<[  1, 1,  1, 1280]> : tensor<4xi32>} : () -> tensor<4xi32>
   %r2, %r3 = "tf.BroadcastGradientArgs"(%s2, %s3) {} : (tensor<4xi32>, tensor<4xi32>) -> (tensor<1xi32>, tensor<3xi32>)
-  // CHECK-DAG: %[[R2:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
-  // CHECK-DAG: %[[R3:.*]] = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
+  // CHECK-DAG: %[[R2:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
+  // CHECK-DAG: %[[R3:.*]] = "tf.Const"() <{value = dense<[0, 1, 2]> : tensor<3xi32>}> : () -> tensor<3xi32>
   // CHECK-NOT: tf.BroadcastGradientArgs
   // CHECK: return %[[R2]], %[[R3]]
 
@@ -555,7 +555,7 @@
   %s4 = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32>
   %s5 = "tf.Const"() {value = dense<[1, 1, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
   %r4, %r5 = "tf.BroadcastGradientArgs"(%s4, %s5) {} : (tensor<0xi32>, tensor<3xi32>) -> (tensor<3xi32>, tensor<3xi32>)
-  // CHECK: %[[R0:.*]] = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
+  // CHECK: %[[R0:.*]] = "tf.Const"() <{value = dense<[0, 1, 2]> : tensor<3xi32>}> : () -> tensor<3xi32>
   // CHECK-NOT: tf.BroadcastGradientArgs
   // CHECK: return %[[R0]], %[[R0]]
 
@@ -567,8 +567,8 @@
   %s4 = "tf.Const"() {value = dense<[1, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
   %s5 = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32>
   %r4, %r5 = "tf.BroadcastGradientArgs"(%s4, %s5) {} : (tensor<3xi32>, tensor<0xi32>) -> (tensor<2xi32>, tensor<3xi32>)
-  // CHECK-DAG: %[[R0:.*]] = "tf.Const"() {value = dense<[0, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
-  // CHECK-DAG: %[[R1:.*]] = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
+  // CHECK-DAG: %[[R0:.*]] = "tf.Const"() <{value = dense<[0, 2]> : tensor<2xi32>}> : () -> tensor<2xi32>
+  // CHECK-DAG: %[[R1:.*]] = "tf.Const"() <{value = dense<[0, 1, 2]> : tensor<3xi32>}> : () -> tensor<3xi32>
   // CHECK-NOT: tf.BroadcastGradientArgs
   // CHECK: return %[[R0]], %[[R1]]
 
@@ -580,7 +580,7 @@
   %s4 = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32>
   %s5 = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} : () -> tensor<1xi32>
   %r4, %r5 = "tf.BroadcastGradientArgs"(%s4, %s5) {} : (tensor<0xi32>, tensor<1xi32>) -> (tensor<1xi32>, tensor<1xi32>)
-  // CHECK: %[[R0:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK: %[[R0:.*]] = "tf.Const"() <{value = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
   // CHECK-NOT: tf.BroadcastGradientArgs
   // CHECK: return %[[R0]], %[[R0]]
 
@@ -592,8 +592,8 @@
   %s4 = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32>
   %s5 = "tf.Const"() {value = dense<[2]> : tensor<1xi32>} : () -> tensor<1xi32>
   %r4, %r5 = "tf.BroadcastGradientArgs"(%s4, %s5) {} : (tensor<0xi32>, tensor<1xi32>) -> (tensor<1xi32>, tensor<0xi32>)
-  // CHECK-DAG: %[[R0:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
-  // CHECK-DAG: %[[R1:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
+  // CHECK-DAG: %[[R0:.*]] = "tf.Const"() <{value = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+  // CHECK-DAG: %[[R1:.*]] = "tf.Const"() <{value = dense<> : tensor<0xi32>}> : () -> tensor<0xi32>
   // CHECK-NOT: tf.BroadcastGradientArgs
   // CHECK: return %[[R0]], %[[R1]]
 
@@ -605,8 +605,8 @@
   %s0 = "tf.Const"() {value = dense<[1, 4, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
   %s1 = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi32>} : () -> tensor<2xi32>
   %r0, %r1 = "tf.BroadcastGradientArgs"(%s0, %s1) {} : (tensor<3xi32>, tensor<2xi32>) -> (tensor<2xi32>, tensor<2xi32>)
-  // CHECK-DAG: %[[R0:.*]] = "tf.Const"() {value = dense<[0, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
-  // CHECK-DAG: %[[R1:.*]] = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> tensor<2xi32>
+  // CHECK-DAG: %[[R0:.*]] = "tf.Const"() <{value = dense<[0, 2]> : tensor<2xi32>}> : () -> tensor<2xi32>
+  // CHECK-DAG: %[[R1:.*]] = "tf.Const"() <{value = dense<[0, 1]> : tensor<2xi32>}> : () -> tensor<2xi32>
   // CHECK-NOT: tf.BroadcastGradientArgs
   // CHECK: return %[[R0]], %[[R1]]
 
@@ -618,8 +618,8 @@
   %s0 = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
   %s1 = "tf.Const"() {value = dense<[2, 4]> : tensor<2xi32>} : () -> tensor<2xi32>
   %r0, %r1 = "tf.BroadcastGradientArgs"(%s0, %s1) {} : (tensor<0xi32>, tensor<2xi32>) -> (tensor<2xi32>, tensor<0xi32>)
-  // CHECK-DAG: %[[R0:.*]] = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> tensor<2xi32>
-  // CHECK-DAG: %[[R1:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
+  // CHECK-DAG: %[[R0:.*]] = "tf.Const"() <{value = dense<[0, 1]> : tensor<2xi32>}> : () -> tensor<2xi32>
+  // CHECK-DAG: %[[R1:.*]] = "tf.Const"() <{value = dense<> : tensor<0xi32>}> : () -> tensor<0xi32>
   // CHECK-NOT: tf.BroadcastGradientArgs
   // CHECK: return %[[R0]], %[[R1]]
 
@@ -631,8 +631,8 @@
   %s0 = "tf.Const"() {value = dense<> : tensor<0xi64>} : () -> tensor<0xi64>
   %s1 = "tf.Const"() {value = dense<[2, 4]> : tensor<2xi64>} : () -> tensor<2xi64>
   %r0, %r1 = "tf.BroadcastGradientArgs"(%s0, %s1) {} : (tensor<0xi64>, tensor<2xi64>) -> (tensor<2xi64>, tensor<0xi64>)
-  // CHECK-DAG: %[[R0:.*]] = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64>
-  // CHECK-DAG: %[[R1:.*]] = "tf.Const"() {value = dense<> : tensor<0xi64>} : () -> tensor<0xi64>
+  // CHECK-DAG: %[[R0:.*]] = "tf.Const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64>
+  // CHECK-DAG: %[[R1:.*]] = "tf.Const"() <{value = dense<> : tensor<0xi64>}> : () -> tensor<0xi64>
   // CHECK-NOT: tf.BroadcastGradientArgs
   // CHECK: return %[[R0]], %[[R1]]
 
@@ -643,7 +643,7 @@
 func.func @testEmptyResults(%arg0: tensor<0x2xf32>) -> tensor<0x2xf32> {
   %indices = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
 
-  // CHECK: "tf.Const"() {value = dense<> : tensor<0x2xf32>} : () -> tensor<0x2xf32>
+  // CHECK: "tf.Const"() <{value = dense<> : tensor<0x2xf32>}> : () -> tensor<0x2xf32>
   %0 = "tf.DynamicStitch"(%indices, %arg0) : (tensor<0xi32>, tensor<0x2xf32>) -> tensor<0x2xf32>
   func.return %0 : tensor<0x2xf32>
 }
@@ -668,7 +668,7 @@
   %cst_1 = arith.constant dense<4> : tensor<i32>
   %cst_2 = arith.constant dense<1> : tensor<i32>
 
-  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[0, 1, 2, 3]> : tensor<4xi32>} : () -> tensor<?xi32>
+  // CHECK: %[[CST:.*]] = "tf.Const"() <{value = dense<[0, 1, 2, 3]> : tensor<4xi32>}> : () -> tensor<?xi32>
   // CHECK: return %[[CST]]
   %0 = "tf.Range"(%cst, %cst_1, %cst_2) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
   func.return %0 : tensor<?xi32>
@@ -680,7 +680,7 @@
   %cst_1 = arith.constant dense<4> : tensor<ui32>
   %cst_2 = arith.constant dense<1> : tensor<ui32>
 
-  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[0, 1, 2, 3]> : tensor<4xui32>} : () -> tensor<?xui32>
+  // CHECK: %[[CST:.*]] = "tf.Const"() <{value = dense<[0, 1, 2, 3]> : tensor<4xui32>}> : () -> tensor<?xui32>
   // CHECK: return %[[CST]]
   %0 = "tf.Range"(%cst, %cst_1, %cst_2) : (tensor<ui32>, tensor<ui32>, tensor<ui32>) -> tensor<?xui32>
   func.return %0 : tensor<?xui32>
@@ -692,7 +692,7 @@
   %cst_1 = arith.constant dense<4.0> : tensor<f32>
   %cst_2 = arith.constant dense<1.0> : tensor<f32>
 
-  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
+  // CHECK: %[[CST:.*]] = "tf.Const"() <{value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>}> : () -> tensor<?xf32>
   // CHECK: return %[[CST]]
   %0 = "tf.Range"(%cst, %cst_1, %cst_2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32>
   func.return %0 : tensor<?xf32>
@@ -700,7 +700,7 @@
 
 // CHECK-LABEL: func @testLogicalAndFoldsWithConstantFalse
 func.func @testLogicalAndFoldsWithConstantFalse(%arg0: tensor<i1>) -> (tensor<i1>) {
-  // CHECK: [[CST:%.+]] = "tf.Const"() {value = dense<false> : tensor<i1>} : () -> tensor<i1>
+  // CHECK: [[CST:%.+]] = "tf.Const"() <{value = dense<false> : tensor<i1>}> : () -> tensor<i1>
   %cst = arith.constant dense<false> : tensor<i1>
 
   %0 = "tf.LogicalAnd"(%cst, %arg0) : (tensor<i1>, tensor<i1>) -> tensor<i1>
@@ -711,7 +711,7 @@
 
 // CHECK-LABEL: func @testLogicalAndFoldsWithConstantFalseSecondArg
 func.func @testLogicalAndFoldsWithConstantFalseSecondArg(%arg0: tensor<i1>) -> (tensor<i1>) {
-  // CHECK: [[CST:%.+]] = "tf.Const"() {value = dense<false> : tensor<i1>} : () -> tensor<i1>
+  // CHECK: [[CST:%.+]] = "tf.Const"() <{value = dense<false> : tensor<i1>}> : () -> tensor<i1>
   %cst = arith.constant dense<false> : tensor<i1>
 
   %0 = "tf.LogicalAnd"(%arg0, %cst) : (tensor<i1>, tensor<i1>) -> tensor<i1>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/constant_op_device_assignment.mlir b/tensorflow/compiler/mlir/tensorflow/tests/constant_op_device_assignment.mlir
index 5d890aa..2e32647 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/constant_op_device_assignment.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/constant_op_device_assignment.mlir
@@ -2,8 +2,8 @@
 
 // CHECK: func @replace_const_op_test
 func.func @replace_const_op_test() {
-  // CHECK-NEXT: %[[RESULT_0:.*]] = "tf.Const"() {device = "/job:worker/replica:0/task:0/device:CPU:1", value = dense<2.000000e+00> : tensor<f32>}
-  // CHECK-NEXT: %[[RESULT_1:.*]] = "tf.Const"() {device = "/job:worker/replica:0/task:0/device:CPU:0", value = dense<2.000000e+00> : tensor<f32>}
+  // CHECK-NEXT: %[[RESULT_0:.*]] = "tf.Const"() <{value = dense<2.000000e+00> : tensor<f32>}> {device = "/job:worker/replica:0/task:0/device:CPU:1"}
+  // CHECK-NEXT: %[[RESULT_1:.*]] = "tf.Const"() <{value = dense<2.000000e+00> : tensor<f32>}> {device = "/job:worker/replica:0/task:0/device:CPU:0"}
   // CHECK-NEXT: %[[RESULT_2:.*]] = "tf.AddV2"(%[[RESULT_1]], %[[RESULT_1]]) {device = "/job:worker/replica:0/task:0/device:CPU:0"}
   // CHECK-NEXT: %[[RESULT_3:.*]] = "tf.AddV2"(%[[RESULT_0]], %[[RESULT_0]]) {device = "/job:worker/replica:0/task:0/device:CPU:1"}
   %0 = "tf.Const"() {value = dense<2.000000e+00> : tensor<f32>} : () -> tensor<f32>
@@ -14,7 +14,7 @@
 
 // CHECK: func @no_change_test
 func.func @no_change_test() -> ()  {
-  // CHECK-NEXT: %[[RESULT_0:.*]] = "tf.Const"() {value = dense<1> : tensor<i64>} : () -> tensor<i64>
+  // CHECK-NEXT: %[[RESULT_0:.*]] = "tf.Const"() <{value = dense<1> : tensor<i64>}> : () -> tensor<i64>
   // CHECK-NEXT: %[[RESULT_1:.*]] = "tf.AddV2"(%[[RESULT_0]], %[[RESULT_0]]) : (tensor<i64>, tensor<i64>) -> tensor<i64>
   %0 = "tf.Const"() {value = dense<1> : tensor<i64>} : () -> tensor<i64>
   %1 = "tf.AddV2"(%0, %0) : (tensor<i64>, tensor<i64>) -> tensor<i64>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/convert-tf-control-flow-to-scf.mlir b/tensorflow/compiler/mlir/tensorflow/tests/convert-tf-control-flow-to-scf.mlir
index 7713781..6368981 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/convert-tf-control-flow-to-scf.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/convert-tf-control-flow-to-scf.mlir
@@ -20,11 +20,11 @@
   // CHECK-NEXT: %[[RES:.*]]:2 = scf.if %[[COND]] -> (tensor<*xf32>, tensor<4xf32>) {
   // CHECK-NEXT:   %[[CALL:.*]] = func.call @test_if_then1(%[[ARG1]]) : (tensor<4xf32>) -> tensor<4xf32>
   // CHECK-NEXT:   %[[ADD:.*]] = "tf.AddV2"(%[[CALL]], %[[CALL]]) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
-  // CHECK-NEXT:   %[[CAST:.*]] = "tf.Cast"(%[[CALL]]) {Truncate = false} : (tensor<4xf32>) -> tensor<*xf32>
+  // CHECK-NEXT:   %[[CAST:.*]] = "tf.Cast"(%[[CALL]]) <{Truncate = false}> : (tensor<4xf32>) -> tensor<*xf32>
   // CHECK-NEXT:   scf.yield %[[CAST]], %[[ADD]] : tensor<*xf32>, tensor<4xf32>
   // CHECK-NEXT: } else {
   // CHECK-NEXT:   %[[CALL_0:.*]] = func.call @test_if_else1(%[[ARG1]]) : (tensor<4xf32>) -> tensor<4xf32>
-  // CHECK-NEXT:   %[[CAST_0:.*]] = "tf.Cast"(%[[CALL_0]]) {Truncate = false} : (tensor<4xf32>) -> tensor<*xf32>
+  // CHECK-NEXT:   %[[CAST_0:.*]] = "tf.Cast"(%[[CALL_0]]) <{Truncate = false}> : (tensor<4xf32>) -> tensor<*xf32>
   // CHECK-NEXT:   scf.yield %[[CAST_0]], %[[CALL_0]] : tensor<*xf32>, tensor<4xf32>
   // CHECK-NEXT: }
   // CHECK-NEXT: return %[[RES]]#0, %[[RES]]#1 : tensor<*xf32>, tensor<4xf32>
@@ -72,7 +72,7 @@
   }) {is_stateless = false} : (tensor<f32>, tensor<*xf32>) -> (tensor<f32>, tensor<*xf32>)
   func.return %0#0 : tensor<f32>
 
-  // CHECK-NEXT: %[[CST:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
+  // CHECK-NEXT: %[[CST:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
   // CHECK-NEXT: %[[RES:.*]]:2 = scf.while (%[[ARG3:.*]] = %[[ARG0]], %[[ARG4:.*]] = %[[ARG2]]) : (tensor<f32>, tensor<*xf32>) -> (tensor<f32>, tensor<*xf32>) {
   // CHECK-NEXT:   %[[IDEN:.*]] = "tf.Identity"(%[[ARG3]]) : (tensor<f32>) -> tensor<f32>
   // CHECK-NEXT:   %[[ADD:.*]] = "tf.Add"(%[[ARG1]], %[[ARG3]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/convert_control_to_data_outputs.mlir b/tensorflow/compiler/mlir/tensorflow/tests/convert_control_to_data_outputs.mlir
index 4968054..d473c1a 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/convert_control_to_data_outputs.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/convert_control_to_data_outputs.mlir
@@ -49,7 +49,7 @@
   tf_executor.graph {
     // CHECK: %[[A_CONTROL:.*]] = tf_executor.island wraps "tf.OpA"() : () -> ()
     %control_A = tf_executor.island wraps "tf.OpA"() : () -> ()
-    // CHECK: %[[CHAIN_CONSTANT:.*]], %{{.*}} = tf_executor.island wraps "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+    // CHECK: %[[CHAIN_CONSTANT:.*]], %{{.*}} = tf_executor.island wraps "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
     // CHECK: %[[WHILE_OUT:.*]]:6, %[[WHILE_CONTROL:.*]] = tf_executor.island(%[[A_CONTROL]]) wraps "tf.While"(%[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_2]], %[[CHAIN_CONSTANT]], %[[CHAIN_CONSTANT]])
     %while_out:4, %control_while = tf_executor.island(%control_A) wraps "tf.While"(%arg0, %arg1, %arg2, %arg2) {body = @simple_independent_chains_while_body, cond = @simple_independent_chains_while_cond, is_stateless = false} : (tensor<!tf_type.resource<tensor<f32>>>, tensor<!tf_type.resource<tensor<f32>>>, tensor<f32>, tensor<f32>) -> (tensor<!tf_type.resource<tensor<f32>>>, tensor<!tf_type.resource<tensor<f32>>>, tensor<f32>, tensor<f32>)
     // CHECK: %[[B_CONTROL:.*]] = tf_executor.island(%[[WHILE_CONTROL]]) wraps "tf.OpB"() : () -> ()
@@ -117,7 +117,7 @@
 func.func @intersecting_chains(%arg0: !tf_res, %arg1: !tf_res, %arg2: tensor<f32>) {
   // CHECK: tf_executor.graph {
   tf_executor.graph {
-    // CHECK: %[[CHAIN_CONSTANT:.*]], %{{.*}} = tf_executor.island wraps "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+    // CHECK: %[[CHAIN_CONSTANT:.*]], %{{.*}} = tf_executor.island wraps "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
     // CHECK: %[[WHILE_OUT:.*]]:5, %[[WHILE_CONTROL:.*]] = tf_executor.island wraps "tf.While"(%[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_2]], %[[CHAIN_CONSTANT]])
     %while_out:4, %while_control = tf_executor.island wraps "tf.While"(%arg0, %arg1, %arg2, %arg2) {body = @intersecting_chains_while_body, cond = @intersecting_chains_while_cond, is_stateless = false} : (tensor<!tf_type.resource<tensor<f32>>>, tensor<!tf_type.resource<tensor<f32>>>, tensor<f32>, tensor<f32>) -> (tensor<!tf_type.resource<tensor<f32>>>, tensor<!tf_type.resource<tensor<f32>>>, tensor<f32>, tensor<f32>)
     // CHECK: tf_executor.fetch
@@ -167,12 +167,12 @@
 func.func @multiple_callers(%arg0: !tf_res, %arg1: tensor<f32>) {
   // CHECK: tf_executor.graph {
   tf_executor.graph {
-    // CHECK: %[[CHAIN_CONSTANT_0:.*]], %{{.*}} = tf_executor.island wraps "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+    // CHECK: %[[CHAIN_CONSTANT_0:.*]], %{{.*}} = tf_executor.island wraps "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
     // CHECK: %[[WHILE_OUT:.*]]:3, %[[WHILE_CONTROL:.*]] = tf_executor.island wraps "tf.While"(%[[ARG_0]], %[[ARG_1]], %[[CHAIN_CONSTANT_0]])
     %while_0_out:2, %while_0_control = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @multiple_callers_while_body, cond = @multiple_callers_while_cond, is_stateless = false} : (tensor<!tf_type.resource<tensor<f32>>>, tensor<f32>) -> (tensor<!tf_type.resource<tensor<f32>>>, tensor<f32>)
     // CHECK: %[[CONTROL_A:.*]] = tf_executor.island(%[[WHILE_CONTROL]]) wraps "tf.OpA"() : () -> ()
     %control_A = tf_executor.island(%while_0_control) wraps "tf.OpA"() : () -> ()
-    // CHECK: %[[CHAIN_CONSTANT_1:.*]], %{{.*}} = tf_executor.island wraps "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+    // CHECK: %[[CHAIN_CONSTANT_1:.*]], %{{.*}} = tf_executor.island wraps "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
     // CHECK: %[[WHILE_OUT:.*]]:3, %[[WHILE_CONTROL:.*]] = tf_executor.island(%[[CONTROL_A]]) wraps "tf.While"(%[[ARG_0]], %[[ARG_1]], %[[CHAIN_CONSTANT_1]])
     %while_1_out:2, %while_1_control = tf_executor.island(%control_A) wraps "tf.While"(%arg0, %arg1) {body = @multiple_callers_while_body, cond = @multiple_callers_while_cond, is_stateless = false} : (tensor<!tf_type.resource<tensor<f32>>>, tensor<f32>) -> (tensor<!tf_type.resource<tensor<f32>>>, tensor<f32>)
     // CHECK: tf_executor.fetch
@@ -223,7 +223,7 @@
   // CHECK: %[[GRAPH_OUT:.*]]:3 = tf_executor.graph {
   %graph:2 = tf_executor.graph {
     // CHECK: %{{.*}}, %[[CONTROL_CHAIN_0_SRC:.*]] = tf_executor.island wraps "tf.Identity"(%[[CHAIN_0]]) : (tensor<i32>) -> tensor<i32>
-    // CHECK: %[[CHAIN_CONSTANT:.*]], %{{.*}} = tf_executor.island wraps "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+    // CHECK: %[[CHAIN_CONSTANT:.*]], %{{.*}} = tf_executor.island wraps "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
     // CHECK: %[[WHILE_OUT:.*]]:3, %[[WHILE_CONTROL:.*]] = tf_executor.island(%[[CONTROL_CHAIN_0_SRC]]) wraps "tf.While"(%[[RES_0]], %[[ARG_1]], %[[CHAIN_CONSTANT]])
     %while_out:2, %while_control = tf_executor.island() wraps "tf.While"(%arg0, %arg1) {body = @nested_loop_while_body_inner, cond = @nested_loop_while_cond_inner, is_stateless = false} : (tensor<!tf_type.resource<tensor<f32>>>, tensor<f32>) -> (tensor<!tf_type.resource<tensor<f32>>>, tensor<f32>)
     // CHECK: %[[CHAIN_0_SINK:.*]], %{{.*}} = tf_executor.island(%[[WHILE_CONTROL]]) wraps "tf.Identity"(%[[CHAIN_0]]) : (tensor<i32>) -> tensor<i32>
@@ -252,7 +252,7 @@
 func.func @nested_while(%arg0: !tf_res, %arg1: tensor<f32>) {
   // CHECK: tf_executor.graph {
   tf_executor.graph {
-    // CHECK: %[[CHAIN_CONSTANT:.*]], %{{.*}} = tf_executor.island wraps "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+    // CHECK: %[[CHAIN_CONSTANT:.*]], %{{.*}} = tf_executor.island wraps "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
     // CHECK: %[[WHILE_OUT:.*]]:3, %[[WHILE_CONTROL:.*]] = tf_executor.island wraps "tf.While"(%[[ARG_0]], %[[ARG_1]], %[[CHAIN_CONSTANT]])
     %while_out:2, %while_control = tf_executor.island() wraps "tf.While"(%arg0, %arg1) {body = @nested_loop_while_body_outer, cond = @nested_loop_while_cond_outer, is_stateless = false} : (tensor<!tf_type.resource<tensor<f32>>>, tensor<f32>) -> (tensor<!tf_type.resource<tensor<f32>>>, tensor<f32>)
     // CHECK: tf_executor.fetch
@@ -396,7 +396,7 @@
 // CHECK-LABEL:   func @unique_resource_chain
 // CHECK-SAME:      %[[ARG_0:.*]]: tensor<i32>, %[[ARG_1:.*]]: tensor<f32>
 // CHECK:           tf_executor.graph
-// CHECK:             %[[WHILE:.*]]:2, %[[WHILE_CONTROL:.*]] = tf_executor.island wraps "tf.While"(%[[ARG_0]], %[[ARG_1]]) {body = @unique_resource_chain_while_body, cond = @unique_resource_chain_while_cond, is_stateless = false} : (tensor<i32>, tensor<f32>) -> (tensor<i32>, tensor<f32>)
+// CHECK:             %[[WHILE:.*]]:2, %[[WHILE_CONTROL:.*]] = tf_executor.island wraps "tf.While"(%[[ARG_0]], %[[ARG_1]]) <{body = @unique_resource_chain_while_body, cond = @unique_resource_chain_while_cond, is_stateless = false}> : (tensor<i32>, tensor<f32>) -> (tensor<i32>, tensor<f32>)
 // CHECK:             tf_executor.fetch
 // CHECK:           }
 // CHECK:           return
@@ -417,12 +417,12 @@
 // CHECK-LABEL:   func @unique_resource_chain_while_body
 // CHECK-SAME:      %[[ARG_0:.*]]: tensor<i32>, %[[ARG_1:.*]]: tensor<f32>
 // CHECK:           %[[GRAPH:.*]]:2 = tf_executor.graph {
-// CHECK:             %[[THOUSAND:.*]], %{{.*}} = tf_executor.island wraps "tf.Const"() {value = dense<1000> : tensor<i32>} : () -> tensor<i32>
-// CHECK:             %[[STACK_HANDLE:.*]], %{{.*}} = tf_executor.island wraps "tf.StackV2"(%[[THOUSAND]]) {elem_type = f32} : (tensor<i32>) -> tensor<!tf_type.resource<tensor<f32>>>
+// CHECK:             %[[THOUSAND:.*]], %{{.*}} = tf_executor.island wraps "tf.Const"() <{value = dense<1000> : tensor<i32>}> : () -> tensor<i32>
+// CHECK:             %[[STACK_HANDLE:.*]], %{{.*}} = tf_executor.island wraps "tf.StackV2"(%[[THOUSAND]]) <{elem_type = f32}> : (tensor<i32>) -> tensor<!tf_type.resource<tensor<f32>>>
 // CHECK:             %{{.*}}, %[[STACK_PUSH_CONTROL:.*]] = tf_executor.island wraps "tf.StackPushV2"(%[[STACK_HANDLE]], %[[ARG_1]]) : (tensor<!tf_type.resource<tensor<f32>>>, tensor<f32>) -> tensor<f32>
 // CHECK:             %[[ADD:.*]], %{{.*}} = tf_executor.island wraps "tf.Add"(%[[ARG_1]], %[[ARG_1]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
 // CHECK:             %{{.*}}, %{{.*}} = tf_executor.island(%[[STACK_PUSH_CONTROL]]) wraps "tf.StackPushV2"(%[[STACK_HANDLE]], %[[ADD]]) : (tensor<!tf_type.resource<tensor<f32>>>, tensor<f32>) -> tensor<f32>
-// CHECK:             %[[ONE:.*]], %{{.*}} = tf_executor.island wraps "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+// CHECK:             %[[ONE:.*]], %{{.*}} = tf_executor.island wraps "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:             %[[COUNTER:.*]], %{{.*}} = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ONE]]) : (tensor<i32>, tensor<i32>) -> tensor<i32>
 // CHECK:             tf_executor.fetch %[[COUNTER]], %[[ARG_1]] : tensor<i32>, tensor<f32>
 // CHECK:           }
@@ -439,7 +439,7 @@
 // CHECK-LABEL:   func @unique_resource_chain_while_cond
 // CHECK-SAME:      %[[ARG_0:.*]]: tensor<i32>, %[[ARG_1:.*]]: tensor<f32>
 // CHECK:           %[[GRAPH:.*]] = tf_executor.graph
-// CHECK:             %[[CONST:.*]], %[[CONST_CONTROL:.*]] = tf_executor.island wraps "tf.Const"() {value = dense<1000> : tensor<i32>} : () -> tensor<i32>
+// CHECK:             %[[CONST:.*]], %[[CONST_CONTROL:.*]] = tf_executor.island wraps "tf.Const"() <{value = dense<1000> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:             %[[LESS:.*]], %[[LESS_CONTROL:.*]] = tf_executor.island wraps "tf.Less"(%[[CONST]], %[[ARG_0]]) : (tensor<i32>, tensor<i32>) -> tensor<i1>
 // CHECK:             tf_executor.fetch %[[LESS]] : tensor<i1>
 // CHECK:           }
@@ -464,8 +464,8 @@
 // CHECK-LABEL:   func @mixed_unique_resource_chain
 // CHECK-SAME:      %[[ARG_0:.*]]: tensor<i32>, %[[ARG_1:.*]]: tensor<f32>
 // CHECK:           tf_executor.graph
-// CHECK:             %[[CHAIN_TOKEN:.*]], %{{.*}} = tf_executor.island wraps "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
-// CHECK:             %[[WHILE:.*]]:3, %[[WHILE_CONTROL:.*]] = tf_executor.island wraps "tf.While"(%[[ARG_0]], %[[ARG_1]], %[[CHAIN_TOKEN]]) {body = @mixed_unique_resource_chain_while_body, cond = @mixed_unique_resource_chain_while_cond, is_stateless = false} : (tensor<i32>, tensor<f32>, tensor<i32>) -> (tensor<i32>, tensor<f32>, tensor<i32>)
+// CHECK:             %[[CHAIN_TOKEN:.*]], %{{.*}} = tf_executor.island wraps "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
+// CHECK:             %[[WHILE:.*]]:3, %[[WHILE_CONTROL:.*]] = tf_executor.island wraps "tf.While"(%[[ARG_0]], %[[ARG_1]], %[[CHAIN_TOKEN]]) <{body = @mixed_unique_resource_chain_while_body, cond = @mixed_unique_resource_chain_while_cond, is_stateless = false}> : (tensor<i32>, tensor<f32>, tensor<i32>) -> (tensor<i32>, tensor<f32>, tensor<i32>)
 // CHECK:             tf_executor.fetch
 // CHECK:           }
 // CHECK:           return
@@ -489,14 +489,14 @@
 // CHECK-SAME:      %[[ARG_0:.*]]: tensor<i32>, %[[ARG_1:.*]]: tensor<f32>, %[[CHAIN_TOKEN:.*]]: tensor<i32>
 // CHECK:           %[[GRAPH:.*]]:3 = tf_executor.graph
 // CHECK:             %{{.*}}, %[[CHAIN_SRC:.*]] = tf_executor.island wraps "tf.Identity"(%[[CHAIN_TOKEN]]) : (tensor<i32>) -> tensor<i32>
-// CHECK:             %[[THOUSAND:.*]], %{{.*}} = tf_executor.island wraps "tf.Const"() {value = dense<1000> : tensor<i32>} : () -> tensor<i32>
-// CHECK:             %[[STACK_HANDLE:.*]], %{{.*}} = tf_executor.island wraps "tf.StackV2"(%[[THOUSAND]]) {elem_type = f32} : (tensor<i32>) -> tensor<!tf_type.resource<tensor<f32>>>
+// CHECK:             %[[THOUSAND:.*]], %{{.*}} = tf_executor.island wraps "tf.Const"() <{value = dense<1000> : tensor<i32>}> : () -> tensor<i32>
+// CHECK:             %[[STACK_HANDLE:.*]], %{{.*}} = tf_executor.island wraps "tf.StackV2"(%[[THOUSAND]]) <{elem_type = f32}> : (tensor<i32>) -> tensor<!tf_type.resource<tensor<f32>>>
 // CHECK:             %{{.*}}, %[[STACK_PUSH_CONTROL:.*]] = tf_executor.island wraps "tf.StackPushV2"(%[[STACK_HANDLE]], %[[ARG_1]]) : (tensor<!tf_type.resource<tensor<f32>>>, tensor<f32>) -> tensor<f32>
 // CHECK:             %[[ADD:.*]], %{{.*}} = tf_executor.island wraps "tf.Add"(%[[ARG_1]], %[[ARG_1]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
 // CHECK:             %{{.*}}, %{{.*}} = tf_executor.island(%[[STACK_PUSH_CONTROL]]) wraps "tf.StackPushV2"(%[[STACK_HANDLE]], %[[ADD]]) : (tensor<!tf_type.resource<tensor<f32>>>, tensor<f32>) -> tensor<f32>
-// CHECK:             %[[ONE:.*]], %{{.*}} = tf_executor.island wraps "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+// CHECK:             %[[ONE:.*]], %{{.*}} = tf_executor.island wraps "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:             %[[COUNTER:.*]], %{{.*}} = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ONE]]) : (tensor<i32>, tensor<i32>) -> tensor<i32>
-// CHECK:             %[[VAR_HANDLE:.*]], %{{.*}} = tf_executor.island wraps "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> tensor<!tf_type.resource<tensor<f32>>>
+// CHECK:             %[[VAR_HANDLE:.*]], %{{.*}} = tf_executor.island wraps "tf.VarHandleOp"() <{container = "c", shared_name = "v0"}> : () -> tensor<!tf_type.resource<tensor<f32>>>
 // CHECK:             %[[ASSIGN_CONTROL:.*]] = tf_executor.island(%[[CHAIN_SRC]]) wraps "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[ARG_1]]) : (tensor<!tf_type.resource<tensor<f32>>>, tensor<f32>) -> ()
 // CHECK:             %[[CHAIN_SINK:.*]], %{{.*}} = tf_executor.island(%[[ASSIGN_CONTROL]]) wraps "tf.Identity"(%[[CHAIN_TOKEN]]) : (tensor<i32>) -> tensor<i32>
 // CHECK:             tf_executor.fetch %[[COUNTER]], %[[ARG_1]], %[[CHAIN_SINK]] : tensor<i32>, tensor<f32>, tensor<i32>
@@ -514,7 +514,7 @@
 // CHECK-LABEL:   func @mixed_unique_resource_chain_while_cond
 // CHECK-SAME:      %[[ARG_0:.*]]: tensor<i32>, %[[ARG_1:.*]]: tensor<f32>, %[[CHAIN_TOKEN:.*]]: tensor<i32>
 // CHECK:           %[[GRAPH:.*]] = tf_executor.graph
-// CHECK:             %[[CONST:.*]], %[[CONST_CONTROL:.*]] = tf_executor.island wraps "tf.Const"() {value = dense<1000> : tensor<i32>} : () -> tensor<i32>
+// CHECK:             %[[CONST:.*]], %[[CONST_CONTROL:.*]] = tf_executor.island wraps "tf.Const"() <{value = dense<1000> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:             %[[LESS:.*]], %[[LESS_CONTROL:.*]] = tf_executor.island wraps "tf.Less"(%[[CONST]], %[[ARG_0]]) : (tensor<i32>, tensor<i32>) -> tensor<i1>
 // CHECK:             tf_executor.fetch %[[LESS]] : tensor<i1>
 // CHECK:           }
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/convert_launch_func_to_tf_call.mlir b/tensorflow/compiler/mlir/tensorflow/tests/convert_launch_func_to_tf_call.mlir
index ea44e0f..4c532cd 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/convert_launch_func_to_tf_call.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/convert_launch_func_to_tf_call.mlir
@@ -11,8 +11,8 @@
       %2 = "tf.A"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
 
       // CHECK: %[[CALL_OUTPUT:[0-9]*]] = "tf.PartitionedCall"(%[[A_OUTPUT]])
-      // CHECK-SAME: device = "/device:test_device:0"
       // CHECK-SAME: f = @_func
+      // CHECK-SAME: device = "/device:test_device:0"
       %3 = "tf_device.launch_func"(%2) {device = "/device:test_device:0", func = @_func} : (tensor<?xf32>) -> tensor<?xf32>
 
       // CHECK: tf_executor.yield %[[CALL_OUTPUT]]
@@ -40,13 +40,13 @@
       %2 = "tf.A"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
 
       // CHECK: %[[CALL_OUTPUT_0:[0-9]*]] = "tf.PartitionedCall"(%[[A_OUTPUT]])
-      // CHECK-SAME: device = "/device:test_device:0"
       // CHECK-SAME: f = @_func
+      // CHECK-SAME: device = "/device:test_device:0"
       %3 = "tf_device.launch_func"(%2) {device = "/device:test_device:0", func = @_func} : (tensor<?xf32>) -> tensor<?xf32>
 
       // CHECK: %[[CALL_OUTPUT_1:[0-9]*]] = "tf.PartitionedCall"(%[[CALL_OUTPUT_0]])
-      // CHECK-SAME: device = "/device:test_device:1"
       // CHECK-SAME: f = @_func
+      // CHECK-SAME: device = "/device:test_device:1"
       %4 = "tf_device.launch_func"(%3) {device = "/device:test_device:1", func = @_func} : (tensor<?xf32>) -> tensor<?xf32>
 
       // CHECK: tf_executor.yield %[[CALL_OUTPUT_1]]
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir
index 7af6ff1..300a766 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir
@@ -9,7 +9,7 @@
 // CHECK-LABEL: func @decomposition_outside_cluster
 func.func @decomposition_outside_cluster() {
   %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf_type.resource<tensor<2x8xi32>>>
-  // CHECK:      %[[ONE:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>}
+  // CHECK:      %[[ONE:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}>
   %1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
   // CLUSTER-ONLY: "tf.AssignAddVariableOp"
   // ALWAYS-DECOMPOSE-NOT:  "tf.AssignAddVariableOp"
@@ -74,7 +74,7 @@
   "tf_device.cluster"() ({
     %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf_type.resource<tensor<2x8xi32>>>
 
-    // CHECK:      %[[ONE:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>}
+    // CHECK:      %[[ONE:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}>
     // CHECK:      %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"
     // CHECK-SAME: (tensor<*x!tf_type.resource<tensor<2x8xi32>>>) -> tensor<2x8xi32>
     // CHECK:      "tf.AddV2"(%[[RES_READ_VAL]], %[[ONE]])
@@ -98,7 +98,7 @@
 
     %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<!tf_type.resource<tensor<i32>>>
 
-    // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>}
+    // CHECK: %[[ONE:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}>
     // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"
     // CHECK: "tf.AddV2"(%[[RES_READ_VAL]], %[[ONE]])
     // CHECK: "tf.AssignVariableOp"
@@ -121,7 +121,7 @@
 
     %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<!tf_type.resource<tensor<i32>>>
 
-    // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>}
+    // CHECK: %[[ONE:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}>
     // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"
     // CHECK: "tf.Sub"(%[[RES_READ_VAL]], %[[ONE]])
     // CHECK: "tf.AssignVariableOp"
@@ -323,8 +323,8 @@
     // CHECK: [[VAR_DELTA:%.*]] = "tf.Div"([[LR_MULTIPLY]], [[DIVISOR]]) : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
     // CHECK: [[OLD_VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]]) : (tensor<*x!tf_type.resource<tensor<*xf32>>>) -> tensor<*xf32>
     // CHECK: [[NEW_VAR:%.*]] = "tf.Sub"(%9, %8) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
-    // CHECK: "tf.AssignVariableOp"([[VAR_HANDLE]], [[NEW_VAR]]) {validate_shape = false} : (tensor<*x!tf_type.resource<tensor<*xf32>>>, tensor<*xf32>) -> ()
-    // CHECK: "tf.AssignVariableOp"([[ACC_HANDLE]], [[NEW_ACC]]) {validate_shape = false} : (tensor<*x!tf_type.resource<tensor<*xf32>>>, tensor<*xf32>) -> ()
+    // CHECK: "tf.AssignVariableOp"([[VAR_HANDLE]], [[NEW_VAR]]) <{validate_shape = false}> : (tensor<*x!tf_type.resource<tensor<*xf32>>>, tensor<*xf32>) -> ()
+    // CHECK: "tf.AssignVariableOp"([[ACC_HANDLE]], [[NEW_ACC]]) <{validate_shape = false}> : (tensor<*x!tf_type.resource<tensor<*xf32>>>, tensor<*xf32>) -> ()
 
     %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf_type.resource<tensor<*xf32>>>
     %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf_type.resource<tensor<*xf32>>>
@@ -341,8 +341,8 @@
 func.func @decompose_resource_apply_adagrad(%arg0: tensor<f32>, %arg1: tensor<f32>) -> () {
   "tf_device.cluster"() ({
 
-    // CHECK: %[[VAR_HANDLE:.*]] = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf_type.resource<tensor<*xf32>>>
-    // CHECK: %[[ACCUM_HANDLE:.*]] = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf_type.resource<tensor<*xf32>>>
+    // CHECK: %[[VAR_HANDLE:.*]] = "tf.VarHandleOp"() <{container = "c", shared_name = "v"}> : () -> tensor<*x!tf_type.resource<tensor<*xf32>>>
+    // CHECK: %[[ACCUM_HANDLE:.*]] = "tf.VarHandleOp"() <{container = "c", shared_name = "v"}> : () -> tensor<*x!tf_type.resource<tensor<*xf32>>>
     // CHECK: %[[ACCUM:.*]] = "tf.ReadVariableOp"(%[[ACCUM_HANDLE]]) : (tensor<*x!tf_type.resource<tensor<*xf32>>>) -> tensor<*xf32>
     // CHECK: %[[GRAD_SQUARE:.*]] = "tf.Mul"(%[[GRAD]], %[[GRAD]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
     // CHECK: %[[ACCUM_NEW:.*]] = "tf.AddV2"(%[[ACCUM]], %[[GRAD_SQUARE]]) : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
@@ -351,8 +351,8 @@
     // CHECK: %[[DIV:.*]] = "tf.Div"(%[[LR_MULTIPLY]], %[[SQRT]]) : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
     // CHECK: %[[VAR:.*]] = "tf.ReadVariableOp"(%[[VAR_HANDLE]]) : (tensor<*x!tf_type.resource<tensor<*xf32>>>) -> tensor<*xf32>
     // CHECK: %[[VAR_NEW:.*]] = "tf.Sub"(%[[VAR]], %[[DIV]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
-    // CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[VAR_NEW]]) {validate_shape = false} : (tensor<*x!tf_type.resource<tensor<*xf32>>>, tensor<*xf32>) -> ()
-    // CHECK: "tf.AssignVariableOp"(%[[ACCUM_HANDLE]], %[[ACCUM_NEW]]) {validate_shape = false} : (tensor<*x!tf_type.resource<tensor<*xf32>>>, tensor<*xf32>) -> ()
+    // CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[VAR_NEW]]) <{validate_shape = false}> : (tensor<*x!tf_type.resource<tensor<*xf32>>>, tensor<*xf32>) -> ()
+    // CHECK: "tf.AssignVariableOp"(%[[ACCUM_HANDLE]], %[[ACCUM_NEW]]) <{validate_shape = false}> : (tensor<*x!tf_type.resource<tensor<*xf32>>>, tensor<*xf32>) -> ()
     %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf_type.resource<tensor<*xf32>>>
     %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf_type.resource<tensor<*xf32>>>
 
@@ -372,7 +372,7 @@
 func.func @decompose_resource_apply_adam_non_nesterov(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<f32>, %arg6: tensor<f32>) -> () {
   "tf_device.cluster"() ({
 
-    // CHECK: [[ONE:%.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
+    // CHECK: [[ONE:%.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}>
     // CHECK: [[VAR_HANDLE:%.*]] = "tf.VarHandleOp"()
     // CHECK: [[M_HANDLE:%.*]] = "tf.VarHandleOp"()
     // CHECK: [[V_HANDLE:%.*]] = "tf.VarHandleOp"()
@@ -422,10 +422,10 @@
 func.func @decompose_resource_apply_adam_nesterov(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<f32>, %arg6: tensor<f32>) -> () {
   "tf_device.cluster"() ({
 
-  // CHECK: [[ONE:%.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
-  // CHECK: [[VAR_HANDLE:%.*]] = "tf.VarHandleOp"() {container = "c", shared_name = "v"}
-  // CHECK: [[M_HANDLE:%.*]] = "tf.VarHandleOp"() {container = "c", shared_name = "v"}
-  // CHECK: [[V_HANDLE:%.*]] = "tf.VarHandleOp"() {container = "c", shared_name = "v"}
+  // CHECK: [[ONE:%.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}>
+  // CHECK: [[VAR_HANDLE:%.*]] = "tf.VarHandleOp"() <{container = "c", shared_name = "v"}>
+  // CHECK: [[M_HANDLE:%.*]] = "tf.VarHandleOp"() <{container = "c", shared_name = "v"}>
+  // CHECK: [[V_HANDLE:%.*]] = "tf.VarHandleOp"() <{container = "c", shared_name = "v"}>
   // CHECK: [[VAL_82:%.*]] = "tf.Sub"([[ONE]], [[BETA2_POWER]])
   // CHECK: [[VAL_83:%.*]] = "tf.Sqrt"([[VAL_82]])
   // CHECK: [[VAL_84:%.*]] = "tf.Sub"([[ONE]], [[BETA1_POWER]])
@@ -452,9 +452,9 @@
   // CHECK: [[VAL_105:%.*]] = "tf.Div"([[VAL_102]], [[VAL_104]])
   // CHECK: [[OLD_VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]]) : (tensor<*x!tf_type.resource<tensor<*xf32>>>) -> tensor<*xf32>
   // CHECK: [[NEW_VAR:%.*]] = "tf.Sub"([[OLD_VAR]], [[VAL_105]])
-  // CHECK: "tf.AssignVariableOp"([[VAR_HANDLE]], [[NEW_VAR]]) {validate_shape = false} : (tensor<*x!tf_type.resource<tensor<*xf32>>>, tensor<*xf32>) -> ()
-  // CHECK: "tf.AssignVariableOp"([[M_HANDLE]], [[NEW_M]]) {validate_shape = false} : (tensor<*x!tf_type.resource<tensor<*xf32>>>, tensor<*xf32>) -> ()
-  // CHECK: "tf.AssignVariableOp"([[V_HANDLE]], [[NEW_V]]) {validate_shape = false} : (tensor<*x!tf_type.resource<tensor<*xf32>>>, tensor<*xf32>) -> ()
+  // CHECK: "tf.AssignVariableOp"([[VAR_HANDLE]], [[NEW_VAR]]) <{validate_shape = false}> : (tensor<*x!tf_type.resource<tensor<*xf32>>>, tensor<*xf32>) -> ()
+  // CHECK: "tf.AssignVariableOp"([[M_HANDLE]], [[NEW_M]]) <{validate_shape = false}> : (tensor<*x!tf_type.resource<tensor<*xf32>>>, tensor<*xf32>) -> ()
+  // CHECK: "tf.AssignVariableOp"([[V_HANDLE]], [[NEW_V]]) <{validate_shape = false}> : (tensor<*x!tf_type.resource<tensor<*xf32>>>, tensor<*xf32>) -> ()
 
     %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf_type.resource<tensor<*xf32>>>
     %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf_type.resource<tensor<*xf32>>>
@@ -474,7 +474,7 @@
 func.func @decompose_adam_with_complex_inputs(%arg0: tensor<!tf_type.resource<tensor<2xcomplex<f32>>>>, %arg1: tensor<!tf_type.resource<tensor<2xcomplex<f32>>>>, %arg2: tensor<!tf_type.resource<tensor<2xcomplex<f32>>>>, %arg3: tensor<complex<f32>>, %arg4: tensor<complex<f32>>, %arg5: tensor<complex<f32>>, %arg6: tensor<complex<f32>>, %arg7: tensor<complex<f32>>, %arg8: tensor<complex<f32>>, %arg9: tensor<2xcomplex<f32>>) attributes {tf.entry_function = {control_outputs = "Adam/update_Variable_1/ResourceApplyAdam", inputs = "_arg0,_arg1,_arg2,_arg3,_arg4,_arg5,_arg6,_arg7,_arg8,_arg9", outputs = ""}} {
   "tf_device.cluster"() ({
 
-    // CHECK: "tf.Const"() {value = dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f32>>} : () -> tensor<complex<f32>>
+    // CHECK: "tf.Const"() <{value = dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f32>>}> : () -> tensor<complex<f32>>
     // CHECK-NOT: tf.ResourceApplyAdam
     "tf.ResourceApplyAdam"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9) {_XlaHasReferenceVars = false, _xla_inferred_shapes = [], device = "/job:localhost/replica:0/task:0/device:TPU:0", use_locking = false, use_nesterov = false} : (tensor<!tf_type.resource<tensor<2xcomplex<f32>>>>, tensor<!tf_type.resource<tensor<2xcomplex<f32>>>>, tensor<!tf_type.resource<tensor<2xcomplex<f32>>>>, tensor<complex<f32>>, tensor<complex<f32>>, tensor<complex<f32>>, tensor<complex<f32>>, tensor<complex<f32>>, tensor<complex<f32>>, tensor<2xcomplex<f32>>) -> ()
 
@@ -489,13 +489,13 @@
 // CHECK-SAME: [[INDEX:%.+]]: tensor<?xi32>
 func.func @decompose_resource_gather_op(%indices : tensor<?xi32>) -> tensor<*xi32> {
   %0 = "tf_device.cluster"() ({
-    // CHECK: [[ZERO:%.+]] = "tf.Const"() {value = dense<0> : tensor<i64>}
+    // CHECK: [[ZERO:%.+]] = "tf.Const"() <{value = dense<0> : tensor<i64>}>
 
     // CHECK: [[VAR:%.+]] = "tf.VarHandleOp"
     %resource = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf_type.resource<tensor<*xi32>>>
 
     // CHECK: [[READVAR:%.+]] = "tf.ReadVariableOp"([[VAR]])
-    // CHECK: [[GATHER:%.+]] = "tf.GatherV2"([[READVAR]], [[INDEX]], [[ZERO]]) {batch_dims = 0 : i64} : (tensor<*xi32>, tensor<?xi32>, tensor<i64>) -> tensor<*xi32>
+    // CHECK: [[GATHER:%.+]] = "tf.GatherV2"([[READVAR]], [[INDEX]], [[ZERO]]) <{batch_dims = 0 : i64}> : (tensor<*xi32>, tensor<?xi32>, tensor<i64>) -> tensor<*xi32>
     // CHECK: return [[GATHER]]
     %1 = "tf.ResourceGather"(%resource, %indices) : (tensor<*x!tf_type.resource<tensor<*xi32>>>, tensor<?xi32>) -> (tensor<*xi32>)
     tf_device.return %1 : tensor<*xi32>
@@ -512,7 +512,7 @@
   %0 = "tf_device.cluster"() ({
     %resource = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf_type.resource<tensor<2x8x16xi32>>>
 
-    // CHECK: "tf.GatherV2"({{.+}}, {{.+}}, {{.+}}) {batch_dims = 1 : i64} : (tensor<2x8x16xi32>, tensor<5xi32>, tensor<i64>) -> tensor<2x5x16xi32>
+    // CHECK: "tf.GatherV2"({{.+}}, {{.+}}, {{.+}}) <{batch_dims = 1 : i64}> : (tensor<2x8x16xi32>, tensor<5xi32>, tensor<i64>) -> tensor<2x5x16xi32>
     %1 = "tf.ResourceGather"(%resource, %indices) {batch_dims = 1} : (tensor<*x!tf_type.resource<tensor<2x8x16xi32>>>, tensor<5xi32>) -> (tensor<2x5x16xi32>)
 
     tf_device.return %1 : tensor<2x5x16xi32>
@@ -527,7 +527,7 @@
 // CHECK-SAME:  [[VAR:%.*]]: tensor<f32>, [[MG:%.*]]: tensor<f32>, [[MS:%.*]]: tensor<f32>, [[MOM:%.*]]: tensor<f32>, [[LR:%.*]]: tensor<f32>, [[RHO:%.*]]: tensor<f32>, [[MOMENTUM:%.*]]: tensor<f32>, [[EPSILON:%.*]]: tensor<f32>, [[GRAD:%.*]]: tensor<f32>
 func.func @decompose_resource_apply_centered_RMS_prop(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<f32>, %arg6: tensor<f32>, %arg7: tensor<f32>, %arg8: tensor<f32>) -> () {
   "tf_device.cluster"() ({
-    // CHECK: [[ONE:%.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
+    // CHECK: [[ONE:%.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}>
     // CHECK: [[VAR_HANDLE:%.*]] = "tf.VarHandleOp"
     // CHECK: [[MG_HANDLE:%.*]] = "tf.VarHandleOp"
     // CHECK: [[MS_HANDLE:%.*]] = "tf.VarHandleOp"
@@ -578,14 +578,14 @@
 // CHECK-SAME:   %[[LR:.*]]: tensor<f32>, %[[RHO:.*]]: tensor<f32>, %[[MOMENTUM:.*]]: tensor<f32>, %[[EPSILON:.*]]: tensor<f32>, %[[GRAD:.*]]: tensor<f32>)
 func.func @decompose_resource_apply_RMS_prop(%arg0: tensor<*x!tf_type.resource>, %arg1: tensor<*x!tf_type.resource>, %arg2: tensor<*x!tf_type.resource>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<f32>, %arg6: tensor<f32>, %arg7: tensor<f32>) -> () {
   "tf_device.cluster"() ({
-    // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
+    // CHECK: %[[ONE:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
     // CHECK: %[[MS:.*]] = "tf.ReadVariableOp"(%[[MS_HANDLE]]) : (tensor<*x!tf_type.resource>) -> tensor<*xf32>
     // CHECK: %[[MS_RHO:.*]] = "tf.Mul"(%[[MS]], %[[RHO]]) : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
     // CHECK: %[[GRAD_SQUARE:.*]] = "tf.Square"(%[[GRAD]]) : (tensor<f32>) -> tensor<f32>
     // CHECK: %[[ONE_RHO:.*]] = "tf.Sub"(%[[ONE]], %[[RHO]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
     // CHECK: %[[MUL:.*]] = "tf.Mul"(%[[GRAD_SQUARE]], %[[ONE_RHO]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
     // CHECK: %[[MS_NEW:.*]] = "tf.AddV2"(%[[MS_RHO]], %[[MUL]]) : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
-    // CHECK: "tf.AssignVariableOp"(%[[MS_HANDLE]], %[[MS_NEW]]) {validate_shape = false} : (tensor<*x!tf_type.resource>, tensor<*xf32>) -> ()
+    // CHECK: "tf.AssignVariableOp"(%[[MS_HANDLE]], %[[MS_NEW]]) <{validate_shape = false}> : (tensor<*x!tf_type.resource>, tensor<*xf32>) -> ()
     // CHECK: %[[MOM:.*]] = "tf.ReadVariableOp"(%[[MOM_HANDLE]]) : (tensor<*x!tf_type.resource>) -> tensor<*xf32>
     // CHECK: %[[MOMENTUM_MOM:.*]] = "tf.Mul"(%[[MOMENTUM]], %[[MOM]]) : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
     // CHECK: %[[LR_GRAD:.*]] = "tf.Mul"(%[[LR]], %[[GRAD]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
@@ -593,10 +593,10 @@
     // CHECK: %[[SQRT:.*]] = "tf.Sqrt"(%[[ADD]]) : (tensor<*xf32>) -> tensor<*xf32>
     // CHECK: %[[DIV:.*]] = "tf.Div"(%[[LR_GRAD]], %[[SQRT]]) : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
     // CHECK: %[[MOM_NEW:.*]] = "tf.AddV2"(%[[MOMENTUM_MOM]], %[[DIV]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
-    // CHECK: "tf.AssignVariableOp"(%[[MOM_HANDLE]], %[[MOM_NEW]]) {validate_shape = false} : (tensor<*x!tf_type.resource>, tensor<*xf32>) -> ()
+    // CHECK: "tf.AssignVariableOp"(%[[MOM_HANDLE]], %[[MOM_NEW]]) <{validate_shape = false}> : (tensor<*x!tf_type.resource>, tensor<*xf32>) -> ()
     // CHECK: %[[VAR:.*]] = "tf.ReadVariableOp"(%[[VAR_HANDLE]]) : (tensor<*x!tf_type.resource>) -> tensor<*xf32>
     // CHECK: %[[VAR_NEW:.*]] = "tf.Sub"(%[[VAR]], %[[MOM_NEW]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
-    // CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[VAR_NEW]]) {validate_shape = false} : (tensor<*x!tf_type.resource>, tensor<*xf32>) -> ()
+    // CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[VAR_NEW]]) <{validate_shape = false}> : (tensor<*x!tf_type.resource>, tensor<*xf32>) -> ()
     "tf.ResourceApplyRMSProp"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) {use_locking = false} : (tensor<*x!tf_type.resource>, tensor<*x!tf_type.resource>, tensor<*x!tf_type.resource>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
     tf_device.return
   }) : () -> ()
@@ -609,7 +609,7 @@
 // CHECK-LABEL: @decompose_resource_scatter_add_op
 // CHECK-SAME: ([[INDEX:%.+]]: tensor<2x?xi32>, [[UPDATE:%.+]]: tensor<?x?x?xi32>)
 func.func @decompose_resource_scatter_add_op(%indices : tensor<2x?xi32>, %updates: tensor<?x?x?xi32>) {
-  // CHECK: [[CST:%.+]] = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: [[CST:%.+]] = "tf.Const"() <{value = dense<-1> : tensor<i32>}> : () -> tensor<i32>
   "tf_device.cluster"() ({
     // CHECK: [[VAR:%.+]] = "tf.VarHandleOp"
     %resource = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf_type.resource<tensor<*xi32>>>
@@ -630,7 +630,7 @@
 // CHECK-LABEL: @decompose_resource_scatter_add_op_1d_indices
 // CHECK-SAME: ([[INDEX:%.+]]: tensor<?xi32>, [[UPDATE:%.+]]: tensor<?x?x?xi32>)
 func.func @decompose_resource_scatter_add_op_1d_indices(%indices : tensor<?xi32>, %updates: tensor<?x?x?xi32>) {
-  // CHECK: [[CST:%.+]] = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: [[CST:%.+]] = "tf.Const"() <{value = dense<-1> : tensor<i32>}> : () -> tensor<i32>
   "tf_device.cluster"() ({
     // CHECK: [[VAR:%.+]] = "tf.VarHandleOp"
     %resource = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf_type.resource<tensor<*xi32>>>
@@ -679,7 +679,7 @@
 // CHECK-LABEL: @decompose_resource_scatter_update_op
 // CHECK-SAME: ([[INDEX:%.+]]: tensor<2x?xi32>, [[UPDATE:%.+]]: tensor<?x?x?xi32>)
 func.func @decompose_resource_scatter_update_op(%indices : tensor<2x?xi32>, %updates: tensor<?x?x?xi32>) {
-  // CHECK: [[CST:%.+]] = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: [[CST:%.+]] = "tf.Const"() <{value = dense<-1> : tensor<i32>}> : () -> tensor<i32>
   "tf_device.cluster"() ({
     // CHECK: [[VAR:%.+]] = "tf.VarHandleOp"
     %resource = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf_type.resource<tensor<*xi32>>>
@@ -782,10 +782,10 @@
     %var = "tf.VarHandleOp"() {container = "c", shared_name = "var"} : () -> tensor<*x!tf_type.resource<tensor<4xf32>>>
     %accum = "tf.VarHandleOp"() {container = "c", shared_name = "accum"} : () -> tensor<*x!tf_type.resource<tensor<4xf32>>>
 
-    // CHECK-DAG: %[[ONE:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
-    // CHECK-DAG: %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
-    // CHECK-DAG: %[[VAR_HANDLE:.*]] = "tf.VarHandleOp"() {container = "c", shared_name = "var"} : () -> tensor<*x!tf_type.resource<tensor<4xf32>>>
-    // CHECK-DAG: %[[ACCUM_HANDLE:.*]] = "tf.VarHandleOp"() {container = "c", shared_name = "accum"} : () -> tensor<*x!tf_type.resource<tensor<4xf32>>>
+    // CHECK-DAG: %[[ONE:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+    // CHECK-DAG: %[[ZERO:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+    // CHECK-DAG: %[[VAR_HANDLE:.*]] = "tf.VarHandleOp"() <{container = "c", shared_name = "var"}> : () -> tensor<*x!tf_type.resource<tensor<4xf32>>>
+    // CHECK-DAG: %[[ACCUM_HANDLE:.*]] = "tf.VarHandleOp"() <{container = "c", shared_name = "accum"}> : () -> tensor<*x!tf_type.resource<tensor<4xf32>>>
     // CHECK-DAG: %[[GRAD_SQ:.*]] = "tf.Square"(%[[GRAD]]) : (tensor<4xf32>) -> tensor<4xf32>
     // CHECK-DAG: %[[ACCUM:.*]] = "tf.ReadVariableOp"(%[[ACCUM_HANDLE]]) : (tensor<*x!tf_type.resource<tensor<4xf32>>>) -> tensor<4xf32>
     // CHECK-DAG: %[[ACCUM_NEW:.*]] = "tf.AddV2"(%[[ACCUM]], %[[GRAD_SQ]]) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
@@ -805,8 +805,8 @@
     // CHECK-DAG: %[[SCALED_L2:.*]] = "tf.Mul"(%[[ADAGRAD_LR]], %[[L2]]) : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
     // CHECK-DAG: %[[DENOMINATOR:.*]] = "tf.Add"(%[[ONE]], %[[SCALED_L2]]) : (tensor<f32>, tensor<4xf32>) -> tensor<4xf32>
     // CHECK-DAG: %[[VAR_NEW:.*]] = "tf.Div"(%[[NUMERATOR]], %[[DENOMINATOR]]) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
-    // CHECK-DAG: "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[VAR_NEW]]) {validate_shape = false} : (tensor<*x!tf_type.resource<tensor<4xf32>>>, tensor<4xf32>) -> ()
-    // CHECK-DAG: "tf.AssignVariableOp"(%[[ACCUM_HANDLE]], %[[ACCUM_NEW]]) {validate_shape = false} : (tensor<*x!tf_type.resource<tensor<4xf32>>>, tensor<4xf32>) -> ()
+    // CHECK-DAG: "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[VAR_NEW]]) <{validate_shape = false}> : (tensor<*x!tf_type.resource<tensor<4xf32>>>, tensor<4xf32>) -> ()
+    // CHECK-DAG: "tf.AssignVariableOp"(%[[ACCUM_HANDLE]], %[[ACCUM_NEW]]) <{validate_shape = false}> : (tensor<*x!tf_type.resource<tensor<4xf32>>>, tensor<4xf32>) -> ()
 
     "tf.ResourceApplyProximalAdagrad"(%var, %accum, %lr, %l1, %l2, %grad) {use_locking = false} : (tensor<*x!tf_type.resource<tensor<4xf32>>>, tensor<*x!tf_type.resource<tensor<4xf32>>>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<4xf32>) -> ()
 
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/device_attribute_to_launch.mlir b/tensorflow/compiler/mlir/tensorflow/tests/device_attribute_to_launch.mlir
index 3384c65..4996884 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/device_attribute_to_launch.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/device_attribute_to_launch.mlir
@@ -4,10 +4,10 @@
 // CHECK-LABEL: func @single_op_launch
 func.func @single_op_launch() {
   // CHECK: "tf_device.launch"
+  // CHECK: device = "CPU:0"
   // CHECK: "tf.opA"
   // CHECK-NOT device
   // CHECK: tf_device.return
-  // CHECK: device = "CPU:0"
   "tf.opA"() {device = "CPU:0"} : () -> tensor<i1>
   func.return
 }
@@ -16,10 +16,10 @@
 // CHECK-LABEL: func @launch_return
 func.func @launch_return() -> tensor<i1> {
   // CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"
+  // CHECK: device = "CPU:0"
   // CHECK: %[[A_OUT:.*]] = "tf.opA"
   // CHECK-NOT device
   // CHECK: tf_device.return %[[A_OUT]]
-  // CHECK: device = "CPU:0"
   // CHECK: return %[[LAUNCH_OUT]]
   %a = "tf.opA"() {device = "CPU:0"} : () -> tensor<i1>
   func.return %a : tensor<i1>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir b/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir
index e1071f3..6666a08 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir
@@ -6,7 +6,7 @@
   // CHECK-LABEL: unary_einsum_reduce_sum_transpose
   // CHECK-DAG: %[[cst:.*]] = arith.constant dense<3> : tensor<1xi32>
   // CHECK-DAG: %[[cst_1:.*]] = arith.constant dense<[0, 2, 1]> : tensor<3xi32>
-  // CHECK: %[[v0:.*]] = "tf.Sum"(%arg0, %[[cst]]) {keep_dims = false} : (tensor<3x4x5x6xf32>, tensor<1xi32>) -> tensor<3x4x5xf32>
+  // CHECK: %[[v0:.*]] = "tf.Sum"(%arg0, %[[cst]]) <{keep_dims = false}> : (tensor<3x4x5x6xf32>, tensor<1xi32>) -> tensor<3x4x5xf32>
   // CHECK: %[[v1:.*]] = "tf.Transpose"(%[[v0]], %[[cst_1]]) : (tensor<3x4x5xf32>, tensor<3xi32>) -> tensor<3x5x4xf32>
   // CHECK: return %[[v1]] : tensor<3x5x4xf32>
 }
@@ -16,7 +16,7 @@
   func.return %0 : tensor<3x4x5xf32>
   // CHECK-LABEL: unary_einsum_reduce_sum_transpose1
   // CHECK-DAG: %[[cst:.*]] = arith.constant dense<3> : tensor<1xi32>
-  // CHECK: %[[v0:.*]] = "tf.Sum"(%arg0, %[[cst]]) {keep_dims = false} : (tensor<3x4x5x6xf32>, tensor<1xi32>) -> tensor<3x4x5xf32>
+  // CHECK: %[[v0:.*]] = "tf.Sum"(%arg0, %[[cst]]) <{keep_dims = false}> : (tensor<3x4x5x6xf32>, tensor<1xi32>) -> tensor<3x4x5xf32>
   // CHECK: return %[[v0]] : tensor<3x4x5xf32>
 }
 
@@ -34,7 +34,7 @@
   func.return %0 : tensor<4xf32>
   // CHECK-LABEL: unary_einsum_reduce_sum
   // CHECK-DAG: %[[cst:.*]] =  arith.constant dense<[1, 2]> : tensor<2xi32>
-  // CHECK: %[[v0:.*]] = "tf.Sum"(%arg0, %[[cst]]) {keep_dims = false} : (tensor<4x5x6xf32>, tensor<2xi32>) -> tensor<4xf32>
+  // CHECK: %[[v0:.*]] = "tf.Sum"(%arg0, %[[cst]]) <{keep_dims = false}> : (tensor<4x5x6xf32>, tensor<2xi32>) -> tensor<4xf32>
   // CHECK: return %[[v0]]
 }
 
@@ -42,14 +42,14 @@
   %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ijk,ikm->ijm"}: (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32>
   func.return %0 : tensor<3x4x6xf32>
   // CHECK-LABEL: einsum_basic
-  // CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32>
+  // CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) <{adj_x = false, adj_y = false}> : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32>
 }
 
 func.func @einsum_matmul(%arg0: tensor<7x9xf32>, %arg1: tensor<9x5xf32>) -> tensor<7x5xf32> {
   %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ae,ed->ad"}: (tensor<7x9xf32>, tensor<9x5xf32>) -> tensor<7x5xf32>
   func.return %0 : tensor<7x5xf32>
   // CHECK-LABEL: einsum_matmul
-  // CHECK: %[[v0:.*]] = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<7x9xf32>, tensor<9x5xf32>) -> tensor<7x5xf32>
+  // CHECK: %[[v0:.*]] = "tf.BatchMatMulV2"(%arg0, %arg1) <{adj_x = false, adj_y = false}> : (tensor<7x9xf32>, tensor<9x5xf32>) -> tensor<7x5xf32>
   // CHECK: return %[[v0]] : tensor<7x5xf32>
 }
 
@@ -59,7 +59,7 @@
   // CHECK-LABEL: einsum_matmul_dynamic_size
   // CHECK-DAG: %[[cst:.*]] = arith.constant dense<[2, -1, 1, 1]> : tensor<4xi64>
   // CHECK: %[[v0:.*]] = "tf.Reshape"(%arg1, %cst) : (tensor<2x?xf32>, tensor<4xi64>) -> tensor<2x?x1x1xf32>
-  // CHECK: %[[v1:.*]] = "tf.BatchMatMulV2"(%arg0, %0) {adj_x = false, adj_y = false} : (tensor<2x?x?x?xf32>, tensor<2x?x1x1xf32>) -> tensor<2x?x?x1xf32>
+  // CHECK: %[[v1:.*]] = "tf.BatchMatMulV2"(%arg0, %0) <{adj_x = false, adj_y = false}> : (tensor<2x?x?x?xf32>, tensor<2x?x1x1xf32>) -> tensor<2x?x?x1xf32>
   // CHECK: return %[[v1]] : tensor<2x?x?x1xf32>
 }
 
@@ -67,14 +67,14 @@
   %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ijk,km->ijm"}: (tensor<3x4x5xf32>, tensor<5x6xf32>) -> tensor<3x4x6xf32>
   func.return %0 : tensor<3x4x6xf32>
   // CHECK-LABEL: einsum_broadcast
-  // CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<3x4x5xf32>, tensor<5x6xf32>) -> tensor<3x4x6xf32>
+  // CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) <{adj_x = false, adj_y = false}> : (tensor<3x4x5xf32>, tensor<5x6xf32>) -> tensor<3x4x6xf32>
 }
 
 func.func @einsum_broadcast4(%arg0: tensor<3x4x5x6x7xf32>, %arg1: tensor<7x8xf32>) -> tensor<3x4x5x6x8xf32> {
   %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "abcdh,hg->abcdg"}: (tensor<3x4x5x6x7xf32>, tensor<7x8xf32>) -> tensor<3x4x5x6x8xf32>
   func.return %0 : tensor<3x4x5x6x8xf32>
   // CHECK-LABEL: einsum_broadcast4
-  // CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<3x4x5x6x7xf32>, tensor<7x8xf32>) -> tensor<3x4x5x6x8xf32>
+  // CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) <{adj_x = false, adj_y = false}> : (tensor<3x4x5x6x7xf32>, tensor<7x8xf32>) -> tensor<3x4x5x6x8xf32>
 }
 
 func.func @einsum_reducesum(%arg0: tensor<2x5x7xf32>, %arg1: tensor<5x2xf32>) -> tensor<5x7xf32> {
@@ -86,7 +86,7 @@
   // CHECK-DAG: %[[cst_2:.*]] = arith.constant dense<[5, 7]> : tensor<2xi64>
   // CHECK: %[[v0:.*]] = "tf.Transpose"(%arg0, %[[cst]]) : (tensor<2x5x7xf32>, tensor<3xi32>) -> tensor<5x7x2xf32>
   // CHECK: %[[v1:.*]] = "tf.Reshape"(%arg1, %[[cst_1]]) : (tensor<5x2xf32>, tensor<3xi64>) -> tensor<5x2x1xf32>
-  // CHECK: %[[v2:.*]] = "tf.BatchMatMulV2"(%[[v0]], %[[v1]]) {adj_x = false, adj_y = false} : (tensor<5x7x2xf32>, tensor<5x2x1xf32>) -> tensor<5x7x1xf32>
+  // CHECK: %[[v2:.*]] = "tf.BatchMatMulV2"(%[[v0]], %[[v1]]) <{adj_x = false, adj_y = false}> : (tensor<5x7x2xf32>, tensor<5x2x1xf32>) -> tensor<5x7x1xf32>
   // CHECK: %[[v3:.*]] = "tf.Reshape"(%[[v2]], %[[cst_2]]) : (tensor<5x7x1xf32>, tensor<2xi64>) -> tensor<5x7xf32>
   // CHECK: return %[[v3:.*]] : tensor<5x7xf32>
 }
@@ -99,7 +99,7 @@
   // CHECK-DAG: %[[cst_0:.*]] = arith.constant dense<[0, 2, 1]> : tensor<3xi32>
   // CHECK: %[[v0:.*]] = "tf.Transpose"(%arg0, %[[cst]]) : (tensor<2x5x7xf32>, tensor<3xi32>) -> tensor<5x7x2xf32>
   // CHECK: %[[v1:.*]] = "tf.Transpose"(%arg1, %[[cst_0]]) : (tensor<5x3x2xf32>, tensor<3xi32>) -> tensor<5x2x3xf32>
-  // CHECK: %[[v2:.*]] = "tf.BatchMatMulV2"(%[[v0]], %[[v1]]) {adj_x = false, adj_y = false} : (tensor<5x7x2xf32>, tensor<5x2x3xf32>) -> tensor<5x7x3xf32>
+  // CHECK: %[[v2:.*]] = "tf.BatchMatMulV2"(%[[v0]], %[[v1]]) <{adj_x = false, adj_y = false}> : (tensor<5x7x2xf32>, tensor<5x2x3xf32>) -> tensor<5x7x3xf32>
   // CHECK: %[[v3:.*]] = "tf.Transpose"(%[[v2]], %[[cst_0]]) : (tensor<5x7x3xf32>, tensor<3xi32>) -> tensor<5x3x7xf32>
 }
 
@@ -111,7 +111,7 @@
   // CHECK-DAG: %[[cst_1:.*]] = arith.constant dense<[0, 2, 3, 1]> : tensor<4xi32>
   // CHECK: %[[v0:.*]] = "tf.Transpose"(%arg0, %[[cst]]) : (tensor<2x5x7x3xf32>, tensor<4xi32>) -> tensor<2x7x5x3xf32>
   // CHECK: %[[v1:.*]] = "tf.Transpose"(%arg1, %[[cst_1]]) : (tensor<2x4x7x3xf32>, tensor<4xi32>) -> tensor<2x7x3x4xf32>
-  // CHECK: "tf.BatchMatMulV2"(%[[v0]], %[[v1]]) {adj_x = false, adj_y = false} : (tensor<2x7x5x3xf32>, tensor<2x7x3x4xf32>) -> tensor<2x7x5x4xf32>
+  // CHECK: "tf.BatchMatMulV2"(%[[v0]], %[[v1]]) <{adj_x = false, adj_y = false}> : (tensor<2x7x5x3xf32>, tensor<2x7x3x4xf32>) -> tensor<2x7x5x4xf32>
 }
 
 func.func @einsum_matrixdotprod(%arg0: tensor<2x5x7x3xf32>, %arg1: tensor<7x3x4xf32>) -> tensor<2x5x4xf32> {
@@ -122,7 +122,7 @@
   // CHECK-DAG: %[[cst_1:.*]] = arith.constant dense<[21, 4]> : tensor<2xi64>
   // CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x5x7x3xf32>, tensor<3xi64>) -> tensor<2x5x21xf32>
   // CHECK: %[[v1:.*]] = "tf.Reshape"(%arg1, %[[cst_1]]) : (tensor<7x3x4xf32>, tensor<2xi64>) -> tensor<21x4xf32>
-  // CHECK: "tf.BatchMatMulV2"(%[[v0]], %[[v1]]) {adj_x = false, adj_y = false} : (tensor<2x5x21xf32>, tensor<21x4xf32>) -> tensor<2x5x4xf32>
+  // CHECK: "tf.BatchMatMulV2"(%[[v0]], %[[v1]]) <{adj_x = false, adj_y = false}> : (tensor<2x5x21xf32>, tensor<21x4xf32>) -> tensor<2x5x4xf32>
 }
 
 func.func @einsum_reshapetail(%arg0: tensor<3x4x5xf32>, %arg1: tensor<5x6x2xf32>) -> tensor<3x4x6x2xf32> {
@@ -132,7 +132,7 @@
   // CHECK-DAG: %[[cst:.*]] = arith.constant dense<[5, 12]> : tensor<2xi64>
   // CHECK-DAG: %[[cst_1:.*]] = arith.constant dense<[3, 4, 6, 2]> : tensor<4xi64>
   // CHECK: %[[v0:.*]] = "tf.Reshape"(%arg1, %[[cst]]) : (tensor<5x6x2xf32>, tensor<2xi64>) -> tensor<5x12xf32>
-  // CHECK: %[[v1:.*]] = "tf.BatchMatMulV2"(%arg0, %[[v0]]) {adj_x = false, adj_y = false} : (tensor<3x4x5xf32>, tensor<5x12xf32>) -> tensor<3x4x12xf32>
+  // CHECK: %[[v1:.*]] = "tf.BatchMatMulV2"(%arg0, %[[v0]]) <{adj_x = false, adj_y = false}> : (tensor<3x4x5xf32>, tensor<5x12xf32>) -> tensor<3x4x12xf32>
   // CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<3x4x12xf32>, tensor<4xi64>) -> tensor<3x4x6x2xf32>
   // CHECK: return %[[v2]] : tensor<3x4x6x2xf32>
 }
@@ -144,7 +144,7 @@
   // CHECK-DAG: %[[cst:.*]] = arith.constant dense<[2, 5, 1, 7]> : tensor<4xi64>
   // CHECK-DAG: %[[cst_1:.*]] = arith.constant dense<[2, 5, 3]> : tensor<3xi64>
   // CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x5x7xf32>, tensor<4xi64>) -> tensor<2x5x1x7xf32>
-  // CHECK: %[[v1:.*]] = "tf.BatchMatMulV2"(%[[v0]], %arg1) {adj_x = false, adj_y = false} : (tensor<2x5x1x7xf32>, tensor<2x5x7x3xf32>) -> tensor<2x5x1x3xf32>
+  // CHECK: %[[v1:.*]] = "tf.BatchMatMulV2"(%[[v0]], %arg1) <{adj_x = false, adj_y = false}> : (tensor<2x5x1x7xf32>, tensor<2x5x7x3xf32>) -> tensor<2x5x1x3xf32>
   // CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<2x5x1x3xf32>, tensor<3xi64>) -> tensor<2x5x3xf32>
   // CHECK: return %[[v2]] : tensor<2x5x3xf32>
 }
@@ -158,7 +158,7 @@
   // CHECK-DAG: %[[cst_2:.*]] = arith.constant dense<[2, 5, 3]> : tensor<3xi64>
   // CHECK: %[[v0:.*]] = "tf.Transpose"(%arg1, %[[cst]]) : (tensor<2x5x3x7xf32>, tensor<4xi32>) -> tensor<2x5x7x3xf32>
   // CHECK: %[[v1:.*]] = "tf.Reshape"(%arg0, %[[cst_1]]) : (tensor<2x5x7xf32>, tensor<4xi64>) -> tensor<2x5x1x7xf32>
-  // CHECK: %[[v2:.*]] = "tf.BatchMatMulV2"(%[[v1]], %[[v0]]) {adj_x = false, adj_y = false} : (tensor<2x5x1x7xf32>, tensor<2x5x7x3xf32>) -> tensor<2x5x1x3xf32>
+  // CHECK: %[[v2:.*]] = "tf.BatchMatMulV2"(%[[v1]], %[[v0]]) <{adj_x = false, adj_y = false}> : (tensor<2x5x1x7xf32>, tensor<2x5x7x3xf32>) -> tensor<2x5x1x3xf32>
   // CHECK: %[[v3:.*]] = "tf.Reshape"(%[[v2]], %[[cst_2]]) : (tensor<2x5x1x3xf32>, tensor<3xi64>) -> tensor<2x5x3xf32>
   // CHECK: return %[[v3]] : tensor<2x5x3xf32>
 }
@@ -169,7 +169,7 @@
   // CHECK-LABEL: einsum_fourdreducelast
   // CHECK: %[[cst:.*]] = arith.constant dense<[0, 2, 1, 3]> : tensor<4xi32>
   // CHECK: %[[v0:.*]] = "tf.Transpose"(%arg1, %[[cst]]) : (tensor<2x3x5x13xf32>, tensor<4xi32>) -> tensor<2x5x3x13xf32>
-  // CHECK: %[[v1:.*]] = "tf.BatchMatMulV2"(%arg0, %[[v0]]) {adj_x = false, adj_y = false} : (tensor<2x5x7x3xf32>, tensor<2x5x3x13xf32>) -> tensor<2x5x7x13xf32>
+  // CHECK: %[[v1:.*]] = "tf.BatchMatMulV2"(%arg0, %[[v0]]) <{adj_x = false, adj_y = false}> : (tensor<2x5x7x3xf32>, tensor<2x5x3x13xf32>) -> tensor<2x5x7x13xf32>
   // CHECK: %[[v2:.*]] = "tf.Transpose"(%[[v1]], %[[cst]]) : (tensor<2x5x7x13xf32>, tensor<4xi32>) -> tensor<2x7x5x13xf32>
   // CHECK: return %[[v2]] : tensor<2x7x5x13xf32>
 }
@@ -183,7 +183,7 @@
   // CHECK-DAG: %[[cst_2:.*]] = arith.constant dense<[0, 1, 3, 2]> : tensor<4xi32>
   // CHECK: %[[v0:.*]] = "tf.Transpose"(%arg0, %[[cst]]) : (tensor<2x5x7x3xf32>, tensor<4xi32>) -> tensor<2x7x5x3xf32>
   // CHECK: %[[v1:.*]] = "tf.Transpose"(%arg1, %[[cst_1]]) : (tensor<2x11x7x3xf32>, tensor<4xi32>) -> tensor<2x7x3x11xf32>
-  // CHECK: %[[v2:.*]] = "tf.BatchMatMulV2"(%[[v0]], %[[v1]]) {adj_x = false, adj_y = false} : (tensor<2x7x5x3xf32>, tensor<2x7x3x11xf32>) -> tensor<2x7x5x11xf32>
+  // CHECK: %[[v2:.*]] = "tf.BatchMatMulV2"(%[[v0]], %[[v1]]) <{adj_x = false, adj_y = false}> : (tensor<2x7x5x3xf32>, tensor<2x7x3x11xf32>) -> tensor<2x7x5x11xf32>
   // CHECK: %[[v3:.*]] = "tf.Transpose"(%[[v2]], %[[cst_2]]) : (tensor<2x7x5x11xf32>, tensor<4xi32>) -> tensor<2x7x11x5xf32>
   // CHECK: return %[[v3]] : tensor<2x7x11x5xf32>
 }
@@ -196,7 +196,7 @@
   // CHECK-DAG: %[[cst_1:.*]] = arith.constant dense<[0, 2, 3, 1]> : tensor<4xi32>
   // CHECK: %[[v0:.*]] = "tf.Transpose"(%arg0, %[[cst:.*]]) : (tensor<3x4x5x6xf32>, tensor<4xi32>) -> tensor<3x5x4x6xf32>
   // CHECK: %[[v1:.*]] = "tf.Transpose"(%arg1, %[[cst_1]]) : (tensor<3x7x5x6xf32>, tensor<4xi32>) -> tensor<3x5x6x7xf32>
-  // CHECK: %[[v2:.*]] = "tf.BatchMatMulV2"(%[[v0]], %[[v1]]) {adj_x = false, adj_y = false} : (tensor<3x5x4x6xf32>, tensor<3x5x6x7xf32>) -> tensor<3x5x4x7xf32>
+  // CHECK: %[[v2:.*]] = "tf.BatchMatMulV2"(%[[v0]], %[[v1]]) <{adj_x = false, adj_y = false}> : (tensor<3x5x4x6xf32>, tensor<3x5x6x7xf32>) -> tensor<3x5x4x7xf32>
   // CHECK: return %[[v2]] : tensor<3x5x4x7xf32>
 }
 
@@ -204,7 +204,7 @@
   %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ijk,j->i"}: (tensor<4x5x6xf32>, tensor<5xf32>) -> tensor<4xf32>
   func.return %0 : tensor<4xf32>
 // CHECK-LABEL: einsum_no_match
-// CHECK: %[[v0:.*]] = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ijk,j->i"} : (tensor<4x5x6xf32>, tensor<5xf32>) -> tensor<4xf32>
+// CHECK: %[[v0:.*]] = "tf.Einsum"(%arg0, %arg1) <{equation = "ijk,j->i"}> {T = "tfdtype$DT_FLOAT"} : (tensor<4x5x6xf32>, tensor<5xf32>) -> tensor<4xf32>
 // CHECK: return %[[v0]]
 }
 
@@ -212,7 +212,7 @@
   %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ij,?zw->kq->i"}: (tensor<4x5xf32>, tensor<5xf32>) -> tensor<4xf32>
   func.return %0 : tensor<4xf32>
 // CHECK-LABEL: einsum_illegal_no_match
-// CHECK: %[[v0:.*]] = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ij,?zw->kq->i"} : (tensor<4x5xf32>, tensor<5xf32>) -> tensor<4xf32>
+// CHECK: %[[v0:.*]] = "tf.Einsum"(%arg0, %arg1) <{equation = "ij,?zw->kq->i"}> {T = "tfdtype$DT_FLOAT"} : (tensor<4x5xf32>, tensor<5xf32>) -> tensor<4xf32>
 // CHECK: return %[[v0]]
 }
 
@@ -223,7 +223,7 @@
 // CHECK-DAG: %[[cst:.*]] = arith.constant dense<[2, 1, 11]> : tensor<3xi64>
 // CHECK-DAG: %[[cst_1:.*]] = arith.constant dense<[2, 1, 1, 2]> : tensor<4xi64>
 // CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x1x1x11xf32>, tensor<3xi64>) -> tensor<2x1x11xf32>
-// CHECK: %[[v1:.*]] = "tf.BatchMatMulV2"(%[[v0]], %arg1) {adj_x = false, adj_y = false} : (tensor<2x1x11xf32>, tensor<2x11x2xf32>) -> tensor<2x1x2xf32>
+// CHECK: %[[v1:.*]] = "tf.BatchMatMulV2"(%[[v0]], %arg1) <{adj_x = false, adj_y = false}> : (tensor<2x1x11xf32>, tensor<2x11x2xf32>) -> tensor<2x1x2xf32>
 // CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<2x1x2xf32>, tensor<4xi64>) -> tensor<2x1x1x2xf32>
 // CHECK: return %[[v2]] : tensor<2x1x1x2xf32>
 }
@@ -236,14 +236,14 @@
 // CHECK-DAG: %[[cst_0:.*]] = arith.constant dense<[-1, 36, 1, 32]> : tensor<4xi64>
 // CHECK-DAG: %[[cst_1:.*]] = arith.constant dense<[0, 1]> : tensor<2xi32>
 // CHECK-DAG: %[[cst_2:.*]] = arith.constant dense<2> : tensor<1xi32>
-// CHECK-DAG: %[[cst_3:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG: %[[cst_3:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK: %[[v0:.*]] = "tf.Transpose"(%arg1, %cst) : (tensor<?x36x?x32xf32>, tensor<4xi32>) -> tensor<?x36x32x?xf32>
 // CHECK: %[[v1:.*]] = "tf.Reshape"(%arg0, %cst_0) : (tensor<?x36x32xf32>, tensor<4xi64>) -> tensor<?x36x1x32xf32>
-// CHECK: %[[v2:.*]] = "tf.BatchMatMulV2"(%1, %0) {adj_x = false, adj_y = false} : (tensor<?x36x1x32xf32>, tensor<?x36x32x?xf32>) -> tensor<?x36x1x?xf32>
+// CHECK: %[[v2:.*]] = "tf.BatchMatMulV2"(%1, %0) <{adj_x = false, adj_y = false}> : (tensor<?x36x1x32xf32>, tensor<?x36x32x?xf32>) -> tensor<?x36x1x?xf32>
 // CHECK: %[[v3:.*]] = "tf.Shape"(%arg0) : (tensor<?x36x32xf32>) -> tensor<3xi32>
 // CHECK: %[[v4:.*]] = "tf.Shape"(%arg1) : (tensor<?x36x?x32xf32>) -> tensor<4xi32>
-// CHECK: %[[v5:.*]] = "tf.Gather"(%3, %cst_1) {validate_indices = true} : (tensor<3xi32>, tensor<2xi32>) -> tensor<2xi32>
-// CHECK: %[[v6:.*]] = "tf.Gather"(%4, %cst_2) {validate_indices = true} : (tensor<4xi32>, tensor<1xi32>) -> tensor<1xi32>
+// CHECK: %[[v5:.*]] = "tf.Gather"(%3, %cst_1) <{validate_indices = true}> : (tensor<3xi32>, tensor<2xi32>) -> tensor<2xi32>
+// CHECK: %[[v6:.*]] = "tf.Gather"(%4, %cst_2) <{validate_indices = true}> : (tensor<4xi32>, tensor<1xi32>) -> tensor<1xi32>
 // CHECK: %[[v7:.*]] = "tf.Concat"(%cst_3, %5, %6) : (tensor<i32>, tensor<2xi32>, tensor<1xi32>) -> tensor<3xi32>
 // CHECK: %[[v8:.*]] = "tf.Reshape"(%2, %7) : (tensor<?x36x1x?xf32>, tensor<3xi32>) -> tensor<?x36x?xf32>
 // CHECK: return %[[v8]] : tensor<?x36x?xf32>
@@ -254,13 +254,13 @@
   func.return %0 : tensor<?x?x8x128xf32>
 // CHECK-LABEL: einsum_with_runtime_outputshape2
 // CHECK-DAG: %[[cst:.*]] = arith.constant dense<1024> : tensor<2xi64>
-// CHECK-DAG: %[[cst_0:.*]] = "tf.Const"() {value = dense<[8, 128]> : tensor<2xi32>} : () -> tensor<2xi32>
+// CHECK-DAG: %[[cst_0:.*]] = "tf.Const"() <{value = dense<[8, 128]> : tensor<2xi32>}> : () -> tensor<2xi32>
 // CHECK-DAG: %[[cst_1:.*]] = arith.constant dense<[0, 1]> : tensor<2xi32>
-// CHECK-DAG: %[[cst_2:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG: %[[cst_2:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK: %[[v0:.*]] = "tf.Reshape"(%arg1, %cst) : (tensor<1024x8x128xf32>, tensor<2xi64>) -> tensor<1024x1024xf32>
-// CHECK: %[[v1:.*]] = "tf.BatchMatMulV2"(%arg0, %0) {adj_x = false, adj_y = false} : (tensor<?x?x1024xf32>, tensor<1024x1024xf32>) -> tensor<?x?x1024xf32>
+// CHECK: %[[v1:.*]] = "tf.BatchMatMulV2"(%arg0, %0) <{adj_x = false, adj_y = false}> : (tensor<?x?x1024xf32>, tensor<1024x1024xf32>) -> tensor<?x?x1024xf32>
 // CHECK: %[[v2:.*]] = "tf.Shape"(%arg0) : (tensor<?x?x1024xf32>) -> tensor<3xi32>
-// CHECK: %[[v3:.*]] = "tf.Gather"(%2, %cst_1) {validate_indices = true} : (tensor<3xi32>, tensor<2xi32>) -> tensor<2xi32>
+// CHECK: %[[v3:.*]] = "tf.Gather"(%2, %cst_1) <{validate_indices = true}> : (tensor<3xi32>, tensor<2xi32>) -> tensor<2xi32>
 // CHECK: %[[v4:.*]] = "tf.Concat"(%cst_2, %3, %cst_0) : (tensor<i32>, tensor<2xi32>, tensor<2xi32>) -> tensor<4xi32>
 // CHECK: %[[v5:.*]] = "tf.Reshape"(%1, %4) : (tensor<?x?x1024xf32>, tensor<4xi32>) -> tensor<?x?x8x128xf32>
 // CHECK: return %[[v5]] : tensor<?x?x8x128xf32>
@@ -275,7 +275,7 @@
 // CHECK: %[[v0:.*]] = "tf.Shape"(%arg0) : (tensor<?x36x?xf32>) -> tensor<3xi32>
 // CHECK: %[[v1:.*]] = "tf.UnsortedSegmentProd"(%0, %cst, %cst_0) : (tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<4xi32>
 // CHECK: %[[v2:.*]] = "tf.Reshape"(%arg0, %1) : (tensor<?x36x?xf32>, tensor<4xi32>) -> tensor<?x36x1x?xf32>
-// CHECK: %[[v3:.*]] = "tf.BatchMatMulV2"(%2, %arg1) {adj_x = false, adj_y = false} : (tensor<?x36x1x?xf32>, tensor<?x36x?x32xf32>) -> tensor<?x36x1x32xf32>
+// CHECK: %[[v3:.*]] = "tf.BatchMatMulV2"(%2, %arg1) <{adj_x = false, adj_y = false}> : (tensor<?x36x1x?xf32>, tensor<?x36x?x32xf32>) -> tensor<?x36x1x32xf32>
 // CHECK: %[[v4:.*]] =  "tf.Reshape"(%3, %cst_1) : (tensor<?x36x1x32xf32>, tensor<3xi64>) -> tensor<?x36x32xf32>
 // CHECK: return %[[v4]] : tensor<?x36x32xf32>
 }
@@ -286,14 +286,14 @@
 // CHECK-LABEL: einsum_with_runtime_shape2
 // CHECK-DAG: %[[cst:.*]] = arith.constant dense<[1, 2, 0]> : tensor<3xi32>
 // CHECK-DAG: %[[cst_0:.*]] = arith.constant dense<[0, 1, 2, 2]> : tensor<4xi32>
-// CHECK-DAG: %[[cst_1:.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG: %[[cst_1:.*]] = "tf.Const"() <{value = dense<3> : tensor<i32>}> : () -> tensor<i32>
 // CHECK-DAG: %[[cst_2:.*]] = arith.constant dense<[512, 8]> : tensor<2xi64>
 // CHECK: %[[v0:.*]] = "tf.Transpose"(%arg1, %cst) : (tensor<8x8x64xf32>, tensor<3xi32>) -> tensor<8x64x8xf32>
 // CHECK: %[[v1:.*]] = "tf.Shape"(%arg0) : (tensor<?x?x8x64xf32>) -> tensor<4xi32>
 // CHECK: %[[v2:.*]] = "tf.UnsortedSegmentProd"(%1, %cst_0, %cst_1) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<3xi32>
 // CHECK: %[[v3:.*]] = "tf.Reshape"(%arg0, %2) : (tensor<?x?x8x64xf32>, tensor<3xi32>) -> tensor<?x?x512xf32>
 // CHECK: %[[v4:.*]] = "tf.Reshape"(%0, %cst_2) : (tensor<8x64x8xf32>, tensor<2xi64>) -> tensor<512x8xf32>
-// CHECK: %[[v5:.*]] = "tf.BatchMatMulV2"(%3, %4) {adj_x = false, adj_y = false} : (tensor<?x?x512xf32>, tensor<512x8xf32>) -> tensor<?x?x8xf32>
+// CHECK: %[[v5:.*]] = "tf.BatchMatMulV2"(%3, %4) <{adj_x = false, adj_y = false}> : (tensor<?x?x512xf32>, tensor<512x8xf32>) -> tensor<?x?x8xf32>
 // CHECK: return %[[v5]] : tensor<?x?x8xf32>
 }
 
@@ -305,7 +305,7 @@
 // CHECK-DAG: %[[cst_0:.*]] = arith.constant dense<[0, 2, 3, 1]> : tensor<4xi32>
 // CHECK: %[[v0:.*]] = "tf.Transpose"(%arg0, %cst) : (tensor<?x?x8x128xf32>, tensor<4xi32>) -> tensor<?x8x?x128xf32>
 // CHECK: %[[v1:.*]] = "tf.Transpose"(%arg1, %cst_0) : (tensor<1x?x8x128xf32>, tensor<4xi32>) -> tensor<1x8x128x?xf32>
-// CHECK: %[[v3:.*]] = "tf.BatchMatMulV2"(%0, %1) {adj_x = false, adj_y = false} : (tensor<?x8x?x128xf32>, tensor<1x8x128x?xf32>) -> tensor<?x8x?x?xf32>
+// CHECK: %[[v3:.*]] = "tf.BatchMatMulV2"(%0, %1) <{adj_x = false, adj_y = false}> : (tensor<?x8x?x128xf32>, tensor<1x8x128x?xf32>) -> tensor<?x8x?x?xf32>
 // CHECK: %[[v4:.*]] =  "tf.Transpose"(%2, %cst) : (tensor<?x8x?x?xf32>, tensor<4xi32>) -> tensor<?x?x8x?xf32>
 // CHECK: return %[[v4]] : tensor<?x?x8x?xf32>
 }
@@ -314,7 +314,7 @@
   %0 = "tf.Einsum"(%arg0, %arg1) {device = "", equation = "...x,xy->...y"} : (tensor<1x512x128xf32>, tensor<128x256xf32>) -> tensor<1x512x256xf32>
   func.return %0 : tensor<1x512x256xf32>
 // CHECK-LABEL: einsum_ellipsis
-// CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<1x512x128xf32>, tensor<128x256xf32>) -> tensor<1x512x256xf32>
+// CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) <{adj_x = false, adj_y = false}> : (tensor<1x512x128xf32>, tensor<128x256xf32>) -> tensor<1x512x256xf32>
 }
 
 func.func @einsum_ellipsis_in_both_sides(%arg0: tensor<1x11x19xf32>, %arg1: tensor<7x11x13x19xf32>) -> tensor<7x11x13xf32> {
@@ -326,7 +326,7 @@
   // CHECK-DAG: %[[cst_2:.*]] = arith.constant dense<[7, 11, 13]> : tensor<3xi64>
   // CHECK: %[[v0:.*]] = "tf.Transpose"(%arg1, %[[cst]]) : (tensor<7x11x13x19xf32>, tensor<4xi32>) -> tensor<7x11x19x13xf32>
   // CHECK: %[[v1:.*]] = "tf.Reshape"(%arg0, %[[cst_1]]) : (tensor<1x11x19xf32>, tensor<4xi64>) -> tensor<1x11x1x19xf32>
-  // CHECK: %[[v2:.*]] = "tf.BatchMatMulV2"(%[[v1]], %[[v0]]) {adj_x = false, adj_y = false} : (tensor<1x11x1x19xf32>, tensor<7x11x19x13xf32>) -> tensor<7x11x1x13xf32>
+  // CHECK: %[[v2:.*]] = "tf.BatchMatMulV2"(%[[v1]], %[[v0]]) <{adj_x = false, adj_y = false}> : (tensor<1x11x1x19xf32>, tensor<7x11x19x13xf32>) -> tensor<7x11x1x13xf32>
   // CHECK: %[[v3:.*]] = "tf.Reshape"(%[[v2]], %[[cst_2]]) : (tensor<7x11x1x13xf32>, tensor<3xi64>) -> tensor<7x11x13xf32>
   // CHECK: return %[[v3]] : tensor<7x11x13xf32>
 }
@@ -338,7 +338,7 @@
   // CHECK-DAG: %[[cst:.*]] = arith.constant dense<[2, 0, 1]> : tensor<3xi32>
   // CHECK-DAG: %[[cst_1:.*]] = arith.constant dense<[1, 2, 0]> : tensor<3xi32>
   // CHECK: %[[v0:.*]] = "tf.Transpose"(%arg1, %[[cst]]) : (tensor<3x2x1xf32>, tensor<3xi32>) -> tensor<1x3x2xf32>
-  // CHECK: %[[v1:.*]] = "tf.BatchMatMulV2"(%arg0, %[[v0]]) {adj_x = false, adj_y = false} : (tensor<5x4x3xf32>, tensor<1x3x2xf32>) -> tensor<5x4x2xf32>
+  // CHECK: %[[v1:.*]] = "tf.BatchMatMulV2"(%arg0, %[[v0]]) <{adj_x = false, adj_y = false}> : (tensor<5x4x3xf32>, tensor<1x3x2xf32>) -> tensor<5x4x2xf32>
   // CHECK: %[[v2:.*]] = "tf.Transpose"(%[[v1]], %[[cst_1]]) : (tensor<5x4x2xf32>, tensor<3xi32>) -> tensor<4x2x5xf32>
   // CHECK: return %[[v2]] : tensor<4x2x5xf32>
 }
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/embedding_pipelining.mlir b/tensorflow/compiler/mlir/tensorflow/tests/embedding_pipelining.mlir
index fdb64b9..8aa1cf6 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/embedding_pipelining.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/embedding_pipelining.mlir
@@ -17,7 +17,7 @@
     // CHECK: {{.*StatefulPartitionedCall.* f = @non_tpu.*}}
     // CHECK: {{.*StatefulPartitionedCall.* f = @start_step_1.*}}
     // CHECK: {{.*StatefulPartitionedCall.* f = @while_cond.*}}
-    // CHECK: {{.*tf.While.* body = @new_while_body.* cond = @new_while_cond.*}}
+    // CHECK: {{.*tf.While.* <{body = @new_while_body.* cond = @new_while_cond.*}}
     // CHECK: {{.*StatefulPartitionedCall.* f = @finish_step_nm2.*}}
     // CHECK: {{.*StatefulPartitionedCall.* f = @finish_step_nm1.*}}
     // CHECK: return
@@ -73,7 +73,7 @@
     // CHECK: {{.*StatefulPartitionedCall.* f = @non_tpu.*}}
     // CHECK: {{.*StatefulPartitionedCall.* f = @start_step_1.*}}
     // CHECK: {{.*StatefulPartitionedCall.* f = @while_cond.*}}
-    // CHECK: {{.*tf.While.* body = @new_while_body.* cond = @new_while_cond.*}}
+    // CHECK: {{.*tf.While.* <{body = @new_while_body.* cond = @new_while_cond.*}}
     // CHECK: {{.*StatefulPartitionedCall.* f = @finish_step_nm2.*}}
     // CHECK: {{.*StatefulPartitionedCall.* f = @finish_step_nm1.*}}
     // CHECK: return
@@ -112,7 +112,7 @@
   func.func private @while_body(%arg0: tensor<i32>) -> (tensor<i32>) {
     // The pipelining control flow and supporting functions stay the same as the training version above.
     // The order of these functions is also significant.
-    // CHECK: {{.*tf.While.* body = @new_while_body.* cond = @new_while_cond.* parallel_iterations = 3}}
+    // CHECK: {{.*tf.While.* <{body = @new_while_body.* cond = @new_while_cond.* parallel_iterations = 3}}
     // CHECK: return
     // metadata ops
     "tf.TPUReplicateMetadata"() {_has_manual_control_dependencies = true, _replication_info = "repl_info", num_replicas = 1 : i64} : () -> ()
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/end-to-end-tpu-reshard-variables.mlir b/tensorflow/compiler/mlir/tensorflow/tests/end-to-end-tpu-reshard-variables.mlir
index 0d42c2c..90ffba03 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/end-to-end-tpu-reshard-variables.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/end-to-end-tpu-reshard-variables.mlir
@@ -1,4 +1,4 @@
-// RUN: tf-opt %s -tf-tpu-bridge 2>&1 | FileCheck %s
+// RUN: tf-opt %s -tf-cluster-tpu-bridge-v2 -tfrt-lower-cluster-to-runtime-ops-tpu 2>&1 | FileCheck %s
 
 // TPUReshardVariables should be inserted even when While functions' shapes are
 // different than While operand shapes. Test the whole tf-tpu-bridge because
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/extract_head_tail_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/extract_head_tail_outside_compilation.mlir
index 5f0821a..5f48061 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/extract_head_tail_outside_compilation.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/extract_head_tail_outside_compilation.mlir
@@ -6,10 +6,10 @@
   // CHECK-LABEL: func @head_single_outside_compiled_op
   func.func @head_single_outside_compiled_op(%arg0: tensor<i32>) {
     // CHECK:      "tf_device.launch"
+    // CHECK-SAME: device = "/job:worker/replica:0/task:0/device:CPU:0"
     // CHECK-NEXT:   "tf.A"
     // CHECK-NOT:    _xla_outside_compilation
     // CHECK-NEXT:   tf_device.return
-    // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
     //
     // CHECK:      "tf_device.cluster"
     // CHECK-NEXT:   "tf.B"
@@ -27,10 +27,10 @@
   // CHECK-LABEL: func @head_single_outside_compiled_op_no_operands
   func.func @head_single_outside_compiled_op_no_operands() {
     // CHECK:      %[[LAUNCH_OUT:.*]] = "tf_device.launch"
+    // CHECK-SAME: device = "/job:worker/replica:0/task:0/device:CPU:0"
     // CHECK-NEXT:   %[[A_OUT:.*]] = "tf.A"
     // CHECK-NOT:    _xla_outside_compilation
     // CHECK-NEXT:   tf_device.return %[[A_OUT]]
-    // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
     //
     // CHECK:      "tf_device.cluster"
     // CHECK-NEXT:   "tf.B"(%[[LAUNCH_OUT]])
@@ -50,10 +50,10 @@
     // CHECK:      %[[A_OUT:.*]] = "tf.A"
     %a = "tf.A"() : () -> tensor<i32>
     // CHECK-NEXT: %[[LAUNCH_OUT:.*]] = "tf_device.launch"
+    // CHECK-SAME: device = "/job:worker/replica:0/task:0/device:CPU:0"
     // CHECK-NEXT:   %[[B_OUT:.*]] = "tf.B"
     // CHECK-NOT:    _xla_outside_compilation
     // CHECK-NEXT:   tf_device.return %[[B_OUT]]
-    // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
     //
     // CHECK:      "tf_device.cluster"
     // CHECK-NEXT:   "tf.C"(%[[LAUNCH_OUT]])
@@ -71,10 +71,10 @@
   // CHECK-LABEL: func @head_aliased_output
   func.func @head_aliased_output() -> (tensor<i32>, tensor<i32>, tensor<i32>) {
     // CHECK:      %[[LAUNCH_OUT:.*]] = "tf_device.launch"
+    // CHECK-SAME: device = "/job:worker/replica:0/task:0/device:CPU:0"
     // CHECK-NEXT:   %[[A_OUT:.*]] = "tf.A"
     // CHECK-NOT:    _xla_outside_compilation
     // CHECK-NEXT:   tf_device.return %[[A_OUT]]
-    // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
     //
     // CHECK:      %[[CLUSTER_OUT:.*]]:2 = "tf_device.cluster"
     // CHECK-NEXT:   %[[B_OUT:.*]] = "tf.B"(%[[LAUNCH_OUT]])
@@ -98,6 +98,7 @@
   // CHECK-LABEL: func @head_all_cluster_op
   func.func @head_all_cluster_op(%arg0: tensor<i32>) -> tensor<i32> {
     // CHECK:      %[[LAUNCH_OUT:.*]] = "tf_device.launch"
+    // CHECK-SAME: device = "/job:worker/replica:0/task:0/device:CPU:0"
     // CHECK-NEXT:   %[[A_OUT:.*]] = "tf.A"
     // CHECK-NOT:    _xla_outside_compilation
     // CHECK-NEXT:   %[[B_OUT:.*]] = "tf.B"(%[[A_OUT]])
@@ -105,7 +106,6 @@
     // CHECK-NEXT:   %[[C_OUT:.*]] = "tf.C"(%[[B_OUT]], %arg0)
     // CHECK-NOT:    _xla_outside_compilation
     // CHECK-NEXT:   tf_device.return %[[C_OUT]]
-    // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
     //
     // CHECK:      "tf_device.cluster"
     // CHECK-NEXT:   tf_device.return
@@ -122,6 +122,7 @@
   // CHECK-LABEL: func @head_multiple_outside_compiled_ops
   func.func @head_multiple_outside_compiled_ops(%arg0: tensor<i32>) {
     // CHECK:      %[[LAUNCH_OUT:.*]] = "tf_device.launch"
+    // CHECK-SAME: device = "/job:worker/replica:0/task:0/device:CPU:0"
     // CHECK-NEXT:   %[[A_OUT:.*]] = "tf.A"
     // CHECK-NOT:    _xla_outside_compilation
     // CHECK-NEXT:   %[[B_OUT:.*]] = "tf.B"(%[[A_OUT]])
@@ -129,7 +130,6 @@
     // CHECK-NEXT:   "tf.C"
     // CHECK-NOT:    _xla_outside_compilation
     // CHECK-NEXT:   tf_device.return %[[B_OUT]]
-    // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
     //
     // CHECK:      "tf_device.cluster"
     // CHECK-NEXT:   "tf.D"(%[[LAUNCH_OUT]])
@@ -149,10 +149,10 @@
     // CHECK:      tf_device.replicate([%arg0, %arg1] as %[[RI:.*]]: tensor<i32>)
     //
     // CHECK-NEXT:   %[[LAUNCH_OUT:.*]] = "tf_device.launch"()
+    // CHECK-SAME:   device = "TPU_REPLICATED_HOST_0"
     // CHECK-NEXT:     %[[A_OUT:.*]] = "tf.A"(%[[RI]])
     // CHECK-NOT:      _xla_outside_compilation
     // CHECK-NEXT:     tf_device.return %[[A_OUT]]
-    // CHECK-NEXT:   device = "TPU_REPLICATED_HOST_0"
     //
     // CHECK:        "tf_device.cluster"
     // CHECK-NEXT:     "tf.B"(%[[LAUNCH_OUT]])
@@ -215,10 +215,10 @@
     // CHECK-DAG:  device_assignment = []
     //
     // CHECK:      "tf_device.launch"
+    // CHECK-SAME: device = "/job:worker/replica:0/task:0/device:CPU:0"
     // CHECK-NEXT:   "tf.B"(%[[CLUSTER_OUT]])
     // CHECK-NOT:    _xla_outside_compilation
     // CHECK-NEXT:   tf_device.return
-    // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
     "tf_device.cluster"() ({
       %a = "tf.A"() : () -> tensor<i32>
       "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> ()
@@ -241,10 +241,10 @@
     // CHECK-DAG:  device_assignment = []
     //
     // CHECK:      %[[LAUNCH_OUT:.*]] = "tf_device.launch"
+    // CHECK-SAME: device = "/job:worker/replica:0/task:0/device:CPU:0"
     // CHECK-NEXT:   %[[B_OUT:.*]] = "tf.B"(%[[CLUSTER_OUT]])
     // CHECK-NOT:    _xla_outside_compilation
     // CHECK-NEXT:   tf_device.return %[[B_OUT]]
-    // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
     %cluster = "tf_device.cluster"() ({
       %a = "tf.A"() : () -> tensor<i32>
       %b = "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> tensor<i32>
@@ -268,12 +268,12 @@
     // CHECK-DAG:  device_assignment = []
     //
     // CHECK:      "tf_device.launch"
+    // CHECK-SAME: device = "/job:worker/replica:0/task:0/device:CPU:0"
     // CHECK-NEXT:   %[[C_OUT:.*]] = "tf.C"(%arg0, %[[CLUSTER_OUT]]#1)
     // CHECK-NOT:    _xla_outside_compilation
     // CHECK-NEXT:   "tf.D"(%[[C_OUT]], %arg0, %[[CLUSTER_OUT]]#0)
     // CHECK-NOT:    _xla_outside_compilation
     // CHECK-NEXT:   tf_device.return
-    // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
     "tf_device.cluster"() ({
       %a = "tf.A"() : () -> tensor<i32>
       %b = "tf.B"(%arg0) : (tensor<i32>) -> tensor<i32>
@@ -299,13 +299,13 @@
     // CHECK-DAG:  device_assignment = []
     //
     // CHECK:      "tf_device.launch"
+    // CHECK-SAME: device = "/job:worker/replica:0/task:0/device:CPU:0"
     // CHECK-NEXT:   %[[C_OUT:.*]] = "tf.C"(%arg0, %[[CLUSTER_OUT]]#2)
     // CHECK-NOT:    _xla_outside_compilation
     // CHECK         "tf.IfRegion"
     // CHECK:          "tf.D"(%[[C_OUT]], %arg0, %[[CLUSTER_OUT]]#0)
     // CHECK-NOT:      _xla_outside_compilation
     // CHECK:        tf_device.return
-    // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
     "tf_device.cluster"() ({
       %0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
       %a = "tf.A"() : () -> tensor<i32>
@@ -339,10 +339,10 @@
     // CHECK-DAG:  device_assignment = []
     //
     // CHECK:      %[[LAUNCH_OUT:.*]] = "tf_device.launch"
+    // CHECK-SAME: device = "/job:worker/replica:0/task:0/device:CPU:0"
     // CHECK-NEXT:   %[[D_OUT:.*]] = "tf.D"(%[[CLUSTER_OUT]]#0, %[[A_OUT]])
     // CHECK-NOT:    _xla_outside_compilation
     // CHECK-NEXT:   tf_device.return
-    // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
     %cluster:5 = "tf_device.cluster"() ({
       %c = "tf.C"()  : () -> tensor<i32>
       %d = "tf.D"(%c, %a) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
@@ -367,10 +367,10 @@
     // CHECK-DAG:    device_assignment = []
     //
     // CHECK-NEXT:   "tf_device.launch"()
+    // CHECK-SAME:   device = "TPU_REPLICATED_HOST_0"
     // CHECK-NEXT:     %[[B_OUT:.*]] = "tf.B"(%[[CLUSTER_OUT]], %[[RI]])
     // CHECK-NOT:      _xla_outside_compilation
     // CHECK-NEXT:     tf_device.return
-    // CHECK-NEXT:   device = "TPU_REPLICATED_HOST_0"
     tf_device.replicate([%arg0, %arg1] as %ri : tensor<i32>) {n = 2 : i32} {
       "tf_device.cluster"() ({
         %a = "tf.A"(%ri) : (tensor<i32>) -> tensor<i32>
@@ -402,10 +402,10 @@
   // CHECK-LABEL: func @head_tail_simple_extraction
   func.func @head_tail_simple_extraction(%arg0: tensor<i32>) -> tensor<i32> {
     // CHECK:      %[[HEAD_LAUNCH_OUT:.*]] = "tf_device.launch"
+    // CHECK-SAME: device = "/job:worker/replica:0/task:0/device:CPU:0"
     // CHECK-NEXT:   %[[A_OUT:.*]] = "tf.A"(%arg0)
     // CHECK-NOT:      _xla_outside_compilation
     // CHECK-NEXT:   tf_device.return %[[A_OUT]]
-    // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
     //
     // CHECK:      %[[CLUSTER_OUT:.*]] = "tf_device.cluster"
     // CHECK-NEXT:   %[[B_OUT:.*]] = "tf.B"(%[[HEAD_LAUNCH_OUT]])
@@ -417,10 +417,10 @@
     // CHECK-DAG:  device_assignment = []
     //
     // CHECK:      %[[TAIL_LAUNCH_OUT:.*]] = "tf_device.launch"
+    // CHECK-SAME: device = "/job:worker/replica:0/task:0/device:CPU:0"
     // CHECK-NEXT:   %[[C_OUT:.*]] = "tf.C"(%[[CLUSTER_OUT]])
     // CHECK-NOT:    _xla_outside_compilation
     // CHECK-NEXT:   tf_device.return %[[C_OUT]]
-    // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
     %cluster = "tf_device.cluster"() ({
       %a = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> tensor<i32>
       %b = "tf.B"(%a) : (tensor<i32>) -> tensor<i32>
@@ -436,10 +436,10 @@
     // CHECK:      tf_device.replicate([%arg0, %arg1] as %[[RI:.*]]: tensor<i32>)
     //
     // CHECK-NEXT:   %[[HEAD_LAUNCH_OUT:.*]] = "tf_device.launch"()
+    // CHECK-SAME:   device = "TPU_REPLICATED_HOST_0"
     // CHECK-NEXT:     %[[A_OUT:.*]] = "tf.A"(%[[RI]])
     // CHECK-NOT:      _xla_outside_compilation
     // CHECK-NEXT:     tf_device.return %[[A_OUT]]
-    // CHECK-NEXT:   device = "TPU_REPLICATED_HOST_0"
     //
     // CHECK:        %[[CLUSTER_OUT:.*]] = "tf_device.cluster"
     // CHECK-NEXT:     %[[B_OUT:.*]] = "tf.B"
@@ -453,10 +453,10 @@
     // CHECK-DAG:    device_assignment = []
     //
     // CHECK-NEXT:   "tf_device.launch"()
+    // CHECK-SAME:   device = "TPU_REPLICATED_HOST_0"
     // CHECK-NEXT:     "tf.D"(%[[HEAD_LAUNCH_OUT]], %[[CLUSTER_OUT]], %[[RI]])
     // CHECK-NOT:      _xla_outside_compilation
     // CHECK-NEXT:     tf_device.return
-    // CHECK-NEXT:   device = "TPU_REPLICATED_HOST_0"
     tf_device.replicate([%arg0, %arg1] as %ri : tensor<i32>) {n = 2 : i32} {
       "tf_device.cluster"() ({
         %a = "tf.A"(%ri) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> tensor<i32>
@@ -490,10 +490,10 @@
   // CHECK-LABEL: func @side_effect_head_no_operand
   func.func @side_effect_head_no_operand() {
     // CHECK:      %[[HEAD_LAUNCH_OUT:.*]] = "tf_device.launch"()
+    // CHECK-SAME: device = "/job:worker/replica:0/task:0/device:CPU:0"
     // CHECK-NEXT:   "tf.B"
     // CHECK-NEXT:   %[[C_OUT:.*]] = "tf.C"
     // CHECK-NEXT:   tf_device.return %[[C_OUT]]
-    // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
 
     // CHECK:      "tf_device.cluster"
     // CHECK-NEXT:   "tf.Const"
@@ -518,10 +518,10 @@
     // CHECK-NEXT:   tf_device.return %[[A_OUT]]
 
     // CHECK:      "tf_device.launch"()
+    // CHECK-SAME: device = "/job:worker/replica:0/task:0/device:CPU:0"
     // CHECK-NEXT:   "tf.B"(%[[CLUSTER_OUT]])
     // CHECK-NEXT:   "tf.C"
     // CHECK-NEXT:   tf_device.return
-    // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
     "tf_device.cluster"() ({
       %a = "tf.A"() : () -> tensor<i32>
       "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> ()
@@ -538,10 +538,10 @@
   // CHECK-LABEL: func @embedding_head_extraction
   func.func @embedding_head_extraction(%arg0: tensor<!tf_type.string>) {
     // CHECK:      "tf_device.launch"()
+    // CHECK-SAME: device = "/job:worker/replica:0/task:0/device:CPU:0"
     // CHECK-NEXT:   "tf.EnqueueTPUEmbeddingRaggedTensorBatch"
     // CHECK-NEXT:   "tf.EnqueueTPUEmbeddingArbitraryTensorBatch"
     // CHECK-NEXT:   tf_device.return
-    // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
 
     // CHECK:      "tf_device.cluster"
     // CHECK-NEXT:   "tf.UnknownOp"
@@ -560,9 +560,9 @@
   // CHECK-LABEL: func @op_after_embedding_head_extraction
   func.func @op_after_embedding_head_extraction() {
     // CHECK:      "tf_device.launch"()
+    // CHECK-SAME: device = "/job:worker/replica:0/task:0/device:CPU:0"
     // CHECK-NEXT:   "tf.A"
     // CHECK-NEXT:   tf_device.return
-    // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
 
     // CHECK:      "tf_device.cluster"
     // CHECK-NEXT:   "tf.RecvTPUEmbeddingActivations"
@@ -588,9 +588,9 @@
     // CHECK-NEXT:   tf_device.return
 
     // CHECK:      "tf_device.launch"()
+    // CHECK-SAME: device = "/job:worker/replica:0/task:0/device:CPU:0"
     // CHECK-NEXT:   "tf.A"
     // CHECK-NEXT:   tf_device.return
-    // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
     "tf_device.cluster"() ({
       "tf.UnknownOp"() : () -> ()
       "tf.A"() {_xla_outside_compilation = "cluster1"} : () -> ()
@@ -607,10 +607,10 @@
   // CHECK-LABEL: func @head_single_outside_compiled_op_in_generic_pipeline
   func.func @head_single_outside_compiled_op_in_generic_pipeline(%arg0: tensor<i32>) {
     // CHECK:      "tf_device.launch"
+    // CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:CPU:0"
     // CHECK-NEXT:   "tf.A"
     // CHECK-NOT:    _xla_outside_compilation
     // CHECK-NEXT:   tf_device.return
-    // CHECK-NEXT: device = "/job:localhost/replica:0/task:0/device:CPU:0"
     //
     // CHECK:      "tf_device.cluster"
     // CHECK-NEXT:   "tf.B"
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/extract_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/extract_outside_compilation.mlir
index 488c98a..cbd9942 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/extract_outside_compilation.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/extract_outside_compilation.mlir
@@ -32,10 +32,10 @@
   func.func @nodep_single_outside_compilation() -> () {
      // CHECK: "tf_device.parallel_execute"
      // CHECK-NEXT: "tf_device.launch"
+     // CHECK-SAME: device = "/job:worker/replica:0/task:0/device:CPU:0"
      // CHECK-NEXT: "tf.B"
      // CHECK-NOT: _xla_outside_compilation
      // CHECK-NEXT:   tf_device.return
-     // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
      // CHECK: "tf_device.cluster"
      // CHECK-NEXT: "tf.A"
      // CHECK: device_assignment =  [], num_cores_per_replica = 1 : i64, topology =  ""
@@ -102,9 +102,9 @@
     // CHECK:      %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
     // CHECK:        %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
     // CHECK-NEXT:     "tf_device.launch"
+    // CHECK-SAME:     device = "TPU_REPLICATED_HOST_0"
     // CHECK:            "tf.B"
     // CHECK-NEXT:       tf_device.return
-    // CHECK-NEXT:     device = "TPU_REPLICATED_HOST_0"
     // CHECK:          %[[TPU_CLUSTER_OUTPUT:[0-9]*]] = "tf_device.cluster"
     // CHECK:            tf_device.return
     // CHECK:          tf_device.return %[[TPU_CLUSTER_OUTPUT]]
@@ -484,7 +484,7 @@
     // CHECK:            %[[RECV_OUTPUT_2:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
     // CHECK-SAME:       key = "host_compute_channel_1_args"
     // CHECK:           "tf.D"(%[[RECV_OUTPUT_2]])
-    // CHECK:          "tf_device.cluster"
+    // CHECK:          "tf_device.cluster"()
     // CHECK:            %[[A_OUTPUT:[0-9]*]] = "tf.A"
     // CHECK:            "tf._XlaHostComputeMlir"(%[[A_OUTPUT]])
     // CHECK-SAME:       send_key = "host_compute_channel_0_args"
@@ -581,14 +581,14 @@
     // CHECK-NEXT:      %[[PREDICATE_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
     // CHECK-SAME:      key = "if_predicate_channel_1"
     // CHECK-NEXT:       tf.IfRegion"(%[[PREDICATE_RECV_OUTPUT]])
+    // CHECK:            _else_func_name = "test_else_name"
+    // CHECK-SAME:       _then_func_name = "test_then_name"
+    // CHECK-SAME:       is_stateless = false
     // CHECK-NEXT:         %[[ARG_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
     // CHECK-SAME:         key = "host_compute_channel_0_args"
     // CHECK:              "tf.D"(%[[ARG_RECV_OUTPUT]]#0, %[[ARG_RECV_OUTPUT]]#1)
     // CHECK-NOT:          "tf._XlaSendFromHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
     // CHECK:              "tf.Yield"() : () -> ()
-    // CHECK:            _else_func_name = "test_else_name"
-    // CHECK-SAME:       _then_func_name = "test_then_name"
-    // CHECK-SAME:       is_stateless = false
     // CHECK:          "tf_device.cluster"
     // CHECK:            %[[A_OUTPUT:[0-9]*]] = "tf.A"
     // CHECK:            %[[B_OUTPUT:[0-9]*]] = "tf.B"
@@ -596,11 +596,11 @@
     // CHECK:            "tf._XlaHostComputeMlir"
     // CHECK-SAME:       key = "if_predicate_channel_1"
     // CHECK-NEXT:       tf.IfRegion"(%[[G_OUTPUT]])
+    // CHECK:            is_stateless = false
     // CHECK:              "tf._XlaHostComputeMlir"(%[[B_OUTPUT]], %[[A_OUTPUT]])
     // CHECK-SAME:         recv_key = "host_compute_channel_0_retvals"
     // CHECK-SAME:         send_key = "host_compute_channel_0_args"
     // CHECK-NEXT:         "tf.Yield"() : () -> ()
-    // CHECK:            is_stateless = false
     %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<2xi32>) {n = 2 : i32} {
       %2 = "tf_device.cluster"() ({
         %3 = "tf.A"() : () -> (tensor<2xi32>)
@@ -637,13 +637,13 @@
     // CHECK-NEXT:      %[[PREDICATE_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
     // CHECK-SAME:      key = "if_predicate_channel_1"
     // CHECK-NEXT:       tf.IfRegion"(%[[PREDICATE_RECV_OUTPUT]])
+    // CHECK:            _else_func_name = "test_else_name"
+    // CHECK-SAME        _then_func_name = "test_then_name"
     // CHECK-NEXT:         %[[ARG_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
     // CHECK-SAME:         key = "host_compute_channel_0_args"
     // CHECK:              "tf.D"(%[[ARG_RECV_OUTPUT]]#0, %[[ARG_RECV_OUTPUT]]#1)
     // CHECK-NOT:          "tf._XlaSendFromHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
     // CHECK:              "tf.Yield"() : () -> ()
-    // CHECK:            _else_func_name = "test_else_name"
-    // CHECK-SAME        _then_func_name = "test_then_name"
     // CHECK:          "tf_device.cluster"
     // CHECK:            %[[A_OUTPUT:[0-9]*]] = "tf.A"
     // CHECK:            %[[B_OUTPUT:[0-9]*]] = "tf.B"
@@ -1044,6 +1044,7 @@
     // CHECK-DAG:        %[[PROGRAM_OUTPUT:.+]] = "tf._XlaCompileMlirPlaceholderProgramKey"
     // CHECK-DAG:        %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
     // CHECK-NEXT:       tf.WhileRegion"
+    // CHECK:            is_stateless = false
     // CHECK-NEXT:         %[[COND_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
     // CHECK-SAME:         key = "while_condition_channel_0"
     // CHECK:              "tf.Yield"(%[[COND_RECV_OUTPUT]])
@@ -1051,19 +1052,18 @@
     // CHECK:              %[[D_OUTPUT:[0-9]*]] = "tf.D"
     // CHECK:              "tf._XlaSendFromHostV2"(%[[D_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
     // CHECK-NEXT:         "tf.Yield"
-    // CHECK:            is_stateless = false
     // CHECK:          "tf_device.cluster"
     // CHECK:            %[[A_OUTPUT:[0-9]*]] = "tf.A"
     // CHECK:            %[[B_OUTPUT:[0-9]*]] = "tf.B"
     // CHECK:            %[[G_OUTPUT:[0-9]*]] = "tf.G"
     // CHECK-NEXT:       tf.WhileRegion"(%[[B_OUTPUT]], %[[A_OUTPUT]])
+    // CHECK:            is_stateless = false
     // CHECK:              %[[H_OUTPUT:[0-9]*]] = "tf.H"
     // CHECK-NEXT:         "tf.XlaSendToHost"(%[[H_OUTPUT]])
     // CHECK-NEXT:         "tf.Yield"(%[[H_OUTPUT]])
     // CHECK:              %[[C_OUTPUT:[0-9]*]] = "tf.C"
     // CHECK-NEXT:         %[[HOST_COMPUTE_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir"
     // CHECK-NEXT:         "tf.Yield"(%[[C_OUTPUT]], %[[HOST_COMPUTE_OUTPUT]])
-    // CHECK:            is_stateless = false
     %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<2xi32>) {n = 2 : i32} {
       %2 = "tf_device.cluster"() ({
         %3 = "tf.A"() : () -> (tensor<3xf32>)
@@ -1839,10 +1839,10 @@
   func.func @outside_compilation_model_parallelism() -> () {
      // CHECK: "tf_device.parallel_execute"
      // CHECK-NEXT: "tf_device.launch"
+     // CHECK-SAME: device = "/job:worker/replica:0/task:0/device:CPU:0"
      // CHECK-NEXT: "tf.B"
      // CHECK-NOT: _xla_outside_compilation
      // CHECK-NEXT:   tf_device.return
-     // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
      // CHECK: "tf_device.cluster"
      // CHECK-NEXT: "tf.A"
      // CHECK: num_cores_per_replica = 2 : i64
@@ -2118,34 +2118,34 @@
     // CHECK:         "tf_device.launch"
     // CHECK:           %[[PROGRAM0:.+]] = "tf._XlaCompileMlirPlaceholderProgramKey"
     // CHECK:           %[[RECV0:.+]] = "tf._XlaRecvAtHost"(%[[PROGRAM0]])
-    // CHECK-SAME:        _xla_has_host_transfer = true
     // CHECK-SAME:        device_ordinal = 0
     // CHECK-SAME:        key = "host_compute_channel_0_args"
+    // CHECK-SAME:        _xla_has_host_transfer = true
     // CHECK:           %[[B0:.+]] = "tf.OpB"(%[[RECV0]]) : (tensor<2x2xi64>) -> tensor<2x2xi64>
     // CHECK:           "tf._XlaSendFromHost"(%[[B0]], %[[PROGRAM0]])
-    // CHECK-SAME:        _xla_has_host_transfer = true
     // CHECK-SAME:        device_ordinal = 0
     // CHECK-SAME:        key = "host_compute_channel_0_retvals"
+    // CHECK-SAME:        _xla_has_host_transfer = true
     // CHECK:         }, {
     // CHECK:           %[[PROGRAM1:.+]] = "tf._XlaCompileMlirPlaceholderProgramKey"
     // CHECK:           %[[RECV1:.+]] = "tf._XlaRecvAtHost"(%[[PROGRAM1]])
-    // CHECK-SAME:        _xla_has_host_transfer = true
     // CHECK-SAME:        device_ordinal = 1
     // CHECK-SAME:        key = "host_compute_channel_0_args"
+    // CHECK-SAME:        _xla_has_host_transfer = true
     // CHECK:           %[[B1:.+]] = "tf.OpB"(%[[RECV1]]) : (tensor<2x2xi64>) -> tensor<2x2xi64>
     // CHECK:           "tf._XlaSendFromHost"(%[[B1]], %[[PROGRAM1]])
-    // CHECK-SAME:        _xla_has_host_transfer = true
     // CHECK-SAME:        device_ordinal = 1
     // CHECK-SAME:        key = "host_compute_channel_0_retvals"
+    // CHECK-SAME:        _xla_has_host_transfer = true
     // CHECK:         }, {
     // CHECK:           "tf_device.cluster"
     // CHECK:             %[[A:.+]] = "tf.OpA"
-    // CHECK:             %[[A_SHARD:.+]] = "tf.XlaSpmdFullToShardShape"(%[[A]]) {dim = -1 : i64, manual_sharding = "\08\03\1A\02\02\01\22\02\00\01", unspecified_dims = []} : (tensor<2x2xi64>) -> tensor<1x2xi64>
+    // CHECK:             %[[A_SHARD:.+]] = "tf.XlaSpmdFullToShardShape"(%[[A]]) <{dim = -1 : i64, manual_sharding = "\08\03\1A\02\02\01\22\02\00\01", unspecified_dims = []}> : (tensor<2x2xi64>) -> tensor<1x2xi64>
     // CHECK:             %[[B:.+]] = "tf._XlaHostComputeMlir"(%[[A_SHARD]])
     // CHECK-SAME:          manual_sharding = true
     // CHECK-SAME:          recv_key = "host_compute_channel_0_retvals"
     // CHECK-SAME:          send_key = "host_compute_channel_0_args"
-    // CHECK:             %[[B_FULL:.+]] = "tf.XlaSpmdShardToFullShape"(%[[B]]) {dim = -1 : i64, full_shape = #tf_type.shape<2x2>, manual_sharding = "\08\03\1A\02\02\01\22\02\00\01", unspecified_dims = []} : (tensor<1x2xi64>) -> tensor<2x2xi64>
+    // CHECK:             %[[B_FULL:.+]] = "tf.XlaSpmdShardToFullShape"(%[[B]]) <{dim = -1 : i64, full_shape = #tf_type.shape<2x2>, manual_sharding = "\08\03\1A\02\02\01\22\02\00\01", unspecified_dims = []}> : (tensor<1x2xi64>) -> tensor<2x2xi64>
     // CHECK:             "tf.OpC"(%[[B_FULL]])
     "tf_device.cluster"() ({
       %0 = "tf.OpA"() {_XlaSharding = "\08\03\1A\02\02\01\22\02\00\01"} : () -> tensor<2x2xi64>
@@ -2178,32 +2178,32 @@
     // CHECK:           %[[DEVICE0:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
     // CHECK-SAME:        logical_core = 0
     // CHECK:           %[[RECV0:.+]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM0]], %[[DEVICE0]])
-    // CHECK-SAME:        _xla_has_host_transfer = true
     // CHECK-SAME:        key = "host_compute_channel_0_args"
+    // CHECK-SAME:        _xla_has_host_transfer = true
     // CHECK:           %[[B0:.+]] = "tf.OpB"(%[[RECV0]]) : (tensor<2x2xi64>) -> tensor<2x2xi64>
     // CHECK:           "tf._XlaSendFromHostV2"(%[[B0]], %[[PROGRAM0]], %[[DEVICE0]])
-    // CHECK-SAME:        _xla_has_host_transfer = true
     // CHECK-SAME:        key = "host_compute_channel_0_retvals"
+    // CHECK-SAME:        _xla_has_host_transfer = true
     // CHECK:         }, {
     // CHECK:           %[[PROGRAM1:.+]] = "tf._XlaCompileMlirPlaceholderProgramKey"
     // CHECK:           %[[DEVICE1:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
     // CHECK-SAME:        logical_core = 1
     // CHECK:           %[[RECV1:.+]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM1]], %[[DEVICE1]])
-    // CHECK-SAME:        _xla_has_host_transfer = true
     // CHECK-SAME:        key = "host_compute_channel_0_args"
+    // CHECK-SAME:        _xla_has_host_transfer = true
     // CHECK:           %[[B1:.+]] = "tf.OpB"(%[[RECV1]]) : (tensor<2x2xi64>) -> tensor<2x2xi64>
     // CHECK:           "tf._XlaSendFromHostV2"(%[[B1]], %[[PROGRAM1]], %[[DEVICE1]])
-    // CHECK-SAME:        _xla_has_host_transfer = true
     // CHECK-SAME:        key = "host_compute_channel_0_retvals"
+    // CHECK-SAME:        _xla_has_host_transfer = true
     // CHECK:         }, {
     // CHECK:           "tf_device.cluster"
     // CHECK:             %[[A:.+]] = "tf.OpA"
-    // CHECK:             %[[A_SHARD:.+]] = "tf.XlaSpmdFullToShardShape"(%[[A]]) {dim = -1 : i64, manual_sharding = "\08\03\1A\02\02\01\22\02\00\01", unspecified_dims = []} : (tensor<2x2xi64>) -> tensor<1x2xi64>
+    // CHECK:             %[[A_SHARD:.+]] = "tf.XlaSpmdFullToShardShape"(%[[A]]) <{dim = -1 : i64, manual_sharding = "\08\03\1A\02\02\01\22\02\00\01", unspecified_dims = []}> : (tensor<2x2xi64>) -> tensor<1x2xi64>
     // CHECK:             %[[B:.+]] = "tf._XlaHostComputeMlir"(%[[A_SHARD]])
     // CHECK-SAME:          manual_sharding = true
     // CHECK-SAME:          recv_key = "host_compute_channel_0_retvals"
     // CHECK-SAME:          send_key = "host_compute_channel_0_args"
-    // CHECK:             %[[B_FULL:.+]] = "tf.XlaSpmdShardToFullShape"(%[[B]]) {dim = -1 : i64, full_shape = #tf_type.shape<2x2>, manual_sharding = "\08\03\1A\02\02\01\22\02\00\01", unspecified_dims = []} : (tensor<1x2xi64>) -> tensor<2x2xi64>
+    // CHECK:             %[[B_FULL:.+]] = "tf.XlaSpmdShardToFullShape"(%[[B]]) <{dim = -1 : i64, full_shape = #tf_type.shape<2x2>, manual_sharding = "\08\03\1A\02\02\01\22\02\00\01", unspecified_dims = []}> : (tensor<1x2xi64>) -> tensor<2x2xi64>
     // CHECK:             "tf.OpC"(%[[B_FULL]])
     tf_device.replicate() {n = 4 : i32} {
       "tf_device.cluster"() ({
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/extract_tpu_copy_with_dynamic_shape_op.mlir b/tensorflow/compiler/mlir/tensorflow/tests/extract_tpu_copy_with_dynamic_shape_op.mlir
index 2c2b36c..ec3fedf 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/extract_tpu_copy_with_dynamic_shape_op.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/extract_tpu_copy_with_dynamic_shape_op.mlir
@@ -6,10 +6,10 @@
   // CHECK-LABEL: func @valid_copy_op_in_replicated_host
 
   // CHECK: "tf_device.launch"
-  // CHECK: "TPU_REPLICATED_HOST_0"
+  // CHECK-SAME: "TPU_REPLICATED_HOST_0"
   // CHECK: "tf_device.launch"
+  // CHECK-SAME: "TPU_REPLICATED_CORE_0"
   // CHECK: "tf.TPUCopyWithDynamicShape"
-  // CHECK: "TPU_REPLICATED_CORE_0"
   func.func @valid_copy_op_in_replicated_host(
     %arg0: tensor<2048xi64> {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"},
     %arg1: tensor<2048xi64> {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}) -> (tensor<2048xi32>, tensor<2048xi32>) {
@@ -26,10 +26,10 @@
   // CHECK-LABEL: func @valid_copy_op_in_non_replicated_host
 
   // CHECK: "tf_device.launch"
-  // CHECK: "/job:localhost/replica:0/task:0/device:CPU:0"
+  // CHECK-SAME: "/job:localhost/replica:0/task:0/device:CPU:0"
   // CHECK: "tf_device.launch"
+  // CHECK-SAME: "/job:localhost/replica:0/task:0/device:TPU:0"
   // CHECK: "tf.TPUCopyWithDynamicShape"
-  // CHECK: "/job:localhost/replica:0/task:0/device:TPU:0"
   func.func @valid_copy_op_in_non_replicated_host(
     %arg0: tensor<2048xi64> {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"},
     %arg1: tensor<2048xi64> {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}) -> (tensor<2048xi32>, tensor<2048xi32>) {
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/fold-broadcast.mlir b/tensorflow/compiler/mlir/tensorflow/tests/fold-broadcast.mlir
index 9e7b5b2..5535bad 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/fold-broadcast.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/fold-broadcast.mlir
@@ -48,7 +48,7 @@
   %0 = "tf.BroadcastTo"(%arg1, %cst) : (tensor<7xf32>, tensor<2xi32>) -> tensor<5x7xf32>
   %1 = "tf.Equal"(%arg0, %0) {incompatible_shape_error = true} : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<5x7xi1>
   func.return %1 : tensor<5x7xi1>
-  // CHECK: %[[V0:.*]] = "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = true} : (tensor<5x7xf32>, tensor<7xf32>) -> tensor<5x7xi1>
+  // CHECK: %[[V0:.*]] = "tf.Equal"(%arg0, %arg1) <{incompatible_shape_error = true}> : (tensor<5x7xf32>, tensor<7xf32>) -> tensor<5x7xi1>
   // CHECK: %[[V0]] : tensor<5x7xi1>
 }
 
@@ -58,7 +58,7 @@
   %0 = "tf.BroadcastTo"(%arg1, %cst) : (tensor<7xf32>, tensor<2xi32>) -> tensor<5x7xf32>
   %1 = "tf.NotEqual"(%arg0, %0) {incompatible_shape_error = true} : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<5x7xi1>
   func.return %1 : tensor<5x7xi1>
-  // CHECK: %[[V0:.*]] = "tf.NotEqual"(%arg0, %arg1) {incompatible_shape_error = true} : (tensor<5x7xf32>, tensor<7xf32>) -> tensor<5x7xi1>
+  // CHECK: %[[V0:.*]] = "tf.NotEqual"(%arg0, %arg1) <{incompatible_shape_error = true}> : (tensor<5x7xf32>, tensor<7xf32>) -> tensor<5x7xi1>
   // CHECK: %[[V0]] : tensor<5x7xi1>
 }
 
@@ -79,7 +79,7 @@
   %0 = "tf.BroadcastTo"(%arg1, %cst) : (tensor<17x24xf32>, tensor<3xi64>) -> tensor<17x17x24xf32>
   %1 = "tf.BatchMatMulV2"(%arg0, %0) {adj_x = false, adj_y = false} : (tensor<17x17x17xf32>, tensor<17x17x24xf32>) -> tensor<17x17x24xf32>
   func.return %1 : tensor<17x17x24xf32>
-  // CHECK: %[[V0:.*]] = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<17x17x17xf32>, tensor<17x24xf32>) -> tensor<17x17x24xf32>
+  // CHECK: %[[V0:.*]] = "tf.BatchMatMulV2"(%arg0, %arg1) <{adj_x = false, adj_y = false}> : (tensor<17x17x17xf32>, tensor<17x24xf32>) -> tensor<17x17x24xf32>
   // CHECK: %[[V0]] : tensor<17x17x24xf32>
 }
 
@@ -89,7 +89,7 @@
   %0 = "tf.BroadcastTo"(%arg0, %cst) : (tensor<17x17xf32>, tensor<3xi64>) -> tensor<17x17x17xf32>
   %1 = "tf.BatchMatMulV2"(%0, %arg1) {adj_x = false, adj_y = false} : (tensor<17x17x17xf32>, tensor<17x17x24xf32>) -> tensor<17x17x24xf32>
   func.return %1 : tensor<17x17x24xf32>
-  // CHECK: %[[V0:.*]] = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<17x17xf32>, tensor<17x17x24xf32>) -> tensor<17x17x24xf32>
+  // CHECK: %[[V0:.*]] = "tf.BatchMatMulV2"(%arg0, %arg1) <{adj_x = false, adj_y = false}> : (tensor<17x17xf32>, tensor<17x17x24xf32>) -> tensor<17x17x24xf32>
   // CHECK: %[[V0]] : tensor<17x17x24xf32>
 }
 
@@ -108,6 +108,6 @@
   %cst = arith.constant dense<5> : tensor<2xi64>
   %0 = "tf.BroadcastTo"(%cst, %cst) : (tensor<2xi64>, tensor<2xi64>) -> tensor<5x5xi64>
   func.return %0 : tensor<5x5xi64>
-  // CHECK: %[[V0:.*]] = "tf.Const"() {value = dense<5> : tensor<5x5xi64>} : () -> tensor<5x5xi64>
+  // CHECK: %[[V0:.*]] = "tf.Const"() <{value = dense<5> : tensor<5x5xi64>}> : () -> tensor<5x5xi64>
   // CHECK: %[[V0]] : tensor<5x5xi64>
 }
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/freeze_variables.mlir b/tensorflow/compiler/mlir/tensorflow/tests/freeze_variables.mlir
index 7c18a389..a458a20 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/freeze_variables.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/freeze_variables.mlir
@@ -464,7 +464,7 @@
     func.return %arg0, %0 : tensor<?xf32>, tensor<0xf32>
   }
   // CHECK: func.func private @f_batch_callee(%[[ARG_0:.*]]: tensor<?xf32>) -> (tensor<?xf32>, tensor<0xf32>)
-  // CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() {{{.*value = dense<> : tensor<0xf32>.*}}} : () -> tensor<0xf32>
+  // CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() <{value = dense<> : tensor<0xf32>}> : () -> tensor<0xf32>
   // CHECK: return %[[ARG_0]], %[[CST_0]] : tensor<?xf32>, tensor<0xf32>
 
   func.func @f(%arg: tensor<1xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
@@ -474,6 +474,6 @@
   }
   // CHECK: func.func @f(%[[ARG_1:.*]]: tensor<1xf32>)
   // Make sure that `operandSegmentSizes` attribute is also updated.
-  // CHECK-NEXT: %[[BATCH_FUNC:.*]]:2 = "tf.BatchFunction"(%[[ARG_1]]) {{{.*operandSegmentSizes = array<i32: 1, 0>.*}}} : (tensor<1xf32>) -> (tensor<*xf32>, tensor<*xf32>)
+  // CHECK-NEXT: %[[BATCH_FUNC:.*]]:2 = "tf.BatchFunction"(%[[ARG_1]]) <{{{.*operandSegmentSizes = array<i32: 1, 0>.*}}}> : (tensor<1xf32>) -> (tensor<*xf32>, tensor<*xf32>)
   // CHECK: return %[[BATCH_FUNC]]#0, %[[BATCH_FUNC]]#1 : tensor<*xf32>, tensor<*xf32>
 }
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-cfg.mlir b/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-cfg.mlir
index 4339cd7..8deedc0 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-cfg.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-cfg.mlir
@@ -68,14 +68,14 @@
 // CHECK:   [[PRED:%.+]] = tensor.extract [[TOBOOL]][] : tensor<i1>
 // CHECK:   cf.cond_br [[PRED]], ^bb1, ^bb2
 // CHECK: ^bb1:
-// CHECK:   [[CAST0:%.+]] = "tf.Cast"(%arg1) {Truncate = false} : (tensor<!tf_type.variant<tensor<f32>>>) -> tensor<!tf_type.variant>
+// CHECK:   [[CAST0:%.+]] = "tf.Cast"(%arg1) <{Truncate = false}> : (tensor<!tf_type.variant<tensor<f32>>>) -> tensor<!tf_type.variant>
 // CHECK:   [[THEN:%.+]] = call @testIfThen([[CAST0]]) : (tensor<!tf_type.variant>) -> tensor<!tf_type.variant>
-// CHECK:   [[CAST1:%.+]] = "tf.Cast"([[THEN]]) {Truncate = false} : (tensor<!tf_type.variant>) -> tensor<!tf_type.variant<tensor<f32>>>
+// CHECK:   [[CAST1:%.+]] = "tf.Cast"([[THEN]]) <{Truncate = false}> : (tensor<!tf_type.variant>) -> tensor<!tf_type.variant<tensor<f32>>>
 // CHECK:   cf.br ^bb3([[CAST1]] : tensor<!tf_type.variant<tensor<f32>>>)
 // CHECK: ^bb2:
-// CHECK:   [[CAST2:%.+]] = "tf.Cast"(%arg1) {Truncate = false} : (tensor<!tf_type.variant<tensor<f32>>>) -> tensor<!tf_type.variant>
+// CHECK:   [[CAST2:%.+]] = "tf.Cast"(%arg1) <{Truncate = false}> : (tensor<!tf_type.variant<tensor<f32>>>) -> tensor<!tf_type.variant>
 // CHECK:   [[ELSE:%.+]] = call @testIfElse([[CAST2]]) : (tensor<!tf_type.variant>) -> tensor<!tf_type.variant>
-// CHECK:   [[CAST3:%.+]] = "tf.Cast"([[ELSE]]) {Truncate = false} : (tensor<!tf_type.variant>) -> tensor<!tf_type.variant<tensor<f32>>>
+// CHECK:   [[CAST3:%.+]] = "tf.Cast"([[ELSE]]) <{Truncate = false}> : (tensor<!tf_type.variant>) -> tensor<!tf_type.variant<tensor<f32>>>
 // CHECK:   cf.br ^bb3([[CAST3]] : tensor<!tf_type.variant<tensor<f32>>>)
 // CHECK: ^bb3([[BBARG0:%.+]]: tensor<!tf_type.variant<tensor<f32>>>):
 // CHECK:   return [[BBARG0]] : tensor<!tf_type.variant<tensor<f32>>>
@@ -201,20 +201,20 @@
     cond = @testWhileCond, body = @testWhileBody, is_stateless = false
   } : (tensor<!tf_type.variant<tensor<1x3xf32>>>) -> (tensor<!tf_type.variant<tensor<*xf32>>>)
   func.return %0 : tensor<!tf_type.variant<tensor<*xf32>>>
-// CHECK:   [[CASTENTRY:%.+]] = "tf.Cast"(%arg0) {Truncate = false} : (tensor<!tf_type.variant<tensor<1x3xf32>>>) -> tensor<!tf_type.variant>
+// CHECK:   [[CASTENTRY:%.+]] = "tf.Cast"(%arg0) <{Truncate = false}> : (tensor<!tf_type.variant<tensor<1x3xf32>>>) -> tensor<!tf_type.variant>
 // CHECK:   cf.br ^bb1([[CASTENTRY]] : tensor<!tf_type.variant>)
 // CHECK: ^bb1([[CONDARG0:%.+]]: tensor<!tf_type.variant>):        // 2 preds: ^bb0, ^bb2
 // CHECK:   [[CONTINUE:%.+]] = call @testWhileCond([[CONDARG0]]) : (tensor<!tf_type.variant>) -> tensor<i1>
 // CHECK:   [[TOBOOL:%.+]] = "tf.ToBool"([[CONTINUE]]) : (tensor<i1>) -> tensor<i1>
 // CHECK:   [[PRED:%.+]] = tensor.extract [[TOBOOL]][] : tensor<i1>
-// CHECK:   [[CASTCONDARG0:%.+]] = "tf.Cast"([[CONDARG0]]) {Truncate = false} : (tensor<!tf_type.variant>) -> tensor<!tf_type.variant<tensor<1x?xf32>>>
+// CHECK:   [[CASTCONDARG0:%.+]] = "tf.Cast"([[CONDARG0]]) <{Truncate = false}> : (tensor<!tf_type.variant>) -> tensor<!tf_type.variant<tensor<1x?xf32>>>
 // CHECK:   cf.cond_br [[PRED]], ^bb2([[CASTCONDARG0]] : tensor<!tf_type.variant<tensor<1x?xf32>>>), ^bb3([[CASTCONDARG0]] : tensor<!tf_type.variant<tensor<1x?xf32>>>)
 // CHECK: ^bb2([[BODYARG0:%.+]]: tensor<!tf_type.variant<tensor<1x?xf32>>>):       // pred: ^bb1
 // CHECK:   [[WHILERET:%.+]] = call @testWhileBody([[BODYARG0]]) : (tensor<!tf_type.variant<tensor<1x?xf32>>>) -> tensor<!tf_type.variant<tensor<?x?xf32>>>
-// CHECK:   [[CASTWHILERET:%.+]] = "tf.Cast"([[WHILERET]]) {Truncate = false} : (tensor<!tf_type.variant<tensor<?x?xf32>>>) -> tensor<!tf_type.variant>
+// CHECK:   [[CASTWHILERET:%.+]] = "tf.Cast"([[WHILERET]]) <{Truncate = false}> : (tensor<!tf_type.variant<tensor<?x?xf32>>>) -> tensor<!tf_type.variant>
 // CHECK:   cf.br ^bb1([[CASTWHILERET]] : tensor<!tf_type.variant>)
 // CHECK: ^bb3([[EXITARG0:%.+]]: tensor<!tf_type.variant<tensor<1x?xf32>>>):       // pred: ^bb1
-// CHECK:   [[CASTEXITARG0:%.+]] = "tf.Cast"([[EXITARG0]]) {Truncate = false} : (tensor<!tf_type.variant<tensor<1x?xf32>>>) -> tensor<!tf_type.variant<tensor<*xf32>>>
+// CHECK:   [[CASTEXITARG0:%.+]] = "tf.Cast"([[EXITARG0]]) <{Truncate = false}> : (tensor<!tf_type.variant<tensor<1x?xf32>>>) -> tensor<!tf_type.variant<tensor<*xf32>>>
 // CHECK:   return [[CASTEXITARG0]] : tensor<!tf_type.variant<tensor<*xf32>>>
 
 }
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-regions.mlir b/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-regions.mlir
index d426267..c5cf589 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-regions.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-regions.mlir
@@ -14,17 +14,17 @@
   } : (tensor<i1>, tensor<*xf32>) -> tensor<*xf32>
 
   // CHECK: "tf.IfRegion"
+  // CHECK-SAME: <{_else_func_name = "testIf1Else"
+  // CHECK-SAME: _then_func_name = "testIf1Then"
+  // CHECK-SAME: is_stateless = false
   // CHECK: [[Result0:%.*]] = func.call @testIf1Then
   // CHECK: "tf.Yield"([[Result0]])
   // CHECK: [[Result1:%.*]] = func.call @testIf1Else
   // CHECK: "tf.Yield"([[Result1]])
   // CHECK: _attr0 = 10
   // CHECK-SAME: _attr1 = true
-  // CHECK-SAME: _else_func_name = "testIf1Else"
-  // CHECK-SAME: _then_func_name = "testIf1Then"
   // CHECK-NOT: attr2 =
   // CHECK-NOT: else_branch
-  // CHECK-SAME: is_stateless = false
   // CHECK-NOT: then_branch
   // CHECK-SAME: }
   func.return %0 : tensor<*xf32>
@@ -179,6 +179,7 @@
   } : (tensor<*xf32>) -> (tensor<*xf32>)
 
   // CHECK: [[Result0:%.*]] = "tf.WhileRegion"
+  // CHECK-SAME: is_stateless = true
   // CHECK: ^bb0([[CARG0:%[^:]*]]:
   // CHECK: [[Result1:%.*]] = func.call @testWhileCond
   // CHECK: "tf.Yield"([[Result1]], [[CARG0]])
@@ -189,7 +190,6 @@
   // CHECK-NOT: attr2 =
   // CHECK-NOT: cond =
   // CHECK-NOT: body =
-  // CHECK-SAME: is_stateless = true
   // CHECK: return [[Result0]]
   func.return %1 : tensor<*xf32>
 }
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/fused_kernel_matcher.mlir b/tensorflow/compiler/mlir/tensorflow/tests/fused_kernel_matcher.mlir
index 8458cb8..38380de 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/fused_kernel_matcher.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/fused_kernel_matcher.mlir
@@ -6,7 +6,7 @@
 
 // CHECK-LABEL: conv2DBiasAdd_noActivation
 func.func @conv2DBiasAdd_noActivation(%arg0: tensor<128xf32>, %arg1: tensor<1x1x3x128xf32>, %arg2: tensor<8x32x32x3xf32>) -> (tensor<*xf32>) {
-  // CHECK: %[[VAL_0:.*]] = "tf._FusedConv2D"(%arg2, %arg1, %arg0) {TArgs = [f32], data_format = "NHWC", dilations = [1, 1, 1, 1], epsilon = 0.000000e+00 : f32, explicit_paddings = [], fused_ops = ["BiasAdd"], num_args = 1 : i64, operandSegmentSizes = array<i32: 1, 1, 1, 0>, padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<8x32x32x3xf32>, tensor<1x1x3x128xf32>, tensor<128xf32>) -> tensor<*xf32>
+  // CHECK: %[[VAL_0:.*]] = "tf._FusedConv2D"(%arg2, %arg1, %arg0) <{data_format = "NHWC", dilations = [1, 1, 1, 1], epsilon = 0.000000e+00 : f32, explicit_paddings = [], fused_ops = ["BiasAdd"], num_args = 1 : i64, operandSegmentSizes = array<i32: 1, 1, 1, 0>, padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> {TArgs = [f32]} : (tensor<8x32x32x3xf32>, tensor<1x1x3x128xf32>, tensor<128xf32>) -> tensor<*xf32>
   // CHECK: %[[VAL_1:.*]] = "tf.Identity"(%[[VAL_0]]) : (tensor<*xf32>) -> tensor<*xf32>
   // CHECK: return %[[VAL_1]]
   %0 = "tf.Conv2D"(%arg2, %arg1) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<8x32x32x3xf32>, tensor<1x1x3x128xf32>) -> tensor<*xf32>
@@ -17,7 +17,7 @@
 
 // CHECK-LABEL: conv2DBiasAdd_reluActivation
 func.func @conv2DBiasAdd_reluActivation(%arg0: tensor<128xf32>, %arg1: tensor<1x1x3x128xf32>, %arg2: tensor<8x32x32x3xf32>) -> (tensor<*xf32>) {
-  // CHECK: %[[VAL_0:.*]] = "tf._FusedConv2D"(%arg2, %arg1, %arg0) {TArgs = [f32], data_format = "NHWC", dilations = [1, 1, 1, 1], epsilon = 0.000000e+00 : f32, explicit_paddings = [], fused_ops = ["BiasAdd", "Relu"], num_args = 1 : i64, operandSegmentSizes = array<i32: 1, 1, 1, 0>, padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<8x32x32x3xf32>, tensor<1x1x3x128xf32>, tensor<128xf32>) -> tensor<*xf32>
+  // CHECK: %[[VAL_0:.*]] = "tf._FusedConv2D"(%arg2, %arg1, %arg0) <{data_format = "NHWC", dilations = [1, 1, 1, 1], epsilon = 0.000000e+00 : f32, explicit_paddings = [], fused_ops = ["BiasAdd", "Relu"], num_args = 1 : i64, operandSegmentSizes = array<i32: 1, 1, 1, 0>, padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> {TArgs = [f32]} : (tensor<8x32x32x3xf32>, tensor<1x1x3x128xf32>, tensor<128xf32>) -> tensor<*xf32>
   // CHECK: %[[VAL_1:.*]] = "tf.Identity"(%[[VAL_0]]) : (tensor<*xf32>) -> tensor<*xf32>
   // CHECK: return %[[VAL_1]]
   %0 = "tf.Conv2D"(%arg2, %arg1) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<8x32x32x3xf32>, tensor<1x1x3x128xf32>) -> tensor<*xf32>
@@ -29,7 +29,7 @@
 
 // CHECK-LABEL: conv2DBiasAdd_relu6Activation
 func.func @conv2DBiasAdd_relu6Activation(%arg0: tensor<128xf32>, %arg1: tensor<1x1x3x128xf32>, %arg2: tensor<8x32x32x3xf32>) -> (tensor<*xf32>) {
-  // CHECK: %[[VAL_0:.*]] = "tf._FusedConv2D"(%arg2, %arg1, %arg0) {TArgs = [f32], data_format = "NHWC", dilations = [1, 1, 1, 1], epsilon = 0.000000e+00 : f32, explicit_paddings = [], fused_ops = ["BiasAdd", "Relu6"], num_args = 1 : i64, operandSegmentSizes = array<i32: 1, 1, 1, 0>, padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<8x32x32x3xf32>, tensor<1x1x3x128xf32>, tensor<128xf32>) -> tensor<*xf32>
+  // CHECK: %[[VAL_0:.*]] = "tf._FusedConv2D"(%arg2, %arg1, %arg0) <{data_format = "NHWC", dilations = [1, 1, 1, 1], epsilon = 0.000000e+00 : f32, explicit_paddings = [], fused_ops = ["BiasAdd", "Relu6"], num_args = 1 : i64, operandSegmentSizes = array<i32: 1, 1, 1, 0>, padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> {TArgs = [f32]} : (tensor<8x32x32x3xf32>, tensor<1x1x3x128xf32>, tensor<128xf32>) -> tensor<*xf32>
   // CHECK: %[[VAL_1:.*]] = "tf.Identity"(%[[VAL_0]]) : (tensor<*xf32>) -> tensor<*xf32>
   // CHECK: return %[[VAL_1]]
   %0 = "tf.Conv2D"(%arg2, %arg1) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<8x32x32x3xf32>, tensor<1x1x3x128xf32>) -> tensor<*xf32>
@@ -41,7 +41,7 @@
 
 // CHECK-LABEL: conv2DBiasAdd_eluActivation
 func.func @conv2DBiasAdd_eluActivation(%arg0: tensor<128xf32>, %arg1: tensor<1x1x3x128xf32>, %arg2: tensor<8x32x32x3xf32>) -> (tensor<*xf32>) {
-  // CHECK: %[[VAL_0:.*]] = "tf._FusedConv2D"(%arg2, %arg1, %arg0) {TArgs = [f32], data_format = "NHWC", dilations = [1, 1, 1, 1], epsilon = 0.000000e+00 : f32, explicit_paddings = [], fused_ops = ["BiasAdd", "Elu"], num_args = 1 : i64, operandSegmentSizes = array<i32: 1, 1, 1, 0>, padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<8x32x32x3xf32>, tensor<1x1x3x128xf32>, tensor<128xf32>) -> tensor<*xf32>
+  // CHECK: %[[VAL_0:.*]] = "tf._FusedConv2D"(%arg2, %arg1, %arg0) <{data_format = "NHWC", dilations = [1, 1, 1, 1], epsilon = 0.000000e+00 : f32, explicit_paddings = [], fused_ops = ["BiasAdd", "Elu"], num_args = 1 : i64, operandSegmentSizes = array<i32: 1, 1, 1, 0>, padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> {TArgs = [f32]} : (tensor<8x32x32x3xf32>, tensor<1x1x3x128xf32>, tensor<128xf32>) -> tensor<*xf32>
   // CHECK: %[[VAL_1:.*]] = "tf.Identity"(%[[VAL_0]]) : (tensor<*xf32>) -> tensor<*xf32>
   // CHECK: return %[[VAL_1]]
   %0 = "tf.Conv2D"(%arg2, %arg1) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<8x32x32x3xf32>, tensor<1x1x3x128xf32>) -> tensor<*xf32>
@@ -64,7 +64,7 @@
 
 // CHECK-LABEL: conv2DBiasAdd_biasAddMultipleUse
 func.func @conv2DBiasAdd_biasAddMultipleUse(%arg0: tensor<128xf32>, %arg1: tensor<1x1x3x128xf32>, %arg2: tensor<8x32x32x3xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
-  // CHECK-DAG: %[[VAL:.*]] = "tf._FusedConv2D"(%arg2, %arg1, %arg0) {TArgs = [f32], data_format = "NHWC", dilations = [1, 1, 1, 1], epsilon = 0.000000e+00 : f32, explicit_paddings = [], fused_ops = ["BiasAdd"], num_args = 1 : i64, operandSegmentSizes = array<i32: 1, 1, 1, 0>, padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<8x32x32x3xf32>, tensor<1x1x3x128xf32>, tensor<128xf32>) -> tensor<*xf32>
+  // CHECK-DAG: %[[VAL:.*]] = "tf._FusedConv2D"(%arg2, %arg1, %arg0) <{data_format = "NHWC", dilations = [1, 1, 1, 1], epsilon = 0.000000e+00 : f32, explicit_paddings = [], fused_ops = ["BiasAdd"], num_args = 1 : i64, operandSegmentSizes = array<i32: 1, 1, 1, 0>, padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> {TArgs = [f32]} : (tensor<8x32x32x3xf32>, tensor<1x1x3x128xf32>, tensor<128xf32>) -> tensor<*xf32>
   // CHECK-DAG: %[[VAL_0:.*]] = "tf.Elu"(%[[VAL]]) : (tensor<*xf32>) -> tensor<*xf32>
   // CHECK-DAG: %[[VAL_1:.*]] = "tf.Identity"(%[[VAL_0]]) : (tensor<*xf32>) -> tensor<*xf32>
   // CHECK-DAG: %[[VAL_2:.*]] = "tf.Identity"(%[[VAL]]) : (tensor<*xf32>) -> tensor<*xf32>
@@ -89,7 +89,7 @@
 // CHECK-LABEL: conv2D_noFusion1
 func.func @conv2D_noFusion1(%arg0: tensor<*xf32>, %arg1: tensor<1x1x3x128xf32>, %arg2: tensor<8x32x32x3xf32>) -> (tensor<*xf32>) {
   // CHECK-NOT: "tf._FusedConv2D"
-  %0 = "tf.Conv2D"(%arg2, %arg1) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<8x32x32x3xf32>, tensor<1x1x3x128xf32>) -> tensor<*xf32>
+  %0 = "tf.Conv2D"(%arg2, %arg1) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<8x32x32x3xf32>, tensor<1x1x3x128xf32>) -> tensor<*xf32>
   // The result of the conv must be the first input to BiasAdd to be fusable.
   %1 = "tf.BiasAdd"(%arg0, %0) {data_format = "NHWC"} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
   %2 = "tf.Elu"(%1) : (tensor<*xf32>) -> tensor<*xf32>
@@ -114,7 +114,7 @@
 
 // CHECK-LABEL: matmulBiasAdd
 func.func @matmulBiasAdd(%arg0: tensor<64xf32>, %arg1: tensor<8x32xf32>, %arg2: tensor<32x64xf32>) -> (tensor<*xf32>) {
-  // CHECK: %[[VAL_3:.*]] = "tf._FusedMatMul"(%arg1, %arg2, %arg0) {epsilon = 0.000000e+00 : f32, fused_ops = ["BiasAdd"], transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>, tensor<64xf32>) -> tensor<*xf32>
+  // CHECK: %[[VAL_3:.*]] = "tf._FusedMatMul"(%arg1, %arg2, %arg0) <{epsilon = 0.000000e+00 : f32, fused_ops = ["BiasAdd"], transpose_a = false, transpose_b = false}> : (tensor<8x32xf32>, tensor<32x64xf32>, tensor<64xf32>) -> tensor<*xf32>
   // CHECK: %[[VAL_4:.*]] = "tf.Identity"(%[[VAL_3]]) : (tensor<*xf32>) -> tensor<*xf32>
   // CHECK: return %[[VAL_4]]
   %3 = "tf.MatMul"(%arg1, %arg2) {transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>) -> tensor<*xf32>
@@ -125,7 +125,7 @@
 
 // CHECK-LABEL: matmulBiasAdd_relu
 func.func @matmulBiasAdd_relu(%arg0: tensor<64xf32>, %arg1: tensor<8x32xf32>, %arg2: tensor<32x64xf32>) -> (tensor<*xf32>) {
-  // CHECK: %[[VAL_3:.*]] = "tf._FusedMatMul"(%arg1, %arg2, %arg0) {epsilon = 0.000000e+00 : f32, fused_ops = ["BiasAdd", "Relu"], transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>, tensor<64xf32>) -> tensor<*xf32>
+  // CHECK: %[[VAL_3:.*]] = "tf._FusedMatMul"(%arg1, %arg2, %arg0) <{epsilon = 0.000000e+00 : f32, fused_ops = ["BiasAdd", "Relu"], transpose_a = false, transpose_b = false}> : (tensor<8x32xf32>, tensor<32x64xf32>, tensor<64xf32>) -> tensor<*xf32>
   // CHECK: %[[VAL_4:.*]] = "tf.Identity"(%[[VAL_3]]) : (tensor<*xf32>) -> tensor<*xf32>
   // CHECK: return %[[VAL_4]]
   %3 = "tf.MatMul"(%arg1, %arg2) {transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>) -> tensor<*xf32>
@@ -137,7 +137,7 @@
 
 // CHECK-LABEL: matmulBiasAdd_relu6
 func.func @matmulBiasAdd_relu6(%arg0: tensor<64xf32>, %arg1: tensor<8x32xf32>, %arg2: tensor<32x64xf32>) -> (tensor<*xf32>) {
-  // CHECK: %[[VAL_3:.*]] = "tf._FusedMatMul"(%arg1, %arg2, %arg0) {epsilon = 0.000000e+00 : f32, fused_ops = ["BiasAdd", "Relu6"], transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>, tensor<64xf32>) -> tensor<*xf32>
+  // CHECK: %[[VAL_3:.*]] = "tf._FusedMatMul"(%arg1, %arg2, %arg0) <{epsilon = 0.000000e+00 : f32, fused_ops = ["BiasAdd", "Relu6"], transpose_a = false, transpose_b = false}> : (tensor<8x32xf32>, tensor<32x64xf32>, tensor<64xf32>) -> tensor<*xf32>
   // CHECK: %[[VAL_4:.*]] = "tf.Identity"(%[[VAL_3]]) : (tensor<*xf32>) -> tensor<*xf32>
   // CHECK: return %[[VAL_4]]
   %3 = "tf.MatMul"(%arg1, %arg2) {transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>) -> tensor<*xf32>
@@ -149,7 +149,7 @@
 
 // CHECK-LABEL: matmulBiasAdd_elu
 func.func @matmulBiasAdd_elu(%arg0: tensor<64xf32>, %arg1: tensor<8x32xf32>, %arg2: tensor<32x64xf32>) -> (tensor<*xf32>) {
-  // CHECK: %[[VAL_3:.*]] = "tf._FusedMatMul"(%arg1, %arg2, %arg0) {epsilon = 0.000000e+00 : f32, fused_ops = ["BiasAdd", "Elu"], transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>, tensor<64xf32>) -> tensor<*xf32>
+  // CHECK: %[[VAL_3:.*]] = "tf._FusedMatMul"(%arg1, %arg2, %arg0) <{epsilon = 0.000000e+00 : f32, fused_ops = ["BiasAdd", "Elu"], transpose_a = false, transpose_b = false}> : (tensor<8x32xf32>, tensor<32x64xf32>, tensor<64xf32>) -> tensor<*xf32>
   // CHECK: %[[VAL_4:.*]] = "tf.Identity"(%[[VAL_3]]) : (tensor<*xf32>) -> tensor<*xf32>
   // CHECK: return %[[VAL_4]]
   %3 = "tf.MatMul"(%arg1, %arg2) {transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>) -> tensor<*xf32>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-custom-operation.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-custom-operation.pbtxt
index c2f4d7a..e820f36 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-custom-operation.pbtxt
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-custom-operation.pbtxt
@@ -54,5 +54,5 @@
 # the names are matching between the function definition and the uses / call
 # site (a numerical suffix may be appended).
 
-# CHECK: "tf.LegacyCall"(%outputs) {_disable_call_shape_inference = false, device = "", f = @foo0}
+# CHECK: "tf.LegacyCall"(%outputs) <{_disable_call_shape_inference = false, f = @foo0}> {device = ""}
 # CHECK: func private @foo0
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-call.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-call.pbtxt
index f954657..02dd85f 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-call.pbtxt
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-call.pbtxt
@@ -68,4 +68,4 @@
 }
 
 # CHECK: func @main
-# CHECK: "tf.LegacyCall"(%arg0) {_disable_call_shape_inference = true, _tpu_replicate = "cluster", device = "", f = @test_func_name0}
+# CHECK: "tf.LegacyCall"(%arg0) <{_disable_call_shape_inference = true, f = @test_func_name0}> {_tpu_replicate = "cluster", device = ""}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-name-bug.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-name-bug.pbtxt
index 4b937a1..7244b60 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-name-bug.pbtxt
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-name-bug.pbtxt
@@ -121,8 +121,8 @@
 # Verify that functions from the library are properly imported.
 
 # CHECK-LABEL:  func @main() {
-# CHECK:    "tf.LegacyCall"() {_disable_call_shape_inference = false, device = "", f = @foo110}
-# CHECK:    "tf.LegacyCall"() {_disable_call_shape_inference = false, device = "", f = @foo111}
+# CHECK:    "tf.LegacyCall"() <{_disable_call_shape_inference = false, f = @foo110}> {device = ""}
+# CHECK:    "tf.LegacyCall"() <{_disable_call_shape_inference = false, f = @foo111}> {device = ""}
 
 # CHECK-LABEL:  func private @foo110()
 # CHECK-LABEL:  func private @foo111()
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-resource-args.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-resource-args.pbtxt
index 66847dc..ea5957a 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-resource-args.pbtxt
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-resource-args.pbtxt
@@ -88,7 +88,7 @@
 # CHECK:    tf_executor.graph
 # CHECK:      "tf.VarHandleOp"()
 # CHECK:      "tf.LegacyCall"
-# CHECK-SAME:   {_disable_call_shape_inference = true, device = "", f = @test_func_name0}
+# CHECK-SAME:   <{_disable_call_shape_inference = true, f = @test_func_name0}> {device = ""}
 # CHECK:      tf_executor.fetch
 # CHECK:    return
 # CHECK:  func private @test_func_name0
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-library.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-library.pbtxt
index eb59318..f515761 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-library.pbtxt
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-library.pbtxt
@@ -54,10 +54,10 @@
 # Verify that functions from the library are properly imported.
 
 # CHECK-LABEL:  func @main() {
-# CHECK:    "tf.LegacyCall"() {_disable_call_shape_inference = true, device = "", f = @foo0}
-# CHECK:    "tf.LegacyCall"() {_disable_call_shape_inference = false, device = "", f = @bar0}
+# CHECK:    "tf.LegacyCall"() <{_disable_call_shape_inference = true, f = @foo0}> {device = ""}
+# CHECK:    "tf.LegacyCall"() <{_disable_call_shape_inference = false, f = @bar0}> {device = ""}
 
 # CHECK-LABEL:  func private @foo0()
-# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, device = "", f = @bar0}
+# CHECK: "tf.LegacyCall"() <{_disable_call_shape_inference = false, f = @bar0}> {device = ""}
 
 # CHECK-LABEL:  func private @bar0()
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/mlir_passthrough_op.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/mlir_passthrough_op.pbtxt
index fd33be7..868ee80 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/mlir_passthrough_op.pbtxt
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/mlir_passthrough_op.pbtxt
@@ -1,7 +1,7 @@
 # RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s | FileCheck %s
 
 # CHECK:"tf.MlirPassthroughOp"
-# CHECK: mlir_module = "\0Afunc @main(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10x10xf32> {\0A %add = \22tf.Add\22(%arg0, %arg1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>\0A %ret = \22magic.op\22(%add, %add) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xf32>\0A return %ret : tensor<10x10xf32>\0A}\0A"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<*xf32>
+# CHECK: mlir_module = "\0Afunc @main(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10x10xf32> {\0A %add = \22tf.Add\22(%arg0, %arg1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>\0A %ret = \22magic.op\22(%add, %add) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xf32>\0A return %ret : tensor<10x10xf32>\0A}\0A"}> {device = ""} : (tensor<10xf32>, tensor<10xf32>) -> tensor<*xf32>
 
 node {
   name: "x"
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/guarantee-all-funcs-one-use.mlir b/tensorflow/compiler/mlir/tensorflow/tests/guarantee-all-funcs-one-use.mlir
index 52f29ba..0dd1813 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/guarantee-all-funcs-one-use.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/guarantee-all-funcs-one-use.mlir
@@ -73,9 +73,9 @@
 // Test stateful and stateless partitioned calls.
 // CHECK-LABEL: func @f
 func.func @f() {
-  // CHECK: "tf.PartitionedCall"() {config = "",  config_proto = "", executor_type = "", f = @g} : () -> ()
+  // CHECK: "tf.PartitionedCall"() <{config = "",  config_proto = "", executor_type = "", f = @g}> : () -> ()
   "tf.PartitionedCall"() {config = "",  config_proto = "", executor_type = "", f = @g} : () -> ()
-  // CHECK: "tf.StatefulPartitionedCall"() {config = "",  config_proto = "", executor_type = "", f = @[[NEWG:.+]]} : () -> ()
+  // CHECK: "tf.StatefulPartitionedCall"() <{config = "",  config_proto = "", executor_type = "", f = @[[NEWG:.+]]}> : () -> ()
   "tf.StatefulPartitionedCall"() {config = "",  config_proto = "", executor_type = "", f = @g} : () -> ()
   func.return
 }
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/host_launch_to_outside_compiled.mlir b/tensorflow/compiler/mlir/tensorflow/tests/host_launch_to_outside_compiled.mlir
index b6a0a2b..d786733 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/host_launch_to_outside_compiled.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/host_launch_to_outside_compiled.mlir
@@ -28,9 +28,9 @@
   func.func @single_op_launch_not_host() -> () {
     // CHECK:      "tf.A"
     // CHECK:      "tf_device.launch"
+    // CHECK-SAME:      device = "/job:worker/replica:0/task:0/device:TPU:0"
     // CHECK:        "tf.B"
     // CHECK-NOT:    _xla_outside_compilation
-    // CHECK:      device = "/job:worker/replica:0/task:0/device:TPU:0"
     // CHECK:      "tf.C"
     // CHECK-NEXT: tf_device.return
     "tf_device.cluster"() ({
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/inlining.mlir b/tensorflow/compiler/mlir/tensorflow/tests/inlining.mlir
index f473300..4b0ba86 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/inlining.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/inlining.mlir
@@ -71,8 +71,8 @@
 // CHECK-LABEL: func @inline_shape_cast(
 // CHECK-SAME:                          %[[ARG:.*]]: tensor<2xi32>
 func.func @inline_shape_cast(%arg: tensor<2xi32>) -> tensor<2xi32> {
-  // CHECK-NEXT: %[[ARG_CAST:.*]] = "tf.Cast"(%[[ARG]]) {Truncate = false} : (tensor<2xi32>) -> tensor<*xi32>
-  // CHECK-NEXT: %[[RESULT_CAST:.*]] = "tf.Cast"(%[[ARG_CAST]]) {Truncate = false} : (tensor<*xi32>) -> tensor<2xi32>
+  // CHECK-NEXT: %[[ARG_CAST:.*]] = "tf.Cast"(%[[ARG]]) <{Truncate = false}> : (tensor<2xi32>) -> tensor<*xi32>
+  // CHECK-NEXT: %[[RESULT_CAST:.*]] = "tf.Cast"(%[[ARG_CAST]]) <{Truncate = false}> : (tensor<*xi32>) -> tensor<2xi32>
   // CHECK-NEXT: return %[[RESULT_CAST]]
   %result = "tf.PartitionedCall"(%arg) {config = "", config_proto = "", executor_type = "", f = @inline_shape_cast_callee} : (tensor<2xi32>) -> tensor<2xi32>
   func.return %result : tensor<2xi32>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/launch_outlining.mlir b/tensorflow/compiler/mlir/tensorflow/tests/launch_outlining.mlir
index 84825ba..91d58dff 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/launch_outlining.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/launch_outlining.mlir
@@ -10,7 +10,7 @@
       // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"(%[[ARG_0]])
       %2 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
 
-      // CHECK: %[[LAUNCH_OUTPUT:[0-9]*]] = "tf_device.launch_func"(%[[A_OUTPUT]]) {device = "/device:test_device:0", func = @[[LAUNCH:.*]]}
+      // CHECK: %[[LAUNCH_OUTPUT:[0-9]*]] = "tf_device.launch_func"(%[[A_OUTPUT]]) <{device = "/device:test_device:0", func = @[[LAUNCH:.*]]}>
       %3 = "tf_device.launch"() ({
         %4 = "tf.B"(%2) : (tensor<?xi32>) -> tensor<?xi32>
         tf_device.return %4 : tensor<?xi32>
@@ -42,7 +42,7 @@
       // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"(%[[ARG_0]])
       %2 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
 
-      // CHECK: %[[LAUNCH_0_OUTPUT:[0-9]*]] = "tf_device.launch_func"(%[[A_OUTPUT]]) {device = "/device:test_device:0", func = @[[LAUNCH_0:.*]]}
+      // CHECK: %[[LAUNCH_0_OUTPUT:[0-9]*]] = "tf_device.launch_func"(%[[A_OUTPUT]]) <{device = "/device:test_device:0", func = @[[LAUNCH_0:.*]]}>
       %3 = "tf_device.launch"() ({
         %6 = "tf.B"(%2) : (tensor<?xi32>) -> tensor<?xi32>
         tf_device.return %6 : tensor<?xi32>
@@ -51,7 +51,7 @@
       // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[LAUNCH_0_OUTPUT]])
       %4 = "tf.D"(%3) : (tensor<?xi32>) -> tensor<?xi32>
 
-      // CHECK: %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch_func"(%[[LAUNCH_0_OUTPUT]], %[[D_OUTPUT]]) {device = "/device:test_device:0", func = @[[LAUNCH_1:.*]]}
+      // CHECK: %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch_func"(%[[LAUNCH_0_OUTPUT]], %[[D_OUTPUT]]) <{device = "/device:test_device:0", func = @[[LAUNCH_1:.*]]}>
       %5 = "tf_device.launch"() ({
         %6 = "tf.E"(%3) : (tensor<?xi32>) -> tensor<?xi32>
         %7 = "tf.F"(%4, %6) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
@@ -86,7 +86,7 @@
 func.func @launch_operands(%arg0: tensor<?xi32>) -> tensor<?xi32> {
   %0 = tf_executor.graph {
     %1:2 = tf_executor.island wraps
-      // CHECK: %[[LAUNCH_OUTPUT:[a-z0-9]*]], %{{.*}} = {{.*}} "tf_device.launch_func"() {device = "/device:test_device:0", func = @[[LAUNCH:.*]]}
+      // CHECK: %[[LAUNCH_OUTPUT:[a-z0-9]*]], %{{.*}} = {{.*}} "tf_device.launch_func"() <{device = "/device:test_device:0", func = @[[LAUNCH:.*]]}>
       "tf_device.launch"() ({
         %3 = "tf.A"() : () -> tensor<?xi32>
         tf_device.return %3 : tensor<?xi32>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nchw.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nchw.mlir
index c2dfdc6..3746183 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nchw.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nchw.mlir
@@ -7,7 +7,7 @@
 // CHECK-LABEL: func @transposeConv2D
 func.func @transposeConv2D(%input: tensor<1x32x32x3xf32>, %filter: tensor<1x1x3x8xf32>) -> tensor<1x7x7x8xf32> {
 
-  // CHECK: %[[ARG_PERM:.*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
+  // CHECK: %[[ARG_PERM:.*]] = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}>
   // CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
 
   // CHECK: %[[CONV2D:[0-9]*]] = "tf.Conv2D"(%[[ARG_TRANSPOSE]], %arg1)
@@ -18,7 +18,7 @@
   // CHECK-SAME: strides = [5, 8, 6, 7]
   // CHECK-SAME: (tensor<1x3x32x32xf32>, tensor<1x1x3x8xf32>) -> tensor<1x8x7x7xf32>
 
-  // CHECK: %[[RES_PERM:.*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>}
+  // CHECK: %[[RES_PERM:.*]] = "tf.Const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi64>}>
   // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[CONV2D]], %[[RES_PERM]])
   // CHECK: return %[[RES_TRANSPOSE]]
 
@@ -38,7 +38,7 @@
 func.func @transposeConv2DWithDefaultAttr(%input: tensor<1x32x32x3xf32>, %filter: tensor<1x1x3x8xf32>) -> tensor<?x?x?x?xf32>
 {
 
-  // CHECK: %[[ARG_PERM:.*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
+  // CHECK: %[[ARG_PERM:.*]] = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}>
   // CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
 
   // CHECK: %[[CONV2D:[0-9]*]] = "tf.Conv2D"(%[[ARG_TRANSPOSE]], %arg1)
@@ -49,7 +49,7 @@
   // CHECK-SAME: strides = [5, 8, 6, 7]
   // CHECK-SAME: (tensor<1x3x32x32xf32>, tensor<1x1x3x8xf32>) -> tensor<?x?x?x?xf32>
 
-  // CHECK: %[[RES_PERM:.*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>}
+  // CHECK: %[[RES_PERM:.*]] = "tf.Const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi64>}>
   // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[CONV2D]], %[[RES_PERM]])
   // CHECK: return %[[RES_TRANSPOSE]]
 
@@ -77,7 +77,7 @@
   // CHECK-SAME: dst_format = "NCHW"
   // CHECK-SAME: src_format = "NHWC"
 
-  // CHECK: %[[ARG_PERM:.*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
+  // CHECK: %[[ARG_PERM:.*]] = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}>
   // CHECK: %[[IN_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
   // CHECK: %[[OUT_BP_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg2, %[[ARG_PERM]])
 
@@ -117,7 +117,7 @@
   // CHECK-SAME: dst_format = "NCHW"
   // CHECK-SAME: src_format = "NHWC"
 
-  // CHECK: %[[ARG_PERM:.*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
+  // CHECK: %[[ARG_PERM:.*]] = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}>
   // CHECK: %[[OUT_BP_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg2, %[[ARG_PERM]])
 
   // CHECK: %[[CONV2D_BACKPROP:[0-9]*]] = "tf.Conv2DBackpropInput"
@@ -130,7 +130,7 @@
   // CHECK-SAME: (tensor<4xi32>, tensor<1x1x3x8xf32>, tensor<1x8x32x32xf32>)
   // CHECK-SAME: -> tensor<1x3x32x32xf32>
 
-  // CHECK: %[[RES_PERM:.*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>}
+  // CHECK: %[[RES_PERM:.*]] = "tf.Const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi64>}>
   // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[CONV2D_BACKPROP]], %[[RES_PERM]])
   // CHECK: return %[[RES_TRANSPOSE]]
 
@@ -154,7 +154,7 @@
 ) -> tensor<1x28x28x64xf32> {
 
   // CHECK: %[[ARG_PERM:.*]] = "tf.Const"()
-  // CHECK-SAME: {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
+  // CHECK-SAME: <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}>
   // CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
 
   // CHECK: "tf.FusedBatchNormV3"
@@ -164,7 +164,7 @@
   // CHECK-SAME: -> (tensor<1x64x28x28xf32>, tensor<64xf32>,
 
   // CHECK: %[[RES_PERM:.*]] = "tf.Const"()
-  // CHECK-SAME: {value = dense<[0, 2, 3, 1]> : tensor<4xi64>}
+  // CHECK-SAME: <{value = dense<[0, 2, 3, 1]> : tensor<4xi64>}>
   // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%y, %[[RES_PERM]])
   // CHECK: return %[[RES_TRANSPOSE]]
 
@@ -192,7 +192,7 @@
 ) -> tensor<1x28x28x64xf32> {
 
   // CHECK: %[[ARG_PERM:.*]] = "tf.Const"()
-  // CHECK-SAME: {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
+  // CHECK-SAME: <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}>
 
   // CHECK: %[[ARG0_TPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
   // CHECK: %[[ARG1_TPOSE:[0-9]*]] = "tf.Transpose"(%arg1, %[[ARG_PERM]])
@@ -204,7 +204,7 @@
   // CHECK-SAME: -> (tensor<1x64x28x28xf32>,
 
   // CHECK: %[[RES_PERM:.*]] = "tf.Const"()
-  // CHECK-SAME: {value = dense<[0, 2, 3, 1]> : tensor<4xi64>}
+  // CHECK-SAME: <{value = dense<[0, 2, 3, 1]> : tensor<4xi64>}>
 
   // CHECK: %[[RES_TPOSE:[0-9]*]] = "tf.Transpose"
   // CHECK-SAME: (%x_backprop, %[[RES_PERM]])
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir
index 62749d1..b13da20 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_layout_assignment_to_nhwc.mlir
@@ -7,7 +7,7 @@
 // CHECK-LABEL: func @transposeConv2D
 func.func @transposeConv2D(%input: tensor<1x3x32x32xf32>, %filter: tensor<1x1x3x8xf32>) -> tensor<1x8x7x6xf32> {
 
-  // CHECK: %[[ARG_PERM:.*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>}
+  // CHECK: %[[ARG_PERM:.*]] = "tf.Const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi64>}>
   // CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
 
   // CHECK: %[[CONV2D:[0-9]*]] = "tf.Conv2D"(%[[ARG_TRANSPOSE]], %arg1)
@@ -18,7 +18,7 @@
   // CHECK-SAME: strides = [5, 7, 8, 6]
   // CHECK-SAME: (tensor<1x32x32x3xf32>, tensor<1x1x3x8xf32>) -> tensor<1x7x6x8xf32>
 
-  // CHECK: %[[RES_PERM:.*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
+  // CHECK: %[[RES_PERM:.*]] = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}>
   // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[CONV2D]], %[[RES_PERM]])
   // CHECK: return %[[RES_TRANSPOSE]]
 
@@ -41,7 +41,7 @@
 ) -> tensor<1x64x28x28xf32> {
 
   // CHECK: %[[ARG_PERM:.*]] = "tf.Const"()
-  // CHECK-SAME: {value = dense<[0, 2, 3, 1]> : tensor<4xi64>}
+  // CHECK-SAME: <{value = dense<[0, 2, 3, 1]> : tensor<4xi64>}>
   // CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
 
   // CHECK: "tf.FusedBatchNormV3"
@@ -51,7 +51,7 @@
   // CHECK-SAME: -> (tensor<1x28x28x64xf32>, tensor<64xf32>,
 
   // CHECK: %[[RES_PERM:.*]] = "tf.Const"()
-  // CHECK-SAME: {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
+  // CHECK-SAME: <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}>
   // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%y, %[[RES_PERM]])
   // CHECK: return %[[RES_TRANSPOSE]]
 
@@ -74,10 +74,10 @@
 // CHECK-LABEL: bias_add_nchw
 func.func @bias_add_nchw(%arg0: tensor<1x256x150x150xf32>, %arg1: tensor<256xf32>) -> tensor<1x256x150x150xf32> {
   // CHECK: (%[[ARG0:.*]]: tensor<1x256x150x150xf32>, %[[ARG1:.*]]: tensor<256xf32>)
-  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>}
+  // CHECK: %[[CST:.*]] = "tf.Const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi64>}>
   // CHECK: %[[R0:.*]] = "tf.Transpose"(%[[ARG0]], %[[CST]])
-  // CHECK: %[[R1:.*]] = "tf.BiasAdd"(%[[R0]], %[[ARG1]]) {data_format = "NHWC", device = ""}
-  // CHECK: %[[CST_0:.*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
+  // CHECK: %[[R1:.*]] = "tf.BiasAdd"(%[[R0]], %[[ARG1]]) <{data_format = "NHWC"}> {device = ""}
+  // CHECK: %[[CST_0:.*]] = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}>
   // CHECK: "tf.Transpose"(%[[R1]], %[[CST_0]])
   %0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NCHW", device = ""} : (tensor<1x256x150x150xf32>, tensor<256xf32>) -> tensor<1x256x150x150xf32>
   func.return %0 : tensor<1x256x150x150xf32>
@@ -85,10 +85,10 @@
 
 // CHECK-LABEL: maxpool_nchw
 func.func @maxpool_nchw(%arg0: tensor<1x64x112x112xf32>) -> tensor<1x64x56x56xf32> {
-  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>}
+  // CHECK: %[[CST:.*]] = "tf.Const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi64>}>
   // CHECK: %[[R0:.*]] = "tf.Transpose"(%arg0, %[[CST]])
-  // CHECK: %[[R1:.*]] = "tf.MaxPool"(%[[R0]]) {data_format = "NHWC", explicit_paddings = [], ksize = [1, 3, 3, 1], padding = "SAME", strides = [1, 2, 2, 1]}
-  // CHECK: %[[CST_0:.*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
+  // CHECK: %[[R1:.*]] = "tf.MaxPool"(%[[R0]]) <{data_format = "NHWC", explicit_paddings = [], ksize = [1, 3, 3, 1], padding = "SAME", strides = [1, 2, 2, 1]}>
+  // CHECK: %[[CST_0:.*]] = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}>
   // CHECK: "tf.Transpose"(%[[R1]], %[[CST_0]])
   %0 = "tf.MaxPool"(%arg0)
        {
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_move_transposes_begin.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_move_transposes_begin.mlir
index be36e2e..be511f9 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_move_transposes_begin.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_move_transposes_begin.mlir
@@ -3,7 +3,7 @@
 // CHECK-LABEL: func @move_across_single_op
 func.func @move_across_single_op(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
 
-  // CHECK: %[[ARG_PERM:.*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
+  // CHECK: %[[ARG_PERM:.*]] = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}>
   // CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
   // CHECK: %[[TANH:[0-9]*]] = "tf.Tanh"(%[[ARG_TRANSPOSE]]) {{.*}} tensor<1x8x4x4xf32>
   // CHECK: return %[[TANH]]
@@ -18,7 +18,7 @@
 // CHECK-LABEL: func @move_across_multiple_ops
 func.func @move_across_multiple_ops(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
 
-  // CHECK: %[[ARG_PERM:.*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
+  // CHECK: %[[ARG_PERM:.*]] = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}>
   // CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
   // CHECK: %[[TANH:[0-9]*]] = "tf.Tanh"(%[[ARG_TRANSPOSE]]) {{.*}} tensor<1x8x4x4xf32>
   // CHECK: %[[RELU:[0-9]*]] = "tf.Relu"(%[[TANH]]) {{.*}} tensor<1x8x4x4xf32>
@@ -36,7 +36,7 @@
 // CHECK-LABEL: func @move_across_multi_operand_op
 func.func @move_across_multi_operand_op(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
 
-  // CHECK: %[[ARG_PERM:.*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
+  // CHECK: %[[ARG_PERM:.*]] = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}>
   // CHECK: %[[ARG0_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
   // CHECK: %[[ARG1_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg1, %[[ARG_PERM]])
   // CHECK: %[[ADD:[0-9]*]] = "tf.AddV2"(%[[ARG0_TRANSPOSE]], %[[ARG1_TRANSPOSE]]) {{.*}} tensor<1x8x4x4xf32>
@@ -52,7 +52,7 @@
 // CHECK-LABEL: func @move_with_multiple_uses
 func.func @move_with_multiple_uses(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
 
-  // CHECK: %[[ARG_PERM:.*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
+  // CHECK: %[[ARG_PERM:.*]] = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}>
   // CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
   // CHECK: %[[TANH:[0-9]*]] = "tf.Tanh"(%[[ARG_TRANSPOSE]]) {{.*}} tensor<1x8x4x4xf32>
   // CHECK: %[[ADD:[0-9]*]] = "tf.AddV2"(%[[TANH]], %[[TANH]]) {{.*}} tensor<1x8x4x4xf32>
@@ -78,9 +78,9 @@
 
   func.return %3 : tensor<512x64xf32>
 
-  // CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
-  // CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
-  // CHECK-DAG: %[[CST_2:.*]] = "tf.Const"() {value = dense<[512, 64]> : tensor<2xi32>} : () -> tensor<2xi32>
+  // CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() <{value = dense<[2, 0, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
+  // CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() <{value = dense<3> : tensor<i32>}> : () -> tensor<i32>
+  // CHECK-DAG: %[[CST_2:.*]] = "tf.Const"() <{value = dense<[512, 64]> : tensor<2xi32>}> : () -> tensor<2xi32>
   // CHECK: %[[EXPAND_DIMS:.*]] = "tf.ExpandDims"(%arg0, %[[CST_1]]) {device = ""} : (tensor<8x64xf32>, tensor<i32>) -> tensor<8x64x1xf32>
   // CHECK: %[[TRANSPOSE_1:.*]] = "tf.Transpose"(%[[EXPAND_DIMS]], %[[CST_0]]) : (tensor<8x64x1xf32>, tensor<3xi32>) -> tensor<1x8x64xf32>
   // CHECK: %[[TRANSPOSE_2:.*]] = "tf.Transpose"(%arg1, %[[CST_0]]) : (tensor<8x64x64xf32>, tensor<3xi32>) -> tensor<64x8x64xf32>
@@ -97,7 +97,7 @@
 
   func.return %1 : tensor<1x2x1x3xf32>
 
-  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[0, 2, 1, 3]> : tensor<4xi32>} : () -> tensor<4xi32>
+  // CHECK: %[[CST:.*]] = "tf.Const"() <{value = dense<[0, 2, 1, 3]> : tensor<4xi32>}> : () -> tensor<4xi32>
   // CHECK: %[[ADD:.*]] = "tf.AddV2"(%arg0, %arg1) {device = ""} : (tensor<1x1x2x3xf32>, tensor<2x3xf32>) -> tensor<1x1x2x3xf32>
   // CHECK: %[[TRANSPOSE:.*]] = "tf.Transpose"(%[[ADD]], %[[CST]]) {device = ""} : (tensor<1x1x2x3xf32>, tensor<4xi32>) -> tensor<1x2x1x3xf32>
   // CHECK: return %[[TRANSPOSE]] : tensor<1x2x1x3xf32>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_move_transposes_end.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_move_transposes_end.mlir
index 20bf6d6..0bc9a13 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_move_transposes_end.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_move_transposes_end.mlir
@@ -4,7 +4,7 @@
 // CHECK-LABEL: func @move_across_single_op
 func.func @move_across_single_op(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
 
-  // CHECK: %[[RES_PERM:.*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
+  // CHECK: %[[RES_PERM:.*]] = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}>
   // CHECK: %[[TANH:[0-9]*]] = "tf.Tanh"(%arg0) {{.*}} tensor<1x4x4x8xf32>
   // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[TANH]], %[[RES_PERM]]) {{.*}} tensor<1x8x4x4xf32>
   // CHECK: return %[[RES_TRANSPOSE]]
@@ -19,7 +19,7 @@
 // CHECK-LABEL: func @move_across_multiple_ops
 func.func @move_across_multiple_ops(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
 
-  // CHECK: %[[RES_PERM:.*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
+  // CHECK: %[[RES_PERM:.*]] = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}>
   // CHECK: %[[TANH:[0-9]*]] = "tf.Tanh"(%arg0) {{.*}} tensor<1x4x4x8xf32>
   // CHECK: %[[RELU:[0-9]*]] = "tf.Relu"(%[[TANH]]) {{.*}} tensor<1x4x4x8xf32>
   // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[RELU]], %[[RES_PERM]])
@@ -36,7 +36,7 @@
 // CHECK-LABEL: func @move_across_multi_operand_op
 func.func @move_across_multi_operand_op(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
 
-  // CHECK: %[[RES_PERM:.*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
+  // CHECK: %[[RES_PERM:.*]] = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}>
   // CHECK: %[[ADD:[0-9]*]] = "tf.AddV2"(%arg0, %arg1) {{.*}} tensor<1x4x4x8xf32>
   // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[ADD]], %[[RES_PERM]])
   // CHECK: return %[[RES_TRANSPOSE]]
@@ -52,7 +52,7 @@
 // CHECK-LABEL: func @move_across_broadcastable_op
 func.func @move_across_broadcastable_op(%arg0: tensor<1x4x1x8xf32>, %arg1: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
 
-  // CHECK: %[[RES_PERM:.*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
+  // CHECK: %[[RES_PERM:.*]] = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}>
   // CHECK: %[[ADD:[0-9]*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<1x4x1x8xf32>, tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
   // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[ADD]], %[[RES_PERM]])
   // CHECK: return %[[RES_TRANSPOSE]]
@@ -68,7 +68,7 @@
 // CHECK-LABEL: func @move_across_double_transpose
 func.func @move_across_double_transpose(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<1x4x4x8xf32>) -> tensor<1x4x8x4xf32> {
 
-  // CHECK: %[[RES_PERM:.*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
+  // CHECK: %[[RES_PERM:.*]] = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}>
   // CHECK: %[[ADD:[0-9]*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<1x4x4x8xf32>, tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
   // CHECK: %[[RES_TRANSPOSE_0:[0-9]*]] = "tf.Transpose"(%[[ADD]], %[[RES_PERM]])
   // CHECK: %[[RES_TRANSPOSE_1:[0-9]*]] = "tf.Transpose"(%[[RES_TRANSPOSE_0]], %[[RES_PERM]])
@@ -90,8 +90,8 @@
   // MaxPool operand transpose must be folded into the op and MaxPool
   // must use NCHW data format with updated kernel size and strides.
 
-  // CHECK: %[[RES_PERM:.*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>}
-  // CHECK: %[[MAX_POOL:[0-9]*]] = "tf.MaxPool"(%arg0) {data_format = "NCHW", ksize = [1, 1, 3, 3], padding = "SAME", strides = [1, 1, 2, 2]} : (tensor<1x64x112x112xf32>) -> tensor<1x64x56x56xf32>
+  // CHECK: %[[RES_PERM:.*]] = "tf.Const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}>
+  // CHECK: %[[MAX_POOL:[0-9]*]] = "tf.MaxPool"(%arg0) <{data_format = "NCHW", ksize = [1, 1, 3, 3], padding = "SAME", strides = [1, 1, 2, 2]}> : (tensor<1x64x112x112xf32>) -> tensor<1x64x56x56xf32>
   // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[MAX_POOL]], %[[RES_PERM]])
   // CHECK: return %[[RES_TRANSPOSE]]
 
@@ -112,14 +112,14 @@
 // CHECK-LABEL: func @fold_into_mean
 func.func @fold_into_mean(%arg0: tensor<1x64x112x112xf32>) -> tensor<1x64xf32> {
 
-  // CHECK: %[[RED_IDX:.*]] = "tf.Const"() {value = dense<[2, 3]> : tensor<2xi32>}
+  // CHECK: %[[RED_IDX:.*]] = "tf.Const"() <{value = dense<[2, 3]> : tensor<2xi32>}>
   // CHECK: %[[MEAN:[0-9]*]] = "tf.Mean"(%arg0, %[[RED_IDX]])
   // CHECK-SAME: (tensor<1x64x112x112xf32>, tensor<2xi32>) -> tensor<1x64xf32>
   // CHECK: return %[[MEAN]]
 
-  // NOFOLD: %[[CST:.*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>}
+  // NOFOLD: %[[CST:.*]] = "tf.Const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}>
   // NOFOLD: %[[TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[CST]])
-  // NOFOLD: %[[CST_1:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>}
+  // NOFOLD: %[[CST_1:.*]] = "tf.Const"() <{value = dense<[1, 2]> : tensor<2xi32>}>
   // NOFOLD: %[[MEAN:[0-9]*]] = "tf.Mean"(%[[TRANSPOSE]], %[[CST_1]])
   // NOFOLD-SAME: (tensor<1x112x112x64xf32>, tensor<2xi32>) -> tensor<1x64xf32>
   // NOFOLD: return %[[MEAN]]
@@ -138,8 +138,8 @@
 // CHECK-LABEL: func @fold_into_fused_batch_norm
 func.func @fold_into_fused_batch_norm(%arg0: tensor<1x64x112x112xf32>, %arg1: tensor<64xf32>) -> tensor<1x112x112x64xf32> {
 
-  // CHECK: %[[RES_PERM:.*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>}
-  // CHECK: "tf.FusedBatchNormV3"(%arg0, {{.*}} {data_format = "NCHW"
+  // CHECK: %[[RES_PERM:.*]] = "tf.Const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}>
+  // CHECK: "tf.FusedBatchNormV3"(%arg0, {{.*}} <{data_format = "NCHW"
   // CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%y, %[[RES_PERM]])
   // CHECK: return %[[RES_TRANSPOSE]]
 
@@ -165,9 +165,9 @@
 // CHECK-LABEL: func @fold_into_pad_with_extra_uses
 func.func @fold_into_pad_with_extra_uses(%arg0: tensor<1x2x4x4x3xf32>) -> (tensor<1x2x3x4x4xf32>, tensor<1x2x3x6x6xf32>) {
 
-  // CHECK: %[[PERM:.*]] = "tf.Const"() {value = dense<[0, 1, 4, 2, 3]> : tensor<5xi32>}
+  // CHECK: %[[PERM:.*]] = "tf.Const"() <{value = dense<[0, 1, 4, 2, 3]> : tensor<5xi32>}>
   // CHECK: %[[TRANSPOSE_OP:[0-9]*]] = "tf.Transpose"(%arg0, %[[PERM]])
-  // CHECK: %[[PADDING:.*]] = "tf.Const"() {value = dense<{{\[\[}}0, 0], [0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<5x2xi32>}
+  // CHECK: %[[PADDING:.*]] = "tf.Const"() <{value = dense<{{\[\[}}0, 0], [0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<5x2xi32>}>
   // CHECK: %[[PAD_OP:[0-9]*]] = "tf.Pad"(%arg0, %[[PADDING]])
   // CHECK: %[[DUP_TRANSPOSE_OP:[0-9]*]] = "tf.Transpose"(%[[PAD_OP]], %[[PERM]])
   // CHECK: return %[[TRANSPOSE_OP]], %[[DUP_TRANSPOSE_OP]]
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_to_nhwc.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_to_nhwc.mlir
index 5f9256f..e828196 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_to_nhwc.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_to_nhwc.mlir
@@ -32,7 +32,7 @@
   // Shuffled paddings.
   // CHECK: %[[PADDINGS:.*]] = "tf.Const"(){{.*}}[0, 0], [3, 3], [3, 3], [0, 0]
   // NOFOLD: %[[PADDING:.*]] = "tf.Const"(){{.*}}[0, 0], [0, 0], [3, 3], [3, 3]
-  // NOFOLD: %[[CST:.*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+  // NOFOLD: %[[CST:.*]] = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
   // NOFOLD: %[[TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[CST]]) : (tensor<?x224x224x3xf32>, tensor<4xi32>) -> tensor<?x3x224x224xf32>
 
   // Pad input with new paddings.
@@ -151,7 +151,7 @@
   %16 = "tf.Mean"(%15, %1) : (tensor<?x256x56x56xf32>, tensor<2xi32>) -> tensor<?x256xf32>
 
   // Mean should compute reduction over NHWC spatial dimensions.
-  // CHECK: %[[MEAN_DIMS:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>}
+  // CHECK: %[[MEAN_DIMS:.*]] = "tf.Const"() <{value = dense<[1, 2]> : tensor<2xi32>}>
   // CHECK: %[[MEAN:[0-9]*]] = "tf.Mean"(%[[RELU]], %[[MEAN_DIMS]])
   // CHECK-SAME: (tensor<?x56x56x256xf32>, tensor<2xi32>) -> tensor<?x256xf32>
   // CHECK: return %[[MEAN]] : tensor<?x256xf32>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/legalize_tfg.mlir b/tensorflow/compiler/mlir/tensorflow/tests/legalize_tfg.mlir
index c170a0e..0ff3cd3 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/legalize_tfg.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/legalize_tfg.mlir
@@ -4,9 +4,9 @@
 module  {
   // CHECK: tf_executor.graph
   tfg.graph #tf_type.version<producer = 919, min_consumer = 12> {
-    // CHECK: tf_executor.island wraps "tf.VarHandleOp"() {_mlir_name = "x", _output_shapes = [#tf_type.shape<>], allowed_devices = [], container = "a", device = "/device:CPU:0", dtype = i64, shape = #tf_type.shape<>, shared_name = "x"} : () -> tensor<!tf_type.resource<tensor<i64>>>
+    // CHECK: tf_executor.island wraps "tf.VarHandleOp"() <{container = "a", shared_name = "x"}> {_mlir_name = "x", _output_shapes = [#tf_type.shape<>], allowed_devices = [], device = "/device:CPU:0", dtype = i64, shape = #tf_type.shape<>} : () -> tensor<!tf_type.resource<tensor<i64>>>
     %VarHandleOp, %ctl = VarHandleOp device("/CPU:0") name("x") {_output_shapes = [#tf_type.shape<>], allowed_devices = [], container = "a", dtype = i64, shape = #tf_type.shape<>, shared_name = "x"} : () -> (tensor<!tf_type.resource<tensor<i64>>>)
-    // CHECK: tf_executor.island wraps "tf.LegacyCall"(%outputs, %outputs) {_disable_call_shape_inference = true, f = @test_func_name0} : (tensor<!tf_type.resource<tensor<i64>>>, tensor<!tf_type.resource<tensor<i64>>>) -> tensor<*x!tf_type.resource>
+    // CHECK: tf_executor.island wraps "tf.LegacyCall"(%outputs, %outputs) <{_disable_call_shape_inference = true, f = @test_func_name0}> : (tensor<!tf_type.resource<tensor<i64>>>, tensor<!tf_type.resource<tensor<i64>>>) -> tensor<*x!tf_type.resource>
     %test_func_name0, %ctl_0 = test_func_name0(%VarHandleOp, %VarHandleOp) name("called") {_disable_call_shape_inference = true, _output_shapes = [#tf_type.shape<*>]} : (tensor<!tf_type.resource<tensor<i64>>>, tensor<!tf_type.resource<tensor<i64>>>) -> (tensor<*x!tf_type.resource>)
     // CHECK: tf_executor.island wraps "tf._Retval"(%outputs_0) {T = !tf_type.resource, _mlir_name = "func_call", index = 0 : i64} : (tensor<*x!tf_type.resource>) -> ()
     %ctl_1 = _Retval(%test_func_name0) name("func_call") {T = !tf_type.resource, index = 0 : i64} : tensor<*x!tf_type.resource>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/lower_quantized.mlir b/tensorflow/compiler/mlir/tensorflow/tests/lower_quantized.mlir
index 11c9704..eedc235 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/lower_quantized.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/lower_quantized.mlir
@@ -2,9 +2,9 @@
 
 // CHECK-LABEL: dequantize
 func.func @dequantize(%arg0: tensor<2x3x!tf_type.qint8>, %min_range: tensor<f32>, %max_range: tensor<f32>) -> tensor<2x3xf32> {
-  // CHECK-DAG: %[[HALF_RANGE:.*]] = "tf.Const"() {value = dense<1.280000e+02> : tensor<f32>}
-  // CHECK-DAG: %[[C255:.*]] = "tf.Const"() {value = dense<2.550000e+02> : tensor<f32>}
-  // CHECK-DAG: %[[CAST:.*]] = "tf.Cast"(%arg0) {Truncate = false}
+  // CHECK-DAG: %[[HALF_RANGE:.*]] = "tf.Const"() <{value = dense<1.280000e+02> : tensor<f32>}>
+  // CHECK-DAG: %[[C255:.*]] = "tf.Const"() <{value = dense<2.550000e+02> : tensor<f32>}>
+  // CHECK-DAG: %[[CAST:.*]] = "tf.Cast"(%arg0) <{Truncate = false}>
   // CHECK-DAG: %[[SHIFT:.*]] = "tf.AddV2"(%[[CAST]], %[[HALF_RANGE]])
   // CHECK-DAG: %[[DRANGE:.*]] = "tf.Sub"(%arg2, %arg1)
   // CHECK-DAG: %[[SCALE:.*]] = "tf.Div"(%[[DRANGE]], %[[C255:.*]])
@@ -18,8 +18,8 @@
 
 // CHECK-LABEL: dequantize_quint8
 func.func @dequantize_quint8(%arg0: tensor<2x3x!tf_type.quint8>, %min_range: tensor<f32>, %max_range: tensor<f32>) -> tensor<2x3xf32> {
-  // CHECK-NEXT: %[[C255:.*]] = "tf.Const"() {value = dense<2.550000e+02> : tensor<f32>}
-  // CHECK-NEXT: %[[CAST:.*]] = "tf.Cast"(%arg0) {Truncate = false}
+  // CHECK-NEXT: %[[C255:.*]] = "tf.Const"() <{value = dense<2.550000e+02> : tensor<f32>}>
+  // CHECK-NEXT: %[[CAST:.*]] = "tf.Cast"(%arg0) <{Truncate = false}>
   // CHECK-NEXT: %[[DRANGE:.*]] = "tf.Sub"(%arg2, %arg1)
   // CHECK-NEXT: %[[SCALE:.*]] = "tf.Div"(%[[DRANGE]], %[[C255:.*]])
   // CHECK-NEXT: %[[SS:.*]] = "tf.Mul"(%[[CAST]], %[[SCALE]])
@@ -32,15 +32,15 @@
 
 // CHECK-LABEL: dequantize_to_bf16
 func.func @dequantize_to_bf16(%arg0: tensor<2x3x!tf_type.qint8>, %min_range: tensor<f32>, %max_range: tensor<f32>) -> tensor<2x3xbf16> {
-  // CHECK-DAG: %[[HALF_RANGE:.*]] = "tf.Const"() {value = dense<1.280000e+02> : tensor<f32>}
-  // CHECK-DAG: %[[C255:.*]] = "tf.Const"() {value = dense<2.550000e+02> : tensor<f32>}
-  // CHECK-DAG: %[[CAST:.*]] = "tf.Cast"(%arg0) {Truncate = false}
+  // CHECK-DAG: %[[HALF_RANGE:.*]] = "tf.Const"() <{value = dense<1.280000e+02> : tensor<f32>}>
+  // CHECK-DAG: %[[C255:.*]] = "tf.Const"() <{value = dense<2.550000e+02> : tensor<f32>}>
+  // CHECK-DAG: %[[CAST:.*]] = "tf.Cast"(%arg0) <{Truncate = false}>
   // CHECK-DAG: %[[SHIFT:.*]] = "tf.AddV2"(%[[CAST]], %[[HALF_RANGE]])
   // CHECK-DAG: %[[DRANGE:.*]] = "tf.Sub"(%arg2, %arg1)
   // CHECK-DAG: %[[SCALE:.*]] = "tf.Div"(%[[DRANGE]], %[[C255:.*]])
   // CHECK-DAG: %[[SS:.*]] = "tf.Mul"(%[[SHIFT]], %[[SCALE]])
   // CHECK-DAG: %[[F32_RESULT:.*]] = "tf.AddV2"(%[[SS]], %arg1)
-  // CHECK-DAG: %[[RESULT:.*]] = "tf.Cast"(%[[F32_RESULT]]) {Truncate = false}
+  // CHECK-DAG: %[[RESULT:.*]] = "tf.Cast"(%[[F32_RESULT]]) <{Truncate = false}>
   %0 = "tf.Dequantize"(%arg0, %min_range, %max_range) : (tensor<2x3x!tf_type.qint8>, tensor<f32>, tensor<f32>) -> tensor<2x3xbf16>
 
   // CHECK-DAG: return %[[RESULT]]
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir
index 432195a..83f0b56 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir
@@ -2,11 +2,11 @@
 
 // CHECK-LABEL: invert_permutation
 func.func @invert_permutation(%arg0: tensor<5xi32>) -> tensor<5xi32> {
-  // CHECK-DAG: %[[UPDATES:.*]] = "tf.Const"() {value = dense<[0, 1, 2, 3, 4]> : tensor<5xi32>} : () -> tensor<5xi32>
-  // CHECK-DAG: %[[SHAPE:.*]] = "tf.Const"() {value = dense<[5, 1]> : tensor<2xi32>} : () -> tensor<2xi32>
-  // CHECK-DAG: %[[cst_1:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
-  // CHECK-DAG: %[[cst_2:.*]] = "tf.Const"() {value = dense<1> : tensor<5xi32>} : () -> tensor<5xi32>
-  // CHECK-DAG: %[[cst_3:.*]] = "tf.Const"() {value = dense<0> : tensor<5xi32>} : () -> tensor<5xi32>
+  // CHECK-DAG: %[[UPDATES:.*]] = "tf.Const"() <{value = dense<[0, 1, 2, 3, 4]> : tensor<5xi32>}> : () -> tensor<5xi32>
+  // CHECK-DAG: %[[SHAPE:.*]] = "tf.Const"() <{value = dense<[5, 1]> : tensor<2xi32>}> : () -> tensor<2xi32>
+  // CHECK-DAG: %[[cst_1:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
+  // CHECK-DAG: %[[cst_2:.*]] = "tf.Const"() <{value = dense<1> : tensor<5xi32>}> : () -> tensor<5xi32>
+  // CHECK-DAG: %[[cst_3:.*]] = "tf.Const"() <{value = dense<0> : tensor<5xi32>}> : () -> tensor<5xi32>
 
   // CHECK-DAG: %[[INDICES:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) : (tensor<5xi32>, tensor<2xi32>) -> tensor<5x1xi32>
   // CHECK-DAG: %[[INDICES_1:.*]] = "tf.TensorScatterAdd"(%[[cst_3]], %[[INDICES]], %[[cst_2]]) : (tensor<5xi32>, tensor<5x1xi32>, tensor<5xi32>) -> tensor<5xi32>
@@ -35,7 +35,7 @@
 // CHECK-LABEL: simple_pack
 // CHECK-SAME: %[[ARG0:.*]]: tensor<3x5xf32>, %[[ARG1:.*]]: tensor<3x5xf32>
 func.func @simple_pack(%arg0: tensor<3x5xf32>, %arg1: tensor<3x5xf32>) -> tensor<2x3x5xf32> {
-  // CHECK: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>}
+  // CHECK: %[[AXIS:.*]] = "tf.Const"() <{value = dense<0> : tensor<i64>}>
   // CHECK: %[[INP0:.*]] = "tf.ExpandDims"(%[[ARG0]], %[[AXIS]]) : (tensor<3x5xf32>, tensor<i64>) -> tensor<1x3x5xf32>
   // CHECK: %[[INP1:.*]] = "tf.ExpandDims"(%[[ARG1]], %[[AXIS]]) : (tensor<3x5xf32>, tensor<i64>) -> tensor<1x3x5xf32>
   // CHECK: "tf.ConcatV2"(%[[INP0]], %[[INP1]], %[[AXIS]]) : (tensor<1x3x5xf32>, tensor<1x3x5xf32>, tensor<i64>) -> tensor<2x3x5xf32>
@@ -71,8 +71,8 @@
 // CHECK-LABEL: func @div_no_nan
 // CHECK-SAME: (%[[X:.*]]: tensor<*xf32>, %[[Y:.*]]: tensor<*xf32>)
 func.func @div_no_nan(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
-  // CHECK:  %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
-  // CHECK:  %[[IS_ZERO:.*]] = "tf.Equal"(%[[Y]], %[[ZERO]]) {incompatible_shape_error = true} : (tensor<*xf32>, tensor<f32>) -> tensor<*xi1>
+  // CHECK:  %[[ZERO:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+  // CHECK:  %[[IS_ZERO:.*]] = "tf.Equal"(%[[Y]], %[[ZERO]]) <{incompatible_shape_error = true}> : (tensor<*xf32>, tensor<f32>) -> tensor<*xi1>
   // CHECK:  %[[DIV:.*]] = "tf.Div"(%[[X]], %[[Y]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
   // CHECK:  %[[RESULT:.*]] = "tf.SelectV2"(%[[IS_ZERO]], %[[ZERO]], %[[DIV]]) : (tensor<*xi1>, tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
   %0 = "tf.DivNoNan"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
@@ -96,7 +96,7 @@
 // CHECK-SAME: (%[[LHS:.*]]: tensor<*xf32>, %[[RHS:.*]]: tensor<*xf32>)
 func.func @truncate_div_float(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>)
     -> tensor<*xf32> {
-  // CHECK:  %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
+  // CHECK:  %[[ZERO:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
   // CHECK:  %[[XDIVY:.*]] = "tf.Div"(%[[LHS]], %[[RHS]])
   // CHECK:  %[[MASK:.*]] = "tf.Less"(%[[XDIVY]], %[[ZERO]])
   // CHECK:  %[[CEIL:.*]] = "tf.Ceil"(%[[XDIVY]])
@@ -112,8 +112,8 @@
 // CHECK-LABEL: func @mul_no_nan
 // CHECK-SAME: (%[[X:.*]]: tensor<2x3xf32>, %[[Y:.*]]: tensor<3xf32>)
 func.func @mul_no_nan(%arg0: tensor<2x3xf32>, %arg1: tensor<3xf32>) -> tensor<2x3xf32> {
-  // CHECK:  %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
-  // CHECK:  %[[IS_ZERO:.*]] = "tf.Equal"(%[[Y]], %[[ZERO]]) {incompatible_shape_error = true} : (tensor<3xf32>, tensor<f32>) -> tensor<3xi1>
+  // CHECK:  %[[ZERO:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+  // CHECK:  %[[IS_ZERO:.*]] = "tf.Equal"(%[[Y]], %[[ZERO]]) <{incompatible_shape_error = true}> : (tensor<3xf32>, tensor<f32>) -> tensor<3xi1>
   // CHECK:  %[[MUL:.*]] = "tf.Mul"(%[[X]], %[[Y]]) : (tensor<2x3xf32>, tensor<3xf32>) -> tensor<2x3xf32>
   // CHECK:  %[[RESULT:.*]] = "tf.SelectV2"(%[[IS_ZERO]], %[[ZERO]], %[[MUL]]) : (tensor<3xi1>, tensor<f32>, tensor<2x3xf32>) -> tensor<2x3xf32>
   %0 = "tf.MulNoNan"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<3xf32>) -> tensor<2x3xf32>
@@ -124,9 +124,9 @@
 
 // CHECK-LABEL: @is_inf
 func.func @is_inf(%arg0: tensor<3x4xf32>) -> tensor<3x4xi1> {
-  // CHECK: %[[INF:.*]] = "tf.Const"() {value = dense<0x7F800000> : tensor<f32>} : () -> tensor<f32>
+  // CHECK: %[[INF:.*]] = "tf.Const"() <{value = dense<0x7F800000> : tensor<f32>}> : () -> tensor<f32>
   // CHECK: %[[ABS:.*]] = "tf.Abs"(%arg0) : (tensor<3x4xf32>) -> tensor<3x4xf32>
-  // CHECK: %[[RESULT:.*]] = "tf.Equal"(%[[ABS]], %[[INF]]) {incompatible_shape_error = true} : (tensor<3x4xf32>, tensor<f32>) -> tensor<3x4xi1>
+  // CHECK: %[[RESULT:.*]] = "tf.Equal"(%[[ABS]], %[[INF]]) <{incompatible_shape_error = true}> : (tensor<3x4xf32>, tensor<f32>) -> tensor<3x4xi1>
   %0 = "tf.IsInf"(%arg0) : (tensor<3x4xf32>) -> tensor<3x4xi1>
   // CHECK: return %[[RESULT]]
   func.return %0 : tensor<3x4xi1>
@@ -134,7 +134,7 @@
 
 // CHECK-LABEL: @is_nan
 func.func @is_nan(%arg0: tensor<3x4xf32>) -> tensor<3x4xi1> {
-  // CHECK: %[[RESULT:.*]] = "tf.NotEqual"(%arg0, %arg0) {incompatible_shape_error = true} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xi1>
+  // CHECK: %[[RESULT:.*]] = "tf.NotEqual"(%arg0, %arg0) <{incompatible_shape_error = true}> : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xi1>
   %0 = "tf.IsNan"(%arg0) : (tensor<3x4xf32>) -> tensor<3x4xi1>
   // CHECK: return %[[RESULT]]
   func.return %0 : tensor<3x4xi1>
@@ -149,7 +149,7 @@
 }
 
 func.func @empty(%arg0: tensor<?xi32>) -> tensor<*xf32> {
-  // CHECK-DAG: [[CST:%.+]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>}
+  // CHECK-DAG: [[CST:%.+]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}>
   // CHECK-DAG: [[RES:%.+]] = "tf.BroadcastTo"([[CST]], %arg0)
   %0 = "tf.Empty"(%arg0) {init = true} : (tensor<?xi32>) -> (tensor<*xf32>)
 
@@ -162,9 +162,9 @@
 func.func @l2_loss(%arg0: tensor<?x?xf32>) -> tensor<f32> {
 
   // CHECK-DAG: %[[SQUARE:.*]] = "tf.Mul"(%[[INPUT]], %[[INPUT]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
-  // CHECK-DAG: %[[REDUCE_AXES:.*]] = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>}
-  // CHECK-DAG: %[[SUM:.*]] = "tf.Sum"(%[[SQUARE]], %[[REDUCE_AXES]]) {keep_dims = false} : (tensor<?x?xf32>, tensor<2xi64>) -> tensor<f32>
-  // CHECK-DAG: %[[TWO:.*]] = "tf.Const"() {value = dense<2.000000e+00> : tensor<f32>}
+  // CHECK-DAG: %[[REDUCE_AXES:.*]] = "tf.Const"() <{value = dense<[0, 1]> : tensor<2xi64>}>
+  // CHECK-DAG: %[[SUM:.*]] = "tf.Sum"(%[[SQUARE]], %[[REDUCE_AXES]]) <{keep_dims = false}> : (tensor<?x?xf32>, tensor<2xi64>) -> tensor<f32>
+  // CHECK-DAG: %[[TWO:.*]] = "tf.Const"() <{value = dense<2.000000e+00> : tensor<f32>}>
   // CHECK-DAG: %[[LOSS:.*]] = "tf.Div"(%[[SUM]], %[[TWO]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
 
   %0 = "tf.L2Loss"(%arg0) : (tensor<?x?xf32>) -> tensor<f32>
@@ -183,7 +183,7 @@
 // CHECK-LABEL: pack_with_unranked
 // CHECK-SAME: %[[ARG0:.*]]: tensor<?x5xf32>, %[[ARG1:.*]]: tensor<*xf32>
 func.func @pack_with_unranked(%arg0: tensor<?x5xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
-  // CHECK: %[[AXIS:.*]] = "tf.Const"() {value = dense<-2> : tensor<i64>}
+  // CHECK: %[[AXIS:.*]] = "tf.Const"() <{value = dense<-2> : tensor<i64>}>
   // CHECK: %[[INP0:.*]] = "tf.ExpandDims"(%[[ARG0]], %[[AXIS]]) : (tensor<?x5xf32>, tensor<i64>) -> tensor<?x1x5xf32>
   // CHECK: %[[INP1:.*]] = "tf.ExpandDims"(%[[ARG1]], %[[AXIS]]) : (tensor<*xf32>, tensor<i64>) -> tensor<*xf32>
   // CHECK: "tf.ConcatV2"(%[[INP0]], %[[INP1]], %[[AXIS]]) : (tensor<?x1x5xf32>, tensor<*xf32>, tensor<i64>) -> tensor<*xf32>
@@ -196,7 +196,7 @@
 func.func @pad(%arg0: tensor<3xf32>) -> tensor<6xf32> {
   %padding = "tf.Const"() { value = dense<[[1, 2]]> : tensor<1x2xi64> } : () -> tensor<1x2xi64>
   // CHECK-DAG: [[PAD:%.+]] = "tf.Const"() {{.+}} -> tensor<1x2xi64>
-  // CHECK-DAG: [[CST:%.+]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>}
+  // CHECK-DAG: [[CST:%.+]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}>
   // CHECK: "tf.PadV2"(%arg0, [[PAD]], [[CST]])
   %0 = "tf.Pad"(%arg0, %padding) : (tensor<3xf32>, tensor<1x2xi64>) -> tensor<6xf32>
   func.return %0 : tensor<6xf32>
@@ -206,7 +206,7 @@
 func.func @pad_bf16(%arg0: tensor<3xbf16>) -> tensor<6xbf16> {
   %padding = "tf.Const"() { value = dense<[[1, 2]]> : tensor<1x2xi64> } : () -> tensor<1x2xi64>
   // CHECK-DAG: [[PAD:%.+]] = "tf.Const"() {{.+}}  -> tensor<1x2xi64>
-  // CHECK-DAG: [[CST:%.+]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<bf16>}
+  // CHECK-DAG: [[CST:%.+]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<bf16>}>
   // CHECK: "tf.PadV2"(%arg0, [[PAD]], [[CST]])
   %0 = "tf.Pad"(%arg0, %padding) : (tensor<3xbf16>, tensor<1x2xi64>) -> tensor<6xbf16>
   func.return %0 : tensor<6xbf16>
@@ -221,8 +221,8 @@
 
 // CHECK-LABEL: func @BiasAddGrad_NHWC
 func.func @BiasAddGrad_NHWC(%arg0: tensor<2x3x4x5xf32>) -> tensor<5xf32> {
-  // CHECK: "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi64>}
-  // CHECK: "tf.Sum"({{.*}}) {keep_dims = false}
+  // CHECK: "tf.Const"() <{value = dense<[0, 1, 2]> : tensor<3xi64>}>
+  // CHECK: "tf.Sum"({{.*}}) <{keep_dims = false}>
 
   %0 = "tf.BiasAddGrad"(%arg0) {data_format = "NHWC"} : (tensor<2x3x4x5xf32>) -> tensor<5xf32>
   func.return %0 : tensor<5xf32>
@@ -230,8 +230,8 @@
 
 // CHECK-LABEL: func @BiasAddGrad_NCHW
 func.func @BiasAddGrad_NCHW(%arg0: tensor<2x3x4x5xf32>) -> tensor<3xf32> {
-  // CHECK: "tf.Const"() {value = dense<[0, 2, 3]> : tensor<3xi64>}
-  // CHECK: "tf.Sum"({{.*}}) {keep_dims = false}
+  // CHECK: "tf.Const"() <{value = dense<[0, 2, 3]> : tensor<3xi64>}>
+  // CHECK: "tf.Sum"({{.*}}) <{keep_dims = false}>
 
   %0 = "tf.BiasAddGrad"(%arg0) {data_format = "NCHW"} : (tensor<2x3x4x5xf32>) -> tensor<3xf32>
   func.return %0 : tensor<3xf32>
@@ -254,7 +254,7 @@
 // CHECK-LABEL: func @rsqrt_grad
 // CHECK-SAME: (%[[ARG0:.*]]: tensor<2xf32>, %[[ARG1:.*]]: tensor<2xf32>)
 func.func @rsqrt_grad(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
-  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<-2.000000e+00> : tensor<f32>}
+  // CHECK: %[[CST:.*]] = "tf.Const"() <{value = dense<-2.000000e+00> : tensor<f32>}>
   // CHECK: %[[LHS2:.*]] = "tf.Mul"(%[[ARG0]], %[[ARG0]])
   // CHECK: %[[LHS3:.*]] = "tf.Mul"(%[[LHS2]], %[[ARG0]])
   // CHECK: %[[DIV:.*]] = "tf.Div"(%[[ARG1]], %[[CST]])
@@ -279,7 +279,7 @@
 // CHECK-LABEL: func @sqrt_grad_unranked
 // CHECK-SAME: (%[[ARG0:.*]]: tensor<*xcomplex<f32>>, %[[ARG1:.*]]: tensor<*xcomplex<f32>>)
 func.func @sqrt_grad_unranked(%arg0: tensor<*xcomplex<f32>>, %arg1: tensor<*xcomplex<f32>>) -> tensor<*xcomplex<f32>> {
-  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<(5.000000e-01,0.000000e+00)> : tensor<complex<f32>>} : () -> tensor<complex<f32>>
+  // CHECK: %[[CST:.*]] = "tf.Const"() <{value = dense<(5.000000e-01,0.000000e+00)> : tensor<complex<f32>>}> : () -> tensor<complex<f32>>
   // CHECK: %[[MUL:.*]] = "tf.Mul"(%arg1, %[[CST]]) : (tensor<*xcomplex<f32>>, tensor<complex<f32>>) -> tensor<*xcomplex<f32>>
   // CHECK: %[[RET:.*]] = "tf.Div"(%[[MUL]], %arg0) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> tensor<*xcomplex<f32>>
 
@@ -292,22 +292,22 @@
 // dimension.
 // CHECK-LABEL: fourdim_space_to_batch_nd
 func.func @fourdim_space_to_batch_nd(%input: tensor<3x5x7x10xf32>, %block_shape: tensor<2xi64>, %paddings: tensor<2x2xi64>) -> tensor<?x?x?x10xf32> {
-  // CHECK-DAG: [[PAD00:%.+]] = "tf.Const"() {value = dense<0> : tensor<1x2xi64>}
-  // CHECK-DAG: [[ZERO_I32:%.+]] = "tf.Const"() {value = dense<0> : tensor<i32>}
-  // CHECK-DAG: [[ZERO_I64:%.+]] = "tf.Const"() {value = dense<0> : tensor<i64>}
+  // CHECK-DAG: [[PAD00:%.+]] = "tf.Const"() <{value = dense<0> : tensor<1x2xi64>}>
+  // CHECK-DAG: [[ZERO_I32:%.+]] = "tf.Const"() <{value = dense<0> : tensor<i32>}>
+  // CHECK-DAG: [[ZERO_I64:%.+]] = "tf.Const"() <{value = dense<0> : tensor<i64>}>
   // CHECK-DAG: [[FULL_PADDINGS:%.+]] = "tf.ConcatV2"([[PAD00]], %arg2, [[PAD00]], [[ZERO_I64]])
-  // CHECK-DAG: [[PAD_DEFAULT:%.+]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>}
+  // CHECK-DAG: [[PAD_DEFAULT:%.+]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}>
   // CHECK-DAG: [[PADDED:%.+]] = "tf.PadV2"(%arg0, [[FULL_PADDINGS]], [[PAD_DEFAULT]])
-  // CHECK-DAG: [[PADDINGS:%.+]]:2 = "tf.Unpack"([[FULL_PADDINGS]]) {axis = 1 : i64}
+  // CHECK-DAG: [[PADDINGS:%.+]]:2 = "tf.Unpack"([[FULL_PADDINGS]]) <{axis = 1 : i64}>
   // CHECK-DAG: [[PADDINGS_SUM:%.+]] = "tf.AddV2"([[PADDINGS]]#0, [[PADDINGS]]#1)
-  // CHECK-DAG: [[INPUT_SHAPE:%.+]] = "tf.Const"() {value = dense<[3, 5, 7, 10]> : tensor<4xi64>}
+  // CHECK-DAG: [[INPUT_SHAPE:%.+]] = "tf.Const"() <{value = dense<[3, 5, 7, 10]> : tensor<4xi64>}>
   // CHECK-DAG: [[PADDED_SHAPE:%.+]] = "tf.AddV2"([[PADDINGS_SUM]], [[INPUT_SHAPE]])
   // CHECK-DAG: [[PADDED_SHAPE_SPLITS:%.+]]:4 = "tf.Split"([[ZERO_I32]], [[PADDED_SHAPE]])
   // CHECK-DAG: [[BLOCK_SHAPE_SPLITS:%.+]]:2 = "tf.Split"([[ZERO_I32]], %arg1)
   // CHECK-DAG: [[OUTER_SHAPE_0:%.+]] = "tf.Div"([[PADDED_SHAPE_SPLITS]]#1, [[BLOCK_SHAPE_SPLITS]]#0)
   // CHECK-DAG: [[OUTER_SHAPE_1:%.+]] = "tf.Div"([[PADDED_SHAPE_SPLITS]]#2, [[BLOCK_SHAPE_SPLITS]]#1)
   // CHECK-DAG: [[RESHAPED_SHAPE:%.+]] = "tf.ConcatV2"([[PADDED_SHAPE_SPLITS]]#0, [[OUTER_SHAPE_0]], [[BLOCK_SHAPE_SPLITS]]#0, [[OUTER_SHAPE_1]], [[BLOCK_SHAPE_SPLITS]]#1, [[PADDED_SHAPE_SPLITS]]#3, [[ZERO_I64]])
-  // CHECK-DAG: [[PERMUTATION:%.+]] = "tf.Const"() {value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi64>}
+  // CHECK-DAG: [[PERMUTATION:%.+]] = "tf.Const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi64>}>
   // CHECK-DAG: [[OUTPUT_BATCH_PART:%.+]] = "tf.Mul"([[PADDED_SHAPE_SPLITS]]#0, [[BLOCK_SHAPE_SPLITS]]#0)
   // CHECK-DAG: [[OUTPUT_BATCH:%.+]] = "tf.Mul"([[OUTPUT_BATCH_PART]], [[BLOCK_SHAPE_SPLITS]]#1)
   // CHECK-DAG: [[OUTPUT_SHAPE:%.+]] = "tf.ConcatV2"([[OUTPUT_BATCH]], [[OUTER_SHAPE_0]], [[OUTER_SHAPE_1]], [[PADDED_SHAPE_SPLITS]]#3, [[ZERO_I64]])
@@ -336,11 +336,11 @@
   %1 = "tf.Const"() {value = dense<[[3, 4]]> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
 
 
-  // CHECK-DAG: [[VAL0:%.+]] = "tf.Const"() {value = dense<[3, 5, 2]> : tensor<3xi64>}
-  // CHECK-DAG: [[VAL1:%.+]] = "tf.Const"() {value = dense<[1, 5, 3, 2]> : tensor<4xi64>}
-  // CHECK-DAG: [[VAL2:%.+]] = "tf.Const"() {value = dense<{{\[\[}}0, 0], [3, 4], [0, 0{{\]\]}}> : tensor<3x2xi64>}
-  // CHECK-DAG: [[VAL3:%.+]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>}
-  // CHECK-DAG: [[VAL4:%.+]] = "tf.Const"() {value = dense<[2, 0, 1, 3]> : tensor<4xi64>}
+  // CHECK-DAG: [[VAL0:%.+]] = "tf.Const"() <{value = dense<[3, 5, 2]> : tensor<3xi64>}>
+  // CHECK-DAG: [[VAL1:%.+]] = "tf.Const"() <{value = dense<[1, 5, 3, 2]> : tensor<4xi64>}>
+  // CHECK-DAG: [[VAL2:%.+]] = "tf.Const"() <{value = dense<{{\[\[}}0, 0], [3, 4], [0, 0{{\]\]}}> : tensor<3x2xi64>}>
+  // CHECK-DAG: [[VAL3:%.+]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}>
+  // CHECK-DAG: [[VAL4:%.+]] = "tf.Const"() <{value = dense<[2, 0, 1, 3]> : tensor<4xi64>}>
   // CHECK-DAG: [[VAL5:%.+]] = "tf.PadV2"(%arg0, [[VAL2]], [[VAL3]])
   // CHECK-SAME: tensor<1x15x2xf32>
   // CHECK-DAG: [[VAL6:%.+]] = "tf.Reshape"([[VAL5]], [[VAL1]])
@@ -368,14 +368,14 @@
 func.func @sixdim_space_to_batch_nd(%input: tensor<3x5x7x9x10x11xf32>, %block_shape: tensor<3xi64>, %paddings: tensor<3x2xi64>) -> tensor<?x?x?x?x10x11xf32> {
   // CHECK-DAG: [[PAD00:%.+]] = "tf.Const"()
   // CHECK-DAG: [[FULL_PADDINGS:%.+]] = "tf.ConcatV2"([[PAD00]], %arg2, [[PAD00]], [[PAD00]], {{.+}})
-  // CHECK-DAG: [[INPUT_SHAPE:%.+]] = "tf.Const"() {value = dense<[3, 5, 7, 9, 10, 11]> : tensor<6xi64>}
+  // CHECK-DAG: [[INPUT_SHAPE:%.+]] = "tf.Const"() <{value = dense<[3, 5, 7, 9, 10, 11]> : tensor<6xi64>}>
   // CHECK-DAG: [[PADDED_SHAPE_SPLITS:%.+]]:6 = "tf.Split"
   // CHECK-DAG: [[BLOCK_SHAPE_SPLITS:%.+]]:3 = "tf.Split"
   // CHECK-DAG: [[OUTER_SHAPE_0:%.+]] = "tf.Div"([[PADDED_SHAPE_SPLITS]]#1, [[BLOCK_SHAPE_SPLITS]]#0)
   // CHECK-DAG: [[OUTER_SHAPE_1:%.+]] = "tf.Div"([[PADDED_SHAPE_SPLITS]]#2, [[BLOCK_SHAPE_SPLITS]]#1)
   // CHECK-DAG: [[OUTER_SHAPE_2:%.+]] = "tf.Div"([[PADDED_SHAPE_SPLITS]]#3, [[BLOCK_SHAPE_SPLITS]]#2)
   // CHECK-DAG: [[RESHAPED_SHAPE:%.+]] = "tf.ConcatV2"([[PADDED_SHAPE_SPLITS]]#0, [[OUTER_SHAPE_0]], [[BLOCK_SHAPE_SPLITS]]#0, [[OUTER_SHAPE_1]], [[BLOCK_SHAPE_SPLITS]]#1, [[OUTER_SHAPE_2]], [[BLOCK_SHAPE_SPLITS]]#2, [[PADDED_SHAPE_SPLITS]]#4, [[PADDED_SHAPE_SPLITS]]#5, {{.+}})
-  // CHECK-DAG: [[PERMUTATION:%.+]] = "tf.Const"() {value = dense<[2, 4, 6, 0, 1, 3, 5, 7, 8]> : tensor<9xi64>}
+  // CHECK-DAG: [[PERMUTATION:%.+]] = "tf.Const"() <{value = dense<[2, 4, 6, 0, 1, 3, 5, 7, 8]> : tensor<9xi64>}>
   // CHECK-DAG: [[OUTPUT_BATCH_PART1:%.+]] = "tf.Mul"([[PADDED_SHAPE_SPLITS]]#0, [[BLOCK_SHAPE_SPLITS]]#0)
   // CHECK-DAG: [[OUTPUT_BATCH_PART2:%.+]] = "tf.Mul"([[OUTPUT_BATCH_PART1]], [[BLOCK_SHAPE_SPLITS]]#1)
   // CHECK-DAG: [[OUTPUT_BATCH:%.+]] = "tf.Mul"([[OUTPUT_BATCH_PART2]], [[BLOCK_SHAPE_SPLITS]]#2)
@@ -386,11 +386,11 @@
 
 // CHECK-LABEL: func @batchToSpace
 func.func @batchToSpace(%arg0: tensor<3x5x2xf32>) -> (tensor<1x8x2xf32>) {
-  // CHECK-DAG: [[VAL0:%.+]] = "tf.Const"() {value = dense<[3, 1, 5, 2]> : tensor<4xi64>}
-  // CHECK-DAG: [[VAL1:%.+]] = "tf.Const"() {value = dense<[1, 2, 0, 3]> : tensor<4xi64>}
-  // CHECK-DAG: [[VAL2:%.+]] = "tf.Const"() {value = dense<[1, 15, 2]> : tensor<3xi64>}
-  // CHECK-DAG: [[VAL3:%.+]] = "tf.Const"() {value = dense<[0, 3, 0]> : tensor<3xi64>}
-  // CHECK-DAG: [[VAL4:%.+]] = "tf.Const"() {value = dense<[1, 8, 2]> : tensor<3xi64>}
+  // CHECK-DAG: [[VAL0:%.+]] = "tf.Const"() <{value = dense<[3, 1, 5, 2]> : tensor<4xi64>}>
+  // CHECK-DAG: [[VAL1:%.+]] = "tf.Const"() <{value = dense<[1, 2, 0, 3]> : tensor<4xi64>}>
+  // CHECK-DAG: [[VAL2:%.+]] = "tf.Const"() <{value = dense<[1, 15, 2]> : tensor<3xi64>}>
+  // CHECK-DAG: [[VAL3:%.+]] = "tf.Const"() <{value = dense<[0, 3, 0]> : tensor<3xi64>}>
+  // CHECK-DAG: [[VAL4:%.+]] = "tf.Const"() <{value = dense<[1, 8, 2]> : tensor<3xi64>}>
   // CHECK-DAG: [[VAL5:%.+]] = "tf.Reshape"(%arg0, [[VAL0]])
   // CHECK-DAG: [[VAL6:%.+]] = "tf.Transpose"([[VAL5]], [[VAL1]])
   // CHECK-DAG: [[VAL7:%.+]] = "tf.Reshape"([[VAL6]], [[VAL2]])
@@ -404,11 +404,11 @@
 }
 
 func.func @fake_quant_with_min_max_args(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
-  // CHECK-DAG: [[VAL0:%.+]] = "tf.Const"() {value = dense<1.275000e+02> : tensor<f32>}
-  // CHECK-DAG: [[VAL1:%.+]] = "tf.Const"() {value = dense<1.00392163> : tensor<f32>}
-  // CHECK-DAG: [[VAL2:%.+]] = "tf.Const"() {value = dense<-0.996078491> : tensor<f32>}
-  // CHECK-DAG: [[VAL3:%.+]] = "tf.Const"() {value = dense<0.00784313772> : tensor<f32>}
-  // CHECK-DAG: [[VAL4:%.+]] = "tf.Const"() {value = dense<5.000000e-01> : tensor<f32>}
+  // CHECK-DAG: [[VAL0:%.+]] = "tf.Const"() <{value = dense<1.275000e+02> : tensor<f32>}>
+  // CHECK-DAG: [[VAL1:%.+]] = "tf.Const"() <{value = dense<1.00392163> : tensor<f32>}>
+  // CHECK-DAG: [[VAL2:%.+]] = "tf.Const"() <{value = dense<-0.996078491> : tensor<f32>}>
+  // CHECK-DAG: [[VAL3:%.+]] = "tf.Const"() <{value = dense<0.00784313772> : tensor<f32>}>
+  // CHECK-DAG: [[VAL4:%.+]] = "tf.Const"() <{value = dense<5.000000e-01> : tensor<f32>}>
   // CHECK-DAG: [[VAL5:%.+]] = "tf.ClipByValue"(%arg0, [[VAL2]], [[VAL1]])
   // CHECK-DAG: [[VAL6:%.+]] = "tf.Sub"([[VAL5]], [[VAL2]])
   // CHECK-DAG: [[VAL7:%.+]] = "tf.Mul"([[VAL6]], [[VAL0]])
@@ -423,11 +423,11 @@
 }
 
 func.func @fake_quant_with_min_max_vars(%arg0 : tensor<?x?xf32>, %arg1 : tensor<f32>, %arg2 : tensor<f32>) -> tensor<?x?xf32> {
-  // CHECK-DAG: %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
-  // CHECK-DAG: %[[VAL1:.*]] = "tf.Const"() {value = dense<2.550000e+02> : tensor<f32>} : () -> tensor<f32>
-  // CHECK-DAG: %[[VAL2:.*]] = "tf.Const"() {value = dense<2.000000e+00> : tensor<f32>} : () -> tensor<f32>
-  // CHECK-DAG: %[[VAL3:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
-  // CHECK-DAG: %[[VAL4:.*]] = "tf.Const"() {value = dense<5.000000e-01> : tensor<f32>} : () -> tensor<f32>
+  // CHECK-DAG: %[[ZERO:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+  // CHECK-DAG: %[[VAL1:.*]] = "tf.Const"() <{value = dense<2.550000e+02> : tensor<f32>}> : () -> tensor<f32>
+  // CHECK-DAG: %[[VAL2:.*]] = "tf.Const"() <{value = dense<2.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+  // CHECK-DAG: %[[VAL3:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+  // CHECK-DAG: %[[VAL4:.*]] = "tf.Const"() <{value = dense<5.000000e-01> : tensor<f32>}> : () -> tensor<f32>
   // CHECK-DAG: %[[VAL5:.*]] = "tf.Sub"(%arg2, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
   // CHECK-DAG: %[[VAL6:.*]] = "tf.Div"(%[[VAL5]], %[[VAL1]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
   // CHECK-DAG: %[[VAL7:.*]] = "tf.Div"(%[[VAL1]], %[[VAL5]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
@@ -436,17 +436,17 @@
   // CHECK-DAG: %[[VAL10:.*]] = "tf.Floor"(%[[VAL9]]) : (tensor<f32>) -> tensor<f32>
   // CHECK-DAG: %[[VAL11:.*]] = "tf.Sub"(%[[VAL9]], %[[VAL10]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
   // CHECK-DAG: %[[VAL12:.*]] = "tf.Greater"(%[[VAL11]], %[[VAL4]]) : (tensor<f32>, tensor<f32>) -> tensor<i1>
-  // CHECK-DAG: %[[VAL13:.*]] = "tf.Equal"(%[[VAL11]], %[[VAL4]]) {incompatible_shape_error = true} : (tensor<f32>, tensor<f32>) -> tensor<i1>
+  // CHECK-DAG: %[[VAL13:.*]] = "tf.Equal"(%[[VAL11]], %[[VAL4]]) <{incompatible_shape_error = true}> : (tensor<f32>, tensor<f32>) -> tensor<i1>
   // CHECK-DAG: %[[VAL14:.*]] = "tf.Mul"(%[[VAL9]], %[[VAL4]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
   // CHECK-DAG: %[[VAL15:.*]] = "tf.Floor"(%[[VAL14]]) : (tensor<f32>) -> tensor<f32>
   // CHECK-DAG: %[[VAL16:.*]] = "tf.Mul"(%[[VAL15]], %[[VAL2]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
   // CHECK-DAG: %[[VAL17:.*]] = "tf.Sub"(%[[VAL10]], %[[VAL16]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
-  // CHECK-DAG: %[[VAL18:.*]] = "tf.Equal"(%[[VAL17]], %[[VAL3]]) {incompatible_shape_error = true} : (tensor<f32>, tensor<f32>) -> tensor<i1>
+  // CHECK-DAG: %[[VAL18:.*]] = "tf.Equal"(%[[VAL17]], %[[VAL3]]) <{incompatible_shape_error = true}> : (tensor<f32>, tensor<f32>) -> tensor<i1>
   // CHECK-DAG: %[[VAL19:.*]] = "tf.LogicalAnd"(%[[VAL13]], %[[VAL18]]) : (tensor<i1>, tensor<i1>) -> tensor<i1>
   // CHECK-DAG: %[[VAL20:.*]] = "tf.LogicalOr"(%[[VAL12]], %[[VAL19]]) : (tensor<i1>, tensor<i1>) -> tensor<i1>
   // CHECK-DAG: %[[VAL21:.*]] = "tf.AddV2"(%[[VAL10]], %[[VAL3]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
   // CHECK-DAG: %[[INNER_SELECT:.*]] = "tf.SelectV2"(%[[VAL20]], %[[VAL21]], %[[VAL10]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
-  // CHECK-DAG: %[[IS_ZERO:.*]] = "tf.Equal"(%[[INNER_SELECT]], %[[ZERO]]) {incompatible_shape_error = true}
+  // CHECK-DAG: %[[IS_ZERO:.*]] = "tf.Equal"(%[[INNER_SELECT]], %[[ZERO]]) <{incompatible_shape_error = true}>
   // CHECK-DAG: %[[VAL22:.*]] = "tf.SelectV2"(%[[IS_ZERO]], %[[ZERO]], %[[INNER_SELECT]])
   // CHECK-DAG: %[[VAL23:.*]] = "tf.ClipByValue"(%[[VAL22]], %[[ZERO]], %[[VAL1]]) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32>
   // CHECK-DAG: %[[VAL24:.*]] = "tf.Sub"(%[[ZERO]], %[[VAL23]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
@@ -469,29 +469,29 @@
 // CHECK-LABEL: SoftmaxCrossEntropyWithLogits
 // CHECK-SAME: %[[FEATURES:.*]]: tensor<2x3xf32>, %[[LABELS:.*]]: tensor<2x3xf32>
 func.func @SoftmaxCrossEntropyWithLogits(%features: tensor<2x3xf32>, %labels: tensor<2x3xf32>) -> (tensor<2xf32>, tensor<2x3xf32>) {
-  // CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi64>} : () -> tensor<1xi64>
-  // CHECK-DAG: %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
+  // CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() <{value = dense<-1> : tensor<1xi64>}> : () -> tensor<1xi64>
+  // CHECK-DAG: %[[ZERO:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
   // CHECK-DAG: %[[NEG_LABELS:.*]] = "tf.Neg"(%[[LABELS]]) : (tensor<2x3xf32>) -> tensor<2x3xf32>
 
   // LogSoftmax expansion.
-  // CHECK-DAG: %[[LOG_SOFTMAX_MAX:.*]] = "tf.Max"(%[[FEATURES]], %[[AXIS]]) {keep_dims = true} : (tensor<2x3xf32>, tensor<1xi64>) -> tensor<2x1xf32>
+  // CHECK-DAG: %[[LOG_SOFTMAX_MAX:.*]] = "tf.Max"(%[[FEATURES]], %[[AXIS]]) <{keep_dims = true}> : (tensor<2x3xf32>, tensor<1xi64>) -> tensor<2x1xf32>
   // CHECK-DAG: %[[LOG_SOFTMAX_SHIFTED:.*]] = "tf.Sub"(%[[FEATURES]], %[[LOG_SOFTMAX_MAX]]) : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
   // CHECK-DAG: %[[LOG_SOFTMAX_EXP:.*]] = "tf.Exp"(%[[LOG_SOFTMAX_SHIFTED]]) : (tensor<2x3xf32>) -> tensor<2x3xf32>
-  // CHECK-DAG: %[[LOG_SOFTMAX_SUM:.*]] = "tf.Sum"(%[[LOG_SOFTMAX_EXP]], %[[AXIS]]) {keep_dims = true} : (tensor<2x3xf32>, tensor<1xi64>) -> tensor<2x1xf32>
+  // CHECK-DAG: %[[LOG_SOFTMAX_SUM:.*]] = "tf.Sum"(%[[LOG_SOFTMAX_EXP]], %[[AXIS]]) <{keep_dims = true}> : (tensor<2x3xf32>, tensor<1xi64>) -> tensor<2x1xf32>
   // CHECK-DAG: %[[LOG_SOFTMAX_LOG:.*]] = "tf.Log"(%[[LOG_SOFTMAX_SUM]]) : (tensor<2x1xf32>) -> tensor<2x1xf32>
   // CHECK-DAG: %[[LOG_SOFTMAX:.*]] = "tf.Sub"(%[[LOG_SOFTMAX_SHIFTED]], %[[LOG_SOFTMAX_LOG]]) : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
 
 
-  // CHECK-DAG: %[[IS_LABEL_ZERO:.*]] = "tf.Equal"(%[[NEG_LABELS]], %[[ZERO]]) {incompatible_shape_error = true} : (tensor<2x3xf32>, tensor<f32>) -> tensor<2x3xi1>
+  // CHECK-DAG: %[[IS_LABEL_ZERO:.*]] = "tf.Equal"(%[[NEG_LABELS]], %[[ZERO]]) <{incompatible_shape_error = true}> : (tensor<2x3xf32>, tensor<f32>) -> tensor<2x3xi1>
   // CHECK-DAG: %[[LOSS_INP:.*]] = "tf.Mul"(%[[LOG_SOFTMAX]], %[[NEG_LABELS]]) : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
   // CHECK-DAG: %[[SAFE_LOSS_INP:.*]] = "tf.SelectV2"(%[[IS_LABEL_ZERO]], %[[ZERO]], %[[LOSS_INP]]) : (tensor<2x3xi1>, tensor<f32>, tensor<2x3xf32>) -> tensor<2x3xf32>
-  // CHECK-DAG: %[[LOSS:.*]] = "tf.Sum"(%[[SAFE_LOSS_INP]], %[[AXIS]]) {keep_dims = false} : (tensor<2x3xf32>, tensor<1xi64>) -> tensor<2xf32>
+  // CHECK-DAG: %[[LOSS:.*]] = "tf.Sum"(%[[SAFE_LOSS_INP]], %[[AXIS]]) <{keep_dims = false}> : (tensor<2x3xf32>, tensor<1xi64>) -> tensor<2xf32>
 
   // Softmax expansion.
-  // CHECK-DAG: %[[SOFTMAX_MAX:.*]] = "tf.Max"(%arg0, %[[AXIS]]) {keep_dims = true} : (tensor<2x3xf32>, tensor<1xi64>) -> tensor<2x1xf32>
+  // CHECK-DAG: %[[SOFTMAX_MAX:.*]] = "tf.Max"(%arg0, %[[AXIS]]) <{keep_dims = true}> : (tensor<2x3xf32>, tensor<1xi64>) -> tensor<2x1xf32>
   // CHECK-DAG: %[[SOFTMAX_SHIFTED:.*]] = "tf.Sub"(%[[FEATURES]], %[[SOFTMAX_MAX]]) : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
   // CHECK-DAG: %[[SOFTMAX_EXP:.*]] = "tf.Exp"(%[[SOFTMAX_SHIFTED]]) : (tensor<2x3xf32>) -> tensor<2x3xf32>
-  // CHECK-DAG: %[[SOFTMAX_SUM:.*]] = "tf.Sum"(%[[SOFTMAX_EXP]], %[[AXIS]]) {keep_dims = true} : (tensor<2x3xf32>, tensor<1xi64>) -> tensor<2x1xf32>
+  // CHECK-DAG: %[[SOFTMAX_SUM:.*]] = "tf.Sum"(%[[SOFTMAX_EXP]], %[[AXIS]]) <{keep_dims = true}> : (tensor<2x3xf32>, tensor<1xi64>) -> tensor<2x1xf32>
   // CHECK-DAG: %[[SOFTMAX:.*]] = "tf.Div"(%[[SOFTMAX_EXP]], %[[SOFTMAX_SUM]]) : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
 
   // CHECK-DAG: %[[BACKPROP:.*]] = "tf.Sub"(%[[SOFTMAX]], %[[LABELS]]) : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
@@ -534,19 +534,19 @@
 // CHECK-SAME: %[[FEATURES:.*]]: tensor<2x3xf32>, %[[SPARSE_LABELS:.*]]: tensor<2xi32>
 func.func @SparseSoftmaxCrossEntropyWithLogits(%features: tensor<2x3xf32>, %labels: tensor<2xi32>) -> (tensor<2xf32>, tensor<2x3xf32>) {
   // Convert SPARSE_LABELS to dense LABELS.
-  // CHECK-DAG: %[[DEPTH:.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
-  // CHECK-DAG: %[[ONE:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
-  // CHECK-DAG: %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
-  // CHECK-DAG: %[[LABELS:.*]] = "tf.OneHot"(%[[SPARSE_LABELS]], %[[DEPTH]], %[[ONE]], %[[ZERO]]) {axis = 1 : i64} : (tensor<2xi32>, tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<2x3xf32>
+  // CHECK-DAG: %[[DEPTH:.*]] = "tf.Const"() <{value = dense<3> : tensor<i32>}> : () -> tensor<i32>
+  // CHECK-DAG: %[[ONE:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+  // CHECK-DAG: %[[ZERO:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+  // CHECK-DAG: %[[LABELS:.*]] = "tf.OneHot"(%[[SPARSE_LABELS]], %[[DEPTH]], %[[ONE]], %[[ZERO]]) <{axis = 1 : i64}> : (tensor<2xi32>, tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<2x3xf32>
 
   // Adjust labels to have Nan for out of range labels.
-  // CHECK-DAG: %[[ZERO_I32:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // CHECK-DAG: %[[ZERO_I32:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
   // CHECK-DAG: %[[IS_NEGATIVE:.*]] = "tf.LessEqual"(%[[ZERO_I32]], %arg1) : (tensor<i32>, tensor<2xi32>) -> tensor<2xi1>
   // CHECK-DAG: %[[IS_LESS:.*]] = "tf.Less"(%arg1, %[[DEPTH]]) : (tensor<2xi32>, tensor<i32>) -> tensor<2xi1>
   // CHECK-DAG: %[[IS_WITHIN_RANGE:.*]] = "tf.LogicalAnd"(%[[IS_NEGATIVE]], %[[IS_LESS]]) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1>
-  // CHECK-DAG: %[[NAN:.*]] = "tf.Const"() {value = dense<0x7FC00000> : tensor<f32>} : () -> tensor<f32>
+  // CHECK-DAG: %[[NAN:.*]] = "tf.Const"() <{value = dense<0x7FC00000> : tensor<f32>}> : () -> tensor<f32>
   // CHECK-DAG: %[[ZERO_OR_NAN:.*]] = "tf.SelectV2"(%[[IS_WITHIN_RANGE]], %[[ZERO]], %[[NAN]]) : (tensor<2xi1>, tensor<f32>, tensor<f32>) -> tensor<2xf32>
-  // CHECK-DAG: %[[NEG_ONE:.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi64>} : () -> tensor<1xi64>
+  // CHECK-DAG: %[[NEG_ONE:.*]] = "tf.Const"() <{value = dense<-1> : tensor<1xi64>}> : () -> tensor<1xi64>
   // CHECK-DAG: %[[RESHAPE:.*]] = "tf.ExpandDims"(%[[ZERO_OR_NAN]], %[[NEG_ONE]]) : (tensor<2xf32>, tensor<1xi64>) -> tensor<2x1xf32>
   // CHECK-DAG: %[[ADJUSTED_LABELS:.*]] = "tf.AddV2"(%[[LABELS]], %[[RESHAPE]]) : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
 
@@ -589,7 +589,7 @@
 // CHECK-LABEL: func @tanhgrad_float
 // CHECK-SAME: (%[[Y:.*]]: tensor<*xf32>, %[[DY:.*]]: tensor<*xf32>)
 func.func @tanhgrad_float(%y : tensor<*xf32>, %dy: tensor<*xf32>) -> tensor<*xf32> {
-  // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
+  // CHECK: %[[ONE:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
   // CHECK: %[[Y_SQUARE:.*]] = "tf.Mul"(%[[Y]], %[[Y]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
   // CHECK: %[[SUB:.*]] = "tf.Sub"(%[[ONE]], %[[Y_SQUARE]]) : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
   // CHECK: %[[RESULT:.*]] = "tf.Mul"(%[[DY]], %[[SUB]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
@@ -610,7 +610,7 @@
 
 // CHECK-LABEL: func @ZerosLike_unranked
 func.func @ZerosLike_unranked(%arg0: tensor<*xi32>) -> tensor<*xi32> {
-  // CHECK: %[[ZERO:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: %[[ZERO:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
   // CHECK: %[[SHAPE:.*]] = "tf.Shape"(%arg0) : (tensor<*xi32>) -> tensor<?xi64>
   // CHECK: "tf.BroadcastTo"(%[[ZERO]], %[[SHAPE]]) : (tensor<i32>, tensor<?xi64>) -> tensor<*xi32>
 
@@ -627,7 +627,7 @@
 
 // CHECK-LABEL: func @OnesLike_unranked
 func.func @OnesLike_unranked(%arg0: tensor<*xi32>) -> tensor<*xi32> {
-  // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: %[[ONE:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
   // CHECK: %[[SHAPE:.*]] = "tf.Shape"(%arg0) : (tensor<*xi32>) -> tensor<?xi64>
   // CHECK: "tf.BroadcastTo"(%[[ONE]], %[[SHAPE]]) : (tensor<i32>, tensor<?xi64>) -> tensor<*xi32>
 
@@ -682,8 +682,8 @@
 
 // CHECK-LABEL: func @DynamicStitch_simple
 func.func @DynamicStitch_simple(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
-  // CHECK: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
-  // CHECK: %[[ITEMS:.*]]:2 = "tf.Unpack"(%arg0) {axis = 0 : i64} : (tensor<2x2xf32>) -> (tensor<2xf32>, tensor<2xf32>)
+  // CHECK: %[[AXIS:.*]] = "tf.Const"() <{value = dense<0> : tensor<i64>}> : () -> tensor<i64>
+  // CHECK: %[[ITEMS:.*]]:2 = "tf.Unpack"(%arg0) <{axis = 0 : i64}> : (tensor<2x2xf32>) -> (tensor<2xf32>, tensor<2xf32>)
   // CHECK-DAG: %[[ITEMS_1:.*]] = "tf.ExpandDims"(%[[ITEMS]]#1, %[[AXIS]])
   // CHECK-DAG: %[[ITEMS_0:.*]] = "tf.ExpandDims"(%[[ITEMS]]#0, %[[AXIS]])
   // CHECK: %[[RESULT:.*]] = "tf.ConcatV2"(%[[ITEMS_1]], %[[ITEMS_0]], %[[AXIS]]) : (tensor<1x2xf32>, tensor<1x2xf32>, tensor<i64>) -> tensor<2x2xf32>
@@ -696,12 +696,12 @@
 
 // CHECK-LABEL: DynamicStitch_scalar_matrix_indices
 func.func @DynamicStitch_scalar_matrix_indices(%arg0: tensor<2xf32>, %arg1: tensor<2x2x2xf32>) -> (tensor<5x2xf32>) {
-  // CHECK-DAG: %[[SHAPE:.*]] = "tf.Const"() {value = dense<[-1, 2]> : tensor<2xi64>} : () -> tensor<2xi64>
+  // CHECK-DAG: %[[SHAPE:.*]] = "tf.Const"() <{value = dense<[-1, 2]> : tensor<2xi64>}> : () -> tensor<2xi64>
   // CHECK-DAG: %[[INP0:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) : (tensor<2xf32>, tensor<2xi64>) -> tensor<1x2xf32>
-  // CHECK-DAG: %[[ITEMS0:.*]] = "tf.Unpack"(%[[INP0]]) {axis = 0 : i64} : (tensor<1x2xf32>) -> tensor<2xf32>
+  // CHECK-DAG: %[[ITEMS0:.*]] = "tf.Unpack"(%[[INP0]]) <{axis = 0 : i64}> : (tensor<1x2xf32>) -> tensor<2xf32>
   // CHECK-DAG: %[[INP1:.*]] = "tf.Reshape"(%arg1, %[[SHAPE]]) : (tensor<2x2x2xf32>, tensor<2xi64>) -> tensor<4x2xf32>
-  // CHECK-DAG: %[[ITEMS1:.*]]:4 = "tf.Unpack"(%[[INP1]]) {axis = 0 : i64} : (tensor<4x2xf32>) -> (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>)
-  // CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
+  // CHECK-DAG: %[[ITEMS1:.*]]:4 = "tf.Unpack"(%[[INP1]]) <{axis = 0 : i64}> : (tensor<4x2xf32>) -> (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>)
+  // CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() <{value = dense<0> : tensor<i64>}> : () -> tensor<i64>
   // CHECK-DAG: %[[ITEMS1_3:.*]] = "tf.ExpandDims"(%[[ITEMS1]]#3, %[[AXIS]])
   // CHECK-DAG: %[[ITEMS1_2:.*]] = "tf.ExpandDims"(%[[ITEMS1]]#2, %[[AXIS]])
   // CHECK-DAG: %[[ITEMS1_1:.*]] = "tf.ExpandDims"(%[[ITEMS1]]#1, %[[AXIS]])
@@ -727,8 +727,8 @@
 
 // CHECK-LABEL: func @DynamicStitch_scalar_item
 func.func @DynamicStitch_scalar_item(%arg0: tensor<2xf32>) -> tensor<2xf32> {
-  // CHECK-DAG: %[[ITEMS:.*]]:2 = "tf.Unpack"(%arg0) {axis = 0 : i64} : (tensor<2xf32>) -> (tensor<f32>, tensor<f32>)
-  // CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
+  // CHECK-DAG: %[[ITEMS:.*]]:2 = "tf.Unpack"(%arg0) <{axis = 0 : i64}> : (tensor<2xf32>) -> (tensor<f32>, tensor<f32>)
+  // CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() <{value = dense<0> : tensor<i64>}> : () -> tensor<i64>
   // CHECK-DAG: %[[ITEMS_1:.*]] = "tf.ExpandDims"(%[[ITEMS]]#1, %[[AXIS]])
   // CHECK-DAG: %[[ITEMS_0:.*]] = "tf.ExpandDims"(%[[ITEMS]]#0, %[[AXIS]])
   // CHECK-DAG: %[[RESULT:.*]] = "tf.ConcatV2"(%[[ITEMS_1]], %[[ITEMS_0]], %[[AXIS]]) : (tensor<1xf32>, tensor<1xf32>, tensor<i64>) -> tensor<2xf32>
@@ -741,8 +741,8 @@
 
 // CHECK-LABEL: func @DynamicStitch_matrix_item
 func.func @DynamicStitch_matrix_item(%arg0: tensor<2x2x2xf32>) -> tensor<2x2x2xf32> {
-  // CHECK-DAG: %[[ITEMS:.*]]:2 = "tf.Unpack"(%arg0) {axis = 0 : i64} : (tensor<2x2x2xf32>) -> (tensor<2x2xf32>, tensor<2x2xf32>)
-  // CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
+  // CHECK-DAG: %[[ITEMS:.*]]:2 = "tf.Unpack"(%arg0) <{axis = 0 : i64}> : (tensor<2x2x2xf32>) -> (tensor<2x2xf32>, tensor<2x2xf32>)
+  // CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() <{value = dense<0> : tensor<i64>}> : () -> tensor<i64>
   // CHECK-DAG: %[[ITEMS_1:.*]] = "tf.ExpandDims"(%[[ITEMS]]#1, %[[AXIS]])
   // CHECK-DAG: %[[ITEMS_0:.*]] = "tf.ExpandDims"(%[[ITEMS]]#0, %[[AXIS]])
   // CHECK-DAG: %[[RESULT:.*]] = "tf.ConcatV2"(%[[ITEMS_1]], %[[ITEMS_0]], %[[AXIS]]) : (tensor<1x2x2xf32>, tensor<1x2x2xf32>, tensor<i64>) -> tensor<2x2x2xf32>
@@ -762,8 +762,8 @@
 
 // CHECK-LABEL: func @DynamicStitch_duplicates
 func.func @DynamicStitch_duplicates(%arg0: tensor<2x2xf32>) -> tensor<1x2xf32> {
-  // CHECK-DAG: %[[ITEMS:.*]]:2 = "tf.Unpack"(%arg0) {axis = 0 : i64} : (tensor<2x2xf32>) -> (tensor<2xf32>, tensor<2xf32>)
-  // CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
+  // CHECK-DAG: %[[ITEMS:.*]]:2 = "tf.Unpack"(%arg0) <{axis = 0 : i64}> : (tensor<2x2xf32>) -> (tensor<2xf32>, tensor<2xf32>)
+  // CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() <{value = dense<0> : tensor<i64>}> : () -> tensor<i64>
   // CHECK-DAG: %[[ITEMS_1:.*]] = "tf.ExpandDims"(%[[ITEMS]]#1, %[[AXIS]])
   // CHECK-DAG: %[[RESULT:.*]] = "tf.ConcatV2"(%[[ITEMS_1]], %[[AXIS]]) : (tensor<1x2xf32>, tensor<i64>) -> tensor<1x2xf32>
   // CHECK: return %[[RESULT]]
@@ -783,7 +783,7 @@
 
 // CHECK-LABEL: @Reciprocal_i32
 func.func @Reciprocal_i32(%arg0: tensor<*xi32>) -> tensor<*xi32> {
-  // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: %[[ONE:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
   // CHECK: "tf.Div"(%[[ONE]], %arg0) : (tensor<i32>, tensor<*xi32>) -> tensor<*xi32>
   %0 = "tf.Reciprocal"(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
   func.return %0 : tensor<*xi32>
@@ -791,7 +791,7 @@
 
 // CHECK-LABEL: @Reciprocal_f32
 func.func @Reciprocal_f32(%arg0: tensor<*xf32>) -> tensor<*xf32> {
-  // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
+  // CHECK: %[[ONE:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
   // CHECK: "tf.Div"(%[[ONE]], %arg0) : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
   %0 = "tf.Reciprocal"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
   func.return %0 : tensor<*xf32>
@@ -799,7 +799,7 @@
 
 // CHECK-LABEL: @Reciprocal_complexf32
 func.func @Reciprocal_complexf32(%arg0: tensor<*xcomplex<f32>>) -> tensor<*xcomplex<f32>> {
-  // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f32>>} : () -> tensor<complex<f32>>
+  // CHECK: %[[ONE:.*]] = "tf.Const"() <{value = dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f32>>}> : () -> tensor<complex<f32>>
   // CHECK: "tf.Div"(%[[ONE]], %arg0) : (tensor<complex<f32>>, tensor<*xcomplex<f32>>) -> tensor<*xcomplex<f32>>
   %0 = "tf.Reciprocal"(%arg0) : (tensor<*xcomplex<f32>>) -> tensor<*xcomplex<f32>>
   func.return %0 : tensor<*xcomplex<f32>>
@@ -807,7 +807,7 @@
 
 // CHECK-LABEL: @Reciprocal_complexf64
 func.func @Reciprocal_complexf64(%arg0: tensor<*xcomplex<f64>>) -> tensor<*xcomplex<f64>> {
-  // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f64>>} : () -> tensor<complex<f64>>
+  // CHECK: %[[ONE:.*]] = "tf.Const"() <{value = dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f64>>}> : () -> tensor<complex<f64>>
   // CHECK: "tf.Div"(%[[ONE]], %arg0) : (tensor<complex<f64>>, tensor<*xcomplex<f64>>) -> tensor<*xcomplex<f64>>
   %0 = "tf.Reciprocal"(%arg0) : (tensor<*xcomplex<f64>>) -> tensor<*xcomplex<f64>>
   func.return %0 : tensor<*xcomplex<f64>>
@@ -816,7 +816,7 @@
 // Inv is the same as Reciprocal
 // CHECK-LABEL: @Inv_i32
 func.func @Inv_i32(%arg0: tensor<*xi32>) -> tensor<*xi32> {
-  // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: %[[ONE:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
   // CHECK: "tf.Div"(%[[ONE]], %arg0) : (tensor<i32>, tensor<*xi32>) -> tensor<*xi32>
   %0 = "tf.Inv"(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
   func.return %0 : tensor<*xi32>
@@ -824,7 +824,7 @@
 
 // CHECK-LABEL: @ScatterNd
 func.func @ScatterNd(%arg0: tensor<4x1xi32>, %arg1: tensor<4xf32>) -> tensor<8xf32> {
-  // CHECK: %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<8xf32>} : () -> tensor<8xf32>
+  // CHECK: %[[ZERO:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<8xf32>}> : () -> tensor<8xf32>
   // CHECK: "tf.TensorScatterAdd"(%[[ZERO]], %arg0, %arg1) : (tensor<8xf32>, tensor<4x1xi32>, tensor<4xf32>) -> tensor<8xf32>
 
   %shape = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> tensor<1xi32>
@@ -856,24 +856,24 @@
 
 // CHECK-LABEL: @round
 func.func @round(%arg0: tensor<2xf32>) -> tensor<2xf32> {
-  // CHECK-DAG: %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
-  // CHECK-DAG: %[[HALF:.*]] = "tf.Const"() {value = dense<5.000000e-01> : tensor<f32>} : () -> tensor<f32>
-  // CHECK-DAG: %[[TWO:.*]] = "tf.Const"() {value = dense<2.000000e+00> : tensor<f32>} : () -> tensor<f32>
-  // CHECK-DAG: %[[ONE:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
+  // CHECK-DAG: %[[ZERO:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+  // CHECK-DAG: %[[HALF:.*]] = "tf.Const"() <{value = dense<5.000000e-01> : tensor<f32>}> : () -> tensor<f32>
+  // CHECK-DAG: %[[TWO:.*]] = "tf.Const"() <{value = dense<2.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+  // CHECK-DAG: %[[ONE:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
   // CHECK: %[[ROUND_VAL:.*]] = "tf.Floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
   // CHECK: %[[FRACTION:.*]] = "tf.Sub"(%arg0, %[[ROUND_VAL]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
   // CHECK: %[[GT:.*]] = "tf.Greater"(%[[FRACTION]], %[[HALF]]) : (tensor<2xf32>, tensor<f32>) -> tensor<2xi1>
-  // CHECK: %[[EQ:.*]] = "tf.Equal"(%[[FRACTION]], %[[HALF]]) {incompatible_shape_error = true} : (tensor<2xf32>, tensor<f32>) -> tensor<2xi1>
+  // CHECK: %[[EQ:.*]] = "tf.Equal"(%[[FRACTION]], %[[HALF]]) <{incompatible_shape_error = true}> : (tensor<2xf32>, tensor<f32>) -> tensor<2xi1>
   // CHECK: %[[MUL1:.*]] = "tf.Mul"(%arg0, %[[HALF]]) : (tensor<2xf32>, tensor<f32>) -> tensor<2xf32>
   // CHECK: %[[FLOOR:.*]] = "tf.Floor"(%[[MUL1]]) : (tensor<2xf32>) -> tensor<2xf32>
   // CHECK: %[[MUL2:.*]] = "tf.Mul"(%[[FLOOR]], %[[TWO]]) : (tensor<2xf32>, tensor<f32>) -> tensor<2xf32>
   // CHECK: %[[NEAREST_EVEN_INT:.*]] = "tf.Sub"(%[[ROUND_VAL]], %[[MUL2]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
-  // CHECK: %[[IS_ODD:.*]] = "tf.Equal"(%[[NEAREST_EVEN_INT]], %[[ONE]]) {incompatible_shape_error = true} : (tensor<2xf32>, tensor<f32>) -> tensor<2xi1>
+  // CHECK: %[[IS_ODD:.*]] = "tf.Equal"(%[[NEAREST_EVEN_INT]], %[[ONE]]) <{incompatible_shape_error = true}> : (tensor<2xf32>, tensor<f32>) -> tensor<2xi1>
   // CHECK: %[[AND:.*]] = "tf.LogicalAnd"(%[[EQ]], %[[IS_ODD]]) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1>
   // CHECK: %[[OR:.*]] = "tf.LogicalOr"(%[[GT]], %[[AND]]) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1>
   // CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[ROUND_VAL]], %[[ONE]]) : (tensor<2xf32>, tensor<f32>) -> tensor<2xf32>
   // CHECK: %[[INNER_SELECT:.*]] = "tf.SelectV2"(%[[OR]], %[[ADD]], %[[ROUND_VAL]]) : (tensor<2xi1>, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
-  // CHECK-DAG: %[[IS_ZERO:.*]] = "tf.Equal"(%[[INNER_SELECT]], %[[ZERO]]) {incompatible_shape_error = true}
+  // CHECK-DAG: %[[IS_ZERO:.*]] = "tf.Equal"(%[[INNER_SELECT]], %[[ZERO]]) <{incompatible_shape_error = true}>
   // CHECK-DAG: %[[SELECT:.*]] = "tf.SelectV2"(%[[IS_ZERO]], %[[ZERO]], %[[INNER_SELECT]])
   %0 = "tf.Round"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
 
@@ -890,24 +890,24 @@
 
 // CHECK-LABEL: func @rint_dynamic
 func.func @rint_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
-  // CHECK-DAG: %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
-  // CHECK-DAG: %[[HALF:.*]] = "tf.Const"() {value = dense<5.000000e-01> : tensor<f32>} : () -> tensor<f32>
-  // CHECK-DAG: %[[TWO:.*]] = "tf.Const"() {value = dense<2.000000e+00> : tensor<f32>} : () -> tensor<f32>
-  // CHECK-DAG: %[[ONE:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
+  // CHECK-DAG: %[[ZERO:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+  // CHECK-DAG: %[[HALF:.*]] = "tf.Const"() <{value = dense<5.000000e-01> : tensor<f32>}> : () -> tensor<f32>
+  // CHECK-DAG: %[[TWO:.*]] = "tf.Const"() <{value = dense<2.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+  // CHECK-DAG: %[[ONE:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
   // CHECK: %[[ROUND_VAL:.*]] = "tf.Floor"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
   // CHECK: %[[FRACTION:.*]] = "tf.Sub"(%arg0, %[[ROUND_VAL]]) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
   // CHECK: %[[GT:.*]] = "tf.Greater"(%[[FRACTION]], %[[HALF]]) : (tensor<?xf32>, tensor<f32>) -> tensor<?xi1>
-  // CHECK: %[[EQ:.*]] = "tf.Equal"(%[[FRACTION]], %[[HALF]]) {incompatible_shape_error = true} : (tensor<?xf32>, tensor<f32>) -> tensor<?xi1>
+  // CHECK: %[[EQ:.*]] = "tf.Equal"(%[[FRACTION]], %[[HALF]]) <{incompatible_shape_error = true}> : (tensor<?xf32>, tensor<f32>) -> tensor<?xi1>
   // CHECK: %[[MUL1:.*]] = "tf.Mul"(%arg0, %[[HALF]]) : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32>
   // CHECK: %[[FLOOR:.*]] = "tf.Floor"(%[[MUL1]]) : (tensor<?xf32>) -> tensor<?xf32>
   // CHECK: %[[MUL2:.*]] = "tf.Mul"(%[[FLOOR]], %[[TWO]]) : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32>
   // CHECK: %[[NEAREST_EVEN_INT:.*]] = "tf.Sub"(%[[ROUND_VAL]], %[[MUL2]]) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
-  // CHECK: %[[IS_ODD:.*]] = "tf.Equal"(%[[NEAREST_EVEN_INT]], %[[ONE]]) {incompatible_shape_error = true} : (tensor<?xf32>, tensor<f32>) -> tensor<?xi1>
+  // CHECK: %[[IS_ODD:.*]] = "tf.Equal"(%[[NEAREST_EVEN_INT]], %[[ONE]]) <{incompatible_shape_error = true}> : (tensor<?xf32>, tensor<f32>) -> tensor<?xi1>
   // CHECK: %[[AND:.*]] = "tf.LogicalAnd"(%[[EQ]], %[[IS_ODD]]) : (tensor<?xi1>, tensor<?xi1>) -> tensor<?xi1>
   // CHECK: %[[OR:.*]] = "tf.LogicalOr"(%[[GT]], %[[AND]]) : (tensor<?xi1>, tensor<?xi1>) -> tensor<?xi1>
   // CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[ROUND_VAL]], %[[ONE]]) : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32>
   // CHECK: %[[INNER_SELECT:.*]] = "tf.SelectV2"(%[[OR]], %[[ADD]], %[[ROUND_VAL]]) : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
-  // CHECK: %[[IS_ZERO:.*]] = "tf.Equal"(%[[INNER_SELECT]], %[[ZERO]]) {incompatible_shape_error = true}
+  // CHECK: %[[IS_ZERO:.*]] = "tf.Equal"(%[[INNER_SELECT]], %[[ZERO]]) <{incompatible_shape_error = true}>
   // CHECK: %[[SELECT:.*]] = "tf.SelectV2"(%[[IS_ZERO]], %[[ZERO]], %[[INNER_SELECT]])
   %0 = "tf.Rint"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
 
@@ -937,12 +937,12 @@
 func.func @imag_resize_nearest(%arg0: tensor<1x7x7x1xi32>) -> tensor<1x3x3x1xi32> {
   %shape = "tf.Const"() {device = "", value = dense<3> : tensor<2xi32>} : () -> tensor<2xi32>
 
-  // CHECK-DAG: [[VAL0:%.+]] = "tf.Const"() {value = dense<1> : tensor<i32>}
-  // CHECK-DAG: [[VAL1:%.+]] = "tf.Const"() {value = dense<[1, 3, 3, 1]>
-  // CHECK-DAG: [[VAL2:%.+]] = "tf.Const"() {value = dense<[1, 49, 1]>
-  // CHECK-DAG: [[VAL3:%.+]] = "tf.Const"() {value = dense<[0, 2, 4, 14, 16, 18, 28, 30, 32]> : tensor<9xi32>}
+  // CHECK-DAG: [[VAL0:%.+]] = "tf.Const"() <{value = dense<1> : tensor<i32>}>
+  // CHECK-DAG: [[VAL1:%.+]] = "tf.Const"() <{value = dense<[1, 3, 3, 1]>
+  // CHECK-DAG: [[VAL2:%.+]] = "tf.Const"() <{value = dense<[1, 49, 1]>
+  // CHECK-DAG: [[VAL3:%.+]] = "tf.Const"() <{value = dense<[0, 2, 4, 14, 16, 18, 28, 30, 32]> : tensor<9xi32>}>
   // CHECK: [[VAL4:%.+]] = "tf.Reshape"(%arg0, [[VAL2]])
-  // CHECK: [[VAL5:%.+]] = "tf.GatherV2"([[VAL4]], [[VAL3]], [[VAL0]]) {batch_dims = 0 : i64}
+  // CHECK: [[VAL5:%.+]] = "tf.GatherV2"([[VAL4]], [[VAL3]], [[VAL0]]) <{batch_dims = 0 : i64}>
   // CHECK: [[VAL6:%.+]] = "tf.Reshape"([[VAL5]], [[VAL1]])
   // CHECK: return [[VAL6]]
   %resize = "tf.ResizeNearestNeighbor"(%arg0, %shape) {align_corners = false, device = "", half_pixel_centers = false} : (tensor<1x7x7x1xi32>, tensor<2xi32>) -> tensor<1x3x3x1xi32>
@@ -953,17 +953,17 @@
 func.func @imag_resize_nearest_dyn_img(%arg0: tensor<1x?x?x1xi32>) -> tensor<1x3x3x1xi32> {
   %shape = "tf.Const"() {device = "", value = dense<3> : tensor<2xi32>} : () -> tensor<2xi32>
 
-  // CHECK-DAG: [[VAL0:%.+]] = "tf.Const"() {value = dense<1> : tensor<i32>}
-  // CHECK-DAG: [[VAL1:%.+]] = "tf.Const"() {value = dense<[3, 1]> : tensor<2xi32>}
-  // CHECK-DAG: [[VAL2:%.+]] = "tf.Const"() {value = dense<9> : tensor<1xi32>}
-  // CHECK-DAG: [[VAL3:%.+]] = "tf.Const"() {value = dense<3> : tensor<1xi32>}
-  // CHECK-DAG: [[VAL4:%.+]] = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>}
-  // CHECK-DAG: [[VAL5:%.+]] = "tf.Const"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]>
-  // CHECK-DAG: [[VAL6:%.+]] = "tf.Const"() {value = dense<3.000000e+00> : tensor<f32>}
-  // CHECK-DAG: [[VAL7:%.+]] = "tf.Const"() {value = dense<0> : tensor<i64>}
+  // CHECK-DAG: [[VAL0:%.+]] = "tf.Const"() <{value = dense<1> : tensor<i32>}>
+  // CHECK-DAG: [[VAL1:%.+]] = "tf.Const"() <{value = dense<[3, 1]> : tensor<2xi32>}>
+  // CHECK-DAG: [[VAL2:%.+]] = "tf.Const"() <{value = dense<9> : tensor<1xi32>}>
+  // CHECK-DAG: [[VAL3:%.+]] = "tf.Const"() <{value = dense<3> : tensor<1xi32>}>
+  // CHECK-DAG: [[VAL4:%.+]] = "tf.Const"() <{value = dense<[1, 3]> : tensor<2xi32>}>
+  // CHECK-DAG: [[VAL5:%.+]] = "tf.Const"() <{value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]>
+  // CHECK-DAG: [[VAL6:%.+]] = "tf.Const"() <{value = dense<3.000000e+00> : tensor<f32>}>
+  // CHECK-DAG: [[VAL7:%.+]] = "tf.Const"() <{value = dense<0> : tensor<i64>}>
   // CHECK: [[VAL8:%.+]] = "tf.Shape"(%arg0)
   // CHECK: [[VAL9:%.+]] = "tf.Cast"([[VAL8]])
-  // CHECK: [[VAL10:%.+]]:4 = "tf.Unpack"([[VAL9]]) {axis = 0 : i64}
+  // CHECK: [[VAL10:%.+]]:4 = "tf.Unpack"([[VAL9]]) <{axis = 0 : i64}>
   // CHECK: [[VAL11:%.+]] = "tf.Mul"([[VAL10]]#1, [[VAL10]]#2)
   // CHECK: [[VAL12:%.+]] = "tf.ExpandDims"([[VAL10]]#0, [[VAL7]])
   // CHECK: [[VAL13:%.+]] = "tf.ExpandDims"([[VAL10]]#3, [[VAL7]])
@@ -986,7 +986,7 @@
   // CHECK: [[VAL30:%.+]] = "tf.ExpandDims"([[VAL10]]#3, [[VAL7]])
   // CHECK: [[VAL31:%.+]] = "tf.ConcatV2"([[VAL28]], [[VAL29]], [[VAL30]], [[VAL7]])
   // CHECK: [[VAL32:%.+]] = "tf.Reshape"(%arg0, [[VAL31]])
-  // CHECK: [[VAL33:%.+]] = "tf.GatherV2"([[VAL32]], [[VAL27]], [[VAL0]]) {batch_dims = 0 : i64}
+  // CHECK: [[VAL33:%.+]] = "tf.GatherV2"([[VAL32]], [[VAL27]], [[VAL0]]) <{batch_dims = 0 : i64}>
   // CHECK: [[VAL34:%.+]] = "tf.Reshape"([[VAL33]], [[VAL14]])
   // CHECK: return [[VAL34]]
   %resize = "tf.ResizeNearestNeighbor"(%arg0, %shape) {align_corners = false, device = "", half_pixel_centers = false} : (tensor<1x?x?x1xi32>, tensor<2xi32>) -> tensor<1x3x3x1xi32>
@@ -996,17 +996,17 @@
 // CHECK-LABEL: func @imag_resize_nearest_full_dyn
 func.func @imag_resize_nearest_full_dyn(%arg0: tensor<1x?x?x1xi32>, %arg1: tensor<2xi32>) -> tensor<1x?x?x1xi32> {
 
-  // CHECK-DAG: [[VAL0:%.+]] = "tf.Const"() {value = dense<1> : tensor<i32>}
-  // CHECK-DAG: [[VAL1:%.+]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>}
-  // CHECK-DAG: [[VAL2:%.+]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
-  // CHECK-DAG: [[VAL3:%.+]] = "tf.Const"() {value = dense<1> : tensor<1xi32>}
-  // CHECK-DAG: [[VAL4:%.+]] = "tf.Const"() {value = dense<1> : tensor<1xi64>}
-  // CHECK-DAG: [[VAL5:%.+]] = "tf.Const"() {value = dense<0> : tensor<i64>}
+  // CHECK-DAG: [[VAL0:%.+]] = "tf.Const"() <{value = dense<1> : tensor<i32>}>
+  // CHECK-DAG: [[VAL1:%.+]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}>
+  // CHECK-DAG: [[VAL2:%.+]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}>
+  // CHECK-DAG: [[VAL3:%.+]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}>
+  // CHECK-DAG: [[VAL4:%.+]] = "tf.Const"() <{value = dense<1> : tensor<1xi64>}>
+  // CHECK-DAG: [[VAL5:%.+]] = "tf.Const"() <{value = dense<0> : tensor<i64>}>
   // CHECK: [[VAL6:%.+]] = "tf.Shape"(%arg0)
   // CHECK: [[VAL7:%.+]] = "tf.Cast"([[VAL6]])
-  // CHECK: [[VAL8:%.+]]:4 = "tf.Unpack"([[VAL7]]) {axis = 0 : i64}
+  // CHECK: [[VAL8:%.+]]:4 = "tf.Unpack"([[VAL7]]) <{axis = 0 : i64}>
   // CHECK: [[VAL9:%.+]] = "tf.Mul"([[VAL8]]#1, [[VAL8]]#2)
-  // CHECK: [[VAL10:%.+]]:2 = "tf.Unpack"(%arg1) {axis = 0 : i64}
+  // CHECK: [[VAL10:%.+]]:2 = "tf.Unpack"(%arg1) <{axis = 0 : i64}>
   // CHECK: [[VAL11:%.+]] = "tf.Mul"([[VAL10]]#0, [[VAL10]]#1)
   // CHECK: [[VAL12:%.+]] = "tf.ExpandDims"([[VAL8]]#0, [[VAL5]])
   // CHECK: [[VAL13:%.+]] = "tf.ExpandDims"([[VAL10]]#0, [[VAL5]])
@@ -1040,7 +1040,7 @@
   // CHECK: [[VAL41:%.+]] = "tf.ExpandDims"([[VAL8]]#3, [[VAL5]])
   // CHECK: [[VAL42:%.+]] = "tf.ConcatV2"([[VAL39]], [[VAL40]], [[VAL41]], [[VAL5]])
   // CHECK: [[VAL43:%.+]] = "tf.Reshape"(%arg0, [[VAL42]])
-  // CHECK: [[VAL44:%.+]] = "tf.GatherV2"([[VAL43]], [[VAL38]], [[VAL0]]) {batch_dims = 0 : i64}
+  // CHECK: [[VAL44:%.+]] = "tf.GatherV2"([[VAL43]], [[VAL38]], [[VAL0]]) <{batch_dims = 0 : i64}>
   // CHECK: [[VAL45:%.+]] = "tf.Reshape"([[VAL44]], [[VAL16]])
   // CHECK: return [[VAL45]]
   %resize = "tf.ResizeNearestNeighbor"(%arg0, %arg1) {align_corners = false, device = "", half_pixel_centers = false} : (tensor<1x?x?x1xi32>, tensor<2xi32>) -> tensor<1x?x?x1xi32>
@@ -1050,8 +1050,8 @@
 // CHECK-LABEL: func @xdivy
 // CHECK-SAME: (%[[X:.*]]: tensor<*xf32>, %[[Y:.*]]: tensor<*xf32>)
 func.func @xdivy(%lhs: tensor<*xf32>, %rhs: tensor<*xf32>) -> tensor<*xf32> {
-  // CHECK:  %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
-  // CHECK:  %[[IS_ZERO:.*]] = "tf.Equal"(%[[X]], %[[ZERO]]) {incompatible_shape_error = true} : (tensor<*xf32>, tensor<f32>) -> tensor<*xi1>
+  // CHECK:  %[[ZERO:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+  // CHECK:  %[[IS_ZERO:.*]] = "tf.Equal"(%[[X]], %[[ZERO]]) <{incompatible_shape_error = true}> : (tensor<*xf32>, tensor<f32>) -> tensor<*xi1>
   // CHECK:  %[[MUL:.*]] = "tf.Div"(%[[X]], %[[Y]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
   // CHECK:  %[[RESULT:.*]] = "tf.SelectV2"(%[[IS_ZERO]], %[[X]], %[[MUL]]) : (tensor<*xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
   %0 = "tf.Xdivy"(%lhs, %rhs) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
@@ -1062,8 +1062,8 @@
 // CHECK-LABEL: func @xlog1py
 // CHECK-SAME: (%[[X:.*]]: tensor<*xf32>, %[[Y:.*]]: tensor<*xf32>)
 func.func @xlog1py(%lhs: tensor<*xf32>, %rhs: tensor<*xf32>) -> tensor<*xf32> {
-  // CHECK:  %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
-  // CHECK:  %[[IS_ZERO:.*]] = "tf.Equal"(%[[X]], %[[ZERO]]) {incompatible_shape_error = true} : (tensor<*xf32>, tensor<f32>) -> tensor<*xi1>
+  // CHECK:  %[[ZERO:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+  // CHECK:  %[[IS_ZERO:.*]] = "tf.Equal"(%[[X]], %[[ZERO]]) <{incompatible_shape_error = true}> : (tensor<*xf32>, tensor<f32>) -> tensor<*xi1>
   // CHECK:  %[[LOG:.*]] = "tf.Log1p"(%[[Y]]) : (tensor<*xf32>) -> tensor<*xf32>
   // CHECK:  %[[MUL:.*]] = "tf.Mul"(%[[X]], %[[LOG]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
   // CHECK:  %[[RESULT:.*]] = "tf.SelectV2"(%[[IS_ZERO]], %[[X]], %[[MUL]]) : (tensor<*xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
@@ -1075,8 +1075,8 @@
 // CHECK-LABEL: func @xlogy
 // CHECK-SAME: (%[[X:.*]]: tensor<*xf32>, %[[Y:.*]]: tensor<*xf32>)
 func.func @xlogy(%lhs: tensor<*xf32>, %rhs: tensor<*xf32>) -> tensor<*xf32> {
-  // CHECK:  %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
-  // CHECK:  %[[IS_ZERO:.*]] = "tf.Equal"(%[[X]], %[[ZERO]]) {incompatible_shape_error = true} : (tensor<*xf32>, tensor<f32>) -> tensor<*xi1>
+  // CHECK:  %[[ZERO:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+  // CHECK:  %[[IS_ZERO:.*]] = "tf.Equal"(%[[X]], %[[ZERO]]) <{incompatible_shape_error = true}> : (tensor<*xf32>, tensor<f32>) -> tensor<*xi1>
   // CHECK:  %[[LOG:.*]] = "tf.Log"(%[[Y]]) : (tensor<*xf32>) -> tensor<*xf32>
   // CHECK:  %[[MUL:.*]] = "tf.Mul"(%[[X]], %[[LOG]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
   // CHECK:  %[[RESULT:.*]] = "tf.SelectV2"(%[[IS_ZERO]], %[[X]], %[[MUL]]) : (tensor<*xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
@@ -1089,9 +1089,9 @@
 func.func @size_to_prod_shape_i32(%arg0 : tensor<1x?x2x3xf32>) -> tensor<i32> {
   %0 = "tf.Size"(%arg0) : (tensor<1x?x2x3xf32>) -> tensor<i32>
   func.return %0 : tensor<i32>
-  // CHECK: %[[CONSTANT:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: %[[CONSTANT:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
   // CHECK: %[[SHAPE:.*]] = "tf.Shape"(%arg0) : (tensor<1x?x2x3xf32>) -> tensor<4xi32>
-  // CHECK: %[[PROD:.*]] = "tf.Prod"(%[[SHAPE]], %[[CONSTANT]]) {keep_dims = false} : (tensor<4xi32>, tensor<i32>) -> tensor<i32>
+  // CHECK: %[[PROD:.*]] = "tf.Prod"(%[[SHAPE]], %[[CONSTANT]]) <{keep_dims = false}> : (tensor<4xi32>, tensor<i32>) -> tensor<i32>
   // CHECK: return %[[PROD]]
 }
 
@@ -1099,9 +1099,9 @@
 func.func @size_to_prod_shape_i64(%arg0 : tensor<1x?x2x3xf32>) -> tensor<i64> {
   %0 = "tf.Size"(%arg0) : (tensor<1x?x2x3xf32>) -> tensor<i64>
   func.return %0 : tensor<i64>
-  // CHECK: %[[CONSTANT:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
+  // CHECK: %[[CONSTANT:.*]] = "tf.Const"() <{value = dense<0> : tensor<i64>}> : () -> tensor<i64>
   // CHECK: %[[SHAPE:.*]] = "tf.Shape"(%arg0) : (tensor<1x?x2x3xf32>) -> tensor<4xi64>
-  // CHECK: %[[PROD:.*]] = "tf.Prod"(%[[SHAPE]], %[[CONSTANT]]) {keep_dims = false} : (tensor<4xi64>, tensor<i64>) -> tensor<i64>
+  // CHECK: %[[PROD:.*]] = "tf.Prod"(%[[SHAPE]], %[[CONSTANT]]) <{keep_dims = false}> : (tensor<4xi64>, tensor<i64>) -> tensor<i64>
   // CHECK: return %[[PROD]]
 }
 
@@ -1109,9 +1109,9 @@
 func.func @is_finite(%arg0: tensor<3x4xf32>) -> tensor<3x4xi1> {
   %0 = "tf.IsFinite"(%arg0) : (tensor<3x4xf32>) -> tensor<3x4xi1>
   func.return %0 : tensor<3x4xi1>
-  // CHECK: %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
+  // CHECK: %[[ZERO:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
   // CHECK: %[[SUB:.*]] = "tf.Sub"(%arg0, %arg0) : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32>
-  // CHECK: %[[RESULT:.*]] = "tf.Equal"(%[[SUB]], %[[ZERO]]) {incompatible_shape_error = true} : (tensor<3x4xf32>, tensor<f32>) -> tensor<3x4xi1>
+  // CHECK: %[[RESULT:.*]] = "tf.Equal"(%[[SUB]], %[[ZERO]]) <{incompatible_shape_error = true}> : (tensor<3x4xf32>, tensor<f32>) -> tensor<3x4xi1>
   // CHECK: return %[[RESULT]]
 }
 
@@ -1119,9 +1119,9 @@
 func.func @is_finite_dynamic(%arg0: tensor<?x4xf32>) -> tensor<?x4xi1> {
   %0 = "tf.IsFinite"(%arg0) : (tensor<?x4xf32>) -> tensor<?x4xi1>
   func.return %0 : tensor<?x4xi1>
-  // CHECK: %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
+  // CHECK: %[[ZERO:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
   // CHECK: %[[SUB:.*]] = "tf.Sub"(%arg0, %arg0) : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
-  // CHECK: %[[RESULT:.*]] = "tf.Equal"(%[[SUB]], %[[ZERO]]) {incompatible_shape_error = true} : (tensor<?x4xf32>, tensor<f32>) -> tensor<?x4xi1>
+  // CHECK: %[[RESULT:.*]] = "tf.Equal"(%[[SUB]], %[[ZERO]]) <{incompatible_shape_error = true}> : (tensor<?x4xf32>, tensor<f32>) -> tensor<?x4xi1>
   // CHECK: return %[[RESULT]]
 }
 
@@ -1131,11 +1131,11 @@
   %0 = "tf.Roll"(%arg0, %shift, %axis) : (tensor<3x8x4xi32>, tensor<i32>, tensor<i32>) -> tensor<3x8x4xi32>
   func.return %0 : tensor<3x8x4xi32>
   // CHECK-LABEL: roll_scalar_axis
-  // CHECK-DAG:  %[[CST:.*]] = "tf.Const"() {value = dense<[0, 6, 0]> : tensor<3xi64>} : () -> tensor<3xi64>
-  // CHECK-DAG:  %[[CST0:.*]] = "tf.Const"() {value = dense<[3, 2, 4]> : tensor<3xi64>} : () -> tensor<3xi64>
-  // CHECK-DAG:  %[[CST1:.*]] = "tf.Const"() {value = dense<0> : tensor<3xi64>} : () -> tensor<3xi64>
-  // CHECK-DAG:  %[[CST2:.*]] = "tf.Const"() {value = dense<[3, 6, 4]> : tensor<3xi64>} : () -> tensor<3xi64>
-  // CHECK-DAG:  %[[CST3:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+  // CHECK-DAG:  %[[CST:.*]] = "tf.Const"() <{value = dense<[0, 6, 0]> : tensor<3xi64>}> : () -> tensor<3xi64>
+  // CHECK-DAG:  %[[CST0:.*]] = "tf.Const"() <{value = dense<[3, 2, 4]> : tensor<3xi64>}> : () -> tensor<3xi64>
+  // CHECK-DAG:  %[[CST1:.*]] = "tf.Const"() <{value = dense<0> : tensor<3xi64>}> : () -> tensor<3xi64>
+  // CHECK-DAG:  %[[CST2:.*]] = "tf.Const"() <{value = dense<[3, 6, 4]> : tensor<3xi64>}> : () -> tensor<3xi64>
+  // CHECK-DAG:  %[[CST3:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
   // CHECK:  %[[SLICE:.*]] = "tf.Slice"(%arg0, %[[CST]], %[[CST0]]) : (tensor<3x8x4xi32>, tensor<3xi64>, tensor<3xi64>) -> tensor<3x2x4xi32>
   // CHECK:  %[[SLICE1:.*]] = "tf.Slice"(%arg0, %[[CST1]], %[[CST2]]) : (tensor<3x8x4xi32>, tensor<3xi64>, tensor<3xi64>) -> tensor<3x6x4xi32>
   // CHECK:  %[[CONCAT:.*]] = "tf.ConcatV2"(%[[SLICE]], %[[SLICE1]], %[[CST3]]) : (tensor<3x2x4xi32>, tensor<3x6x4xi32>, tensor<i32>) -> tensor<3x8x4xi32>
@@ -1148,11 +1148,11 @@
   %0 = "tf.Roll"(%arg0, %shift, %axis) : (tensor<3x8x4xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3x8x4xi32>
   func.return %0 : tensor<3x8x4xi32>
   // CHECK-LABEL: roll_1d_axis
-  // CHECK-DAG:  %[[CST:.*]] = "tf.Const"() {value = dense<[0, 6, 0]> : tensor<3xi64>} : () -> tensor<3xi64>
-  // CHECK-DAG:  %[[CST0:.*]] = "tf.Const"() {value = dense<[3, 2, 4]> : tensor<3xi64>} : () -> tensor<3xi64>
-  // CHECK-DAG:  %[[CST1:.*]] = "tf.Const"() {value = dense<0> : tensor<3xi64>} : () -> tensor<3xi64>
-  // CHECK-DAG:  %[[CST2:.*]] = "tf.Const"() {value = dense<[3, 6, 4]> : tensor<3xi64>} : () -> tensor<3xi64>
-  // CHECK-DAG:  %[[CST3:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+  // CHECK-DAG:  %[[CST:.*]] = "tf.Const"() <{value = dense<[0, 6, 0]> : tensor<3xi64>}> : () -> tensor<3xi64>
+  // CHECK-DAG:  %[[CST0:.*]] = "tf.Const"() <{value = dense<[3, 2, 4]> : tensor<3xi64>}> : () -> tensor<3xi64>
+  // CHECK-DAG:  %[[CST1:.*]] = "tf.Const"() <{value = dense<0> : tensor<3xi64>}> : () -> tensor<3xi64>
+  // CHECK-DAG:  %[[CST2:.*]] = "tf.Const"() <{value = dense<[3, 6, 4]> : tensor<3xi64>}> : () -> tensor<3xi64>
+  // CHECK-DAG:  %[[CST3:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
   // CHECK:  %[[SLICE:.*]] = "tf.Slice"(%arg0, %[[CST]], %[[CST0]]) : (tensor<3x8x4xi32>, tensor<3xi64>, tensor<3xi64>) -> tensor<3x2x4xi32>
   // CHECK:  %[[SLICE1:.*]] = "tf.Slice"(%arg0, %[[CST1]], %[[CST2]]) : (tensor<3x8x4xi32>, tensor<3xi64>, tensor<3xi64>) -> tensor<3x6x4xi32>
   // CHECK:  %[[CONCAT:.*]] = "tf.ConcatV2"(%[[SLICE]], %[[SLICE1]], %[[CST3]]) : (tensor<3x2x4xi32>, tensor<3x6x4xi32>, tensor<i32>) -> tensor<3x8x4xi32>
@@ -1165,15 +1165,15 @@
   %0 = "tf.Roll"(%arg0, %shift, %axis) : (tensor<3x8x4xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<3x8x4xi32>
   func.return %0 : tensor<3x8x4xi32>
   // CHECK-LABEL: roll_multiple_axis
-  // CHECK-DAG:  %[[CST:.*]] = "tf.Const"() {value = dense<[1, 0, 0]> : tensor<3xi64>} : () -> tensor<3xi64>
-  // CHECK-DAG:  %[[CST0:.*]] = "tf.Const"() {value = dense<[2, 8, 4]> : tensor<3xi64>} : () -> tensor<3xi64>
-  // CHECK-DAG:  %[[CST1:.*]] = "tf.Const"() {value = dense<[1, 8, 4]> : tensor<3xi64>} : () -> tensor<3xi64>
-  // CHECK-DAG:  %[[CST2:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-  // CHECK-DAG:  %[[CST3:.*]] = "tf.Const"() {value = dense<[0, 6, 0]> : tensor<3xi64>} : () -> tensor<3xi64>
-  // CHECK-DAG:  %[[CST4:.*]] = "tf.Const"() {value = dense<[3, 2, 4]> : tensor<3xi64>} : () -> tensor<3xi64>
-  // CHECK-DAG:  %[[CST5:.*]] = "tf.Const"() {value = dense<0> : tensor<3xi64>} : () -> tensor<3xi64>
-  // CHECK-DAG:  %[[CST6:.*]] = "tf.Const"() {value = dense<[3, 6, 4]> : tensor<3xi64>} : () -> tensor<3xi64>
-  // CHECK-DAG:  %[[CST7:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+  // CHECK-DAG:  %[[CST:.*]] = "tf.Const"() <{value = dense<[1, 0, 0]> : tensor<3xi64>}> : () -> tensor<3xi64>
+  // CHECK-DAG:  %[[CST0:.*]] = "tf.Const"() <{value = dense<[2, 8, 4]> : tensor<3xi64>}> : () -> tensor<3xi64>
+  // CHECK-DAG:  %[[CST1:.*]] = "tf.Const"() <{value = dense<[1, 8, 4]> : tensor<3xi64>}> : () -> tensor<3xi64>
+  // CHECK-DAG:  %[[CST2:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
+  // CHECK-DAG:  %[[CST3:.*]] = "tf.Const"() <{value = dense<[0, 6, 0]> : tensor<3xi64>}> : () -> tensor<3xi64>
+  // CHECK-DAG:  %[[CST4:.*]] = "tf.Const"() <{value = dense<[3, 2, 4]> : tensor<3xi64>}> : () -> tensor<3xi64>
+  // CHECK-DAG:  %[[CST5:.*]] = "tf.Const"() <{value = dense<0> : tensor<3xi64>}> : () -> tensor<3xi64>
+  // CHECK-DAG:  %[[CST6:.*]] = "tf.Const"() <{value = dense<[3, 6, 4]> : tensor<3xi64>}> : () -> tensor<3xi64>
+  // CHECK-DAG:  %[[CST7:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
   // CHECK:      %[[SLICE:.*]] = "tf.Slice"(%arg0, %[[CST]], %[[CST0]]) : (tensor<3x8x4xi32>, tensor<3xi64>, tensor<3xi64>) -> tensor<2x8x4xi32>
   // CHECK:      %[[SLICE1:.*]] = "tf.Slice"(%arg0, %[[CST5]], %[[CST1]]) : (tensor<3x8x4xi32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x8x4xi32>
   // CHECK:      %[[CONCAT:.*]] = "tf.ConcatV2"(%[[SLICE]], %[[SLICE1]], %[[CST2]]) : (tensor<2x8x4xi32>, tensor<1x8x4xi32>, tensor<i32>) -> tensor<3x8x4xi32>
@@ -1184,8 +1184,8 @@
 }
 
 func.func @roll_dynamic_shape(%arg0: tensor<?x8x4xi32>) -> tensor<?x8x4xi32> {
-  %axis = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
-  %shift = "tf.Const"() {value = dense<2> : tensor<i32>} : () -> tensor<i32>
+  %axis = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
+  %shift = "tf.Const"() <{value = dense<2> : tensor<i32>}> : () -> tensor<i32>
   %0 = "tf.Roll"(%arg0, %shift, %axis) : (tensor<?x8x4xi32>, tensor<i32>, tensor<i32>) -> tensor<?x8x4xi32>
   func.return %0 : tensor<?x8x4xi32>
   // CHECK-LABEL: roll_dynamic_shape
@@ -1193,7 +1193,7 @@
 }
 
 func.func @roll_non_constant_axis(%arg0: tensor<3x8x4xi32>, %arg1: tensor<i32>) -> tensor<3x8x4xi32> {
-  %shift = "tf.Const"() {value = dense<2> : tensor<i32>} : () -> tensor<i32>
+  %shift = "tf.Const"() <{value = dense<2> : tensor<i32>}> : () -> tensor<i32>
   %0 = "tf.Roll"(%arg0, %shift, %arg1) : (tensor<3x8x4xi32>, tensor<i32>, tensor<i32>) -> tensor<3x8x4xi32>
   func.return %0 : tensor<3x8x4xi32>
   // CHECK-LABEL: roll_non_constant_axis
@@ -1201,7 +1201,7 @@
 }
 
 func.func @roll_non_constant_shift(%arg0: tensor<3x8x4xi32>, %arg1: tensor<i32>) -> tensor<3x8x4xi32> {
-  %axis = "tf.Const"() {value = dense<2> : tensor<i32>} : () -> tensor<i32>
+  %axis = "tf.Const"() <{value = dense<2> : tensor<i32>}> : () -> tensor<i32>
   %0 = "tf.Roll"(%arg0, %arg1, %axis) : (tensor<3x8x4xi32>, tensor<i32>, tensor<i32>) -> tensor<3x8x4xi32>
   func.return %0 : tensor<3x8x4xi32>
   // CHECK-LABEL: roll_non_constant_shift
@@ -1213,9 +1213,9 @@
   func.return %0 : tensor<14xf32>
 
   // CHECK-LABEL: scatter_nd_updates
-  // CHECK-DAG: %[[CST:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
-  // CHECK-DAG: %[[CST0:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<1xf32>} : () -> tensor<1xf32>
-  // CHECK-DAG: %[[CST1:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<14xf32>} : () -> tensor<14xf32>
+  // CHECK-DAG: %[[CST:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+  // CHECK-DAG: %[[CST0:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
+  // CHECK-DAG: %[[CST1:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<14xf32>}> : () -> tensor<14xf32>
   // CHECK: %[[SCATTER:.*]] = "tf.TensorScatterAdd"(%cst_1, %arg1, %[[CST0]]) : (tensor<14xf32>, tensor<1x1xi32>, tensor<1xf32>) -> tensor<14xf32>
   // CHECK: %[[SUB:.*]] = "tf.Sub"(%[[CST]], %[[SCATTER]]) : (tensor<f32>, tensor<14xf32>) -> tensor<14xf32>
   // CHECK: %[[MUL:.*]] = "tf.Mul"(%[[SUB]], %arg0) : (tensor<14xf32>, tensor<14xf32>) -> tensor<14xf32>
@@ -1229,17 +1229,17 @@
   func.return %0 : tensor<1x24xi1>
 
 // CHECK-LABEL: scatter_nd_updates_bool(
-// CHECK-DAG:       %[[CST:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
-// CHECK-DAG:       %[[CST0:.*]] = "tf.Const"() {value = dense<1> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
-// CHECK-DAG:       %[[CST1:.*]] = "tf.Const"() {value = dense<0> : tensor<1x24xi32>} : () -> tensor<1x24xi32>
-// CHECK:           %[[CAST0:.*]] = "tf.Cast"(%arg0) {Truncate = false} : (tensor<1x24xi1>) -> tensor<1x24xi32>
-// CHECK:           %[[CAST1:.*]] = "tf.Cast"(%arg2) {Truncate = false} : (tensor<1x2xi1>) -> tensor<1x2xi32>
+// CHECK-DAG:       %[[CST:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
+// CHECK-DAG:       %[[CST0:.*]] = "tf.Const"() <{value = dense<1> : tensor<1x2xi32>}> : () -> tensor<1x2xi32>
+// CHECK-DAG:       %[[CST1:.*]] = "tf.Const"() <{value = dense<0> : tensor<1x24xi32>}> : () -> tensor<1x24xi32>
+// CHECK:           %[[CAST0:.*]] = "tf.Cast"(%arg0) <{Truncate = false}> : (tensor<1x24xi1>) -> tensor<1x24xi32>
+// CHECK:           %[[CAST1:.*]] = "tf.Cast"(%arg2) <{Truncate = false}> : (tensor<1x2xi1>) -> tensor<1x2xi32>
 // CHECK:           %[[SCATTER:.*]] = "tf.TensorScatterAdd"(%[[CST1]], %arg1, %[[CST0]]) : (tensor<1x24xi32>, tensor<1x2x2xi32>, tensor<1x2xi32>) -> tensor<1x24xi32>
 // CHECK:           %[[SUB:.*]] = "tf.Sub"(%[[CST]], %[[SCATTER]]) : (tensor<i32>, tensor<1x24xi32>) -> tensor<1x24xi32>
 // CHECK:           %[[MUL:.*]] = "tf.Mul"(%[[SUB]], %[[CAST0]]) : (tensor<1x24xi32>, tensor<1x24xi32>) -> tensor<1x24xi32>
 // CHECK:           %[[SCATTER1:.*]] = "tf.TensorScatterAdd"(%[[CST1]], %arg1, %[[CAST1]]) : (tensor<1x24xi32>, tensor<1x2x2xi32>, tensor<1x2xi32>) -> tensor<1x24xi32>
 // CHECK:           %[[ADD:.*]] = "tf.AddV2"(%[[MUL]], %[[SCATTER1]]) : (tensor<1x24xi32>, tensor<1x24xi32>) -> tensor<1x24xi32>
-// CHECK:           %[[CAST2:.*]] = "tf.Cast"(%[[ADD]]) {Truncate = false} : (tensor<1x24xi32>) -> tensor<1x24xi1>
+// CHECK:           %[[CAST2:.*]] = "tf.Cast"(%[[ADD]]) <{Truncate = false}> : (tensor<1x24xi32>) -> tensor<1x24xi1>
 // CHECK:           return %[[CAST2]] : tensor<1x24xi1>
 }
 
@@ -1250,11 +1250,11 @@
 // CHECK-LABEL: func @simple_softmax
 // CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3xf32>)
 func.func @simple_softmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
-  // CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi64>} : () -> tensor<1xi64>
-  // CHECK-DAG: %[[MAX:.*]] = "tf.Max"(%[[ARG0]], %[[AXIS]]) {keep_dims = true} : (tensor<2x3xf32>, tensor<1xi64>) -> tensor<2x1xf32>
+  // CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() <{value = dense<-1> : tensor<1xi64>}> : () -> tensor<1xi64>
+  // CHECK-DAG: %[[MAX:.*]] = "tf.Max"(%[[ARG0]], %[[AXIS]]) <{keep_dims = true}> : (tensor<2x3xf32>, tensor<1xi64>) -> tensor<2x1xf32>
   // CHECK-DAG: %[[SHIFTED:.*]] = "tf.Sub"(%[[ARG0]], %[[MAX]]) : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
   // CHECK-DAG: %[[EXP:.*]] = "tf.Exp"(%[[SHIFTED]]) : (tensor<2x3xf32>) -> tensor<2x3xf32>
-  // CHECK-DAG: %[[SUM:.*]] = "tf.Sum"(%[[EXP]], %[[AXIS]]) {keep_dims = true} : (tensor<2x3xf32>, tensor<1xi64>) -> tensor<2x1xf32>
+  // CHECK-DAG: %[[SUM:.*]] = "tf.Sum"(%[[EXP]], %[[AXIS]]) <{keep_dims = true}> : (tensor<2x3xf32>, tensor<1xi64>) -> tensor<2x1xf32>
   // CHECK-DAG: %[[RESULT:.*]] = "tf.Div"(%[[EXP]], %[[SUM]]) : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
   // CHECK: return %[[RESULT]]
   %0 = "tf.Softmax"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
@@ -1277,11 +1277,11 @@
 // CHECK-LABEL: func @simple_logsoftmax
 // CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3xf32>)
 func.func @simple_logsoftmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
-  // CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi64>} : () -> tensor<1xi64>
-  // CHECK-DAG: %[[MAX:.*]] = "tf.Max"(%[[ARG0]], %[[AXIS]]) {keep_dims = true} : (tensor<2x3xf32>, tensor<1xi64>) -> tensor<2x1xf32>
+  // CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() <{value = dense<-1> : tensor<1xi64>}> : () -> tensor<1xi64>
+  // CHECK-DAG: %[[MAX:.*]] = "tf.Max"(%[[ARG0]], %[[AXIS]]) <{keep_dims = true}> : (tensor<2x3xf32>, tensor<1xi64>) -> tensor<2x1xf32>
   // CHECK-DAG: %[[SHIFTED:.*]] = "tf.Sub"(%[[ARG0]], %[[MAX]]) : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
   // CHECK-DAG: %[[EXP:.*]] = "tf.Exp"(%[[SHIFTED]]) : (tensor<2x3xf32>) -> tensor<2x3xf32>
-  // CHECK-DAG: %[[SUM:.*]] = "tf.Sum"(%[[EXP]], %[[AXIS]]) {keep_dims = true} : (tensor<2x3xf32>, tensor<1xi64>) -> tensor<2x1xf32>
+  // CHECK-DAG: %[[SUM:.*]] = "tf.Sum"(%[[EXP]], %[[AXIS]]) <{keep_dims = true}> : (tensor<2x3xf32>, tensor<1xi64>) -> tensor<2x1xf32>
   // CHECK-DAG: %[[LOG:.*]] = "tf.Log"(%[[SUM]]) : (tensor<2x1xf32>) -> tensor<2x1xf32>
   // CHECK-DAG: %[[RESULT:.*]] = "tf.Sub"(%[[SHIFTED]], %[[LOG]]) : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
   // CHECK: return %[[RESULT]]
@@ -1299,10 +1299,10 @@
 // CHECK-LABEL: func @selu
 // CHECK-SAME:  (%[[FEATURES:.*]]: tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> {
 func.func @selu(%arg0: tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> {
-    // CHECK-DAG:   %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
-    // CHECK-DAG:   %[[SCALE:.*]] = "tf.Const"() {value = dense<1.05070102> : tensor<f32>} : () -> tensor<f32>
-    // CHECK-DAG:   %[[SCALED_ALPHA:.*]] = "tf.Const"() {value = dense<1.75809932> : tensor<f32>} : () -> tensor<f32>
-    // CHECK-NEXT:  %[[ONE:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
+    // CHECK-DAG:   %[[ZERO:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+    // CHECK-DAG:   %[[SCALE:.*]] = "tf.Const"() <{value = dense<1.05070102> : tensor<f32>}> : () -> tensor<f32>
+    // CHECK-DAG:   %[[SCALED_ALPHA:.*]] = "tf.Const"() <{value = dense<1.75809932> : tensor<f32>}> : () -> tensor<f32>
+    // CHECK-NEXT:  %[[ONE:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
     // CHECK-DAG:   %[[PRED:.*]] = "tf.Greater"(%[[FEATURES]], %[[ZERO]]) : (tensor<1x4x4x3xf32>, tensor<f32>) -> tensor<1x4x4x3xi1>
     // CHECK-NEXT:  %[[SCALED_FEATURES:.*]] = "tf.Mul"(%[[FEATURES]], %[[SCALE]]) : (tensor<1x4x4x3xf32>, tensor<f32>) -> tensor<1x4x4x3xf32>
     // CHECK-NEXT:  %[[EXP:.*]] = "tf.Exp"(%[[FEATURES]]) : (tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32>
@@ -1317,9 +1317,9 @@
 // CHECK-LABEL: func @selu_grad
 // CHECK-SAME: (%[[GRADIENTS:.*]]: tensor<4x8xf32>, %[[FEATURES:.*]]: tensor<4x8xf32>) -> tensor<4x8xf32> {
 func.func @selu_grad(%gradients: tensor<4x8xf32>, %features: tensor<4x8xf32>) -> tensor<4x8xf32> {
-    // CHECK-DAG:   %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
-    // CHECK-DAG:   %[[SCALE:.*]] = "tf.Const"() {value = dense<1.05070102> : tensor<f32>} : () -> tensor<f32>
-    // CHECK-DAG:   %[[SCALED_ALPHA:.*]] = "tf.Const"() {value = dense<1.75809932> : tensor<f32>} : () -> tensor<f32>
+    // CHECK-DAG:   %[[ZERO:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+    // CHECK-DAG:   %[[SCALE:.*]] = "tf.Const"() <{value = dense<1.05070102> : tensor<f32>}> : () -> tensor<f32>
+    // CHECK-DAG:   %[[SCALED_ALPHA:.*]] = "tf.Const"() <{value = dense<1.75809932> : tensor<f32>}> : () -> tensor<f32>
     // CHECK-DAG:   %[[PRED:.*]] = "tf.Greater"(%[[FEATURES]], %[[ZERO]]) : (tensor<4x8xf32>, tensor<f32>) -> tensor<4x8xi1>
     // CHECK-NEXT:  %[[SCALED_GRADIENTS:.*]] = "tf.Mul"(%[[GRADIENTS]], %[[SCALE]]) : (tensor<4x8xf32>, tensor<f32>) -> tensor<4x8xf32>
     // CHECK-NEXT:  %[[FEATURES_PLUS_SCALED_ALPHA:.*]] = "tf.AddV2"(%[[FEATURES]], %[[SCALED_ALPHA]]) : (tensor<4x8xf32>, tensor<f32>) -> tensor<4x8xf32>
@@ -1335,7 +1335,7 @@
 func.func @expm1(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
   %0 = "tf.Expm1"(%arg0) : (tensor<3x4xf32>) -> tensor<3x4xf32>
   func.return %0 : tensor<3x4xf32>
-  // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
+  // CHECK: %[[ONE:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
   // CHECK: %[[EXP:.*]] = "tf.Exp"(%[[ARG0]]) : (tensor<3x4xf32>) -> tensor<3x4xf32>
   // CHECK: %[[RESULT:.*]] = "tf.Sub"(%[[EXP]], %[[ONE]]) : (tensor<3x4xf32>, tensor<f32>) -> tensor<3x4xf32>
   // CHECK: return %[[RESULT]]
@@ -1344,11 +1344,11 @@
 // CHECK-LABEL: func @matrix_band_part
 // CHECK-SAME: (%[[INPUT:.*]]: tensor<4x5xf32>, %[[NUM_LOWER:.*]]: tensor<i64>, %[[NUM_UPPER:.*]]: tensor<i64>) -> tensor<4x5xf32> {
 func.func @matrix_band_part(%input: tensor<4x5xf32>, %num_lower: tensor<i64>, %num_upper: tensor<i64>) -> tensor<4x5xf32> {
-  // CHECK-DAG: %[[ZERO:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
-  // CHECK-DAG: %[[OFFSET:.*]] = "tf.Const"() {{.+}} : () -> tensor<4x5xi64>
-  // CHECK-DAG: %[[M:.*]] = "tf.Const"() {value = dense<4> : tensor<i64>} : () -> tensor<i64>
-  // CHECK-DAG: %[[N:.*]] = "tf.Const"() {value = dense<5> : tensor<i64>} : () -> tensor<i64>
-  // CHECK-DAG: %[[ZEROS_LIKE:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<4x5xf32>} : () -> tensor<4x5xf32>
+  // CHECK-DAG: %[[ZERO:.*]] = "tf.Const"() <{value = dense<0> : tensor<i64>}> : () -> tensor<i64>
+  // CHECK-DAG: %[[OFFSET:.*]] = "tf.Const"() <{{.+}}> : () -> tensor<4x5xi64>
+  // CHECK-DAG: %[[M:.*]] = "tf.Const"() <{value = dense<4> : tensor<i64>}> : () -> tensor<i64>
+  // CHECK-DAG: %[[N:.*]] = "tf.Const"() <{value = dense<5> : tensor<i64>}> : () -> tensor<i64>
+  // CHECK-DAG: %[[ZEROS_LIKE:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<4x5xf32>}> : () -> tensor<4x5xf32>
   // CHECK-DAG: %[[LE:.*]] = "tf.Less"(%[[NUM_LOWER]], %[[ZERO]]) : (tensor<i64>, tensor<i64>) -> tensor<i1>
   // CHECK-DAG: %[[NUM_LOWER_OR_M:.*]] = "tf.SelectV2"(%[[LE]], %[[M]], %[[NUM_LOWER]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
   // CHECK-DAG: %[[LE1:.*]] = "tf.Less"(%[[NUM_UPPER]], %[[ZERO]]) : (tensor<i64>, tensor<i64>) -> tensor<i1>
@@ -1373,17 +1373,17 @@
 // CHECK-LABEL: func @dynamic_shape_matrix_band_part
 // CHECK-SAME: (%[[INPUT:.*]]: tensor<?x?xf32>, %[[NUM_LOWER:.*]]: tensor<i32>, %[[NUM_UPPER:.*]]: tensor<i32>) -> tensor<?x?xf32> {
 func.func @dynamic_shape_matrix_band_part(%input: tensor<?x?xf32>, %num_lower: tensor<i32>, %num_upper: tensor<i32>) -> tensor<?x?xf32> {
-  // CHECK-DAG: %[[ZERO:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-  // CHECK-DAG: %[[ONE:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
-  // CHECK-DAG: %[[NEG_ONE:.*]] = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
-  // CHECK-DAG: %[[ZERO_1D:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
-  // CHECK-DAG: %[[ONE_1D:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
-  // CHECK-DAG: %[[TWO_1D:.*]] = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
-  // CHECK-DAG: %[[ZERO_F32:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
+  // CHECK-DAG: %[[ZERO:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
+  // CHECK-DAG: %[[ONE:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
+  // CHECK-DAG: %[[NEG_ONE:.*]] = "tf.Const"() <{value = dense<-1> : tensor<i32>}> : () -> tensor<i32>
+  // CHECK-DAG: %[[ZERO_1D:.*]] = "tf.Const"() <{value = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+  // CHECK-DAG: %[[ONE_1D:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
+  // CHECK-DAG: %[[TWO_1D:.*]] = "tf.Const"() <{value = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
+  // CHECK-DAG: %[[ZERO_F32:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
   // CHECK-DAG: %[[SHAPE:.*]] = "tf.Shape"(%[[INPUT]]) : (tensor<?x?xf32>) -> tensor<2xi32>
-  // CHECK-DAG: %[[M:.*]] = "tf.StridedSlice"(%[[SHAPE]], %[[ZERO_1D]], %[[ONE_1D]], %[[ONE_1D]]) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
+  // CHECK-DAG: %[[M:.*]] = "tf.StridedSlice"(%[[SHAPE]], %[[ZERO_1D]], %[[ONE_1D]], %[[ONE_1D]]) <{begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64}> : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
   // CHECK-DAG: %[[SHAPE1:.*]] = "tf.Shape"(%[[INPUT]]) : (tensor<?x?xf32>) -> tensor<2xi32>
-  // CHECK-DAG: %[[N:.*]] = "tf.StridedSlice"(%[[SHAPE1]], %[[ONE_1D]], %[[TWO_1D]], %[[ONE_1D]]) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
+  // CHECK-DAG: %[[N:.*]] = "tf.StridedSlice"(%[[SHAPE1]], %[[ONE_1D]], %[[TWO_1D]], %[[ONE_1D]]) <{begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64}> : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
   // CHECK-DAG: %[[LE:.*]] = "tf.Less"(%[[NUM_LOWER]], %[[ZERO]]) : (tensor<i32>, tensor<i32>) -> tensor<i1>
   // CHECK-DAG: %[[NUM_LOWER_OR_M:.*]] = "tf.SelectV2"(%[[LE]], %[[M]], %[[NUM_LOWER]]) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
   // CHECK-DAG: %[[LE1:.*]] = "tf.Less"(%[[NUM_UPPER]], %[[ZERO]]) : (tensor<i32>, tensor<i32>) -> tensor<i1>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir
index 837c37b..ca1e4c9 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir
@@ -131,7 +131,7 @@
 // CHECK-LABEL: func @op_string_result
 func.func @op_string_result() -> tensor<i32> {
   %0 = "tf_device.cluster"() ({
-    // CHECK: "tf.Const"() {value = dense<1> : tensor<i32>}
+    // CHECK: "tf.Const"() <{value = dense<1> : tensor<i32>}>
     // CHECK-NOT: _xla_outside_compilation
     // CHECK: "tf.Const"
     // CHECK-SAME: _xla_outside_compilation
@@ -148,7 +148,7 @@
 // CHECK-LABEL: func @op_string_operand
 func.func @op_string_operand(%arg0: tensor<!tf_type.string>) -> tensor<i32> {
   %0 = "tf_device.cluster"() ({
-    // CHECK: "tf.Const"() {value = dense<1> : tensor<i32>}
+    // CHECK: "tf.Const"() <{value = dense<1> : tensor<i32>}>
     // CHECK-NOT: _xla_outside_compilation
     // CHECK: "tf.StringToNumber"
     // CHECK-SAME: _xla_outside_compilation
@@ -166,7 +166,7 @@
 // CHECK-LABEL: func @op_string_operand_string_result
 func.func @op_string_operand_string_result(%arg0: tensor<!tf_type.string>) -> tensor<i32> {
   %0 = "tf_device.cluster"() ({
-    // CHECK: "tf.Const"() {value = dense<1> : tensor<i32>}
+    // CHECK: "tf.Const"() <{value = dense<1> : tensor<i32>}>
     // CHECK-NOT: _xla_outside_compilation
     // CHECK: "tf.Identity"
     // CHECK-SAME: _xla_outside_compilation
@@ -187,7 +187,7 @@
 // CHECK-LABEL: func @ops_inside_tf_if_outside_compiled
 func.func @ops_inside_tf_if_outside_compiled(%arg0: tensor<i1>, %arg1: tensor<!tf_type.string>) -> tensor<f32> {
   %0 = "tf_device.cluster"() ({
-    // CHECK:      "tf.Const"() {value = dense<1> : tensor<i32>}
+    // CHECK:      "tf.Const"() <{value = dense<1> : tensor<i32>}>
     // CHECK-NOT:  _xla_outside_compilation
     // CHECK:      "tf.IfRegion"
     // CHECK:        "tf.StringToNumber"
@@ -212,22 +212,22 @@
 // CHECK-LABEL: func @if_region_string_op
 func.func @if_region_string_op(%arg0: tensor<i1>, %arg1: tensor<?xi32>) -> tensor<f32> {
   %0 = "tf_device.cluster"() ({
-    // CHECK: "tf.Const"() {value = dense<1> : tensor<i32>}
+    // CHECK: "tf.Const"() <{value = dense<1> : tensor<i32>}>
     // CHECK-NOT: _xla_outside_compilation
     // CHECK: "tf.IfRegion"
+    // CHECK: <{is_stateless
     // CHECK-NOT: _xla_outside_compilation
     %1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
     %2 = "tf.IfRegion"(%arg0) ({
       %3 = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
       "tf.Yield"(%3) : (tensor<f32>) -> ()
      },  {
-      // CHECK: "tf.Const"() {_xla_outside_compilation = "auto0", value = dense<"1.0"> : tensor<!tf_type.string>}
+      // CHECK: "tf.Const"() <{value = dense<"1.0"> : tensor<!tf_type.string>}> {_xla_outside_compilation = "auto0"}
       // CHECK-NEXT: "tf.StringToNumber"
       // CHECK-SAME: _xla_outside_compilation
       %4 = "tf.Const"() {value = dense<"1.0"> : tensor<!tf_type.string>} : () -> tensor<!tf_type.string>
       %5 = "tf.StringToNumber"(%4) {out_type = f32} : (tensor<!tf_type.string>) -> tensor<f32>
       "tf.Yield"(%5) : (tensor<f32>) -> ()
-    // CHECK: {is_stateless
     }) {is_stateless = true} : (tensor<i1>) -> (tensor<f32>)
     %6 = "tf.Identity"(%2) : (tensor<f32>) -> tensor<f32>
     tf_device.return %6: tensor<f32>
@@ -241,34 +241,34 @@
 // CHECK-LABEL: func @nested_if_region_string_op
 func.func @nested_if_region_string_op(%arg0: tensor<i1>, %arg1: tensor<?xi32>) -> tensor<f32> {
   %0 = "tf_device.cluster"() ({
-    // CHECK: "tf.Const"() {value = dense<1> : tensor<i32>}
+    // CHECK: "tf.Const"() <{value = dense<1> : tensor<i32>}>
     // CHECK-NOT: _xla_outside_compilation
     // CHECK: "tf.IfRegion"
+    // CHECK: <{is_stateless
     // CHECK-NOT: _xla_outside_compilation
     %1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
     %2 = "tf.IfRegion"(%arg0) ({
       %3 = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
       "tf.Yield"(%3) : (tensor<f32>) -> ()
       },  {
-       // CHECK: "tf.Const"() {value = dense<true> : tensor<i1>}
+       // CHECK: "tf.Const"() <{value = dense<true> : tensor<i1>}>
+       // CHECK: <{is_stateless
        // CHECK-NOT: _xla_outside_compilation
        %4 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
        %5 = "tf.IfRegion"(%4)({
-         // CHECK: "tf.Const"() {_xla_outside_compilation = "auto0", value = dense<"1.0"> : tensor<!tf_type.string>}
+         // CHECK: "tf.Const"() <{value = dense<"1.0"> : tensor<!tf_type.string>}> {_xla_outside_compilation = "auto0"}
          // CHECK-NEXT: "tf.StringToNumber"
          // CHECK-SAME: _xla_outside_compilation
          %6 = "tf.Const"() {value = dense<"1.0"> : tensor<!tf_type.string>} : () -> tensor<!tf_type.string>
          %7 = "tf.StringToNumber"(%6) {out_type = f32} : (tensor<!tf_type.string>) -> tensor<f32>
          "tf.Yield"(%7) : (tensor<f32>) -> ()
        },  {
-         // CHECK: "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
+         // CHECK: "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}>
          // CHECK-NOT: _xla_outside_compilation
          %8 = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
          "tf.Yield"(%8) : (tensor<f32>) -> ()
-       // CHECK: {is_stateless
        }){is_stateless = true} : (tensor<i1>) -> (tensor<f32>)
        "tf.Yield"(%5) : (tensor<f32>) -> ()
-    // CHECK: {is_stateless
     }) {is_stateless = true} : (tensor<i1>) -> (tensor<f32>)
     %9 = "tf.Identity"(%2) : (tensor<f32>) -> tensor<f32>
     tf_device.return %9: tensor<f32>
@@ -282,7 +282,7 @@
 // CHECK-LABEL: func @ops_inside_while_outside_compiled
 func.func @ops_inside_while_outside_compiled(%arg0: tensor<i32>, %arg1: tensor<!tf_type.string>) -> tensor<f32> {
   %0 = "tf_device.cluster"() ({
-    // CHECK:     "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
+    // CHECK:     "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}>
     // CHECK-NOT: _xla_outside_compilation
     // CHECK:     "tf.WhileRegion"
     // CHECK:       "tf.StringToNumber"
@@ -313,9 +313,10 @@
 // CHECK-LABEL: func @while_region_unsupported_op
 func.func @while_region_unsupported_op(%arg0: tensor<i32>, %arg1: tensor<!tf_type.string>) -> tensor<f32> {
   %0 = "tf_device.cluster"() ({
-    // CHECK: "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
+    // CHECK: "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}>
     // CHECK-NOT: _xla_outside_compilation
     // CHECK: "tf.WhileRegion"
+    // CHECK: <{is_stateless = true
     %1 = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
     %2:2 = "tf.WhileRegion"(%1, %arg0) ({
       ^bb0(%carg0: tensor<f32>, %carg1: tensor<i32>):
@@ -329,10 +330,9 @@
         // CHECK: "tf.UnsupportedOp"
         // CHECK-SAME: _xla_outside_compilation
         %3 = "tf.UnsupportedOp"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
-        // CHECK: "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
+        // CHECK: "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}>
         %4 = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
         "tf.Yield"(%4, %sub) : (tensor<f32>, tensor<i32>) -> ()
-    // CHECK: {is_stateless = true
     }) {is_stateless = true} : (tensor<f32>, tensor<i32>) -> (tensor<f32>, tensor<i32>)
     // CHECK: "tf.Identity"
     // CHECK-NOT: _xla_outside_compilation
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/merge_control_flow.mlir b/tensorflow/compiler/mlir/tensorflow/tests/merge_control_flow.mlir
index bd5f805..a9e4fc9 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/merge_control_flow.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/merge_control_flow.mlir
@@ -807,19 +807,19 @@
 func.func @two_overlapped_if_groups_with_no_dependency_merged() {
   // CHECK:      tf_device.cluster
   // CHECK:        "tf.IfRegion"
-  // CHECK:          "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
-  // CHECK           "tf.Const"() {value = dense<5.000000e+00> : tensor<f32>} : () -> tensor<f32>
-  // CHECK           "tf.Const"() {value = dense<9.000000e+00> : tensor<f32>} : () -> tensor<f32>
-  // CHECK:          "tf.Const"() {value = dense<2.000000e+00> : tensor<f32>} : () -> tensor<f32>
-  // CHECK           "tf.Const"() {value = dense<6.000000e+00> : tensor<f32>} : () -> tensor<f32>
-  // CHECK           "tf.Const"() {value = dense<1.000000e+01> : tensor<f32>} : () -> tensor<f32>
+  // CHECK:          "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+  // CHECK           "tf.Const"() <{value = dense<5.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+  // CHECK           "tf.Const"() <{value = dense<9.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+  // CHECK:          "tf.Const"() <{value = dense<2.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+  // CHECK           "tf.Const"() <{value = dense<6.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+  // CHECK           "tf.Const"() <{value = dense<1.000000e+01> : tensor<f32>}> : () -> tensor<f32>
   // CHECK:        "tf.IfRegion"
-  // CHECK:          "tf.Const"() {value = dense<3.000000e+00> : tensor<f32>} : () -> tensor<f32>
-  // CHECK           "tf.Const"() {value = dense<7.000000e+00> : tensor<f32>} : () -> tensor<f32>
-  // CHECK           "tf.Const"() {value = dense<1.100000e+01> : tensor<f32>} : () -> tensor<f32>
-  // CHECK:          "tf.Const"() {value = dense<4.000000e+00> : tensor<f32>} : () -> tensor<f32>
-  // CHECK           "tf.Const"() {value = dense<8.000000e+00> : tensor<f32>} : () -> tensor<f32>
-  // CHECK           "tf.Const"() {value = dense<1.200000e+01> : tensor<f32>} : () -> tensor<f32>
+  // CHECK:          "tf.Const"() <{value = dense<3.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+  // CHECK           "tf.Const"() <{value = dense<7.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+  // CHECK           "tf.Const"() <{value = dense<1.100000e+01> : tensor<f32>}> : () -> tensor<f32>
+  // CHECK:          "tf.Const"() <{value = dense<4.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+  // CHECK           "tf.Const"() <{value = dense<8.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+  // CHECK           "tf.Const"() <{value = dense<1.200000e+01> : tensor<f32>}> : () -> tensor<f32>
   // CHECK-NOT:    "tf.IfRegion"
   "tf_device.cluster"() ({
     %0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
@@ -888,12 +888,12 @@
   // CHECK:          "tf.E"
   // CHECK:          "tf.F"
   // CHECK:        "tf.IfRegion"
-  // CHECK:          "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
-  // CHECK:          "tf.Const"() {value = dense<3.000000e+00> : tensor<f32>} : () -> tensor<f32>
-  // CHECK:          "tf.Const"() {value = dense<5.000000e+00> : tensor<f32>} : () -> tensor<f32>
-  // CHECK:          "tf.Const"() {value = dense<2.000000e+00> : tensor<f32>} : () -> tensor<f32>
-  // CHECK:          "tf.Const"() {value = dense<4.000000e+00> : tensor<f32>} : () -> tensor<f32>
-  // CHECK;          "tf.Const"() {value = dense<6.000000e+00> : tensor<f32>} : () -> tensor<f32>
+  // CHECK:          "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+  // CHECK:          "tf.Const"() <{value = dense<3.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+  // CHECK:          "tf.Const"() <{value = dense<5.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+  // CHECK:          "tf.Const"() <{value = dense<2.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+  // CHECK:          "tf.Const"() <{value = dense<4.000000e+00> : tensor<f32>}> : () -> tensor<f32>
+  // CHECK;          "tf.Const"() <{value = dense<6.000000e+00> : tensor<f32>}> : () -> tensor<f32>
   // CHECK-NOT:    "tf.IfRegion"
 func.func @two_overlapped_if_groups_with_dependency_not_merged_for_first_if_region_group() {
   "tf_device.cluster"() ({
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlprogram.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlprogram.mlir
index 5b0958a..6ded59b 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/mlprogram.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/mlprogram.mlir
@@ -123,7 +123,7 @@
   // CHECK-LABEL @lowers_string_ops
   // CHECK-DAG: ml_program.global public @vars.Variable_1([]) : tensor<!tf_type.string>
   func.func @lowers_string_ops(%arg0: tensor<128xi32>, %arg1: tensor<128xi32>, %arg2: tensor<128x1xi32>, %arg3: tensor<128x90xi32>, %arg4: tensor<128x90xi32>, %arg5: tensor<128x90xi32>, %arg6: tensor<128x90x64xf32>, %arg7: tensor<128x90x64xf32>) -> tensor<!tf_type.string> {
-    // CHECK: %0 = ml_program.global_load @vars.Variable_1 : tensor<!tf_type.string>
+    // CHECK: %[[v0:.*]] = ml_program.global_load @vars.Variable_1 : tensor<!tf_type.string>
     %0 = tf_executor.graph {
       %outputs_4, %control_5 = tf_executor.island wraps "tf.VarHandleOp"() {container = "", shared_name = "Variable"} : () -> tensor<!tf_type.resource<tensor<!tf_type.string>>>
       %outputs_10, %control_11 = tf_executor.island wraps "tf.VarHandleOp"() {container = "", shared_name = "Variable_1"} : () -> tensor<!tf_type.resource<tensor<!tf_type.string>>>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/outside_compiled_to_host_launch.mlir b/tensorflow/compiler/mlir/tensorflow/tests/outside_compiled_to_host_launch.mlir
index 2f74453..c0230b4 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/outside_compiled_to_host_launch.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/outside_compiled_to_host_launch.mlir
@@ -22,10 +22,10 @@
   func.func @nodep_single_outside_compilation() -> () {
     // CHECK:      "tf.A"
     // CHECK:      "tf_device.launch"
+    // CHECK-SAME: device = "/job:worker/replica:0/task:0/device:CPU:0"
     // CHECK-NEXT:   "tf.B"
     // CHECK-NOT:    _xla_outside_compilation
     // CHECK-NEXT: tf_device.return
-    // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
     // CHECK: device_assignment =  [], num_cores_per_replica = 1 : i64, topology =  ""
     "tf_device.cluster"() ({
       "tf.A"() : () -> ()
@@ -45,10 +45,10 @@
     // CHECK-NEXT:   "tf_device.cluster"
     // CHECK-NEXT:     "tf.B"
     // CHECK-NEXT:     "tf_device.launch"
+    // CHECK-SAME:     device = "TPU_REPLICATED_HOST_0"
     // CHECK-NEXT:       "tf.C"
     // CHECK-NOT:        _xla_outside_compilation
     // CHECK:            tf_device.return
-    // CHECK-NEXT:     device = "TPU_REPLICATED_HOST_0"
     // CHECK: device_assignment =  [], num_cores_per_replica = 1 : i64, topology =  ""
     %0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
     tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
@@ -136,10 +136,10 @@
   func.func @called_outside_compilation_callee() -> () {
     // CHECK:      "tf.A"
     // CHECK:      "tf_device.launch"
+    // CHECK-SAME: device = "/job:worker/replica:0/task:0/device:CPU:0"
     // CHECK-NEXT:   "tf.B"
     // CHECK-NOT:    _xla_outside_compilation
     // CHECK-NEXT: tf_device.return
-    // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
     "tf.A"() : () -> ()
     "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
     "tf.C"() : () -> ()
@@ -178,10 +178,10 @@
   func.func @outside_compilation_model_parallelism() -> () {
     // CHECK:      "tf.A"
     // CHECK:      "tf_device.launch"
+    // CHECK-SAME: device = "/job:worker/replica:0/task:0/device:CPU:0"
     // CHECK-NEXT:   "tf.B"
     // CHECK-NOT:    _xla_outside_compilation
     // CHECK-NEXT: tf_device.return
-    // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
     // CHECK: num_cores_per_replica = 2 : i64
     %0 = "tf_device.cluster"() ({
       "tf.A"() : () -> ()
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/prepare_tpu_computation_for_tf_export.mlir b/tensorflow/compiler/mlir/tensorflow/tests/prepare_tpu_computation_for_tf_export.mlir
index 45ee57a..021cad3 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/prepare_tpu_computation_for_tf_export.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/prepare_tpu_computation_for_tf_export.mlir
@@ -3,8 +3,8 @@
 // CHECK-LABEL: @ShardingAttr
 func.func @ShardingAttr(%arg0: tensor<128x10xf32> {mhlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"}, %arg1: tensor<10x1024xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<128x1024xf32> {mhlo.sharding = ""}) -> (tensor<128x10xf32>, tensor<10x1024xf32>, tensor<128x1024xf32>) {
 
-  // CHECK: %[[SHARDED_ARG0:.*]] = "tf.XlaSharding"(%arg0) {_XlaSharding = "\08\03\1A\02\01\02\22\02\00\01", sharding = "\08\03\1A\02\01\02\22\02\00\01"}
-  // CHECK: %[[SHARDED_ARG1:.*]] = "tf.XlaSharding"(%arg1) {_XlaSharding = "\08\01\1A\01\01\22\01\00", sharding = "\08\01\1A\01\01\22\01\00"}
+  // CHECK: %[[SHARDED_ARG0:.*]] = "tf.XlaSharding"(%arg0) <{_XlaSharding = "\08\03\1A\02\01\02\22\02\00\01", sharding = "\08\03\1A\02\01\02\22\02\00\01"}>
+  // CHECK: %[[SHARDED_ARG1:.*]] = "tf.XlaSharding"(%arg1) <{_XlaSharding = "\08\01\1A\01\01\22\01\00", sharding = "\08\01\1A\01\01\22\01\00"}>
 
   // CHECK: "tf.Identity"(%[[SHARDED_ARG1]])
   %0 = "tf.Identity"(%arg1) : (tensor<10x1024xf32>) -> tensor<10x1024xf32>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir b/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir
index 14f20b6..faf2a96 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir
@@ -62,7 +62,7 @@
 func.func @main(%arg0: tensor<i1>) -> tensor<2xf32> {
   // CHECK-NOT: "tf.VarHandleOp"
   // CHECK-NOT: "tf.ReadVariableOp"
-  // CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<4.200000e+01> : tensor<f32>}
+  // CHECK: %[[CONST:.*]] = "tf.Const"() <{value = dense<4.200000e+01> : tensor<f32>}>
   // CHECK: %[[ADD1:[0-9]*]] = "tf.AddV2"(%arg1, %[[CONST]])
   // CHECK: %[[ADD2:[0-9]*]] = "tf.AddV2"(%[[ADD1]], %arg1)
   // CHECK: %[[PACK:[0-9]*]] = "tf.Pack"(%[[CONST]], %[[ADD2]])
@@ -133,7 +133,7 @@
 // CHECK-LABEL: func @main(%arg0: tensor<i1>) -> (tensor<2xf32>, tensor<f32> {tf.resource_name = "x"})
 func.func @main(%arg0: tensor<i1>) -> tensor<2xf32> {
   // CHECK-NOT: "tf.AssignVariableOp"
-  // CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<4.200000e+01> : tensor<f32>}
+  // CHECK: %[[CONST:.*]] = "tf.Const"() <{value = dense<4.200000e+01> : tensor<f32>}>
   // CHECK: %[[ADD1:[0-9]*]] = "tf.AddV2"(%[[CONST]], %[[CONST]])
   // CHECK: %[[ADD2:[0-9]*]] = "tf.AddV2"(%[[ADD1]], %[[ADD1]])
   // CHECK: %[[PACK:[0-9]*]] = "tf.Pack"(%[[CONST]], %[[ADD2]])
@@ -222,7 +222,7 @@
 func.func @main(%arg0: tensor<!tf_type.resource<tensor<f32>>>, %arg1: tensor<i1>) {
   %0 = "tf.Const"() {value = dense<4.200000e+01> : tensor<f32>} : () -> tensor<f32>
   "tf.AssignVariableOp"(%arg0, %0) : (tensor<!tf_type.resource<tensor<f32>>>, tensor<f32>) -> ()
-  // CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<1.050000e+03> : tensor<f32>}
+  // CHECK: %[[CONST:.*]] = "tf.Const"() <{value = dense<1.050000e+03> : tensor<f32>}>
   %1 = "tf.Const"() {value = dense<1.050000e+03> : tensor<f32>} : () -> tensor<f32>
   "tf.AssignVariableOp"(%arg0, %1) : (tensor<!tf_type.resource<tensor<f32>>>, tensor<f32>) -> ()
   // CHECK-NEXT: return %[[CONST]] : tensor<f32>
@@ -241,7 +241,7 @@
 func.func @main(%arg0: tensor<!tf_type.resource<tensor<f32>>>, %arg1: tensor<i1>) -> tensor<f32> {
   %0 = "tf.Const"() {value = dense<4.200000e+01> : tensor<f32>} : () -> tensor<f32>
   "tf.AssignVariableOp"(%arg0, %0) : (tensor<!tf_type.resource<tensor<f32>>>, tensor<f32>) -> ()
-  // CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<1.050000e+03> : tensor<f32>}
+  // CHECK: %[[CONST:.*]] = "tf.Const"() <{value = dense<1.050000e+03> : tensor<f32>}>
   %1 = "tf.Const"() {value = dense<1.050000e+03> : tensor<f32>} : () -> tensor<f32>
   "tf.AssignVariableOp"(%arg0, %1) : (tensor<!tf_type.resource<tensor<f32>>>, tensor<f32>) -> ()
   // CHECK-NEXT: return %[[CONST]], %[[CONST]] : tensor<f32>, tensor<f32>
@@ -257,13 +257,13 @@
 // CHECK-SAME: %arg1: tensor<i1>
 // CHECK-SAME: -> (tensor<f32>, tensor<f32>)
 func.func @main(%arg0: tensor<!tf_type.resource<tensor<f32>>>, %arg1: tensor<i1>) -> tensor<f32> {
-  // CHECK-NEXT: %[[CONST_0:.*]] = "tf.Const"() {value = dense<4.200000e+01> : tensor<f32>}
+  // CHECK-NEXT: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<4.200000e+01> : tensor<f32>}>
   %0 = "tf.Const"() {value = dense<4.200000e+01> : tensor<f32>} : () -> tensor<f32>
   "tf.AssignVariableOp"(%arg0, %0) : (tensor<!tf_type.resource<tensor<f32>>>, tensor<f32>) -> ()
   %1 = "tf.ReadVariableOp"(%arg0) : (tensor<!tf_type.resource<tensor<f32>>>) -> tensor<f32>
   // CHECK-NEXT: %[[ADD:[a-z0-9]+]] = "tf.AddV2"(%[[CONST_0]], %[[CONST_0]])
   %2 = "tf.AddV2"(%1, %1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
-  // CHECK-NEXT: %[[CONST_1:.*]] = "tf.Const"() {value = dense<1.050000e+03> : tensor<f32>}
+  // CHECK-NEXT: %[[CONST_1:.*]] = "tf.Const"() <{value = dense<1.050000e+03> : tensor<f32>}>
   %3 = "tf.Const"() {value = dense<1.050000e+03> : tensor<f32>} : () -> tensor<f32>
   "tf.AssignVariableOp"(%arg0, %3) : (tensor<!tf_type.resource<tensor<f32>>>, tensor<f32>) -> ()
   // CHECK-NEXT: return %[[ADD]], %[[CONST_1]] : tensor<f32>, tensor<f32>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/region-control-flow-to-functional.mlir b/tensorflow/compiler/mlir/tensorflow/tests/region-control-flow-to-functional.mlir
index e26a299..eff3e38 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/region-control-flow-to-functional.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/region-control-flow-to-functional.mlir
@@ -7,11 +7,11 @@
 // CHECK-NEXT:   "tf.Abs"
 func.func @testSimple(%arg0: tensor<i1>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
   // CHECK: "tf.If"
-  // CHECK-SAME: _attr0 = false
-  // CHECK-SAME: _xla_propagate_compile_time_consts = true
   // CHECK-NOT: attr1
   // CHECK-SAME: else_branch = @test_else_name
   // CHECK-SAME: then_branch = @test_then_name
+  // CHECK-SAME: _attr0 = false
+  // CHECK-SAME: _xla_propagate_compile_time_consts = true
   %0 = "tf.IfRegion"(%arg0) ({
     %1 = "tf.Abs"(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
     "tf.Yield"(%1) : (tensor<*xf32>) -> ()
@@ -31,10 +31,10 @@
 // CHECK-NEXT:   "tf.Abs"
 func.func @testSimpleEmptyBranchNames(%arg0: tensor<i1>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
   // CHECK: "tf.If"
-  // CHECK-SAME: _attr0 = false
   // CHECK-NOT: attr1
   // CHECK-SAME: else_branch = @tf.IfRegion_else
   // CHECK-SAME: then_branch = @tf.IfRegion_then
+  // CHECK-SAME: _attr0 = false
   %0 = "tf.IfRegion"(%arg0) ({
     %1 = "tf.Abs"(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
     "tf.Yield"(%1) : (tensor<*xf32>) -> ()
@@ -78,7 +78,7 @@
 // CHECK-NEXT: constant dense<0.0
 func.func @testIfConstant(%arg0: tensor<i1>) -> tensor<2xf32> {
   %cst_zero = arith.constant dense<0.0> : tensor<2xf32>
-  // CHECK: "tf.If"(%arg0) {{.*}} else_branch = @tf.IfRegion_else{{.+}}then_branch = @tf.IfRegion_then
+  // CHECK: "tf.If"(%arg0) <{else_branch = @tf.IfRegion_else{{.+}}then_branch = @tf.IfRegion_then
   %0 = "tf.IfRegion"(%arg0) ({
      "tf.Yield"(%cst_zero) : (tensor<2xf32>) -> ()
     }, {
@@ -98,7 +98,7 @@
 // CHECK: func private @tf.IfRegion1_then
 // CHECK-NEXT: "tf.LogicalNot"
 // CHECK-NEXT: "tf.Asin"
-// CHECK-NEXT: "tf.If"({{.+}}) {{.*}} else_branch = @tf.IfRegion_else, {{.+}} then_branch = @tf.IfRegion_then}
+// CHECK-NEXT: "tf.If"({{.+}}) <{else_branch = @tf.IfRegion_else, {{.+}} then_branch = @tf.IfRegion_then}
 
 // CHECK: func private @tf.IfRegion_else
 // CHECK-NEXT: "tf.Neg"
@@ -106,7 +106,7 @@
 // CHECK-NEXT: "tf.Abs"
 
 func.func @testNested(%arg0: tensor<i1>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
-  // CHECK: "tf.If"({{.+}}) {{.*}} else_branch = @tf.IfRegion1_else, {{.+}} then_branch = @tf.IfRegion1_then}
+  // CHECK: "tf.If"({{.+}}) <{else_branch = @tf.IfRegion1_else, {{.+}} then_branch = @tf.IfRegion1_then}
   %0 = "tf.IfRegion"(%arg0) ({
     // Outer Then
     %cond = "tf.LogicalNot"(%arg0) : (tensor<i1>) -> tensor<i1>
@@ -137,7 +137,7 @@
 func.func private @testIf1Then(tensor<*xf32>) -> tensor<*xf32>
 func.func private @testIf1Else(tensor<*xf32>) -> tensor<*xf32>
 func.func @testIf1Result(%arg0: tensor<i1>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
-  // CHECK: "tf.If"({{.+}}) {{.*}} else_branch = @testIf1Else, {{.+}} then_branch = @testIf1Then}
+  // CHECK: "tf.If"({{.+}}) <{else_branch = @testIf1Else, {{.+}} then_branch = @testIf1Then}
   %0 = "tf.IfRegion"(%arg0) ({
     %1 = func.call @testIf1Then(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
     "tf.Yield"(%1) : (tensor<*xf32>) -> ()
@@ -155,7 +155,7 @@
 func.func private @testIf1Then(tensor<*xf32>) -> tensor<*xf32>
 func.func private @testIf1Else(tensor<*xf32>) -> tensor<*xf32>
 func.func @testIf2Result(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
-  // CHECK: "tf.If"({{.+}}) {{.*}} else_branch = @testIf1Else, {{.+}} then_branch = @testIf1Then}
+  // CHECK: "tf.If"({{.+}}) <{else_branch = @testIf1Else, {{.+}} then_branch = @testIf1Then}
   %0 = "tf.IfRegion"(%arg0) ({
     %1 = "tf.Cast"(%arg1) {Truncate = false} : (tensor<2xf32>) -> tensor<*xf32>
     %2 = func.call @testIf1Then(%1) : (tensor<*xf32>) -> tensor<*xf32>
@@ -175,7 +175,7 @@
 func.func private @testIf1Then(tensor<*xf32>) -> tensor<*xf32>
 func.func private @testIf1Else(tensor<*xf32>) -> tensor<*xf32>
 func.func @testIf2Result(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
-  // CHECK: "tf.If"({{.+}}) {{.*}} else_branch = @testIf1Else, {{.+}} then_branch = @testIf1Then}
+  // CHECK: "tf.If"({{.+}}) <{else_branch = @testIf1Else, {{.+}} then_branch = @testIf1Then}
   %0 = "tf.IfRegion"(%arg0) ({
     %1 = "tf.Cast"(%arg1) {Truncate = false} : (tensor<2xf32>) -> tensor<?xf32>
     %2 = "tf.Cast"(%1) {Truncate = false} : (tensor<?xf32>) -> tensor<*xf32>
@@ -197,8 +197,8 @@
 func.func private @testIf1Then(tensor<*xf32>) -> tensor<*xf32>
 func.func private @testIf1Else(tensor<*xf32>) -> tensor<*xf32>
 func.func @testIfExternIncompatibleCastTrivialTransform(%arg0: tensor<i1>, %arg1: tensor<2xi64>) -> tensor<2xf32> {
-  // CHECK: %[[CAST:.*]] = "tf.Cast"(%arg1) {Truncate = false} : (tensor<2xi64>) -> tensor<*xf32>
-  // CHECK: "tf.If"(%arg0, %[[CAST]]) {{.*}} else_branch = @testIf1Else, {{.+}} then_branch = @testIf1Then}
+  // CHECK: %[[CAST:.*]] = "tf.Cast"(%arg1) <{Truncate = false}> : (tensor<2xi64>) -> tensor<*xf32>
+  // CHECK: "tf.If"(%arg0, %[[CAST]]) <{else_branch = @testIf1Else, {{.+}} then_branch = @testIf1Then}
   %1 = "tf.Cast"(%arg1) {Truncate = false} : (tensor<2xi64>) -> tensor<*xf32>
   %0 = "tf.IfRegion"(%arg0) ({
     %2 = func.call @testIf1Then(%1) : (tensor<*xf32>) -> tensor<*xf32>
@@ -221,7 +221,7 @@
 func.func private @testIf1Then(tensor<*xf32>) -> tensor<*xf32>
 func.func private @testIf1Else(tensor<*xf32>) -> tensor<*xf32>
 func.func @testIfIncompatibleCastTrivialTransform(%arg0: tensor<i1>, %arg1: tensor<2xi64>) -> tensor<2xf32> {
-  // CHECK: "tf.If"(%arg0, %arg1) {{.*}} else_branch = @tf.IfRegion_else{{.+}}then_branch = @tf.IfRegion_then}
+  // CHECK: "tf.If"(%arg0, %arg1) <{else_branch = @tf.IfRegion_else{{.+}}then_branch = @tf.IfRegion_then}
   %0 = "tf.IfRegion"(%arg0) ({
     %1 = "tf.Cast"(%arg1) {Truncate = false} : (tensor<2xi64>) -> tensor<*xf32>
     %2 = func.call @testIf1Then(%1) : (tensor<*xf32>) -> tensor<*xf32>
@@ -341,11 +341,11 @@
 // CHECK-LABEL: testValidWhileRegion
 func.func @testValidWhileRegion(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>) -> tensor<*xf32> {
   // CHECK: [[Result:%.*]]:2 = "tf.While"(%arg0, %arg1)
-  // CHECK-SAME: _attr0 = false
-  // CHECK-SAME: _xla_propagate_compile_time_consts = true
   // CHECK-NOT: attr1
   // CHECK-SAME: body = @tf.WhileRegion_body
   // CHECK-SAME: cond = @tf.WhileRegion_cond
+  // CHECK-SAME: _attr0 = false
+  // CHECK-SAME: _xla_propagate_compile_time_consts = true
   %0:2 = "tf.WhileRegion"(%arg0, %arg1) (
     {
       // condition, check if count has reached 0
@@ -379,7 +379,7 @@
 // CHECK: "tf.NotEqual"
 // CHECK-LABEL: testWhileRegionTypeMismatch
 func.func @testWhileRegionTypeMismatch(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>) -> tensor<*xf32> {
-  // CHECK: [[Result:%.*]]:2 = "tf.While"(%arg0, %arg1) {{.*}} body = @tf.WhileRegion_body, cond = @tf.WhileRegion_cond
+  // CHECK: [[Result:%.*]]:2 = "tf.While"(%arg0, %arg1) <{body = @tf.WhileRegion_body, cond = @tf.WhileRegion_cond
   %0:2 = "tf.WhileRegion"(%arg0, %arg1) (
     {
       // condition, check if count has reached 0
@@ -415,7 +415,7 @@
 func.func @testWhileRegionConstantSink(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>) -> tensor<*xf32> {
   %zero = arith.constant dense<0> : tensor<i32>
   %one = arith.constant dense<1> : tensor<i32>
-  // CHECK: [[Result:%.*]]:2 = "tf.While"(%arg0, %arg1) {{.*}} body = @tf.WhileRegion_body, cond = @tf.WhileRegion_cond
+  // CHECK: [[Result:%.*]]:2 = "tf.While"(%arg0, %arg1) <{body = @tf.WhileRegion_body, cond = @tf.WhileRegion_cond
   %0:2 = "tf.WhileRegion"(%arg0, %arg1) (
     {
       ^bb0(%carg0: tensor<4xf32>, %carg1: tensor<i32>):
@@ -448,7 +448,7 @@
 func.func @testWhileRegionExternInCond(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>, %arg2 : tensor<i32>) -> tensor<*xf32> {
   %cst = arith.constant dense<4> : tensor<i32>
   %limit = "tf.Add"(%arg2, %cst) : (tensor<i32>, tensor<i32>) -> tensor<i32>
-  // CHECK: [[Result:%.*]]:3 = "tf.While"(%arg0, %arg1, %{{.+}} body = @tf.WhileRegion_body, cond = @tf.WhileRegion_cond
+  // CHECK: [[Result:%.*]]:3 = "tf.While"(%arg0, %arg1, %{{.+}} <{body = @tf.WhileRegion_body, cond = @tf.WhileRegion_cond
   %0:2 = "tf.WhileRegion"(%arg0, %arg1) (
     {
       ^bb0(%carg0: tensor<*xf32>, %carg1: tensor<i32>):
@@ -485,7 +485,7 @@
   %zero = arith.constant dense<0> : tensor<i32>
   %cst = arith.constant dense<4> : tensor<i32>
   %stride = "tf.Add"(%arg2, %cst) : (tensor<i32>, tensor<i32>) -> tensor<i32>
-  // CHECK: [[Result:%.*]]:3 = "tf.While"(%arg0, %arg1, %{{.+}} body = @tf.WhileRegion_body, cond = @tf.WhileRegion_cond
+  // CHECK: [[Result:%.*]]:3 = "tf.While"(%arg0, %arg1, %{{.+}} <{body = @tf.WhileRegion_body, cond = @tf.WhileRegion_cond
   %0:2 = "tf.WhileRegion"(%arg0, %arg1) (
     {
       ^bb0(%carg0: tensor<*xf32>, %carg1: tensor<i32>):
@@ -516,7 +516,7 @@
   %stride = "tf.Add"(%arg2, %cst) : (tensor<i32>, tensor<i32>) -> tensor<i32>
   %cst1 = arith.constant dense<44> : tensor<i32>
   %limit = "tf.Add"(%arg2, %cst1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
-  // CHECK: [[Result:%.*]]:4 = "tf.While"(%arg0, %arg1, %{{.+}}, %{{.+}} body = @tf.WhileRegion_body, cond = @tf.WhileRegion_cond
+  // CHECK: [[Result:%.*]]:4 = "tf.While"(%arg0, %arg1, %{{.+}}, %{{.+}} <{body = @tf.WhileRegion_body, cond = @tf.WhileRegion_cond
   %0:2 = "tf.WhileRegion"(%arg0, %arg1) (
     {
       ^bb0(%carg0: tensor<*xf32>, %carg1: tensor<i32>):
@@ -545,7 +545,7 @@
 func.func @testWhileRegionSameExternInBodyAndCond(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>, %arg2 : tensor<i32>) -> tensor<*xf32> {
   %cst = arith.constant dense<4> : tensor<i32>
   %stride = "tf.Add"(%arg2, %cst) : (tensor<i32>, tensor<i32>) -> tensor<i32>
-  // CHECK: [[Result:%.*]]:3 = "tf.While"(%arg0, %arg1, %{{.+}} body = @tf.WhileRegion_body, cond = @tf.WhileRegion_cond
+  // CHECK: [[Result:%.*]]:3 = "tf.While"(%arg0, %arg1, %{{.+}} <{body = @tf.WhileRegion_body, cond = @tf.WhileRegion_cond
   %0:2 = "tf.WhileRegion"(%arg0, %arg1) (
     {
       ^bb0(%carg0: tensor<*xf32>, %carg1: tensor<i32>):
@@ -573,7 +573,7 @@
 func.func private @while_cond(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>) -> tensor<i1>
 func.func private @while_body(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>) -> (tensor<*xf32>, tensor<i32>)
 func.func @testWhileRegionTrivial(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>) -> tensor<*xf32> {
-  // CHECK: [[Result:%.*]]:2 = "tf.While"(%arg0, %arg1) {{.*}} body = @while_body, cond = @while_cond
+  // CHECK: [[Result:%.*]]:2 = "tf.While"(%arg0, %arg1) <{body = @while_body, cond = @while_cond
   %0:2 = "tf.WhileRegion"(%arg0, %arg1) (
     {
       ^bb0(%carg0: tensor<*xf32>, %carg1: tensor<i32>):
@@ -600,7 +600,7 @@
 func.func private @while_cond(%arg0 : tensor<4xf32>, %arg1 : tensor<i32>) -> tensor<i1>
 func.func private @while_body(%arg0 : tensor<4xf32>, %arg1 : tensor<i32>) -> (tensor<4xf32>, tensor<i32>)
 func.func @testWhileRegionTrivialCasts(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>) -> tensor<*xf32> {
-  // CHECK: [[Result:%.*]]:2 = "tf.While"(%arg0, %arg1) {{.*}} body = @while_body, cond = @while_cond
+  // CHECK: [[Result:%.*]]:2 = "tf.While"(%arg0, %arg1) <{body = @while_body, cond = @while_cond
   %0:2 = "tf.WhileRegion"(%arg0, %arg1) (
     {
       ^bb0(%carg0: tensor<*xf32>, %carg1: tensor<i32>):
@@ -629,7 +629,7 @@
 func.func private @while_cond(%arg0 : tensor<4xf32>, %arg1 : tensor<i32>) -> tensor<i1>
 func.func private @while_body(%arg0 : tensor<4xf32>, %arg1 : tensor<i32>) -> (tensor<4xf32>, tensor<i32>)
 func.func @testWhileRegionTrivialMultipleCasts(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>) -> tensor<*xf32> {
-  // CHECK: [[Result:%.*]]:2 = "tf.While"(%arg0, %arg1) {{.*}} body = @while_body, cond = @while_cond
+  // CHECK: [[Result:%.*]]:2 = "tf.While"(%arg0, %arg1) <{body = @while_body, cond = @while_cond
   %0:2 = "tf.WhileRegion"(%arg0, %arg1) (
     {
       ^bb0(%carg0: tensor<*xf32>, %carg1: tensor<i32>):
@@ -662,7 +662,7 @@
 func.func private @while_cond(%arg0 : tensor<4xf32>, %arg1 : tensor<i32>) -> tensor<i1>
 func.func private @while_body(%arg0 : tensor<4xf32>, %arg1 : tensor<i32>) -> (tensor<4xi64>, tensor<i32>)
 func.func @testWhileRegionIncompatibleCast(%arg0 : tensor<*xi64>, %arg1 : tensor<i32>) -> tensor<*xi64> {
-  // CHECK: [[Result:%.*]]:2 = "tf.While"(%arg0, %arg1) {{.*}} body = @tf.WhileRegion_body, cond = @tf.WhileRegion_cond
+  // CHECK: [[Result:%.*]]:2 = "tf.While"(%arg0, %arg1) <{body = @tf.WhileRegion_body, cond = @tf.WhileRegion_cond
   %0:2 = "tf.WhileRegion"(%arg0, %arg1) (
     {
       ^bb0(%carg0: tensor<*xi64>, %carg1: tensor<i32>):
@@ -694,7 +694,7 @@
 func.func private @while_body(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>, %arg2 : tensor<*xf32>) -> (tensor<*xf32>, tensor<i32>)
 func.func @testWhileRegionExtern(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>) -> tensor<*xf32> {
   %ext = "tf.Neg"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
-  // CHECK: [[Result:%.*]]:3 = "tf.While"(%arg0, %arg1, %{{.+}} body = @tf.WhileRegion_body, cond = @tf.WhileRegion_cond
+  // CHECK: [[Result:%.*]]:3 = "tf.While"(%arg0, %arg1, %{{.+}} <{body = @tf.WhileRegion_body, cond = @tf.WhileRegion_cond
   %0:2 = "tf.WhileRegion"(%arg0, %arg1) (
     {
       ^bb0(%carg0: tensor<*xf32>, %carg1: tensor<i32>):
@@ -723,7 +723,7 @@
 func.func private @while_cond(%arg0 : tensor<i32>, %arg1 : tensor<*xf32>) -> tensor<i1>
 func.func private @while_body(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>) -> (tensor<*xf32>, tensor<i32>)
 func.func @testWhileRegionBlockArgMismatch(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>) -> tensor<*xf32> {
-  // CHECK: [[Result:%.*]]:2 = "tf.While"(%arg0, %arg1) {{.*}} body = @tf.WhileRegion_body, cond = @tf.WhileRegion_cond
+  // CHECK: [[Result:%.*]]:2 = "tf.While"(%arg0, %arg1) <{body = @tf.WhileRegion_body, cond = @tf.WhileRegion_cond
   %0:2 = "tf.WhileRegion"(%arg0, %arg1) (
     {
       ^bb0(%carg0: tensor<*xf32>, %carg1: tensor<i32>):
@@ -750,7 +750,7 @@
 func.func private @while_cond(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>) -> tensor<i32>
 func.func private @while_body(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>) -> (tensor<*xf32>, tensor<i32>)
 func.func @testWhileRegionTrivial(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>) -> tensor<*xf32> {
-  // CHECK: [[Result:%.*]]:2 = "tf.While"(%arg0, %arg1) {{.*}} body = @while_body, cond = @while_cond
+  // CHECK: [[Result:%.*]]:2 = "tf.While"(%arg0, %arg1) <{body = @while_body, cond = @while_cond
   %0:2 = "tf.WhileRegion"(%arg0, %arg1) (
     {
       ^bb0(%carg0: tensor<*xf32>, %carg1: tensor<i32>):
@@ -830,11 +830,11 @@
 // CHECK-LABEL: testValidWhileRegion
 func.func @testValidWhileRegion(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>) -> tensor<*xf32> {
   // CHECK: [[Result:%.*]]:2 = "tf.While"(%arg0, %arg1)
-  // CHECK-SAME: _attr0 = false
-  // CHECK-SAME: _xla_propagate_compile_time_consts = true
   // CHECK-NOT: attr1
   // CHECK-SAME: body = @tf.WhileRegion_body
   // CHECK-SAME: cond = @tf.WhileRegion_cond
+  // CHECK-SAME: _attr0 = false
+  // CHECK-SAME: _xla_propagate_compile_time_consts = true
   %0:2 = "tf.WhileRegion"(%arg0, %arg1) (
     {
       // condition, check if count has reached 0
@@ -881,3 +881,82 @@
   ) { is_stateless = false, _attr0 = false, attr1 = "hello"} : (tensor<*xf32>, tensor<i32>) -> (tensor<*xf32>, tensor<i32>)
   func.return %0#0 : tensor<*xf32>
 }
+
+// -----
+
+func.func @init(%arg0: tensor<4xf32>) -> tensor<7xf32> {
+  %0 = builtin.unrealized_conversion_cast to tensor<7xf32>
+  return %0 : tensor<7xf32>
+}
+func.func @next(%arg0: tensor<7xf32>, %arg1: tensor<3xf32>) -> tensor<6xf32> {
+  %0 = builtin.unrealized_conversion_cast to tensor<6xf32>
+  return %0 : tensor<6xf32>
+}
+func.func @finalize(%arg0: tensor<6xf32>, %arg1: tensor<2xf32>) -> tensor<5xf32> {
+  %0 = builtin.unrealized_conversion_cast to tensor<5xf32>
+  return %0 : tensor<5xf32>
+}
+
+// CHECK-LABEL: testGeneratorDatasetRegion
+func.func @testGeneratorDatasetRegion(%arg0: tensor<4xf32>, %arg1: tensor<3xf32>, %arg2: tensor<!tf_type.resource>, %arg3: tensor<2xf32>) {
+  // CHECK: "tf.GeneratorDataset"
+  // CHECK-DAG: @init
+  // CHECK-DAG: @next
+  // CHECK-DAG: @finalize
+  // CHECK: return
+  %0 = "tf.GeneratorDatasetRegion"(%arg0, %arg1, %arg2, %arg3) ({
+  ^bb0(%arg4: tensor<4xf32>):
+    %1 = func.call @init(%arg4) : (tensor<4xf32>) -> tensor<7xf32>
+    "tf.Yield"(%1) : (tensor<7xf32>) -> ()
+  }, {
+  ^bb0(%arg4: tensor<7xf32>, %arg5: tensor<3xf32>):
+    %1 = func.call @next(%arg4, %arg5) : (tensor<7xf32>, tensor<3xf32>) -> tensor<6xf32>
+    "tf.Yield"(%1) : (tensor<6xf32>) -> ()
+  }, {
+  ^bb0(%arg4: tensor<6xf32>, %arg5: tensor<2xf32>):
+    %1 = func.call @finalize(%arg4, %arg5) : (tensor<6xf32>, tensor<2xf32>) -> tensor<5xf32>
+    "tf.Yield"(%1) : (tensor<5xf32>) -> ()
+  }) {device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0", metadata = "", operandSegmentSizes = array<i32: 1, 2, 1>, output_shapes = [#tf_type.shape<>], output_types = [!tf_type.string]} : (tensor<4xf32>, tensor<3xf32>, tensor<!tf_type.resource>, tensor<2xf32>) -> tensor<!tf_type.variant>
+  return
+}
+
+// -----
+
+func.func @init(%arg0: tensor<4xf32>) -> tensor<7xf32> {
+  %0 = builtin.unrealized_conversion_cast to tensor<7xf32>
+  return %0 : tensor<7xf32>
+}
+func.func @next(%arg0: tensor<3xf32>, %arg1: tensor<7xf32>) -> tensor<6xf32> {
+  %0 = builtin.unrealized_conversion_cast to tensor<6xf32>
+  return %0 : tensor<6xf32>
+}
+func.func @finalize(%arg0: tensor<6xf32>, %arg1: tensor<2xf32>) -> tensor<5xf32> {
+  %0 = builtin.unrealized_conversion_cast to tensor<5xf32>
+  return %0 : tensor<5xf32>
+}
+
+// CHECK-LABEL: testGeneratorDatasetRegionWithComplexBlocks
+func.func @testGeneratorDatasetRegionWithComplexBlocks(%arg0: tensor<4xf32>, %arg1: tensor<3xf32>, %arg2: tensor<!tf_type.resource>, %arg3: tensor<2xf32>) {
+  // CHECK: "tf.GeneratorDataset"
+  // CHECK-NOT: @init
+  // CHECK-NOT: @next
+  // CHECK-NOT: @finalize
+  // CHECK: -> tensor<!tf_type.variant>
+  // CHECK: return
+  %0 = "tf.GeneratorDatasetRegion"(%arg0, %arg1, %arg2, %arg3) ({
+  ^bb0(%arg4: tensor<4xf32>):
+    %sum = "tf.Add"(%arg4, %arg4) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+    %1 = func.call @init(%sum) : (tensor<4xf32>) -> tensor<7xf32>
+    "tf.Yield"(%1) : (tensor<7xf32>) -> ()
+  }, {
+  ^bb0(%arg4: tensor<7xf32>, %arg5: tensor<3xf32>):
+    %1 = func.call @next(%arg5, %arg4) : (tensor<3xf32>, tensor<7xf32>) -> tensor<6xf32>
+    "tf.Yield"(%1) : (tensor<6xf32>) -> ()
+  }, {
+  ^bb0(%arg4: tensor<6xf32>, %arg5: tensor<2xf32>):
+    %1 = func.call @finalize(%arg4, %arg5) : (tensor<6xf32>, tensor<2xf32>) -> tensor<5xf32>
+    %sum = "tf.Add"(%1, %1) : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32>
+    "tf.Yield"(%sum) : (tensor<5xf32>) -> ()
+  }) {device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0", metadata = "", operandSegmentSizes = array<i32: 1, 2, 1>, output_shapes = [#tf_type.shape<>], output_types = [!tf_type.string]} : (tensor<4xf32>, tensor<3xf32>, tensor<!tf_type.resource>, tensor<2xf32>) -> tensor<!tf_type.variant>
+  return
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/replicate_invariant_op_hoisting.mlir b/tensorflow/compiler/mlir/tensorflow/tests/replicate_invariant_op_hoisting.mlir
index 024caf9..ec30a7b 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/replicate_invariant_op_hoisting.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/replicate_invariant_op_hoisting.mlir
@@ -156,19 +156,19 @@
 
 // CHECK:      %[[SHAPE:[0-9]*]] = "tf.Shape"(%[[ARG_0]])
 // CHECK-NEXT: %[[LAUNCH_A:[0-9]*]] = "tf_device.launch"
+// CHECK-SAME: device = "a"
 // CHECK-NEXT:   %[[OP_A:[0-9]*]] = "tf.opA"(%[[SHAPE]])
 // CHECK-NEXT:   tf_device.return %[[OP_A]]
-// CHECK-NEXT: device = "a"
-// CHECK-NEXT: %[[LAUNCH_B:[0-9]*]] = "tf_device.launch"
+// CHECK:      %[[LAUNCH_B:[0-9]*]] = "tf_device.launch"
+// CHECK-SAME: device = "b"
 // CHECK-NEXT:   %[[OP_B:[0-9]*]] = "tf.opB"(%[[SHAPE]], %[[LAUNCH_A]])
 // CHECK-NEXT:   tf_device.return %[[OP_B]]
-// CHECK-NEXT: device = "b"
-// CHECK-NEXT: tf_device.replicate([{{.*}}] as %[[RI:[a-z0-9]+]]: tensor<*xf32>)
+// CHECK: tf_device.replicate([{{.*}}] as %[[RI:[a-z0-9]+]]: tensor<*xf32>)
 // CHECK-NEXT:   %[[LAUNCH_C:[0-9]*]] = "tf_device.launch"
+// CHECK-SAME:   device = "c"
 // CHECK-NEXT:     %[[OP_C:[0-9]*]] = "tf.opC"(%[[RI]], %[[LAUNCH_B]])
 // CHECK-NEXT:     tf_device.return %[[OP_C]]
-// CHECK-NEXT:   device = "c"
-// CHECK-NEXT:   tf_device.return %[[SHAPE]], %[[LAUNCH_A]], %[[LAUNCH_B]], %[[LAUNCH_C]]
+// CHECK:   tf_device.return %[[SHAPE]], %[[LAUNCH_A]], %[[LAUNCH_B]], %[[LAUNCH_C]]
 
 
 // CHECK-LABEL:   func @do_not_hoist_ops_with_virtual_device
@@ -193,14 +193,14 @@
 // CHECK:  [[SHAPE:%.*]] = "tf.Shape"([[VAL_0]])
 // CHECK:  tf_device.replicate({{\[}}[[VAL_0]], [[VAL_1]]] as [[VAL_4:%.*]]: tensor<*xf32>) {devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} {
 // CHECK:    [[OP_A:%.*]] = "tf.opA"([[SHAPE]]) {device = "TPU_REPLICATED_CORE_0"} : (tensor<?xi32>) -> tensor<*xi32>
-// CHECK:    [[LAUNCH_B:%.*]] = "tf_device.launch"() ({
+// CHECK:    [[LAUNCH_B:%.*]] = "tf_device.launch"() <{device = "TPU_REPLICATED_CORE_0"}> ({
 // CHECK:      [[OP_B:%.*]] = "tf.opB"([[SHAPE]]) : (tensor<?xi32>) -> tensor<*xi32>
 // CHECK:      tf_device.return [[OP_B]] : tensor<*xi32>
-// CHECK:    }) {device = "TPU_REPLICATED_CORE_0"} : () -> tensor<*xi32>
-// CHECK:    [[LAUNCH_C:%.*]] = "tf_device.launch"() ({
+// CHECK:    }) : () -> tensor<*xi32>
+// CHECK:    [[LAUNCH_C:%.*]] = "tf_device.launch"() <{device = "c"}> ({
 // CHECK:      [[OP_C:%.*]] = "tf.opC"([[SHAPE]]) {device = "TPU_REPLICATED_CORE_0"} : (tensor<?xi32>) -> tensor<*xi32>
 // CHECK:      tf_device.return [[OP_C]] : tensor<*xi32>
-// CHECK:    }) {device = "c"} : () -> tensor<*xi32>
+// CHECK:    }) : () -> tensor<*xi32>
 // CHECK:    tf_device.return [[SHAPE]], [[OP_A]], [[LAUNCH_B]], [[LAUNCH_C]]
 
 
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir b/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir
index 8e0e455..a27a0ff 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir
@@ -44,9 +44,11 @@
 }
 
 // CHECK: "tf.opA"
-// CHECK: _parallel_execution_ids = "r0:0", device = "CORE_0"
+// device = "CORE_0"
+// CHECK: _parallel_execution_ids = "r0:0"
 // CHECK: "tf.opA"
-// CHECK: _parallel_execution_ids = "r0:1", device = "CORE_0"
+// device = "CORE_0"
+// CHECK: _parallel_execution_ids = "r0:1"
 
 
 // Tests devices are not remapped if device is not in replicate devices.
@@ -69,9 +71,11 @@
 }
 
 // CHECK: "tf.opA"
-// CHECK: _parallel_execution_ids = "r0:0", device = "/TPU:2"
+// device = "/TPU:2"
+// CHECK: _parallel_execution_ids = "r0:0"
 // CHECK: "tf.opA"
-// CHECK: _parallel_execution_ids = "r0:1", device = "/TPU:2"
+// device = "/TPU:2"
+// CHECK: _parallel_execution_ids = "r0:1"
 
 
 // Tests devices are remapped if device is in replicate devices.
@@ -94,9 +98,11 @@
 }
 
 // CHECK: "tf.opA"
-// CHECK: _parallel_execution_ids = "r0:0", device = "/CPU:0"
+// device = "/CPU:0"
+// CHECK: _parallel_execution_ids = "r0:0"
 // CHECK: "tf.opA"
-// CHECK: _parallel_execution_ids = "r0:1", device = "/GPU:1"
+// device = "/GPU:1"
+// CHECK: _parallel_execution_ids = "r0:1"
 
 
 // Tests replicate with control dependency output has each expanded replica
@@ -268,10 +274,10 @@
   func.return
 }
 
-// CHECK: tf_executor.island wraps "tf.Const"() {_parallel_execution_ids = "r0:0", value = dense<1> : tensor<i64>}
-// CHECK: tf_executor.island wraps "tf.Const"() {_parallel_execution_ids = "r0:0", value = dense<3> : tensor<i64>}
-// CHECK: tf_executor.island wraps "tf.Const"() {_parallel_execution_ids = "r0:1", value = dense<2> : tensor<i64>}
-// CHECK: tf_executor.island wraps "tf.Const"() {_parallel_execution_ids = "r0:1", value = dense<4> : tensor<i64>}
+// CHECK: tf_executor.island wraps "tf.Const"() <{value = dense<1> : tensor<i64>}> {_parallel_execution_ids = "r0:0"}
+// CHECK: tf_executor.island wraps "tf.Const"() <{value = dense<3> : tensor<i64>}> {_parallel_execution_ids = "r0:0"}
+// CHECK: tf_executor.island wraps "tf.Const"() <{value = dense<2> : tensor<i64>}> {_parallel_execution_ids = "r0:1"}
+// CHECK: tf_executor.island wraps "tf.Const"() <{value = dense<4> : tensor<i64>}> {_parallel_execution_ids = "r0:1"}
 
 // -----
 // Tests parallel_execute nested inside replicate
@@ -305,20 +311,20 @@
 // CHECK:      tf_executor.island
 // CHECK:      tf_device.parallel_execute
 // CHECK:      tf_device.launch
+// CHECK:      <{device = "/TPU:1"}>
 // CHECK:      tf.OpA
-// CHECK:      {device = "/TPU:1"}
 // CHECK:      tf_device.launch
+// CHECK:      <{device = "/TPU:2"}>
 // CHECK:      tf.OpB
-// CHECK:      {device = "/TPU:2"}
 // CHECK:      _parallel_execution_ids = "r0:0"
 // CHECK:      tf_executor.island
 // CHECK:      tf_device.parallel_execute
 // CHECK:      tf_device.launch
+// CHECK:      <{device = "/TPU:1"}>
 // CHECK:      tf.OpA
-// CHECK:      {device = "/TPU:1"}
 // CHECK:      tf_device.launch
+// CHECK:      <{device = "/TPU:2"}>
 // CHECK:      tf.OpB
-// CHECK:      {device = "/TPU:2"}
 // CHECK:      _parallel_execution_ids = "r0:1"
 // CHECK:      tf_executor.fetch
 
@@ -343,9 +349,11 @@
 }
 
 // CHECK: "tf.opA"
-// CHECK: _parallel_execution_ids = "r4:5,r0:0", device = "/CPU:0"
+// device = "/CPU:0"
+// CHECK: _parallel_execution_ids = "r4:5,r0:0"
 // CHECK: "tf.opA"
-// CHECK: _parallel_execution_ids = "r4:5,r0:1", device = "/GPU:1"
+// device = "/GPU:1"
+// CHECK: _parallel_execution_ids = "r4:5,r0:1"
 
 // -----
 
@@ -418,10 +426,14 @@
   func.return
 }
 // CHECK: "tf.opA"
-// CHECK: _parallel_execution_ids = "r0:0", device = "/TPU:0"
+// device = "/TPU:0"
+// CHECK: _parallel_execution_ids = "r0:0"
 // CHECK: "tf.opA"
-// CHECK: _parallel_execution_ids = "r0:1", device = "/TPU:0"
+// device = "/TPU:0"
+// CHECK: _parallel_execution_ids = "r0:1"
 // CHECK: "tf.opA"
-// CHECK: _parallel_execution_ids = "r1:0", device = "/TPU:1"
+// device = "/TPU:1"
+// CHECK: _parallel_execution_ids = "r1:0"
 // CHECK: "tf.opA"
-// CHECK: _parallel_execution_ids = "r1:1", device = "/TPU:1"
+// device = "/TPU:1"
+// CHECK: _parallel_execution_ids = "r1:1"
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island_legacy.mlir b/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island_legacy.mlir
index 24d498e..2c47b08 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island_legacy.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island_legacy.mlir
@@ -43,10 +43,12 @@
   func.return
 }
 
-// CHECK: "tf.opA"
+// CHECK: "tf_device.launch"
 // CHECK: device = "CORE_0"
 // CHECK: "tf.opA"
+// CHECK: "tf_device.launch"
 // CHECK: device = "CORE_0"
+// CHECK: "tf.opA"
 
 
 // Tests devices are not remapped if device is not in replicate devices.
@@ -68,10 +70,12 @@
   func.return
 }
 
-// CHECK: "tf.opA"
+// CHECK: "tf_device.launch"
 // CHECK: device = "/TPU:2"
 // CHECK: "tf.opA"
+// CHECK: "tf_device.launch"
 // CHECK: device = "/TPU:2"
+// CHECK: "tf.opA"
 
 
 // Tests devices are remapped if device is in replicate devices.
@@ -93,10 +97,12 @@
   func.return
 }
 
-// CHECK: "tf.opA"
+// CHECK: "tf_device.launch"
 // CHECK: device = "/CPU:0"
 // CHECK: "tf.opA"
+// CHECK: "tf_device.launch"
 // CHECK: device = "/GPU:1"
+// CHECK: "tf.opA"
 
 
 // Tests replicate with control dependency output has each expanded replica
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir
index 02c1444..9c00c8e 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir
@@ -221,7 +221,7 @@
 
 // CHECK-LABEL: func @cluster_with_loop
 func.func @cluster_with_loop() -> () {
-  // CHECK: %[[COUNT:.*]] = "tf.Const"() {value = dense<10> : tensor<i32>}
+  // CHECK: %[[COUNT:.*]] = "tf.Const"() <{value = dense<10> : tensor<i32>}>
   %0 = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
   // CHECK: %[[VH:.*]] = "tf.VarHandleOp"()
   %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf_type.resource<tensor<f32>>>
@@ -253,7 +253,7 @@
   // CHECK-NEXT: %[[ADD1:.*]] = "tf.AddV2"(%[[ADD0]], %[[ADD0]])
   %add1 = "tf.AddV2"(%read1, %read1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
   "tf.AssignVariableOp"(%arg1, %add1) : (tensor<*x!tf_type.resource<tensor<f32>>>, tensor<f32>) -> ()
-  // CHECK-NEXT: %[[DELTA:.*]] = "tf.Const"() {value = dense<-1> : tensor<i32>}
+  // CHECK-NEXT: %[[DELTA:.*]] = "tf.Const"() <{value = dense<-1> : tensor<i32>}>
   %constant = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
   // CHECK-NEXT: %[[ADD2:.*]] = "tf.AddV2"(%[[BARG0]], %[[DELTA]])
   %add2 = "tf.AddV2"(%arg0, %constant) : (tensor<i32>, tensor<i32>) -> tensor<i32>
@@ -299,7 +299,7 @@
   // CHECK-NEXT: return %[[CONST]]
   func.return %arg0 : tensor<*x!tf_type.resource<tensor<f32>>>
 }
-// CHECK: func @while_cond(%arg0: tensor<f32>)
+// CHECK: func @while_cond(%[[CARG0:.*]]: tensor<f32>)
 func.func @while_cond(%arg0: tensor<*x!tf_type.resource<tensor<f32>>>) -> tensor<f32> {
   %id = "tf.Identity"(%arg0) : (tensor<*x!tf_type.resource<tensor<f32>>>) -> tensor<*x!tf_type.resource<tensor<f32>>>
   %read = "tf.ReadVariableOp"(%id) : (tensor<*x!tf_type.resource<tensor<f32>>>) -> tensor<f32>
@@ -935,7 +935,7 @@
 
 // CHECK-LABEL: func @cluster_with_whileregion
 func.func @cluster_with_whileregion() -> () {
-  // CHECK: %[[COUNT:.*]] = "tf.Const"() {value = dense<10> : tensor<i32>}
+  // CHECK: %[[COUNT:.*]] = "tf.Const"() <{value = dense<10> : tensor<i32>}>
   // CHECK: %[[VH:.*]] = "tf.VarHandleOp"()
   // CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VH]])
   // CHECK: %[[CLUSTER:.*]] = "tf_device.cluster"()
@@ -959,7 +959,7 @@
             // CHECK: (%[[BARG0:.+]]: tensor<i32>, %[[BARG1:.+]]: tensor<f32>):
             // CHECK: %[[ADD0:.*]] = "tf.AddV2"(%[[BARG1]], %[[BARG1]])
             // CHECK-NEXT: %[[ADD1:.*]] = "tf.AddV2"(%[[ADD0]], %[[ADD0]])
-            // CHECK-NEXT: %[[DELTA:.*]] = "tf.Const"() {value = dense<-1> : tensor<i32>}
+            // CHECK-NEXT: %[[DELTA:.*]] = "tf.Const"() <{value = dense<-1> : tensor<i32>}>
             // CHECK-NEXT: %[[ADD2:.*]] = "tf.AddV2"(%[[BARG0]], %[[DELTA]])
             // CHECK-NEXT: "tf.Yield"(%[[ADD2]], %[[ADD1]])
             ^bb1(%barg0: tensor<i32>, %barg1: !tf_ref, %barg2: !tf_ref, %barg3: !tf_ref):
@@ -1046,7 +1046,7 @@
   %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf_type.resource<tensor<4xf32>>>
   "tf_device.cluster"() ({
     "tf.IfRegion"(%arg0) ({
-       // CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<0.000000e+00>
+       // CHECK: %[[CONST:.*]] = "tf.Const"() <{value = dense<0.000000e+00>
        // CHECK: "tf.Yield"(%[[CONST]])
        %constant = "tf.Const"() {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32>
        "tf.AssignVariableOp"(%0, %constant) : (tensor<*x!tf_type.resource<tensor<4xf32>>>, tensor<4xf32>) -> ()
@@ -1074,13 +1074,13 @@
   %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf_type.resource<tensor<4xf32>>>
   "tf_device.cluster"() ({
     "tf.IfRegion"(%arg0) ({
-       // CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<0.000000e+00>
+       // CHECK: %[[CONST:.*]] = "tf.Const"() <{value = dense<0.000000e+00>
        // CHECK: "tf.Yield"(%[[CONST]])
        %constant = "tf.Const"() {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32>
        "tf.AssignVariableOp"(%0, %constant) : (tensor<*x!tf_type.resource<tensor<4xf32>>>, tensor<4xf32>) -> ()
        "tf.Yield"() : () -> ()
       }, {
-       // CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<1.000000e+00>
+       // CHECK: %[[CONST:.*]] = "tf.Const"() <{value = dense<1.000000e+00>
        // CHECK: "tf.Yield"(%[[CONST]])
        %constant = "tf.Const"() {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32>
        "tf.AssignVariableOp"(%0, %constant) : (tensor<*x!tf_type.resource<tensor<4xf32>>>, tensor<4xf32>) -> ()
@@ -1123,8 +1123,8 @@
 // to not be lifted and arg1 to be lifted.
 // CHECK-LABEL: func @test_unsupported_resource_op_in_if
 func.func @test_unsupported_resource_op_in_if(%arg0: tensor<i1>) -> tensor<*xi32> {
-  // CHECK: [[VH0:%.*]] = "tf.VarHandleOp"() {container = "c", shared_name = "v"}
-  // CHECK: [[VH1:%.*]] = "tf.VarHandleOp"() {container = "d", shared_name = "w"}
+  // CHECK: [[VH0:%.*]] = "tf.VarHandleOp"() <{container = "c", shared_name = "v"}>
+  // CHECK: [[VH1:%.*]] = "tf.VarHandleOp"() <{container = "d", shared_name = "w"}>
   // CHECK-NOT: "tf.ReadVariableOp"([[VH0]])
   // CHECK: [[READ1:%.*]] = "tf.ReadVariableOp"([[VH1]])
   // CHECK-NOT: "tf.ReadVariableOp"([[VH0]])
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/rewrite_tpu_embedding_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/rewrite_tpu_embedding_ops.mlir
index 0f1344c..0099129 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/rewrite_tpu_embedding_ops.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/rewrite_tpu_embedding_ops.mlir
@@ -2,8 +2,8 @@
 
 // CHECK-LABEL: func @recv_tpu_embedding_activations
 func.func @recv_tpu_embedding_activations() -> (tensor<512x256xf32>) {
-  // CHECK: %[[DATA:.*]] = "tf.XlaRecvTPUEmbeddingDeduplicationData"() {config = {{.*}}} : () -> tensor<!tf_type.variant>
-  // CHECK: %[[RESULT:.*]] = "tf.XlaRecvTPUEmbeddingActivations"(%[[DATA]]) {config = {{.*}}} : (tensor<!tf_type.variant>) -> tensor<512x256xf32>
+  // CHECK: %[[DATA:.*]] = "tf.XlaRecvTPUEmbeddingDeduplicationData"() <{config = {{.*}}}> : () -> tensor<!tf_type.variant>
+  // CHECK: %[[RESULT:.*]] = "tf.XlaRecvTPUEmbeddingActivations"(%[[DATA]]) <{config = {{.*}}}> : (tensor<!tf_type.variant>) -> tensor<512x256xf32>
   // CHECK: return %[[RESULT]]
   // CHECK-NOT: tf.RecvTPUEmbeddingActivations
   // CHECK-NOT: tf.SendTPUEmbeddingGradients
@@ -14,8 +14,8 @@
 
 // CHECK-LABEL: func @send_tpu_embedding_gradients
 func.func @send_tpu_embedding_gradients(%arg0: tensor<512x256xf32>) -> () {
-  // CHECK: %[[DATA:.*]] = "tf.XlaRecvTPUEmbeddingDeduplicationData"() {config = {{.*}}} : () -> tensor<!tf_type.variant>
-  // CHECK: "tf.XlaSendTPUEmbeddingGradients"(%arg0, %[[DATA]]) {config = {{.*}}, operandSegmentSizes = array<i32: 1, 0, 1>} : (tensor<512x256xf32>, tensor<!tf_type.variant>) -> ()
+  // CHECK: %[[DATA:.*]] = "tf.XlaRecvTPUEmbeddingDeduplicationData"() <{config = {{.*}}}> : () -> tensor<!tf_type.variant>
+  // CHECK: "tf.XlaSendTPUEmbeddingGradients"(%arg0, %[[DATA]]) <{config = {{.*}}, operandSegmentSizes = array<i32: 1, 0, 1>}> : (tensor<512x256xf32>, tensor<!tf_type.variant>) -> ()
   // CHECK-NOT: tf.SendTPUEmbeddingGradients
   // CHECK-NOT: tf.RecvTPUEmbeddingActivations
 
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir
index 45d5ba9..238157b 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir
@@ -131,7 +131,8 @@
   // CHECK-SAME: -> tensor<1x2x3xf32>
   func.func @shape_from_if_to_region_bodies_to_output(%arg0: tensor<i1>, %arg1: tensor<1x2x3xf32>) -> tensor<*xf32> {
     %unshaped = "tf.Cast"(%arg1) : (tensor<1x2x3xf32>) -> tensor<*xf32>
-    %0 = "tf.IfRegion"(%arg0) ({
+    // CHECK: <{is_stateless = true}>
+    %0 = "tf.IfRegion"(%arg0) <{is_stateless = true}> ({
       // CHECK: "tf.Add"{{.+}}(tensor<1x2x3xf32>, tensor<1x2x3xf32>) -> tensor<1x2x3xf32>
       // CHECK: "tf.Yield"{{.+}}(tensor<1x2x3xf32>) -> ()
       %1 = "tf.Add"(%unshaped, %unshaped) : (tensor<*xf32>,  tensor<*xf32>) -> tensor<*xf32>
@@ -141,8 +142,8 @@
       // CHECK: "tf.Yield"{{.+}}(tensor<1x2x3xf32>) -> ()
       %2 = "tf.Sub"(%unshaped, %unshaped) : (tensor<*xf32>,  tensor<*xf32>) -> tensor<*xf32>
       "tf.Yield"(%2) : (tensor<*xf32>) -> ()
-      // CHECK: {is_stateless = true} : (tensor<i1>) -> tensor<1x2x3xf32>
-     }) {is_stateless = true} : (tensor<i1>) -> tensor<*xf32>
+      // CHECK: (tensor<i1>) -> tensor<1x2x3xf32>
+     }) : (tensor<i1>) -> tensor<*xf32>
     // CHECK: return {{.*}} :  tensor<1x2x3xf32>
     func.return %0 : tensor<*xf32>
   }
@@ -176,7 +177,8 @@
   // CHECK-SAME: -> tensor<1x2x3xf32>
   func.func @shape_from_case_to_region_bodies_to_output(%arg0: tensor<i32>, %arg1: tensor<1x2x3xf32>) -> tensor<*xf32> {
     %unshaped = "tf.Cast"(%arg1) : (tensor<1x2x3xf32>) -> tensor<*xf32>
-    %0 = "tf.CaseRegion"(%arg0) ({
+    // CHECK: <{is_stateless = true}>
+    %0 = "tf.CaseRegion"(%arg0) <{is_stateless = true}> ({
       // CHECK: "tf.Add"{{.+}}(tensor<1x2x3xf32>, tensor<1x2x3xf32>) -> tensor<1x2x3xf32>
       // CHECK: "tf.Yield"{{.+}}(tensor<1x2x3xf32>) -> ()
       %1 = "tf.Add"(%unshaped, %unshaped) : (tensor<*xf32>,  tensor<*xf32>) -> tensor<*xf32>
@@ -186,8 +188,8 @@
       // CHECK: "tf.Yield"{{.+}}(tensor<1x2x3xf32>) -> ()
       %2 = "tf.Sub"(%unshaped, %unshaped) : (tensor<*xf32>,  tensor<*xf32>) -> tensor<*xf32>
       "tf.Yield"(%2) : (tensor<*xf32>) -> ()
-      // CHECK: {is_stateless = true} : (tensor<i32>) -> tensor<1x2x3xf32>
-     }) {is_stateless = true} : (tensor<i32>) -> tensor<*xf32>
+      // CHECK: (tensor<i32>) -> tensor<1x2x3xf32>
+     }) : (tensor<i32>) -> tensor<*xf32>
     // CHECK: return {{.*}} :  tensor<1x2x3xf32>
     func.return %0 : tensor<*xf32>
   }
@@ -243,7 +245,8 @@
   func.func @shape_from_while_operands_to_cond_body_to_while_results(%arg0: tensor<i32>, %arg1: tensor<1x2x3xf32>) ->  tensor<*xf32> {
     %unshaped = "tf.Cast"(%arg1) : (tensor<1x2x3xf32>) -> tensor<*xf32>
     // CHECK: "tf.WhileRegion"
-    %0:2 = "tf.WhileRegion"(%arg0, %unshaped) ({
+    // CHECK: <{is_stateless = true}>
+    %0:2 = "tf.WhileRegion"(%arg0, %unshaped) <{is_stateless = true}> ({
        // CHECK: {{.*}}({{.+}}: tensor<i32>, {{.+}}: tensor<1x2x3xf32>):
        ^bb0(%carg0: tensor<i32>, %carg1: tensor<*xf32>):
          %limit = arith.constant dense<5> : tensor<i32>
@@ -258,8 +261,8 @@
         %neg = "tf.Neg"(%barg1) : (tensor<*xf32>) -> tensor<*xf32>
         // CHECK: "tf.Yield"{{.+}}, {{.+}}) : (tensor<i32>, tensor<1x2x3xf32>) -> ()
         "tf.Yield"(%sub, %neg) : (tensor<i32>, tensor<*xf32>) -> ()
-    // CHECK: {is_stateless = true} : (tensor<i32>, tensor<1x2x3xf32>) -> (tensor<i32>, tensor<1x2x3xf32>)
-    }) {is_stateless = true} : (tensor<i32>, tensor<*xf32>) -> (tensor<i32>, tensor<*xf32>)
+    // CHECK: (tensor<i32>, tensor<1x2x3xf32>) -> (tensor<i32>, tensor<1x2x3xf32>)
+    }) : (tensor<i32>, tensor<*xf32>) -> (tensor<i32>, tensor<*xf32>)
     // CHECK: return {{.+}}#1 : tensor<1x2x3xf32>
     func.return %0#1 : tensor<*xf32>
   }
@@ -752,7 +755,7 @@
 
   // CHECK-LABEL: replace_tensor_list_element_shape
   func.func @replace_tensor_list_element_shape() {
-    // CHECK: %[[ELEMENT_SHAPE:.*]] = "tf.Const"() {value = dense<[-1, 1]> : tensor<2xi32>}
+    // CHECK: %[[ELEMENT_SHAPE:.*]] = "tf.Const"() <{value = dense<[-1, 1]> : tensor<2xi32>}>
     %elem_shape = "tf.Const"() {value = dense<[-1, 1]> : tensor<2xi32>} : () -> tensor<2xi32>
     %size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
     %elem = "tf._SomeOp"() : () -> tensor<16x1xf32>
@@ -767,7 +770,7 @@
     "tf._SomeOtherOp"(%shape_32, %shape_64) : (tensor<?xi32>, tensor<?xi64>) -> ()
     func.return
   }
-  
+
   // CHECK-LABEL: refine_pop_back_results_from_operands
   func.func @refine_pop_back_results_from_operands(%arg0: tensor<!tf_type.variant<tensor<2xi32>>>, %arg1: tensor<1xi32>) -> (tensor<!tf_type.variant>, tensor<*xi32>)  {
     %0, %1 = "tf.TensorListPopBack"(%arg0, %arg1) : (tensor<!tf_type.variant<tensor<2xi32>>>, tensor<1xi32>) -> (tensor<!tf_type.variant>, tensor<*xi32>)
@@ -876,7 +879,7 @@
       tf_device.return %2 : tensor<1x8x2xf32>
     // CHECK: () -> tensor<1x8x2xf32>
     }) {device = "/device:CPU:0"} : () -> tensor<*xf32>
-    // CHECK: "tf.Cast"(%{{.*}}) {Truncate = false} : (tensor<1x8x2xf32>) -> tensor<*xf32>
+    // CHECK: "tf.Cast"(%{{.*}}) <{Truncate = false}> : (tensor<1x8x2xf32>) -> tensor<*xf32>
     // CHECK: (tensor<i32>, tensor<1x8x2xf32>) -> (tensor<1x8x1xf32>, tensor<1x8x1xf32>)
     %3:2 = "tf.Split"(%0, %1) {device = ""} : (tensor<i32>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>)
     %4 = tensor.cast %1 : tensor<*xf32> to tensor<?x?x?xf32>
@@ -891,7 +894,7 @@
       tf_device.return %2 : tensor<1x8x2xf32>
     // CHECK: () -> tensor<1x8x2xf32>
     }) : () -> tensor<*xf32>
-    // CHECK: "tf.Cast"(%{{.*}}) {Truncate = false} : (tensor<1x8x2xf32>) -> tensor<*xf32>
+    // CHECK: "tf.Cast"(%{{.*}}) <{Truncate = false}> : (tensor<1x8x2xf32>) -> tensor<*xf32>
     // CHECK: (tensor<i32>, tensor<1x8x2xf32>) -> (tensor<1x8x1xf32>, tensor<1x8x1xf32>)
     %3:2 = "tf.Split"(%0, %1) {device = ""} : (tensor<i32>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>)
     %4 = tensor.cast %1 : tensor<*xf32> to tensor<?x?x?xf32>
@@ -958,6 +961,7 @@
     func.return %0 : tensor<*xi32>
   }
 
+  // Test fetch and yield are diectly assigned to island and graph ops results.
   // CHECK-LABEL: func @call_in_graph_func({{%.+}}: tensor<i32>) -> tensor<i32>
   func.func @call_in_graph_func(%arg0: tensor<*xi32>) -> tensor<*xi32> {
     // CHECK-NOT: tf.Cast
@@ -968,6 +972,27 @@
     func.return %0 : tensor<*xi32>
   }
 
+  // CHECK-LABEL: func @call_in_graph_1
+  func.func @call_in_graph_1(%arg0: tensor<?x?x?x?xbf16>, %arg1: tensor<5x5x1x32xbf16>) -> tensor<*xbf16> {
+    // CHECK: tf_executor.fetch %outputs : tensor<?x?x?x32xbf16>
+    %0 = tf_executor.graph {
+      %1:2 = tf_executor.island wraps "tf.PartitionedCall"(%arg0, %arg1) {
+        config = "", config_proto = "", executor_type = "", f = @call_in_graph_func_1} : (tensor<?x?x?x?xbf16>, tensor<5x5x1x32xbf16>) -> tensor<*xbf16>
+      tf_executor.fetch %1#0 : tensor<*xbf16>
+    }
+    func.return %0 : tensor<*xbf16>
+  }
+
+  // CHECK-LABEL: func @call_in_graph_func_1
+  func.func @call_in_graph_func_1(%arg0: tensor<?x28x28x1xbf16>, %arg1: tensor<5x5x1x32xbf16>) -> tensor<?x28x28x?xbf16> {
+    // CHECK: tf_executor.fetch %outputs : tensor<?x?x?x32xbf16>
+    %0 = tf_executor.graph {
+      %1:2 = tf_executor.island wraps "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}: (tensor<?x28x28x1xbf16>, tensor<5x5x1x32xbf16>) -> tensor<?x28x28x?xbf16>
+      tf_executor.fetch %1#0 : tensor<?x28x28x?xbf16>
+    }
+    func.return %0 : tensor<?x28x28x?xbf16>
+  }
+
   // Test shape invariant While only propagates operand handle types into
   // results and functions/regions.
   // CHECK-LABEL: func @while_shape_invariant_propagate
@@ -980,7 +1005,8 @@
     %0:4 = "tf.While"(%arg0, %arg1, %arg2, %arg3) {cond = @while_shape_invariant_cond_func_propagate, body = @while_shape_invariant_body_func_propagate, is_stateless = false, shape_invariant} : (tensor<4xf32>, tensor<!tf_type.resource<tensor<4xf32>>>, tensor<!tf_type.resource<tensor<8xf32>>>, tensor<1xi32>) -> (tensor<*xf32>, tensor<*x!tf_type.resource>, tensor<!tf_type.resource>, tensor<?xi32>)
 
     // CHECK: "tf.WhileRegion"
-    %1:4 = "tf.WhileRegion"(%arg0, %arg1, %arg2, %arg3) ({
+    // CHECK-SAME: shape_invariant
+    %1:4 = "tf.WhileRegion"(%arg0, %arg1, %arg2, %arg3) <{is_stateless = false, shape_invariant}> ({
     // CHECK-NEXT: ^{{.+}}({{%.+}}: tensor<*xf32>, {{%.+}}: tensor<*x!tf_type.resource<tensor<4xf32>>>, {{%.+}}: tensor<!tf_type.resource<tensor<8xf32>>>, {{%.+}}: tensor<?xi32>):
     ^cond(%carg0: tensor<*xf32>, %carg1: tensor<*x!tf_type.resource>, %carg2: tensor<!tf_type.resource>, %carg3: tensor<?xi32>):
       %2 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
@@ -992,10 +1018,9 @@
       // CHECK: "tf.Yield"
       // CHECK-SAME: (tensor<*xf32>, tensor<*x!tf_type.resource<tensor<4xf32>>>, tensor<!tf_type.resource<tensor<8xf32>>>, tensor<?xi32>) -> ()
       "tf.Yield"(%barg0, %barg1, %barg2, %2) : (tensor<*xf32>, tensor<*x!tf_type.resource>, tensor<!tf_type.resource>, tensor<?xi32>) -> ()
-    // CHECK-NEXT: shape_invariant
-    // CHECK-SAME: (tensor<4xf32>, tensor<!tf_type.resource<tensor<4xf32>>>, tensor<!tf_type.resource<tensor<8xf32>>>, tensor<1xi32>)
+    // CHECK-NEXT: (tensor<4xf32>, tensor<!tf_type.resource<tensor<4xf32>>>, tensor<!tf_type.resource<tensor<8xf32>>>, tensor<1xi32>)
     // CHECK-SAME: -> (tensor<*xf32>, tensor<*x!tf_type.resource<tensor<4xf32>>>, tensor<!tf_type.resource<tensor<8xf32>>>, tensor<?xi32>)
-    }) {is_stateless = false, shape_invariant} : (tensor<4xf32>, tensor<!tf_type.resource<tensor<4xf32>>>, tensor<!tf_type.resource<tensor<8xf32>>>, tensor<1xi32>) -> (tensor<*xf32>, tensor<*x!tf_type.resource>, tensor<!tf_type.resource>, tensor<?xi32>)
+    }) : (tensor<4xf32>, tensor<!tf_type.resource<tensor<4xf32>>>, tensor<!tf_type.resource<tensor<8xf32>>>, tensor<1xi32>) -> (tensor<*xf32>, tensor<*x!tf_type.resource>, tensor<!tf_type.resource>, tensor<?xi32>)
 
     func.return %0#0, %0#1, %0#2, %0#3, %1#0, %1#1, %1#2, %1#3 : tensor<*xf32>, tensor<*x!tf_type.resource>, tensor<!tf_type.resource>, tensor<?xi32>, tensor<*xf32>, tensor<*x!tf_type.resource>, tensor<!tf_type.resource>, tensor<?xi32>
   }
@@ -1028,7 +1053,8 @@
     %0 = "tf.While"(%arg0) {cond = @while_shape_invariant_cond_func_different_dims, body = @while_shape_invariant_body_func_different_dims, is_stateless = false, shape_invariant} : (tensor<1x2x3xf32>) -> tensor<1x8x3xf32>
 
     // CHECK: "tf.WhileRegion"
-    %1 = "tf.WhileRegion"(%arg0) ({
+    // CHECK-SAME: shape_invariant
+    %1 = "tf.WhileRegion"(%arg0) <{is_stateless = false, shape_invariant}> ({
     // CHECK-NEXT: ^{{.+}}({{%.+}}: tensor<1x?x3xf32>):
     ^cond(%carg0: tensor<*xf32>):
       %2 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
@@ -1040,10 +1066,9 @@
       // CHECK: "tf.Yield"
       // CHECK-SAME: (tensor<1x?x3xf32>) -> ()
       "tf.Yield"(%2) : (tensor<*xf32>) -> ()
-    // CHECK-NEXT: shape_invariant
-    // CHECK-SAME: (tensor<1x2x3xf32>)
+    // CHECK-NEXT: (tensor<1x2x3xf32>)
     // CHECK-SAME: -> tensor<1x8x3xf32>
-    }) {is_stateless = false, shape_invariant} : (tensor<1x2x3xf32>) -> tensor<1x8x3xf32>
+    }) : (tensor<1x2x3xf32>) -> tensor<1x8x3xf32>
 
     func.return %0, %1 : tensor<1x8x3xf32>, tensor<1x8x3xf32>
   }
@@ -1076,7 +1101,8 @@
     %0 = "tf.While"(%arg0) {cond = @while_shape_invariant_cond_func_body_result_propagate, body = @while_shape_invariant_body_func_body_result_propagate, is_stateless = false, shape_invariant} : (tensor<*x!tf_type.resource<tensor<f32>>>) -> tensor<*x!tf_type.resource>
 
     // CHECK: "tf.WhileRegion"
-    %1 = "tf.WhileRegion"(%arg0) ({
+    // CHECK-SAME: shape_invariant
+    %1 = "tf.WhileRegion"(%arg0) <{is_stateless = false, shape_invariant}> ({
     // CHECK-NEXT: ^{{.+}}({{%.+}}: tensor<*x!tf_type.resource<tensor<f32>>>):
     ^cond(%carg0: tensor<*x!tf_type.resource<tensor<f32>>>):
       %2 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
@@ -1088,10 +1114,9 @@
       // CHECK: "tf.Yield"
       // CHECK-SAME: (tensor<*x!tf_type.resource<tensor<f32>>>) -> ()
       "tf.Yield"(%2) : (tensor<*x!tf_type.resource<tensor<f32>>>) -> ()
-    // CHECK-NEXT: shape_invariant
-    // CHECK-SAME: (tensor<*x!tf_type.resource<tensor<f32>>>)
+    // CHECK-NEXT: (tensor<*x!tf_type.resource<tensor<f32>>>)
     // CHECK-SAME: -> tensor<*x!tf_type.resource<tensor<f32>>>
-    }) {is_stateless = false, shape_invariant} : (tensor<*x!tf_type.resource<tensor<f32>>>) -> tensor<*x!tf_type.resource>
+    }) : (tensor<*x!tf_type.resource<tensor<f32>>>) -> tensor<*x!tf_type.resource>
 
     func.return %0, %1 : tensor<*x!tf_type.resource>, tensor<*x!tf_type.resource>
   }
@@ -1333,7 +1358,7 @@
     %cst = "tf.Const"() {value = dense<0> : tensor<4x2xi32>} : () -> tensor<4x2xi32>
     %cst_0 = "tf.Const"() {value = dense<[2, 2, 1, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
     %cst_1 = "tf.Const"() {value = dense<[2, 3, 1, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
-    // CHECK: %0 = "tf.XlaSelectAndScatter"(%arg0, %cst_1, %cst_0, %cst, %arg1, %arg2) {scatter = @add_scatter, select = @ge_select} : (tensor<4x5x1x1xbf16>, tensor<4xi32>, tensor<4xi32>, tensor<4x2xi32>, tensor<2x2x1x1xbf16>, tensor<bf16>) -> tensor<4x5x1x1xbf16>
+    // CHECK: %0 = "tf.XlaSelectAndScatter"(%arg0, %cst_1, %cst_0, %cst, %arg1, %arg2) <{scatter = @add_scatter, select = @ge_select}> : (tensor<4x5x1x1xbf16>, tensor<4xi32>, tensor<4xi32>, tensor<4x2xi32>, tensor<2x2x1x1xbf16>, tensor<bf16>) -> tensor<4x5x1x1xbf16>
     %0 = "tf.XlaSelectAndScatter"(%arg0, %cst_1, %cst_0, %cst, %arg1, %arg2) {scatter = @add_scatter, select = @ge_select} : (tensor<4x5x1x1xbf16>, tensor<4xi32>, tensor<4xi32>, tensor<4x2xi32>, tensor<2x2x1x1xbf16>, tensor<bf16>) -> tensor<?x?x?x?xbf16>
     func.return %0 : tensor<?x?x?x?xbf16>
   }
@@ -1373,7 +1398,7 @@
     %cst_1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
     %cst_2 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
     %cst_3 = "tf.Const"() {value = dense<4> : tensor<1xi32>} : () -> tensor<1xi32>
-    // CHECK: 0 = "tf.XlaReduceWindow"(%arg0, %arg1, %cst_0, %cst_1, %cst_2, %cst_3, %cst) {computation = @sum_reducer3} : (tensor<7xf32>, tensor<f32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<10xf32>
+    // CHECK: 0 = "tf.XlaReduceWindow"(%arg0, %arg1, %cst_0, %cst_1, %cst_2, %cst_3, %cst) <{computation = @sum_reducer3}> : (tensor<7xf32>, tensor<f32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<10xf32>
     %0 = "tf.XlaReduceWindow"(%arg0, %arg1, %cst_0, %cst_1, %cst_2, %cst_3, %cst) {computation = @sum_reducer3} : (tensor<7xf32>, tensor<f32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<?xf32>
     func.return %0 : tensor<?xf32>
   }
@@ -1813,7 +1838,7 @@
   func.func @infer_var_handle_op_from_assigns() -> tensor<1xi8> {
     %cst = arith.constant dense<1> : tensor<1xi8>
     %0 = "tf.VarHandleOp"() {container = "", shared_name = "bar"} : () -> tensor<!tf_type.resource<tensor<*xi8>>>
-    // CHECK: "tf.VarHandleOp"() {container = "", shared_name = "bar"} : () -> tensor<!tf_type.resource<tensor<1xi8>>>
+    // CHECK: "tf.VarHandleOp"() <{container = "", shared_name = "bar"}> : () -> tensor<!tf_type.resource<tensor<1xi8>>>
     "tf.AssignVariableOp"(%0, %cst) : (tensor<!tf_type.resource<tensor<*xi8>>>, tensor<1xi8>) -> ()
     func.return %cst : tensor<1xi8>
   }
@@ -1822,7 +1847,7 @@
   func.func @infer_var_handle_op_from_read() -> tensor<1xi8> {
     %cst = arith.constant dense<1> : tensor<1xi8>
     %0 = "tf.VarHandleOp"() {container = "", shared_name = "bar"} : () -> tensor<!tf_type.resource<tensor<*xi8>>>
-    // CHECK: "tf.VarHandleOp"() {container = "", shared_name = "bar"} : () -> tensor<!tf_type.resource<tensor<1xi8>>>
+    // CHECK: "tf.VarHandleOp"() <{container = "", shared_name = "bar"}> : () -> tensor<!tf_type.resource<tensor<1xi8>>>
     %read = "tf.ReadVariableOp"(%0) : (tensor<!tf_type.resource<tensor<*xi8>>>) -> tensor<1xi8>
     func.return %read : tensor<1xi8>
   }
@@ -1831,7 +1856,7 @@
   func.func @do_not_infer_var_handle_op_when_custom_op_uses_it() -> tensor<1xi8> {
     %cst = arith.constant dense<1> : tensor<1xi8>
     %0 = "tf.VarHandleOp"() {container = "", shared_name = "bar"} : () -> tensor<!tf_type.resource<tensor<*xi8>>>
-    // CHECK: "tf.VarHandleOp"() {container = "", shared_name = "bar"} : () -> tensor<!tf_type.resource<tensor<*xi8>>>
+    // CHECK: "tf.VarHandleOp"() <{container = "", shared_name = "bar"}> : () -> tensor<!tf_type.resource<tensor<*xi8>>>
     %read = "tf.ReadVariableOp"(%0) : (tensor<!tf_type.resource<tensor<*xi8>>>) -> tensor<1xi8>
     %1 = "tf.MyCustomOp"(%0) : (tensor<!tf_type.resource<tensor<*xi8>>>) -> tensor<4xi8>
     func.return %read : tensor<1xi8>
@@ -1916,7 +1941,7 @@
     %lhs_dilation = "tf.Const"() {value = dense<[4, 1, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
     %padding = "tf.Const"() {value = dense<0> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
     %strides = "tf.Const"() {value = dense<[3, 1, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
-    // CHECK: %0 = "tf.XlaConvV2"(%arg0, %arg1, %cst_3, %cst_2, %cst_1, %cst_0, %cst) {dimension_numbers = "\18\03 \042\03\00\01\02@\04P\04Z\03\01\02\03b\03\01\02\03", precision_config = ""} : (tensor<8x?x?x?x16xf32>, tensor<4x3x3x16x16xf32>, tensor<3xi32>, tensor<3x2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<8x4x14x14x16xf32>
+    // CHECK: %0 = "tf.XlaConvV2"(%arg0, %arg1, %cst_3, %cst_2, %cst_1, %cst_0, %cst) <{dimension_numbers = "\18\03 \042\03\00\01\02@\04P\04Z\03\01\02\03b\03\01\02\03", precision_config = ""}> : (tensor<8x?x?x?x16xf32>, tensor<4x3x3x16x16xf32>, tensor<3xi32>, tensor<3x2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<8x4x14x14x16xf32>
     %0 = "tf.XlaConvV2"(%lhs, %rhs, %strides, %padding, %lhs_dilation, %rhs_dilation, %feature_group_count) {dimension_numbers = "\18\03 \042\03\00\01\02@\04P\04Z\03\01\02\03b\03\01\02\03", precision_config = ""} : (tensor<8x?x?x?x16xf32>, tensor<4x3x3x16x16xf32>, tensor<3xi32>, tensor<3x2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<8x4x14x14x16xf32>
     func.return %0 : tensor<8x4x14x14x16xf32>
   }
@@ -1928,7 +1953,7 @@
     %lhs_dilation = "tf.Const"() {value = dense<[4, 1, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
     %padding = "tf.Const"() {value = dense<0> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
     %strides = "tf.Const"() {value = dense<[3, 1, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
-    // CHECK: %0 = "tf.XlaConvV2"(%arg0, %arg1, %cst_3, %cst_2, %cst_1, %cst_0, %cst) {dimension_numbers = "\18\03 \042\03\00\01\02@\04P\04Z\03\01\02\03b\03\01\02\03", precision_config = ""} : (tensor<8x4x16x16x16xf32>, tensor<?x?x?x16x16xf32>, tensor<3xi32>, tensor<3x2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<8x4x14x14x16xf32>
+    // CHECK: %0 = "tf.XlaConvV2"(%arg0, %arg1, %cst_3, %cst_2, %cst_1, %cst_0, %cst) <{dimension_numbers = "\18\03 \042\03\00\01\02@\04P\04Z\03\01\02\03b\03\01\02\03", precision_config = ""}> : (tensor<8x4x16x16x16xf32>, tensor<?x?x?x16x16xf32>, tensor<3xi32>, tensor<3x2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<8x4x14x14x16xf32>
     %0 = "tf.XlaConvV2"(%lhs, %rhs, %strides, %padding, %lhs_dilation, %rhs_dilation, %feature_group_count) {dimension_numbers = "\18\03 \042\03\00\01\02@\04P\04Z\03\01\02\03b\03\01\02\03", precision_config = ""} : (tensor<8x4x16x16x16xf32>, tensor<?x?x?x16x16xf32>, tensor<3xi32>, tensor<3x2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<8x4x14x14x16xf32>
     func.return %0 : tensor<8x4x14x14x16xf32>
   }
@@ -1941,7 +1966,7 @@
     %cst_2 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
     %cst_3 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
     %0 = tf_executor.graph {
-      // CHECK: "tf.XlaConvV2"(%arg0, %arg1, %cst, %cst_0, %cst_1, %cst_2, %cst_3) {_XlaHasReferenceVars = false, device = "/job:localhost/replica:0/task:0/device:XLA_CPU:0", dimension_numbers = "\18\012\01\02@\01P\01Z\01\02b\01\02", precision_config = "\0A\02\01\01"} : (tensor<*xf32>, tensor<*xf32>, tensor<1xi32>, tensor<1x2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<i32>) -> tensor<*xf32>
+      // CHECK: "tf.XlaConvV2"(%arg0, %arg1, %cst, %cst_0, %cst_1, %cst_2, %cst_3) <{dimension_numbers = "\18\012\01\02@\01P\01Z\01\02b\01\02", precision_config = "\0A\02\01\01"}> {_XlaHasReferenceVars = false, device = "/job:localhost/replica:0/task:0/device:XLA_CPU:0"} : (tensor<*xf32>, tensor<*xf32>, tensor<1xi32>, tensor<1x2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<i32>) -> tensor<*xf32>
       %outputs, %control = tf_executor.island wraps "tf.XlaConvV2"(%arg0, %arg1, %cst, %cst_0, %cst_1, %cst_2, %cst_3) {_XlaHasReferenceVars = false, device = "/job:localhost/replica:0/task:0/device:XLA_CPU:0", dimension_numbers = "\18\012\01\02@\01P\01Z\01\02b\01\02", precision_config = "\0A\02\01\01"} : (tensor<*xf32>, tensor<*xf32>, tensor<1xi32>, tensor<1x2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<i32>) -> tensor<*xf32>
       tf_executor.fetch %outputs : tensor<*xf32>
     }
@@ -1955,7 +1980,7 @@
     %lhs_dilation = "tf.Const"() {value = dense<[4, 1, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
     %padding = "tf.Const"() {value = dense<0> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
     %strides = "tf.Const"() {value = dense<[3, 1, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
-    // CHECK: %0 = "tf.XlaConvV2"(%arg0, %arg1, %cst_3, %cst_2, %cst_1, %cst_0, %cst) {dimension_numbers = "\18\03 \042\03\00\01\02@\04P\04Z\03\01\02\03b\03\01\02\03", precision_config = ""} : (tensor<8x4x16x16x16xf32>, tensor<4x3x3x16x16xf32>, tensor<3xi32>, tensor<3x2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<8x4x14x14x16xf32>
+    // CHECK: %0 = "tf.XlaConvV2"(%arg0, %arg1, %cst_3, %cst_2, %cst_1, %cst_0, %cst) <{dimension_numbers = "\18\03 \042\03\00\01\02@\04P\04Z\03\01\02\03b\03\01\02\03", precision_config = ""}> : (tensor<8x4x16x16x16xf32>, tensor<4x3x3x16x16xf32>, tensor<3xi32>, tensor<3x2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<8x4x14x14x16xf32>
     %0 = "tf.XlaConvV2"(%lhs, %rhs, %strides, %padding, %lhs_dilation, %rhs_dilation, %feature_group_count) {dimension_numbers = "\18\03 \042\03\00\01\02@\04P\04Z\03\01\02\03b\03\01\02\03", precision_config = ""} : (tensor<8x4x16x16x16xf32>, tensor<4x3x3x16x16xf32>, tensor<3xi32>, tensor<3x2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<8x4x14x14x16xf32>
     func.return %0 : tensor<8x4x14x14x16xf32>
   }
@@ -1967,7 +1992,7 @@
     %lhs_dilation = "tf.Const"() {value = dense<[4, 1, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
     %padding = "tf.Const"() {value = dense<0> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
     %strides = "tf.Const"() {value = dense<[3, 1, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
-    // CHECK: %0 = "tf.XlaConvV2"(%arg0, %arg1, %cst_3, %cst_2, %cst_1, %cst_0, %cst) {dimension_numbers = "\18\03 \042\03\00\01\02@\04P\04Z\03\01\02\03b\03\01\02\03", precision_config = ""} : (tensor<8x4x16x16x16xf32>, tensor<4x3x3x16x16xf32>, tensor<3xi32>, tensor<3x2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<8x4x14x14x16xf32>
+    // CHECK: %0 = "tf.XlaConvV2"(%arg0, %arg1, %cst_3, %cst_2, %cst_1, %cst_0, %cst) <{dimension_numbers = "\18\03 \042\03\00\01\02@\04P\04Z\03\01\02\03b\03\01\02\03", precision_config = ""}> : (tensor<8x4x16x16x16xf32>, tensor<4x3x3x16x16xf32>, tensor<3xi32>, tensor<3x2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<8x4x14x14x16xf32>
     %0 = "tf.XlaConvV2"(%lhs, %rhs, %strides, %padding, %lhs_dilation, %rhs_dilation, %feature_group_count) {dimension_numbers = "\18\03 \042\03\00\01\02@\04P\04Z\03\01\02\03b\03\01\02\03", precision_config = ""} : (tensor<8x4x16x16x16xf32>, tensor<4x3x3x16x16xf32>, tensor<3xi32>, tensor<3x2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<?x?x?x?x?xf32>
     func.return %0 : tensor<?x?x?x?x?xf32>
   }
@@ -1979,7 +2004,7 @@
     %lhs_dilation = "tf.Const"() {value = dense<[4, 1, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
     %padding = "tf.Const"() {value = dense<0> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
     %strides = "tf.Const"() {value = dense<[3, 1, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
-    // CHECK: %0 = "tf.XlaConvV2"(%arg0, %arg1, %cst_3, %cst_2, %cst_1, %cst_0, %cst) {dimension_numbers = "\18\03 \042\03\00\01\02@\04P\04Z\03\01\02\03b\03\01\02\03", precision_config = ""} : (tensor<8x4x16x16x16xf16>, tensor<4x3x3x16x16xf16>, tensor<3xi32>, tensor<3x2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<8x4x14x14x16xf32>
+    // CHECK: %0 = "tf.XlaConvV2"(%arg0, %arg1, %cst_3, %cst_2, %cst_1, %cst_0, %cst) <{dimension_numbers = "\18\03 \042\03\00\01\02@\04P\04Z\03\01\02\03b\03\01\02\03", precision_config = ""}> : (tensor<8x4x16x16x16xf16>, tensor<4x3x3x16x16xf16>, tensor<3xi32>, tensor<3x2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<8x4x14x14x16xf32>
     %0 = "tf.XlaConvV2"(%lhs, %rhs, %strides, %padding, %lhs_dilation, %rhs_dilation, %feature_group_count) {dimension_numbers = "\18\03 \042\03\00\01\02@\04P\04Z\03\01\02\03b\03\01\02\03", precision_config = ""} : (tensor<8x4x16x16x16xf16>, tensor<4x3x3x16x16xf16>, tensor<3xi32>, tensor<3x2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<?x?x?x?x?xf32>
     func.return %0 : tensor<?x?x?x?x?xf32>
   }
@@ -1991,7 +2016,7 @@
     %lhs_dilation = "tf.Const"() {value = dense<[4, 1, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
     %padding = "tf.Const"() {value = dense<0> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
     %strides = "tf.Const"() {value = dense<[3, 1, 1]> : tensor<3xi64>} : () -> tensor<3xi64>
-    // CHECK: %0 = "tf.XlaConvV2"(%arg0, %arg1, %cst_3, %cst_2, %cst_1, %cst_0, %cst) {dimension_numbers = "\18\03 \042\03\00\01\02@\04P\04Z\03\01\02\03b\03\01\02\03", precision_config = ""} : (tensor<8x4x16x16x16xf32>, tensor<4x3x3x16x16xf32>, tensor<3xi64>, tensor<3x2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<8x4x14x14x16xf32>
+    // CHECK: %0 = "tf.XlaConvV2"(%arg0, %arg1, %cst_3, %cst_2, %cst_1, %cst_0, %cst) <{dimension_numbers = "\18\03 \042\03\00\01\02@\04P\04Z\03\01\02\03b\03\01\02\03", precision_config = ""}> : (tensor<8x4x16x16x16xf32>, tensor<4x3x3x16x16xf32>, tensor<3xi64>, tensor<3x2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<8x4x14x14x16xf32>
     %0 = "tf.XlaConvV2"(%lhs, %rhs, %strides, %padding, %lhs_dilation, %rhs_dilation, %feature_group_count) {dimension_numbers = "\18\03 \042\03\00\01\02@\04P\04Z\03\01\02\03b\03\01\02\03", precision_config = ""} : (tensor<8x4x16x16x16xf32>, tensor<4x3x3x16x16xf32>, tensor<3xi64>, tensor<3x2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<?x?x?x?x?xf32>
     func.return %0 : tensor<?x?x?x?x?xf32>
   }
@@ -2173,4 +2198,4 @@
     %6 = "tf.TensorListSetItem"(%if49, %4, %5) {device = ""} : (!tf_variant, tensor<i32>, tensor<2x2xf32>)-> tensor<*x!tf_type.variant>
     func.return
   }
-}
+}
\ No newline at end of file
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/split_into_island_per_op.mlir b/tensorflow/compiler/mlir/tensorflow/tests/split_into_island_per_op.mlir
index 7307bff..4428811 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/split_into_island_per_op.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/split_into_island_per_op.mlir
@@ -116,7 +116,7 @@
 // CHECK:  %[[GRAPH:.*]]:2 = tf_executor.graph {
 // CHECK:    %[[ADD1:.*]], %[[ADD1_control:.*]] = tf_executor.island wraps "tf.Add"(%arg0, %arg1)
 // CHECK:    %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ADD1]], %arg1)
-// CHECK:    %[[PRINT:.*]], %[[PRINT_control:.*]] = tf_executor.island wraps "tf.Print"(%[[ADD2]]) {message = "add result"}
+// CHECK:    %[[PRINT:.*]], %[[PRINT_control:.*]] = tf_executor.island wraps "tf.Print"(%[[ADD2]]) <{message = "add result"}>
 // CHECK:    tf_executor.fetch %[[ADD1]], %[[ADD2]] :
 // CHECK:  }
 // CHECK:  return %[[GRAPH]]#0, %[[GRAPH]]#1
@@ -186,7 +186,7 @@
 // CHECK:   %[[READ0:.*]], %[[READ0_CONTROL:.*]] = tf_executor.island wraps "tf.ReadVariableOp"(%arg0)
 // CHECK:   %[[ASSIGN0_CONTROL:.*]] = tf_executor.island wraps "tf.AssignVariableOp"(%arg0, %arg2)
 // CHECK:   %[[READ1:.*]], %[[READ1_CONTROL:.*]] = tf_executor.island wraps "tf.ReadVariableOp"(%arg1)
-// CHECK:   %[[VH0:.*]], %[[VH0_CONTROL:.*]] = tf_executor.island wraps "tf.VarHandleOp"() {container = "c", shared_name = "v0"}
+// CHECK:   %[[VH0:.*]], %[[VH0_CONTROL:.*]] = tf_executor.island wraps "tf.VarHandleOp"() <{container = "c", shared_name = "v0"}>
 // CHECK:   %[[READ2:.*]], %[[READ2_CONTROL:.*]] = tf_executor.island wraps "tf.ReadVariableOp"(%[[VH0]])
 // CHECK:   %[[ASSIGN1_CONTROL:.*]] = tf_executor.island wraps "tf.AssignVariableOp"(%arg1, %[[READ0]])
 // CHECK:   %[[ASSIGN2_CONTROL:.*]] = tf_executor.island wraps "tf.AssignVariableOp"(%arg0, %[[READ2]])
@@ -214,8 +214,8 @@
 
 // CHECK-LABEL: func @unknown_side_effecting_op
 // CHECK: tf_executor.graph {
-// CHECK:   %[[VH0:.*]], %[[VH0_CONTROL:.*]] = tf_executor.island wraps "tf.VarHandleOp"() {container = "c", shared_name = "v0"}
-// CHECK:   %[[VH1:.*]], %[[VH1_CONTROL:.*]] = tf_executor.island wraps "tf.VarHandleOp"() {container = "c", shared_name = "v1"}
+// CHECK:   %[[VH0:.*]], %[[VH0_CONTROL:.*]] = tf_executor.island wraps "tf.VarHandleOp"() <{container = "c", shared_name = "v0"}>
+// CHECK:   %[[VH1:.*]], %[[VH1_CONTROL:.*]] = tf_executor.island wraps "tf.VarHandleOp"() <{container = "c", shared_name = "v1"}>
 // CHECK:   %[[READ0:.*]], %[[READ0_CONTROL:.*]] = tf_executor.island wraps "tf.ReadVariableOp"(%[[VH0]])
 // CHECK:   %[[ASSIGN0_CONTROL:.*]] = tf_executor.island wraps "tf.AssignVariableOp"(%[[VH1]], %arg0)
 // CHECK:   %[[UNKNOWN_CONTROL:.*]] = tf_executor.island wraps "tf._UnknownSideEffectingOp_"()
@@ -432,4 +432,4 @@
     tf_executor.fetch
   }
   func.return
-}
\ No newline at end of file
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/stack_ops_decomposition.mlir b/tensorflow/compiler/mlir/tensorflow/tests/stack_ops_decomposition.mlir
index ee663e9..907d512 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/stack_ops_decomposition.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/stack_ops_decomposition.mlir
@@ -4,15 +4,15 @@
 
 // CHECK-LABEL: func @main
 func.func @main() -> tensor<f32> {
-  // CHECK-NEXT: "tf.Const"() {value = dense<10> : tensor<i32>}
+  // CHECK-NEXT: "tf.Const"() <{value = dense<10> : tensor<i32>}>
   %max_size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
-  // CHECK-NEXT: %[[ZERO_SCALAR:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // CHECK-NEXT: %[[ZERO_SCALAR:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
   // CHECK-NEXT: %[[CAST_ZERO:.*]] = "tf.Cast"(%[[ZERO_SCALAR]]) : (tensor<i32>) -> tensor<f32>
-  // CHECK-NEXT: %[[CONST10:.*]] = "tf.Const"() {value = dense<10> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK-NEXT: %[[CONST10:.*]] = "tf.Const"() <{value = dense<10> : tensor<1xi32>}> : () -> tensor<1xi32>
   // CHECK-NEXT: %[[BROADCAST:.*]] = "tf.BroadcastTo"(%[[CAST_ZERO]], %[[CONST10]]) : (tensor<f32>, tensor<1xi32>) -> tensor<10xf32>
   // CHECK-NEXT: %[[BUFFER:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf_type.resource<tensor<10xf32>>>
   // CHECK-NEXT: %[[SIZE:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf_type.resource<tensor<1xi32>>>
-  // CHECK-NEXT: %[[ZERO:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK-NEXT: %[[ZERO:.*]] = "tf.Const"() <{value = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
   // CHECK-NEXT: "tf.AssignVariableOp"(%[[SIZE]], %[[ZERO]])
   // CHECK-NEXT: "tf.AssignVariableOp"(%[[BUFFER]], %[[BROADCAST]])
   %stack = "tf.StackV2"(%max_size) {elem_type = f32, stack_name = "s"} : (tensor<i32>) -> tensor<!tf_type.resource>
@@ -21,22 +21,22 @@
   %elem = "tf._SomeOp"() : () -> tensor<f32>
   // CHECK-NEXT: %[[READ_VAL:.*]] = "tf.ReadVariableOp"(%[[BUFFER]])
   // CHECK-NEXT: %[[READ_SIZE:.*]] = "tf.ReadVariableOp"(%[[SIZE]])
-  // CHECK-NEXT: %[[UPDATE_SHAPE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK-NEXT: %[[UPDATE_SHAPE:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
   // CHECK-NEXT: %[[UPDATE_SLICE:.*]] = "tf.Reshape"(%[[PUSHVAL]], %[[UPDATE_SHAPE]]) : (tensor<f32>, tensor<1xi32>) -> tensor<1xf32>
   // CHECK-NEXT: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ_VAL]], %[[UPDATE_SLICE]], %[[READ_SIZE]]) : (tensor<10xf32>, tensor<1xf32>, tensor<1xi32>) -> tensor<10xf32>
   // CHECK-NEXT: "tf.AssignVariableOp"(%[[BUFFER]], %[[UPDATE]]) : (tensor<!tf_type.resource<tensor<10xf32>>>, tensor<10xf32>) -> ()
-  // CHECK-NEXT: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK-NEXT: %[[CONST1:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
   // CHECK-NEXT: %[[NEW_SIZE:.*]] = "tf.AddV2"(%[[READ_SIZE]], %[[CONST1]]) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
   // CHECK-NEXT: "tf.AssignVariableOp"(%[[SIZE]], %[[NEW_SIZE]]) : (tensor<!tf_type.resource<tensor<1xi32>>>, tensor<1xi32>) -> ()
   %push = "tf.StackPushV2"(%id, %elem) {swap_memory = false} : (tensor<!tf_type.resource>, tensor<f32>) -> tensor<f32>
   %pop = "tf.StackPopV2"(%stack) : (tensor<!tf_type.resource>) -> tensor<f32>
   // CHECK-NEXT: %[[READ_VAL1:.*]] = "tf.ReadVariableOp"(%[[BUFFER]])
   // CHECK-NEXT: %[[READ_SIZE1:.*]] = "tf.ReadVariableOp"(%[[SIZE]])
-  // CHECK-NEXT: %[[CONST1_1:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK-NEXT: %[[CONST1_1:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
   // CHECK-NEXT: %[[SUB:.*]] = "tf.Sub"(%[[READ_SIZE1]], %[[CONST1_1]])
-  // CHECK-NEXT: %[[SLICE_SIZE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK-NEXT: %[[SLICE_SIZE:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
   // CHECK-NEXT: %[[SLICE:.*]] = "tf.Slice"(%[[READ_VAL1]], %[[SUB]], %[[SLICE_SIZE]]) : (tensor<10xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xf32>
-  // CHECK-NEXT: %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
+  // CHECK-NEXT: %[[ELEM_SHAPE:.*]] = "tf.Const"() <{value = dense<> : tensor<0xi32>}> : () -> tensor<0xi32>
   // CHECK-NEXT: %[[ELEM:.*]] = "tf.Reshape"(%[[SLICE]], %[[ELEM_SHAPE]]) : (tensor<1xf32>, tensor<0xi32>) -> tensor<f32>
   // CHECK-NEXT: "tf.AssignVariableOp"(%[[SIZE]], %[[SUB]]) : (tensor<!tf_type.resource<tensor<1xi32>>>, tensor<1xi32>) -> ()
   "tf.StackCloseV2"(%stack) : (tensor<!tf_type.resource>) -> ()
@@ -50,14 +50,14 @@
 
 // CHECK-LABEL: func @main
 func.func @main() -> tensor<2xi32> {
-  // CHECK-NEXT: "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
+  // CHECK-NEXT: "tf.Const"() <{value = dense<10> : tensor<i32>}> : () -> tensor<i32>
   %size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
-  // CHECK-NEXT: %[[ZERO_CONST:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-  // CHECK-NEXT: %[[STACK_SHAPE:.*]] = "tf.Const"() {value = dense<[10, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
+  // CHECK-NEXT: %[[ZERO_CONST:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
+  // CHECK-NEXT: %[[STACK_SHAPE:.*]] = "tf.Const"() <{value = dense<[10, 2]> : tensor<2xi32>}> : () -> tensor<2xi32>
   // CHECK-NEXT: %[[BROADCAST:.*]] = "tf.BroadcastTo"(%[[ZERO_CONST]], %[[STACK_SHAPE]]) : (tensor<i32>, tensor<2xi32>) -> tensor<10x2xi32>
   // CHECK-NEXT: %[[BUFFER:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf_type.resource<tensor<10x2xi32>>>
   // CHECK-NEXT: %[[SIZE:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf_type.resource<tensor<1xi32>>>
-  // CHECK-NEXT: %[[ZERO_SIZE:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK-NEXT: %[[ZERO_SIZE:.*]] = "tf.Const"() <{value = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
   // CHECK-NEXT: "tf.AssignVariableOp"(%[[SIZE]], %[[ZERO_SIZE]]) : (tensor<!tf_type.resource<tensor<1xi32>>>, tensor<1xi32>) -> ()
   // CHECK-NEXT: "tf.AssignVariableOp"(%[[BUFFER]], %[[BROADCAST]]) : (tensor<!tf_type.resource<tensor<10x2xi32>>>, tensor<10x2xi32>) -> ()
   %stack = "tf.StackV2"(%size) {elem_type = i32, stack_name = "s"} : (tensor<i32>) -> tensor<!tf_type.resource>
@@ -65,14 +65,14 @@
   %elem = "tf._SomeOp"() : () -> tensor<2xi32>
   // CHECK-NEXT: %[[STACK_VAL:.*]] = "tf.ReadVariableOp"(%[[BUFFER]]) : (tensor<!tf_type.resource<tensor<10x2xi32>>>) -> tensor<10x2xi32>
   // CHECK-NEXT: %[[STACK_SIZE:.*]] = "tf.ReadVariableOp"(%[[SIZE]]) : (tensor<!tf_type.resource<tensor<1xi32>>>) -> tensor<1xi32>
-  // CHECK-NEXT: %[[UPDATE_SHAPE:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
+  // CHECK-NEXT: %[[UPDATE_SHAPE:.*]] = "tf.Const"() <{value = dense<[1, 2]> : tensor<2xi32>}> : () -> tensor<2xi32>
   // CHECK-NEXT: %[[UPDATE_SLICE:.*]] = "tf.Reshape"(%[[PUSH_VAL]], %[[UPDATE_SHAPE]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<1x2xi32>
-  // CHECK-NEXT: %[[ZERO_INDS:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
-  // CHECK-NEXT: %[[CONCAT_DIM:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // CHECK-NEXT: %[[ZERO_INDS:.*]] = "tf.Const"() <{value = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+  // CHECK-NEXT: %[[CONCAT_DIM:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
   // CHECK-NEXT: %[[CONCAT_OFFETS:.*]] = "tf.ConcatV2"(%[[STACK_SIZE]], %[[ZERO_INDS]], %[[CONCAT_DIM]]) : (tensor<1xi32>, tensor<1xi32>, tensor<i32>) -> tensor<2xi32>
   // CHECK-NEXT: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"(%[[STACK_VAL]], %[[UPDATE_SLICE]], %[[CONCAT_OFFETS]]) : (tensor<10x2xi32>, tensor<1x2xi32>, tensor<2xi32>) -> tensor<10x2xi32>
   // CHECK-NEXT: "tf.AssignVariableOp"(%[[BUFFER]], %[[UPDATE]]) : (tensor<!tf_type.resource<tensor<10x2xi32>>>, tensor<10x2xi32>) -> ()
-  // CHECK-NEXT: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK-NEXT: %[[CONST1:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
   // CHECK-NEXT: %[[NEW_SIZE:.*]] = "tf.AddV2"(%[[STACK_SIZE]], %[[CONST1]]) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
   // CHECK-NEXT: "tf.AssignVariableOp"(%[[SIZE]], %[[NEW_SIZE]]) : (tensor<!tf_type.resource<tensor<1xi32>>>, tensor<1xi32>) -> ()
   %push = "tf.StackPushV2"(%stack, %elem) {swap_memory = false} : (tensor<!tf_type.resource>, tensor<2xi32>) -> tensor<2xi32>
@@ -102,7 +102,7 @@
 }
 // CHECK: func @while_body(%[[BARG0:.*]]: tensor<!tf_type.resource<tensor<10xf32>>>, %[[BARG1:.*]]: tensor<i32>, %[[BARG2:.*]]: tensor<!tf_type.resource<tensor<1xi32>>>)
 func.func @while_body(%arg0: tensor<!tf_type.resource>, %arg1: tensor<i32>) -> (tensor<!tf_type.resource>, tensor<i32>) {
-  // CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: %[[CONST1:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
   %const1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
   // CHECK: %[[SUB:.*]] = "tf.Sub"(%[[BARG1]], %[[CONST1]])
   %sub = "tf.Sub"(%arg1, %const1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
@@ -143,7 +143,7 @@
   }, {
     // CHECK: ^bb0(%[[BARG0:.*]]: tensor<i32>
     ^bb0(%barg0: tensor<i32>):
-    // CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+    // CHECK: %[[CONST1:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
     %const1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
     // CHECK: %[[SUB:.*]] = "tf.Sub"(%[[BARG0]], %[[CONST1]])
     %sub = "tf.Sub"(%barg0, %const1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
@@ -185,7 +185,7 @@
   // CHECK: tf.AssignVariableOp
   // CHECK: tf.AssignVariableOp
   %stack = "tf.StackV2"(%max_size) {elem_type = f32, stack_name = "s"} : (tensor<i32>) -> tensor<!tf_type.resource>
-  // CHECK: %[[CASE_OUTPUT:.*]] = "tf.CaseRegion"(%[[BRANCH_INDEX]]) ({
+  // CHECK: %[[CASE_OUTPUT:.*]] = "tf.CaseRegion"(%[[BRANCH_INDEX]]) {{.*}} ({
   %case_op = "tf.CaseRegion"(%arg0) ({
     %elem = "tf._SomeOp"() : () -> tensor<f32>
     // CHECK-NOT: tf.StackPushV2
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tensor_array_ops_decomposition.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tensor_array_ops_decomposition.mlir
index 06ecfd4..6adc4329 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tensor_array_ops_decomposition.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tensor_array_ops_decomposition.mlir
@@ -10,9 +10,9 @@
   // CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf_type.resource<tensor<5x3xf32>>>
   // CHECK: "tf.AssignVariableOp"(%[[VAR]], %[[BUFFER]])
   %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = #tf_type.shape<3>, dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf_type.resource>, tensor<f32>)
-  // CHECK: %[[IND:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: %[[IND:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
   %index = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
-  // CHECK: %[[VAL:.*]] = "tf.Const"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf32>} : () -> tensor<3xf32>
+  // CHECK: %[[VAL:.*]] = "tf.Const"() <{value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf32>}> : () -> tensor<3xf32>
   %value = "tf.Const"() {value = dense<[1.0, 2.0, 3.0]> : tensor<3xf32>} : () -> tensor<3xf32>
   // CHECK: %[[READ_VAR:.*]] = "tf.ReadVariableOp"(%[[VAR]])
   // CHECK: %[[UPDATE_SLICE:.*]] = "tf.Reshape"(%[[VAL]]
@@ -46,7 +46,7 @@
   %index = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
   %value = "tf.Const"() {value = dense<[1.0, 2.0, 3.0]> : tensor<3xf32>} : () -> tensor<3xf32>
   %write = "tf.TensorArrayWriteV3"(%ta#0, %index, %value, %ta#1) : (tensor<!tf_type.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
-  // CHECK: %[[SIZE:.*]] = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: %[[SIZE:.*]] = "tf.Const"() <{value = dense<5> : tensor<i32>}> : () -> tensor<i32>
   %size_out = "tf.TensorArraySizeV3"(%ta#0, %write) : (tensor<!tf_type.resource>, tensor<f32>) -> tensor<i32>
   // CHECK: return %[[SIZE]] : tensor<i32>
   func.return %size_out : tensor<i32>
@@ -110,7 +110,7 @@
   // CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VAR]]) : (tensor<!tf_type.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
   // CHECK: %[[CONCAT_RESHAPE:.*]] = "tf.Reshape"(%[[READ]],
   // CHECK-SAME: -> tensor<15xf32>
-  // CHECK: %[[LENS:.*]] = "tf.Const"() {value = dense<3> : tensor<5xi64>} : () -> tensor<5xi64>
+  // CHECK: %[[LENS:.*]] = "tf.Const"() <{value = dense<3> : tensor<5xi64>}> : () -> tensor<5xi64>
   %concat:2 = "tf.TensorArrayConcatV3"(%ta#0, %ta#1) {element_shape_except0 = #tf_type.shape<*>} : (tensor<!tf_type.resource>, tensor<f32>) -> (tensor<*xf32>, tensor<*xi64>)
   // CHECK: %[[SPLIT_RESHAPE:.*]] = "tf.Reshape"(%[[CONCAT_RESHAPE]],
   // CHECK-SAME: -> tensor<5x3xf32>
@@ -153,33 +153,33 @@
   // CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf_type.resource<tensor<5x3xf32>>>
   // CHECK: "tf.AssignVariableOp"
   %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = #tf_type.shape<3>, dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf_type.resource>, tensor<f32>)
-  // CHECK: %[[INDS:.*]] = "tf.Const"() {value = dense<[2, 1]> : tensor<2xi32>} : () -> tensor<2xi32>
+  // CHECK: %[[INDS:.*]] = "tf.Const"() <{value = dense<[2, 1]> : tensor<2xi32>}> : () -> tensor<2xi32>
   %indices = "tf.Const"() {value = dense<[2, 1]> : tensor<2xi32>} : () -> tensor<2xi32>
   // CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VAR]]) : (tensor<!tf_type.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
-  // CHECK: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: %[[AXIS:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
   // CHECK: %[[GATHER:.*]] = "tf.GatherV2"(%[[READ]], %[[INDS]], %[[AXIS]]) : (tensor<5x3xf32>, tensor<2xi32>, tensor<i32>) -> tensor<2x3xf32>
   %gather = "tf.TensorArrayGatherV3"(%ta#0, %indices, %ta#1) {element_shape = #tf_type.shape<*>} : (tensor<!tf_type.resource>, tensor<2xi32>, tensor<f32>) -> tensor<*xf32>
   // CHECK: %[[READ2:.*]] = "tf.ReadVariableOp"(%[[VAR]]) : (tensor<!tf_type.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
-  // CHECK-DAG: %[[SLICE_SIZE:.*]] = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>} : () -> tensor<2xi32>
-  // CHECK-DAG: %[[IND_SLICE0_START:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
-  // CHECK-DAG: %[[IND_SLICE0_SIZE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK-DAG: %[[SLICE_SIZE:.*]] = "tf.Const"() <{value = dense<[1, 3]> : tensor<2xi32>}> : () -> tensor<2xi32>
+  // CHECK-DAG: %[[IND_SLICE0_START:.*]] = "tf.Const"() <{value = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+  // CHECK-DAG: %[[IND_SLICE0_SIZE:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
   // CHECK: %[[IND_SLICE0:.*]] = "tf.Slice"(%[[INDS]], %[[IND_SLICE0_START]], %[[IND_SLICE0_SIZE]]) : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
   // CHECK: %[[SLICE0_START:.*]] = "tf.ConcatV2"(%[[IND_SLICE0]],
   // CHECK: %[[OLD_SLICE0:.*]] = "tf.Slice"(%[[READ2]], %[[SLICE0_START]],
   // CHECK-SAME: (tensor<5x3xf32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x3xf32>
-  // CHECK: %[[UPDATE_SLICE0_START:.*]] = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> tensor<2xi32>
+  // CHECK: %[[UPDATE_SLICE0_START:.*]] = "tf.Const"() <{value = dense<0> : tensor<2xi32>}> : () -> tensor<2xi32>
   // CHECK: %[[UPDATE_SLICE0:.*]] = "tf.Slice"(%[[GATHER]], %[[UPDATE_SLICE0_START]], %[[SLICE_SIZE]]) : (tensor<2x3xf32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x3xf32>
   // CHECK: %[[ADD0:.*]] = "tf.AddV2"(%[[OLD_SLICE0]], %[[UPDATE_SLICE0]])
   // CHECK: %[[UPDATE0:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ2]], %[[ADD0]]
   // CHECK-SAME: (tensor<5x3xf32>, tensor<1x3xf32>, tensor<2xi32>) -> tensor<5x3xf32>
 
-  // CHECK-DAG: %[[IND_SLICE1_START:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
-  // CHECK-DAG: %[[IND_SLICE1_SIZE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK-DAG: %[[IND_SLICE1_START:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
+  // CHECK-DAG: %[[IND_SLICE1_SIZE:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
   // CHECK: %[[IND_SLICE1:.*]] = "tf.Slice"(%[[INDS]], %[[IND_SLICE1_START]], %[[IND_SLICE1_SIZE]]) : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
   // CHECK: %[[SLICE1_START:.*]] = "tf.ConcatV2"(%[[IND_SLICE1]],
   // CHECK: %[[OLD_SLICE1:.*]] = "tf.Slice"(%[[UPDATE0]], %[[SLICE1_START]],
   // CHECK-SAME: (tensor<5x3xf32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x3xf32>
-  // CHECK: %[[UPDATE_SLICE1_START:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+  // CHECK: %[[UPDATE_SLICE1_START:.*]] = "tf.Const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<2xi32>
   // CHECK: %[[UPDATE_SLICE1:.*]] = "tf.Slice"(%[[GATHER]], %[[UPDATE_SLICE1_START]], %[[SLICE_SIZE]]) : (tensor<2x3xf32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x3xf32>
   // CHECK: %[[ADD1:.*]] = "tf.AddV2"(%[[OLD_SLICE1]], %[[UPDATE_SLICE1]])
   // CHECK: %[[UPDATE1:.*]] = "tf.XlaDynamicUpdateSlice"(%[[UPDATE0]], %[[ADD1]]
@@ -200,7 +200,7 @@
   // CHECK: "tf.AssignVariableOp"(%[[VAR]],
   %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = #tf_type.shape<3>, dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf_type.resource>, tensor<f32>)
   %index = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
-  // CHECK: %[[VALUE:.*]] = "tf.Const"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf32>} : () -> tensor<3xf32>
+  // CHECK: %[[VALUE:.*]] = "tf.Const"() <{value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf32>}> : () -> tensor<3xf32>
   %value = "tf.Const"() {value = dense<[1.0, 2.0, 3.0]> : tensor<3xf32>} : () -> tensor<3xf32>
   // CHECK: %[[GVAR1:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf_type.resource<tensor<5x3xf32>>>
   // CHECK: "tf.AssignVariableOp"(%[[GVAR1]],
@@ -240,7 +240,7 @@
 
 // CHECK-LABEL: func @main
 func.func @main() -> () {
-  // CHECK: %[[SIZE:.*]] = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: %[[SIZE:.*]] = "tf.Const"() <{value = dense<5> : tensor<i32>}> : () -> tensor<i32>
   %size = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
   %index = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
   // CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf_type.resource<tensor<5x3xf32>>>
@@ -257,7 +257,7 @@
 }
 // CHECK: func @while_body(%[[BARG0:.*]]: tensor<!tf_type.resource<tensor<5x3xf32>>>, %[[BARG1:.*]]: tensor<i32>, %[[BARG2:.*]]: tensor<!tf_type.resource<tensor<5x3xf32>>>)
 func.func @while_body(%arg0: tensor<!tf_type.resource>, %arg1: tensor<i32>) -> (tensor<!tf_type.resource>, tensor<i32>) {
-  // CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: %[[CONST1:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
   %const1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
   // CHECK: %[[SUB:.*]] = "tf.Sub"(%[[BARG1]], %[[CONST1]])
   %sub = "tf.Sub"(%arg1, %const1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
@@ -288,7 +288,7 @@
 
 // CHECK-LABEL: func @main
 func.func @main() -> () {
-  // CHECK: %[[SIZE:.*]] = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: %[[SIZE:.*]] = "tf.Const"() <{value = dense<5> : tensor<i32>}> : () -> tensor<i32>
   %size = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
   %index = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
   // CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf_type.resource<tensor<5x3xf32>>>
@@ -328,7 +328,7 @@
 }
 // CHECK: func @else_branch(%[[EARG0:.*]]: tensor<!tf_type.resource<tensor<5x3xf32>>>, %[[EARG1:.*]]: tensor<!tf_type.resource<tensor<5x3xf32>>>, %[[EARG2:.*]]: tensor<!tf_type.resource<tensor<5x3xf32>>>)
 func.func @else_branch(%arg0: tensor<!tf_type.resource>) -> tensor<!tf_type.resource> {
-  // CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: %[[CONST1:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
   %const1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
   %elem = "tf._SomeOp"() : () -> tensor<3xf32>
   %flow = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
@@ -348,14 +348,14 @@
 
 // CHECK-LABEL: func @main
 func.func @main() -> () {
-  // CHECK: %[[SIZE:.*]] = "tf.Const"() {value = dense<5> : tensor<i32>}
+  // CHECK: %[[SIZE:.*]] = "tf.Const"() <{value = dense<5> : tensor<i32>}>
   %size = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
   %index = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
   // CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf_type.resource<tensor<5x3xf32>>>
   // CHECK-NOT: tf.TensorArrayV3
   %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = #tf_type.shape<3>, dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf_type.resource>, tensor<f32>)
-  // CHECK: %[[FLOW_INIT:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>}
-  // CHECK: %[[WHILE:.*]]:2 = "tf.WhileRegion"(%[[FLOW_INIT]], %[[SIZE]]) ({
+  // CHECK: %[[FLOW_INIT:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}>
+  // CHECK: %[[WHILE:.*]]:2 = "tf.WhileRegion"(%[[FLOW_INIT]], %[[SIZE]]) {{.*}} ({
   %while:2 = "tf.WhileRegion"(%ta#1, %size) ({
   // CHECK: ^bb0(%[[BARG0:.*]]: tensor<f32>, %[[BARG1:.*]]: tensor<i32>):
   ^bb0(%barg0: tensor<f32>, %barg1: tensor<i32>):
@@ -402,8 +402,8 @@
   // CHECK: "tf.AssignVariableOp"(%[[TA_BUFFER]]
   // CHECK-NOT: tf.TensorArrayV3
   %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = #tf_type.shape<3>, dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf_type.resource>, tensor<f32>)
-  // CHECK: "tf.IfRegion"(%[[PRED]]) ({
-  %case_op = "tf.IfRegion"(%arg0) ({
+  // CHECK: "tf.IfRegion"(%[[PRED]]) <{is_stateless = false}> ({
+  %case_op = "tf.IfRegion"(%arg0) <{is_stateless = false}> ({
       // CHECK: %[[TA_VAL:.*]] = "tf.ReadVariableOp"(%[[TA_BUFFER]])
       // CHECK: "tf.Slice"(%[[TA_VAL]]
       // CHECK-NOT: tf.TensorArrayReadV3
@@ -420,8 +420,8 @@
       %elem = "tf._SomeOp"() : () -> tensor<3xf32>
       %write = "tf.TensorArrayWriteV3"(%ta#0, %idx, %elem, %ta#1) : (tensor<!tf_type.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
       "tf.Yield"(%write) : (tensor<f32>) -> ()
-    // CHECK: }) {is_stateless = false} : (tensor<i1>) -> tensor<f32>
-    }) {is_stateless = false} : (tensor<i1>) -> tensor<f32>
+    // CHECK: }) : (tensor<i1>) -> tensor<f32>
+    }) : (tensor<i1>) -> tensor<f32>
   %idx = "tf.Const"() {value = dense<6> : tensor<i32>} : () -> tensor<i32>
   // CHECK-NOT: tf.TensorArrayReadV3
   %read_val = "tf.TensorArrayReadV3"(%ta#0, %idx, %case_op) : (tensor<!tf_type.resource>, tensor<i32>, tensor<f32>) -> tensor<3xf32>
@@ -436,7 +436,7 @@
 
 // CHECK-LABEL: func @main
 func.func @main() -> () {
-  // CHECK: %[[SIZE:.*]] = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: %[[SIZE:.*]] = "tf.Const"() <{value = dense<5> : tensor<i32>}> : () -> tensor<i32>
   %size = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
   %index = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
   // CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf_type.resource<tensor<5x3xf32>>>
@@ -486,7 +486,7 @@
 
 // CHECK-LABEL: func @main
 func.func @main() -> () {
-  // CHECK: %[[SIZE:.*]] = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: %[[SIZE:.*]] = "tf.Const"() <{value = dense<5> : tensor<i32>}> : () -> tensor<i32>
   %size = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
   %index = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
   // CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf_type.resource<tensor<5x3xf32>>>
@@ -543,7 +543,7 @@
   // CHECK: "tf.MlirLocalVarOp"() : () -> tensor<!tf_type.resource<tensor<5xf32>>>
   // CHECK: "tf.AssignVariableOp"
   %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = #tf_type.shape<>, dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf_type.resource>, tensor<f32>)
-  // CHECK: %[[SIZE:.*]] = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: %[[SIZE:.*]] = "tf.Const"() <{value = dense<5> : tensor<i32>}> : () -> tensor<i32>
   %size_out = "tf.TensorArraySizeV3"(%ta#0, %ta#1) : (tensor<!tf_type.resource>, tensor<f32>) -> tensor<i32>
   // CHECK: return %[[SIZE]] : tensor<i32>
   func.return %size_out : tensor<i32>
@@ -553,7 +553,7 @@
 
 // CHECK-LABEL: func @main
 func.func @main() -> () {
-  // CHECK: "tf.PartitionedCall"() {config = "", config_proto = "", executor_type = "", f = @callee} : () -> tensor<*xf32>
+  // CHECK: "tf.PartitionedCall"() <{config = "", config_proto = "", executor_type = "", f = @callee}> : () -> tensor<*xf32>
   %call = "tf.PartitionedCall"() {config = "", config_proto = "", executor_type = "", f = @callee} : () -> (tensor<*xf32>)
   func.return
 }
@@ -567,7 +567,7 @@
   // CHECK: "tf.AssignVariableOp"(%[[LOCAL_VAR]], %[[UPDATE]]) : (tensor<!tf_type.resource<tensor<5x3xf32>>>, tensor<5x3xf32>) -> ()
   %flow = "tf.TensorArrayWriteV3"(%ta#0, %index, %value, %ta#1) : (tensor<!tf_type.resource<tensor<*xf32>>>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
   // CHECK: %[[SLICE:.*]] = "tf.Slice"
-  // CHECK: %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<3> : tensor<1xi32>}
+  // CHECK: %[[ELEM_SHAPE:.*]] = "tf.Const"() <{value = dense<3> : tensor<1xi32>}>
   // CHECK: %[[ELEM:.*]] = "tf.Reshape"(%[[SLICE]], %[[ELEM_SHAPE]])
   %val = "tf.TensorArrayReadV3"(%ta#0, %index, %ta#1) : (tensor<!tf_type.resource<tensor<*xf32>>>, tensor<i32>, tensor<f32>) -> tensor<*xf32>
   // CHECK: %[[CAST:.*]] = tensor.cast %[[ELEM]] : tensor<3xf32> to tensor<*xf32>
@@ -604,7 +604,7 @@
   %flow = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
   // CHECK: %[[BR_INDEX:.*]] = "tf.SomeOp"() : () -> tensor<i32>
   %branch_index = "tf.SomeOp"() : () -> tensor<i32>
-  // CHECK: "tf.CaseRegion"(%[[BR_INDEX]]) ({
+  // CHECK: "tf.CaseRegion"(%[[BR_INDEX]]) {{.*}} ({
   "tf.CaseRegion"(%branch_index) ({
     // CHECK: %[[READ_GVAR:.*]] = "tf.ReadVariableOp"(%[[GVAR]])
     // CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ_GVAR]],
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tensor_list_ops_decomposition.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tensor_list_ops_decomposition.mlir
index 1177f9f..6fb9598 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tensor_list_ops_decomposition.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tensor_list_ops_decomposition.mlir
@@ -4,34 +4,34 @@
 
 // CHECK-LABEL: func @main
 func.func @main() -> (tensor<f32>, tensor<i32>) {
-  // CHECK-NEXT: "tf.Const"() {value = dense<> : tensor<0xi32>}
+  // CHECK-NEXT: "tf.Const"() <{value = dense<> : tensor<0xi32>}>
   %elem_shape = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
-  // CHECK-NEXT: "tf.Const"() {value = dense<10> : tensor<i32>}
+  // CHECK-NEXT: "tf.Const"() <{value = dense<10> : tensor<i32>}>
   %max_size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
-  // CHECK-NEXT: %[[ZERO_SCALAR:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // CHECK-NEXT: %[[ZERO_SCALAR:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
   // CHECK-NEXT: %[[CAST_ZERO:.*]] = "tf.Cast"(%[[ZERO_SCALAR]]) : (tensor<i32>) -> tensor<f32>
-  // CHECK-NEXT: %[[CONST10:.*]] = "tf.Const"() {value = dense<10> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK-NEXT: %[[CONST10:.*]] = "tf.Const"() <{value = dense<10> : tensor<1xi32>}> : () -> tensor<1xi32>
   // CHECK-NEXT: %[[BROADCAST:.*]] = "tf.BroadcastTo"(%[[CAST_ZERO]], %[[CONST10]]) : (tensor<f32>, tensor<1xi32>) -> tensor<10xf32>
-  // CHECK-NEXT: %[[ZERO:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK-NEXT: %[[ZERO:.*]] = "tf.Const"() <{value = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
   %tl = "tf.EmptyTensorList"(%elem_shape, %max_size) : (tensor<0xi32>, tensor<i32>) -> tensor<!tf_type.variant<tensor<f32>>>
   %id = "tf.Identity"(%tl) : (tensor<!tf_type.variant<tensor<f32>>>) -> tensor<!tf_type.variant<tensor<f32>>>
   // CHECK-NEXT: %[[PUSHVAL:.*]] = "tf._SomeOp"()
   %elem = "tf._SomeOp"() : () -> tensor<f32>
-  // CHECK-NEXT: %[[UPDATE_SHAPE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK-NEXT: %[[UPDATE_SHAPE:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
   // CHECK-NEXT: %[[UPDATE_SLICE:.*]] = "tf.Reshape"(%[[PUSHVAL]], %[[UPDATE_SHAPE]]) : (tensor<f32>, tensor<1xi32>) -> tensor<1xf32>
   // CHECK-NEXT: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"(%[[BROADCAST]], %[[UPDATE_SLICE]], %[[ZERO]]) : (tensor<10xf32>, tensor<1xf32>, tensor<1xi32>) -> tensor<10xf32>
-  // CHECK-NEXT: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK-NEXT: %[[CONST1:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
   // CHECK-NEXT: %[[NEW_SIZE:.*]] = "tf.AddV2"(%[[ZERO]], %[[CONST1]]) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
   %push = "tf.TensorListPushBack"(%id, %elem) : (tensor<!tf_type.variant<tensor<f32>>>, tensor<f32>) -> tensor<!tf_type.variant<tensor<f32>>>
   // CHECK-NEXT: %[[COPY:.*]] = "tf.Identity"(%[[UPDATE]])
-  // CHECK-NEXT: %[[CONST1_1:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK-NEXT: %[[CONST1_1:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
   // CHECK-NEXT: %[[SUB:.*]] = "tf.Sub"(%[[NEW_SIZE]], %[[CONST1_1]])
-  // CHECK-NEXT: %[[SLICE_SIZE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK-NEXT: %[[SLICE_SIZE:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
   // CHECK-NEXT: %[[SLICE:.*]] = "tf.Slice"(%[[COPY]], %[[SUB]], %[[SLICE_SIZE]]) : (tensor<10xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xf32>
-  // CHECK-NEXT: %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
+  // CHECK-NEXT: %[[ELEM_SHAPE:.*]] = "tf.Const"() <{value = dense<> : tensor<0xi32>}> : () -> tensor<0xi32>
   // CHECK-NEXT: %[[ELEM:.*]] = "tf.Reshape"(%[[SLICE]], %[[ELEM_SHAPE]]) : (tensor<1xf32>, tensor<0xi32>) -> tensor<f32>
   %pop:2 = "tf.TensorListPopBack"(%push, %elem_shape) : (tensor<!tf_type.variant<tensor<f32>>>, tensor<0xi32>) -> (tensor<!tf_type.variant<tensor<f32>>>, tensor<f32>)
-  // CHECK-NEXT: %[[SCALAR_SHAPE:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>}
+  // CHECK-NEXT: %[[SCALAR_SHAPE:.*]] = "tf.Const"() <{value = dense<> : tensor<0xi32>}>
   // CHECK-NEXT: %[[LENGTH:.*]] = "tf.Reshape"(%[[NEW_SIZE]], %[[SCALAR_SHAPE]])
   %length = "tf.TensorListLength"(%push) : (tensor<!tf_type.variant<tensor<f32>>>) -> tensor<i32>
   // CHECK-NEXT: return %[[ELEM]], %[[LENGTH]] : tensor<f32>, tensor<i32>
@@ -46,30 +46,30 @@
 // CHECK-LABEL: func @main
 // CHECK-SAME: (%[[ARG0:.*]]: tensor<i32>) -> (tensor<f32>, tensor<10xf32>, tensor<i32>)
 func.func @main(%arg0: tensor<i32>) -> (tensor<f32>, tensor<10xf32>, tensor<i32>) {
-  // CHECK-NEXT: "tf.Const"() {value = dense<> : tensor<0xi32>}
+  // CHECK-NEXT: "tf.Const"() <{value = dense<> : tensor<0xi32>}>
   %elem_shape = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
-  // CHECK-NEXT: %[[NUM:.*]] = "tf.Const"() {value = dense<10> : tensor<i32>}
+  // CHECK-NEXT: %[[NUM:.*]] = "tf.Const"() <{value = dense<10> : tensor<i32>}>
   %num = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
-  // CHECK-NEXT: %[[ZERO_SCALAR:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // CHECK-NEXT: %[[ZERO_SCALAR:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
   // CHECK-NEXT: %[[CAST_ZERO:.*]] = "tf.Cast"(%[[ZERO_SCALAR]]) : (tensor<i32>) -> tensor<f32>
-  // CHECK-NEXT: %[[CONST10:.*]] = "tf.Const"() {value = dense<10> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK-NEXT: %[[CONST10:.*]] = "tf.Const"() <{value = dense<10> : tensor<1xi32>}> : () -> tensor<1xi32>
   // CHECK-NEXT: %[[BROADCAST:.*]] = "tf.BroadcastTo"(%[[CAST_ZERO]], %[[CONST10]]) : (tensor<f32>, tensor<1xi32>) -> tensor<10xf32>
-  // CHECK-NEXT: %[[SIZE_SHAPE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>}
+  // CHECK-NEXT: %[[SIZE_SHAPE:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}>
   // CHECK-NEXT: %[[SIZE:.*]] = "tf.Reshape"(%[[NUM]], %[[SIZE_SHAPE]])
   %tl = "tf.TensorListReserve"(%elem_shape, %num) : (tensor<0xi32>, tensor<i32>) -> tensor<!tf_type.variant<tensor<f32>>>
   // CHECK-NEXT: %[[SETVAL:.*]] = "tf._SomeOp"()
   %elem = "tf._SomeOp"() : () -> tensor<f32>
-  // CHECK-NEXT: %[[SIZE_SHAPE1:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>}
+  // CHECK-NEXT: %[[SIZE_SHAPE1:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}>
   // CHECK-NEXT: %[[SET_INDEX:.*]] = "tf.Reshape"(%[[ARG0]], %[[SIZE_SHAPE1]]) : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
-  // CHECK-NEXT: %[[UPDATE_SHAPE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK-NEXT: %[[UPDATE_SHAPE:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
   // CHECK-NEXT: %[[UPDATE_SLICE:.*]] = "tf.Reshape"(%[[SETVAL]], %[[UPDATE_SHAPE]]) : (tensor<f32>, tensor<1xi32>) -> tensor<1xf32>
   // CHECK-NEXT: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"(%[[BROADCAST]], %[[UPDATE_SLICE]], %[[SET_INDEX]]) : (tensor<10xf32>, tensor<1xf32>, tensor<1xi32>) -> tensor<10xf32>
   %set = "tf.TensorListSetItem"(%tl, %arg0, %elem) : (tensor<!tf_type.variant<tensor<f32>>>, tensor<i32>, tensor<f32>) -> tensor<!tf_type.variant<tensor<f32>>>
-  // CHECK-NEXT: %[[SIZE_SHAPE2:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>}
+  // CHECK-NEXT: %[[SIZE_SHAPE2:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}>
   // CHECK-NEXT: %[[GET_INDEX:.*]] = "tf.Reshape"(%[[ARG0]], %[[SIZE_SHAPE2]]) : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
-  // CHECK-NEXT: %[[SLICE_SIZE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK-NEXT: %[[SLICE_SIZE:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
   // CHECK-NEXT: %[[SLICE:.*]] = "tf.Slice"(%[[UPDATE]], %[[GET_INDEX]], %[[SLICE_SIZE]]) : (tensor<10xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xf32>
-  // CHECK-NEXT: %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
+  // CHECK-NEXT: %[[ELEM_SHAPE:.*]] = "tf.Const"() <{value = dense<> : tensor<0xi32>}> : () -> tensor<0xi32>
   // CHECK-NEXT: %[[ELEM:.*]] = "tf.Reshape"(%[[SLICE]], %[[ELEM_SHAPE]]) : (tensor<1xf32>, tensor<0xi32>) -> tensor<f32>
   %get = "tf.TensorListGetItem"(%set, %arg0, %elem_shape) : (tensor<!tf_type.variant<tensor<f32>>>, tensor<i32>, tensor<0xi32>) -> tensor<f32>
   // CHECK-NEXT: %[[ADDN:.*]] = "tf.AddN"(%[[UPDATE]], %[[BROADCAST]]) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
@@ -79,7 +79,7 @@
   // CHECK-NEXT: %[[ADDN2:.*]] = "tf.AddN"(%[[ADDN]], %[[ZEROS_LIKE]]) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
   %addn2 = "tf.AddN"(%addn, %zeros-like) : (tensor<!tf_type.variant<tensor<f32>>>, tensor<!tf_type.variant<tensor<f32>>>) -> tensor<!tf_type.variant<tensor<f32>>>
   %stack = "tf.TensorListStack"(%addn2, %elem_shape) : (tensor<!tf_type.variant<tensor<f32>>>, tensor<0xi32>) -> tensor<10xf32>
-  // CHECK-NEXT: %[[LEN:.*]] = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
+  // CHECK-NEXT: %[[LEN:.*]] = "tf.Const"() <{value = dense<10> : tensor<i32>}> : () -> tensor<i32>
   %length = "tf.TensorListLength"(%addn2) : (tensor<!tf_type.variant<tensor<f32>>>) -> tensor<i32>
   // CHECK-NEXT: return %[[ELEM]], %[[ADDN2]], %[[LEN]] : tensor<f32>, tensor<10xf32>, tensor<i32>
   func.return %get, %stack, %length : tensor<f32>, tensor<10xf32>, tensor<i32>
@@ -92,16 +92,16 @@
 // CHECK-LABEL: func @main
 // CHECK-SAME: (%[[ARG0:.*]]: tensor<i32>, %[[ARG1:.*]]: tensor<10xf32>) -> tensor<f32>
 func.func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32>) -> tensor<f32> {
-  // CHECK-NEXT: "tf.Const"() {value = dense<> : tensor<0xi32>}
+  // CHECK-NEXT: "tf.Const"() <{value = dense<> : tensor<0xi32>}>
   %elem_shape = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
   // CHECK-NEXT: %[[BUFFER:.*]] = "tf.Identity"(%[[ARG1]]) : (tensor<10xf32>) -> tensor<10xf32>
-  // CHECK-NEXT: %[[SIZE:.*]] = "tf.Const"() {value = dense<10> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK-NEXT: %[[SIZE:.*]] = "tf.Const"() <{value = dense<10> : tensor<1xi32>}> : () -> tensor<1xi32>
   %tl = "tf.TensorListFromTensor"(%arg1, %elem_shape) : (tensor<10xf32>, tensor<0xi32>) -> tensor<!tf_type.variant<tensor<f32>>>
-  // CHECK-NEXT: %[[SIZE_SHAPE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>}
+  // CHECK-NEXT: %[[SIZE_SHAPE:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}>
   // CHECK-NEXT: %[[GET_INDEX:.*]] = "tf.Reshape"(%[[ARG0]], %[[SIZE_SHAPE]]) : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
-  // CHECK-NEXT: %[[SLICE_SIZE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK-NEXT: %[[SLICE_SIZE:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
   // CHECK-NEXT: %[[SLICE:.*]] = "tf.Slice"(%[[BUFFER]], %[[GET_INDEX]], %[[SLICE_SIZE]]) : (tensor<10xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xf32>
-  // CHECK-NEXT: %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
+  // CHECK-NEXT: %[[ELEM_SHAPE:.*]] = "tf.Const"() <{value = dense<> : tensor<0xi32>}> : () -> tensor<0xi32>
   // CHECK-NEXT: %[[ELEM:.*]] = "tf.Reshape"(%[[SLICE]], %[[ELEM_SHAPE]]) : (tensor<1xf32>, tensor<0xi32>) -> tensor<f32>
   %get = "tf.TensorListGetItem"(%tl, %arg0, %elem_shape) : (tensor<!tf_type.variant<tensor<f32>>>, tensor<i32>, tensor<0xi32>) -> tensor<f32>
   // CHECK-NEXT: return %[[ELEM]] : tensor<f32>
@@ -116,7 +116,7 @@
 func.func @main(%arg0: tensor<10x8x9xf32>) -> tensor<2xi64> {
   %elem_shape = "tf.Const"() {value = dense<[8, 9]> : tensor<2xi32>} : () -> tensor<2xi32>
   %tl = "tf.TensorListFromTensor"(%arg0, %elem_shape) : (tensor<10x8x9xf32>, tensor<2xi32>) -> tensor<!tf_type.variant<tensor<8x9xf32>>>
-  // CHECK: %[[SHAPE:.*]] = "tf.Const"() {value = dense<[8, 9]> : tensor<2xi64>} : () -> tensor<2xi64>
+  // CHECK: %[[SHAPE:.*]] = "tf.Const"() <{value = dense<[8, 9]> : tensor<2xi64>}> : () -> tensor<2xi64>
   %shape = "tf.TensorListElementShape"(%tl) : (tensor<!tf_type.variant<tensor<8x9xf32>>>) -> tensor<2xi64>
   // CHECK-NEXT: return %[[SHAPE]] : tensor<2xi64>
   func.return %shape: tensor<2xi64>
@@ -132,7 +132,7 @@
   %elem_shape = "tf.Const"() {value = dense<[8, 9]> : tensor<2xi32>} : () -> tensor<2xi32>
   // CHECK: %[[BUFFER:.*]] = "tf.Identity"(%[[ARG0]]) : (tensor<10x8x9xf32>) -> tensor<10x8x9xf32>
   %tl = "tf.TensorListFromTensor"(%arg0, %elem_shape) : (tensor<10x8x9xf32>, tensor<2xi32>) -> tensor<!tf_type.variant<tensor<8x9xf32>>>
-  // CHECK: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: %[[AXIS:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
   // CHECK: %[[GATHER:.*]] = "tf.GatherV2"(%[[BUFFER]], %[[ARG1]], %[[AXIS]]) : (tensor<10x8x9xf32>, tensor<3xi32>, tensor<i32>) -> tensor<3x8x9xf32>
   %gather = "tf.TensorListGather"(%tl, %arg1, %elem_shape) : (tensor<!tf_type.variant<tensor<8x9xf32>>>, tensor<3xi32>, tensor<2xi32>) -> tensor<3x8x9xf32>
   // CHECK-NEXT: return %[[GATHER]] : tensor<3x8x9xf32>
@@ -149,7 +149,7 @@
   %elem_shape = "tf.Const"() {value = dense<[8, 9]> : tensor<2xi32>} : () -> tensor<2xi32>
   // CHECK: %[[BUFFER:.*]] = "tf.Identity"(%[[ARG0]]) : (tensor<10x8x9xf32>) -> tensor<10x8x9xf32>
   %tl = "tf.TensorListFromTensor"(%arg0, %elem_shape) : (tensor<10x8x9xf32>, tensor<2xi32>) -> tensor<!tf_type.variant<tensor<8x9xf32>>>
-  // CHECK: %[[IND_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 1]> : tensor<2xi32>} : () -> tensor<2xi32>
+  // CHECK: %[[IND_SHAPE:.*]] = "tf.Const"() <{value = dense<[5, 1]> : tensor<2xi32>}> : () -> tensor<2xi32>
   // CHECK: %[[IND_RESHPE:.*]] = "tf.Reshape"(%[[ARG1]], %[[IND_SHAPE]]) : (tensor<5xi32>, tensor<2xi32>) -> tensor<5x1xi32>
   // CHECK: %[[SC:.*]] = "tf.TensorScatterUpdate"(%[[BUFFER]], %[[IND_RESHPE]], %[[ARG2]]) : (tensor<10x8x9xf32>, tensor<5x1xi32>, tensor<5x8x9xf32>) -> tensor<10x8x9xf32>
   %scatter = "tf.TensorListScatterIntoExistingList"(%tl, %arg2, %arg1) : (tensor<!tf_type.variant<tensor<8x9xf32>>>, tensor<5x8x9xf32>, tensor<5xi32>) -> tensor<!tf_type.variant<tensor<8x9xf32>>>
@@ -179,14 +179,14 @@
 }
 // CHECK: func @while_body(%[[BARG0:.*]]: tensor<10xf32>, %[[BARG1:.*]]: tensor<i32>, %[[BARG2:.*]]: tensor<1xi32>)
 func.func @while_body(%arg0: tensor<!tf_type.variant<tensor<f32>>>, %arg1: tensor<i32>) -> (tensor<!tf_type.variant<tensor<f32>>>, tensor<i32>) {
-  // CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: %[[CONST1:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
   %const1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
   // CHECK: %[[SUB:.*]] = "tf.Sub"(%[[BARG1]], %[[CONST1]])
   %sub = "tf.Sub"(%arg1, %const1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
   %elem = "tf._SomeOp"() : () -> tensor<f32>
   // CHECK-NOT: "tf.TensorListPushBack"
   // CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"
-  // CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK: %[[CONST1:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
   // CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[BARG2]], %[[CONST1]])
   // CHECK-NOT: "tf.TensorListPushBack"
   %push = "tf.TensorListPushBack"(%arg0, %elem) : (tensor<!tf_type.variant<tensor<f32>>>, tensor<f32>) -> tensor<!tf_type.variant<tensor<f32>>>
@@ -222,7 +222,7 @@
   %elem = "tf._SomeOp"() : () -> tensor<f32>
   // CHECK-NOT: "tf.TensorListPushBack"
   // CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"
-  // CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK: %[[CONST1:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
   // CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[TARG1]], %[[CONST1]])
   // CHECK-NOT: "tf.TensorListPushBack"
   %push = "tf.TensorListPushBack"(%arg0, %elem) : (tensor<!tf_type.variant<tensor<f32>>>, tensor<f32>) -> tensor<!tf_type.variant<tensor<f32>>>
@@ -234,11 +234,11 @@
   %elem_shape = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
   // CHECK-NOT: "tf.TensorListPopBack"
   // CHECK: %[[COPY:.*]] = "tf.Identity"(%[[EARG0]])
-  // CHECK: %[[CONST1_1:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK: %[[CONST1_1:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
   // CHECK: %[[SUB:.*]] = "tf.Sub"(%[[EARG1]], %[[CONST1_1]])
-  // CHECK: %[[SLICE_SIZE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK: %[[SLICE_SIZE:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
   // CHECK: %[[SLICE:.*]] = "tf.Slice"(%[[COPY]], %[[SUB]], %[[SLICE_SIZE]]) : (tensor<10xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xf32>
-  // CHECK: %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
+  // CHECK: %[[ELEM_SHAPE:.*]] = "tf.Const"() <{value = dense<> : tensor<0xi32>}> : () -> tensor<0xi32>
   // CHECK: %[[ELEM:.*]] = "tf.Reshape"(%[[SLICE]], %[[ELEM_SHAPE]]) : (tensor<1xf32>, tensor<0xi32>) -> tensor<f32>
   // CHECK-NOT: "tf.TensorListPopBack"
   %pop:2 = "tf.TensorListPopBack"(%arg0, %elem_shape) : (tensor<!tf_type.variant<tensor<f32>>>, tensor<0xi32>) -> (tensor<!tf_type.variant<tensor<f32>>>, tensor<f32>)
@@ -269,7 +269,7 @@
   %elem = "tf._SomeOp"() : () -> tensor<f32>
   // CHECK-NOT: "tf.TensorListPushBack"
   // CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"
-  // CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK: %[[CONST1:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
   // CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[TARG1]], %[[CONST1]])
   // CHECK-NOT: "tf.TensorListPushBack"
   %push = "tf.TensorListPushBack"(%arg0, %elem) : (tensor<!tf_type.variant<tensor<f32>>>, tensor<f32>) -> tensor<!tf_type.variant<tensor<f32>>>
@@ -281,11 +281,11 @@
   %elem_shape = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
   // CHECK-NOT: "tf.TensorListPopBack"
   // CHECK: %[[COPY:.*]] = "tf.Identity"(%[[EARG0]])
-  // CHECK: %[[CONST1_1:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK: %[[CONST1_1:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
   // CHECK: %[[SUB:.*]] = "tf.Sub"(%[[EARG1]], %[[CONST1_1]])
-  // CHECK: %[[SLICE_SIZE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK: %[[SLICE_SIZE:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
   // CHECK: %[[SLICE:.*]] = "tf.Slice"(%[[COPY]], %[[SUB]], %[[SLICE_SIZE]]) : (tensor<10xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xf32>
-  // CHECK: %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
+  // CHECK: %[[ELEM_SHAPE:.*]] = "tf.Const"() <{value = dense<> : tensor<0xi32>}> : () -> tensor<0xi32>
   // CHECK: %[[ELEM:.*]] = "tf.Reshape"(%[[SLICE]], %[[ELEM_SHAPE]]) : (tensor<1xf32>, tensor<0xi32>) -> tensor<f32>
   // CHECK-NOT: "tf.TensorListPopBack"
   %pop:2 = "tf.TensorListPopBack"(%arg0, %elem_shape) : (tensor<!tf_type.variant<tensor<f32>>>, tensor<0xi32>) -> (tensor<!tf_type.variant<tensor<f32>>>, tensor<f32>)
@@ -297,11 +297,11 @@
   %elem_shape = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
   // CHECK-NOT: "tf.TensorListPopBack"
   // CHECK: %[[COPY:.*]] = "tf.Identity"(%[[EARG0]])
-  // CHECK: %[[CONST1_1:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK: %[[CONST1_1:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
   // CHECK: %[[SUB:.*]] = "tf.Sub"(%[[EARG1]], %[[CONST1_1]])
-  // CHECK: %[[SLICE_SIZE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK: %[[SLICE_SIZE:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
   // CHECK: %[[SLICE:.*]] = "tf.Slice"(%[[COPY]], %[[SUB]], %[[SLICE_SIZE]]) : (tensor<10xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xf32>
-  // CHECK: %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
+  // CHECK: %[[ELEM_SHAPE:.*]] = "tf.Const"() <{value = dense<> : tensor<0xi32>}> : () -> tensor<0xi32>
   // CHECK: %[[ELEM:.*]] = "tf.Reshape"(%[[SLICE]], %[[ELEM_SHAPE]]) : (tensor<1xf32>, tensor<0xi32>) -> tensor<f32>
   // CHECK-NOT: "tf.TensorListPopBack"
   %pop:2 = "tf.TensorListPopBack"(%arg0, %elem_shape) : (tensor<!tf_type.variant<tensor<f32>>>, tensor<0xi32>) -> (tensor<!tf_type.variant<tensor<f32>>>, tensor<f32>)
@@ -317,7 +317,7 @@
   %size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
   // CHECK-NOT: tf.EmptyTensorList
   %tl = "tf.EmptyTensorList"(%elem_shape, %size) : (tensor<0xi32>, tensor<i32>) -> tensor<!tf_type.variant<tensor<f32>>>
-  %while_op:2 = "tf.WhileRegion"(%tl, %size) ({
+  %while_op:2 = "tf.WhileRegion"(%tl, %size) <{is_stateless = false}> ({
   // CHECK: ^bb0(%[[CARG0:.*]]: tensor<10xf32>, %[[CARG1:.*]]: tensor<i32>, %[[CARG2:.*]]: tensor<1xi32>):
   ^bb0(%arg0: tensor<!tf_type.variant<tensor<f32>>>, %arg1: tensor<i32>):
     // CHECK:   %[[PRED:.*]] = "tf._SomeOp"()
@@ -327,7 +327,7 @@
   },  {
   // CHECK: ^bb0(%[[CARG0:.*]]: tensor<10xf32>, %[[CARG1:.*]]: tensor<i32>, %[[CARG2:.*]]: tensor<1xi32>):
   ^bb0(%arg0: tensor<!tf_type.variant<tensor<f32>>>, %arg1: tensor<i32>):
-    // CHECK:   %[[CST:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+    // CHECK:   %[[CST:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
     // CHECK:   %[[SUB:.*]] = "tf.Sub"(%[[CARG1]], %[[CST]])
     // CHECK:   %[[ELEM:.*]] = "tf._SomeOp"() : () -> tensor<f32>
     %cst = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
@@ -335,14 +335,14 @@
     %elem = "tf._SomeOp"() : () -> tensor<f32>
     // CHECK-NOT: "tf.TensorListPushBack"
     // CHECK:   %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"(%[[CARG0]]
-    // CHECK:   %[[ONE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>}
+    // CHECK:   %[[ONE:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}>
     // CHECK:   %[[ADD:.*]] = "tf.AddV2"(%[[CARG2]], %[[ONE]])
     // CHECK-NOT: "tf.TensorListPushBack"
     // CHECK:   "tf.Yield"(%[[UPDATE]], %[[SUB]], %[[ADD]])
-    // CHECK: }) {is_stateless = false}
+    // CHECK: })
     %push = "tf.TensorListPushBack"(%arg0, %elem) : (tensor<!tf_type.variant<tensor<f32>>>, tensor<f32>) -> tensor<!tf_type.variant<tensor<f32>>>
     "tf.Yield"(%push, %sub) : (tensor<!tf_type.variant<tensor<f32>>>, tensor<i32>) -> ()
-  }) {is_stateless = false} : (tensor<!tf_type.variant<tensor<f32>>>, tensor<i32>) -> (tensor<!tf_type.variant<tensor<f32>>>, tensor<i32>)
+  }) : (tensor<!tf_type.variant<tensor<f32>>>, tensor<i32>) -> (tensor<!tf_type.variant<tensor<f32>>>, tensor<i32>)
   // CHECK: "tf.Slice"
   // CHECK-NOT: tf.TensorListPopBack
   %pop:2 = "tf.TensorListPopBack"(%while_op#0, %elem_shape) : (tensor<!tf_type.variant<tensor<f32>>>, tensor<0xi32>) -> (tensor<!tf_type.variant<tensor<f32>>>, tensor<f32>)
@@ -356,27 +356,27 @@
   %elem_shape = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
   %max_size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
   %tl = "tf.EmptyTensorList"(%elem_shape, %max_size) : (tensor<0xi32>, tensor<i32>) -> tensor<!tf_type.variant<tensor<f32>>>
-  // CHECK: %[[ZERO:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
+  // CHECK: %[[ZERO:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}>
   // CHECK: %[[ZERO_F32:.*]] = "tf.Cast"(%[[ZERO]])
-  // CHECK: %[[MAX_SIZE:.*]] = "tf.Const"() {value = dense<10> : tensor<1xi32>}
+  // CHECK: %[[MAX_SIZE:.*]] = "tf.Const"() <{value = dense<10> : tensor<1xi32>}>
   // CHECK: %[[BUFFER:.*]] = "tf.BroadcastTo"(%[[ZERO_F32]], %[[MAX_SIZE]])
-  // CHECK: %[[BUFFER_SIZE:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi32>}
+  // CHECK: %[[BUFFER_SIZE:.*]] = "tf.Const"() <{value = dense<0> : tensor<1xi32>}>
   // CHECK-NOT: tf.EmptyTensorList
   %if_op = "tf.IfRegion"(%arg0) ({
       %elem = "tf._SomeOp"() : () -> tensor<f32>
       // CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"
-      // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+      // CHECK: %[[ONE:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
       // CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[BUFFER_SIZE]], %[[ONE]])
       // CHECK-NOT: "tf.TensorListPushBack"
       %push = "tf.TensorListPushBack"(%tl, %elem) : (tensor<!tf_type.variant<tensor<f32>>>, tensor<f32>) -> tensor<!tf_type.variant<tensor<f32>>>
       "tf.Yield" (%push) : (tensor<!tf_type.variant<tensor<f32>>>) -> ()
     }, {
       // CHECK:   %[[COPY:.*]] = "tf.Identity"(%[[BUFFER]])
-      // CHECK:   %[[ONE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>}
+      // CHECK:   %[[ONE:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}>
       // CHECK:   %[[SUB:.*]] = "tf.Sub"(%[[BUFFER_SIZE]], %[[ONE]])
-      // CHECK:   %[[SLICE_SIZE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>}
+      // CHECK:   %[[SLICE_SIZE:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}>
       // CHECK:   %[[SLICE:.*]] = "tf.Slice"(%[[COPY]], %[[SUB]], %[[SLICE_SIZE]])
-      // CHECK:   %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>}
+      // CHECK:   %[[ELEM_SHAPE:.*]] = "tf.Const"() <{value = dense<> : tensor<0xi32>}>
       // CHECK:   %[[ELEM:.*]] = "tf.Reshape"(%[[SLICE]], %[[ELEM_SHAPE]])
       // CHECK-NOT: "tf.TensorListPopBack"
       %pop:2 = "tf.TensorListPopBack"(%tl, %elem_shape) : (tensor<!tf_type.variant<tensor<f32>>>, tensor<0xi32>) -> (tensor<!tf_type.variant<tensor<f32>>>, tensor<f32>)
@@ -397,28 +397,28 @@
 func.func @main(%arg0: tensor<i32>) -> () {
   %elem_shape = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
   %max_size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
-  // CHECK: %[[ZERO:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
+  // CHECK: %[[ZERO:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}>
   // CHECK: %[[ZERO_F32:.*]] = "tf.Cast"(%[[ZERO]])
-  // CHECK: %[[MAX_SIZE:.*]] = "tf.Const"() {value = dense<10> : tensor<1xi32>}
+  // CHECK: %[[MAX_SIZE:.*]] = "tf.Const"() <{value = dense<10> : tensor<1xi32>}>
   // CHECK: %[[BUFFER:.*]] = "tf.BroadcastTo"(%[[ZERO_F32]], %[[MAX_SIZE]])
-  // CHECK: %[[BUFFER_SIZE:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi32>}
+  // CHECK: %[[BUFFER_SIZE:.*]] = "tf.Const"() <{value = dense<0> : tensor<1xi32>}>
   // CHECK-NOT: tf.EmptyTensorList
   %tl = "tf.EmptyTensorList"(%elem_shape, %max_size) : (tensor<0xi32>, tensor<i32>) -> tensor<!tf_type.variant<tensor<f32>>>
   %case_op = "tf.CaseRegion"(%arg0) ({
       %elem = "tf._SomeOp"() : () -> tensor<f32>
       // CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"
-      // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+      // CHECK: %[[ONE:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
       // CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[BUFFER_SIZE]], %[[ONE]])
       // CHECK-NOT: "tf.TensorListPushBack"
       %push = "tf.TensorListPushBack"(%tl, %elem) : (tensor<!tf_type.variant<tensor<f32>>>, tensor<f32>) -> tensor<!tf_type.variant<tensor<f32>>>
       "tf.Yield" (%push) : (tensor<!tf_type.variant<tensor<f32>>>) -> ()
     }, {
       // CHECK:   %[[COPY:.*]] = "tf.Identity"(%[[BUFFER]])
-      // CHECK:   %[[ONE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>}
+      // CHECK:   %[[ONE:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}>
       // CHECK:   %[[SUB:.*]] = "tf.Sub"(%[[BUFFER_SIZE]], %[[ONE]])
-      // CHECK:   %[[SLICE_SIZE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>}
+      // CHECK:   %[[SLICE_SIZE:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}>
       // CHECK:   %[[SLICE:.*]] = "tf.Slice"(%[[COPY]], %[[SUB]], %[[SLICE_SIZE]])
-      // CHECK:   %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>}
+      // CHECK:   %[[ELEM_SHAPE:.*]] = "tf.Const"() <{value = dense<> : tensor<0xi32>}>
       // CHECK:   %[[ELEM:.*]] = "tf.Reshape"(%[[SLICE]], %[[ELEM_SHAPE]])
       // CHECK-NOT: "tf.TensorListPopBack"
       %pop:2 = "tf.TensorListPopBack"(%tl, %elem_shape) : (tensor<!tf_type.variant<tensor<f32>>>, tensor<0xi32>) -> (tensor<!tf_type.variant<tensor<f32>>>, tensor<f32>)
@@ -426,11 +426,11 @@
       "tf.Yield" (%pop#0) : (tensor<!tf_type.variant<tensor<f32>>>) -> ()
     }, {
       // CHECK:   %[[COPY:.*]] = "tf.Identity"(%[[BUFFER]])
-      // CHECK:   %[[ONE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>}
+      // CHECK:   %[[ONE:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}>
       // CHECK:   %[[SUB:.*]] = "tf.Sub"(%[[BUFFER_SIZE]], %[[ONE]])
-      // CHECK:   %[[SLICE_SIZE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>}
+      // CHECK:   %[[SLICE_SIZE:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}>
       // CHECK:   %[[SLICE:.*]] = "tf.Slice"(%[[COPY]], %[[SUB]], %[[SLICE_SIZE]])
-      // CHECK:   %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>}
+      // CHECK:   %[[ELEM_SHAPE:.*]] = "tf.Const"() <{value = dense<> : tensor<0xi32>}>
       // CHECK:   %[[ELEM:.*]] = "tf.Reshape"(%[[SLICE]], %[[ELEM_SHAPE]])
       // CHECK-NOT: "tf.TensorListPopBack"
       %pop:2 = "tf.TensorListPopBack"(%tl, %elem_shape) : (tensor<!tf_type.variant<tensor<f32>>>, tensor<0xi32>) -> (tensor<!tf_type.variant<tensor<f32>>>, tensor<f32>)
@@ -482,7 +482,7 @@
 // CHECK: func private @callee_tensorlist_decomposed(%[[ARG0:.*]]: tensor<10xf32>, %[[ARG1:.*]]: tensor<i1>, %[[ARG2:.*]]: tensor<1xi32>) -> (tensor<10xf32>, tensor<1xi32>)
 // CHECK-NOT: "tf.TensorListPushBack"
 // CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"
-// CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+// CHECK: %[[CONST1:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
 // CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[ARG2]], %[[CONST1]])
 // CHECK-NOT: "tf.TensorListPushBack"
 // CHECK: return %[[UPDATE]], %[[ADD]]
@@ -520,7 +520,7 @@
 
   // CHECK-NOT: "tf.TensorListPushBack"
   // CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"
-  // CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK: %[[CONST1:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
   // CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[ARG2]], %[[CONST1]])
   // CHECK-NOT: "tf.TensorListPushBack"
   %push = "tf.TensorListPushBack"(%arg0, %elem) : (tensor<!tf_type.variant<tensor<f32>>>, tensor<f32>) -> tensor<!tf_type.variant<tensor<f32>>>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir
index 22de9ee..7605f03 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir
@@ -1185,7 +1185,7 @@
 
 // Test invalid tf.Yield operation (parent should be IfRegion)
 func.func @testInvalidYieldOp(%arg0: f32) -> () {
-  // expected-error @+1 {{'tf.Yield' op expects parent op to be one of 'tf.CaseRegion, tf.IfRegion, tf.WhileRegion'}}
+  // expected-error @+1 {{'tf.Yield' op expects parent op to be one of 'tf.CaseRegion, tf.IfRegion, tf.WhileRegion, tf.GeneratorDatasetRegion'}}
   "tf.Yield"(%arg0) : (f32) -> ()
 }
 
@@ -5180,3 +5180,41 @@
   "tf.XlaCallModule"() {Sout = [], device = "", dim_args_spec = [], function_list = [@undefined_function], module = "", platforms = [], version = 4 : i64} : () -> ()
   func.return
 }
+
+// -----
+
+func.func @init(%arg0: tensor<4xf32>) -> tensor<7xf32> {
+    %0 = builtin.unrealized_conversion_cast to tensor<7xf32>
+    return %0 : tensor<7xf32>
+}
+
+func.func @next(%arg0: tensor<7xf32>, %arg1: tensor<3xf32>) -> tensor<6xf32> {
+    %0 = builtin.unrealized_conversion_cast to tensor<6xf32>
+    return %0 : tensor<6xf32>
+}
+
+func.func @finalize(%arg0: tensor<6xf32>, %arg1: tensor<2xf32>) -> tensor<5xf32> {
+    %0 = builtin.unrealized_conversion_cast to tensor<5xf32>
+    return %0 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func @testGeneratorDataset
+func.func @testGeneratorDataset(%arg0: tensor<4xf32>,
+                                %arg1: tensor<3xf32>,
+                                %arg2: tensor<!tf_type.resource>,
+                                %arg3: tensor<2xf32>) -> tensor<!tf_type.variant> {
+  %0 = "tf.GeneratorDataset"(%arg0, %arg1, %arg2, %arg3) {
+      device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0",
+      finalize_func = @finalize,
+      init_func = @init,
+      next_func = @next,
+      operandSegmentSizes = array<i32: 1, 2, 1>,
+      output_shapes = [#tf_type.shape<>],
+      output_types = [!tf_type.string],
+      metadata = ""} : (
+              tensor<4xf32>,
+              tensor<3xf32>,
+              tensor<!tf_type.resource>,
+              tensor<2xf32>) -> tensor<!tf_type.variant>
+  return %0 : tensor<!tf_type.variant>
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_data_fuse_map_and_batch.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_data_fuse_map_and_batch.mlir
index af605e9..c6f63db 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf_data_fuse_map_and_batch.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_data_fuse_map_and_batch.mlir
@@ -6,7 +6,7 @@
   %0 = "tf.Const"() {value = dense<5> : tensor<i64>} : () -> tensor<i64>
   %1 = "tf.Const"() {value = dense<false> : tensor<i1>} : () -> tensor<i1>
   %2 = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
-  // CHECK: %[[NPC:.*]] = "tf.Const"() {value = dense<1> : tensor<i64>}
+  // CHECK: %[[NPC:.*]] = "tf.Const"() <{value = dense<1> : tensor<i64>}>
   // CHECK: %[[TSLICE:.*]] = "tf.TensorSliceDataset"
   %3 = "tf.TensorSliceDataset"(%2) {device = "", output_shapes = [#tf_type.shape<>], metadata = ""} : (tensor<3xi32>) -> tensor<*x!tf_type.variant>
   // CHECK: "tf.MapAndBatchDataset"(%[[TSLICE]], %[[BSIZE:.*]], %[[NPC]]
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_optimize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_optimize.mlir
index 3dacd11..d0a5a74 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf_optimize.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_optimize.mlir
@@ -12,7 +12,7 @@
   // CHECK-SAME: [1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00], [5.000000e+00, 1.200000e+01]
   // CHECK-SAME: [7.000000e+00, 1.600000e+01], [9.000000e+00, 2.000000e+01], [1.100000e+01, 2.400000e+01]
   // CHECK-SAME: [1.300000e+01, 2.800000e+01], [1.500000e+01, 3.200000e+01], [1.700000e+01, 3.600000e+01]
-  // CHECK: %[[CONV:.*]] = "tf.Conv2D"(%arg0, %[[CST]]) {data_format = "NHWC", dilations = [1, 2, 3, 1], explicit_paddings = [], padding = "SAME", strides = [1, 4, 5, 1], use_cudnn_on_gpu = true}
+  // CHECK: %[[CONV:.*]] = "tf.Conv2D"(%arg0, %[[CST]]) <{data_format = "NHWC", dilations = [1, 2, 3, 1], explicit_paddings = [], padding = "SAME", strides = [1, 4, 5, 1], use_cudnn_on_gpu = true}>
   // CHECK: return %[[CONV]] : tensor<1x28x23x2xf32>
 }
 
@@ -26,7 +26,7 @@
 
   func.return %1 : tensor<1x28x23x2xf32>
   // CHECK: %cst_0 = arith.constant dense<3.000000e+00> : tensor<23x2xf32>
-  // CHECK: %0 = "tf.Conv2D"(%arg0, %cst) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]}
+  // CHECK: %0 = "tf.Conv2D"(%arg0, %cst) <{data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]}> {T = "tfdtype$DT_FLOAT"}
   // CHECK: %1 = "tf.Mul"(%0, %cst_0) : (tensor<1x28x23x2xf32>, tensor<23x2xf32>) -> tensor<1x28x23x2xf32>
   // CHECK: return %1 : tensor<1x28x23x2xf32>
 }
@@ -40,8 +40,8 @@
   %98 = "tf.Reshape"(%97, %cst_2) : (tensor<1x8x6x1x6x1x1x18xbf16>, tensor<4xi64>) -> tensor<8x6x6x18xbf16>
   func.return %98 : tensor<8x6x6x18xbf16>
 
-  // CHECK-DAG: %[[CST:.*]] = "tf.Const"() {value = dense<[8, 1, 1, 18]> : tensor<4xi64>} : () -> tensor<4xi64>
-  // CHECK-DAG: %[[CST1:.*]] =  "tf.Const"() {value = dense<[8, 6, 6, 18]> : tensor<4xi64>} : () -> tensor<4xi64>
+  // CHECK-DAG: %[[CST:.*]] = "tf.Const"() <{value = dense<[8, 1, 1, 18]> : tensor<4xi64>}> : () -> tensor<4xi64>
+  // CHECK-DAG: %[[CST1:.*]] =  "tf.Const"() <{value = dense<[8, 6, 6, 18]> : tensor<4xi64>}> : () -> tensor<4xi64>
   // CHECK: %[[RESHAPE:.*]] = "tf.Reshape"(%arg0, %[[CST]]) : (tensor<1x8x1x1x1x1x1x18xbf16>, tensor<4xi64>) -> tensor<8x1x1x18xbf16>
   // CHECK: %[[BROADCAST:.*]] = "tf.BroadcastTo"(%[[RESHAPE]], %[[CST1]]) : (tensor<8x1x1x18xbf16>, tensor<4xi64>) -> tensor<8x6x6x18xbf16>
   // CHECK: return %[[BROADCAST]] : tensor<8x6x6x18xbf16>
@@ -55,8 +55,8 @@
   %98 = "tf.Reshape"(%97, %cst_2) : (tensor<7x1x8x6x1x6x1x1x18xbf16>, tensor<5xi64>) -> tensor<7x8x6x6x18xbf16>
   func.return %98 : tensor<7x8x6x6x18xbf16>
 
-  // CHECK-DAG: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 8, 1, 1, 18]> : tensor<5xi64>} : () -> tensor<5xi64>
-  // CHECK-DAG: %[[CST1:.*]] =  "tf.Const"() {value = dense<[7, 8, 6, 6, 18]> : tensor<5xi64>} : () -> tensor<5xi64>
+  // CHECK-DAG: %[[CST:.*]] = "tf.Const"() <{value = dense<[1, 8, 1, 1, 18]> : tensor<5xi64>}> : () -> tensor<5xi64>
+  // CHECK-DAG: %[[CST1:.*]] =  "tf.Const"() <{value = dense<[7, 8, 6, 6, 18]> : tensor<5xi64>}> : () -> tensor<5xi64>
   // CHECK: %[[RESHAPE:.*]] = "tf.Reshape"(%arg0, %[[CST]]) : (tensor<1x8x1x1x1x1x1x18xbf16>, tensor<5xi64>) -> tensor<1x8x1x1x18xbf16>
   // CHECK: %[[BROADCAST:.*]] = "tf.BroadcastTo"(%[[RESHAPE]], %[[CST1]]) : (tensor<1x8x1x1x18xbf16>, tensor<5xi64>) -> tensor<7x8x6x6x18xbf16>
   // CHECK: return %[[BROADCAST]] : tensor<7x8x6x6x18xbf16>
@@ -70,8 +70,8 @@
   %98 = "tf.Reshape"(%97, %cst_2) : (tensor<1x1x6x1x6x1x1x18xbf16>, tensor<5xi64>) -> tensor<1x6x1x6x18xbf16>
   func.return %98 : tensor<1x6x1x6x18xbf16>
 
-  // CHECK-DAG: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 1, 1, 1, 18]> : tensor<5xi64>} : () -> tensor<5xi64>
-  // CHECK-DAG: %[[CST1:.*]] = "tf.Const"() {value = dense<[1, 6, 1, 6, 18]> : tensor<5xi64>} : () -> tensor<5xi64>
+  // CHECK-DAG: %[[CST:.*]] = "tf.Const"() <{value = dense<[1, 1, 1, 1, 18]> : tensor<5xi64>}> : () -> tensor<5xi64>
+  // CHECK-DAG: %[[CST1:.*]] = "tf.Const"() <{value = dense<[1, 6, 1, 6, 18]> : tensor<5xi64>}> : () -> tensor<5xi64>
   // CHECK: %[[RESHAPE:.*]] = "tf.Reshape"(%arg0, %[[CST]]) : (tensor<1x1x1x1x1x1x1x18xbf16>, tensor<5xi64>) -> tensor<1x1x1x1x18xbf16>
   // CHECK: %[[BROADCAST:.*]] = "tf.BroadcastTo"(%[[RESHAPE]], %[[CST1]]) : (tensor<1x1x1x1x18xbf16>, tensor<5xi64>) -> tensor<1x6x1x6x18xbf16>
   // CHECK: return %[[BROADCAST]] : tensor<1x6x1x6x18xbf16>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic.py
index 1303480..118d7f3 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic.py
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic.py
@@ -39,9 +39,9 @@
     self.c43 = tf.constant(43.0)
 
   # During serialization, the constants are given internal (non-user-accessible, non-semantically-load-bearing) exported names.
-  # CHECK: "tf_saved_model.global_tensor"() {sym_name = "[[CONST:[a-zA-Z_0-9.]+]]", tf_saved_model.exported_names = [{{.*}}], type = tensor<f32>, value = dense<4.300000e+01> : tensor<f32>} : () -> ()
+  # CHECK: "tf_saved_model.global_tensor"() <{sym_name = "[[CONST:[a-zA-Z_0-9.]+]]", type = tensor<f32>, value = dense<4.300000e+01> : tensor<f32>}> {tf_saved_model.exported_names = [{{.*}}]} : () -> ()
 
-  # CHECK: "tf_saved_model.global_tensor"() {is_mutable, sym_name = "[[VAR:[a-zA-Z_0-9]+]]", tf_saved_model.exported_names = ["v42"], type = tensor<f32>, value = dense<4.200000e+01> : tensor<f32>} : () -> ()
+  # CHECK: "tf_saved_model.global_tensor"() <{is_mutable, sym_name = "[[VAR:[a-zA-Z_0-9]+]]", type = tensor<f32>, value = dense<4.200000e+01> : tensor<f32>}> {tf_saved_model.exported_names = ["v42"]} : () -> ()
   # CHECK:      func {{@[a-zA-Z_0-9]+}}(
   # CHECK-SAME:   %arg0: tensor<f32> {tf._user_specified_name = "x", tf_saved_model.index_path = [0]},
   # CHECK-SAME:   %arg1: tensor<!tf_type.resource<tensor<f32>>>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic_v1.py
index 9b71d46..2fc0670 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic_v1.py
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic_v1.py
@@ -29,7 +29,7 @@
 # CHECK-SAME: min_consumer
 # CHECK-SAME: producer
 
-# CHECK: "tf_saved_model.global_tensor"() {is_mutable, sym_name = "[[VAR:[a-zA-Z_0-9]+]]", type = tensor<1x3xf32>, value = {{.*}} : tensor<1x3xf32>} : () -> ()
+# CHECK: "tf_saved_model.global_tensor"() <{is_mutable, sym_name = "[[VAR:[a-zA-Z_0-9]+]]", type = tensor<1x3xf32>, value = {{.*}} : tensor<1x3xf32>}> : () -> ()
 
 # CHECK:      func {{@[a-zA-Z_0-9]+}}(
 # CHECK-SAME:   [[ARG0:%.*]]: tensor<3x1xf32> {tf_saved_model.index_path = ["x"]},
@@ -38,7 +38,7 @@
 # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["key"]
 
 # CHECK-NEXT: [[R0:%.*]] = "tf.ReadVariableOp"([[ARG1]]) {{{.*}}} : (tensor<!tf_type.resource<tensor<1x3xf32>>>) -> tensor<1x3xf32>
-# CHECK-NEXT: [[R1:%.*]] = "tf.MatMul"([[ARG0]], [[R0]]) {{{.*}}} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32>
+# CHECK-NEXT: [[R1:%.*]] = "tf.MatMul"([[ARG0]], [[R0]]) <{{{.*}}}> {device = ""} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32>
 # CHECK-NEXT: return [[R1]] : tensor<3x3xf32>
 
 
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_asset_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_asset_v1.py
index bd48dfd..1f533ec 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_asset_v1.py
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_asset_v1.py
@@ -22,9 +22,9 @@
 import tensorflow.compat.v1 as tf
 from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1
 
-# CHECK: "tf_saved_model.session_initializer"() {initializers = [@[[init:.*]]]} : () -> ()
-# CHECK: "tf_saved_model.asset"() {filename = {{.*}}, sym_name = "[[asset1:__tf_saved_model_asset1_.*]]"}
-# CHECK: "tf_saved_model.asset"() {filename = {{.*}}, sym_name = "[[asset0:__tf_saved_model_asset0_.*]]"}
+# CHECK: "tf_saved_model.session_initializer"() <{initializers = [@[[init:.*]]]}> : () -> ()
+# CHECK: "tf_saved_model.asset"() <{filename = {{.*}}, sym_name = "[[asset1:__tf_saved_model_asset1_.*]]"}>
+# CHECK: "tf_saved_model.asset"() <{filename = {{.*}}, sym_name = "[[asset0:__tf_saved_model_asset0_.*]]"}>
 
 # CHECK:      func @[[init]]
 # CHECK-SAME: [[ARG0:%.*]]: tensor<!tf_type.string> {tf_saved_model.bound_input = @[[asset0]]}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_v1.py
index b3cd46c..4383b1b 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_v1.py
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_v1.py
@@ -30,12 +30,12 @@
 # CHECK-SAME: producer
 
 # CHECK: "tf_saved_model.global_tensor"()
-# CHECK: "tf_saved_model.session_initializer"() {initializers = [@[[init:.*]]]} : () -> ()
+# CHECK: "tf_saved_model.session_initializer"() <{initializers = [@[[init:.*]]]}> : () -> ()
 
 # CHECK:      func @[[init]]
 # CHECK-SAME: tf_saved_model.initializer_type = "init_op"
 # CHECK-NEXT: [[R6:%.*]] = "tf.Const"()
-# CHECK-NEXT: [[R5:%.*]] = "tf.Const"() {device = "", value = dense<[1, 2,
+# CHECK-NEXT: [[R5:%.*]] = "tf.Const"() <{value = dense<[1, 2,
 # CHECK-NEXT: [[R7:%.*]] = "tf.HashTableV2"()
 # CHECK-SAME: shared_name = "[[hash_table:.*]]"
 # CHECK-NEXT: "tf.LookupTableImportV2"([[R7]], [[R5]], [[R6]])
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/import_restore_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/import_restore_v1.py
index cc8b5ad..fe2c692 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/import_restore_v1.py
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/import_restore_v1.py
@@ -33,7 +33,7 @@
 # CHECK-SAME: initializers = [@[[restore:.*]]]
 
 # CHECK: "tf_saved_model.asset"()
-# CHECK-SAME: {filename = [[filename:.*]], sym_name = "[[sym_name:.*]]"} : () -> ()
+# CHECK-SAME: <{filename = [[filename:.*]], sym_name = "[[sym_name:.*]]"}> : () -> ()
 
 # CHECK:      func @[[restore]](
 # CHECK-SAME:   [[variable_path:%.*]]: tensor<!tf_type.string> {tf_saved_model.bound_input = @[[sym_name]]}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/include_variables_in_init_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/include_variables_in_init_v1.py
index 2f99fae..c7957a6 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/include_variables_in_init_v1.py
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/include_variables_in_init_v1.py
@@ -30,7 +30,7 @@
 # CHECK-SAME: producer
 
 # CHECK: "tf_saved_model.global_tensor"()
-# CHECK: "tf_saved_model.session_initializer"() {initializers = [@[[INIT_FUNC:[a-zA-Z_0-9]+]]]} : () -> ()
+# CHECK: "tf_saved_model.session_initializer"() <{initializers = [@[[INIT_FUNC:[a-zA-Z_0-9]+]]]}> : () -> ()
 
 # Initializer function. This should contain the initialization sequence for the
 # variable.
@@ -38,7 +38,7 @@
 # CHECK-SAME: tf_saved_model.exported_names = ["__tf_saved_model_session_initializer_init"]
 # CHECK-SAME: tf_saved_model.initializer_type = "init_op"
 # CHECK-SAME: }
-# CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() {{{.*dense<.*> : tensor<2xi32>.*}}} : () -> tensor<2xi32>
+# CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() <{{{.*dense<.*> : tensor<2xi32>.*}}}> {{{.*}}} : () -> tensor<2xi32>
 # CHECK: %[[RAND_STD_NORMAL:.*]] = "tf.RandomStandardNormal"(%[[CST_0]])
 # CHECK: "tf.AssignVariableOp"(%[[ARG_0]], %[[RAND_STD_NORMAL]]){{.*}}: (tensor<!tf_type.resource<tensor<1x3xf32>>>, tensor<1x3xf32>) -> ()
 # CHECK: return
@@ -50,7 +50,7 @@
 # CHECK-SAME: -> (tensor<3x3xf32> {tf_saved_model.index_path = ["r"]})
 # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["key"]
 # CHECK-NEXT: %[[READ_VAR_0:.*]] = "tf.ReadVariableOp"(%[[ARG_2]]) {{{.*}}} : (tensor<!tf_type.resource<tensor<1x3xf32>>>) -> tensor<1x3xf32>
-# CHECK-NEXT: %[[MATMUL_0:.*]] = "tf.MatMul"(%[[ARG_1]], %[[READ_VAR_0]]) {{{.*}}} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32>
+# CHECK-NEXT: %[[MATMUL_0:.*]] = "tf.MatMul"(%[[ARG_1]], %[[READ_VAR_0]]) <{{{.*}}}> {{{.*}}} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32>
 # CHECK-NEXT: return %[[MATMUL_0]] : tensor<3x3xf32>
 
 
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/multi_variables_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/multi_variables_v1.py
index 4ce47a3..b3c3a20 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/multi_variables_v1.py
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/multi_variables_v1.py
@@ -19,8 +19,8 @@
 import tensorflow.compat.v1 as tf
 from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1
 
-# CHECK: "tf_saved_model.global_tensor"() {is_mutable, sym_name = "[[VAR0:[a-zA-Z_0-9]+]]", type = tensor<5x3xf32>, value = {{.*}} : tensor<5x3xf32>} : () -> ()
-# CHECK: "tf_saved_model.global_tensor"() {is_mutable, sym_name = "[[VAR1:[a-zA-Z_0-9]+]]", type = tensor<3x5xf32>, value = {{.*}} : tensor<3x5xf32>} : () -> ()
+# CHECK: "tf_saved_model.global_tensor"() <{is_mutable, sym_name = "[[VAR0:[a-zA-Z_0-9]+]]", type = tensor<5x3xf32>, value = {{.*}} : tensor<5x3xf32>}> : () -> ()
+# CHECK: "tf_saved_model.global_tensor"() <{is_mutable, sym_name = "[[VAR1:[a-zA-Z_0-9]+]]", type = tensor<3x5xf32>, value = {{.*}} : tensor<3x5xf32>}> : () -> ()
 # CHECK:      func {{@[a-zA-Z_0-9]+}}(
 # CHECK-SAME:   [[ARG0:%.*]]: tensor<!tf_type.resource<tensor<5x3xf32>>> {tf_saved_model.bound_input = @[[VAR0]]},
 # CHECK-SAME:   [[ARG1:%.*]]: tensor<!tf_type.resource<tensor<3x5xf32>>> {tf_saved_model.bound_input = @[[VAR1]]})
@@ -29,7 +29,7 @@
 
 # CHECK-NEXT: [[R0:%.*]] = "tf.ReadVariableOp"([[ARG0]]) {{{.*}}} : (tensor<!tf_type.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
 # CHECK-NEXT: [[R1:%.*]] = "tf.ReadVariableOp"([[ARG1]]) {{{.*}}} : (tensor<!tf_type.resource<tensor<3x5xf32>>>) -> tensor<3x5xf32>
-# CHECK-NEXT: [[R2:%.*]] = "tf.MatMul"([[R0]], [[R1]]) {{{.*}}} : (tensor<5x3xf32>, tensor<3x5xf32>) -> tensor<5x5xf32>
+# CHECK-NEXT: [[R2:%.*]] = "tf.MatMul"([[R0]], [[R1]]) <{{{.*}}}> {{{.*}}} : (tensor<5x3xf32>, tensor<3x5xf32>) -> tensor<5x5xf32>
 
 
 def Test():
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/partially_shaped_variables.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/partially_shaped_variables.py
index 8b01372..a4238d5 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/partially_shaped_variables.py
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/partially_shaped_variables.py
@@ -24,8 +24,8 @@
 
   def __init__(self):
     super(TestModule, self).__init__()
-    # CHECK: "tf_saved_model.global_tensor"() {is_mutable, {{.*}} tf_saved_model.exported_names = ["v0"], type = tensor<*xf32>, value = dense<0.000000e+00> : tensor<1xf32>} : () -> ()
-    # CHECK: "tf_saved_model.global_tensor"() {is_mutable, {{.*}} tf_saved_model.exported_names = ["v1"], type = tensor<?xf32>, value = dense<[0.000000e+00, 1.000000e+00]> : tensor<2xf32>} : () -> ()
+    # CHECK: "tf_saved_model.global_tensor"() <{is_mutable, {{.*}} type = tensor<*xf32>, value = dense<0.000000e+00> : tensor<1xf32>}> {tf_saved_model.exported_names = ["v0"]} : () -> ()
+    # CHECK: "tf_saved_model.global_tensor"() <{is_mutable, {{.*}} type = tensor<?xf32>, value = dense<[0.000000e+00, 1.000000e+00]> : tensor<2xf32>}> {tf_saved_model.exported_names = ["v1"]} : () -> ()
     self.v0 = tf.Variable([0.], shape=tf.TensorShape(None))
     self.v1 = tf.Variable([0., 1.], shape=[None])
 
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/remove_init_variable_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/remove_init_variable_v1.py
index 123a11a..b2be608 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/remove_init_variable_v1.py
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/remove_init_variable_v1.py
@@ -29,7 +29,7 @@
 # CHECK-SAME: min_consumer
 # CHECK-SAME: producer
 
-# CHECK: "tf_saved_model.global_tensor"() {is_mutable, sym_name = "[[VAR:[a-zA-Z_0-9]+]]", type = tensor<1x3xf32>, value = {{.*}} : tensor<1x3xf32>} : () -> ()
+# CHECK: "tf_saved_model.global_tensor"() <{is_mutable, sym_name = "[[VAR:[a-zA-Z_0-9]+]]", type = tensor<1x3xf32>, value = {{.*}} : tensor<1x3xf32>}> : () -> ()
 # CHECK-NOT: session_initializer
 
 # CHECK:      func {{@[a-zA-Z_0-9]+}}(
@@ -39,7 +39,7 @@
 # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["key"]
 
 # CHECK-NEXT: [[R0:%.*]] = "tf.ReadVariableOp"([[ARG1]]) {{{.*}}} : (tensor<!tf_type.resource<tensor<1x3xf32>>>) -> tensor<1x3xf32>
-# CHECK-NEXT: [[R1:%.*]] = "tf.MatMul"([[ARG0]], [[R0]]) {{{.*}}} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32>
+# CHECK-NEXT: [[R1:%.*]] = "tf.MatMul"([[ARG0]], [[R0]]) <{{{.*}}}> {{{.*}}} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32>
 # CHECK-NEXT: return [[R1]] : tensor<3x3xf32>
 
 
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/shared_variable_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/shared_variable_v1.py
index 6c47dd5..2382a37 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/shared_variable_v1.py
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/shared_variable_v1.py
@@ -19,7 +19,7 @@
 import tensorflow.compat.v1 as tf
 from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1
 
-# CHECK: "tf_saved_model.global_tensor"() {is_mutable, sym_name = "[[VAR:[a-zA-Z_0-9]+]]", type = tensor<1x3xf32>, value = {{.*}} : tensor<1x3xf32>} : () -> ()
+# CHECK: "tf_saved_model.global_tensor"() <{is_mutable, sym_name = "[[VAR:[a-zA-Z_0-9]+]]", type = tensor<1x3xf32>, value = {{.*}} : tensor<1x3xf32>}> : () -> ()
 
 # CHECK:      func {{@[a-zA-Z_0-9]+}}(
 # CHECK-SAME:   [[ARG0:%.*]]: tensor<3x1xf32> {tf_saved_model.index_path = ["x"]},
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_asset_sinking.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_asset_sinking.mlir
index 2638aab..94e0da3 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_asset_sinking.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_asset_sinking.mlir
@@ -10,8 +10,8 @@
 
   // CHECK: func @init()
   func.func @init(%arg0: tensor<!tf_type.string> {tf_saved_model.bound_input = @asset0}, %arg1: tensor<!tf_type.string> {tf_saved_model.bound_input = @asset1}) attributes {tf_saved_model.exported_names = ["init"]} {
-    // CHECK-DAG: %[[ASSET0:.*]] = "tf.Const"() {value = dense<"foo/bar/assets/test0.txt"> : tensor<!tf_type.string>}
-    // CHECK-DAG: %[[ASSET1:.*]] = "tf.Const"() {value = dense<"foo/bar/assets/test1.txt"> : tensor<!tf_type.string>}
+    // CHECK-DAG: %[[ASSET0:.*]] = "tf.Const"() <{value = dense<"foo/bar/assets/test0.txt"> : tensor<!tf_type.string>}>
+    // CHECK-DAG: %[[ASSET1:.*]] = "tf.Const"() <{value = dense<"foo/bar/assets/test1.txt"> : tensor<!tf_type.string>}>
 
     // CHECK: %[[VAR0:.*]] = "tf.VarHandleOp"()
     %0 = "tf.VarHandleOp"() {container = "", shared_name = "var0"} : () -> tensor<!tf_type.resource<tensor<!tf_type.string>>>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_assets.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_assets.mlir
index eb4aed8..982ace4 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_assets.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_assets.mlir
@@ -11,7 +11,7 @@
   attributes {tf_saved_model.exported_names = ["f"]} {
     %0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf_type.string, shared_name = "", use_node_name_sharing = false, value_dtype = i64} : () -> tensor<!tf_type.resource>
     "tf.InitializeTableFromTextFileV2"(%0, %arg0) {delimiter = "\09", device = "", key_index = -2 : i64, offset = 0 : i64, value_index = -1 : i64, vocab_size = 437 : i64} : (tensor<!tf_type.resource>, tensor<!tf_type.string>) -> ()
-    // CHECK: [[CST:%.+]] = "tf.Const"() {value = dense<"assets/table.txt"> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string>
+    // CHECK: [[CST:%.+]] = "tf.Const"() <{value = dense<"assets/table.txt"> : tensor<1x!tf_type.string>}> : () -> tensor<1x!tf_type.string>
     // CHECK: [[HASHTABLE:%.+]] = "tf.HashTableV2"()
     // CHECK: "tf.InitializeTableFromTextFileV2"([[HASHTABLE]], [[CST]])
     func.return
@@ -69,8 +69,8 @@
     "tf.InitializeTableFromTextFileV2"(%0, %arg0) {delimiter = "\09", device = "", key_index = -2 : i64, offset = 0 : i64, value_index = -1 : i64, vocab_size = 437 : i64} : (tensor<!tf_type.resource>, tensor<!tf_type.string>) -> ()
     %1 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf_type.string, shared_name = "", use_node_name_sharing = false, value_dtype = i64} : () -> tensor<!tf_type.resource>
     "tf.InitializeTableFromTextFileV2"(%1, %arg1) {delimiter = "\09", device = "", key_index = -2 : i64, offset = 0 : i64, value_index = -1 : i64, vocab_size = 437 : i64} : (tensor<!tf_type.resource>, tensor<!tf_type.string>) -> ()
-    // CHECK-DAG: [[CST_1:%.+]] = "tf.Const"() {value = dense<"assets/table2.txt"> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string>
-    // CHECK-DAG: [[CST:%.+]] = "tf.Const"() {value = dense<"assets/table.txt"> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string>
+    // CHECK-DAG: [[CST_1:%.+]] = "tf.Const"() <{value = dense<"assets/table2.txt"> : tensor<1x!tf_type.string>}> : () -> tensor<1x!tf_type.string>
+    // CHECK-DAG: [[CST:%.+]] = "tf.Const"() <{value = dense<"assets/table.txt"> : tensor<1x!tf_type.string>}> : () -> tensor<1x!tf_type.string>
     // CHECK: [[HASHTABLE:%.+]] = "tf.HashTableV2"()
     // CHECK: "tf.InitializeTableFromTextFileV2"([[HASHTABLE]], [[CST]])
     // CHECK: [[HASHTABLE_1:%.+]] = "tf.HashTableV2"()
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir
index 67a5439..9c0f9b2 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir
@@ -12,7 +12,7 @@
   func.func @f(%arg0: tensor<!tf_type.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
   attributes {tf_saved_model.exported_names = ["f"]} {
     %val = "tf.ReadVariableOp"(%arg0) : (tensor<!tf_type.resource<tensor<f32>>>) -> tensor<f32>
-    // CHECK: "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
+    // CHECK: "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}>
     func.return
   }
 }
@@ -67,7 +67,7 @@
   }
 
   func.func private @f_callee(%arg0: tensor<!tf_type.resource<tensor<f32>>>) {
-    // CHECK: "tf.Const"() {value = dense<2.100000e+01> : tensor<f32>}
+    // CHECK: "tf.Const"() <{value = dense<2.100000e+01> : tensor<f32>}>
     func.return
   }
 }
@@ -90,7 +90,7 @@
 
   func.func private @g_callee(%arg0: tensor<!tf_type.resource<tensor<f32>>>) {
     %val = "tf.ReadVariableOp"(%arg0) : (tensor<!tf_type.resource<tensor<f32>>>) -> tensor<f32>
-    // CHECK: "tf.Const"() {value = dense<3.200000e+01> : tensor<f32>}
+    // CHECK: "tf.Const"() <{value = dense<3.200000e+01> : tensor<f32>}>
     func.return
   }
 }
@@ -146,10 +146,10 @@
 
   func.func @f(%arg1: tensor<!tf_type.resource<tensor<f32>>> {tf_saved_model.bound_input = @"v"}, %arg2: tensor<!tf_type.resource<tensor<f32>>> {tf_saved_model.bound_input = @"v2"})
   attributes {tf_saved_model.exported_names = ["f"]} {
-    // CHECK-DAG: "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
+    // CHECK-DAG: "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}>
     %0 = "tf.ReadVariableOp"(%arg1) {device = ""} : (tensor<!tf_type.resource<tensor<f32>>>) -> tensor<f32>
 
-    // CHECK-DAG: "tf.Const"() {value = dense<2.000000e+00> : tensor<f32>}
+    // CHECK-DAG: "tf.Const"() <{value = dense<2.000000e+00> : tensor<f32>}>
     %1 = "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor<!tf_type.resource<tensor<f32>>>) -> tensor<f32>
     func.return
   }
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_mark_initialized_variables.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_mark_initialized_variables.mlir
index 51368f2..4f69679 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_mark_initialized_variables.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_mark_initialized_variables.mlir
@@ -15,14 +15,14 @@
     func.return %4 : tensor<100x50xf32>
   }
   // CHECK: "tf.VarHandleOp"
-  // CHECK-SAME: _is_initialized = true
   // CHECK-SAME: shared_name = "var1"
-  // CHECK: "tf.VarHandleOp"
   // CHECK-SAME: _is_initialized = true
-  // CHECK-SAME: shared_name = "var2"
   // CHECK: "tf.VarHandleOp"
-  // CHECK-SAME: _is_initialized = false
+  // CHECK-SAME: shared_name = "var2"
+  // CHECK-SAME: _is_initialized = true
+  // CHECK: "tf.VarHandleOp"
   // CHECK-SAME: shared_name = "var3"
+  // CHECK-SAME: _is_initialized = false
 
   // INVALID-NOT: _is_initialized
 }
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors.mlir
index ae61fc9..ea26be8 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors.mlir
@@ -8,9 +8,9 @@
 
   // Test case: Basic test of marking immutable.
 
-  // CHECK: "tf_saved_model.global_tensor"() {
+  // CHECK: "tf_saved_model.global_tensor"() <{
   // CHECK-NOT: is_mutable
-  // CHECK-SAME: } : () -> ()
+  // CHECK-SAME: }> : () -> ()
   "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
 
   func.func @f(%arg0: tensor<!tf_type.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
@@ -27,9 +27,9 @@
 
   // Test case: Don't mark immutable if the variable is mutated.
 
-  // CHECK: "tf_saved_model.global_tensor"() {
+  // CHECK: "tf_saved_model.global_tensor"() <{
   // CHECK-SAME: is_mutable
-  // CHECK-SAME: } : () -> ()
+  // CHECK-SAME: }> : () -> ()
   "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
 
   func.func @f(%arg0: tensor<!tf_type.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
@@ -47,9 +47,9 @@
 
   // Test case: Don't mark immutable if the variable is exported.
 
-  // CHECK: "tf_saved_model.global_tensor"() {
+  // CHECK: "tf_saved_model.global_tensor"() <{
   // CHECK: is_mutable
-  // CHECK-SAME: } : () -> ()
+  // CHECK-SAME: }> {{{.*}}} : () -> ()
   "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", tf_saved_model.exported_names = ["v"], type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
 
   func.func @f(%arg0: tensor<!tf_type.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
@@ -148,9 +148,9 @@
 // Test use as an input in unhandled op
 module attributes {tf_saved_model.semantics} {
 
-  // CHECK: "tf_saved_model.global_tensor"() {
+  // CHECK: "tf_saved_model.global_tensor"() <{
   // CHECK-SAME: is_mutable
-  // CHECK-SAME: } : () -> ()
+  // CHECK-SAME: }> : () -> ()
   "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
 
   func.func @f(%arg0: tensor<!tf_type.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
@@ -166,9 +166,9 @@
 // Test use as a region capture in an unhandled op
 module attributes {tf_saved_model.semantics} {
 
-  // CHECK: "tf_saved_model.global_tensor"() {
+  // CHECK: "tf_saved_model.global_tensor"() <{
   // CHECK-SAME: is_mutable
-  // CHECK-SAME: } : () -> ()
+  // CHECK-SAME: }> : () -> ()
   "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
 
   func.func @f(%arg0: tensor<!tf_type.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
@@ -187,14 +187,14 @@
 // to the unhandled op.
 module attributes {tf_saved_model.semantics} {
 
-  // CHECK: "tf_saved_model.global_tensor"() {
+  // CHECK: "tf_saved_model.global_tensor"() <{
   // CHECK-SAME: is_mutable
-  // CHECK-SAME: } : () -> ()
+  // CHECK-SAME: }> : () -> ()
   "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
 
-  // CHECK: "tf_saved_model.global_tensor"() {
+  // CHECK: "tf_saved_model.global_tensor"() <{
   // CHECK-SAME: is_mutable
-  // CHECK-SAME: } : () -> ()
+  // CHECK-SAME: }> : () -> ()
   "tf_saved_model.global_tensor"() { is_mutable, sym_name = "u", type = tensor<f32>, value = dense<22.> : tensor<f32> } : () -> ()
 
   func.func @f(%arg0: tensor<!tf_type.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}, %arg1: tensor<!tf_type.resource<tensor<f32>>> {tf_saved_model.bound_input = @u})
@@ -212,14 +212,14 @@
 // Test multiple global tensors uses as operands for an unhandled op.
 module attributes {tf_saved_model.semantics} {
 
-  // CHECK: "tf_saved_model.global_tensor"() {
+  // CHECK: "tf_saved_model.global_tensor"() <{
   // CHECK-SAME: is_mutable
-  // CHECK-SAME: } : () -> ()
+  // CHECK-SAME: }> : () -> ()
   "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
 
-  // CHECK: "tf_saved_model.global_tensor"() {
+  // CHECK: "tf_saved_model.global_tensor"() <{
   // CHECK-SAME: is_mutable
-  // CHECK-SAME: } : () -> ()
+  // CHECK-SAME: }> : () -> ()
   "tf_saved_model.global_tensor"() { is_mutable, sym_name = "u", type = tensor<f32>, value = dense<22.> : tensor<f32> } : () -> ()
 
   func.func @f(%arg0: tensor<!tf_type.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}, %arg1: tensor<!tf_type.resource<tensor<f32>>> {tf_saved_model.bound_input = @u})
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors_interprocedural.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors_interprocedural.mlir
index c05698c..d773e38 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors_interprocedural.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors_interprocedural.mlir
@@ -9,9 +9,9 @@
   // Test case: This test exercises marking a global tensor as immutable after it propagates
   // via set of chained calls -> f -> f_callee -> f_callee_callee
 
-  // CHECK: "tf_saved_model.global_tensor"() {
+  // CHECK: "tf_saved_model.global_tensor"() <{
   // CHECK-NOT: is_mutable
-  // CHECK-SAME: } : () -> ()
+  // CHECK-SAME: }> : () -> ()
   "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
 
   func.func @f(%arg0: tensor<!tf_type.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
@@ -39,11 +39,11 @@
   // Test case:
   // This test exercises trying to mark immutable when same func is called by multiple callers
   // with different global tensors.
-  // CHECK: "tf_saved_model.global_tensor"() {
+  // CHECK: "tf_saved_model.global_tensor"() <{
   // CHECK-NOT: is_mutable
   "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
 
-  // CHECK: "tf_saved_model.global_tensor"() {
+  // CHECK: "tf_saved_model.global_tensor"() <{
   // CHECK-NOT: is_mutable
   "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v2", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
 
@@ -73,9 +73,9 @@
   // Test case: This test exercises immutability without explicit use
   // via ReadVariableOp
 
-  // CHECK: "tf_saved_model.global_tensor"() {
+  // CHECK: "tf_saved_model.global_tensor"() <{
   // CHECK-NOT: is_mutable
-  // CHECK-SAME: } : () -> ()
+  // CHECK-SAME: }> : () -> ()
   "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
 
   func.func @f(%arg0: tensor<!tf_type.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
@@ -99,9 +99,9 @@
 
 // CHECK-LABEL: module attributes {tf_saved_model.semantics}
 module attributes {tf_saved_model.semantics} {
-  // CHECK: "tf_saved_model.global_tensor"() {
+  // CHECK: "tf_saved_model.global_tensor"() <{
   // CHECK-SAME: is_mutable
-  // CHECK-SAME: } : () -> ()
+  // CHECK-SAME: }> : () -> ()
   "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
 
   // CHECK: func @f(%arg0: tensor<!tf_type.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
@@ -133,9 +133,9 @@
   // Test case: The inter-procedural analysis with different types of
   // TF call ops
 
-  // CHECK: "tf_saved_model.global_tensor"() {
+  // CHECK: "tf_saved_model.global_tensor"() <{
   // CHECK-SAME: is_mutable
-  // CHECK-SAME: } : () -> ()
+  // CHECK-SAME: }> : () -> ()
   "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
 
   // CHECK: func @f(%arg0: tensor<!tf_type.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
@@ -166,9 +166,9 @@
 
   // Test case: The inter-procedural analysis does not recurse infinitely
 
-  // CHECK: "tf_saved_model.global_tensor"() {
+  // CHECK: "tf_saved_model.global_tensor"() <{
   // CHECK-NOT: is_mutable
-  // CHECK-SAME: } : () -> ()
+  // CHECK-SAME: }> : () -> ()
   "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
 
   func.func @exported_f(%arg0: tensor<!tf_type.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
@@ -199,9 +199,9 @@
   // Test case: Inter-procedural analysis with resource usage in an
   // unknown op, we assume mutating behavior and propagate that.
 
-  // CHECK: "tf_saved_model.global_tensor"() {
+  // CHECK: "tf_saved_model.global_tensor"() <{
   // CHECK-SAME: is_mutable
-  // CHECK-SAME: } : () -> ()
+  // CHECK-SAME: }> : () -> ()
   "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
 
   func.func @exported_f(%arg0: tensor<!tf_type.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu-dynamic-layout-pass.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu-dynamic-layout-pass.mlir
index 6d655ab..75bf23d 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tpu-dynamic-layout-pass.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu-dynamic-layout-pass.mlir
@@ -14,8 +14,8 @@
       mlir_module = "..."} : () -> (tensor<!tf_type.string>, tensor<2x!tf_type.string>)
     tf_device.return %1#0, %1#1 : tensor<!tf_type.string>, tensor<2x!tf_type.string>
   }) {device = "/device:CPU:0"} : () -> (tensor<!tf_type.string>, tensor<2x!tf_type.string>)
-  // CHECK-DAG: %[[LAYOUT0:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 0 : i64, is_output = false}
-  // CHECK-DAG: %[[LAYOUT1:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 1 : i64, is_output = false}
+  // CHECK-DAG: %[[LAYOUT0:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) <{index = 0 : i64, is_output = false}>
+  // CHECK-DAG: %[[LAYOUT1:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) <{index = 1 : i64, is_output = false}>
   // CHECK: %[[ITER:.*]]:2 = "tf.IteratorGetNext"
   %2:2 = "tf.IteratorGetNext"(%arg0) {device = "/device:CPU:0"}
     : (tensor<*x!tf_type.resource>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
@@ -249,8 +249,8 @@
       mlir_module = "..."} : () -> (tensor<!tf_type.string>, tensor<2x!tf_type.string>)
     tf_device.return %1#0, %1#1 : tensor<!tf_type.string>, tensor<2x!tf_type.string>
   }) {device = "/device:CPU:0"} : () -> (tensor<!tf_type.string>, tensor<2x!tf_type.string>)
-  // CHECK-DAG: %[[LAYOUT0:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 0 : i64, is_output = false}
-  // CHECK-DAG: %[[LAYOUT1:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 1 : i64, is_output = false}
+  // CHECK-DAG: %[[LAYOUT0:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) <{index = 0 : i64, is_output = false}>
+  // CHECK-DAG: %[[LAYOUT1:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) <{index = 1 : i64, is_output = false}>
   // CHECK: %[[ITER1:.*]]:2 = "tf.IteratorGetNext"
   %3:2 = "tf.IteratorGetNext"(%arg0) {device = "/device:CPU:0"}
     : (tensor<*x!tf_type.resource>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
@@ -294,8 +294,8 @@
       mlir_module = "..."} : () -> (tensor<!tf_type.string>, tensor<2x!tf_type.string>)
     tf_device.return %1#0, %1#1 : tensor<!tf_type.string>, tensor<2x!tf_type.string>
   }) {device = "/device:CPU:0"} : () -> (tensor<!tf_type.string>, tensor<2x!tf_type.string>)
-  // CHECK-DAG: %[[LAYOUT0:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 0 : i64, is_output = false}
-  // CHECK-DAG: %[[LAYOUT1:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 1 : i64, is_output = false}
+  // CHECK-DAG: %[[LAYOUT0:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) <{index = 0 : i64, is_output = false}>
+  // CHECK-DAG: %[[LAYOUT1:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) <{index = 1 : i64, is_output = false}>
 
   // CHECK-DAG: %[[COPY0:.*]] = "tf.TPUCopyWithLayout"(%[[ITER0]]#0, %[[LAYOUT0]]) {device = "/device:TPU:0"}
   // CHECK-DAG: %[[COPY1:.*]] = "tf.TPUCopyWithLayout"(%[[ITER0]]#1, %[[LAYOUT1]]) {device = "/device:TPU:0"}
@@ -332,8 +332,8 @@
       mlir_module = "..."} : () -> (tensor<!tf_type.string>, tensor<2x!tf_type.string>)
     tf_device.return %1#0, %1#1 : tensor<!tf_type.string>, tensor<2x!tf_type.string>
   }) {device = "/device:CPU:0"} : () -> (tensor<!tf_type.string>, tensor<2x!tf_type.string>)
-  // CHECK-DAG: %[[LAYOUT0:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 0 : i64, is_output = false}
-  // CHECK-DAG: %[[LAYOUT1:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 1 : i64, is_output = false}
+  // CHECK-DAG: %[[LAYOUT0:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) <{index = 0 : i64, is_output = false}>
+  // CHECK-DAG: %[[LAYOUT1:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) <{index = 1 : i64, is_output = false}>
   // CHECK: %[[ITER1:.*]] = "tf.IteratorGetNext"
   %3 = "tf.IteratorGetNext"(%arg1) {device = "/device:CPU:0"}
     : (tensor<*x!tf_type.resource>) -> tensor<3x3x1x32xf32>
@@ -415,8 +415,8 @@
     %1:3 = "tf._TPUCompileMlir"() {NumDynamicShapes = 0 : i64, metadata = "\0A\09\08\01\12\05\12\03\08\80\01\18\01 \02", mlir_module = "..."} : () -> (tensor<!tf_type.string>, tensor<2x!tf_type.string>, tensor<2x!tf_type.string>)
     tf_device.return %1#0, %1#1, %1#2 : tensor<!tf_type.string>, tensor<2x!tf_type.string>, tensor<2x!tf_type.string>
   }) {device = "/device:CPU:0"} : () -> (tensor<!tf_type.string>, tensor<2x!tf_type.string>, tensor<2x!tf_type.string>)
-  // CHECK-DAG: %[[LAYOUT0:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 0 : i64, is_output = false}
-  // CHECK-DAG: %[[LAYOUT1:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#2) {index = 0 : i64, is_output = false}
+  // CHECK-DAG: %[[LAYOUT0:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) <{index = 0 : i64, is_output = false}>
+  // CHECK-DAG: %[[LAYOUT1:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#2) <{index = 0 : i64, is_output = false}>
   // CHECK: %[[ITER:.*]]:2 = "tf.IteratorGetNext"
   %2:2 = "tf.IteratorGetNext"(%arg0) {device = "/device:CPU:0"} : (tensor<*x!tf_type.resource>) -> (tensor<128xf32>, tensor<128xf32>)
   // CHECK: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0)
@@ -429,9 +429,9 @@
     // CHECK-NEXT: %[[COPY0:.*]] = "tf.TPUCopyWithLayout"(%[[ITER]]#0, %[[LAYOUT0]])
     // CHECK-SAME: device = "/device:TPU:0"
     // CHECK-NEXT: "tf_device.launch"
+    // CHECK-SAME: device = "/device:TPU:0"
     // CHECK-NEXT: "tf.TPUExecute"(%[[COPY0]], %[[COMPILE]]#1)
     // CHECK-NEXT: tf_device.return
-    // CHECK-NEXT: device = "/device:TPU:0"
     "tf_device.launch"() ({
       "tf.TPUExecute"(%2#0, %compile#1) : (tensor<128xf32>, tensor<2x!tf_type.string>) -> ()
       tf_device.return
@@ -442,9 +442,9 @@
     // CHECK: %[[COPY1:.*]] = "tf.TPUCopyWithLayout"(%[[ITER]]#1, %[[LAYOUT1]])
     // CHECK-SAME: device = "/device:TPU:1"
     // CHECK-NEXT: "tf_device.launch"
+    // CHECK-SAME: device = "/device:TPU:1"
     // CHECK-NEXT: "tf.TPUExecute"(%[[COPY1]], %[[COMPILE]]#2)
     // CHECK-NEXT: tf_device.return
-    // CHECK-NEXT: device = "/device:TPU:1"
     "tf_device.launch"() ({
       "tf.TPUExecute"(%2#1, %compile#2) : (tensor<128xf32>, tensor<2x!tf_type.string>) -> ()
       tf_device.return
@@ -481,8 +481,8 @@
     %1:3 = "tf._TPUCompileMlir"() {NumDynamicShapes = 0 : i64, metadata = "\0A\09\08\01\12\05\12\03\08\80\01\18\02 \02", mlir_module = "..."} : () -> (tensor<!tf_type.string>, tensor<2x!tf_type.string>, tensor<2x!tf_type.string>)
     tf_device.return %1#0, %1#1, %1#2 : tensor<!tf_type.string>, tensor<2x!tf_type.string>, tensor<2x!tf_type.string>
   }) {device = "/device:CPU:0"} : () -> (tensor<!tf_type.string>, tensor<2x!tf_type.string>, tensor<2x!tf_type.string>)
-  // CHECK-DAG: %[[LAYOUT0:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 0 : i64, is_output = false}
-  // CHECK-DAG: %[[LAYOUT1:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#2) {index = 0 : i64, is_output = false}
+  // CHECK-DAG: %[[LAYOUT0:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) <{index = 0 : i64, is_output = false}>
+  // CHECK-DAG: %[[LAYOUT1:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#2) <{index = 0 : i64, is_output = false}>
   // CHECK-DAG: %[[ITER0:.*]]:2 = "tf.IteratorGetNext"(%[[ARG0]])
   // CHECK-DAG: %[[ITER1:.*]]:2 = "tf.IteratorGetNext"(%[[ARG1]])
   %2:2 = "tf.IteratorGetNext"(%arg0) {device = "/device:CPU:0"} : (tensor<*x!tf_type.resource>) -> (tensor<128xf32>, tensor<128xf32>)
@@ -501,9 +501,10 @@
   tf_device.replicate([%2#0, %3#0] as %r0: tensor<128xf32>, [%2#1, %3#1] as %r1: tensor<128xf32>) {n = 2 : i32, devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"], TPU_REPLICATED_CORE_1 = ["/device:TPU:2", "/device:TPU:3"]}} {
     // CHECK: "tf_device.parallel_execute"
     "tf_device.parallel_execute"() ({
+      // CHECK: "tf_device.launch"
+      // CHECK-SAME: device = "TPU_REPLICATED_CORE_0"
       // CHECK: "tf.TPUExecute"(%[[R0]], %[[COMPILE]]#1)
       // CHECK-NEXT: tf_device.return
-      // CHECK-NEXT: device = "TPU_REPLICATED_CORE_0"
       "tf_device.launch"() ({
         "tf.TPUExecute"(%r0, %compile#1) : (tensor<128xf32>, tensor<2x!tf_type.string>) -> ()
         tf_device.return
@@ -511,9 +512,10 @@
       tf_device.return
     },
     {
+      // CHECK: "tf_device.launch"
+      // CHECK-SAME: device = "TPU_REPLICATED_CORE_1"
       // CHECK: "tf.TPUExecute"(%[[R1]], %[[COMPILE]]#2)
       // CHECK-NEXT: tf_device.return
-      // CHECK-NEXT: device = "TPU_REPLICATED_CORE_1"
       "tf_device.launch"() ({
         "tf.TPUExecute"(%r1, %compile#2) : (tensor<128xf32>, tensor<2x!tf_type.string>) -> ()
         tf_device.return
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu-merge-variables-with-execute.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu-merge-variables-with-execute.mlir
index e3191b5..880703a 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tpu-merge-variables-with-execute.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu-merge-variables-with-execute.mlir
@@ -27,6 +27,7 @@
       tf_device.return %0#0, %0#1 : tensor<!tf_type.string>, tensor<2x!tf_type.string>
     }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> (tensor<!tf_type.string>, tensor<2x!tf_type.string>)
   // CHECK: %[[EXE:.*]] = "tf_device.launch"
+  // CHECK-SAME: <{device = "/job:localhost/replica:0/task:0/device:TPU:0"}>
   // CHECK-NEXT: "tf.TPUExecuteAndUpdateVariables"(%[[ID_0]], %[[ARG_1]], %[[READ_2]], %[[COMPILE]]#1)
   // CHECK-SAME: device_var_reads_indices = [0, 1],
   // CHECK-SAME: device_var_updates_indices = [0, -1]
@@ -38,7 +39,7 @@
     tf_device.return %0#0, %0#1 : tensor<32xf32>, tensor<16xf32>
   }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : () -> (tensor<32xf32>, tensor<16xf32>)
   // CHECK-NEXT: tf_device.return
-  // CHECK-NEXT: }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"}
+  // CHECK-NEXT: })
   "tf.AssignVariableOp"(%id0, %execute#0) : (tensor<*x!tf_type.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
   // CHECK-NEXT: "tf.AssignVariableOp"(%[[ARG_2]], %[[EXE]])
   "tf.AssignVariableOp"(%arg2, %execute#1) : (tensor<*x!tf_type.resource<tensor<16xf32>>>, tensor<16xf32>) -> ()
@@ -71,6 +72,7 @@
   // CHECK: tf_device.replicate([%[[ARG_1]], %[[ARG_2]]] as %[[R_ARG:.*]]: tensor<*x!tf_type.resource<tensor<32xf32>>>)
   tf_device.replicate([%arg1, %arg2] as %r: tensor<*x!tf_type.resource<tensor<32xf32>>>) {n = 2 : i32} {
     // CHECK-NEXT: "tf_device.launch"
+    // CHECK-SAME: <{device = ""}>
     // CHECK-NEXT: "tf.TPUExecuteAndUpdateVariables"(%[[READ_0]], %[[R_ARG]], %[[COMPILE]]#1)
     // CHECK-SAME: device_var_reads_indices = [1],
     // CHECK-SAME: device_var_updates_indices = [0]
@@ -81,7 +83,7 @@
       tf_device.return %0 : tensor<32xf32>
     }) {device = ""} : () -> tensor<32xf32>
     // CHECK-NEXT: tf_device.return
-    // CHECK-NEXT: }) {device = ""}
+    // CHECK-NEXT: })
     "tf.AssignVariableOp"(%r, %execute) : (tensor<*x!tf_type.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
     // CHECK-NEXT: tf_device.return
     tf_device.return
@@ -130,6 +132,7 @@
     tf_device.return %0#0, %0#1 : tensor<!tf_type.string>, tensor<2x!tf_type.string>
   }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> (tensor<!tf_type.string>, tensor<2x!tf_type.string>)
   // CHECK: %[[EXE:.*]]:2 = "tf_device.launch"
+  // CHECK-SAME: <{device = "/job:localhost/replica:0/task:0/device:TPU:0"}>
   // CHECK-NEXT: "tf.TPUExecuteAndUpdateVariables"(%[[READ_0]], %[[ARG_1]], %[[ARG_4]], %[[READ_5]], %[[COMPILE]]#1)
   // CHECK-SAME: device_var_reads_indices = [1, 2],
   // CHECK-SAME: device_var_updates_indices = [1, -1]
@@ -142,7 +145,7 @@
     tf_device.return %0#0, %0#1, %0#2 : tensor<32xf32>, tensor<64xf32>, tensor<8xf32>
   }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : () -> (tensor<32xf32>, tensor<64xf32>, tensor<8xf32>)
   // CHECK-NEXT: tf_device.return
-  // CHECK-NEXT: }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"}
+  // CHECK-NEXT: })
   "tf.AssignVariableOp"(%arg1, %execute#1) : (tensor<*x!tf_type.resource<tensor<64xf32>>>, tensor<64xf32>) -> ()
   // CHECK-NEXT: "tf.AssignVariableOp"(%[[ARG_0]], %[[EXE]]#0)
   "tf.AssignVariableOp"(%arg0, %execute#0) : (tensor<*x!tf_type.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
@@ -197,6 +200,7 @@
     tf_device.return %0#0, %0#1 : tensor<!tf_type.string>, tensor<2x!tf_type.string>
   }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> (tensor<!tf_type.string>, tensor<2x!tf_type.string>)
   // CHECK: %[[EXE:.*]] = "tf_device.launch"
+  // CHECK-SAME: <{device = "/job:localhost/replica:0/task:0/device:TPU:0"}>
   // CHECK-NEXT: "tf.TPUExecuteAndUpdateVariables"(%[[ARG_0]], %[[ARG_1]], %[[ARG_3]], %[[ARG_4]], %[[COMPILE]]#1)
   // CHECK-SAME: device_var_reads_indices = [0, 1, 2, 3],
   // CHECK-SAME: device_var_updates_indices = [0, 1, -1, -1]
@@ -209,7 +213,7 @@
     tf_device.return %0#0, %0#1, %0#2 : tensor<32xf32>, tensor<64xf32>, tensor<8xf32>
   }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : () -> (tensor<32xf32>, tensor<64xf32>, tensor<8xf32>)
   // CHECK-NEXT: tf_device.return
-  // CHECK-NEXT: }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"}
+  // CHECK-NEXT: })
   // CHECK-NEXT: %[[READ:.*]] = "tf.ReadVariableOp"(%[[ARG_3]])
   %read3 = "tf.ReadVariableOp"(%arg3) : (tensor<*x!tf_type.resource<tensor<8xf32>>>) -> tensor<8xf32>
   // CHECK-NEXT: "tf.AssignVariableOp"(%[[ARG_3]], %[[EXE]])
@@ -236,6 +240,7 @@
   // CHECK-NEXT: %[[READ_1:.*]] = "tf.ReadVariableOp"(%[[ARG_0]])
   %read1 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf_type.resource<tensor<32xf32>>>) -> tensor<32xf32>
   // CHECK-NEXT: %[[EXE:.*]] = "tf_device.launch"
+  // CHECK-SAME: <{device = "/job:localhost/replica:0/task:0/device:TPU:0"}>
   // CHECK-NEXT: "tf.TPUExecute"(%[[READ_0]], %[[READ_1]], %[[ARG_1]])
   %execute = "tf_device.launch"() ({
     %0 = "tf.TPUExecute"(%read0, %read1, %arg1) {
@@ -244,7 +249,7 @@
     tf_device.return %0 : tensor<32xf32>
   }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : () -> tensor<32xf32>
   // CHECK-NEXT: tf_device.return
-  // CHECK-NEXT: }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"}
+  // CHECK-NEXT: })
   // CHECK-NEXT: "tf.AssignVariableOp"(%[[ARG_0]], %[[EXE]])
   "tf.AssignVariableOp"(%arg0, %execute) : (tensor<*x!tf_type.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
   // CHECK-NEXT: return
@@ -265,6 +270,7 @@
   // CHECK-NEXT: %[[READ_0:.*]] = "tf.ReadVariableOp"(%[[ARG_0]])
   %read0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf_type.resource<tensor<32xf32>>>) -> tensor<32xf32>
   // CHECK-NEXT: %[[EXE:.*]]:2 = "tf_device.launch"
+  // CHECK-SAME: <{device = "/job:localhost/replica:0/task:0/device:TPU:0"}>
   // CHECK-NEXT: "tf.TPUExecute"(%[[READ_0]], %[[ARG_1]])
   %execute:2 = "tf_device.launch"() ({
     %0:2 = "tf.TPUExecute"(%read0, %arg1) {
@@ -273,7 +279,7 @@
     tf_device.return %0#0, %0#1 : tensor<32xf32>, tensor<32xf32>
   }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : () -> (tensor<32xf32>, tensor<32xf32>)
   // CHECK-NEXT: tf_device.return
-  // CHECK-NEXT: }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"}
+  // CHECK-NEXT: })
   // CHECK-NEXT: "tf.AssignVariableOp"(%[[ARG_0]], %[[EXE]]#0)
   "tf.AssignVariableOp"(%arg0, %execute#0) : (tensor<*x!tf_type.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
   // CHECK-NEXT: "tf.AssignVariableOp"(%[[ARG_0]], %[[EXE]]#1)
@@ -301,22 +307,22 @@
   // CHECK: "tf_device.parallel_execute"
   %pe:2 = "tf_device.parallel_execute"() ({
     // CHECK: "tf_device.launch"
+    // CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:0"
     %execute0 = "tf_device.launch"() ({
       // CHECK-NEXT: "tf.TPUExecuteAndUpdateVariables"(%[[ARG_0]], %[[ARG_2]])
       %0 = "tf.TPUExecute"(%read0, %arg2) : (tensor<32xf32>, tensor<!tf_type.string>) -> tensor<32xf32>
       // CHECK-NEXT: tf_device.return
       tf_device.return %0 : tensor<32xf32>
-    // CHECK-NEXT: device = "/job:localhost/replica:0/task:0/device:TPU:0"
     }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : () -> tensor<32xf32>
     tf_device.return %execute0 : tensor<32xf32>
   }, {
     // CHECK: "tf_device.launch"
+    // CHECK-SAME: device = "/job:localhost/replica:0/task:0/device:TPU:1"
     %execute1 = "tf_device.launch"() ({
       // CHECK-NEXT: "tf.TPUExecuteAndUpdateVariables"(%[[ARG_1]], %[[ARG_2]])
       %1 = "tf.TPUExecute"(%read1, %arg2) : (tensor<64xf32>, tensor<!tf_type.string>) -> tensor<64xf32>
       // CHECK-NEXT: tf_device.return
       tf_device.return %1 : tensor<64xf32>
-    // CHECK-NEXT: device = "/job:localhost/replica:0/task:0/device:TPU:1"
     }) {device = "/job:localhost/replica:0/task:0/device:TPU:1"} : () -> tensor<64xf32>
     tf_device.return %execute1 : tensor<64xf32>
   }) : () -> (tensor<32xf32>, tensor<64xf32>)
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu-multiple-while-body-func.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu-multiple-while-body-func.mlir
index 1762e9f..ae3498a 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tpu-multiple-while-body-func.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu-multiple-while-body-func.mlir
@@ -1,5 +1,5 @@
-// RUN: tf-opt %s -tf-tpu-bridge 2>&1 | FileCheck %s
-// RUN: tf-opt %s -tf-cluster-tpu-bridge-v1 2>&1 | FileCheck %s
+// RUN: tf-opt %s -tf-cluster-tpu-bridge-v2 -tfrt-lower-cluster-to-runtime-ops-tpu 2>&1 | FileCheck %s
+// RUN: tf-opt %s -tf-cluster-tpu-bridge-v1 -tfrt-lower-cluster-to-runtime-ops-tpu -tf-dialect-to-executor-v1 2>&1 | FileCheck %s
 
 // This test verifies there is no warning about shape inference failure in TPU
 // bridge in handling multiple usage of the same function.
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir
index 5594918..9a903d7 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir
@@ -61,9 +61,9 @@
             // CHECK: %[[ID:.*]] = "tf.Identity"(%[[R0]])
             %id = "tf.Identity"(%arg30) : (tensor<*x!tf_type.resource<tensor<f32>>>) -> tensor<*x!tf_type.resource<tensor<f32>>>
             // CHECK: "tf_device.launch"
+            // CHECK-SAME: device = "TPU_REPLICATED_CORE_0"
             // CHECK-NEXT: "tf.TPUReshardVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1, %[[R_STATE]])
             // CHECK-NEXT: tf_device.return
-            // CHECK-NEXT: device = "TPU_REPLICATED_CORE_0"
             // CHECK: "tf.TPUExecuteAndUpdateVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1)
             "tf_device.launch"() ({
               "tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %compile#1)
@@ -84,9 +84,9 @@
     // CHECK-SAME: [%[[STATE0]], %[[STATE1]]] as %[[STATE:.*]]: tensor<!tf_type.resource<tensor<2x!tf_type.string>>>
     // CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]
     // CHECK: "tf_device.launch"
+    // CHECK-SAME: device = "TPU_REPLICATED_CORE_0"
     // CHECK-NEXT: "tf.TPUReshardVariables"(%[[V0]], %[[V1]], %[[DEFAULT]], %[[STATE]])
     // CHECK-NEXT: tf_device.return
-    // CHECK-NEXT: device = "TPU_REPLICATED_CORE_0"
     func.return
   }
 }
@@ -296,9 +296,9 @@
             %id = "tf.Identity"(%arg30) : (tensor<*x!tf_type.resource<tensor<f32>>>) -> tensor<*x!tf_type.resource<tensor<f32>>>
             // CHECK: "tf_device.parallel_execute"
             // CHECK: "tf_device.launch"
+            // CHECK-SAME: device = "TPU_REPLICATED_CORE_0"
             // CHECK-NEXT: "tf.TPUReshardVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1, %[[R_STATE]])
             // CHECK-NEXT: tf_device.return
-            // CHECK-NEXT: device = "TPU_REPLICATED_CORE_0"
             // CHECK: "tf.TPUExecuteAndUpdateVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1)
 	    "tf_device.parallel_execute"() ({
               "tf_device.launch"() ({
@@ -324,9 +324,9 @@
     // CHECK-SAME: [%[[STATE0]], %[[STATE1]]] as %[[STATE:.*]]: tensor<!tf_type.resource<tensor<2x!tf_type.string>>>
     // CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]
     // CHECK: "tf_device.launch"
+    // CHECK-SAME: device = "TPU_REPLICATED_CORE_0"
     // CHECK-NEXT: "tf.TPUReshardVariables"(%[[V0]], %[[V1]], %[[DEFAULT]], %[[STATE]])
     // CHECK-NEXT: tf_device.return
-    // CHECK-NEXT: device = "TPU_REPLICATED_CORE_0"
     func.return
   }
 }
@@ -391,9 +391,9 @@
             // CHECK: %[[ID:.*]] = "tf.Identity"(%[[R0]])
             %id = "tf.Identity"(%arg30) : (tensor<*x!tf_type.resource<tensor<f32>>>) -> tensor<*x!tf_type.resource<tensor<f32>>>
             // CHECK: "tf_device.launch"
+            // CHECK-SAME: device = "TPU_REPLICATED_CORE_0"
             // CHECK-NEXT: "tf.TPUReshardVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1, %[[R_STATE]])
             // CHECK-NEXT: tf_device.return
-            // CHECK-NEXT: device = "TPU_REPLICATED_CORE_0"
             // CHECK: "tf.TPUExecuteAndUpdateVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1)
             "tf_device.launch"() ({
               "tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %compile#1)
@@ -414,9 +414,9 @@
     // CHECK-SAME: %[[ARG2]] as %[[V1:.*]]: tensor<*x!tf_type.resource<tensor<3x3x1x32xf32>>>
     // CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]
     // CHECK: "tf_device.launch"
+    // CHECK-SAME: device = "TPU_REPLICATED_CORE_0"
     // CHECK-NEXT: "tf.TPUReshardVariables"(%[[V0]], %[[V1]], %[[DEFAULT]], %[[STATE]])
     // CHECK-NEXT: tf_device.return
-    // CHECK-NEXT: device = "TPU_REPLICATED_CORE_0"
     func.return
   }
 }
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir
index ed43750..1d3c1b6 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir
@@ -942,7 +942,7 @@
 
 // CHECK-LABEL: func @const_with_attrs
 func.func @const_with_attrs(%arg0: tensor<*xi32>, %arg1: tensor<?xi64>) -> (tensor<?xi32>, tensor<?xi64>) {
-  // CHECK: %{{[a-z0-9_]*}} = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK: %{{[a-z0-9_]*}} = "tf.Const"() <{value = dense<-1> : tensor<1xi32>}> : () -> tensor<1xi32>
   // CHECK-NEXT: %{{[a-z0-9_]*}} = "tf.Reshape"(%arg0
   // CHECK-NEXT: %{{.*}} = "tf_device.cluster"() ({
   %minus_one = "tf.Const"() {_replication_info = "cluster",
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_colocate_composite_resource_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_colocate_composite_resource_ops.mlir
index b2896fa..62fe231 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_colocate_composite_resource_ops.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_colocate_composite_resource_ops.mlir
@@ -13,9 +13,9 @@
     devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]},
     n = 2 : i32} {
      // CHECK:      %[[RESOURCE_OUT:.*]] = "tf_device.launch"()
+     // CHECK-SAME: TPU_REPLICATED_CORE_0
      // CHECK-NEXT:   %[[READ_OUT:.*]] = "tf.ReadVariableOp"(%[[RI_0]])
      // CHECK-NEXT:   tf_device.return %[[READ_OUT]]
-     // CHECK-NEXT: TPU_REPLICATED_CORE_0
      %0 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf_type.resource<tensor<4xf32>>>) -> tensor<4xf32>
      %1 = "tf.A"() : () -> (tensor<2x!tf_type.string>)
      "tf_device.launch"() ({
@@ -43,9 +43,9 @@
     n = 2 : i32} {
      // CHECK:      %[[IDENTITY_OUT:.*]] = "tf.Identity"(%[[RI_0]])
      // CHECK:      %[[RESOURCE_OUT:.*]] = "tf_device.launch"()
+     // CHECK-SAME: TPU_REPLICATED_CORE_0
      // CHECK-NEXT:   %[[READ_OUT:.*]] = "tf.ReadVariableOp"(%[[IDENTITY_OUT]])
      // CHECK-NEXT:   tf_device.return %[[READ_OUT]]
-     // CHECK-NEXT: TPU_REPLICATED_CORE_0
      %0 = "tf.Identity"(%arg1) : (tensor<*x!tf_type.resource<tensor<4xf32>>>) -> tensor<*x!tf_type.resource<tensor<4xf32>>>
      %1 = "tf.ReadVariableOp"(%0) : (tensor<*x!tf_type.resource<tensor<4xf32>>>) -> tensor<4xf32>
      %2 = "tf.A"() : () -> (tensor<2x!tf_type.string>)
@@ -77,9 +77,9 @@
     n = 2 : i32} {
      // CHECK:      %[[VAL_OUT:.*]] = "tf.A"() : () -> tensor<4xf32>
      // CHECK:      "tf_device.launch"()
+     // CHECK-SAME: TPU_REPLICATED_CORE_0
      // CHECK-NEXT:   "tf.AssignVariableOp"(%[[RI_0]], %[[VAL_OUT]])
-     // CHECK-NEXT:   tf_device.return
-     // CHECK-NEXT: TPU_REPLICATED_CORE_0
+     // CHECK:   tf_device.return
      %1 = "tf.A"() : () -> (tensor<4xf32>)
      "tf.AssignVariableOp"(%arg1, %1) : (tensor<*x!tf_type.resource<tensor<4xf32>>>, tensor<4xf32>) -> ()
      %2 = "tf.B"() : () -> (tensor<2x!tf_type.string>)
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir
index 8b128e5..9796913 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir
@@ -611,9 +611,9 @@
   // CHECK-LABEL: func @no_replication_device
   func.func @no_replication_device() {
     "tf_device.cluster_func"() {_xla_compile_device_type = "TPU", _replication_info = "__no_replication_cluster", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device = "/job:worker/replica:0/task:0/device:TPU:1", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> ()
+    // CHECK: "tf_device.launch"() <{device = "/job:worker/replica:0/task:0/device:TPU:1"}>
     // CHECK: tf.TPUExecute
     // CHECK-NEXT: tf_device.return
-    // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:TPU:1"
     func.return
   }
   func.func @empty_func() {
@@ -629,9 +629,9 @@
   // CHECK-LABEL: func @no_replication_device
   func.func @no_replication_device() {
     "tf_device.cluster_func"() {_xla_compile_device_type = "TPU", _replication_info = "__no_replication_cluster", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device = "/job:worker/replica:0/task:0/device:CPU:0", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> ()
+    // CHECK: "tf_device.launch"() <{device = "/job:worker/replica:0/task:0/device:TPU:0"}>
     // CHECK: tf.TPUExecute
     // CHECK-NEXT: tf_device.return
-    // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:TPU:0"
     func.return
   }
   func.func @empty_func() {
@@ -709,20 +709,18 @@
 
     %1 = "tf_device.cluster_func"(%0) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<?xi32>) -> tensor<?xi32>
     // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
-    // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
+    // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"() <{device = "/job:worker/replica:0/task:0/device:CPU:0"}>
     // CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
     // CHECK-SAME: metadata
     // CHECK-SAME: mlir_module
     // CHECK-SAME: func @main
     // CHECK-SAME: tf.B
     // CHECK-NOT: func = @tpu0_func
-    // CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
-    // CHECK: "tf_device.launch"
+    // CHECK: "tf_device.launch"() <{device = "/job:worker/replica:0/task:0/device:CPU:0"}>
     // CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE_OUTPUT]]#0)
-    // CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
     // CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf_device.launch"
+    // CHECK-SAME: device = "/job:worker/replica:0/task:0/device:TPU:0"
     // CHECK-NEXT: "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE_OUTPUT]]#1)
-    // CHECK: device = "/job:worker/replica:0/task:0/device:TPU:0"
 
     %2 = "tf.C"(%1) : (tensor<?xi32>) -> tensor<?xi32>
     // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[EXECUTE_OUTPUT]])
@@ -756,17 +754,15 @@
     // CHECK-SAME: n = 2
     %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
       // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[RI_0]])
-      // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
+      // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"() <{device = "/job:worker/replica:0/task:0/device:CPU:0"}>
       // CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
       // CHECK-SAME: metadata
       // CHECK-SAME: mlir_module
       // CHECK-SAME: func @main
       // CHECK-SAME: tf.B
       // CHECK-NOT: func = @tpu0_func
-      // CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
-      // CHECK: "tf_device.launch"
+      // CHECK: "tf_device.launch"() <{device = "/job:worker/replica:0/task:0/device:CPU:0"}>
       // CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE_OUTPUT]]#0)
-      // CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
       // CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf_device.launch"
       // CHECK-NEXT: "tf.TPUExecute"(%[[RI_0]], %[[COMPILE_OUTPUT]]#1)
       %2 = "tf_device.cluster_func"(%ri_0) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<?xi32>) -> tensor<?xi32>
@@ -799,8 +795,8 @@
 
     %1 = "tf_device.cluster_func"(%0) {device = "gpu0", func = @gpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<?xi32>) -> tensor<?xi32>
     // CHECK: tf_device.cluster_func
-    // CHECK-SAME: device = "gpu0"
     // CHECK-SAME: func = @gpu0_func
+    // CHECK-SAME: device = "gpu0"
     // CHECK-SAME: num_cores_per_replica = 1
     // CHECK-SAME: step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP"
     // CHECK-NOT: metadata
@@ -826,7 +822,7 @@
 
     %1 = "tf_device.cluster_func"(%0) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<?xi32>) -> tensor<?xi32>
     // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
-    // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
+    // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"() <{device = "/job:worker/replica:0/task:0/device:CPU:0"}>
     // CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
     // CHECK-SAME: metadata
     // CHECK-SAME: mlir_module
@@ -835,13 +831,10 @@
     // CHECK-SAME: func private @nested_func
     // CHECK-SAME: tf.D
     // CHECK-NOT: func = @tpu0_func
-    // CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
-    // CHECK: "tf_device.launch"
+    // CHECK: "tf_device.launch"() <{device = "/job:worker/replica:0/task:0/device:CPU:0"}>
     // CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE_OUTPUT]]#0)
-    // CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
-    // CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf_device.launch"
+    // CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf_device.launch"() <{device = "/job:worker/replica:0/task:0/device:TPU:0"}>
     // CHECK-NEXT: "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE_OUTPUT]]#1)
-    // CHECK: device = "/job:worker/replica:0/task:0/device:TPU:0"
 
     %2 = "tf.C"(%1) : (tensor<?xi32>) -> tensor<?xi32>
     // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[EXECUTE_OUTPUT]])
@@ -1198,14 +1191,12 @@
     // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:3 = "tf_device.launch"
     // CHECK-NEXT: "tf._TPUCompileMlir"()
     // CHECK: "tf_device.launch"
-    // CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE_OUTPUT]]#0)
+    // CHECK: "tf.TPUCompileSucceededAssert"(%[[COMPILE_OUTPUT]]#0)
     // CHECK: [[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:2 = "tf_device.parallel_execute"
-    // CHECK: "tf_device.launch"
+    // CHECK: "tf_device.launch"() <{device = "/job:worker/replica:0/task:0/device:TPU:0"}>
     // CHECK-NEXT: "tf.TPUExecute"(%[[READ_VAR_0]], %[[COMPILE_OUTPUT]]#1)
-    // CHECK: device = "/job:worker/replica:0/task:0/device:TPU:0"
-    // CHECK: "tf_device.launch"
+    // CHECK: "tf_device.launch"() <{device = "/job:worker/replica:0/task:0/device:TPU:1"}>
     // CHECK-NEXT: "tf.TPUExecute"(%[[READ_VAR_1]], %[[COMPILE_OUTPUT]]#2)
-    // CHECK: device = "/job:worker/replica:0/task:0/device:TPU:1"
     %computation = "tf_device.cluster_func"(%partitioned_input) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @computation, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\01\01\02\10\01\18\02\22\08\00\00\00\00\00\00\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1], input_sharding_configuration = [""], output_sharding_configuration = [""], use_spmd_for_xla_partitioning = true} : (tensor<i32>) -> tensor<i32>
     // CHECK-NOT: tf.TPUPartitionedOutputV2
     %partitioned_output:2 = "tf.TPUPartitionedOutputV2"(%computation) {N = 2 : i64, partition_dims = []} : (tensor<i32>) -> (tensor<i32>, tensor<i32>)
@@ -1238,12 +1229,10 @@
     // CHECK: "tf_device.launch"
     // CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE_OUTPUT]]#0)
     // CHECK: [[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:2 = "tf_device.parallel_execute"
-    // CHECK: "tf_device.launch"
+    // CHECK: "tf_device.launch"() <{device = "/job:worker/replica:0/task:0/device:TPU:0"}>
     // CHECK-NEXT: "tf.TPUExecute"(%[[READ_VAR_0]], %[[COMPILE_OUTPUT]]#1)
-    // CHECK: device = "/job:worker/replica:0/task:0/device:TPU:0"
-    // CHECK: "tf_device.launch"
+    // CHECK: "tf_device.launch"() <{device = "/job:worker/replica:0/task:0/device:TPU:1"}>
     // CHECK-NEXT: "tf.TPUExecute"(%[[READ_VAR_1]], %[[COMPILE_OUTPUT]]#2)
-    // CHECK: device = "/job:worker/replica:0/task:0/device:TPU:1"
     %computation = "tf_device.cluster_func"(%partitioned_input) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @computation, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\01\01\02\10\01\18\02\22\08\00\00\00\00\00\00\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1], input_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01"], output_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01"], use_spmd_for_xla_partitioning = true} : (tensor<3x4xf32>) -> tensor<3x4xf32>
     // CHECK-NOT: tf.TPUPartitionedOutputV2
     %partitioned_output:2 = "tf.TPUPartitionedOutputV2"(%computation) {_XlaSharding = "\08\03\1A\02\01\02\22\02\00\01", partition_dims = [1, 2]} : (tensor<3x4xf32>) -> (tensor<3x2xf32>, tensor<3x2xf32>)
@@ -1443,10 +1432,10 @@
       // CHECK: "tf_device.parallel_execute"
       // CHECK-NOT:"tf._XlaCompileMlirPlaceholderProgramKey"
       // CHECK:    "tf.D"(%[[COMPILE_OUTPUT]]#1
+      // CHECK:    "tf_device.launch"() <{device = "TPU_REPLICATED_CORE_0"}>
       // CHECK:    "tf.TPUExecute"
-      // CHECK:      device = "TPU_REPLICATED_CORE_0"
+      // CHECK:     "tf_device.launch"() <{device = "TPU_REPLICATED_CORE_1"}>
       // CHECK:    "tf.TPUExecute"
-      // CHECK:      device = "TPU_REPLICATED_CORE_1"
       // CHECK-NOT:    "tf.TPUExecute"
       %3 = "tf_device.parallel_execute"() ({
          %program = "tf._XlaCompileMlirPlaceholderProgramKey"() : () -> tensor<3x!tf_type.string>
@@ -1485,10 +1474,10 @@
       // CHECK: "tf._TPUCompileMlir"
       // CHECK: "tf.TPUCompileSucceededAssert"
       // CHECK: "tf_device.parallel_execute"
+      // CHECK:    "tf_device.launch"() <{device = "TPU_REPLICATED_CORE_0"}>
       // CHECK:    "tf.TPUExecute"
-      // CHECK:      device = "TPU_REPLICATED_CORE_0"
+      // CHECK:    "tf_device.launch"() <{device = "TPU_REPLICATED_CORE_1"}>
       // CHECK:    "tf.TPUExecute"
-      // CHECK:      device = "TPU_REPLICATED_CORE_1"
       // CHECK-NOT:    "tf.TPUExecute"
       // CHECK-NOT:"tf._XlaCompileMlirPlaceholderProgramKey"
       // CHECK:    "tf.D"(%[[COMPILE_OUTPUT]]#1
@@ -1524,23 +1513,19 @@
 module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"]} {
   // CHECK-LABEL: func @non_replicated_parallel_execute
   func.func @non_replicated_parallel_execute(%arg0: tensor<8xi32>) -> tensor<8xi32> {
-    // CHECK:      %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch"
+    // CHECK:      %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}>
     // CHECK-NEXT:   "tf._TPUCompileMlir"()
     // CHECK-NEXT:   tf_device.return
-    // CHECK-NEXT: device = "/job:localhost/replica:0/task:0/device:CPU:0"
-    // CHECK:      "tf_device.launch"
+    // CHECK:      "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}>
     // CHECK-NEXT:   "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0)
     // CHECK-NEXT:   tf_device.return
-    // CHECK-NEXT: device = "/job:localhost/replica:0/task:0/device:CPU:0"
     // CHECK:      "tf_device.parallel_execute"
-    // CHECK-NEXT:   "tf_device.launch"
+    // CHECK-NEXT:   "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:TPU:0"}>
     // CHECK-NEXT:     "tf.TPUExecute"
     // CHECK-NEXT:     tf_device.return
-    // CHECK-NEXT:   device = "/job:localhost/replica:0/task:0/device:TPU:0"
-    // CHECK:        "tf_device.launch"
+    // CHECK:        "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:TPU:1"}>
     // CHECK-NEXT:     "tf.TPUExecute"
     // CHECK-NEXT:     tf_device.return
-    // CHECK-NEXT:   device = "/job:localhost/replica:0/task:0/device:TPU:1"
     %0 = "tf_device.cluster_func"(%arg0) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\01\01\02\10\01\18\02\22\08\00\00\00\00\00\00\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<8xi32>) -> tensor<8xi32>
     func.return %0 : tensor<8xi32>
   }
@@ -1587,23 +1572,19 @@
     // CHECK: tf_device.replicate
     // CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"], TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"], TPU_REPLICATED_HOST_0 = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:1/device:CPU:0"], TPU_REPLICATED_HOST_1 = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:1/device:CPU:0"]}
     %0:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<8xi32>) {n = 2 : i32} {
-      // CHECK-NEXT: %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch"
+      // CHECK-NEXT: %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}>
       // CHECK-NEXT:   "tf._TPUCompileMlir"()
       // CHECK-NEXT:   tf_device.return
-      // CHECK-NEXT: device = "/job:localhost/replica:0/task:0/device:CPU:0"
-      // CHECK:      "tf_device.launch"
+      // CHECK:      "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}>
       // CHECK-NEXT:   "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0)
       // CHECK-NEXT:   tf_device.return
-      // CHECK-NEXT: device = "/job:localhost/replica:0/task:0/device:CPU:0"
       // CHECK:      "tf_device.parallel_execute"
-      // CHECK-NEXT:   "tf_device.launch"
+      // CHECK-NEXT:   "tf_device.launch"() <{device = "TPU_REPLICATED_CORE_0"}>
       // CHECK-NEXT:     "tf.TPUExecute"
       // CHECK-NEXT:     tf_device.return
-      // CHECK-NEXT:   device = "TPU_REPLICATED_CORE_0"
-      // CHECK:        "tf_device.launch"
+      // CHECK:        "tf_device.launch"() <{device = "TPU_REPLICATED_CORE_1"}>
       // CHECK-NEXT:     "tf.TPUExecute"
       // CHECK-NEXT:     tf_device.return
-      // CHECK-NEXT:   device = "TPU_REPLICATED_CORE_1"
       %1 = "tf_device.cluster_func"(%ri) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<8xi32>) -> tensor<8xi32>
       tf_device.return %1 : tensor<8xi32>
     }
@@ -1632,13 +1613,11 @@
       // CHECK:      %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch"
       // CHECK:      "tf._TPUCompileMlir"
       // CHECK:      %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
-      // CHECK-NEXT:   %[[LAUNCH_0_OUTPUT:[0-9]*]] = "tf_device.launch"
+      // CHECK-NEXT:   %[[LAUNCH_0_OUTPUT:[0-9]*]] = "tf_device.launch"() <{device = "TPU_REPLICATED_CORE_0"}>
       // CHECK-NEXT:     %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[RI_0]], %[[RI_1]], %[[RI_2]], %[[COMPILE]]#1)
       // CHECK-NEXT:     tf_device.return %[[EXECUTE_OUTPUT]]
-      // CHECK-NEXT:   device = "TPU_REPLICATED_CORE_0"
-      // CHECK:        "tf_device.launch"
+      // CHECK:        "tf_device.launch"() <{device = "TPU_REPLICATED_CORE_1"}>
       // CHECK-NEXT:     "tf.TPUExecute"(%[[RI_1]], %[[RI_2]], %[[COMPILE]]#2)
-      // CHECK:        device = "TPU_REPLICATED_CORE_1"
       %1 = "tf_device.cluster_func"(%ri, %ri2, %ri3) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "", ""], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<8xi32>, tensor<*xi1>, tensor<*xi32>) -> tensor<8xi32>
       tf_device.return %1 : tensor<8xi32>
     }
@@ -1663,20 +1642,16 @@
     // CHECK-SAME: TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"]
     // CHECK-SAME: TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"]
     %0:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<8xi32>) {n = 2 : i32} {
-      // CHECK-NEXT: %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch"
+      // CHECK-NEXT: %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}>
       // CHECK-NEXT:   "tf._TPUCompileMlir"()
-      // CHECK:      device = "/job:localhost/replica:0/task:0/device:CPU:0"
-      // CHECK:      "tf_device.launch"
+      // CHECK:      "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}>
       // CHECK-NEXT:   "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0)
-      // CHECK:      device = "/job:localhost/replica:0/task:0/device:CPU:0"
       // CHECK:      %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
-      // CHECK-NEXT:   %[[LAUNCH_0_OUTPUT:[0-9]*]] = "tf_device.launch"
+      // CHECK-NEXT:   %[[LAUNCH_0_OUTPUT:[0-9]*]] = "tf_device.launch"() <{device = "TPU_REPLICATED_CORE_0"}>
       // CHECK-NEXT:     %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute"
       // CHECK-NEXT:     tf_device.return %[[EXECUTE_OUTPUT]]
-      // CHECK-NEXT:   device = "TPU_REPLICATED_CORE_0"
-      // CHECK:        "tf_device.launch"
+      // CHECK:        "tf_device.launch"() <{device = "TPU_REPLICATED_CORE_1"}>
       // CHECK-NEXT:     "tf.TPUExecute"
-      // CHECK:        device = "TPU_REPLICATED_CORE_1"
       %1 = "tf_device.cluster_func"(%ri) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<8xi32>) -> tensor<8xi32>
       tf_device.return %1 : tensor<8xi32>
     }
@@ -1700,21 +1675,17 @@
     // CHECK-SAME: TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"]
     // CHECK-SAME: TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"]
     %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<8xi32>) {n = 2 : i32} {
-      // CHECK-NEXT: %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch"
+      // CHECK-NEXT: %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}>
       // CHECK-NEXT:   "tf._TPUCompileMlir"()
-      // CHECK:      device = "/job:localhost/replica:0/task:0/device:CPU:0"
-      // CHECK:      "tf_device.launch"
+      // CHECK:      "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}>
       // CHECK-NEXT:   "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0)
-      // CHECK:      device = "/job:localhost/replica:0/task:0/device:CPU:0"
       // CHECK:      %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:3 = "tf_device.parallel_execute"
-      // CHECK-NEXT:   %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
+      // CHECK-NEXT:   %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"() <{device = "TPU_REPLICATED_CORE_0"}>
       // CHECK-NEXT:     %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute"
       // CHECK-NEXT:     tf_device.return %[[EXECUTE_0_OUTPUT]]
-      // CHECK-NEXT:   device = "TPU_REPLICATED_CORE_0"
-      // CHECK:        %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch"
+      // CHECK:        %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch"() <{device = "TPU_REPLICATED_CORE_1"}>
       // CHECK-NEXT:     %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute"
       // CHECK-NEXT:     tf_device.return %[[EXECUTE_1_OUTPUT]]
-      // CHECK:        device = "TPU_REPLICATED_CORE_1"
       %1, %2 = "tf_device.cluster_func"(%ri) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""], use_spmd_for_xla_partitioning = false} : (tensor<8xi32>) -> (tensor<*xi32>, tensor<*xi1>)
       tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1>
     }
@@ -1763,25 +1734,21 @@
     // CHECK-SAME: TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"]
     // CHECK-SAME: TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"]
     %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} {
-      // CHECK:      %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch"
+      // CHECK:      %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}>
       // CHECK-NEXT:   "tf._TPUCompileMlir"
-      // CHECK:      device = "/job:localhost/replica:0/task:0/device:CPU:0"
-      // CHECK:      "tf_device.launch"
+      // CHECK:      "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}>
       // CHECK-NEXT:   "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0)
-      // CHECK:      device = "/job:localhost/replica:0/task:0/device:CPU:0"
       //
       // CHECK:      %[[CONST_SPLIT_DIM:.*]] = "tf.Const"()
       // CHECK:      %[[SPLIT_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_DIM]], %[[RI_0]])
       // CHECK:      %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:3 = "tf_device.parallel_execute"
-      // CHECK-NEXT:   %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
+      // CHECK-NEXT:   %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"() <{device = "TPU_REPLICATED_CORE_0"}>
       //
       // CHECK-NEXT:     %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute"(%[[SPLIT_OUT]]#0, %[[COMPILE]]#1)
       // CHECK-NEXT:     tf_device.return %[[EXECUTE_0_OUTPUT]]
-      // CHECK-NEXT:   device = "TPU_REPLICATED_CORE_0"
-      // CHECK:        %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch"
+      // CHECK:        %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch"() <{device = "TPU_REPLICATED_CORE_1"}>
       // CHECK-NEXT:     %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_OUT]]#1, %[[RI_1]], %[[COMPILE]]#2)
       // CHECK-NEXT:     tf_device.return %[[EXECUTE_1_OUTPUT]]
-      // CHECK:        device = "TPU_REPLICATED_CORE_1"
       %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""], use_spmd_for_xla_partitioning = false} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>)
       tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1>
     }
@@ -1830,22 +1797,18 @@
     // CHECK-SAME: TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"]
     // CHECK-SAME: TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"]
     %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} {
-      // CHECK:      %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch"
+      // CHECK:      %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}>
       // CHECK-NEXT:   "tf._TPUCompileMlir"
-      // CHECK:      device = "/job:localhost/replica:0/task:0/device:CPU:0"
-      // CHECK:      "tf_device.launch"
+      // CHECK:      "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}>
       // CHECK-NEXT:   "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0)
-      // CHECK:      device = "/job:localhost/replica:0/task:0/device:CPU:0"
       //
       // CHECK:      %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:3 = "tf_device.parallel_execute"
-      // CHECK-NEXT:   %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
+      // CHECK-NEXT:   %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"() <{device = "TPU_REPLICATED_CORE_0"}>
       // CHECK-NEXT:     %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute"
       // CHECK-NEXT:     tf_device.return %[[EXECUTE_0_OUTPUT]]
-      // CHECK-NEXT:   device = "TPU_REPLICATED_CORE_0"
-      // CHECK:        %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch"
+      // CHECK:        %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch"() <{device = "TPU_REPLICATED_CORE_1"}>
       // CHECK-NEXT:     %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute"
       // CHECK-NEXT:     tf_device.return %[[EXECUTE_1_OUTPUT]]
-      // CHECK:        device = "TPU_REPLICATED_CORE_1"
       //
       // CHECK:     %[[CONST_CONCAT_DIM:.*]] = "tf.Const"()
       // CHECK:     %[[CONCAT_OUTPUT:[0-9]*]] = "tf.Concat"(%[[CONST_CONCAT_DIM]], %[[PARALLEL_EXECUTE_OUTPUT]]#0, %[[PARALLEL_EXECUTE_OUTPUT]]#2
@@ -1899,22 +1862,18 @@
     // CHECK-SAME: TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"]
     // CHECK-SAME: TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"]
     %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} {
-      // CHECK:      %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch"
+      // CHECK:      %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}>
       // CHECK-NEXT:   "tf._TPUCompileMlir"
-      // CHECK:      device = "/job:localhost/replica:0/task:0/device:CPU:0"
-      // CHECK:      "tf_device.launch"
+      // CHECK:      "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}>
       // CHECK-NEXT:   "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0)
-      // CHECK:      device = "/job:localhost/replica:0/task:0/device:CPU:0"
       //
       // CHECK:      %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:3 = "tf_device.parallel_execute"
-      // CHECK-NEXT:   %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
+      // CHECK-NEXT:   %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"() <{device = "TPU_REPLICATED_CORE_0"}>
       // CHECK-NEXT:     %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute"
       // CHECK-NEXT:     tf_device.return %[[EXECUTE_0_OUTPUT]]
-      // CHECK-NEXT:   device = "TPU_REPLICATED_CORE_0"
-      // CHECK:        %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch"
+      // CHECK:        %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch"() <{device = "TPU_REPLICATED_CORE_1"}>
       // CHECK-NEXT:     %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute"
       // CHECK-NEXT:     tf_device.return %[[EXECUTE_1_OUTPUT]]
-      // CHECK:        device = "TPU_REPLICATED_CORE_1"
       //
       // CHECK:     %[[CONST_CONCAT_DIM:.*]] = "tf.Const"()
       // CHECK:     %[[CONCAT_OUTPUT:[0-9]*]] = "tf.Concat"(%[[CONST_CONCAT_DIM]], %[[PARALLEL_EXECUTE_OUTPUT]]#1, %[[PARALLEL_EXECUTE_OUTPUT]]#2
@@ -2091,12 +2050,10 @@
     // CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<128x10xf32>
     // CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<*xi32>
     %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} {
-      // CHECK:      %[[COMPILE:[a-z0-9]+]]:5 = "tf_device.launch"
+      // CHECK:      %[[COMPILE:[a-z0-9]+]]:5 = "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}>
       // CHECK-NEXT:   "tf._TPUCompileMlir"
-      // CHECK:      device = "/job:localhost/replica:0/task:0/device:CPU:0"
-      // CHECK:      "tf_device.launch"
+      // CHECK:      "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}>
       // CHECK-NEXT:   "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0)
-      // CHECK:      device = "/job:localhost/replica:0/task:0/device:CPU:0"
       // CHECK:      %[[CONST_SPLIT_0_DIM:.*]] = "tf.Const"()
       // CHECK:      %[[SPLIT_0_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_0_DIM]], %[[RI_0]])
       // CHECK:      %[[CONST_SPLIT_1_DIM:.*]] = "tf.Const"()
@@ -2198,12 +2155,10 @@
     // CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<128x10xf32>
     // CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<*xi32>
     %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} {
-      // CHECK:      %[[COMPILE:[a-z0-9]+]]:5 = "tf_device.launch"
+      // CHECK:      %[[COMPILE:[a-z0-9]+]]:5 = "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}>
       // CHECK-NEXT:   "tf._TPUCompileMlir"
-      // CHECK:      device = "/job:localhost/replica:0/task:0/device:CPU:0"
-      // CHECK:      "tf_device.launch"
+      // CHECK:      "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}>
       // CHECK-NEXT:   "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0)
-      // CHECK:      device = "/job:localhost/replica:0/task:0/device:CPU:0"
       // CHECK:      %[[CONST_SPLIT_0_DIM:.*]] = "tf.Const"()
       // CHECK:      %[[SPLIT_0_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_0_DIM]], %[[RI_0]])
       // CHECK:      %[[CONST_SPLIT_1_DIM:.*]] = "tf.Const"()
@@ -2282,12 +2237,10 @@
     // CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<128x10xf32>
     // CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<*xi32>
     %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} {
-      // CHECK:      %[[COMPILE:[a-z0-9]+]]:5 = "tf_device.launch"
+      // CHECK:      %[[COMPILE:[a-z0-9]+]]:5 = "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}>
       // CHECK-NEXT:   "tf._TPUCompileMlir"
-      // CHECK:      device = "/job:localhost/replica:0/task:0/device:CPU:0"
-      // CHECK:      "tf_device.launch"
+      // CHECK:      "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}>
       // CHECK-NEXT:   "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0)
-      // CHECK:      device = "/job:localhost/replica:0/task:0/device:CPU:0"
       // CHECK:      %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:5 = "tf_device.parallel_execute"
       // CHECK-NEXT:   %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
       // CHECK-NEXT:     %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute"
@@ -2367,12 +2320,10 @@
     // CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<128x10xf32>
     // CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<*xi32>
     %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} {
-      // CHECK:      %[[COMPILE:[a-z0-9]+]]:5 = "tf_device.launch"
+      // CHECK:      %[[COMPILE:[a-z0-9]+]]:5 = "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}>
       // CHECK-NEXT:   "tf._TPUCompileMlir"
-      // CHECK:      device = "/job:localhost/replica:0/task:0/device:CPU:0"
-      // CHECK:      "tf_device.launch"
+      // CHECK:      "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}>
       // CHECK-NEXT:   "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0)
-      // CHECK:      device = "/job:localhost/replica:0/task:0/device:CPU:0"
       // CHECK:      %[[CONST_SPLIT_0_DIM:.*]] = "tf.Const"()
       // CHECK:      %[[SPLIT_0_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_0_DIM]], %[[RI_0]])
       // CHECK:      %[[CONST_SPLIT_1_DIM:.*]] = "tf.Const"()
@@ -2451,12 +2402,10 @@
     // CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<128x10xf32>
     // CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<*xi32>
     %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} {
-      // CHECK:      %[[COMPILE:[a-z0-9]+]]:5 = "tf_device.launch"
+      // CHECK:      %[[COMPILE:[a-z0-9]+]]:5 = "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}>
       // CHECK-NEXT:   "tf._TPUCompileMlir"
-      // CHECK:      device = "/job:localhost/replica:0/task:0/device:CPU:0"
-      // CHECK:      "tf_device.launch"
+      // CHECK:      "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}>
       // CHECK-NEXT:   "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0)
-      // CHECK:      device = "/job:localhost/replica:0/task:0/device:CPU:0"
       // CHECK:      %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:5 = "tf_device.parallel_execute"
       // CHECK-NEXT:   %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
       // CHECK-NEXT:     %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute"
@@ -2613,14 +2562,12 @@
 module attributes {tf.devices = {"/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"}, tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 1199 : i32}} {
   func.func @return_from_host_and_tpu() -> (tensor<?xi32>, tensor<?x!tf_type.string>) attributes {tf._construction_context = "kEagerRuntime", tf.signature.is_stateful} {
       // CHECK:     %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:2 = "tf_device.parallel_execute"
-      // CHECK:       %[[LAUNCH_0_OUTPUT:[0-9]*]] = "tf_device.launch"
+      // CHECK:       %[[LAUNCH_0_OUTPUT:[0-9]*]] = "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}>
       // CHECK:         %[[B_OUTPUT:[0-9]*]] = "tf.B"
       // CHECK:         tf_device.return %[[B_OUTPUT:[0-9]*]]
-      // CHECK:       device = "/job:localhost/replica:0/task:0/device:CPU:0"
-      // CHECK:       %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch"
+      // CHECK:       %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:TPU:0"}>
       // CHECK-NEXT:    %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute"
       // CHECK:         tf_device.return %[[EXECUTE_1_OUTPUT]]
-      // CHECK:       device = "/job:localhost/replica:0/task:0/device:TPU:0"
       // CHECK:    return %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]#1, %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]#0
     %0:2 = "tf_device.parallel_execute"() ({
       %1 = "tf_device.launch"() ({
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_space_to_depth_pass.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_space_to_depth_pass.mlir
index 2125f28..4d290b7 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_space_to_depth_pass.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_space_to_depth_pass.mlir
@@ -15,7 +15,7 @@
     %0 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
     // CHECK: %[[INPUT:.*]] = "tf.IteratorGetNext"
     %1 = "tf.IteratorGetNext"(%arg5) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<!tf_type.resource>) -> tensor<2x224x224x3xf32>
-    // CHECK-DAG: %[[SPACETODEPTH0:.*]] = "tf.SpaceToDepth"([[INPUT:.*]]) {block_size = 2 : i64, data_format = "NHWC"} : (tensor<2x224x224x3xf32>) -> tensor<2x112x112x12xf32>
+    // CHECK-DAG: %[[SPACETODEPTH0:.*]] = "tf.SpaceToDepth"([[INPUT:.*]]) <{block_size = 2 : i64, data_format = "NHWC"}> : (tensor<2x224x224x3xf32>) -> tensor<2x112x112x12xf32>
     %2 = "tf.AddV2"(%arg2, %arg3) {device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
     %3 = "tf.ReadVariableOp"(%arg6) : (tensor<!tf_type.resource<tensor<7x7x3x64xf32>>>) -> tensor<7x7x3x64xf32>
     %4 = "tf.ReadVariableOp"(%arg8) : (tensor<!tf_type.resource<tensor<f32>>>) -> tensor<f32>
@@ -61,18 +61,18 @@
     // CHECK-SAME: strides = [1, 1, 1, 1]
     // CHECK-SAME: (tensor<2x115x115x12xf32>, tensor<4xi32>, tensor<2x112x112x64xf32>) -> tensor<4x4x12x64xf32>
     %7 = "tf.Conv2DBackpropFilter"(%5, %2, %6) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true} : (tensor<2x230x230x3xf32>, tensor<4xi32>, tensor<2x112x112x64xf32>) -> tensor<7x7x3x64xf32>
-    // CHECK: %[[CONST0:.*]] = "tf.Const"() {value = dense<
+    // CHECK: %[[CONST0:.*]] = "tf.Const"() <{value = dense<
     // CHECK-SAME: [4, 4, 2, 2, 3, 64]
     // CHECK: %[[RESHAPE0:.*]] = "tf.Reshape"(%[[BACKPROP:.*]], %[[CONST0:.*]]) : (tensor<4x4x12x64xf32>, tensor<6xi64>) -> tensor<4x4x2x2x3x64xf32>
-    // CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense<
+    // CHECK: %[[CONST1:.*]] = "tf.Const"() <{value = dense<
     // CHECK-SAME: [0, 2, 1, 3, 4, 5]
     // CHECK: %[[TRANSPOSE0:.*]] = "tf.Transpose"(%[[RESHAPE0:.*]], %[[CONST1:.*]]) : (tensor<4x4x2x2x3x64xf32>, tensor<6xi32>) -> tensor<4x2x4x2x3x64xf32>
-    // CHECK: %[[CONST2:.*]] = "tf.Const"() {value = dense<
+    // CHECK: %[[CONST2:.*]] = "tf.Const"() <{value = dense<
     // CHECK-SAME: [8, 8, 3, 64]
     // CHECK: %[[RESHAPE1:.*]] = "tf.Reshape"(%[[TRANSPOSE1:.*]], %[[CONST2:.*]]) : (tensor<4x2x4x2x3x64xf32>, tensor<4xi64>) -> tensor<8x8x3x64xf32>
-    // CHECK: %[[CONST3:.*]] = "tf.Const"() {value = dense<
+    // CHECK: %[[CONST3:.*]] = "tf.Const"() <{value = dense<
     // CHECK-SAME: [7, 7, 3, 64]
-    // CHECK: %[[CONST4:.*]] = "tf.Const"() {value = dense<
+    // CHECK: %[[CONST4:.*]] = "tf.Const"() <{value = dense<
     // CHECK-SAME: 0
     // CHECK: %[[SLICE0:.*]] = "tf.Slice"(%[[RESHAPE1:.*]], %[[CONST4:.*]], %[[CONST3:.*]]) : (tensor<8x8x3x64xf32>, tensor<4xi64>, tensor<4xi32>) -> tensor<7x7x3x64xf32>
     %8 = "tf.CrossReplicaSum"(%7, %1) : (tensor<7x7x3x64xf32>, tensor<1x1xi32>) -> tensor<7x7x3x64xf32>
@@ -90,10 +90,10 @@
 module attributes {tf.devices = {"/job:localhost/replica:0/task:0/device:COMPOSITE:0" = {}, "/job:localhost/replica:0/task:0/device:CPU:0" = {}, "/job:localhost/replica:0/task:0/device:TPU:0" = {}, "/job:localhost/replica:0/task:0/device:TPU:1" = {}, "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0" = {}}, tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 458 : i32}} {
   func.func @main(%arg0: tensor<*x!tf_type.resource> {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg1: tensor<!tf_type.variant> {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg2: tensor<*x!tf_type.resource> {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg3: tensor<!tf_type.variant> {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg4: tensor<*x!tf_type.resource> {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg5: tensor<!tf_type.variant> {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg6: tensor<*x!tf_type.resource<tensor<7x7x3x64xf32>>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg7: tensor<*x!tf_type.resource<tensor<64x1001xf32>>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg8: tensor<*x!tf_type.resource<tensor<1001xf32>>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg9: tensor<*x!tf_type.resource<tensor<f32>>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg10: tensor<*x!tf_type.resource<tensor<f32>>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg11: tensor<*x!tf_type.resource<tensor<f32>>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg12: tensor<*x!tf_type.resource<tensor<f32>>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}) attributes {tf.entry_function = {control_outputs = "IteratorGetNext,IteratorGetNext_1,CrossReplicaSum,AssignAddVariableOp,CrossReplicaSum_1,AssignAddVariableOp_1,CrossReplicaSum_2,AssignAddVariableOp_2,CrossReplicaSum_3,AssignAddVariableOp_3", inputs = "iterator,iterator_1,iterator_2,iterator_3,iterator_4,iterator_5,resnet50_conv1_conv2d_conv1_kernel_140365606309224_handle_inputs_0,resnet50_fc1000_matmul_fc1000_kernel_140365944145960_handle_inputs_0,resnet50_fc1000_biasadd_fc1000_bias_140365944146240_handle_inputs_0,total_140366323758976_handle_inputs_0,count_140366323759312_handle_inputs_0,total_140366323760264_handle_inputs_0,count_140366323760600_handle_inputs_0", outputs = ""}} {
     // CHECK: %[[INPUT00:.*]] = "tf.IteratorGetNext"
-    // CHECK-DAG: %[[SPACETODEPTH00:.*]] = "tf.SpaceToDepth"([[INPUT00:.*]]#0) {block_size = 2 : i64, data_format = "NHWC"} : (tensor<2x224x224x3xf32>) -> tensor<2x112x112x12xf32>
+    // CHECK-DAG: %[[SPACETODEPTH00:.*]] = "tf.SpaceToDepth"([[INPUT00:.*]]#0) <{block_size = 2 : i64, data_format = "NHWC"}> : (tensor<2x224x224x3xf32>) -> tensor<2x112x112x12xf32>
     %0:2 = "tf.IteratorGetNext"(%arg2) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<*x!tf_type.resource>) -> (tensor<2x224x224x3xf32>, tensor<2x1xf32>)
     // CHECK: %[[INPUT01:.*]] = "tf.IteratorGetNext"
-    // CHECK-DAG: %[[SPACETODEPTH01:.*]] = "tf.SpaceToDepth"([[INPUT01:.*]]#0) {block_size = 2 : i64, data_format = "NHWC"} : (tensor<2x224x224x3xf32>) -> tensor<2x112x112x12xf32>
+    // CHECK-DAG: %[[SPACETODEPTH01:.*]] = "tf.SpaceToDepth"([[INPUT01:.*]]#0) <{block_size = 2 : i64, data_format = "NHWC"}> : (tensor<2x224x224x3xf32>) -> tensor<2x112x112x12xf32>
     %1:2 = "tf.IteratorGetNext"(%arg4) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<*x!tf_type.resource>) -> (tensor<2x224x224x3xf32>, tensor<2x1xf32>)
     tf_device.replicate([%0#0, %1#0] as %arg13: tensor<2x224x224x3xf32>, [%0#1, %1#1] as %arg14: tensor<2x1xf32>, %arg6 as %arg15: tensor<*x!tf_type.resource<tensor<7x7x3x64xf32>>>, %arg8 as %arg16: tensor<*x!tf_type.resource<tensor<1001xf32>>>, %arg7 as %arg17: tensor<*x!tf_type.resource<tensor<64x1001xf32>>>, %arg9 as %arg18: tensor<*x!tf_type.resource<tensor<f32>>>, %arg10 as %arg19: tensor<*x!tf_type.resource<tensor<f32>>>, %arg11 as %arg20: tensor<*x!tf_type.resource<tensor<f32>>>, %arg12 as %arg21: tensor<*x!tf_type.resource<tensor<f32>>>) {_mirrored_variable_indices = [2, 3, 4, 5, 6, 7, 8], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} {
       %2 = "tf.ReadVariableOp"(%arg15) : (tensor<*x!tf_type.resource<tensor<7x7x3x64xf32>>>) -> tensor<7x7x3x64xf32>
@@ -167,7 +167,7 @@
 module attributes {tf.devices = {"/job:localhost/replica:0/task:0/device:COMPOSITE:0" = {}, "/job:localhost/replica:0/task:0/device:CPU:0" = {}, "/job:localhost/replica:0/task:0/device:TPU:0" = {}, "/job:localhost/replica:0/task:0/device:TPU:1" = {}, "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0" = {}}, tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 458 : i32}} {
   func.func @main(%arg0: tensor<*x!tf_type.resource> {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg1: tensor<!tf_type.variant> {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg2: tensor<*x!tf_type.resource> {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg3: tensor<!tf_type.variant> {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg4: tensor<*x!tf_type.resource> {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg5: tensor<!tf_type.variant> {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg6: tensor<*x!tf_type.resource<tensor<7x7x3x64xf32>>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg7: tensor<*x!tf_type.resource<tensor<64x1001xf32>>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg8: tensor<*x!tf_type.resource<tensor<1001xf32>>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg9: tensor<*x!tf_type.resource<tensor<f32>>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg10: tensor<*x!tf_type.resource<tensor<f32>>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg11: tensor<*x!tf_type.resource<tensor<f32>>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg12: tensor<*x!tf_type.resource<tensor<f32>>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}) attributes {tf.entry_function = {control_outputs = "IteratorGetNext,IteratorGetNext_1,CrossReplicaSum,AssignAddVariableOp,CrossReplicaSum_1,AssignAddVariableOp_1,CrossReplicaSum_2,AssignAddVariableOp_2,CrossReplicaSum_3,AssignAddVariableOp_3", inputs = "iterator,iterator_1,iterator_2,iterator_3,iterator_4,iterator_5,resnet50_conv1_conv2d_conv1_kernel_140365606309224_handle_inputs_0,resnet50_fc1000_matmul_fc1000_kernel_140365944145960_handle_inputs_0,resnet50_fc1000_biasadd_fc1000_bias_140365944146240_handle_inputs_0,total_140366323758976_handle_inputs_0,count_140366323759312_handle_inputs_0,total_140366323760264_handle_inputs_0,count_140366323760600_handle_inputs_0", outputs = ""}} {
     // CHECK: %[[INPUT00:.*]] = "tf.IteratorGetNext"
-    // CHECK-DAG: %[[SPACETODEPTH00:.*]] = "tf.SpaceToDepth"([[INPUT00:.*]]#0) {block_size = 2 : i64, data_format = "NHWC"} : (tensor<2x224x224x3xf32>) -> tensor<2x112x112x12xf32>
+    // CHECK-DAG: %[[SPACETODEPTH00:.*]] = "tf.SpaceToDepth"([[INPUT00:.*]]#0) <{block_size = 2 : i64, data_format = "NHWC"}> : (tensor<2x224x224x3xf32>) -> tensor<2x112x112x12xf32>
     %0:2 = "tf.IteratorGetNext"(%arg2) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<*x!tf_type.resource>) -> (tensor<2x224x224x3xf32>, tensor<2x1xf32>)
     tf_device.replicate(%0#0 as %arg13: tensor<2x224x224x3xf32>, %0#1 as %arg14: tensor<2x1xf32>, %arg6 as %arg15: tensor<*x!tf_type.resource<tensor<7x7x3x64xf32>>>, %arg8 as %arg16: tensor<*x!tf_type.resource<tensor<1001xf32>>>, %arg7 as %arg17: tensor<*x!tf_type.resource<tensor<64x1001xf32>>>, %arg9 as %arg18: tensor<*x!tf_type.resource<tensor<f32>>>, %arg10 as %arg19: tensor<*x!tf_type.resource<tensor<f32>>>, %arg11 as %arg20: tensor<*x!tf_type.resource<tensor<f32>>>, %arg12 as %arg21: tensor<*x!tf_type.resource<tensor<f32>>>) {_mirrored_variable_indices = [2, 3, 4, 5, 6, 7, 8], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} {
       %2 = "tf.ReadVariableOp"(%arg15) : (tensor<*x!tf_type.resource<tensor<7x7x3x64xf32>>>) -> tensor<7x7x3x64xf32>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_tail_with_tobool_op.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_tail_with_tobool_op.mlir
index 403bc64..5dcb27a 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_tail_with_tobool_op.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_tail_with_tobool_op.mlir
@@ -1,4 +1,4 @@
-// RUN: tf-opt %s -tf-tpu-bridge 2>&1 | FileCheck %s
+// RUN: tf-opt %s -tf-cluster-tpu-bridge-v2 -tfrt-lower-cluster-to-runtime-ops-tpu 2>&1 | FileCheck %s
 
 // This test verifies that the tail extraction is not terminated prematurely
 // in handling tf.If op which would end up with excessive host-device
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_update_embedding_enqueue_op_inputs.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_update_embedding_enqueue_op_inputs.mlir
index d3276c2..d1437a6 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_update_embedding_enqueue_op_inputs.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_update_embedding_enqueue_op_inputs.mlir
@@ -15,7 +15,7 @@
   // CHECK: %[[CONST_0:.*]] = "tf.Const"()
   %0 = "tf.Const"() {value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32>
 
-  // CHECK: %[[CONST_MODE:.*]] = "tf.Const"() {_xla_outside_compilation = "0", value = dense<"inference"> : tensor<!tf_type.string>} : () -> tensor<!tf_type.string>
+  // CHECK: %[[CONST_MODE:.*]] = "tf.Const"() <{value = dense<"inference"> : tensor<!tf_type.string>}> {_xla_outside_compilation = "0"} : () -> tensor<!tf_type.string>
   // CHECK: "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ARG_4]], %[[ARG_5]], %[[CONST_0]], %[[CONST_0]], %[[CONST_0]], %[[CONST_MODE]])
   "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %arg7) {_tpu_embedding_layer = "call1", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor<?x2xi32>, tensor<?x2xi32>, tensor<?x2xi32>, tensor<?xi32>, tensor<?xi32>, tensor<?xi32>, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor<!tf_type.string>) -> ()
   %2:2 = "tf.RecvTPUEmbeddingActivations"() {_tpu_embedding_layer = "call1", config = "\0A\0B\0C\0D"} : () -> (tensor<2x2xf32>, tensor<4x4xf32>)
@@ -43,7 +43,7 @@
   %3 = "tf.Const"() {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
   "tf.SendTPUEmbeddingGradients"(%2, %3) {_tpu_embedding_layer = "call1", config = "\0A\0B\0C\0D", operandSegmentSizes = array<i32: 2, 0>} : (tensor<2x2xf32>, tensor<4x4xf32>) -> ()
 
-  // CHECK: %[[CONST_MODE:.*]] = "tf.Const"() {_xla_outside_compilation = "0", value = dense<"train"> : tensor<!tf_type.string>} : () -> tensor<!tf_type.string>
+  // CHECK: %[[CONST_MODE:.*]] = "tf.Const"() <{value = dense<"train"> : tensor<!tf_type.string>}> {_xla_outside_compilation = "0"} : () -> tensor<!tf_type.string>
   // CHECK: "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ARG_4]], %[[ARG_5]], %[[CONST_0]], %[[CONST_0]], %[[CONST_0]], %[[CONST_MODE]])
   "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %arg7) {_tpu_embedding_layer = "call1", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor<?x2xi32>, tensor<?x2xi32>, tensor<?x2xi32>, tensor<?xi32>, tensor<?xi32>, tensor<?xi32>, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor<!tf_type.string>) -> ()
   %4:2 = "tf.RecvTPUEmbeddingActivations"() {_tpu_embedding_layer = "call1", config = "\0A\0B\0C\0D"} : () -> (tensor<2x2xf32>, tensor<4x4xf32>)
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir b/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir
index ec2d36d..4333e79 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir
@@ -7,12 +7,12 @@
   func.return %0 : tensor<2x3x4x6xf32>
 
   // CHECK-LABEL: batchMatMulTwoDim
-  // CHECK-DAG: %[[LHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 4, 5]> : tensor<3xi64>}
-  // CHECK-DAG: %[[RHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 5, 6]> : tensor<3xi64>}
-  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
-  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
-  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
-  // CHECK-DAG: %[[RESULT_SHAPE:.*]] = "tf.Const"() {value = dense<[2, 3, 4, 6]> : tensor<4xi64>}
+  // CHECK-DAG: %[[LHS_RESHAPED_SHAPE:.*]] = "tf.Const"() <{value = dense<[6, 4, 5]> : tensor<3xi64>}>
+  // CHECK-DAG: %[[RHS_RESHAPED_SHAPE:.*]] = "tf.Const"() <{value = dense<[6, 5, 6]> : tensor<3xi64>}>
+  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}>
+  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() <{value = dense<[4, 5]> : tensor<2xi64>}>
+  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() <{value = dense<[5, 6]> : tensor<2xi64>}>
+  // CHECK-DAG: %[[RESULT_SHAPE:.*]] = "tf.Const"() <{value = dense<[2, 3, 4, 6]> : tensor<4xi64>}>
 
   // CHECK: %[[LHS_RESHAPED:.*]] = "tf.Reshape"(%arg0, %[[LHS_RESHAPED_SHAPE]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32>
   // CHECK: %[[LHS_SPLIT:.*]]:6 = "tf.Split"(%[[SPLITTING_AXIS]], %[[LHS_RESHAPED]]) : (tensor<i32>, tensor<6x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>)
@@ -32,14 +32,14 @@
   // CHECK: %[[RHS_5:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#4, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
   // CHECK: %[[RHS_6:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#5, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
 
-  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_4]], %[[RHS_4]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_5]], %[[RHS_5]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_6]], %[[RHS_6]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_4]], %[[RHS_4]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_5]], %[[RHS_5]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_6]], %[[RHS_6]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
 
-  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
+  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) <{axis = 0 : i64}> : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
   // CHECK: %[[RESULT:.*]] = "tf.Reshape"(%[[MATMUL_PACKED]], %[[RESULT_SHAPE]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32>
   // CHECK: return %[[RESULT]] : tensor<2x3x4x6xf32>
 }
@@ -51,12 +51,12 @@
   func.return %0 : tensor<2x3x4x6xf32>
 
   // CHECK-LABEL: batchMatMulTwoDimAdjXY
-  // CHECK-DAG: %[[LHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 5, 4]> : tensor<3xi64>}
-  // CHECK-DAG: %[[RHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 6, 5]> : tensor<3xi64>}
-  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
-  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 4]> : tensor<2xi64>}
-  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 5]> : tensor<2xi64>}
-  // CHECK-DAG: %[[RESULT_SHAPE:.*]] = "tf.Const"() {value = dense<[2, 3, 4, 6]> : tensor<4xi64>}
+  // CHECK-DAG: %[[LHS_RESHAPED_SHAPE:.*]] = "tf.Const"() <{value = dense<[6, 5, 4]> : tensor<3xi64>}>
+  // CHECK-DAG: %[[RHS_RESHAPED_SHAPE:.*]] = "tf.Const"() <{value = dense<[6, 6, 5]> : tensor<3xi64>}>
+  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}>
+  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() <{value = dense<[5, 4]> : tensor<2xi64>}>
+  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() <{value = dense<[6, 5]> : tensor<2xi64>}>
+  // CHECK-DAG: %[[RESULT_SHAPE:.*]] = "tf.Const"() <{value = dense<[2, 3, 4, 6]> : tensor<4xi64>}>
 
   // CHECK: %[[LHS_RESHAPED:.*]] = "tf.Reshape"(%arg0, %[[LHS_RESHAPED_SHAPE]]) : (tensor<2x3x5x4xf32>, tensor<3xi64>) -> tensor<6x5x4xf32>
   // CHECK: %[[LHS_SPLIT:.*]]:6 = "tf.Split"(%[[SPLITTING_AXIS]], %[[LHS_RESHAPED]]) : (tensor<i32>, tensor<6x5x4xf32>) -> (tensor<1x5x4xf32>, tensor<1x5x4xf32>, tensor<1x5x4xf32>, tensor<1x5x4xf32>, tensor<1x5x4xf32>, tensor<1x5x4xf32>)
@@ -76,14 +76,14 @@
   // CHECK: %[[RHS_5:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#4, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x6x5xf32>, tensor<2xi64>) -> tensor<6x5xf32>
   // CHECK: %[[RHS_6:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#5, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x6x5xf32>, tensor<2xi64>) -> tensor<6x5xf32>
 
-  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = true, transpose_b = true} : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) {transpose_a = true, transpose_b = true} : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) {transpose_a = true, transpose_b = true} : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_4]], %[[RHS_4]]) {transpose_a = true, transpose_b = true} : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_5]], %[[RHS_5]]) {transpose_a = true, transpose_b = true} : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_6]], %[[RHS_6]]) {transpose_a = true, transpose_b = true} : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_4]], %[[RHS_4]]) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_5]], %[[RHS_5]]) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_6]], %[[RHS_6]]) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
 
-  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
+  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) <{axis = 0 : i64}> : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
   // CHECK: %[[RESULT:.*]] = "tf.Reshape"(%[[MATMUL_PACKED]], %[[RESULT_SHAPE]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32>
   // CHECK: return %[[RESULT]] : tensor<2x3x4x6xf32>
 }
@@ -95,9 +95,9 @@
   func.return %0 : tensor<3x4x6xf32>
 
   // CHECK-LABEL: batchMatMulOneDim
-  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
-  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
-  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
+  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}>
+  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() <{value = dense<[4, 5]> : tensor<2xi64>}>
+  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() <{value = dense<[5, 6]> : tensor<2xi64>}>
 
   // CHECK: %[[LHS_RESHAPED:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %arg0) : (tensor<i32>, tensor<3x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>)
   // CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_RESHAPED]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
@@ -109,11 +109,11 @@
   // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
   // CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
 
-  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
 
-  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
+  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) <{axis = 0 : i64}> : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
   // CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32>
 }
 
@@ -124,16 +124,16 @@
   func.return %0 : tensor<1x4x6xf32>
 
   // CHECK-LABEL: batchMatMulSingleBatch
-  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>} : () -> tensor<2xi64>
-  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>} : () -> tensor<2xi64>
+  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() <{value = dense<[4, 5]> : tensor<2xi64>}> : () -> tensor<2xi64>
+  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() <{value = dense<[5, 6]> : tensor<2xi64>}> : () -> tensor<2xi64>
 
   // CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%arg0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
 
   // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%arg1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
 
-  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_2]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
 
-  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]]) {axis = 0 : i64} : (tensor<4x6xf32>) -> tensor<1x4x6xf32>
+  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]]) <{axis = 0 : i64}> : (tensor<4x6xf32>) -> tensor<1x4x6xf32>
   // CHECK: return %[[MATMUL_PACKED]] : tensor<1x4x6xf32>
 }
 
@@ -144,19 +144,19 @@
   func.return %0 : tensor<3x4x6xf32>
 
   // CHECK-LABEL: batchMatMulUnbatchedLeft
-  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
-  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
+  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}>
+  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() <{value = dense<[5, 6]> : tensor<2xi64>}>
 
   // CHECK: %[[RHS_RESHAPED:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %arg1) : (tensor<i32>, tensor<3x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>)
   // CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
   // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
   // CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
 
-  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%arg0, %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%arg0, %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %[[RHS_1]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%arg0, %[[RHS_2]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%arg0, %[[RHS_3]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
 
-  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
+  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) <{axis = 0 : i64}> : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
   // CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32>
 }
 
@@ -167,19 +167,19 @@
   func.return %0 : tensor<3x4x6xf32>
 
   // CHECK-LABEL: batchMatMulUnbatchedRight
-  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>} : () -> tensor<2xi64>
+  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
+  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() <{value = dense<[4, 5]> : tensor<2xi64>}> : () -> tensor<2xi64>
 
   // CHECK: %[[LHS_SPLIT:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %arg0) : (tensor<i32>, tensor<3x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>)
   // CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
   // CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
   // CHECK: %[[LHS_3:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#2, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
 
-  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %arg1) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %arg1) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %arg1) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
 
-  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
+  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) <{axis = 0 : i64}> : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
   // CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32>
 }
 
@@ -190,7 +190,7 @@
   func.return %0 : tensor<4x6xf32>
 
   // CHECK-LABEL: batchMatMulMatrix
-  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %arg1) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
   // CHECK: return %[[MATMUL_1]] : tensor<4x6xf32>
 }
 
@@ -201,7 +201,7 @@
   func.return %0 : tensor<4x6xf32>
 
   // CHECK-LABEL: batchMatMulMatrixAdjXY
-  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %arg1) {transpose_a = true, transpose_b = true} : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %arg1) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
   // CHECK: return %[[MATMUL_1]] : tensor<4x6xf32>
 }
 
@@ -213,12 +213,12 @@
   func.return %0 : tensor<2x3x4x6xf32>
 
   // CHECK-LABEL: batchMatMulV2TwoDim
-  // CHECK-DAG: %[[LHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 4, 5]> : tensor<3xi64>}
-  // CHECK-DAG: %[[RHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 5, 6]> : tensor<3xi64>}
-  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
-  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
-  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
-  // CHECK-DAG: %[[RESULT_SHAPE:.*]] = "tf.Const"() {value = dense<[2, 3, 4, 6]> : tensor<4xi64>}
+  // CHECK-DAG: %[[LHS_RESHAPED_SHAPE:.*]] = "tf.Const"() <{value = dense<[6, 4, 5]> : tensor<3xi64>}>
+  // CHECK-DAG: %[[RHS_RESHAPED_SHAPE:.*]] = "tf.Const"() <{value = dense<[6, 5, 6]> : tensor<3xi64>}>
+  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}>
+  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() <{value = dense<[4, 5]> : tensor<2xi64>}>
+  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() <{value = dense<[5, 6]> : tensor<2xi64>}>
+  // CHECK-DAG: %[[RESULT_SHAPE:.*]] = "tf.Const"() <{value = dense<[2, 3, 4, 6]> : tensor<4xi64>}>
 
   // CHECK: %[[LHS_RESHAPED:.*]] = "tf.Reshape"(%arg0, %[[LHS_RESHAPED_SHAPE]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32>
   // CHECK: %[[LHS_SPLIT:.*]]:6 = "tf.Split"(%[[SPLITTING_AXIS]], %[[LHS_RESHAPED]]) : (tensor<i32>, tensor<6x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>)
@@ -238,14 +238,14 @@
   // CHECK: %[[RHS_5:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#4, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
   // CHECK: %[[RHS_6:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#5, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
 
-  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_4]], %[[RHS_4]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_5]], %[[RHS_5]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_6]], %[[RHS_6]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_4]], %[[RHS_4]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_5]], %[[RHS_5]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_6]], %[[RHS_6]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
 
-  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
+  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) <{axis = 0 : i64}> : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
   // CHECK: %[[RESULT:.*]] = "tf.Reshape"(%[[MATMUL_PACKED]], %[[RESULT_SHAPE]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32>
   // CHECK: return %[[RESULT]] : tensor<2x3x4x6xf32>
 }
@@ -257,12 +257,12 @@
   func.return %0 : tensor<2x3x4x6xf32>
 
   // CHECK-LABEL: batchMatMulV2TwoDimAdjXY
-  // CHECK-DAG: %[[LHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 5, 4]> : tensor<3xi64>}
-  // CHECK-DAG: %[[RHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 6, 5]> : tensor<3xi64>}
-  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
-  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 4]> : tensor<2xi64>}
-  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 5]> : tensor<2xi64>}
-  // CHECK-DAG: %[[RESULT_SHAPE:.*]] = "tf.Const"() {value = dense<[2, 3, 4, 6]> : tensor<4xi64>}
+  // CHECK-DAG: %[[LHS_RESHAPED_SHAPE:.*]] = "tf.Const"() <{value = dense<[6, 5, 4]> : tensor<3xi64>}>
+  // CHECK-DAG: %[[RHS_RESHAPED_SHAPE:.*]] = "tf.Const"() <{value = dense<[6, 6, 5]> : tensor<3xi64>}>
+  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}>
+  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() <{value = dense<[5, 4]> : tensor<2xi64>}>
+  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() <{value = dense<[6, 5]> : tensor<2xi64>}>
+  // CHECK-DAG: %[[RESULT_SHAPE:.*]] = "tf.Const"() <{value = dense<[2, 3, 4, 6]> : tensor<4xi64>}>
 
   // CHECK: %[[LHS_RESHAPED:.*]] = "tf.Reshape"(%arg0, %[[LHS_RESHAPED_SHAPE]]) : (tensor<2x3x5x4xf32>, tensor<3xi64>) -> tensor<6x5x4xf32>
   // CHECK: %[[LHS_SPLIT:.*]]:6 = "tf.Split"(%[[SPLITTING_AXIS]], %[[LHS_RESHAPED]]) : (tensor<i32>, tensor<6x5x4xf32>) -> (tensor<1x5x4xf32>, tensor<1x5x4xf32>, tensor<1x5x4xf32>, tensor<1x5x4xf32>, tensor<1x5x4xf32>, tensor<1x5x4xf32>)
@@ -282,14 +282,14 @@
   // CHECK: %[[RHS_5:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#4, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x6x5xf32>, tensor<2xi64>) -> tensor<6x5xf32>
   // CHECK: %[[RHS_6:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#5, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x6x5xf32>, tensor<2xi64>) -> tensor<6x5xf32>
 
-  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = true, transpose_b = true} : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) {transpose_a = true, transpose_b = true} : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) {transpose_a = true, transpose_b = true} : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_4]], %[[RHS_4]]) {transpose_a = true, transpose_b = true} : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_5]], %[[RHS_5]]) {transpose_a = true, transpose_b = true} : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_6]], %[[RHS_6]]) {transpose_a = true, transpose_b = true} : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_4]], %[[RHS_4]]) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_5]], %[[RHS_5]]) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_6]], %[[RHS_6]]) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
 
-  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
+  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) <{axis = 0 : i64}> : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
   // CHECK: %[[RESULT:.*]] = "tf.Reshape"(%[[MATMUL_PACKED]], %[[RESULT_SHAPE]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32>
   // CHECK: return %[[RESULT]] : tensor<2x3x4x6xf32>
 }
@@ -301,12 +301,12 @@
   func.return %0 : tensor<2x3x4x6xf32>
 
   // CHECK-LABEL: batchMatMulV2Broadcast
-  // CHECK-DAG: %[[LHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[2, 4, 5]> : tensor<3xi64>}
-  // CHECK-DAG: %[[RHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[3, 5, 6]> : tensor<3xi64>}
-  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
-  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
-  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
-  // CHECK-DAG: %[[RESULT_SHAPE:.*]] = "tf.Const"() {value = dense<[2, 3, 4, 6]> : tensor<4xi64>}
+  // CHECK-DAG: %[[LHS_RESHAPED_SHAPE:.*]] = "tf.Const"() <{value = dense<[2, 4, 5]> : tensor<3xi64>}>
+  // CHECK-DAG: %[[RHS_RESHAPED_SHAPE:.*]] = "tf.Const"() <{value = dense<[3, 5, 6]> : tensor<3xi64>}>
+  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}>
+  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() <{value = dense<[4, 5]> : tensor<2xi64>}>
+  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() <{value = dense<[5, 6]> : tensor<2xi64>}>
+  // CHECK-DAG: %[[RESULT_SHAPE:.*]] = "tf.Const"() <{value = dense<[2, 3, 4, 6]> : tensor<4xi64>}>
 
   // CHECK: %[[LHS_RESHAPED:.*]] = "tf.Reshape"(%arg0, %[[LHS_RESHAPED_SHAPE]]) : (tensor<2x1x4x5xf32>, tensor<3xi64>) -> tensor<2x4x5xf32>
   // CHECK: %[[LHS_SPLIT:.*]]:2 = "tf.Split"(%[[SPLITTING_AXIS]], %[[LHS_RESHAPED]]) : (tensor<i32>, tensor<2x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>)
@@ -319,14 +319,14 @@
   // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
   // CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
 
-  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_2]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_3]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_1]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_3]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
 
-  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
+  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) <{axis = 0 : i64}> : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
   // CHECK: %[[RESULT:.*]] = "tf.Reshape"(%[[MATMUL_PACKED]], %[[RESULT_SHAPE]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32>
   // CHECK: return %[[RESULT]] : tensor<2x3x4x6xf32>
 }
@@ -338,9 +338,9 @@
   func.return %0 : tensor<3x4x6xf32>
 
   // CHECK-LABEL: batchMatMulV2OneDim
-  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
-  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
-  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
+  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}>
+  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() <{value = dense<[4, 5]> : tensor<2xi64>}>
+  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() <{value = dense<[5, 6]> : tensor<2xi64>}>
 
   // CHECK: %[[LHS_RESHAPED:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %arg0) : (tensor<i32>, tensor<3x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>)
   // CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_RESHAPED]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
@@ -352,11 +352,11 @@
   // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
   // CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
 
-  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
 
-  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
+  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) <{axis = 0 : i64}> : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
   // CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32>
 }
 
@@ -367,16 +367,16 @@
   func.return %0 : tensor<1x4x6xf32>
 
   // CHECK-LABEL: batchMatMulV2SingleBatch
-  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>} : () -> tensor<2xi64>
-  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>} : () -> tensor<2xi64>
+  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() <{value = dense<[4, 5]> : tensor<2xi64>}> : () -> tensor<2xi64>
+  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() <{value = dense<[5, 6]> : tensor<2xi64>}> : () -> tensor<2xi64>
 
   // CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%arg0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
 
   // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%arg1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
 
-  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_2]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
 
-  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]]) {axis = 0 : i64} : (tensor<4x6xf32>) -> tensor<1x4x6xf32>
+  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]]) <{axis = 0 : i64}> : (tensor<4x6xf32>) -> tensor<1x4x6xf32>
   // CHECK: return %[[MATMUL_PACKED]] : tensor<1x4x6xf32>
 }
 
@@ -387,19 +387,19 @@
   func.return %0 : tensor<3x4x6xf32>
 
   // CHECK-LABEL: batchMatMulV2UnbatchedLeft
-  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
-  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
+  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}>
+  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() <{value = dense<[5, 6]> : tensor<2xi64>}>
 
   // CHECK: %[[RHS_RESHAPED:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %arg1) : (tensor<i32>, tensor<3x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>)
   // CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
   // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
   // CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
 
-  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%arg0, %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%arg0, %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %[[RHS_1]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%arg0, %[[RHS_2]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%arg0, %[[RHS_3]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
 
-  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
+  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) <{axis = 0 : i64}> : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
   // CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32>
 }
 
@@ -410,19 +410,19 @@
   func.return %0 : tensor<3x4x6xf32>
 
   // CHECK-LABEL: batchMatMulV2UnbatchedRight
-  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>} : () -> tensor<2xi64>
+  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
+  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() <{value = dense<[4, 5]> : tensor<2xi64>}> : () -> tensor<2xi64>
 
   // CHECK: %[[LHS_SPLIT:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %arg0) : (tensor<i32>, tensor<3x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>)
   // CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
   // CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
   // CHECK: %[[LHS_3:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#2, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
 
-  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %arg1) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %arg1) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %arg1) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
 
-  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
+  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) <{axis = 0 : i64}> : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
   // CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32>
 }
 
@@ -433,7 +433,7 @@
   func.return %0 : tensor<4x6xf32>
 
   // CHECK-LABEL: batchMatMulV2Matrix
-  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %arg1) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
   // CHECK: return %[[MATMUL_1]] : tensor<4x6xf32>
 }
 
@@ -444,7 +444,7 @@
   func.return %0 : tensor<4x6xf32>
 
   // CHECK-LABEL: batchMatMulV2MatrixAdjXY
-  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %arg1) {transpose_a = true, transpose_b = true} : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %arg1) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
   // CHECK: return %[[MATMUL_1]] : tensor<4x6xf32>
 }
 
@@ -455,7 +455,7 @@
   func.return %0 : tensor<?x4xf32>
 
   // CHECK-LABEL: batchMatMulV2DynamicSize
-  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<?x?xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %arg1) <{transpose_a = false, transpose_b = false}> : (tensor<?x?xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
   // CHECK: return %[[MATMUL_1]] : tensor<?x4xf32>
 }
 
@@ -467,12 +467,12 @@
   func.return %0 : tensor<2x3x4x6xf32>
 
   // CHECK-LABEL: batchMatMulV3TwoDim
-  // CHECK-DAG: %[[LHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 4, 5]> : tensor<3xi64>}
-  // CHECK-DAG: %[[RHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 5, 6]> : tensor<3xi64>}
-  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
-  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
-  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
-  // CHECK-DAG: %[[RESULT_SHAPE:.*]] = "tf.Const"() {value = dense<[2, 3, 4, 6]> : tensor<4xi64>}
+  // CHECK-DAG: %[[LHS_RESHAPED_SHAPE:.*]] = "tf.Const"() <{value = dense<[6, 4, 5]> : tensor<3xi64>}>
+  // CHECK-DAG: %[[RHS_RESHAPED_SHAPE:.*]] = "tf.Const"() <{value = dense<[6, 5, 6]> : tensor<3xi64>}>
+  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}>
+  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() <{value = dense<[4, 5]> : tensor<2xi64>}>
+  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() <{value = dense<[5, 6]> : tensor<2xi64>}>
+  // CHECK-DAG: %[[RESULT_SHAPE:.*]] = "tf.Const"() <{value = dense<[2, 3, 4, 6]> : tensor<4xi64>}>
 
   // CHECK: %[[LHS_RESHAPED:.*]] = "tf.Reshape"(%arg0, %[[LHS_RESHAPED_SHAPE]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32>
   // CHECK: %[[LHS_SPLIT:.*]]:6 = "tf.Split"(%[[SPLITTING_AXIS]], %[[LHS_RESHAPED]]) : (tensor<i32>, tensor<6x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>)
@@ -492,14 +492,14 @@
   // CHECK: %[[RHS_5:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#4, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
   // CHECK: %[[RHS_6:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#5, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
 
-  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_4]], %[[RHS_4]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_5]], %[[RHS_5]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_6]], %[[RHS_6]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_4]], %[[RHS_4]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_5]], %[[RHS_5]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_6]], %[[RHS_6]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
 
-  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
+  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) <{axis = 0 : i64}> : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
   // CHECK: %[[RESULT:.*]] = "tf.Reshape"(%[[MATMUL_PACKED]], %[[RESULT_SHAPE]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32>
   // CHECK: return %[[RESULT]] : tensor<2x3x4x6xf32>
 }
@@ -511,12 +511,12 @@
   func.return %0 : tensor<2x3x4x6xf32>
 
   // CHECK-LABEL: batchMatMulV3TwoDimAdjXY
-  // CHECK-DAG: %[[LHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 5, 4]> : tensor<3xi64>}
-  // CHECK-DAG: %[[RHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 6, 5]> : tensor<3xi64>}
-  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
-  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 4]> : tensor<2xi64>}
-  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[6, 5]> : tensor<2xi64>}
-  // CHECK-DAG: %[[RESULT_SHAPE:.*]] = "tf.Const"() {value = dense<[2, 3, 4, 6]> : tensor<4xi64>}
+  // CHECK-DAG: %[[LHS_RESHAPED_SHAPE:.*]] = "tf.Const"() <{value = dense<[6, 5, 4]> : tensor<3xi64>}>
+  // CHECK-DAG: %[[RHS_RESHAPED_SHAPE:.*]] = "tf.Const"() <{value = dense<[6, 6, 5]> : tensor<3xi64>}>
+  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}>
+  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() <{value = dense<[5, 4]> : tensor<2xi64>}>
+  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() <{value = dense<[6, 5]> : tensor<2xi64>}>
+  // CHECK-DAG: %[[RESULT_SHAPE:.*]] = "tf.Const"() <{value = dense<[2, 3, 4, 6]> : tensor<4xi64>}>
 
   // CHECK: %[[LHS_RESHAPED:.*]] = "tf.Reshape"(%arg0, %[[LHS_RESHAPED_SHAPE]]) : (tensor<2x3x5x4xf32>, tensor<3xi64>) -> tensor<6x5x4xf32>
   // CHECK: %[[LHS_SPLIT:.*]]:6 = "tf.Split"(%[[SPLITTING_AXIS]], %[[LHS_RESHAPED]]) : (tensor<i32>, tensor<6x5x4xf32>) -> (tensor<1x5x4xf32>, tensor<1x5x4xf32>, tensor<1x5x4xf32>, tensor<1x5x4xf32>, tensor<1x5x4xf32>, tensor<1x5x4xf32>)
@@ -536,14 +536,14 @@
   // CHECK: %[[RHS_5:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#4, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x6x5xf32>, tensor<2xi64>) -> tensor<6x5xf32>
   // CHECK: %[[RHS_6:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#5, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x6x5xf32>, tensor<2xi64>) -> tensor<6x5xf32>
 
-  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = true, transpose_b = true} : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) {transpose_a = true, transpose_b = true} : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) {transpose_a = true, transpose_b = true} : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_4]], %[[RHS_4]]) {transpose_a = true, transpose_b = true} : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_5]], %[[RHS_5]]) {transpose_a = true, transpose_b = true} : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_6]], %[[RHS_6]]) {transpose_a = true, transpose_b = true} : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_4]], %[[RHS_4]]) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_5]], %[[RHS_5]]) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_6]], %[[RHS_6]]) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
 
-  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
+  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) <{axis = 0 : i64}> : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
   // CHECK: %[[RESULT:.*]] = "tf.Reshape"(%[[MATMUL_PACKED]], %[[RESULT_SHAPE]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32>
   // CHECK: return %[[RESULT]] : tensor<2x3x4x6xf32>
 }
@@ -555,12 +555,12 @@
   func.return %0 : tensor<2x3x4x6xf32>
 
   // CHECK-LABEL: batchMatMulV3Broadcast
-  // CHECK-DAG: %[[LHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[2, 4, 5]> : tensor<3xi64>}
-  // CHECK-DAG: %[[RHS_RESHAPED_SHAPE:.*]] = "tf.Const"() {value = dense<[3, 5, 6]> : tensor<3xi64>}
-  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
-  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
-  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
-  // CHECK-DAG: %[[RESULT_SHAPE:.*]] = "tf.Const"() {value = dense<[2, 3, 4, 6]> : tensor<4xi64>}
+  // CHECK-DAG: %[[LHS_RESHAPED_SHAPE:.*]] = "tf.Const"() <{value = dense<[2, 4, 5]> : tensor<3xi64>}>
+  // CHECK-DAG: %[[RHS_RESHAPED_SHAPE:.*]] = "tf.Const"() <{value = dense<[3, 5, 6]> : tensor<3xi64>}>
+  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}>
+  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() <{value = dense<[4, 5]> : tensor<2xi64>}>
+  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() <{value = dense<[5, 6]> : tensor<2xi64>}>
+  // CHECK-DAG: %[[RESULT_SHAPE:.*]] = "tf.Const"() <{value = dense<[2, 3, 4, 6]> : tensor<4xi64>}>
 
   // CHECK: %[[LHS_RESHAPED:.*]] = "tf.Reshape"(%arg0, %[[LHS_RESHAPED_SHAPE]]) : (tensor<2x1x4x5xf32>, tensor<3xi64>) -> tensor<2x4x5xf32>
   // CHECK: %[[LHS_SPLIT:.*]]:2 = "tf.Split"(%[[SPLITTING_AXIS]], %[[LHS_RESHAPED]]) : (tensor<i32>, tensor<2x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>)
@@ -573,14 +573,14 @@
   // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
   // CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
 
-  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_2]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_3]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_1]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_3]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
 
-  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
+  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) <{axis = 0 : i64}> : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
   // CHECK: %[[RESULT:.*]] = "tf.Reshape"(%[[MATMUL_PACKED]], %[[RESULT_SHAPE]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32>
   // CHECK: return %[[RESULT]] : tensor<2x3x4x6xf32>
 }
@@ -592,9 +592,9 @@
   func.return %0 : tensor<3x4x6xf32>
 
   // CHECK-LABEL: batchMatMulV3OneDim
-  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
-  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
-  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
+  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}>
+  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() <{value = dense<[4, 5]> : tensor<2xi64>}>
+  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() <{value = dense<[5, 6]> : tensor<2xi64>}>
 
   // CHECK: %[[LHS_RESHAPED:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %arg0) : (tensor<i32>, tensor<3x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>)
   // CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_RESHAPED]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
@@ -606,11 +606,11 @@
   // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
   // CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
 
-  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
 
-  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
+  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) <{axis = 0 : i64}> : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
   // CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32>
 }
 
@@ -621,16 +621,16 @@
   func.return %0 : tensor<1x4x6xf32>
 
   // CHECK-LABEL: batchMatMulV3SingleBatch
-  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>} : () -> tensor<2xi64>
-  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>} : () -> tensor<2xi64>
+  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() <{value = dense<[4, 5]> : tensor<2xi64>}> : () -> tensor<2xi64>
+  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() <{value = dense<[5, 6]> : tensor<2xi64>}> : () -> tensor<2xi64>
 
   // CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%arg0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
 
   // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%arg1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
 
-  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_2]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
 
-  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]]) {axis = 0 : i64} : (tensor<4x6xf32>) -> tensor<1x4x6xf32>
+  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]]) <{axis = 0 : i64}> : (tensor<4x6xf32>) -> tensor<1x4x6xf32>
   // CHECK: return %[[MATMUL_PACKED]] : tensor<1x4x6xf32>
 }
 
@@ -641,19 +641,19 @@
   func.return %0 : tensor<3x4x6xf32>
 
   // CHECK-LABEL: batchMatMulV3UnbatchedLeft
-  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
-  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
+  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}>
+  // CHECK-DAG: %[[MATMUL_RHS_SHAPE:.*]] = "tf.Const"() <{value = dense<[5, 6]> : tensor<2xi64>}>
 
   // CHECK: %[[RHS_RESHAPED:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %arg1) : (tensor<i32>, tensor<3x5x6xf32>) -> (tensor<1x5x6xf32>, tensor<1x5x6xf32>, tensor<1x5x6xf32>)
   // CHECK: %[[RHS_1:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#0, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
   // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
   // CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
 
-  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %[[RHS_1]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%arg0, %[[RHS_2]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%arg0, %[[RHS_3]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %[[RHS_1]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%arg0, %[[RHS_2]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%arg0, %[[RHS_3]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
 
-  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
+  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) <{axis = 0 : i64}> : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
   // CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32>
 }
 
@@ -664,19 +664,19 @@
   func.return %0 : tensor<3x4x6xf32>
 
   // CHECK-LABEL: batchMatMulV3UnbatchedRight
-  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>} : () -> tensor<2xi64>
+  // CHECK-DAG: %[[SPLITTING_AXIS:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
+  // CHECK-DAG: %[[MATMUL_LHS_SHAPE:.*]] = "tf.Const"() <{value = dense<[4, 5]> : tensor<2xi64>}> : () -> tensor<2xi64>
 
   // CHECK: %[[LHS_SPLIT:.*]]:3 = "tf.Split"(%[[SPLITTING_AXIS]], %arg0) : (tensor<i32>, tensor<3x4x5xf32>) -> (tensor<1x4x5xf32>, tensor<1x4x5xf32>, tensor<1x4x5xf32>)
   // CHECK: %[[LHS_1:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#0, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
   // CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
   // CHECK: %[[LHS_3:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#2, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
 
-  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
-  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %arg1) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %arg1) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %arg1) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
 
-  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
+  // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) <{axis = 0 : i64}> : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
   // CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32>
 }
 
@@ -687,7 +687,7 @@
   func.return %0 : tensor<4x6xf32>
 
   // CHECK-LABEL: batchMatMulV3Matrix
-  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %arg1) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
   // CHECK: return %[[MATMUL_1]] : tensor<4x6xf32>
 }
 
@@ -698,7 +698,7 @@
   func.return %0 : tensor<4x6xf32>
 
   // CHECK-LABEL: batchMatMulV3MatrixAdjXY
-  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %arg1) {transpose_a = true, transpose_b = true} : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
+  // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %arg1) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32>
   // CHECK: return %[[MATMUL_1]] : tensor<4x6xf32>
 }
 
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/update_control_dependencies.mlir b/tensorflow/compiler/mlir/tensorflow/tests/update_control_dependencies.mlir
index 263a676..09931d9 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/update_control_dependencies.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/update_control_dependencies.mlir
@@ -80,7 +80,7 @@
 // CHECK:   %[[GRAPH:.*]]:2 = tf_executor.graph {
 // CHECK:     %[[ADD1:.*]], %[[ADD1_control:.*]] = tf_executor.island wraps "tf.Add"(%arg0, %arg1)
 // CHECK:     %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ADD1]], %arg1)
-// CHECK:     %[[PRINT:.*]], %[[PRINT_control:.*]] = tf_executor.island wraps "tf.Print"(%[[ADD2]]) {message = "add2 result"}
+// CHECK:     %[[PRINT:.*]], %[[PRINT_control:.*]] = tf_executor.island wraps "tf.Print"(%[[ADD2]]) <{message = "add2 result"}>
 // CHECK:     tf_executor.fetch %[[ADD1]], %[[ADD2]], %[[PRINT_control]] :
 // CHECK:   }
 // CHECK:  return %[[GRAPH]]#0, %[[GRAPH]]#1
@@ -99,7 +99,7 @@
 // CHECK:   %[[GRAPH:.*]]:2 = tf_executor.graph {
 // CHECK:     %[[ADD1:.*]], %[[ADD1_control:.*]] = tf_executor.island wraps "tf.Add"(%arg0, %arg1)
 // CHECK:     %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ADD1]], %arg1)
-// CHECK:     %[[PRINT:.*]], %[[PRINT_control:.*]] = tf_executor.island wraps "tf.Print"(%[[ADD2]]) {message = "add2 result"}
+// CHECK:     %[[PRINT:.*]], %[[PRINT_control:.*]] = tf_executor.island wraps "tf.Print"(%[[ADD2]]) <{message = "add2 result"}>
 // CHECK:     tf_executor.fetch %[[ADD1]], %[[ADD2]], %[[PRINT_control]] :
 // CHECK:   }
 // CHECK:  return %[[GRAPH]]#0, %[[GRAPH]]#1
@@ -127,7 +127,7 @@
 // CHECK-DAG:   %[[READ0:.*]], %[[READ0_CONTROL:.*]] = tf_executor.island wraps "tf.ReadVariableOp"(%arg0)
 // CHECK-DAG:   %[[ASSIGN0_CONTROL:.*]] = tf_executor.island(%[[READ0_CONTROL]]) wraps "tf.AssignVariableOp"(%arg0, %arg2)
 // CHECK-DAG:   %[[READ1:.*]], %[[READ1_CONTROL:.*]] = tf_executor.island wraps "tf.ReadVariableOp"(%arg1)
-// CHECK:   %[[VH0:.*]], %[[VH0_CONTROL:.*]] = tf_executor.island wraps "tf.VarHandleOp"() {container = "c", shared_name = "v0"}
+// CHECK:   %[[VH0:.*]], %[[VH0_CONTROL:.*]] = tf_executor.island wraps "tf.VarHandleOp"() <{container = "c", shared_name = "v0"}>
 // CHECK:   %[[READ2:.*]], %[[READ2_CONTROL:.*]] = tf_executor.island wraps "tf.ReadVariableOp"(%[[VH0]])
 // CHECK:   %[[ASSIGN1_CONTROL:.*]] = tf_executor.island(%[[READ1_CONTROL]]) wraps "tf.AssignVariableOp"(%arg1, %[[READ0]])
 // CHECK:   %[[ASSIGN2_CONTROL:.*]] = tf_executor.island(%[[ASSIGN0_CONTROL]]) wraps "tf.AssignVariableOp"(%arg0, %[[READ2]])
@@ -151,8 +151,8 @@
 }
 // CHECK-LABEL: func @unknown_side_effecting_op
 // CHECK: tf_executor.graph {
-// CHECK:   %[[VH0:.*]], %[[VH0_CONTROL:.*]] = tf_executor.island wraps "tf.VarHandleOp"() {container = "c", shared_name = "v0"}
-// CHECK:   %[[VH1:.*]], %[[VH1_CONTROL:.*]] = tf_executor.island wraps "tf.VarHandleOp"() {container = "c", shared_name = "v1"}
+// CHECK:   %[[VH0:.*]], %[[VH0_CONTROL:.*]] = tf_executor.island wraps "tf.VarHandleOp"() <{container = "c", shared_name = "v0"}>
+// CHECK:   %[[VH1:.*]], %[[VH1_CONTROL:.*]] = tf_executor.island wraps "tf.VarHandleOp"() <{container = "c", shared_name = "v1"}>
 // CHECK:   %[[READ0:.*]], %[[READ0_CONTROL:.*]] = tf_executor.island wraps "tf.ReadVariableOp"(%[[VH0]])
 // CHECK:   %[[ASSIGN0_CONTROL:.*]] = tf_executor.island wraps "tf.AssignVariableOp"(%[[VH1]], %arg0)
 // CHECK:   %[[UNKNOWN_CONTROL:.*]] = tf_executor.island(%[[READ0_CONTROL]], %[[ASSIGN0_CONTROL]]) wraps "tf._UnknownSideEffectingOp_"()
@@ -170,7 +170,7 @@
 }
 // CHECK-LABEL: func @single_op_island_forward_block_arg
 // CHECK: tf_executor.graph {
-// CHECK:   %[[outputs:.*]], %[[control:.*]] = tf_executor.island wraps "tf.Const"() {value = dense<0.000000e+00> : tensor<2048xf32>} : () -> tensor<2048xf32>
+// CHECK:   %[[outputs:.*]], %[[control:.*]] = tf_executor.island wraps "tf.Const"() <{value = dense<0.000000e+00> : tensor<2048xf32>}> : () -> tensor<2048xf32>
 // CHECK:   tf_executor.fetch %[[outputs]], %arg0 : tensor<2048xf32>, tensor<?x?x?x?xbf16>
 
 func.func @tpu_load_embedding_ops_sink_controls(%arg0: tensor<*x!tf_type.resource<tensor<8xf32>>>, %arg1: tensor<*x!tf_type.resource<tensor<8xf32>>>, %arg2: tensor<*x!tf_type.resource<tensor<8xf32>>>, %arg3: tensor<*x!tf_type.resource<tensor<8xf32>>>) {
@@ -194,13 +194,13 @@
 // CHECK:    %[[outputs:.*]], %[[control:.*]] = tf_executor.island wraps "tf.ReadVariableOp"(%arg0) {device = ""} : (tensor<*x!tf_type.resource<tensor<8xf32>>>) -> tensor<8xf32>
 // CHECK:    %[[outputs_0:.*]], %[[control_1:.*]] = tf_executor.island wraps "tf.ReadVariableOp"(%arg1) {device = ""} : (tensor<*x!tf_type.resource<tensor<8xf32>>>) -> tensor<8xf32>
 // CHECK:    %[[outputs_2:.*]], %[[control_3:.*]] = tf_executor.island wraps "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor<*x!tf_type.resource<tensor<8xf32>>>) -> tensor<8xf32>
-// CHECK:    %[[control_4:.*]] = tf_executor.island wraps "tf.LoadTPUEmbeddingAdagradParameters"(%[[outputs]], %[[outputs_0]]) {config = "", num_shards = 1 : i64, shard_id = 0 : i64, table_id = -1 : i64, table_name = "table1"} : (tensor<8xf32>, tensor<8xf32>) -> ()
+// CHECK:    %[[control_4:.*]] = tf_executor.island wraps "tf.LoadTPUEmbeddingAdagradParameters"(%[[outputs]], %[[outputs_0]]) <{config = "", num_shards = 1 : i64, shard_id = 0 : i64, table_id = -1 : i64, table_name = "table1"}> : (tensor<8xf32>, tensor<8xf32>) -> ()
 // CHECK:    %[[outputs_5:.*]], %[[control_6:.*]] = tf_executor.island wraps "tf.ReadVariableOp"(%arg3) {device = ""} : (tensor<*x!tf_type.resource<tensor<8xf32>>>) -> tensor<8xf32>
-// CHECK:    %[[control_7:.*]] = tf_executor.island wraps "tf.LoadTPUEmbeddingAdagradParameters"(%[[outputs_2]], %[[outputs_5]]) {config = "", num_shards = 1 : i64, shard_id = 0 : i64, table_id = -1 : i64, table_name = "table2"} : (tensor<8xf32>, tensor<8xf32>) -> ()
+// CHECK:    %[[control_7:.*]] = tf_executor.island wraps "tf.LoadTPUEmbeddingAdagradParameters"(%[[outputs_2]], %[[outputs_5]]) <{config = "", num_shards = 1 : i64, shard_id = 0 : i64, table_id = -1 : i64, table_name = "table2"}> : (tensor<8xf32>, tensor<8xf32>) -> ()
 // CHECK:    %[[control_8:.*]] = tf_executor.island(%[[control]], %[[control_1]], %[[control_3]], %[[control_4]], %[[control_6]], %[[control_7]]) wraps "tf.UnknownOp"() : () -> ()
 // CHECK:    %[[control_9:.*]] = tf_executor.island(%[[control_8]]) wraps "tf.UnknownOp"() : () -> ()
-// CHECK:    %[[control_10:.*]] = tf_executor.island(%[[control_9]]) wraps "tf.LoadTPUEmbeddingAdagradParameters"(%[[outputs]], %[[outputs_0]]) {config = "", num_shards = 1 : i64, shard_id = 0 : i64, table_id = -1 : i64, table_name = "table3"} : (tensor<8xf32>, tensor<8xf32>) -> ()
-// CHECK:    %[[control_11:.*]] = tf_executor.island(%[[control_9]]) wraps "tf.LoadTPUEmbeddingAdagradParameters"(%[[outputs_2]], %[[outputs_5]]) {config = "", num_shards = 1 : i64, shard_id = 0 : i64, table_id = -1 : i64, table_name = "table4"} : (tensor<8xf32>, tensor<8xf32>) -> ()
+// CHECK:    %[[control_10:.*]] = tf_executor.island(%[[control_9]]) wraps "tf.LoadTPUEmbeddingAdagradParameters"(%[[outputs]], %[[outputs_0]]) <{config = "", num_shards = 1 : i64, shard_id = 0 : i64, table_id = -1 : i64, table_name = "table3"}> : (tensor<8xf32>, tensor<8xf32>) -> ()
+// CHECK:    %[[control_11:.*]] = tf_executor.island(%[[control_9]]) wraps "tf.LoadTPUEmbeddingAdagradParameters"(%[[outputs_2]], %[[outputs_5]]) <{config = "", num_shards = 1 : i64, shard_id = 0 : i64, table_id = -1 : i64, table_name = "table4"}> : (tensor<8xf32>, tensor<8xf32>) -> ()
 // CHECK:    tf_executor.fetch %[[control_10]], %[[control_11]] : !tf_executor.control, !tf_executor.control
 
 // -----
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/xla_call_module_deserialization.mlir b/tensorflow/compiler/mlir/tensorflow/tests/xla_call_module_deserialization.mlir
index 982b44f..eb3e078 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/xla_call_module_deserialization.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/xla_call_module_deserialization.mlir
@@ -18,9 +18,9 @@
   // CHECK-SAME:    %[[ARG0:.*]]: tensor<10xi32>, %[[ARG1:.*]]: tensor<10xi32>
   func.func @main(%arg0: tensor<10xi32>, %arg1: tensor<10xi32>) -> tensor<10xi32> {
     // CHECK:      %[[RESULT:.*]] = "tf.XlaCallModule"(%[[ARG0]], %[[ARG1]])
-    // CHECK-SAME:   _entry_function = @main0,
     // CHECK-NOT:    function_list
     // CHECK-SAME:   module = ""
+    // CHECK-SAME:   _entry_function = @main_0,
 
     // `module` is stablehlo bytecode for:
     //  func.func @main(%arg0: tensor<?xi32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}, %arg1: tensor<*xi32>) -> (tensor<?xi32> {jax.result_info = ""}) {
@@ -36,9 +36,9 @@
   // CHECK-SAME:    %[[ARG0:.*]]: tensor<10xi32>, %[[ARG1:.*]]: tensor<10xi32>
   func.func @foo(%arg0: tensor<10xi32>, %arg1: tensor<10xi32>) -> tensor<10xi32> {
     // CHECK:      %[[RESULT:.*]] = "tf.XlaCallModule"(%[[ARG0]], %[[ARG1]])
-    // CHECK-SAME:   _entry_function = @main1,
     // CHECK-NOT:    function_list
     // CHECK-SAME:   module = ""
+    // CHECK-SAME:   _entry_function = @main_1,
 
     // `module` is stablehlo bytecode for:
     //  func.func @main(%arg0: tensor<?xi32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}, %arg1: tensor<*xi32>) -> (tensor<?xi32> {jax.result_info = ""}) {
@@ -50,13 +50,13 @@
     func.return %0 : tensor<10xi32>
   }
 
-  // CHECK-LABEL: func private @main0
+  // CHECK-LABEL: func private @main_0
   // CHECK-SAME:    (%[[ARG0:.*]]: tensor<?xi32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}, %[[ARG1:.*]]: tensor<*xi32>) -> (tensor<?xi32> {jax.result_info = ""}) attributes {_from_xla_call_module} {
   // CHECK:         stablehlo.custom_call @tf.call_tf_function(%[[ARG0]], %[[ARG1]]) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {called_func = @_tf_func}} : (tensor<?xi32>, tensor<*xi32>) -> ()
   // CHECK:         return %arg0 : tensor<?xi32>
   // CHECK:       }
 
-  // CHECK-LABEL: func private @main1
+  // CHECK-LABEL: func private @main_1
   // CHECK-SAME:    (%[[ARG0:.*]]: tensor<?xi32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}, %[[ARG1:.*]]: tensor<*xi32>) -> (tensor<?xi32> {jax.result_info = ""}) attributes {_from_xla_call_module} {
   // CHECK:         stablehlo.custom_call @tf.call_tf_function(%[[ARG0]], %[[ARG1]]) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {called_func = @_tf_func}} : (tensor<?xi32>, tensor<*xi32>) -> ()
   // CHECK:         return %arg0 : tensor<?xi32>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/xla_call_module_round_trip.mlir b/tensorflow/compiler/mlir/tensorflow/tests/xla_call_module_round_trip.mlir
index 83be1fd..4061484 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/xla_call_module_round_trip.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/xla_call_module_round_trip.mlir
@@ -13,14 +13,14 @@
   func.func @main(%arg0: tensor<10xi32>, %arg1: tensor<10xi32>) -> tensor<10xi32> {
     // CHECK:      %[[RESULT:.*]] = "tf.XlaCallModule"(%[[ARG0]], %[[ARG1]])
     // CHECK-SAME:   Sout = [#tf_type.shape<?>]
-    // CHECK-SAME:   _entry_function = @main0
-    // CHECK-SAME:   _stablehlo_module_attrs = {}
     // CHECK-NOT:    function_list
     // CHECK-SAME:   module = ""
     // CHECK-SAME:   platforms = []
     // CHECK-SAME:   version = 5
+    // CHECK-SAME:   _entry_function = @main_0
+    // CHECK-SAME:   _stablehlo_module_attrs = {}
 
-    %0 = "tf.XlaCallModule"(%arg0, %arg1) {Sout = [#tf_type.shape<?>], dim_args_spec = [], _entry_function = @main0, module = "", platforms = [], version = 5 : i64} : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32>
+    %0 = "tf.XlaCallModule"(%arg0, %arg1) {Sout = [#tf_type.shape<?>], dim_args_spec = [], _entry_function = @main_0, module = "", platforms = [], version = 5 : i64} : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32>
     // CHECK: return %[[RESULT]]
     func.return %0 : tensor<10xi32>
   }
@@ -34,12 +34,12 @@
     func.return
   }
 
-  // CHECK-LABEL: func private @main0
+  // CHECK-LABEL: func private @main_0
   // CHECK-SAME:    %[[ARG0:.*]]: tensor<?xi32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}
   // CHECK-SAME:    %[[ARG1:.*]]: tensor<*xi32>)
   // CHECK-SAME:    (tensor<?xi32> {jax.result_info = ""})
   // CHECK-SAME:    attributes {_from_xla_call_module}
-  func.func private @main0(%arg0: tensor<?xi32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}, %arg1: tensor<*xi32>) -> (tensor<?xi32> {jax.result_info = ""}) attributes {_from_xla_call_module} {
+  func.func private @main_0(%arg0: tensor<?xi32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}, %arg1: tensor<*xi32>) -> (tensor<?xi32> {jax.result_info = ""}) attributes {_from_xla_call_module} {
     // CHECK:      stablehlo.custom_call @tf.call_tf_function(%[[ARG0]], %[[ARG1]])
     // CHECK-SAME: {
     // CHECK-SAME:  api_version = 2 : i32,
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/xla_cluster_formation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/xla_cluster_formation.mlir
index 282588a..6b0f700 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/xla_cluster_formation.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/xla_cluster_formation.mlir
@@ -91,8 +91,8 @@
 // Check that we encapsulate the function body of entry functions with compilation markers, and not the included partitioned calls with the markers.
 // CHECK-LABEL:   func.func @entry_function_with_compilation_markers(%arg0: tensor<i32>) -> tensor<i32> attributes {_xla_compile_device_type = "GPU", allow_soft_placement = true, device = "/device:GPU:0", tf.entry_function = {}} {
 // CHECK:           %0 = "tf_device.cluster"() ({
-// CHECK:             %1 = "tf.StatefulPartitionedCall"(%arg0) {_xla_compile_device_type = "GPU", config = "", config_proto = "", executor_type = "", f = @stateful_pcall_func} : (tensor<i32>) -> tensor<i32>
-// CHECK:             %cst = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
+// CHECK:             %1 = "tf.StatefulPartitionedCall"(%arg0) <{config = "", config_proto = "", executor_type = "", f = @stateful_pcall_func}> {_xla_compile_device_type = "GPU"} : (tensor<i32>) -> tensor<i32>
+// CHECK:             %cst = "tf.Const"() <{value = dense<5> : tensor<i32>}> : () -> tensor<i32>
 // CHECK:             %2 = "tf.Add"(%1, %cst) : (tensor<i32>, tensor<i32>) -> tensor<i32>
 // CHECK:             tf_device.return %2 : tensor<i32>
 // CHECK:           }) {_cluster_outlined_function_name = "entry_function_with_compilation_markers_cluster_func", _xla_compile_device_type = "GPU", allow_soft_placement = true, device = "/device:GPU:0"} : () -> tensor<i32>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/xla_rewrite.mlir b/tensorflow/compiler/mlir/tensorflow/tests/xla_rewrite.mlir
index 914c753..4e02848 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/xla_rewrite.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/xla_rewrite.mlir
@@ -4,7 +4,7 @@
 module attributes {tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:GPU:0"]} {
   // CHECK-LABEL: func.func @convert_cluster_func
   func.func @convert_cluster_func(%arg0: tensor<i32>) -> tensor<i32> {
-    // CHECK: "tf.XlaLaunch"(%arg0) {function = @func, operandSegmentSizes = array<i32: 0, 1, 0>} : (tensor<i32>) -> tensor<i32>
+    // CHECK: "tf.XlaLaunch"(%arg0) <{function = @func, operandSegmentSizes = array<i32: 0, 1, 0>}> : (tensor<i32>) -> tensor<i32>
     %0 = "tf_device.cluster_func"(%arg0) {func = @func} : (tensor<i32>) -> tensor<i32>
     func.return %0 : tensor<i32>
   }
@@ -19,7 +19,7 @@
 module attributes {tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:GPU:0"]} {
   // CHECK-LABEL: func.func @convert_cluster_func_with_resources_in_order
   func.func @convert_cluster_func_with_resources_in_order(%arg0: tensor<!tf_type.resource>, %arg1: tensor<i32>) -> tensor<i32> {
-    // CHECK: "tf.XlaLaunch"(%arg1, %arg0) {function = @func_with_resources_in_order, operandSegmentSizes = array<i32: 0, 1, 1>} : (tensor<i32>, tensor<!tf_type.resource>) -> tensor<i32>
+    // CHECK: "tf.XlaLaunch"(%arg1, %arg0) <{function = @func_with_resources_in_order, operandSegmentSizes = array<i32: 0, 1, 1>}> : (tensor<i32>, tensor<!tf_type.resource>) -> tensor<i32>
     %0 = "tf_device.cluster_func"(%arg1, %arg0) {func = @func_with_resources_in_order} : (tensor<i32>, tensor<!tf_type.resource>) -> (tensor<i32>)
     func.return %0 : tensor<i32>
   }
@@ -34,9 +34,9 @@
 module attributes {tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:GPU:0"]} {
   // CHECK-LABEL: func.func @convert_cluster_func_with_resources
   func.func @convert_cluster_func_with_resources(%arg0: tensor<!tf_type.resource>, %arg1: tensor<i32>) -> tensor<i32> {
-    // CHECK: "tf.XlaLaunch"(%arg1, %arg0) {function = @func_with_resources, operandSegmentSizes = array<i32: 0, 1, 1>} : (tensor<i32>, tensor<!tf_type.resource>) -> tensor<i32>
+    // CHECK: "tf.XlaLaunch"(%arg1, %arg0) <{function = @func_with_resources, operandSegmentSizes = array<i32: 0, 1, 1>}> : (tensor<i32>, tensor<!tf_type.resource>) -> tensor<i32>
     %0 = "tf_device.cluster_func"(%arg0, %arg1) {func = @func_with_resources} : (tensor<!tf_type.resource>, tensor<i32>) -> tensor<i32>
-    // CHECK: "tf.XlaLaunch"(%arg1, %arg0) {function = @func_with_resources, operandSegmentSizes = array<i32: 0, 1, 1>} : (tensor<i32>, tensor<!tf_type.resource>) -> tensor<i32>
+    // CHECK: "tf.XlaLaunch"(%arg1, %arg0) <{function = @func_with_resources, operandSegmentSizes = array<i32: 0, 1, 1>}> : (tensor<i32>, tensor<!tf_type.resource>) -> tensor<i32>
     %1 = "tf_device.cluster_func"(%arg0, %arg1) {func = @func_with_resources} : (tensor<!tf_type.resource>, tensor<i32>) -> tensor<i32>
     return %0 : tensor<i32>
   }
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/xla_rewrite_v2.mlir b/tensorflow/compiler/mlir/tensorflow/tests/xla_rewrite_v2.mlir
index 7d225ce..e36bdaa 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/xla_rewrite_v2.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/xla_rewrite_v2.mlir
@@ -5,11 +5,11 @@
   // CHECK-LABEL: func.func @convert_cluster_func
   func.func @convert_cluster_func(%arg0: tensor<i32>) -> tensor<i32> {
     // CHECK: "tf_device.launch"()
-    // CHECK: "tf._XlaCompile"(%arg0) {function = @func, must_compile = true, operandSegmentSizes = array<i32: 0, 1, 0>} : (tensor<i32>) -> (tensor<3x!tf_type.string>, tensor<!tf_type.boolref>)
-    // CHECK: {device = "/job:localhost/replica:0/task:0/device:GPU:0"}
+    // CHECK-SAME: <{device = "/job:localhost/replica:0/task:0/device:GPU:0"}>
+    // CHECK: "tf._XlaCompile"(%arg0) <{function = @func, must_compile = true, operandSegmentSizes = array<i32: 0, 1, 0>}> : (tensor<i32>) -> (tensor<3x!tf_type.string>, tensor<!tf_type.boolref>)
     // CHECK: "tf_device.launch"()
+    // CHECK-SAME: <{device = "/job:localhost/replica:0/task:0/device:GPU:0"}>
     // CHECK: "tf._XlaRun"(%arg0, %0#0) : (tensor<i32>, tensor<3x!tf_type.string>) -> tensor<i32>
-    // CHECK: {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : () -> tensor<i32>
     %0 = "tf_device.cluster_func"(%arg0) {func = @func, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<i32>) -> tensor<i32>
     func.return %0 : tensor<i32>
   }
@@ -25,11 +25,11 @@
   // CHECK-LABEL: func.func @convert_cluster_func_with_resources_in_order
   func.func @convert_cluster_func_with_resources_in_order(%arg0: tensor<!tf_type.resource>, %arg1: tensor<i32>) -> tensor<i32> {
     // CHECK: "tf_device.launch"()
-    // CHECK: "tf._XlaCompile"(%arg1, %arg0) {function = @func_with_resources_in_order, must_compile = true, operandSegmentSizes = array<i32: 0, 1, 1>} : (tensor<i32>, tensor<!tf_type.resource>)
-    // CHECK: {device = "/job:localhost/replica:0/task:0/device:GPU:0"}
+    // CHECK-SAME: <{device = "/job:localhost/replica:0/task:0/device:GPU:0"}>
+    // CHECK: "tf._XlaCompile"(%arg1, %arg0) <{function = @func_with_resources_in_order, must_compile = true, operandSegmentSizes = array<i32: 0, 1, 1>}> : (tensor<i32>, tensor<!tf_type.resource>)
     // CHECK: "tf_device.launch"()
+    // CHECK-SAME: <{device = "/job:localhost/replica:0/task:0/device:GPU:0"}>
     // CHECK: "tf._XlaRun"(%arg1, %arg0, %0#0) : (tensor<i32>, tensor<!tf_type.resource>, tensor<3x!tf_type.string>) -> tensor<i32>
-    // CHECK: {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : () -> tensor<i32>
     %0 = "tf_device.cluster_func"(%arg1, %arg0) {func = @func_with_resources_in_order, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<i32>, tensor<!tf_type.resource>) -> (tensor<i32>)
     func.return %0 : tensor<i32>
   }
@@ -45,18 +45,18 @@
   // CHECK-LABEL: func.func @convert_cluster_func_with_resources
   func.func @convert_cluster_func_with_resources(%arg0: tensor<!tf_type.resource>, %arg1: tensor<i32>) -> tensor<i32> {
     // CHECK: "tf_device.launch"()
-    // CHECK: "tf._XlaCompile"(%arg1, %arg0) {function = @func_with_resources_1, must_compile = true, operandSegmentSizes = array<i32: 0, 1, 1>} : (tensor<i32>, tensor<!tf_type.resource>) -> (tensor<3x!tf_type.string>, tensor<!tf_type.boolref>)
-    // CHECK: {device = "/job:localhost/replica:0/task:0/device:GPU:0"}
+    // CHECK-SAME: <{device = "/job:localhost/replica:0/task:0/device:GPU:0"}>
+    // CHECK: "tf._XlaCompile"(%arg1, %arg0) <{function = @func_with_resources_1, must_compile = true, operandSegmentSizes = array<i32: 0, 1, 1>}> : (tensor<i32>, tensor<!tf_type.resource>) -> (tensor<3x!tf_type.string>, tensor<!tf_type.boolref>)
     // CHECK: "tf_device.launch"()
+    // CHECK-SAME: <{device = "/job:localhost/replica:0/task:0/device:GPU:0"}>
     // CHECK: "tf._XlaRun"(%arg1, %arg0, %0#0) : (tensor<i32>, tensor<!tf_type.resource>, tensor<3x!tf_type.string>) -> tensor<i32>
-    // CHECK: {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : () -> tensor<i32>
     %0 = "tf_device.cluster_func"(%arg0, %arg1) {func = @func_with_resources_1, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<!tf_type.resource>, tensor<i32>) -> tensor<i32>
     // CHECK: "tf_device.launch"()
-    // CHECK: "tf._XlaCompile"(%arg1, %arg0) {function = @func_with_resources_2, must_compile = true, operandSegmentSizes = array<i32: 0, 1, 1>} : (tensor<i32>, tensor<!tf_type.resource>) -> (tensor<3x!tf_type.string>, tensor<!tf_type.boolref>)
-    // CHECK: {device = "/job:localhost/replica:0/task:0/device:GPU:0"}
+    // CHECK-SAME: <{device = "/job:localhost/replica:0/task:0/device:GPU:0"}>
+    // CHECK: "tf._XlaCompile"(%arg1, %arg0) <{function = @func_with_resources_2, must_compile = true, operandSegmentSizes = array<i32: 0, 1, 1>}> : (tensor<i32>, tensor<!tf_type.resource>) -> (tensor<3x!tf_type.string>, tensor<!tf_type.boolref>)
     // CHECK: "tf_device.launch"()
+    // CHECK-SAME: <{device = "/job:localhost/replica:0/task:0/device:GPU:0"}>
     // CHECK: "tf._XlaRun"(%arg1, %arg0, %2#0) : (tensor<i32>, tensor<!tf_type.resource>, tensor<3x!tf_type.string>) -> tensor<i32>
-    // CHECK: {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : () -> tensor<i32>
     %1 = "tf_device.cluster_func"(%arg0, %arg1) {func = @func_with_resources_2, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<!tf_type.resource>, tensor<i32>) -> tensor<i32>
     return %0 : tensor<i32>
   }
@@ -77,16 +77,16 @@
 module attributes {tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0"], tf.versions = {producer = 888 : i32}} {
   func.func @outside_compilation_in_generic_pipeline(%arg0: tensor<2xi32>) -> tensor<2xi32> {
     // CHECK: tf_device.launch
-    // CHECK: "tf._XlaCompile"() {function = @func, must_compile = true, operandSegmentSizes = array<i32: 0, 0, 0>}
-    // CHECK: {device = "/job:localhost/replica:0/task:0/device:GPU:0"}
+    // CHECK-SAME: <{device = "/job:localhost/replica:0/task:0/device:GPU:0"}>
+    // CHECK: "tf._XlaCompile"() <{function = @func, must_compile = true, operandSegmentSizes = array<i32: 0, 0, 0>}>
     // CHECK: tf_device.parallel_execute
     // CHECK: tf_device.launch
+    // CHECK-SAME: <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}>
     // CHECK: tf.B
     // CHECK: tf._XlaSendFromHost
-    // CHECK: {device = "/job:localhost/replica:0/task:0/device:CPU:0"}
     // CHECK: tf_device.launch
+    // CHECK-SAME: <{device = "/job:localhost/replica:0/task:0/device:GPU:0"}>
     // CHECK: tf._XlaRun
-    // CHECK: {device = "/job:localhost/replica:0/task:0/device:GPU:0"}
     %0 = "tf_device.parallel_execute"() ({
       "tf_device.launch"() ({
         %1 = "tf._XlaCompileMlirPlaceholderProgramKey"() : () -> tensor<3x!tf_type.string>
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD
index d69de74..58c3338 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD
@@ -76,7 +76,7 @@
         ":bridge",
         ":tensorflow_passes",
         ":tf_saved_model_passes",
-        "//tensorflow/compiler/mlir/tf2xla:compile_mlir_util",
+        "//tensorflow/compiler/mlir/tf2xla/internal:clustering_bridge_passes",
         "//tensorflow/compiler/mlir/tf2xla/transforms:legalize_tf",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:IR",
@@ -566,7 +566,6 @@
         ":cluster_formation",
         ":decompose_resource_ops",
         ":decompose_resource_ops_inc_gen",
-        ":extract_head_tail_outside_compilation",
         ":extract_outside_compilation",
         ":lower_tf_lib",
         ":mark_ops_for_outside_compilation",
@@ -577,7 +576,6 @@
         ":tf_pass_inc_gen",
         ":tf_savedmodel_pass_inc_gen",
         ":tfe_legalize_tfg",
-        ":tpu_cluster_formation",
         ":unroll_batch_matmul_pass",
         ":verify_no_outside_compilation_markers_pass",
         ":xla_cluster_formation",
@@ -681,31 +679,6 @@
 )
 
 cc_library(
-    name = "tpu_cluster_formation",
-    srcs = ["tpu_cluster_formation.cc"],
-    textual_hdrs = [
-        "tf_passes.h.inc",
-    ],
-    deps = [
-        ":tf_pass_inc_gen",
-        "//tensorflow/compiler/mlir/tensorflow",
-        "//tensorflow/compiler/mlir/tensorflow:attribute_utils",
-        "//tensorflow/compiler/mlir/tensorflow:string_util",
-        "//tensorflow/compiler/mlir/tensorflow:tensorflow_analysis",
-        "//tensorflow/compiler/mlir/tensorflow:tpu_rewrite_device_util",
-        "//tensorflow/core:framework",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/strings",
-        "@llvm-project//llvm:Support",
-        "@llvm-project//mlir:FuncDialect",
-        "@llvm-project//mlir:IR",
-        "@llvm-project//mlir:Pass",
-        "@llvm-project//mlir:Support",
-        "@llvm-project//mlir:TransformUtils",
-    ],
-)
-
-cc_library(
     name = "cluster_formation",
     srcs = ["cluster_formation.cc"],
     textual_hdrs = [
@@ -765,39 +738,6 @@
 )
 
 cc_library(
-    name = "extract_head_tail_outside_compilation",
-    srcs = ["extract_head_tail_outside_compilation.cc"],
-    textual_hdrs = [
-        "tf_passes.h.inc",
-    ],
-    deps = [
-        ":lower_tf_lib",
-        ":tf_pass_inc_gen",
-        "//tensorflow/compiler/mlir/tensorflow",
-        "//tensorflow/compiler/mlir/tensorflow:attribute_utils",
-        "//tensorflow/compiler/mlir/tensorflow:device_util",
-        "//tensorflow/compiler/mlir/tensorflow:string_util",
-        "//tensorflow/compiler/mlir/tensorflow:tensorflow_analysis",
-        "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops",
-        "//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
-        "//tensorflow/compiler/mlir/tensorflow:tpu_rewrite_device_util",
-        "//tensorflow/compiler/mlir/tf2xla/transforms:legalization_op_config",
-        "//tensorflow/compiler/mlir/tf2xla/transforms:legalize_tf",
-        "//tensorflow/core:framework",
-        "//tensorflow/core:lib",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/strings",
-        "@llvm-project//llvm:Support",
-        "@llvm-project//mlir:FuncDialect",
-        "@llvm-project//mlir:IR",
-        "@llvm-project//mlir:Pass",
-        "@llvm-project//mlir:Rewrite",
-        "@llvm-project//mlir:Support",
-        "@llvm-project//mlir:TransformUtils",
-    ],
-)
-
-cc_library(
     name = "extract_outside_compilation",
     srcs = ["extract_outside_compilation.cc"],
     textual_hdrs = [
@@ -982,22 +922,6 @@
 )
 
 cc_library(
-    name = "bridge_pass_test_pipeline_registration",
-    testonly = True,  # Ensure alwayslink does not leak in the codebase.
-    srcs = [
-        "bridge_pass.cc",
-    ],
-    deps = [
-        ":bridge",
-        ":tensorflow_passes",
-        "//tensorflow/compiler/mlir/tensorflow:error_util",
-        "@llvm-project//mlir:Pass",
-        "@llvm-project//mlir:Transforms",
-    ],
-    alwayslink = 1,
-)
-
-cc_library(
     name = "tensorflow_test_passes",
     testonly = True,  # Ensure alwayslink does not leak in the codebase.
     srcs = [
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc
index d657fc9..a7f1037 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc
@@ -31,10 +31,6 @@
 #include "tensorflow/compiler/mlir/tensorflow/utils/data_dumper_logger_config.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
-#include "tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor.h"
-#include "tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.h"
-#include "tensorflow/compiler/mlir/tf2xla/api/v2/device_type.pb.h"
-#include "tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor.h"
 #include "tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.h"
 #include "tensorflow/compiler/mlir/tf2xla/internal/inference/inference_passes.h"
 #include "tensorflow/compiler/mlir/tf2xla/internal/logging_hooks.h"
@@ -47,81 +43,6 @@
 #include "tsl/platform/error_logging.h"
 
 namespace mlir {
-namespace TFTPU {
-namespace {
-
-constexpr char kBridgeComponent[] = "TFXLABridge";
-
-// Run the TF XLA Bridge based on the input pipeline, which can be either TPU
-// bridge pipeline or non TPU bridge pipeline.
-tensorflow::Status RunTFXLABridge(
-    ModuleOp module,
-    llvm::function_ref<void(OpPassManager &pm)> pipeline_builder,
-    llvm::StringRef module_name = llvm::StringRef()) {
-  // Explicitly check that the TensorFlow dialect can constant fold ops.
-  // Constant folding is essential for the bridge. Without this check, the
-  // bridge may fail with an error that is difficult to understand and not
-  // actionable.
-  if (!TF::TensorFlowDialect::HasConstantFoldHook()) {
-    return tensorflow::errors::Internal(
-        "TensorFlow dialect missing constant fold hook in TFXLA bridge phase "
-        "1; this could happen if the binary doesn't link the constant fold "
-        "hook registration library.");
-  }
-
-  PassManager bridge(module.getContext());
-  ::tensorflow::applyTensorflowAndCLOptions(bridge);
-
-  // Populate a passmanager with the list of passes that implement the bridge.
-  pipeline_builder(bridge);
-
-  mlir::StatusScopedDiagnosticHandler diag_handler(
-      module.getContext(), /*propagate=*/false,
-      /*filter_stack=*/!VLOG_IS_ON(1));
-
-  if (VLOG_IS_ON(1) ||
-      DEBUG_DATA_DUMPER()->ShouldDump(module_name.str(), kDebugGroupMain)) {
-    ::tensorflow::DumpMlirOpToFile(
-        DEBUG_DATA_DUMPER()->GetDumpFilename(module_name.str(), kDebugGroupMain,
-                                             "tf_xla_bridge_before"),
-        module, llvm::StringRef(), &bridge);
-  }
-
-  if (VLOG_IS_ON(2) ||
-      DEBUG_DATA_DUMPER()->ShouldDump(module_name.str(),
-                                      kDebugGroupBridgePhase1Clustering)) {
-    ::tensorflow::tf2xla::internal::EnablePassIRPrinting(
-        bridge, kDebugGroupBridgePhase1Clustering, module_name);
-  }
-
-  LogicalResult result = bridge.run(module);
-  (void)result;
-
-  if (VLOG_IS_ON(1) ||
-      DEBUG_DATA_DUMPER()->ShouldDump(module_name.str(), kDebugGroupMain)) {
-    ::tensorflow::DumpMlirOpToFile(
-        DEBUG_DATA_DUMPER()->GetDumpFilename(module_name.str(), kDebugGroupMain,
-                                             "tf_xla_bridge_after"),
-        module, llvm::StringRef(), &bridge);
-  }
-
-  return diag_handler.ConsumeStatus();
-}
-
-}  // namespace
-
-void CreateTPUBridgePipeline(OpPassManager &pm, llvm::StringRef module_name) {
-  pm.addPass(CreateTPUValidateInputsPass());
-  pm.addNestedPass<func::FuncOp>(
-      TF::CreateCanonicalizeCompileAndReplicateAttributesPass());
-  tensorflow::tf2xla::internal::AddBridgeClusteringPipelinePasses(pm,
-                                                                  module_name);
-  tensorflow::tfrt_compiler::AddTPULowerClusterToRuntimeOpsPassPipeline(
-      pm, module_name);
-}
-
-}  // namespace TFTPU
-
 namespace TF {
 
 tensorflow::Status RunBridgeWithStandardPipeline(ModuleOp module,
@@ -137,11 +58,12 @@
       module.getContext(), /*propagate=*/false,
       /*filter_stack=*/!VLOG_IS_ON(1));
 
+  constexpr char kBridgeComponent[] = "TFXLABridge";
   if (enable_logging || VLOG_IS_ON(1)) {
     tensorflow::DumpMlirOpToFile(kStandardPipelineBefore, module, "", &bridge);
     if (VLOG_IS_ON(2)) {
-      tensorflow::tf2xla::internal::EnablePassIRPrinting(
-          bridge, TFTPU::kBridgeComponent);
+      tensorflow::tf2xla::internal::EnablePassIRPrinting(bridge,
+                                                         kBridgeComponent);
     }
   }
   LogicalResult result = bridge.run(module);
@@ -151,17 +73,5 @@
   return diag_handler.ConsumeStatus();
 }
 
-void CreateTFXLABridgePipeline(OpPassManager &pm) {
-  tensorflow::tf2xla::internal::AddNonTPUBridgeClusteringPipelinePasses(pm);
-}
-
-tensorflow::Status RunTFXLABridge(ModuleOp module,
-                                  llvm::StringRef module_name) {
-  // CPU == GPU here, so both are equivalent.
-  return tensorflow::tf2xla::v2::RunFunctionTf2xlaClusteringBridge(
-      module, tensorflow::tf2xla::v2::XLA_GPU_JIT,
-      /*is_in_fallback_enabled_mode=*/false, module_name);
-}
-
 }  // namespace TF
 }  // namespace mlir
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.h b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.h
index 72875f2..6712350 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.h
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.h
@@ -40,11 +40,6 @@
 tensorflow::Status RunBridgeWithStandardPipeline(ModuleOp module,
                                                  bool enable_logging,
                                                  bool enable_inliner);
-
-// Runs all passes for non TPU (GPU and CPU) graph.
-ABSL_DEPRECATED("Use tf2xla::v2::RunFunctionTf2xlaClusteringBridge instead.")
-tensorflow::Status RunTFXLABridge(
-    ModuleOp module, llvm::StringRef module_name = llvm::StringRef());
 }  // namespace TF
 
 }  // namespace mlir
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge_pass.cc
deleted file mode 100644
index d2e94b5..0000000
--- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge_pass.cc
+++ /dev/null
@@ -1,34 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "mlir/Pass/Pass.h"  // from @llvm-project
-#include "mlir/Pass/PassManager.h"  // from @llvm-project
-#include "mlir/Transforms/Passes.h"  // from @llvm-project
-#include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
-#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
-#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
-
-namespace {
-
-// Registers a pipeline builder function for TF TPU bridge.
-mlir::PassPipelineRegistration<> tpu_pipeline(
-    "tf-tpu-bridge",
-    "Run all the passes involved in transforming the graph before execution so "
-    "that it is suitable for targeting TPUs.",
-    [](mlir::OpPassManager& pm) {
-      return mlir::TFTPU::CreateTPUBridgePipeline(pm);
-    });
-
-}  // anonymous namespace
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD
index da79d85..359dd5c 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD
@@ -4,10 +4,12 @@
 package(
     # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
     default_visibility = [
+        "//learning/serving/contrib/tfrt/mlir/saved_model_analysis:__pkg__",
         "//tensorflow/compiler/mlir:__pkg__",
         "//tensorflow/compiler/mlir/tensorflow/transforms:__pkg__",
         "//tensorflow/compiler/mlir/tf2xla/api:__subpackages__",
         "//tensorflow/compiler/mlir/tfrt:__subpackages__",
+        "//tensorflow/compiler/tf2xla:__pkg__",
     ],
     licenses = ["notice"],
 )
@@ -27,14 +29,20 @@
         "//tensorflow/compiler/mlir/tensorflow/transforms:verify_no_outside_compilation_markers_pass",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib_proto_parsing",
+        "//tensorflow/core/platform:error_payloads",
+        "//tensorflow/core/platform:status",
         "//tensorflow/core/tpu:tpu_defs",
+        "@com_google_absl//absl/base:core_headers",
         "@com_google_absl//absl/log",
+        "@com_google_absl//absl/status",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:FuncDialect",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:Support",
         "@llvm-project//mlir:Transforms",
+        "@local_tsl//tsl/platform:error_logging",
+        "@local_tsl//tsl/platform:errors",
     ],
 )
 
@@ -54,6 +62,7 @@
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core:test",
+        "//tensorflow/core/lib/monitoring:cell_reader",
         "//tensorflow/core/platform:resource_loader",
         "//tensorflow/core/tpu:tpu_defs",
         "@com_google_absl//absl/status",
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.cc b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.cc
index 4dfe7c1..cba2b05 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.cc
@@ -17,6 +17,7 @@
 #include <string>
 
 #include "absl/log/log.h"
+#include "absl/status/status.h"
 #include "llvm/ADT/StringRef.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
@@ -29,10 +30,15 @@
 #include "tensorflow/compiler/mlir/tensorflow/utils/data_dumper_logger_config.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
+#include "tensorflow/core/framework/metrics.h"
 #include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/error_payloads.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/tpu/tpu_defs.h"
 #include "tensorflow/core/util/debug_data_dumper.h"
 #include "tsl/framework/device_type.h"
+#include "tsl/platform/error_logging.h"
+#include "tsl/platform/errors.h"
 
 namespace tensorflow {
 namespace tfrt_compiler {
@@ -111,6 +117,39 @@
   AddNonTPULowerClusterToRuntimeOpsPassPipeline(pm, /*module_name=*/"");
 }
 
+// TODO(b/306728216): Move this out of the Bridge component and into a Host
+// runtime component.
+tensorflow::Status RecordIfErrorStatus(const std::string error_prefix,
+                                       tsl::DeviceType device_type,
+                                       absl::Status status) {
+  if (status.ok()) {
+    return status;
+  }
+
+  VLOG(2) << error_prefix << " " << status;
+  tensorflow::metrics::UpdateTfMlirBridgeFirstPhaseCounter(
+      device_type.type_string(), /*bridge_version=*/"v2",
+      /*fallback_enabled=*/false,
+      /*result=*/"failure");
+
+  constexpr char kBridgeComponent[] = "TFXLABridge";
+  std::string bridge_subcomponent = "TFXLA_PHASE_ONE_MLIR_TPU_BRIDGE";
+
+  tsl::OkOrSetErrorCounterPayload(
+      tensorflow::core::platform::ErrorSourceProto::MLIR_BRIDGE_PHASE_1,
+      status);
+
+  if (device_type != DeviceType(DEVICE_TPU_XLA_JIT)) {
+    bridge_subcomponent = "TFXLA_PHASE_ONE_MLIR_CPU/GPU_BRIDGE";
+  }
+
+  tsl::error_logging::Log(kBridgeComponent, bridge_subcomponent,
+                          status.ToString())
+      .IgnoreError();
+
+  return status;
+}
+
 absl::Status RunLowerClusterToRuntimeOpsPassPipeline(
     mlir::ModuleOp module, tsl::DeviceType xla_device_type,
     llvm::StringRef module_name) {
@@ -154,7 +193,12 @@
         module, llvm::StringRef(), &runtime_lowering);
   }
 
-  return diag_handler.ConsumeStatus();
+  auto result_status = diag_handler.ConsumeStatus();
+  TF_RETURN_IF_ERROR(
+      RecordIfErrorStatus(/*error_prefix=*/"lower_cluster_to_runtime",
+                          xla_device_type, result_status));
+
+  return absl::OkStatus();
 }
 
 // TODO(b/305211853): Unify the CPU/TPU/GPU Execution Ops and thus these two
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.h b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.h
index 49f9448..7d4d04c 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.h
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.h
@@ -42,19 +42,6 @@
     mlir::ModuleOp module, tsl::DeviceType xla_device_type,
     llvm::StringRef module_name = llvm::StringRef());
 
-// TODO(b/305738491): Remove these exposed runtime passes.
-ABSL_DEPRECATED(
-    "Temporary placeholder that will be deleted. Used as a temporary migration "
-    "hack.")
-void AddTPULowerClusterToRuntimeOpsPassPipeline(
-    mlir::OpPassManager& pm, llvm::StringRef module_name = llvm::StringRef());
-
-ABSL_DEPRECATED(
-    "Temporary placeholder that will be deleted. Used as a temporary migration "
-    "hack.")
-void AddNonTPULowerClusterToRuntimeOpsPassPipeline(
-    mlir::OpPassManager& pm, llvm::StringRef module_name = llvm::StringRef());
-
 // The same API as RunLowerClusterToRuntimeOpsPassPipeline but as an MLIR pass
 // pipeline.
 void RegisterTPULowerClusterToRuntimeOpsPassPipeline();
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc
index b58a13d..ab9a56e 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc
@@ -15,6 +15,7 @@
 
 #include "tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.h"
 
+#include <cstdint>
 #include <string>
 #include <vector>
 
@@ -33,6 +34,7 @@
 #include "tensorflow/compiler/mlir/register_common_dialects.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/core/lib/monitoring/cell_reader.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/resource_loader.h"
 #include "tensorflow/core/platform/test.h"
@@ -51,6 +53,7 @@
 using mlir::OpPassManager;
 using mlir::OwningOpRef;
 using mlir::func::FuncOp;
+using ::tensorflow::monitoring::testing::CellReader;
 using tsl::DeviceType;
 
 std::string TestDataPath() {
@@ -58,6 +61,9 @@
       "tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/testdata/");
 }
 
+static constexpr char kCompilationStreamz[] =
+    "/tensorflow/core/tf_mlir_bridge_first_phase_count";
+
 class LowerClusterToRuntimeOpsTest : public ::testing::Test {
  public:
   LowerClusterToRuntimeOpsTest() {
@@ -154,11 +160,17 @@
 }
 
 TEST_F(LowerClusterToRuntimeOpsTest, ErrorsWithBadCluster) {
+  CellReader<int64_t> compilation_status(kCompilationStreamz);
+
   TF_ASSERT_OK(CreateMlirModule("malformed_cluster.mlir"));
 
   EXPECT_FALSE(RunLowerClusterToRuntimeOpsPassPipeline(
                    *mlir_module_, DeviceType(DEVICE_TPU_XLA_JIT))
                    .ok());
+
+  EXPECT_EQ(compilation_status.Delta("XLA_TPU_JIT", "v2", "fallback_disabled",
+                                     "failure"),
+            1);
 }
 
 TEST_F(LowerClusterToRuntimeOpsTest, DumpsPipelinePasses) {
@@ -178,20 +190,6 @@
   EXPECT_THAT(files, ::testing::SizeIs(15));
 }
 
-TEST_F(LowerClusterToRuntimeOpsTest, AddsTPUPipelinePasses) {
-  OpPassManager pass_manager;
-  AddTPULowerClusterToRuntimeOpsPassPipeline(pass_manager);
-
-  EXPECT_EQ(pass_manager.size(), 8);
-}
-
-TEST_F(LowerClusterToRuntimeOpsTest, AddsNonTPUPipelinePasses) {
-  OpPassManager pass_manager;
-  AddNonTPULowerClusterToRuntimeOpsPassPipeline(pass_manager);
-
-  EXPECT_EQ(pass_manager.size(), 4);
-}
-
 }  // namespace
 }  // namespace tfrt_compiler
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/mlprogram.cc b/tensorflow/compiler/mlir/tensorflow/transforms/mlprogram.cc
index aa1a74a..523cdf2 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/mlprogram.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/mlprogram.cc
@@ -24,13 +24,14 @@
 #include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h"
+#include "tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.h"
 #include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h"
 #include "xla/mlir_hlo/mhlo/transforms/passes.h"
 
 namespace tensorflow {
 
 void PopulateLowerToMlProgramAndHloPipeline(mlir::OpPassManager& pm) {
-  mlir::TF::CreateTFXLABridgePipeline(pm);
+  tensorflow::tf2xla::internal::AddNonTPUBridgeClusteringPipelinePasses(pm);
 
   // Remove unused global tensors, or make then immutable if possible.
   pm.addPass(mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass());
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
index 247ea10..12cf30e 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
@@ -318,10 +318,6 @@
 // Moves TPUCompileMlir ops as far to the front as possible.
 std::unique_ptr<OperationPass<func::FuncOp>> CreateMoveTpuCompileToFrontPass();
 
-// Populates the supplied passmanager with the passes required to run the
-// CPU/GPU bridge.
-void CreateTFXLABridgePipeline(OpPassManager& pm);
-
 //===----------------------------------------------------------------------===//
 // XlaCallModule
 //===----------------------------------------------------------------------===//
@@ -467,11 +463,6 @@
 std::unique_ptr<OperationPass<ModuleOp>>
 CreateMarkOpsForOutsideCompilationPass();
 
-// Creates a pass that extracts outside compilation (Host ops inside device
-// cluster) at head/tail of Device cluster to run before/after XLA computation.
-std::unique_ptr<OperationPass<ModuleOp>>
-CreateExtractHeadTailOutsideCompilationPass();
-
 // Creates a pass that extract outside compilation (Host ops inside cevice
 // cluster) ops to a separate parallel_execute region to run on CPU.
 std::unique_ptr<OperationPass<ModuleOp>> CreateExtractOutsideCompilationPass();
@@ -534,11 +525,6 @@
 std::unique_ptr<OperationPass<func::FuncOp>>
 CreateTPUPartitionedOpConversionPass();
 
-// Creates a pass that forms clusters from operations of the same
-// `_replication_info` attribute.
-std::unique_ptr<OperationPass<ModuleOp>> CreateTPUClusterFormationPass(
-    bool strict_clusters = false);
-
 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUValidateInputsPass();
 
 // Creates a pass that cleans up `_replication_info` attribute on operations
@@ -627,11 +613,6 @@
 // Create a pass that colocates each `Split` with its predecessor.
 std::unique_ptr<OperationPass<func::FuncOp>> CreateTPUColocateSplitsPass();
 
-// Populates the supplied passmanager with the passes required to run the
-// bridge.
-void CreateTPUBridgePipeline(OpPassManager& pm,
-                             llvm::StringRef module_name = llvm::StringRef());
-
 // Creates a pass that replicates the tf._TPUCompileMlir op on each host that
 // needs the compiled program. It helps avoid transferring the compiled binary
 // between hosts.
@@ -725,7 +706,6 @@
 #define GEN_PASS_DECL_TPUCOLOCATECOMPOSITERESOURCEOPSPASS
 #define GEN_PASS_DECL_TPUDEVICEPROPAGATIONPASS
 #define GEN_PASS_DECL_TPUDYNAMICLAYOUTPASS
-#define GEN_PASS_DECL_TPUEXTRACTHEADTAILOUTSIDECOMPILATIONPASS
 #define GEN_PASS_DECL_TPUEXTRACTOUTSIDECOMPILATIONPASS
 #define GEN_PASS_DECL_TPUHOSTCOMPUTATIONEXPANSIONPASS
 #define GEN_PASS_DECL_TPUIDENTITYPRUNINGPASS
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc b/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc
index d81ee1e..78f2c3e 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc
@@ -17,9 +17,13 @@
 // the TensorFlow dialect to their functional counterparts, i.e.,
 // tf.IfRegion ->  tf.If and tf.WhileRegion -> tf.While
 
+#include <iterator>
+#include <memory>
 #include <optional>
+#include <string>
 
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/FormatVariadic.h"
@@ -27,18 +31,21 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/MLIRContext.h"  // from @llvm-project
+#include "mlir/IR/Matchers.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
+#include "mlir/IR/Region.h"  // from @llvm-project
 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
 #include "mlir/IR/Verifier.h"  // from @llvm-project
 #include "mlir/IR/Visitors.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
-#include "mlir/Pass/PassRegistry.h"  // from @llvm-project
+#include "mlir/Support/LLVM.h"  // from @llvm-project
+#include "mlir/Support/LogicalResult.h"  // from @llvm-project
 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
-#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
 
@@ -69,6 +76,8 @@
                               CaseRegionOp case_region);
   LogicalResult ConvertWhileOp(SymbolTableCollection& symbol_table,
                                WhileRegionOp while_region);
+  LogicalResult ConvertGeneratorDatasetOp(SymbolTableCollection& symbol_table,
+                                          GeneratorDatasetRegionOp regional);
 
   // Get unique name by using the loc to name mapping.
   std::string GetName(Operation* op, StringRef suffix);
@@ -124,6 +133,37 @@
   dst->setAttr(kXlaPropagateCompileTimeConsts, builder->getBoolAttr(true));
 }
 
+// If the region only does a single function call whose operands / returns match
+// exactly the block args and results, return the name of the called function.
+std::optional<StringRef> UnwrapSingleFunctionCall(Region& region) {
+  // The pattern we're matching is
+  // ^block(arg0, arg1, ..., argN):
+  //   r0, r1, ..., rN = func.call @foo(arg0, arg1, ..., argN)
+  //   "tf.yield"(r0, r1, ..., rN)
+  if (!region.hasOneBlock()) return std::nullopt;
+  Block& block = region.front();
+  if (std::distance(block.begin(), block.end()) != 2) return std::nullopt;
+  TF::YieldOp yield =
+      llvm::dyn_cast_or_null<TF::YieldOp>(block.getTerminator());
+  if (!yield) return std::nullopt;
+  func::CallOp call = llvm::dyn_cast_or_null<func::CallOp>(*block.begin());
+  if (!call) return std::nullopt;
+  if (block.getNumArguments() != call.getNumOperands() ||
+      call.getNumResults() != yield.getNumOperands())
+    return std::nullopt;
+  for (auto [arg, operand] :
+       llvm::zip(block.getArguments(), call.getOperands())) {
+    if (arg != operand) return std::nullopt;
+  }
+  for (auto [ret, operand] :
+       llvm::zip(call.getResults(), yield.getOperands())) {
+    if (ret != operand) return std::nullopt;
+  }
+  SymbolRefAttr symbol = call.getCallableForCallee().get<SymbolRefAttr>();
+  if (!symbol) return std::nullopt;
+  return symbol.getLeafReference();
+}
+
 // Extracts the contents of a region with a single block into a new function.
 // `extern_values` is the set of external values that the region refers to.
 // Returns the name of the newly created function.
@@ -135,7 +175,13 @@
     SymbolTableCollection& symbol_table, Region& region, StringRef name,
     llvm::SmallVectorImpl<Value>& extern_values,
     llvm::SmallVectorImpl<func::FuncOp>& worklist,
-    bool extern_values_passthrough, bool only_one_return_value) {
+    bool extern_values_passthrough, bool only_one_return_value,
+    bool allow_return_of_existing = false) {
+  if (allow_return_of_existing && extern_values.empty()) {
+    auto existing = UnwrapSingleFunctionCall(region);
+    if (existing) return *existing;
+  }
+
   ModuleOp module = region.getParentOfType<ModuleOp>();
   auto builder = OpBuilder::atBlockBegin(module.getBody());
   auto loc = region.getParentOp()->getLoc();
@@ -524,6 +570,52 @@
   return success();
 }
 
+// Transform GeneratorDatasetRegion to GeneratorDatasetOp.
+LogicalResult RegionControlFlowToFunctional::ConvertGeneratorDatasetOp(
+    SymbolTableCollection& symbol_table, GeneratorDatasetRegionOp regional) {
+  mlir::MLIRContext* ctx = regional.getContext();
+  std::string init_name, next_name, finalize_name;
+
+  llvm::SmallVector<Value, 4> extern_values =
+      CollectExternValues(regional.getRegions());
+
+  if (!extern_values.empty()) return failure();
+
+  init_name = GetName(regional, "_init");
+  init_name = ExtractSingleBlockRegion(symbol_table, regional.getInit(),
+                                       init_name, extern_values, worklist,
+                                       /*extern_values_passthrough=*/false,
+                                       /*only_one_return_value=*/false,
+                                       /*allow_return_of_existing=*/true);
+
+  next_name = GetName(regional, "_next");
+  next_name = ExtractSingleBlockRegion(symbol_table, regional.getNext(),
+                                       next_name, extern_values, worklist,
+                                       /*extern_values_passthrough=*/false,
+                                       /*only_one_return_value=*/false,
+                                       /*allow_return_of_existing=*/true);
+
+  finalize_name = GetName(regional, "_finalize");
+  finalize_name =
+      ExtractSingleBlockRegion(symbol_table, regional.getFinalize(),
+                               finalize_name, extern_values, worklist,
+                               /*extern_values_passthrough=*/false,
+                               /*only_one_return_value=*/false,
+                               /*allow_return_of_existing=*/true);
+
+  auto new_op = OpBuilder(regional).create<TF::GeneratorDatasetOp>(
+      regional.getLoc(), regional->getResultTypes(),
+      regional.getInitFuncOtherArgs(), regional.getNextFuncOtherArgs(),
+      regional.getFinalizeFuncOtherArgs(), SymbolRefAttr::get(ctx, init_name),
+      SymbolRefAttr::get(ctx, next_name),
+      SymbolRefAttr::get(ctx, finalize_name), regional.getOutputTypes(),
+      regional.getOutputShapes(), regional.getMetadata());
+
+  regional->replaceAllUsesWith(new_op->getResults());
+  regional->erase();
+  return success();
+}
+
 void RegionControlFlowToFunctional::runOnOperation() {
   ModuleOp module = getOperation();
   SymbolTableCollection symbol_table;
@@ -549,6 +641,11 @@
           op->emitOpError() << "failed to convert to functional form";
           return WalkResult::interrupt();
         }
+      } else if (auto gen = llvm::dyn_cast<GeneratorDatasetRegionOp>(op)) {
+        if (failed(ConvertGeneratorDatasetOp(symbol_table, gen))) {
+          op->emitOpError() << "failed to convert to functional form";
+          return WalkResult::interrupt();
+        }
       }
       return WalkResult::advance();
     });
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
index 581e04a..c458b2c 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
@@ -835,6 +835,11 @@
   // Returns whether it was able to compute constant values.
   LogicalResult TryToFold(Operation* op);
 
+  // Forcely assign operand types to result types (the i-th operand type will
+  // assign to i-th result type). Returns true if anything is changed.
+  bool ForceTypeForPassThroughOperands(Operation* op, OperandRange operands,
+                                       ResultRange results);
+
   // Makes result types match the operand types (the i-th result type will
   // match the i-th operand type). Returns true if anything is changed.
   bool RefineTypeForPassThroughOperands(Operation* op, OperandRange operands,
@@ -1257,7 +1262,7 @@
 
   bool changed = false;
   for (auto [result, type] :
-       llvm::zip(op.getResults(), loader->output_types())) {
+       llvm::zip(op.getResults(), loader->OutputTypes())) {
     auto ranked = type.dyn_cast<RankedTensorType>();
     if (ranked == nullptr) {
       LLVM_DEBUG(llvm::dbgs()
@@ -2288,6 +2293,23 @@
   return ic->MakeShape(dims);
 }
 
+bool ShapeInference::ForceTypeForPassThroughOperands(Operation* op,
+                                                     OperandRange operands,
+                                                     ResultRange results) {
+  bool changed = false;
+  for (auto entry : llvm::zip(operands, results)) {
+    Type operand_type = std::get<0>(entry).getType();
+    Value result = std::get<1>(entry);
+    TensorType result_type = dyn_cast<TensorType>(result.getType());
+    if (result_type == operand_type) continue;
+
+    if (!UpdateTypeAndInsertIncompatibleUseCasts(operand_type, result))
+      continue;
+    changed = true;
+  }
+  return changed;
+}
+
 bool ShapeInference::RefineTypeForPassThroughOperands(Operation* op,
                                                       OperandRange operands,
                                                       ResultRange results) {
@@ -2323,14 +2345,14 @@
 
 bool ShapeInference::InferShapeForNonTFDialectOperation(Operation* op) {
   if (auto graph_op = dyn_cast<tf_executor::GraphOp>(op)) {
-    return RefineTypeForPassThroughOperands(graph_op.GetFetch(),
-                                            graph_op.GetFetch().getFetches(),
-                                            op->getResults());
+    return ForceTypeForPassThroughOperands(graph_op.GetFetch(),
+                                           graph_op.GetFetch().getFetches(),
+                                           op->getResults());
   }
   if (auto island_op = dyn_cast<tf_executor::IslandOp>(op)) {
-    return RefineTypeForPassThroughOperands(island_op.GetYield(),
-                                            island_op.GetYield().getFetches(),
-                                            op->getResults());
+    return ForceTypeForPassThroughOperands(island_op.GetYield(),
+                                           island_op.GetYield().getFetches(),
+                                           op->getResults());
   }
   if (auto iter_sink = dyn_cast<tf_executor::NextIterationSinkOp>(op)) {
     auto iter_source = cast<tf_executor::NextIterationSourceOp>(
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td b/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td
index 108d58f..782232c 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td
@@ -1019,92 +1019,6 @@
   let constructor = "TFTPU::CreateTPUValidateInputsPass()";
 }
 
-def TPUClusterFormationPass : Pass<"tf-tpu-cluster-formation", "ModuleOp"> {
-  let summary = "Forms clusters from operations assigned to the same TPU computation";
-
-  let description = [{
-    TPU computations from the frontend are composed of a `tf.TPUReplicateMetadata`
-    op, a subgraph of ops (TensorFlow Dialect) each with a matching
-    `_replication_info` attribute relative to the associated
-    `tf.TPUReplicateMetadata` op, and optionally `tf.TPUReplicatedInput` and
-    `tf.TPUReplicatedOutput` ops feeding in inputs and outputs to and from a
-    replicated TPU computation. The number of times a TPU computation is
-    replicated is defined in the `tf.TPUReplicateMetadata` op (`num_replicas`
-    attribute) and operand and result sizes of `tf.TPUReplicatedInput` and
-    `tf.TPUReplicatedOutput` respectively must match, excluding packed tensors.
-    It is also assumed ops of the same TPU computation do not have ops outside
-    of the TPU computation that are both inputs and outputs to the same TPU
-    computation. Furthermore, we assume that every node has either none or both
-    of `_replication_info` and `_xla_compile_device_type` attributes defined.
-
-    This pass takes the TPU computation subgraph, moves them into a
-    `tf_device.cluster`, and copies over attributes from the associated
-    `tf.TPUReplicateMetadata` op to the newly created `tf_device.cluster`. If the
-    computation is replicated (`num_replicas` > 1), the `num_replicas` attribute is
-    not copied over but instead the `tf_device.cluster` is further wrapped with a
-    `tf_device.replicate`, and associated `tf.TPUReplicatedInput` and
-    `tf.TPUReplicatedOutput` ops are replaced as the `tf_device.replicate` operands
-    and results. Otherwise, the single operands and results of the associated
-    `tf.TPUReplicatedInput` and `tf.TPUReplicatedOutput` ops are simply forwarded to
-    the `tf_device.cluster`.
-
-    For example, the following non replicated computation:
-
-    ```mlir
-    func @tpu_computation(%arg0: tensor<i32>) -> tensor<i32> {
-      // Metadata op for cluster `cluster` with 1 replica, 1 core per replica and
-      // with topology `<topology>`.
-      "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "cluster", num_relicas = 1, num_cores_per_replica = 1, topology = "<topology>", device_assignment = [], padding_map = []} : () -> ()
-      %replicated_input = "tf.TPUReplicatedInput"(%arg0) : (tensor<i32>) -> tensor<i32>
-      %identity = "tf.Identity"(%replicated_input) {_xla_compile_device_type = "TPU", _replication_info = "cluster"} : (tensor<i32>) -> tensor<i32>
-      %replicated_output = "tf.TPUReplicatedOutput(%identity) : (tensor<i32>) -> tensor<i32>
-      return %replicated_output : tensor<i32>
-    }
-    ```
-
-    will be transformed into:
-
-    ```mlir
-    func @tpu_computation(%arg0: tensor<i32>) -> tensor<i32> {
-      %cluster = "tf_device.cluster"() ( {
-        %identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
-        tf_device.return %identity : tensor<i32>
-      }) {_xla_compile_device_type = "TPU", _replication_info = "cluster", num_cores_per_replica = 1, topology = "topology", device_assignment = [], padding_map = []} : () -> (tensor<i32>)
-      return %cluster : tensor<i32>
-    }
-    ```
-
-    The following replicated computation:
-
-    ```mlir
-    func @tpu_computation(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
-      "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "cluster", num_relicas = 2, num_cores_per_replica = 1, topology = "topology", device_assignment = [], padding_map = []} : () -> ()
-      %replicated_input = "tf.TPUReplicatedInput"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
-      %identity = "tf.Identity"(%replicated_input) {_xla_compile_device_type = "TPU", _replication_info = "cluster"} : (tensor<i32>) -> tensor<i32>
-      %replicated_output:2 = "tf.TPUReplicatedOutput(%identity) : (tensor<i32>) -> (tensor<i32>, tensor<i32>)
-      return %replicated_output#0, %replicated_output#1 : tensor<i32>, tensor<i32>
-    }
-    ```
-
-    will be transformed into:
-
-    ```mlir
-    func @tpu_computation(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
-      %replicate:2 = tf_device.replicate([%arg0, %arg1] as %replicated_input) {n = 2 : i32} {
-        %cluster = "tf_device.cluster"() ( {
-          %identity = "tf.Identity"(%replicated_input) : (tensor<i32>) -> tensor<i32>
-          tf_device.return %identity : tensor<i32>
-        }) {_xla_compile_device_type = "TPU", _replication_info = "cluster", num_cores_per_replica = 1, topology = "topology", device_assignment = [], padding_map = []} : () -> (tensor<i32>)
-        tf_device.return %cluster : tensor<i32>
-      }
-      return %replicate#0, %replicate#1 : tensor<i32>, tensor<i32>
-    }
-    ```
-  }];
-
-  let constructor = "TFTPU::CreateTPUClusterFormationPass()";
-}
-
 def ClusterConstantSinkingPass : Pass<"tf-device-constant-sinking", "mlir::func::FuncOp"> {
   let summary = "Sinks constants implicitly captured in a tf_device.cluster region.";
 
@@ -2042,48 +1956,6 @@
   let constructor = "TFTPU::CreateTPUClusterCleanupAttributesPass()";
 }
 
-def ExtractHeadTailOutsideCompilationPass : Pass<"tf-extract-head-tail-outside-compilation", "ModuleOp"> {
-  let summary = "Extracts head or tail outside compilation to separate host launches before/after device cluster.";
-
-  let description = [{
-    This pass extracts a CPU computation cluster with `_xla_outside_compilation`
-    annotation from the head or tail of a Device cluster.
-
-    For example:
-
-    ```mlir
-      %cluster = "tf_device.cluster"() ( {
-        %a = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> tensor<i32>
-        %b = "tf.B"(%a) : (tensor<i32>) -> tensor<i32>
-        %c = "tf.C"(%b) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> tensor<i32>
-        tf_device.return %c : tensor<i32>
-      }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> tensor<i32>
-      return %cluster : tensor<i32>
-    ```
-
-    becomes:
-
-    ```mlir
-    %0 = "tf_device.launch"() ( {
-      %3 = "tf.A"(%arg0) : (tensor<i32>) -> tensor<i32>
-      tf_device.return %3 : tensor<i32>
-    }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> tensor<i32>
-    %1 = "tf_device.cluster"() ( {
-      %3 = "tf.B"(%0) : (tensor<i32>) -> tensor<i32>
-      tf_device.return %3 : tensor<i32>
-    }) {device_assignment = [], num_cores_per_replica = 1 : i64, padding_map = [], step_marker_location = "", topology = ""} : () -> tensor<i32>
-    %2 = "tf_device.launch"() ( {
-      %3 = "tf.C"(%1) : (tensor<i32>) -> tensor<i32>
-      tf_device.return %3 : tensor<i32>
-    }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> tensor<i32>
-    return %2 : tensor<i32>
-
-    ```
-  }];
-
-  let constructor = "TFDevice::CreateExtractHeadTailOutsideCompilationPass()";
-}
-
 def TPUSpaceToDepthPass : Pass<"tf-tpu-space-to-depth-pass", "ModuleOp"> {
   let summary = "Applies automatic space to depth transform for the first or frontier convolutions consume host inputs on TPU.";
 
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/xla_call_module_deserialization.cc b/tensorflow/compiler/mlir/tensorflow/transforms/xla_call_module_deserialization.cc
index 536a394..42b0516 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/xla_call_module_deserialization.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/xla_call_module_deserialization.cc
@@ -18,16 +18,21 @@
 #include <utility>
 #include <vector>
 
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringRef.h"
-#include "llvm/Support/FormatVariadic.h"
 #include "mlir/Dialect/Func/Extensions/AllExtensions.h"  // from @llvm-project
 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
 #include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project  // IWYU pragma: keep
+#include "mlir/IR/Attributes.h"  // from @llvm-project
+#include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/OwningOpRef.h"  // from @llvm-project
 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
+#include "mlir/IR/Visitors.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
 #include "stablehlo/dialect/ChloOps.h"  // from @stablehlo  // IWYU pragma: keep
@@ -81,19 +86,6 @@
   return std::move(*loader).module();
 }
 
-// If `func_name` exists in `symbol_table`, returns a new name that doesn't
-// exist. Otherwise, returns `func_name` as is.
-StringAttr NewFuncName(const SymbolTable &symbol_table, StringAttr func_name) {
-  int index = 0;
-  StringAttr new_func_name = func_name;
-  while (symbol_table.lookup(new_func_name)) {
-    new_func_name =
-        StringAttr::get(func_name.getContext(),
-                        llvm::formatv("{0}{1}", func_name.getValue(), index++));
-  }
-  return new_func_name;
-}
-
 // Renames functions in the stablehlo module to avoid naming conflicts with
 // existing functions in the tf module.
 // Sets _from_xla_call_module attribute for each stablehlo function.
@@ -108,20 +100,21 @@
     MLIRContext *context, SymbolTableCollection &symbol_tables,
     ModuleOp tf_module, ModuleOp stablehlo_module) {
   SymbolTable &tf_symbol_table = symbol_tables.getSymbolTable(tf_module);
+  SymbolTable &stablehlo_symbol_table =
+      symbol_tables.getSymbolTable(stablehlo_module);
   Builder builder(context);
   StringAttr main_func_name;
   for (auto func : stablehlo_module.getOps<func::FuncOp>()) {
-    StringAttr func_name = NewFuncName(tf_symbol_table, func.getSymNameAttr());
-    if (func.getSymName() == kStablehloMainFunctionName) {
-      main_func_name = func_name;
-    }
-    if (func_name != func.getSymNameAttr()) {
-      if (failed(SymbolTable::replaceAllSymbolUses(func, func_name,
-                                                   stablehlo_module))) {
+    const bool is_main_func = func.getSymName() == kStablehloMainFunctionName;
+    if (tf_symbol_table.lookup(func.getSymName())) {
+      if (failed(stablehlo_symbol_table.renameToUnique(
+              func, {&tf_symbol_table, &stablehlo_symbol_table}))) {
         return func.emitError()
                << "failed to rename StableHLO function " << func.getSymName();
       }
-      func.setName(func_name);
+    }
+    if (is_main_func) {
+      main_func_name = func.getSymNameAttr();
     }
     func->setAttr(kFromXlaCallModuleAttrName, builder.getUnitAttr());
   }
@@ -225,7 +218,7 @@
 
   // Translate `called_index` in TF function custom calls into symbol
   // references. `function_list` attribute is needed after that.
-  SmallVector<SymbolRefAttr> function_list(
+  llvm::SmallVector<SymbolRefAttr> function_list(
       op.getFunctionList().getAsRange<SymbolRefAttr>());
   if (failed(
           SymbolizeCustomCallCalledIndex(*stablehlo_module, function_list))) {
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc
index d478a8a..7b3c24c 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc
@@ -821,9 +821,8 @@
   // Construct one in that case.
   if (configs.export_entry_func_to_flib) {
     graph = std::make_unique<Graph>(OpRegistry::Global());
-    // TODO(hinsu): Avoid Proto -> Memory -> Proto conversion here.
-    FunctionDefLibrary flib = flib_def.ToProto();
-    TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(flib));
+    TF_RETURN_IF_ERROR(
+        graph->mutable_flib_def()->AddLibrary(std::move(flib_def)));
   }
 
   auto graphdef = std::make_unique<GraphDef>();
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
index f545005..999fe17 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
@@ -4326,7 +4326,7 @@
     mlir::MLIRContext* context) {
   tensorflow::GraphDebugInfo dummy_debug_info;
   tensorflow::GraphImportConfig specs;
-  specs.graph_func_name = fbody->fdef.signature().name();
+  specs.graph_func_name = fbody->record->fdef().signature().name();
   specs.enable_shape_inference = false;
   specs.graph_as_function = true;
   for (const auto* control_ret_node : fbody->control_ret_nodes)
diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD b/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD
index 856adf4..d0653b9 100644
--- a/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD
+++ b/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD
@@ -183,6 +183,7 @@
         "//tensorflow/core/platform:errors",
         "//tensorflow/core/platform:stacktrace",
         "//tensorflow/core/platform:status",
+        "//tensorflow/core/tpu:tpu_defs",
         "@com_google_absl//absl/log",
         "@com_google_absl//absl/status",
         "@llvm-project//llvm:Support",
@@ -232,14 +233,19 @@
         "//tensorflow/compiler/mlir/tensorflow/transforms:verify_no_outside_compilation_markers_pass",
         "//tensorflow/compiler/mlir/tf2xla/internal:logging_hooks",
         "//tensorflow/core:framework",
+        "//tensorflow/core/platform:error_payloads",
         "//tensorflow/core/platform:status",
         "@com_google_absl//absl/base:core_headers",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/status",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:FuncDialect",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:Support",
         "@llvm-project//mlir:Transforms",
+        "@local_tsl//tsl/lib/monitoring:counter",
+        "@local_tsl//tsl/platform:error_logging",
         "@local_tsl//tsl/platform:status",
     ],
 )
@@ -254,6 +260,7 @@
     deps = [
         ":tf_dialect_to_executor",
         "//tensorflow/compiler/mlir:register_common_dialects",
+        "//tensorflow/core/lib/monitoring:cell_reader",
         "//tensorflow/core/platform:resource_loader",
         "@com_google_absl//absl/status",
         "@com_google_absl//absl/strings",
diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf.cc
index fa03af5..2f8469e 100644
--- a/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf.cc
+++ b/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf.cc
@@ -42,7 +42,9 @@
 #include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/stacktrace.h"
 #include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/tpu/tpu_defs.h"
 #include "tensorflow/core/util/debug_data_dumper.h"
+#include "tsl/framework/device_type.h"
 #include "tsl/platform/error_logging.h"
 #include "tsl/platform/errors.h"
 
@@ -142,6 +144,7 @@
 }  // namespace
 
 tensorflow::Status RecordStatusIfError(const std::string error_prefix,
+                                       bool is_in_fallback_enabled_mode,
                                        absl::Status status) {
   if (status.ok()) {
     return status;
@@ -149,7 +152,8 @@
 
   tensorflow::metrics::UpdateTfMlirBridgeFirstPhaseCounter(
       /*device_type=*/"tpu", /*bridge_version=*/"v1",
-      /*fallback_enabled=*/false, /*result=*/"failure");
+      /*fallback_enabled=*/is_in_fallback_enabled_mode,
+      /*result=*/"failure");
   tsl::error_logging::Log(kBridgeComponent,
                           "TFXLA_PHASE_ONE_MLIR_TPU_V1_COMPAT_BRIDGE",
                           status.ToString())
@@ -161,21 +165,20 @@
 // V1 Compat Bridge takes a TF Executor dialect and extracts the TF2 portion
 // and inserts it into a submodule. We just want to run the clustering
 // portion of the pipeline on just the single submodule.
-absl::Status RunClusteringPipelineOnSubmodule(ModuleOp parent_module) {
+absl::Status RunClusteringPipelineOnSubmodule(
+    ModuleOp parent_module, bool is_in_fallback_enabled_mode) {
   int num_submodules = 0;
-  mlir::WalkResult submodule_status = parent_module.walk([&](ModuleOp
-                                                                 submodule) {
+  absl::Status clustering_pipeline_status;
+  parent_module.walk([&](ModuleOp submodule) {
     if (submodule == parent_module) return mlir::WalkResult::advance();
     num_submodules++;
-    auto clustering_pipeline_status = RunTFXLABridge(
+    clustering_pipeline_status = RunTFXLABridge(
         submodule,
         [](OpPassManager &pm) {
           internal::AddBridgeClusteringPipelinePasses(pm);
-          tensorflow::tfrt_compiler::AddTPULowerClusterToRuntimeOpsPassPipeline(
-              pm);
         },
         /*module_name=*/"", /*dump_prefix=*/"tf_xla_clustering_bridge_v1");
-    if (!clustering_pipeline_status.ok()) {
+    if (num_submodules > 1) {
       return mlir::WalkResult::interrupt();
     }
 
@@ -187,22 +190,20 @@
         "V1 Compat Bridge has more than one submodule. Erroring out.");
     TF_RETURN_IF_ERROR(RecordStatusIfError(
         /*error_prefix=*/"Bridge has more than one submodule:",
-        num_submodules_error));
+        is_in_fallback_enabled_mode, num_submodules_error));
   }
 
-  if (submodule_status.wasInterrupted()) {
-    auto submodule_error = absl::InternalError(
-        "V1 Compat Bridge Errored running clustering pipeline. Erroring out.");
+  if (!clustering_pipeline_status.ok()) {
     TF_RETURN_IF_ERROR(RecordStatusIfError(
         /*error_prefix=*/"Bridge Errored running clustering pipeline:",
-        submodule_error));
+        is_in_fallback_enabled_mode, clustering_pipeline_status));
   }
 
   return absl::OkStatus();
 }
 
-// TODO(b/298390303): Pass in enable_fallback from caller to re-enable logging.
-tensorflow::Status RunSessionTf2xlaClusteringBridge(ModuleOp module) {
+tensorflow::Status RunSessionTf2xlaClusteringBridge(
+    ModuleOp module, bool is_in_fallback_enabled_mode) {
   VLOG(2) << "TPU Sessions Bridge called stack trace is "
           << "(NOTE: this is not an error; rather the stack trace for "
              "debugging) : "
@@ -212,36 +213,18 @@
       module, [](OpPassManager &pm) { CreateTPUBridgePipelineV1(pm); },
       /*module_name=*/"", /*dump_prefix=*/"tf_xla_functional_import_bridge_v1");
   TF_RETURN_IF_ERROR(RecordStatusIfError(
-      /*error_prefix=*/"Bridge Function Import V1", functional_import_status));
+      /*error_prefix=*/"Bridge Function Import V1", is_in_fallback_enabled_mode,
+      functional_import_status));
 
-  TF_RETURN_IF_ERROR(RunClusteringPipelineOnSubmodule(module));
-
-  Status export_preparation_status = RunTFXLABridge(
-      module,
-      [](OpPassManager &pm) {
-        pm.addPass(
-            mlir::tf_executor::CreateTFExecutorTPUV1IslandInliningPass());
-        // There are cases where we don't consume all compilation and
-        // replication attributes like we do for the V2 pipeline, so we need to
-        // convert them from unified to legacy attributes before they get
-        // exposed to outside of the bridge.
-        pm.addNestedPass<FuncOp>(
-            mlir::TFTPU::
-                CreateConvertToLegacyCompileAndReplicateAttributesPass());
-      },
-      /*module_name=*/"",
-      /*dump_prefix=*/"tf_xla_bridge_v1_export_preparation");
-  if (!export_preparation_status.ok()) {
-    TF_RETURN_IF_ERROR(RecordStatusIfError(
-        /*error_prefix=*/"Bridge Export Preparation Failed:",
-        export_preparation_status));
-  }
+  TF_RETURN_IF_ERROR(
+      RunClusteringPipelineOnSubmodule(module, is_in_fallback_enabled_mode));
 
   tensorflow::metrics::UpdateTfMlirBridgeFirstPhaseCounter(
       /*device_type=*/"tpu", /*bridge_version=*/"v1",
-      /*fallback_enabled=*/false, /*result=*/"success");
+      /*n_fallback_enabled*/ is_in_fallback_enabled_mode,
+      /*result=*/"success");
 
-  return tensorflow::tf2xla::v1::ExportFromTensorflowDialectToExecutor(module);
+  return absl::OkStatus();
 }
 
 // Registers a pipeline builder function for TF TPU V1 bridge.
diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf.h b/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf.h
index 02a68e7..e27ec14 100644
--- a/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf.h
+++ b/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf.h
@@ -29,12 +29,13 @@
 // These transformations take as input a Tensorflow Graph as an MLIR Module
 // and transforms the module in place to cluster the given ops for compilation
 // that is compatible with the given device_type. The MLIR should be in the TF
-// Executor Dialect for graph nodes and edges. Individual Op inside a node
-// should be the Tensorflow Dialect. The output MLIR is in the TF Executor
-// Dialect.  The input MLIR should not have infeed and outfeed ops, which are
-// unsupported via this API.
-// Returns OkStatus if passed, otherwise an error.
-tensorflow::Status RunSessionTf2xlaClusteringBridge(mlir::ModuleOp module);
+// Executor Dialect for graph nodes and edges or TF Functional. It will convert
+// to TF Functional internally. Individual Op inside a node should be the
+// Tensorflow Dialect. The output MLIR is in the TF Functional Dialect.  The
+// input MLIR should not have infeed and outfeed ops, which are unsupported via
+// this API. Returns OkStatus if passed, otherwise an error.
+tensorflow::Status RunSessionTf2xlaClusteringBridge(
+    mlir::ModuleOp module, bool is_in_fallback_enabled_mode);
 
 }  // namespace v1
 }  // namespace tf2xla
diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf_test.cc
index b19d224..44eafb2 100644
--- a/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf_test.cc
+++ b/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf_test.cc
@@ -81,26 +81,22 @@
 
   TF_ASSERT_OK(CreateMlirModule("empty_func.mlir"));
 
-  TF_EXPECT_OK(RunSessionTf2xlaClusteringBridge(*mlir_module_));
+  TF_EXPECT_OK(
+      RunSessionTf2xlaClusteringBridge(*mlir_module_,
+                                       /*is_in_fallback_enabled_mode=*/false));
   EXPECT_EQ(
       compilation_status.Delta("tpu", "v1", "fallback_disabled", "success"), 1);
 }
 
-// Required for now due to the Bridge API, but this should be separated out
-// later.
-TEST_F(SessionClusterTensorflowDialectTest,
-       RunsTensorflowDialectToTensorflowExecutor) {
-  TF_ASSERT_OK(CreateMlirModule("invalid_executor.mlir"));
-
-  EXPECT_FALSE(RunSessionTf2xlaClusteringBridge(*mlir_module_).ok());
-}
-
 TEST_F(SessionClusterTensorflowDialectTest, FailsWithMultipleSubmodules) {
   CellReader<int64_t> compilation_status(kCompilationStreamz);
 
   TF_ASSERT_OK(CreateMlirModule("multiple_submodules.mlir"));
 
-  EXPECT_FALSE(RunSessionTf2xlaClusteringBridge(*mlir_module_).ok());
+  EXPECT_FALSE(
+      RunSessionTf2xlaClusteringBridge(*mlir_module_,
+                                       /*is_in_fallback_enabled_mode=*/false)
+          .ok());
 
   EXPECT_EQ(
       compilation_status.Delta("tpu", "v1", "fallback_disabled", "failure"), 1);
diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor.cc
index 941dd36..236282f 100644
--- a/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor.cc
+++ b/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor.cc
@@ -19,6 +19,8 @@
 #include <string>
 #include <utility>
 
+#include "absl/log/log.h"
+#include "absl/status/status.h"
 #include "llvm/ADT/StringRef.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
@@ -32,8 +34,11 @@
 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
 #include "tensorflow/compiler/mlir/tf2xla/internal/logging_hooks.h"
+#include "tensorflow/core/platform/error_payloads.h"
 #include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/util/debug_data_dumper.h"
+#include "tsl/lib/monitoring/counter.h"
+#include "tsl/platform/error_logging.h"
 #include "tsl/platform/status.h"
 
 namespace tensorflow {
@@ -47,6 +52,15 @@
 using mlir::PassManager;
 using mlir::func::FuncOp;
 
+auto *tf_dialect_to_executor_dialect_status = tsl::monitoring::Counter<1>::New(
+    "/tensorflow/core/tf2xla/api/v1/tf_dialect_to_executor_dialect_status",
+    "Counts how often a successful export from TF Dialect to Executor Dialect "
+    "is",
+    "status");
+
+constexpr char kExportSuccess[] = "success";
+constexpr char kExportFailed[] = "failed";
+
 namespace {
 
 void AddTfDialectToExecutorPasses(OpPassManager &pm) {
@@ -55,6 +69,14 @@
     pm.addPass(mlir::CreateBreakUpIslandsPass());
   };
 
+  pm.addPass(mlir::tf_executor::CreateTFExecutorTPUV1IslandInliningPass());
+  // There are cases where we don't consume all compilation and
+  // replication attributes like we do for the V2 pipeline, so we need to
+  // convert them from unified to legacy attributes before they get
+  // exposed to outside of the bridge.
+  pm.addNestedPass<FuncOp>(
+      mlir::TFTPU::CreateConvertToLegacyCompileAndReplicateAttributesPass());
+
   pm.addPass(mlir::TF::CreateTFRegionControlFlowToFunctional());
   add_pass(mlir::CreateFunctionalToExecutorDialectConversionPass());
   add_pass(mlir::TFDevice::CreateReplicateToIslandPass(
@@ -75,6 +97,30 @@
   pm.addPass(mlir::TF::CreateVerifySuitableForExportPass());
 }
 
+tensorflow::Status RecordStatusIfError(absl::Status status) {
+  if (status.ok()) {
+    return absl::OkStatus();
+  }
+
+  VLOG(1) << "Failed to export from TF Dialect to TF Executor Dialect. "
+          << status;
+  tf_dialect_to_executor_dialect_status->GetCell(kExportFailed)->IncrementBy(1);
+
+  constexpr char bridge_subcomponent[] =
+      "TFXLA_TF_FUNCTIONAL_TO_EXECUTOR_EXPORT_v1";
+  constexpr char kBridgeComponent[] = "TFXLABridge";
+
+  tsl::OkOrSetErrorCounterPayload(
+      tensorflow::core::platform::ErrorSourceProto::MLIR_BRIDGE_PHASE_1,
+      status);
+
+  tsl::error_logging::Log(kBridgeComponent, bridge_subcomponent,
+                          status.ToString())
+      .IgnoreError();
+
+  return status;
+}
+
 }  // namespace
 
 tensorflow::Status ExportFromTensorflowDialectToExecutor(
@@ -116,6 +162,13 @@
         module, llvm::StringRef(), &tf_to_executor);
   }
 
+  if (result.failed()) {
+    return RecordStatusIfError(diag_handler.ConsumeStatus());
+  }
+
+  tf_dialect_to_executor_dialect_status->GetCell(kExportSuccess)
+      ->IncrementBy(1);
+
   return diag_handler.ConsumeStatus();
 }
 
diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor_test.cc
index 80a7701..38393d3 100644
--- a/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor_test.cc
+++ b/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor_test.cc
@@ -15,6 +15,7 @@
 
 #include "tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor.h"
 
+#include <cstdint>
 #include <string>
 
 #include <gtest/gtest.h>
@@ -26,6 +27,7 @@
 #include "mlir/IR/OwningOpRef.h"  // from @llvm-project
 #include "mlir/Parser/Parser.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/register_common_dialects.h"
+#include "tensorflow/core/lib/monitoring/cell_reader.h"
 #include "tensorflow/core/platform/resource_loader.h"
 #include "tsl/lib/core/status_test_util.h"
 #include "tsl/platform/status.h"
@@ -41,6 +43,10 @@
 using mlir::MLIRContext;
 using mlir::ModuleOp;
 using mlir::OwningOpRef;
+using tensorflow::monitoring::testing::CellReader;
+
+static constexpr char kCompilationStreamz[] =
+    "/tensorflow/core/tf2xla/api/v1/tf_dialect_to_executor_dialect_status";
 
 std::string TestDataPath() {
   return tensorflow::GetDataDependencyFilepath(
@@ -73,15 +79,23 @@
 };
 
 TEST_F(TensorflowDialectToExecutorTest, ConvertsToExecutor) {
+  CellReader<int64_t> compilation_status(kCompilationStreamz);
+
   TF_ASSERT_OK(CreateMlirModule("empty_func.mlir"));
 
   TF_EXPECT_OK(ExportFromTensorflowDialectToExecutor(*mlir_module_));
+
+  EXPECT_EQ(compilation_status.Delta("success"), 1);
 }
 
 TEST_F(TensorflowDialectToExecutorTest, ErrorsWhenCannotConvert) {
+  CellReader<int64_t> compilation_status(kCompilationStreamz);
+
   TF_ASSERT_OK(CreateMlirModule("invalid_executor.mlir"));
 
   EXPECT_FALSE(ExportFromTensorflowDialectToExecutor(*mlir_module_).ok());
+
+  EXPECT_EQ(compilation_status.Delta("failed"), 1);
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD b/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD
index c46bf99..70a84bcc 100644
--- a/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD
+++ b/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD
@@ -18,6 +18,7 @@
         "//learning/serving/contrib/tfrt/mlir/saved_model_analysis",
         "//tensorflow/compiler/mlir/tfrt",
         "//tensorflow/compiler/tf2xla",
+        "//tensorflow/compiler/mlir",
         # Legacy due to where the bridge currently runs. This should go away.
         "//tensorflow/compiler/mlir/tensorflow/transforms",
     ],
@@ -140,6 +141,7 @@
     data = [
         "testdata/empty_func.mlir",
         "testdata/invalid_executor.mlir",
+        "testdata/outside_compilation.mlir",
     ],
     deps = [
         ":cluster_tf",
@@ -170,7 +172,9 @@
         "//tensorflow/compiler/mlir/tensorflow/transforms:verify_no_outside_compilation_markers_pass",
         "//tensorflow/compiler/mlir/tf2xla/internal:logging_hooks",
         "//tensorflow/core:framework",
+        "//tensorflow/core/platform:error_payloads",
         "//tensorflow/core/platform:status",
+        "@com_google_absl//absl/log",
         "@com_google_absl//absl/status",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:FuncDialect",
@@ -179,6 +183,8 @@
         "@llvm-project//mlir:Support",
         "@llvm-project//mlir:Transforms",
         "@local_tsl//tsl/lib/monitoring:counter",
+        "@local_tsl//tsl/platform:error_logging",
+        "@local_tsl//tsl/platform:errors",
         "@local_tsl//tsl/platform:status",
     ],
 )
diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc
index adf023c..24de1be 100644
--- a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc
+++ b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc
@@ -15,6 +15,8 @@
 
 #include "tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.h"
 
+#include <string>
+
 #include "absl/log/log.h"
 #include "absl/status/status.h"
 #include "llvm/ADT/STLFunctionalExtras.h"
@@ -22,14 +24,13 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
 #include "mlir/Pass/PassManager.h"  // from @llvm-project
+#include "mlir/Pass/PassRegistry.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
-#include "tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
 #include "tensorflow/compiler/mlir/tf2xla/api/v2/device_type.pb.h"
-#include "tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor.h"
 #include "tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.h"
 #include "tensorflow/compiler/mlir/tf2xla/internal/logging_hooks.h"
 #include "tensorflow/core/framework/metrics.h"
@@ -37,12 +38,9 @@
 #include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/stacktrace.h"
 #include "tensorflow/core/platform/status.h"
-#include "tensorflow/core/tpu/tpu_defs.h"
 #include "tensorflow/core/util/debug_data_dumper.h"
-#include "tsl/framework/device_type.h"
 #include "tsl/platform/error_logging.h"
 #include "tsl/platform/errors.h"
-#include "tsl/platform/status.h"
 
 namespace tensorflow {
 namespace tf2xla {
@@ -113,16 +111,9 @@
   return diag_handler.ConsumeStatus();
 }
 
-void CreateTPUBridgePipeline(OpPassManager &pm, llvm::StringRef module_name) {
-  pm.addPass(mlir::TFTPU::CreateTPUValidateInputsPass());
-  pm.addNestedPass<mlir::func::FuncOp>(
-      mlir::TF::CreateCanonicalizeCompileAndReplicateAttributesPass());
-  tensorflow::tf2xla::internal::AddBridgeClusteringPipelinePasses(pm,
-                                                                  module_name);
-}
-
 tensorflow::Status RecordIfErrorStatus(const std::string error_prefix,
                                        bool fallback_enabled,
+                                       std::string device_type,
                                        absl::Status status) {
   if (status.ok()) {
     return status;
@@ -130,7 +121,7 @@
 
   VLOG(2) << error_prefix << " " << status;
   tensorflow::metrics::UpdateTfMlirBridgeFirstPhaseCounter(
-      /*device_type=*/"tpu", /*bridge_version=*/"v2",
+      device_type, /*bridge_version=*/"v2",
       /*fallback_enabled=*/fallback_enabled,
       /*result=*/"failure");
 
@@ -138,44 +129,50 @@
       tensorflow::core::platform::ErrorSourceProto::MLIR_BRIDGE_PHASE_1,
       status);
 
-  tsl::error_logging::Log(kBridgeComponent, "TFXLA_PHASE_ONE_MLIR_TPU_BRIDGE",
+  std::string bridge_subcomponent = "TFXLA_PHASE_ONE_MLIR_TPU_BRIDGE";
+  if (device_type != "tpu") {
+    bridge_subcomponent = "TFXLA_PHASE_ONE_MLIR_CPU/GPU_BRIDGE";
+  }
+
+  tsl::error_logging::Log(kBridgeComponent, bridge_subcomponent,
                           status.ToString())
       .IgnoreError();
 
   return status;
 }
 
+void CreateClusteringPipeline(OpPassManager &pm, llvm::StringRef module_name) {
+  pm.addPass(mlir::TFTPU::CreateTPUValidateInputsPass());
+  pm.addNestedPass<FuncOp>(
+      mlir::TF::CreateCanonicalizeCompileAndReplicateAttributesPass());
+  tensorflow::tf2xla::internal::AddBridgeClusteringPipelinePasses(pm,
+                                                                  module_name);
+}
+
+void CreateTPUClusteringPipelineV2(OpPassManager &pm) {
+  CreateClusteringPipeline(pm, /*module_name=*/"");
+}
+
 tensorflow::Status TPUBridge(ModuleOp module, bool fallback_enabled,
                              llvm::StringRef module_name) {
   VLOG(2)
       << "TPU Bridge called stack trace is "
       << "(NOTE: this is not an error; rather the stack trace for debugging) : "
       << tensorflow::CurrentStackTrace();
+  std::string device_type = "tpu";
   Status clustering_status = RunTFXLABridge(
       module,
       [module_name](OpPassManager &pm) {
-        CreateTPUBridgePipeline(pm, module_name);
+        CreateClusteringPipeline(pm, module_name);
       },
       module_name, /*dump_prefix=*/"tf_xla_bridge_v2_tpu");
 
   TF_RETURN_IF_ERROR(RecordIfErrorStatus(/*error_prefix=*/"clustering_v2",
-                                         fallback_enabled, clustering_status));
-
-  Status runtime_lowering_status =
-      tensorflow::tfrt_compiler::RunLowerClusterToRuntimeOpsPassPipeline(
-          module, tsl::DeviceType(DEVICE_TPU_XLA_JIT), module_name);
-  TF_RETURN_IF_ERROR(RecordIfErrorStatus(/*error_prefix=*/"runtime_lowering_v2",
-                                         fallback_enabled,
-                                         runtime_lowering_status));
-
-  Status export_status =
-      tensorflow::tf2xla::v2::ExportFromTensorflowDialectToExecutor(
-          module, module_name);
-  TF_RETURN_IF_ERROR(RecordIfErrorStatus(/*error_prefix=*/"export_to_executor",
-                                         fallback_enabled, export_status));
+                                         fallback_enabled, device_type,
+                                         clustering_status));
 
   tensorflow::metrics::UpdateTfMlirBridgeFirstPhaseCounter(
-      /*device_type=*/"tpu", /*bridge_version=*/"v2",
+      device_type, /*bridge_version=*/"v2",
       /*fallback_enabled=*/fallback_enabled,
       /*result=*/"success");
 
@@ -183,42 +180,31 @@
 }
 
 tensorflow::Status RunNonTPUBridge(ModuleOp module,
+                                   bool is_in_fallback_enabled_mode,
                                    llvm::StringRef module_name) {
   VLOG(2)
       << "CPU/GPU Bridge called stack trace is "
       << "(NOTE: this is not an error; rather the stack trace for debugging) : "
       << tensorflow::CurrentStackTrace();
-  Status status = RunTFXLABridge(
+
+  std::string device_type = "cpu/gpu";
+  Status clustering_status = RunTFXLABridge(
       module,
       [](OpPassManager &pm) {
         tensorflow::tf2xla::internal::AddNonTPUBridgeClusteringPipelinePasses(
             pm);
-        tensorflow::tfrt_compiler::
-            AddNonTPULowerClusterToRuntimeOpsPassPipeline(pm);
       },
       module_name, /*dump_prefix=*/"tf_xla_bridge_v2_nontpu");
+
+  TF_RETURN_IF_ERROR(RecordIfErrorStatus(/*error_prefix=*/"clustering_v2",
+                                         is_in_fallback_enabled_mode,
+                                         device_type, clustering_status));
+
   tensorflow::metrics::UpdateTfMlirBridgeFirstPhaseCounter(
-      /*device type*/ "cpu/gpu", /*bridge version*/ "tfxla",
-      /*fallback_enabled*/ false,
-      /*result*/ status.ok() ? "success" : "failure");
-  if (!status.ok()) {
-    tsl::error_logging::Log(kBridgeComponent,
-                            "TFXLA_PHASE_ONE_MLIR_CPU/GPU_BRIDGE",
-                            status.ToString())
-        .IgnoreError();
-  }
+      device_type, /*bridge_version=*/"v2", is_in_fallback_enabled_mode,
+      /*result=*/"success");
 
-  Status export_status =
-      tensorflow::tf2xla::v2::ExportFromTensorflowDialectToExecutor(
-          module, module_name);
-  if (!export_status.ok()) {
-    tsl::error_logging::Log(kBridgeComponent,
-                            "TFXLA_PHASE_ONE_MLIR_CPU_BRIDGE_EXPORT",
-                            export_status.ToString())
-        .IgnoreError();
-  }
-
-  return status;
+  return absl::OkStatus();
 }
 
 tensorflow::Status RunFunctionTf2xlaClusteringBridge(
@@ -229,9 +215,15 @@
                      /*module_name=*/module_name);
   }
 
-  return RunNonTPUBridge(module, module_name);
+  return RunNonTPUBridge(module, is_in_fallback_enabled_mode, module_name);
 }
 
+mlir::PassPipelineRegistration<> clustering_tpu_pipeline_v2(
+    "tf-cluster-tpu-bridge-v2",
+    "Run all the passes involved in transforming a TensorFlow 2 graph before "
+    "execution so that it is suitable for targeting TPUs.",
+    CreateTPUClusteringPipelineV2);
+
 }  // namespace v2
 }  // namespace tf2xla
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.h b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.h
index d597ca8..e1298ac 100644
--- a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.h
+++ b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.h
@@ -31,13 +31,14 @@
 // API. These transformations take as input a Tensorflow Graph as an MLIR Module
 // and transforms the module in place to cluster the given ops for compilation
 // that is compatible with the given device_type. The MLIR should be in the TF
-// Executor Dialect for graph nodes and edges. Individual Op inside a node
-// should be the Tensorflow Dialect. The output MLIR is in the TF Executor
-// Dialect. Returns OkStatus if passed, otherwise an error.
+// Executor Dialect for graph nodes and edges or be in TF Functional already.
+// Individual Op inside a node should be the Tensorflow Functional Dialect. The
+// output MLIR is in the TF Functional Dialect. Returns OkStatus if passed,
+// otherwise an error.
 //
 // Inputs:
 //   module - The MLIR Module that will be clustered. Expected to be in TF
-//   Executor Dialect
+//   Executor Dialect or TF Functional Dialect. Will convert to TF Functional.
 // . device_type - The device type to cluster for.
 //   is_in_fallback_enabled_mode - Whether this was called with fallback to the
 //   non-MLIR Bridge. This is just for logging purposes and doesn't affect
diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf_test.cc
index c2b6e36..d00d8b4 100644
--- a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf_test.cc
+++ b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf_test.cc
@@ -15,6 +15,7 @@
 
 #include "tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.h"
 
+#include <cstdint>
 #include <string>
 
 #include <gtest/gtest.h>
@@ -28,6 +29,7 @@
 #include "mlir/IR/Visitors.h"  // from @llvm-project
 #include "mlir/Parser/Parser.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/register_common_dialects.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
 #include "tensorflow/core/lib/monitoring/cell_reader.h"
 #include "tensorflow/core/platform/resource_loader.h"
@@ -80,7 +82,7 @@
   OwningOpRef<mlir::ModuleOp> mlir_module_;
 };
 
-TEST_F(FunctionClusterTensorflowDialectTest, ClustersTf) {
+TEST_F(FunctionClusterTensorflowDialectTest, ClustersTfTPU) {
   CellReader<int64_t> compilation_status(kCompilationStreamz);
 
   TF_ASSERT_OK(CreateMlirModule("empty_func.mlir"));
@@ -92,13 +94,29 @@
   FuncOp main = mlir_module_->lookupSymbol<mlir::func::FuncOp>("main");
   ASSERT_TRUE(main);
 
-  bool has_graph_op = false;
-  main.walk([&](mlir::tf_executor::GraphOp graph) {
-    has_graph_op = true;
+  EXPECT_EQ(
+      compilation_status.Delta("tpu", "v2", "fallback_disabled", "success"), 1);
+}
+
+TEST_F(FunctionClusterTensorflowDialectTest, RunsOutsideCompilationTPU) {
+  CellReader<int64_t> compilation_status(kCompilationStreamz);
+
+  TF_ASSERT_OK(CreateMlirModule("outside_compilation.mlir"));
+
+  TF_EXPECT_OK(
+      RunFunctionTf2xlaClusteringBridge(*mlir_module_, DeviceType::XLA_TPU_JIT,
+                                        /*is_in_fallback_enabled_mode=*/false));
+
+  FuncOp main = mlir_module_->lookupSymbol<mlir::func::FuncOp>("main");
+  ASSERT_TRUE(main);
+
+  bool has_cluster_op = false;
+  main.walk([&](mlir::tf_device::ClusterFuncOp cluster_op) {
+    has_cluster_op = true;
     return WalkResult::advance();
   });
 
-  EXPECT_TRUE(has_graph_op);
+  EXPECT_TRUE(has_cluster_op);
   EXPECT_EQ(
       compilation_status.Delta("tpu", "v2", "fallback_disabled", "success"), 1);
 }
@@ -115,17 +133,9 @@
   FuncOp main = mlir_module_->lookupSymbol<mlir::func::FuncOp>("main");
   ASSERT_TRUE(main);
 
-  bool has_graph_op = false;
-  main.walk([&](mlir::tf_executor::GraphOp graph) {
-    has_graph_op = true;
-    return WalkResult::advance();
-  });
-
-  EXPECT_TRUE(has_graph_op);
-
-  EXPECT_EQ(compilation_status.Delta("cpu/gpu", "tfxla", "fallback_disabled",
-                                     "success"),
-            1);
+  EXPECT_EQ(
+      compilation_status.Delta("cpu/gpu", "v2", "fallback_disabled", "success"),
+      1);
 }
 
 TEST_F(FunctionClusterTensorflowDialectTest, ClustersTFGPU) {
@@ -140,17 +150,22 @@
   FuncOp main = mlir_module_->lookupSymbol<mlir::func::FuncOp>("main");
   ASSERT_TRUE(main);
 
-  bool has_graph_op = false;
-  main.walk([&](mlir::tf_executor::GraphOp graph) {
-    has_graph_op = true;
-    return WalkResult::advance();
-  });
+  EXPECT_EQ(
+      compilation_status.Delta("cpu/gpu", "v2", "fallback_disabled", "success"),
+      1);
+}
 
-  EXPECT_TRUE(has_graph_op);
+TEST_F(FunctionClusterTensorflowDialectTest, LogsFallbackMode) {
+  CellReader<int64_t> compilation_status(kCompilationStreamz);
 
-  EXPECT_EQ(compilation_status.Delta("cpu/gpu", "tfxla", "fallback_disabled",
-                                     "success"),
-            1);
+  TF_ASSERT_OK(CreateMlirModule("empty_func.mlir"));
+
+  TF_EXPECT_OK(
+      RunFunctionTf2xlaClusteringBridge(*mlir_module_, DeviceType::XLA_TPU_JIT,
+                                        /*is_in_fallback_enabled_mode=*/true));
+
+  EXPECT_EQ(
+      compilation_status.Delta("tpu", "v2", "fallback_enabled", "success"), 1);
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/testdata/outside_compilation.mlir b/tensorflow/compiler/mlir/tf2xla/api/v2/testdata/outside_compilation.mlir
new file mode 100644
index 0000000..67434cf
--- /dev/null
+++ b/tensorflow/compiler/mlir/tf2xla/api/v2/testdata/outside_compilation.mlir
@@ -0,0 +1,135 @@
+module attributes {tf.devices = {"/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"}, tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 1654 : i32}} {
+  func.func @main(%arg0: tensor<*x!tf_type.resource> {tf._user_specified_name = "input_1", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}) attributes {allow_soft_placement = true, tf.entry_function = {control_outputs = "while,image_sample/write_summary/summary_cond", inputs = "image_sample_write_summary_summary_cond_input_1", outputs = ""}} {
+    tf_executor.graph {
+      %outputs, %control = tf_executor.island wraps "tf.Const"() {device = "", value = dense<0> : tensor<i32>} : () -> tensor<i32>
+      %outputs_0, %control_1 = tf_executor.island wraps "tf.Const"() {device = "", value = dense<1> : tensor<i32>} : () -> tensor<i32>
+      %outputs_2, %control_3 = tf_executor.island wraps "tf.TPUReplicatedInput"(%outputs, %outputs_0) {device = "", index = -1 : i64, is_mirrored_variable = false, is_packed = false} : (tensor<i32>, tensor<i32>) -> tensor<*xi32>
+      %outputs_4, %control_5 = tf_executor.island wraps "tf.Const"() {device = "", value = dense<0> : tensor<i32>} : () -> tensor<i32>
+      %outputs_6, %control_7 = tf_executor.island wraps "tf.Const"() {_post_device_rewrite = true, device = "", value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
+      %outputs_8, %control_9 = tf_executor.island wraps "tf.Const"() {_post_device_rewrite = true, device = "", value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
+      %outputs_10, %control_11 = tf_executor.island wraps "tf.Pack"(%outputs_6, %outputs_8) {axis = 0 : i64, device = ""} : (tensor<0xi32>, tensor<0xi32>) -> tensor<*xi32>
+      %outputs_12, %control_13 = tf_executor.island wraps "tf.Max"(%outputs_10, %outputs_4) {device = "", keep_dims = false} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
+      %control_14 = tf_executor.island wraps "tf.NoOp"() {_pivot_for_cluster = "cluster_sample_sequence", device = ""} : () -> ()
+      %control_15 = tf_executor.island(%control_14) wraps "tf.NoOp"() {_has_manual_control_dependencies = true, _tpu_replicate = "cluster_sample_sequence", device = ""} : () -> ()
+      %control_16 = tf_executor.island(%control_15) wraps "tf.NoOp"() {device = ""} : () -> ()
+      %control_17 = tf_executor.island(%control_15) wraps "tf.NoOp"() {device = ""} : () -> ()
+      %control_18 = tf_executor.island(%control_14) wraps "tf.TPUReplicateMetadata"() {_has_manual_control_dependencies = true, _tpu_replicate = "cluster_sample_sequence", allow_soft_placement = true, computation_shape = [], device = "", device_assignment = [], host_compute_core = [], num_cores_per_replica = 1 : i64, num_replicas = 2 : i64, padding_map = [], step_marker_location = "STEP_MARK_AT_ENTRY", topology = "", tpu_compile_options_proto = "", use_spmd_for_xla_partitioning = false, use_tpu = true} : () -> ()
+      %outputs_19, %control_20 = tf_executor.island(%control_18) wraps "tf.Const"() {_tpu_replicate = "cluster_sample_sequence", device = "", value = dense<0> : tensor<i32>} : () -> tensor<i32>
+      %outputs_21, %control_22 = tf_executor.island(%control_18) wraps "tf.Const"() {_tpu_replicate = "cluster_sample_sequence", device = "", value = dense<5> : tensor<i64>} : () -> tensor<i64>
+      %outputs_23, %control_24 = tf_executor.island(%control_18) wraps "tf.Const"() {_tpu_replicate = "cluster_sample_sequence", device = "", value = dense<[3, 32, 32, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+      %outputs_25, %control_26 = tf_executor.island(%control_18) wraps "tf.TPUCompilationResult"() {_tpu_compilation_status = "cluster_sample_sequence", device = ""} : () -> tensor<!tf_type.string>
+      %outputs_27, %control_28 = tf_executor.island(%control_18) wraps "tf.Const"() {_tpu_replicate = "cluster_sample_sequence", device = "", value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
+      %outputs_29, %control_30 = tf_executor.island(%control_18) wraps "tf.Const"() {_tpu_replicate = "cluster_sample_sequence", device = "", value = dense<[1, 1, 1, 3]> : tensor<4xi32>} : () -> tensor<4xi32>
+      %outputs_31, %control_32 = tf_executor.island(%control_18) wraps "tf.Const"() {_tpu_replicate = "cluster_sample_sequence", device = "", value = dense<0> : tensor<i32>} : () -> tensor<i32>
+      %outputs_33, %control_34 = tf_executor.island(%control_18) wraps "tf.Const"() {_tpu_replicate = "cluster_sample_sequence", device = "", value = dense<true> : tensor<i1>} : () -> tensor<i1>
+      %outputs_35, %control_36 = tf_executor.island(%control_18) wraps "tf.Identity"(%outputs_2) {_tpu_input_identity = true, _tpu_replicate = "cluster_sample_sequence", device = ""} : (tensor<*xi32>) -> tensor<*xi32>
+      %outputs_37, %control_38 = tf_executor.island wraps "tf.Equal"(%outputs_35, %outputs_31) {_tpu_replicate = "cluster_sample_sequence", device = "", incompatible_shape_error = true} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1>
+      %outputs_39, %control_40 = tf_executor.island wraps "tf.LogicalAnd"(%outputs_37, %outputs_33) {_tpu_replicate = "cluster_sample_sequence", device = ""} : (tensor<*xi1>, tensor<i1>) -> tensor<*xi1>
+      %outputs_41, %control_42 = tf_executor.island(%control_18) wraps "tf.Const"() {_tpu_replicate = "cluster_sample_sequence", device = "", value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
+      %outputs_43, %control_44 = tf_executor.island(%control_18) wraps "tf.Const"() {_tpu_replicate = "cluster_sample_sequence", device = "", value = dense<1024> : tensor<i32>} : () -> tensor<i32>
+      %outputs_45, %control_46 = tf_executor.island wraps "tf.TensorListReserve"(%outputs_41, %outputs_43) {_tpu_replicate = "cluster_sample_sequence", device = ""} : (tensor<1xi32>, tensor<i32>) -> tensor<!tf_type.variant<tensor<*xf32>>>
+      %outputs_47, %control_48 = tf_executor.island(%control_18) wraps "tf.Const"() {_tpu_replicate = "cluster_sample_sequence", device = "", value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+      %outputs_49, %control_50 = tf_executor.island(%control_18) wraps "tf.Const"() {_tpu_replicate = "cluster_sample_sequence", device = "", value = dense<0> : tensor<i32>} : () -> tensor<i32>
+      %outputs_51, %control_52 = tf_executor.island(%control_18) wraps "tf.Const"() {_tpu_replicate = "cluster_sample_sequence", device = "", value = dense<-1> : tensor<i32>} : () -> tensor<i32>
+      %outputs_53:4, %control_54 = tf_executor.island wraps "tf.While"(%outputs_49, %outputs_51, %outputs_19, %outputs_45) {_num_original_outputs = 4 : i64, _read_only_resource_inputs = [], _tpu_replicate = "cluster_sample_sequence", _xla_propagate_compile_time_consts = true, body = @while_body_260, cond = @while_cond_250, device = "", is_stateless = false, parallel_iterations = 10 : i64, shape_invariant} : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<!tf_type.variant<tensor<*xf32>>>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<!tf_type.variant>)
+      %outputs_55, %control_56 = tf_executor.island wraps "tf.TensorListStack"(%outputs_53#3, %outputs_27) {_tpu_replicate = "cluster_sample_sequence", device = "", num_elements = 1024 : i64} : (tensor<!tf_type.variant>, tensor<1xi32>) -> tensor<*xf32>
+      %outputs_57, %control_58 = tf_executor.island wraps "tf.Transpose"(%outputs_55, %outputs_47) {_tpu_replicate = "cluster_sample_sequence", device = ""} : (tensor<*xf32>, tensor<2xi32>) -> tensor<*xf32>
+      %outputs_59, %control_60 = tf_executor.island wraps "tf.Reshape"(%outputs_57, %outputs_23) {_tpu_replicate = "cluster_sample_sequence", device = ""} : (tensor<*xf32>, tensor<4xi32>) -> tensor<*xf32>
+      %outputs_61, %control_62 = tf_executor.island wraps "tf.Tile"(%outputs_59, %outputs_29) {_tpu_replicate = "cluster_sample_sequence", device = ""} : (tensor<*xf32>, tensor<4xi32>) -> tensor<*xf32>
+      %outputs_63, %control_64 = tf_executor.island wraps "tf.If"(%outputs_39, %outputs_61, %arg0, %outputs_21) {_read_only_resource_inputs = [], _tpu_replicate = "cluster_sample_sequence", _xla_propagate_compile_time_consts = true, device = "", else_branch = @image_sample_write_summary_summary_cond_false_710, is_stateless = false, then_branch = @image_sample_write_summary_summary_cond_true_700} : (tensor<*xi1>, tensor<*xf32>, tensor<*x!tf_type.resource>, tensor<i64>) -> tensor<*xi1>
+      %outputs_65, %control_66 = tf_executor.island wraps "tf.Identity"(%outputs_63) {_tpu_replicate = "cluster_sample_sequence", device = ""} : (tensor<*xi1>) -> tensor<*xi1>
+      tf_executor.fetch %control_54, %control_64 : !tf_executor.control, !tf_executor.control
+    }
+    return
+  }
+  func.func private @while_body_260(%arg0: tensor<i32> {tf._user_specified_name = "while/loop_counter"}, %arg1: tensor<i32> {tf._user_specified_name = "while/maximum_iterations"}, %arg2: tensor<i32>, %arg3: tensor<!tf_type.variant>) -> (tensor<*xi32>, tensor<*xi32>, tensor<*xi32>, tensor<*x!tf_type.variant>) attributes {tf._construction_context = "kEagerRuntime", tf.signature.is_stateful} {
+    %0:4 = tf_executor.graph {
+      %outputs, %control = tf_executor.island wraps "tf.Const"() {device = "", value = dense<1> : tensor<i32>} : () -> tensor<i32>
+      %outputs_0, %control_1 = tf_executor.island wraps "tf.Const"() {device = "", value = dense<1> : tensor<i32>} : () -> tensor<i32>
+      %outputs_2, %control_3 = tf_executor.island wraps "tf.Const"() {device = "", value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
+      %outputs_4, %control_5 = tf_executor.island wraps "tf.RandomUniform"(%outputs_2) {device = "", seed = 87654321 : i64, seed2 = 0 : i64} : (tensor<1xi32>) -> tensor<*xf32>
+      %outputs_6, %control_7 = tf_executor.island wraps "tf.AddV2"(%arg2, %outputs) {device = ""} : (tensor<i32>, tensor<i32>) -> tensor<*xi32>
+      %outputs_8, %control_9 = tf_executor.island wraps "tf.Identity"(%outputs_6) {device = ""} : (tensor<*xi32>) -> tensor<*xi32>
+      %outputs_10, %control_11 = tf_executor.island wraps "tf.TensorListSetItem"(%arg3, %arg2, %outputs_4) {device = "", resize_if_index_out_of_bounds = false} : (tensor<!tf_type.variant>, tensor<i32>, tensor<*xf32>) -> tensor<*x!tf_type.variant>
+      %outputs_12, %control_13 = tf_executor.island wraps "tf.Identity"(%outputs_10) {device = ""} : (tensor<*x!tf_type.variant>) -> tensor<*x!tf_type.variant>
+      %outputs_14, %control_15 = tf_executor.island wraps "tf.AddV2"(%arg0, %outputs_0) {device = ""} : (tensor<i32>, tensor<i32>) -> tensor<*xi32>
+      %outputs_16, %control_17 = tf_executor.island wraps "tf.Identity"(%outputs_14) {device = ""} : (tensor<*xi32>) -> tensor<*xi32>
+      %outputs_18, %control_19 = tf_executor.island wraps "tf.Identity"(%arg1) {device = ""} : (tensor<i32>) -> tensor<*xi32>
+      tf_executor.fetch %outputs_16, %outputs_18, %outputs_8, %outputs_12 : tensor<*xi32>, tensor<*xi32>, tensor<*xi32>, tensor<*x!tf_type.variant>
+    }
+    return %0#0, %0#1, %0#2, %0#3 : tensor<*xi32>, tensor<*xi32>, tensor<*xi32>, tensor<*x!tf_type.variant>
+  }
+  func.func private @while_cond_250(%arg0: tensor<i32> {tf._user_specified_name = "while/loop_counter"}, %arg1: tensor<i32> {tf._user_specified_name = "while/maximum_iterations"}, %arg2: tensor<i32>, %arg3: tensor<!tf_type.variant>) -> tensor<*xi1> attributes {tf._construction_context = "kEagerRuntime"} {
+    %0 = tf_executor.graph {
+      %outputs, %control = tf_executor.island wraps "tf.Const"() {device = "", value = dense<1024> : tensor<i32>} : () -> tensor<i32>
+      %outputs_0, %control_1 = tf_executor.island wraps "tf.Less"(%arg2, %outputs) {device = ""} : (tensor<i32>, tensor<i32>) -> tensor<*xi1>
+      %outputs_2, %control_3 = tf_executor.island wraps "tf.Identity"(%outputs_0) {device = ""} : (tensor<*xi1>) -> tensor<*xi1>
+      tf_executor.fetch %outputs_2 : tensor<*xi1>
+    }
+    return %0 : tensor<*xi1>
+  }
+  func.func private @image_sample_write_summary_summary_cond_false_710(%arg0: tensor<3x32x32x3xf32>, %arg1: tensor<*x!tf_type.resource>, %arg2: tensor<i64>) -> tensor<*xi1> attributes {tf._construction_context = "kEagerRuntime"} {
+    %0 = tf_executor.graph {
+      %outputs, %control = tf_executor.island wraps "tf.Const"() {device = "", value = dense<false> : tensor<i1>} : () -> tensor<i1>
+      %outputs_0, %control_1 = tf_executor.island wraps "tf.Identity"(%outputs) {device = ""} : (tensor<i1>) -> tensor<*xi1>
+      tf_executor.fetch %outputs_0 : tensor<*xi1>
+    }
+    return %0 : tensor<*xi1>
+  }
+  func.func private @image_sample_write_summary_summary_cond_true_700(%arg0: tensor<3x32x32x3xf32> {tf._user_specified_name = "Tile"}, %arg1: tensor<*x!tf_type.resource> {tf._user_specified_name = "writer"}, %arg2: tensor<i64> {tf._user_specified_name = "Const_3"}) -> tensor<*xi1> attributes {tf._construction_context = "kEagerRuntime", tf.signature.is_stateful} {
+    %0 = tf_executor.graph {
+      %outputs, %control = tf_executor.island wraps "tf.Const"() {device = "/device:CPU:0", value = dense<[3, 32, 32, 3]> : tensor<4xi32>} : () -> tensor<4xi32>
+      %outputs_0, %control_1 = tf_executor.island wraps "tf.Const"() {device = "/device:CPU:0", value = dense<0> : tensor<i32>} : () -> tensor<i32>
+      %outputs_2, %control_3 = tf_executor.island wraps "tf.Const"() {device = "/device:CPU:0", value = dense<""> : tensor<!tf_type.string>} : () -> tensor<!tf_type.string>
+      %outputs_4, %control_5 = tf_executor.island wraps "tf.Const"() {device = "/device:CPU:0", value = dense<"Condition x >= 0 did not hold element-wise:"> : tensor<!tf_type.string>} : () -> tensor<!tf_type.string>
+      %outputs_6, %control_7 = tf_executor.island wraps "tf.Const"() {device = "/device:CPU:0", value = dense<"x (image_sample/write_summary/summary_cond/assert_non_negative/x:0) = "> : tensor<!tf_type.string>} : () -> tensor<!tf_type.string>
+      %outputs_8, %control_9 = tf_executor.island wraps "tf.Const"() {device = "/device:CPU:0", value = dense<""> : tensor<!tf_type.string>} : () -> tensor<!tf_type.string>
+      %outputs_10, %control_11 = tf_executor.island wraps "tf.Const"() {device = "/device:CPU:0", value = dense<"Condition x >= 0 did not hold element-wise:"> : tensor<!tf_type.string>} : () -> tensor<!tf_type.string>
+      %outputs_12, %control_13 = tf_executor.island wraps "tf.Const"() {device = "/device:CPU:0", value = dense<"x (image_sample/write_summary/summary_cond/assert_non_negative/x:0) = "> : tensor<!tf_type.string>} : () -> tensor<!tf_type.string>
+      %outputs_14, %control_15 = tf_executor.island wraps "tf.Const"() {device = "/device:CPU:0", value = dense<0> : tensor<i32>} : () -> tensor<i32>
+      %outputs_16, %control_17 = tf_executor.island wraps "tf.Const"() {device = "/device:CPU:0", value = dense<1> : tensor<i32>} : () -> tensor<i32>
+      %outputs_18, %control_19 = tf_executor.island wraps "tf.Const"() {device = "/device:CPU:0", value = dense<0> : tensor<i32>} : () -> tensor<i32>
+      %outputs_20, %control_21 = tf_executor.island wraps "tf.Range"(%outputs_18, %outputs_14, %outputs_16) {device = "/device:CPU:0"} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<*xi32>
+      %outputs_22, %control_23 = tf_executor.island wraps "tf.Const"() {device = "/device:CPU:0", value = dense<3> : tensor<i32>} : () -> tensor<i32>
+      %outputs_24, %control_25 = tf_executor.island wraps "tf.LessEqual"(%outputs_0, %outputs_22) {device = "/device:CPU:0"} : (tensor<i32>, tensor<i32>) -> tensor<*xi1>
+      %outputs_26, %control_27 = tf_executor.island wraps "tf.All"(%outputs_24, %outputs_20) {device = "/device:CPU:0", keep_dims = false} : (tensor<*xi1>, tensor<*xi32>) -> tensor<*xi1>
+      %control_28 = tf_executor.island wraps "tf.Assert"(%outputs_26, %outputs_2, %outputs_4, %outputs_6, %outputs_22) {device = "/device:CPU:0", summarize = 3 : i64} : (tensor<*xi1>, tensor<!tf_type.string>, tensor<!tf_type.string>, tensor<!tf_type.string>, tensor<i32>) -> ()
+      %outputs_29, %control_30 = tf_executor.island wraps "tf.Const"() {device = "/device:CPU:0", value = dense<[3, 32, 32, 3]> : tensor<4xi32>} : () -> tensor<4xi32>
+      %control_31 = tf_executor.island wraps "tf.NoOp"() {device = "/device:CPU:0"} : () -> ()
+      %outputs_32, %control_33 = tf_executor.island wraps "tf.Const"() {device = "/device:CPU:0", value = dense<4> : tensor<i32>} : () -> tensor<i32>
+      %control_34 = tf_executor.island wraps "tf.NoOp"() {device = "/device:CPU:0"} : () -> ()
+      %outputs_35, %control_36 = tf_executor.island wraps "tf.Const"() {device = "/device:CPU:0", value = dense<0> : tensor<i32>} : () -> tensor<i32>
+      %outputs_37, %control_38 = tf_executor.island wraps "tf.Const"() {device = "/device:CPU:0", value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
+      %outputs_39, %control_40 = tf_executor.island wraps "tf.Const"() {device = "/device:CPU:0", value = dense<2.550000e+02> : tensor<f32>} : () -> tensor<f32>
+      %outputs_41, %control_42 = tf_executor.island wraps "tf.Const"() {device = "/device:CPU:0", value = dense<2.555000e+02> : tensor<f32>} : () -> tensor<f32>
+      %outputs_43, %control_44 = tf_executor.island wraps "tf.Const"() {device = "/device:CPU:0", value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+      %outputs_45, %control_46 = tf_executor.island wraps "tf.Const"() {device = "/device:CPU:0", value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
+      %outputs_47, %control_48 = tf_executor.island wraps "tf.Const"() {device = "/device:CPU:0", value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+      %outputs_49, %control_50 = tf_executor.island wraps "tf.Const"() {device = "/device:CPU:0", value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+      %outputs_51, %control_52 = tf_executor.island wraps "tf.Const"() {device = "/device:CPU:0", value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
+      %outputs_53, %control_54 = tf_executor.island wraps "tf.Const"() {device = "/device:CPU:0", value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+      %outputs_55, %control_56 = tf_executor.island wraps "tf.StridedSlice"(%outputs, %outputs_49, %outputs_51, %outputs_53) {begin_mask = 0 : i64, device = "/device:CPU:0", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32>
+      %outputs_57, %control_58 = tf_executor.island wraps "tf.AsString"(%outputs_55) {device = "/device:CPU:0", fill = "", precision = -1 : i64, scientific = false, shortest = false, width = -1 : i64} : (tensor<*xi32>) -> tensor<*x!tf_type.string>
+      %outputs_59, %control_60 = tf_executor.island wraps "tf.Const"() {device = "/device:CPU:0", value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+      %outputs_61, %control_62 = tf_executor.island wraps "tf.Const"() {device = "/device:CPU:0", value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+      %outputs_63, %control_64 = tf_executor.island wraps "tf.Const"() {device = "/device:CPU:0", value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+      %outputs_65, %control_66 = tf_executor.island wraps "tf.StridedSlice"(%outputs, %outputs_59, %outputs_61, %outputs_63) {begin_mask = 0 : i64, device = "/device:CPU:0", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32>
+      %outputs_67, %control_68 = tf_executor.island wraps "tf.AsString"(%outputs_65) {device = "/device:CPU:0", fill = "", precision = -1 : i64, scientific = false, shortest = false, width = -1 : i64} : (tensor<*xi32>) -> tensor<*x!tf_type.string>
+      %outputs_69, %control_70 = tf_executor.island wraps "tf.Pack"(%outputs_57, %outputs_67) {axis = 0 : i64, device = "/device:CPU:0"} : (tensor<*x!tf_type.string>, tensor<*x!tf_type.string>) -> tensor<*x!tf_type.string>
+      %outputs_71, %control_72 = tf_executor.island wraps "tf.Const"() {device = "/device:CPU:0", value = dense<"\0A\08\0A\06images"> : tensor<!tf_type.string>} : () -> tensor<!tf_type.string>
+      %outputs_73, %control_74 = tf_executor.island wraps "tf.Const"() {device = "/device:CPU:0", value = dense<"image_sample"> : tensor<!tf_type.string>} : () -> tensor<!tf_type.string>
+      %outputs_75, %control_76 = tf_executor.island wraps "tf.Mul"(%arg0, %outputs_41) {device = "/device:CPU:0"} : (tensor<3x32x32x3xf32>, tensor<f32>) -> tensor<*xf32>
+      %outputs_77, %control_78 = tf_executor.island wraps "tf.ClipByValue"(%outputs_75, %outputs_37, %outputs_39) {device = "/device:CPU:0"} : (tensor<*xf32>, tensor<f32>, tensor<f32>) -> tensor<*xf32>
+      %outputs_79, %control_80 = tf_executor.island wraps "tf.Cast"(%outputs_77) {Truncate = false, device = "/device:CPU:0"} : (tensor<*xf32>) -> tensor<*xui8>
+      %outputs_81, %control_82 = tf_executor.island wraps "tf.StridedSlice"(%outputs_79, %outputs_43, %outputs_45, %outputs_47) {begin_mask = 1 : i64, device = "/device:CPU:0", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<*xui8>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xui8>
+      %outputs_83, %control_84 = tf_executor.island wraps "tf.EncodePng"(%outputs_81) {compression = -1 : i64, device = "/device:CPU:0"} : (tensor<*xui8>) -> tensor<*x!tf_type.string>
+      %outputs_85, %control_86 = tf_executor.island wraps "tf.ConcatV2"(%outputs_69, %outputs_83, %outputs_35) {device = "/device:CPU:0"} : (tensor<*x!tf_type.string>, tensor<*x!tf_type.string>, tensor<i32>) -> tensor<*x!tf_type.string>
+      %control_87 = tf_executor.island wraps "tf.WriteSummary"(%arg1, %arg2, %outputs_85, %outputs_73, %outputs_71) {_has_manual_control_dependencies = true, device = "/device:CPU:0"} : (tensor<*x!tf_type.resource>, tensor<i64>, tensor<*x!tf_type.string>, tensor<!tf_type.string>, tensor<!tf_type.string>) -> ()
+      %outputs_88, %control_89 = tf_executor.island(%control_87) wraps "tf.Const"() {device = "/device:CPU:0", value = dense<true> : tensor<i1>} : () -> tensor<i1>
+      %control_90 = tf_executor.island(%control_28, %control_87) wraps "tf.NoOp"() {device = ""} : () -> ()
+      %outputs_91, %control_92 = tf_executor.island(%control_90) wraps "tf.Identity"(%outputs_88) {device = ""} : (tensor<i1>) -> tensor<*xi1>
+      tf_executor.fetch %outputs_91, %control_28, %control_87 : tensor<*xi1>, !tf_executor.control, !tf_executor.control
+    }
+    return %0 : tensor<*xi1>
+  }
+}
\ No newline at end of file
diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor.cc
index da42af2..69f1c0e 100644
--- a/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor.cc
+++ b/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor.cc
@@ -18,6 +18,7 @@
 #include <memory>
 #include <string>
 
+#include "absl/log/log.h"
 #include "absl/status/status.h"
 #include "llvm/ADT/StringRef.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
@@ -31,9 +32,11 @@
 #include "tensorflow/compiler/mlir/tensorflow/utils/data_dumper_logger_config.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
 #include "tensorflow/compiler/mlir/tf2xla/internal/logging_hooks.h"
+#include "tensorflow/core/platform/error_payloads.h"
 #include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/util/debug_data_dumper.h"
 #include "tsl/lib/monitoring/counter.h"
+#include "tsl/platform/error_logging.h"
 #include "tsl/platform/status.h"
 
 namespace tensorflow {
@@ -93,6 +96,30 @@
   pm.addPass(mlir::TF::CreateVerifySuitableForExportPass());
 }
 
+tensorflow::Status RecordStatusIfError(absl::Status status) {
+  if (status.ok()) {
+    return absl::OkStatus();
+  }
+
+  tf_dialect_to_executor_dialect_status->GetCell(kExportFailed)->IncrementBy(1);
+  VLOG(1) << "Failed to export from TF Dialect to TF Executor Dialect. "
+          << status;
+
+  constexpr char bridge_subcomponent[] =
+      "TFXLA_TF_FUNCTIONAL_TO_EXECUTOR_EXPORT_v2";
+  constexpr char kBridgeComponent[] = "TFXLABridge";
+
+  tsl::OkOrSetErrorCounterPayload(
+      tensorflow::core::platform::ErrorSourceProto::MLIR_BRIDGE_PHASE_1,
+      status);
+
+  tsl::error_logging::Log(kBridgeComponent, bridge_subcomponent,
+                          status.ToString())
+      .IgnoreError();
+
+  return status;
+}
+
 }  // namespace
 
 tensorflow::Status ExportFromTensorflowDialectToExecutor(
@@ -128,12 +155,10 @@
         module, llvm::StringRef(), &tf_to_executor);
   }
 
-  if (!result.succeeded()) {
-    tf_dialect_to_executor_dialect_status->GetCell(kExportFailed)
-        ->IncrementBy(1);
-
-    return absl::InternalError(
-        "Failed to export from TF Dialect to TF Executor Dialect.");
+  if (result.failed()) {
+    return RecordStatusIfError(
+        absl::InternalError("Failed to export from TF Dialect to TF Executor "
+                            "Dialect. Read LLVM Pipeline Error"));
   }
 
   tf_dialect_to_executor_dialect_status->GetCell(kExportSuccess)
diff --git a/tensorflow/compiler/mlir/tf2xla/internal/BUILD b/tensorflow/compiler/mlir/tf2xla/internal/BUILD
index c2516cb..b7b060b 100644
--- a/tensorflow/compiler/mlir/tf2xla/internal/BUILD
+++ b/tensorflow/compiler/mlir/tf2xla/internal/BUILD
@@ -187,6 +187,7 @@
         "//tensorflow/compiler/jit:flags_headers",
         "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes",
         "//tensorflow/compiler/mlir/tensorflow/transforms:verify_no_outside_compilation_markers_pass",
+        "//tensorflow/compiler/mlir/tf2xla/internal/passes:clustering_passes",
         "@com_google_absl//absl/base:core_headers",
         "@com_google_absl//absl/log",
         "@llvm-project//llvm:Support",
diff --git a/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc b/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc
index 08fa8f4..1cad3d1 100644
--- a/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc
+++ b/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc
@@ -25,6 +25,7 @@
 #include "mlir/Transforms/Passes.h"  // from @llvm-project
 #include "tensorflow/compiler/jit/flags.h"
 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
+#include "tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h"
 
 namespace tensorflow {
 namespace tf2xla {
@@ -69,7 +70,8 @@
   // preserved and the sequencing rewrite will trigger.
   pm.addPass(mlir::TFDevice::CreateEmbeddingPipeliningPass());
   pm.addPass(mlir::TFDevice::CreateEmbeddingSequencingPass());
-  pm.addPass(mlir::TFTPU::CreateTPUClusterFormationPass(strict_clusters));
+  pm.addPass(tensorflow::tf2xla::internal::CreateTPUClusterFormationPass(
+      strict_clusters));
   // CreateEmbeddingPipeliningPass may have created more functions, but
   // TPUClusterCleanup and OutsideCompiledToHostLaunch need every function to be
   // only called from one cluster. Here, we choose to fix the all-funcs-one-use
@@ -142,7 +144,8 @@
   pm.addPass(mlir::TFDevice::CreateHostLaunchToOutsideCompiledPass());
 
   pm.addPass(mlir::TFDevice::CreateMarkOpsForOutsideCompilationPass());
-  pm.addPass(mlir::TFDevice::CreateExtractHeadTailOutsideCompilationPass());
+  pm.addPass(tensorflow::tf2xla::internal::
+                 CreateExtractHeadTailOutsideCompilationPass());
   pm.addPass(mlir::TFDevice::CreateExtractOutsideCompilationPass());
   pm.addNestedPass<FuncOp>(
       mlir::TFDevice::CreateVerifyNoOutsideCompilationMarkersPass());
@@ -160,6 +163,9 @@
   pm.addPass(mlir::TFTPU::CreateTPUAnnotateDynamicShapeInputsPass());
   pm.addNestedPass<FuncOp>(
       mlir::TF::CreateHoistReplicateInvariantResourceWritesPass());
+  // Verifies clustering has conformed with the expected invariants
+  pm.addNestedPass<FuncOp>(
+      tensorflow::tf2xla::internal::CreateVerifyClusteringPass());
 }
 
 void NoCanonicalization(OpPassManager& pm) {}
@@ -218,11 +224,15 @@
   if (tensorflow::GetMlirCommonFlags()
           ->tf_mlir_enable_generic_outside_compilation) {
     pm.addPass(mlir::TFDevice::CreateMarkOpsForOutsideCompilationPass());
-    pm.addPass(mlir::TFDevice::CreateExtractHeadTailOutsideCompilationPass());
+    pm.addPass(tensorflow::tf2xla::internal::
+                   CreateExtractHeadTailOutsideCompilationPass());
     pm.addPass(mlir::TFDevice::CreateExtractOutsideCompilationPass());
   }
   // Outline clusters into cluster functions.
   pm.addPass(mlir::TFDevice::CreateClusterOutliningPass());
+  // Verifies clustering has conformed with the expected invariants
+  pm.addNestedPass<FuncOp>(
+      tensorflow::tf2xla::internal::CreateVerifyClusteringPass());
 }
 
 };  // namespace internal
diff --git a/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes_test.cc b/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes_test.cc
index 985a354..91b80fa 100644
--- a/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes_test.cc
+++ b/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes_test.cc
@@ -28,14 +28,14 @@
   OpPassManager pass_manager;
   AddBridgeClusteringPipelinePasses(pass_manager);
 
-  EXPECT_EQ(pass_manager.size(), 46);
+  EXPECT_EQ(pass_manager.size(), 47);
 }
 
 TEST(ClusteringBridgePassesTest, AddsNonTPUBridgePasses) {
   OpPassManager pass_manager;
   AddNonTPUBridgeClusteringPipelinePasses(pass_manager);
 
-  EXPECT_EQ(pass_manager.size(), 14);
+  EXPECT_EQ(pass_manager.size(), 15);
 }
 
 };  // namespace internal
diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD b/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD
new file mode 100644
index 0000000..0e25e62
--- /dev/null
+++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD
@@ -0,0 +1,203 @@
+load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
+load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")
+load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load("//tensorflow:tensorflow.default.bzl", "filegroup", "get_compatible_with_portable")
+load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
+
+package(
+    # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
+    default_visibility = [
+        "//learning/pathways/serving/transforms:__pkg__",
+        "//tensorflow/compiler/mlir:__pkg__",
+        "//tensorflow/compiler/mlir/tensorflow:__pkg__",
+        "//tensorflow/compiler/mlir/tf2xla/internal:__subpackages__",
+    ],
+    licenses = ["notice"],
+)
+
+cc_library(
+    name = "clustering_passes",
+    srcs = [
+        "verify_clustering_pass.cc",
+    ],
+    hdrs = [
+        "clustering_passes.h",
+    ],
+    textual_hdrs = [
+        "clustering_passes.h.inc",
+    ],
+    deps = [
+        ":clustering_passes_inc_gen",
+        ":extract_head_tail_outside_compilation",
+        ":tpu_cluster_formation",
+        "//tensorflow/compiler/mlir/tensorflow",
+        "//tensorflow/compiler/mlir/tensorflow:attribute_utils",
+        "//tensorflow/compiler/mlir/tensorflow:string_util",
+        "//tensorflow/compiler/mlir/tensorflow:tensorflow_analysis",
+        "//tensorflow/compiler/mlir/tensorflow:tpu_rewrite_device_util",
+        "//tensorflow/core:framework",
+        "//tensorflow/core/transforms/toposort:Pass",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/strings",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:FuncDialect",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:Transforms",
+    ],
+)
+
+gentbl_cc_library(
+    name = "clustering_passes_inc_gen",
+    compatible_with = get_compatible_with_portable(),
+    tbl_outs = [
+        (
+            [
+                "-gen-pass-decls",
+                "-name=TFXLABridge",
+            ],
+            "clustering_passes.h.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "clustering_passes.td",
+    deps = [
+        "@llvm-project//mlir:PassBaseTdFiles",
+    ],
+)
+
+tf_cc_test(
+    name = "verify_clustering_pass_test",
+    srcs = ["verify_clustering_pass_test.cc"],
+    deps = [
+        ":clustering_passes",
+        "//tensorflow/compiler/mlir/tf2xla/transforms:test_utils",
+        "@com_google_absl//absl/strings",
+        "@com_google_googletest//:gtest_main",
+        "@llvm-project//mlir:FuncDialect",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:Support",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+glob_lit_tests(
+    name = "all_tests",
+    data = [":test_utilities"],
+    driver = "@llvm-project//mlir:run_lit.sh",
+    test_file_exts = [
+        "mlir",
+    ],
+)
+
+# Bundle together all of the test utilities that are used by tests.
+filegroup(
+    name = "test_utilities",
+    testonly = True,
+    data = [
+        "//tensorflow/compiler/mlir:tf-opt",
+        "@llvm-project//llvm:FileCheck",
+    ],
+)
+
+cc_library(
+    name = "tpu_cluster_formation",
+    srcs = ["tpu_cluster_formation.cc"],
+    textual_hdrs = [
+        "clustering_passes.h.inc",
+    ],
+    deps = [
+        ":clustering_passes_inc_gen",
+        "//tensorflow/compiler/mlir/tensorflow",
+        "//tensorflow/compiler/mlir/tensorflow:attribute_utils",
+        "//tensorflow/compiler/mlir/tensorflow:string_util",
+        "//tensorflow/compiler/mlir/tensorflow:tensorflow_analysis",
+        "//tensorflow/compiler/mlir/tensorflow:tpu_rewrite_device_util",
+        "//tensorflow/compiler/mlir/tensorflow/transforms:tf_pass_inc_gen",
+        "//tensorflow/core:framework",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/strings",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:FuncDialect",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:TransformUtils",
+    ],
+)
+
+cc_library(
+    name = "extract_head_tail_outside_compilation",
+    srcs = ["extract_head_tail_outside_compilation.cc"],
+    textual_hdrs = [
+        "clustering_passes.h.inc",
+    ],
+    deps = [
+        ":clustering_passes_inc_gen",
+        "//tensorflow/compiler/mlir/tensorflow",
+        "//tensorflow/compiler/mlir/tensorflow:attribute_utils",
+        "//tensorflow/compiler/mlir/tensorflow:device_util",
+        "//tensorflow/compiler/mlir/tensorflow:string_util",
+        "//tensorflow/compiler/mlir/tensorflow:tensorflow_analysis",
+        "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops",
+        "//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
+        "//tensorflow/compiler/mlir/tensorflow:tpu_rewrite_device_util",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/strings",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:FuncDialect",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:Rewrite",
+        "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:TransformUtils",
+    ],
+)
+
+cc_library(
+    name = "dialect_to_executor_passes",
+    srcs = [
+        "dialect_to_executor_passes.h",
+    ],
+    textual_hdrs = [
+        "dialect_to_executor_passes.h.inc",
+    ],
+    deps = [
+        ":dialect_to_executor_passes_inc_gen",
+        "//tensorflow/compiler/mlir/tensorflow",
+        "//tensorflow/core:framework",
+        "//tensorflow/core/transforms/toposort:Pass",
+        "@com_google_absl//absl/strings",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:FuncDialect",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:Transforms",
+    ],
+)
+
+gentbl_cc_library(
+    name = "dialect_to_executor_passes_inc_gen",
+    compatible_with = get_compatible_with_portable(),
+    tbl_outs = [
+        (
+            [
+                "-gen-pass-decls",
+                "-name=TFXLABridge",
+            ],
+            "dialect_to_executor_passes.h.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "dialect_to_executor_passes.td",
+    deps = [
+        "@llvm-project//mlir:PassBaseTdFiles",
+    ],
+)
diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h b/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h
new file mode 100644
index 0000000..79721a0
--- /dev/null
+++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h
@@ -0,0 +1,50 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+    http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_PASSES_CLUSTERING_PASSES_H_
+#define TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_PASSES_CLUSTERING_PASSES_H_
+
+#include <memory>
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/Pass/Pass.h"  // from @llvm-project
+
+namespace tensorflow {
+namespace tf2xla {
+namespace internal {
+
+// Verifies that all MLIR Ops have the expected attributes.
+std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
+CreateVerifyClusteringPass();
+
+// Creates a pass that forms clusters from operations of the same
+// `_replication_info` attribute.
+std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
+CreateTPUClusterFormationPass(bool strict_clusters = false);
+
+// Creates a pass that extracts outside compilation (Host ops inside device
+// cluster) at head/tail of Device cluster to run before/after XLA computation.
+std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
+CreateExtractHeadTailOutsideCompilationPass();
+
+#define GEN_PASS_REGISTRATION
+#define GEN_PASS_DECL_TPUCLUSTERFORMATIONPASS
+#define GEN_PASS_DECL_TPUEXTRACTHEADTAILOUTSIDECOMPILATIONPASS
+#define GEN_PASS_DECL_VERIFYCLUSTERINGPASS
+#include "tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h.inc"
+
+}  // namespace internal
+}  // namespace tf2xla
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_PASSES_CLUSTERING_PASSES_H_
diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.td b/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.td
new file mode 100644
index 0000000..4fc8af15
--- /dev/null
+++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.td
@@ -0,0 +1,157 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+    http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+include "mlir/Pass/PassBase.td"
+
+def VerifyClusteringPass : Pass<"verify-clustering-pass", "mlir::func::FuncOp"> {
+
+  let summary = "Verify that the Bridge output is correct and errors if verification fails.";
+
+  let description = [{
+    Verifies whether clustering has resulted in the expected invariants. These
+    include verifying that clusters have been created and have been outside
+    compiled, the result is device agnostic and in TF functional dialect &
+    that the device attribute exists.
+  }];
+
+  let constructor = "tensorflow::tf2xla::internal::CreateVerifyClusteringPass()";
+}
+
+def TPUClusterFormationPass : Pass<"tf-tpu-cluster-formation", "ModuleOp"> {
+  let summary = "Forms clusters from operations assigned to the same TPU computation";
+
+  let description = [{
+    TPU computations from the frontend are composed of a `tf.TPUReplicateMetadata`
+    op, a subgraph of ops (TensorFlow Dialect) each with a matching
+    `_replication_info` attribute relative to the associated
+    `tf.TPUReplicateMetadata` op, and optionally `tf.TPUReplicatedInput` and
+    `tf.TPUReplicatedOutput` ops feeding in inputs and outputs to and from a
+    replicated TPU computation. The number of times a TPU computation is
+    replicated is defined in the `tf.TPUReplicateMetadata` op (`num_replicas`
+    attribute) and operand and result sizes of `tf.TPUReplicatedInput` and
+    `tf.TPUReplicatedOutput` respectively must match, excluding packed tensors.
+    It is also assumed ops of the same TPU computation do not have ops outside
+    of the TPU computation that are both inputs and outputs to the same TPU
+    computation. Furthermore, we assume that every node has either none or both
+    of `_replication_info` and `_xla_compile_device_type` attributes defined.
+
+    This pass takes the TPU computation subgraph, moves them into a
+    `tf_device.cluster`, and copies over attributes from the associated
+    `tf.TPUReplicateMetadata` op to the newly created `tf_device.cluster`. If the
+    computation is replicated (`num_replicas` > 1), the `num_replicas` attribute is
+    not copied over but instead the `tf_device.cluster` is further wrapped with a
+    `tf_device.replicate`, and associated `tf.TPUReplicatedInput` and
+    `tf.TPUReplicatedOutput` ops are replaced as the `tf_device.replicate` operands
+    and results. Otherwise, the single operands and results of the associated
+    `tf.TPUReplicatedInput` and `tf.TPUReplicatedOutput` ops are simply forwarded to
+    the `tf_device.cluster`.
+
+    For example, the following non replicated computation:
+
+    ```mlir
+    func @tpu_computation(%arg0: tensor<i32>) -> tensor<i32> {
+      // Metadata op for cluster `cluster` with 1 replica, 1 core per replica and
+      // with topology `<topology>`.
+      "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "cluster", num_relicas = 1, num_cores_per_replica = 1, topology = "<topology>", device_assignment = [], padding_map = []} : () -> ()
+      %replicated_input = "tf.TPUReplicatedInput"(%arg0) : (tensor<i32>) -> tensor<i32>
+      %identity = "tf.Identity"(%replicated_input) {_xla_compile_device_type = "TPU", _replication_info = "cluster"} : (tensor<i32>) -> tensor<i32>
+      %replicated_output = "tf.TPUReplicatedOutput(%identity) : (tensor<i32>) -> tensor<i32>
+      return %replicated_output : tensor<i32>
+    }
+    ```
+
+    will be transformed into:
+
+    ```mlir
+    func @tpu_computation(%arg0: tensor<i32>) -> tensor<i32> {
+      %cluster = "tf_device.cluster"() ( {
+        %identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
+        tf_device.return %identity : tensor<i32>
+      }) {_xla_compile_device_type = "TPU", _replication_info = "cluster", num_cores_per_replica = 1, topology = "topology", device_assignment = [], padding_map = []} : () -> (tensor<i32>)
+      return %cluster : tensor<i32>
+    }
+    ```
+
+    The following replicated computation:
+
+    ```mlir
+    func @tpu_computation(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
+      "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "cluster", num_relicas = 2, num_cores_per_replica = 1, topology = "topology", device_assignment = [], padding_map = []} : () -> ()
+      %replicated_input = "tf.TPUReplicatedInput"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+      %identity = "tf.Identity"(%replicated_input) {_xla_compile_device_type = "TPU", _replication_info = "cluster"} : (tensor<i32>) -> tensor<i32>
+      %replicated_output:2 = "tf.TPUReplicatedOutput(%identity) : (tensor<i32>) -> (tensor<i32>, tensor<i32>)
+      return %replicated_output#0, %replicated_output#1 : tensor<i32>, tensor<i32>
+    }
+    ```
+
+    will be transformed into:
+
+    ```mlir
+    func @tpu_computation(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
+      %replicate:2 = tf_device.replicate([%arg0, %arg1] as %replicated_input) {n = 2 : i32} {
+        %cluster = "tf_device.cluster"() ( {
+          %identity = "tf.Identity"(%replicated_input) : (tensor<i32>) -> tensor<i32>
+          tf_device.return %identity : tensor<i32>
+        }) {_xla_compile_device_type = "TPU", _replication_info = "cluster", num_cores_per_replica = 1, topology = "topology", device_assignment = [], padding_map = []} : () -> (tensor<i32>)
+        tf_device.return %cluster : tensor<i32>
+      }
+      return %replicate#0, %replicate#1 : tensor<i32>, tensor<i32>
+    }
+    ```
+  }];
+
+  let constructor = "tensorflow::tf2xla::internal::CreateTPUClusterFormationPass()";
+}
+
+def ExtractHeadTailOutsideCompilationPass : Pass<"tf-extract-head-tail-outside-compilation", "ModuleOp"> {
+  let summary = "Extracts head or tail outside compilation to separate host launches before/after device cluster.";
+
+  let description = [{
+    This pass extracts a CPU computation cluster with `_xla_outside_compilation`
+    annotation from the head or tail of a Device cluster.
+
+    For example:
+
+    ```mlir
+      %cluster = "tf_device.cluster"() ( {
+        %a = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> tensor<i32>
+        %b = "tf.B"(%a) : (tensor<i32>) -> tensor<i32>
+        %c = "tf.C"(%b) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> tensor<i32>
+        tf_device.return %c : tensor<i32>
+      }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> tensor<i32>
+      return %cluster : tensor<i32>
+    ```
+
+    becomes:
+
+    ```mlir
+    %0 = "tf_device.launch"() ( {
+      %3 = "tf.A"(%arg0) : (tensor<i32>) -> tensor<i32>
+      tf_device.return %3 : tensor<i32>
+    }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> tensor<i32>
+    %1 = "tf_device.cluster"() ( {
+      %3 = "tf.B"(%0) : (tensor<i32>) -> tensor<i32>
+      tf_device.return %3 : tensor<i32>
+    }) {device_assignment = [], num_cores_per_replica = 1 : i64, padding_map = [], step_marker_location = "", topology = ""} : () -> tensor<i32>
+    %2 = "tf_device.launch"() ( {
+      %3 = "tf.C"(%1) : (tensor<i32>) -> tensor<i32>
+      tf_device.return %3 : tensor<i32>
+    }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> tensor<i32>
+    return %2 : tensor<i32>
+
+    ```
+  }];
+
+  let constructor = "tensorflow::tf2xla::internal::CreateExtractHeadTailOutsideCompilationPass()";
+}
+
+
+
diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/dialect_to_executor_passes.h b/tensorflow/compiler/mlir/tf2xla/internal/passes/dialect_to_executor_passes.h
new file mode 100644
index 0000000..7424786
--- /dev/null
+++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/dialect_to_executor_passes.h
@@ -0,0 +1,32 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+    http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_PASSES_DIALECT_TO_EXECUTOR_PASSES_H_
+#define TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_PASSES_DIALECT_TO_EXECUTOR_PASSES_H_
+
+#include <memory>
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
+#include "mlir/Pass/Pass.h"  // from @llvm-project
+
+namespace tensorflow {
+namespace tf2xla {
+namespace internal {
+
+// Verifies that Executor input is of the expected format.
+std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
+CreateVerifyInputDialectToExecutorPass();
+
+#define GEN_PASS_DECL_VERIFYINPUTDIALECTTOEXECUTORPASS
+}  // namespace internal
+}  // namespace tf2xla
+}  // namespace tensorflow
+#endif  // TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_PASSES_DIALECT_TO_EXECUTOR_PASSES_H_
diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/dialect_to_executor_passes.td b/tensorflow/compiler/mlir/tf2xla/internal/passes/dialect_to_executor_passes.td
new file mode 100644
index 0000000..9c7891d
--- /dev/null
+++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/dialect_to_executor_passes.td
@@ -0,0 +1,21 @@
+
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+    http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+include "mlir/Pass/PassBase.td"
+
+def VerifyInputDialectToExecutor : Pass<"verify-input-dialect-to-executor-pass", "mlir::func::FuncOp"> {
+  let summary = "Verify that TF dialect to executor converter receives the correct input.";
+  let description = [{
+    Verifies the input before exporting to TF executor. This includes checking whether the Ops are in TF functional, have device attributes & there are no tf_device.cluster_func ops.
+  }];
+  let constructor = "tensorflow::tf2xla::internal::CreateVerifyInputDialectToExecutorPass()";
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/extract_head_tail_outside_compilation.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/extract_head_tail_outside_compilation.cc
similarity index 83%
rename from tensorflow/compiler/mlir/tensorflow/transforms/extract_head_tail_outside_compilation.cc
rename to tensorflow/compiler/mlir/tf2xla/internal/passes/extract_head_tail_outside_compilation.cc
index 286a181..ad85310 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/extract_head_tail_outside_compilation.cc
+++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/extract_head_tail_outside_compilation.cc
@@ -15,25 +15,23 @@
 
 #include <memory>
 #include <string>
-#include <tuple>
-#include <type_traits>
-#include <utility>
 
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Casting.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Block.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
 #include "mlir/IR/Visitors.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
-#include "mlir/Pass/PassRegistry.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
@@ -42,16 +40,30 @@
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
-#include "tsl/util/device_name_utils.h"
 
-namespace mlir {
-namespace TFDevice {
+namespace tensorflow {
+namespace tf2xla {
+namespace internal {
 
 // This pass extracts a CPU computation cluster with `_xla_outside_compilation`
 // annotation from the head or tail of a TPU cluster.
 
 namespace {
 
+using mlir::Block;
+using mlir::BlockArgument;
+using mlir::BoolAttr;
+using mlir::ModuleOp;
+using mlir::OpBuilder;
+using mlir::Operation;
+using mlir::OperationPass;
+using mlir::Region;
+using mlir::StringAttr;
+using mlir::Type;
+using mlir::Value;
+using mlir::WalkResult;
+using mlir::func::FuncOp;
+
 constexpr char kXlaMapOutsideCompilationAttr[] = "_xla_map_outside_compilation";
 constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
 
@@ -90,9 +102,10 @@
 // values of the Launch and remapped to the Launch results. If `before` is set
 // to true, the Launch is created before `op`. Otherwise the Launch is created
 // after `op`.
-tf_device::LaunchOp CreateLaunchForBlock(OpBuilder* builder, Operation* op,
-                                         bool before, Block* launch_block,
-                                         llvm::StringRef host_device) {
+mlir::tf_device::LaunchOp CreateLaunchForBlock(OpBuilder* builder,
+                                               Operation* op, bool before,
+                                               Block* launch_block,
+                                               llvm::StringRef host_device) {
   // Find results and result types of ops in block that needs to returned.
   llvm::SmallVector<Value, 4> launch_results;
   llvm::SmallVector<Type, 4> launch_result_types;
@@ -112,23 +125,23 @@
   }
 
   before ? builder->setInsertionPoint(op) : builder->setInsertionPointAfter(op);
-  auto launch = builder->create<tf_device::LaunchOp>(
+  auto launch = builder->create<mlir::tf_device::LaunchOp>(
       op->getLoc(), builder->getStringAttr(host_device), launch_result_types);
   launch.getBody().push_back(launch_block);
 
   builder->setInsertionPointToEnd(&launch.GetBody());
-  builder->create<tf_device::ReturnOp>(op->getLoc(), launch_results);
+  builder->create<mlir::tf_device::ReturnOp>(op->getLoc(), launch_results);
 
   return launch;
 }
 
 // Checks if an operation is a supported TPU embedding op.
 bool IsEmbeddingOp(Operation* op) {
-  return isa<TF::EnqueueTPUEmbeddingRaggedTensorBatchOp,
-             TF::EnqueueTPUEmbeddingSparseTensorBatchOp,
-             TF::EnqueueTPUEmbeddingArbitraryTensorBatchOp,
-             TF::RecvTPUEmbeddingActivationsOp,
-             TF::SendTPUEmbeddingGradientsOp>(op);
+  return llvm::isa<mlir::TF::EnqueueTPUEmbeddingRaggedTensorBatchOp,
+                   mlir::TF::EnqueueTPUEmbeddingSparseTensorBatchOp,
+                   mlir::TF::EnqueueTPUEmbeddingArbitraryTensorBatchOp,
+                   mlir::TF::RecvTPUEmbeddingActivationsOp,
+                   mlir::TF::SendTPUEmbeddingGradientsOp>(op);
 }
 
 // Returns a set of ops that are outside compiled and can be extracted to before
@@ -136,10 +149,10 @@
 // computation or other ops that can be extracted, and have no operands from
 // other ops in the TPU computation that cannot be extracted.
 llvm::SmallVector<Operation*, 4> FindOutsideCompiledOpsAtHead(
-    const TF::SideEffectAnalysis& side_effect_analysis,
-    tf_device::ClusterOp cluster) {
+    const mlir::TF::SideEffectAnalysis& side_effect_analysis,
+    mlir::tf_device::ClusterOp cluster) {
   const auto& analysis = side_effect_analysis.GetAnalysisForFunc(
-      cluster->getParentOfType<func::FuncOp>());
+      cluster->getParentOfType<mlir::func::FuncOp>());
   Region* cluster_region = &cluster.getBody();
   llvm::SmallSetVector<Operation*, 4> head_outside_compiled_ops;
 
@@ -188,7 +201,8 @@
 
 // Moves head outside compiled ops into its own `tf_device.LaunchOp`
 // computation before the cluster.
-void CreateHeadComputation(OpBuilder* builder, tf_device::ClusterOp cluster,
+void CreateHeadComputation(OpBuilder* builder,
+                           mlir::tf_device::ClusterOp cluster,
                            llvm::ArrayRef<Operation*> head_outside_compiled_ops,
                            llvm::StringRef host_device) {
   Block* launch_block = new Block;
@@ -197,7 +211,7 @@
     head_outside_compiled_op->moveBefore(launch_block, launch_block->end());
   }
 
-  tf_device::LaunchOp launch = CreateLaunchForBlock(
+  mlir::tf_device::LaunchOp launch = CreateLaunchForBlock(
       builder, cluster, /*before=*/true, launch_block, host_device);
 
   for (auto result : llvm::zip(launch.GetBody().getTerminator()->getOperands(),
@@ -209,21 +223,22 @@
 // Extracts and move outside compiled ops that have no dependencies in the
 // cluster to before the cluster.
 mlir::LogicalResult LiftHeadOutsideCompiledOps(
-    OpBuilder* builder, const TF::SideEffectAnalysis& side_effect_analysis,
-    const mlir::TF::RuntimeDevices& devices, tf_device::ClusterOp cluster,
+    OpBuilder* builder,
+    const mlir::TF::SideEffectAnalysis& side_effect_analysis,
+    const mlir::TF::RuntimeDevices& devices, mlir::tf_device::ClusterOp cluster,
     std::string* host_device, bool* cluster_updated) {
   llvm::SmallVector<Operation*, 4> head_outside_compiled_ops =
       FindOutsideCompiledOpsAtHead(side_effect_analysis, cluster);
-  if (head_outside_compiled_ops.empty()) return success();
+  if (head_outside_compiled_ops.empty()) return mlir::success();
   if (failed(tensorflow::GetHostDeviceOutsideComputation(devices, cluster,
                                                          host_device)))
-    return failure();
+    return mlir::failure();
 
   CreateHeadComputation(builder, cluster, head_outside_compiled_ops,
                         *host_device);
 
   *cluster_updated = true;
-  return success();
+  return mlir::success();
 }
 
 // Fills `tail_outside_compiled_ops` with ops that are outside compiled and
@@ -232,12 +247,12 @@
 // TPU computation or other ops that can be extracted, and have no results used
 // by other ops in the TPU computation that cannot be extracted.
 void FindOutsideCompiledOpsAtTailAndClusterResults(
-    const TF::SideEffectAnalysis& side_effect_analysis,
-    tf_device::ClusterOp cluster,
+    const mlir::TF::SideEffectAnalysis& side_effect_analysis,
+    mlir::tf_device::ClusterOp cluster,
     llvm::SmallVectorImpl<Operation*>* tail_outside_compiled_ops,
     llvm::SmallVectorImpl<Value>* cluster_results) {
   const auto& analysis = side_effect_analysis.GetAnalysisForFunc(
-      cluster->getParentOfType<func::FuncOp>());
+      cluster->getParentOfType<FuncOp>());
   Region* cluster_region = &cluster.getBody();
   llvm::SmallSetVector<Operation*, 4> tail_outside_compiled_ops_set;
   Operation* terminator = cluster.GetBody().getTerminator();
@@ -300,7 +315,8 @@
 
 // Moves tail outside compiled ops into its own `tf_device.LaunchOp`
 // computation after the cluster.
-void CreateTailComputation(OpBuilder* builder, tf_device::ClusterOp cluster,
+void CreateTailComputation(OpBuilder* builder,
+                           mlir::tf_device::ClusterOp cluster,
                            llvm::ArrayRef<Operation*> tail_outside_compiled_ops,
                            llvm::StringRef host_device) {
   Block* launch_block = new Block;
@@ -309,10 +325,10 @@
     tail_outside_compiled_op->moveBefore(launch_block, launch_block->begin());
   }
 
-  tf_device::LaunchOp launch = CreateLaunchForBlock(
+  mlir::tf_device::LaunchOp launch = CreateLaunchForBlock(
       builder, cluster, /*before=*/false, launch_block, host_device);
 
-  auto operand_not_in_launch = [&](OpOperand& operand) {
+  auto operand_not_in_launch = [&](mlir::OpOperand& operand) {
     return !launch.getOperation()->isProperAncestor(operand.getOwner());
   };
   for (auto result : llvm::zip(launch.GetBody().getTerminator()->getOperands(),
@@ -323,13 +339,13 @@
 
 // Updates cluster with updated cluster results after extracting tail outside
 // compiled ops.
-tf_device::ClusterOp UpdateClusterResults(
-    OpBuilder* builder, tf_device::ClusterOp cluster,
+mlir::tf_device::ClusterOp UpdateClusterResults(
+    OpBuilder* builder, mlir::tf_device::ClusterOp cluster,
     llvm::ArrayRef<Value> new_cluster_results) {
   Operation* old_terminator = cluster.GetBody().getTerminator();
   builder->setInsertionPoint(old_terminator);
-  builder->create<tf_device::ReturnOp>(old_terminator->getLoc(),
-                                       new_cluster_results);
+  builder->create<mlir::tf_device::ReturnOp>(old_terminator->getLoc(),
+                                             new_cluster_results);
   old_terminator->erase();
 
   builder->setInsertionPoint(cluster);
@@ -338,12 +354,12 @@
   for (const auto& new_cluster_result : new_cluster_results)
     new_cluster_result_types.push_back(new_cluster_result.getType());
 
-  auto new_cluster = builder->create<tf_device::ClusterOp>(
+  auto new_cluster = builder->create<mlir::tf_device::ClusterOp>(
       cluster.getLoc(), new_cluster_result_types,
       /*operands=*/llvm::ArrayRef<Value>{}, cluster->getAttrs());
   new_cluster.getBody().takeBody(cluster.getBody());
 
-  auto operand_not_in_cluster = [&](OpOperand& operand) {
+  auto operand_not_in_cluster = [&](mlir::OpOperand& operand) {
     return !new_cluster.getOperation()->isProperAncestor(operand.getOwner());
   };
   for (auto result :
@@ -359,20 +375,21 @@
 // Extracts and move outside compiled ops that do not create dependencies in the
 // cluster to after the cluster.
 mlir::LogicalResult LiftTailOutsideCompiledOps(
-    OpBuilder* builder, const TF::SideEffectAnalysis& side_effect_analysis,
+    OpBuilder* builder,
+    const mlir::TF::SideEffectAnalysis& side_effect_analysis,
     const mlir::TF::RuntimeDevices& devices, std::string host_device,
-    tf_device::ClusterOp* cluster, bool* cluster_updated) {
+    mlir::tf_device::ClusterOp* cluster, bool* cluster_updated) {
   llvm::SmallVector<Operation*, 4> tail_outside_compiled_ops;
   llvm::SmallVector<Value, 4> cluster_results;
   FindOutsideCompiledOpsAtTailAndClusterResults(side_effect_analysis, *cluster,
                                                 &tail_outside_compiled_ops,
                                                 &cluster_results);
-  if (tail_outside_compiled_ops.empty()) return success();
+  if (tail_outside_compiled_ops.empty()) return mlir::success();
 
   if (host_device.empty())
     if (failed(tensorflow::GetHostDeviceOutsideComputation(devices, *cluster,
                                                            &host_device)))
-      return failure();
+      return mlir::failure();
 
   // Forward all results of cluster first. These results will be remapped once
   // a new cluster is formed.
@@ -385,12 +402,12 @@
   *cluster = UpdateClusterResults(builder, *cluster, cluster_results);
 
   *cluster_updated = true;
-  return success();
+  return mlir::success();
 }
 
 // Removes aliased outputs in cluster from ops outside of cluster.
 void RemoveClusterAliasedOutputs(OpBuilder* builder,
-                                 tf_device::ClusterOp cluster) {
+                                 mlir::tf_device::ClusterOp cluster) {
   llvm::SmallVector<Value, 4> used_old_cluster_results;
   llvm::SmallVector<Value, 4> new_cluster_results;
   llvm::SmallVector<Type, 4> new_cluster_result_types;
@@ -412,7 +429,7 @@
   if (new_cluster_results.size() == cluster.getNumResults()) return;
 
   builder->setInsertionPoint(cluster);
-  auto new_cluster = builder->create<tf_device::ClusterOp>(
+  auto new_cluster = builder->create<mlir::tf_device::ClusterOp>(
       cluster.getLoc(), new_cluster_result_types,
       /*operands=*/llvm::ArrayRef<Value>{}, cluster->getAttrs());
   new_cluster.getBody().takeBody(cluster.getBody());
@@ -426,7 +443,7 @@
 }
 
 #define GEN_PASS_DEF_EXTRACTHEADTAILOUTSIDECOMPILATIONPASS
-#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc"
+#include "tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h.inc"
 
 struct ExtractHeadTailOutsideCompilationPass
     : public impl::ExtractHeadTailOutsideCompilationPassBase<
@@ -435,7 +452,7 @@
 };
 
 void ExtractHeadTailOutsideCompilationPass::runOnOperation() {
-  auto& side_effect_analysis = getAnalysis<TF::SideEffectAnalysis>();
+  auto& side_effect_analysis = getAnalysis<mlir::TF::SideEffectAnalysis>();
   // Get runtime devices information from the closest parent module.
   auto module = getOperation();
   mlir::TF::RuntimeDevices devices;
@@ -443,11 +460,11 @@
     return signalPassFailure();
 
   OpBuilder builder(&getContext());
-  llvm::SmallVector<tf_device::ClusterOp, 4> clusters;
+  llvm::SmallVector<mlir::tf_device::ClusterOp, 4> clusters;
   module.walk(
-      [&](tf_device::ClusterOp cluster) { clusters.push_back(cluster); });
+      [&](mlir::tf_device::ClusterOp cluster) { clusters.push_back(cluster); });
 
-  for (tf_device::ClusterOp cluster : clusters) {
+  for (mlir::tf_device::ClusterOp cluster : clusters) {
     std::string host_device;
     bool cluster_updated = false;
     if (failed(LiftHeadOutsideCompiledOps(&builder, side_effect_analysis,
@@ -468,5 +485,6 @@
   return std::make_unique<ExtractHeadTailOutsideCompilationPass>();
 }
 
-}  // namespace TFDevice
-}  // namespace mlir
+}  // namespace internal
+}  // namespace tf2xla
+}  // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_cluster_formation.cc
similarity index 85%
rename from tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc
rename to tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_cluster_formation.cc
index c2ad182..b600c86 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc
+++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_cluster_formation.cc
@@ -1,4 +1,4 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -13,44 +13,44 @@
 limitations under the License.
 ==============================================================================*/
 
-#include <algorithm>
+#include <cassert>
 #include <cstdint>
 #include <iterator>
 #include <memory>
-#include <ostream>
 #include <set>
-#include <sstream>
 #include <string>
 #include <tuple>
 #include <unordered_map>
 #include <utility>
 
 #include "absl/container/flat_hash_map.h"
+#include "absl/log/log.h"
 #include "absl/strings/match.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/DenseMap.h"
-#include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/iterator_range.h"
 #include "llvm/Support/Casting.h"
-#include "llvm/Support/FormatVariadic.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
+#include "mlir/IR/DialectRegistry.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
+#include "mlir/IR/Region.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
+#include "mlir/IR/ValueRange.h"  // from @llvm-project
+#include "mlir/IR/Visitors.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
-#include "mlir/Pass/PassRegistry.h"  // from @llvm-project
-#include "mlir/Support/DebugStringHelper.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
@@ -61,11 +61,29 @@
 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
 #include "tensorflow/core/util/device_name_utils.h"
 
-namespace mlir {
-namespace TFTPU {
+namespace tensorflow {
+namespace tf2xla {
+namespace internal {
 
 namespace {
 
+using mlir::Block;
+using mlir::DialectRegistry;
+using mlir::LogicalResult;
+using mlir::ModuleOp;
+using mlir::NamedAttribute;
+using mlir::NamedAttrList;
+using mlir::OpBuilder;
+using mlir::Operation;
+using mlir::OpResult;
+using mlir::Region;
+using mlir::StringAttr;
+using mlir::success;
+using mlir::Type;
+using mlir::Value;
+using mlir::ValueRange;
+using mlir::WalkResult;
+
 constexpr llvm::StringRef kDeviceAttr = "device";
 constexpr llvm::StringRef kNameAttr = "name";
 constexpr llvm::StringRef kNumCoresPerReplicaAttr = "num_cores_per_replica";
@@ -88,7 +106,7 @@
 using ClusterMap = llvm::SmallDenseMap<llvm::StringRef, OpSetVector, 8>;
 
 #define GEN_PASS_DEF_TPUCLUSTERFORMATIONPASS
-#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc"
+#include "tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h.inc"
 
 class TPUClusterFormationPass
     : public impl::TPUClusterFormationPassBase<TPUClusterFormationPass> {
@@ -97,7 +115,7 @@
       : strict_clusters_(strict_clusters) {}
 
   void getDependentDialects(DialectRegistry& registry) const override {
-    registry.insert<tf_device::TensorFlowDeviceDialect>();
+    registry.insert<mlir::tf_device::TensorFlowDeviceDialect>();
   }
 
   void runOnOperation() override;
@@ -113,13 +131,13 @@
 LogicalResult CollectMetadata(Block* block, MetadataMap* metadata_map) {
   // Just look at top-level operations in the block (not nested ones)
   for (Operation& op : llvm::make_early_inc_range(*block)) {
-    auto metadata_op = dyn_cast<TF::TPUReplicateMetadataOp>(op);
+    auto metadata_op = llvm::dyn_cast<mlir::TF::TPUReplicateMetadataOp>(op);
     if (!metadata_op) continue;
 
     NamedAttrList attrs(metadata_op->getAttrDictionary());
 
     // Missing or bad `_replication_info` attribute.
-    auto replication_info_attr = attrs.get(TF::kReplicationInfoAttr);
+    auto replication_info_attr = attrs.get(mlir::TF::kReplicationInfoAttr);
     if (!replication_info_attr)
       return metadata_op.emitError() << kBadReplicateInfoAttrMsg;
 
@@ -140,7 +158,7 @@
     if (!it.second) {
       return metadata_op.emitError()
              << "multiple TPUReplicateMetadata ops with the same '"
-             << TF::kReplicationInfoAttr << "' attribute '"
+             << mlir::TF::kReplicationInfoAttr << "' attribute '"
              << replication_info_attr_str.getValue() << "' found";
     }
     metadata_op.erase();
@@ -168,13 +186,14 @@
   std::set<llvm::StringRef> device_types;
   absl::flat_hash_map<std::string, OpDevice> devices;
   for (Operation& op : *block) {
-    LogicalResult result = TF::HasValidCompilationAndReplicationAttributes(op);
+    LogicalResult result =
+        mlir::TF::HasValidCompilationAndReplicationAttributes(op);
     if (failed(result)) return result;
 
     // Collect device types which currently must be consistent per block
     // (checked later).
     auto device_type_attr =
-        op.getAttrOfType<StringAttr>(TF::kCompileDeviceTypeAttr);
+        op.getAttrOfType<StringAttr>(mlir::TF::kCompileDeviceTypeAttr);
     if (device_type_attr) {
       // Some graphs in TPU bridge may have both tf.StatefulPartitionedCall
       // ops with and without _tpu_replicate attributes. As a result, the ops
@@ -185,20 +204,20 @@
       if (device_type_attr.getValue().empty()) continue;
       device_types.insert(device_type_attr);
       // Stop here for ops with non-TPU devices, they are handled elsewhere.
-      if (device_type_attr.getValue() != TF::kTpuDevice) continue;
+      if (device_type_attr.getValue() != mlir::TF::kTpuDevice) continue;
     }
 
-    if (op.hasAttr(TF::kReplicationInfoAttr)) {
+    if (op.hasAttr(mlir::TF::kReplicationInfoAttr)) {
       // For replicated case, borrow cluster structure from replication info.
       // Following condition is already checked in
       // `HasValidCompilationAndReplicationAttributes` above, assert here for
       // documentation and to avoid breakage when that function is changed.
-      assert(op.hasAttr(TF::kCompileDeviceTypeAttr));
+      assert(op.hasAttr(mlir::TF::kCompileDeviceTypeAttr));
       has_replicated_compiled_op = true;
-      auto attr = op.getAttrOfType<StringAttr>(TF::kReplicationInfoAttr);
+      auto attr = op.getAttrOfType<StringAttr>(mlir::TF::kReplicationInfoAttr);
       auto it = clusters->try_emplace(attr.getValue());
       it.first->getSecond().insert(&op);
-    } else if (op.hasAttr(TF::kCompileDeviceTypeAttr)) {
+    } else if (op.hasAttr(mlir::TF::kCompileDeviceTypeAttr)) {
       // For non-replicated case, assume one cluster per block (in line with
       // Framework behavior).
       has_non_replicated_compiled_op = true;
@@ -213,7 +232,7 @@
       if (!tensorflow::DeviceNameUtils::ParseFullOrLocalName(device_attr.str(),
                                                              &parsed)) {
         op.emitWarning() << "Invalid device name " << device_attr.str();
-        return failure();
+        return mlir::failure();
       }
 
       device_local_name =
@@ -270,7 +289,7 @@
   }
   if (device_types.size() > 1) {
     return block->getParentOp()->emitError()
-           << "found different '" << TF::kCompileDeviceTypeAttr
+           << "found different '" << mlir::TF::kCompileDeviceTypeAttr
            << "' attribute values (" << llvm::join(device_types, ",")
            << ") in same block which is not supported";
   }
@@ -303,7 +322,7 @@
 Operation* getOpClusterControlDependency(
     Operation* op, bool incoming, const OpSetVector& cluster_ops,
     const OpSetVector& cluster_dependent_ops,
-    const TF::SideEffectAnalysis::Info& side_effect_analysis) {
+    const mlir::TF::SideEffectAnalysis::Info& side_effect_analysis) {
   auto filter = [&](Operation* other_op) {
     return cluster_ops.contains(other_op) ||
            cluster_dependent_ops.contains(other_op);
@@ -349,9 +368,9 @@
 
 // Collects ops that need to be moved behind the cluster due to data or control
 // dependencies.
-FailureOr<llvm::SmallSetVector<Operation*, 8>> CollectClusterSuccessorOps(
+mlir::FailureOr<llvm::SmallSetVector<Operation*, 8>> CollectClusterSuccessorOps(
     Block* block, const OpSetVector& cluster_ops,
-    const TF::SideEffectAnalysis::Info& side_effect_analysis,
+    const mlir::TF::SideEffectAnalysis::Info& side_effect_analysis,
     bool strict_clusters) {
   OpSetVector cluster_predecessor_ops;
   OpSetVector cluster_successor_ops;
@@ -396,9 +415,9 @@
         // might have runtime impact for existing models.
         // We should make this message an error once there is such a contract
         // and once existing cases have been fixed.
-        InFlightDiagnostic error = strict_clusters
-                                       ? mlir::emitError(op.getLoc(), "")
-                                       : mlir::emitWarning(op.getLoc(), "");
+        mlir::InFlightDiagnostic error =
+            strict_clusters ? mlir::emitError(op.getLoc(), "")
+                            : mlir::emitWarning(op.getLoc(), "");
         error << "Op has cyclic dependency with a compilation cluster:\n";
         error << "The cluster depends on\n";
         error << op.getName() << "\n"
@@ -447,7 +466,7 @@
 }
 
 // Creates a `tf_device.cluster` to wrap cluster ops.
-tf_device::ClusterOp CreateClusterOp(
+mlir::tf_device::ClusterOp CreateClusterOp(
     Block* block, const OpSetVector& cluster_ops, llvm::ArrayRef<Value> results,
     llvm::ArrayRef<Operation*> cluster_successor_ops) {
   // `tf_device.cluster` will be placed at where the last op of the cluster is.
@@ -456,8 +475,8 @@
 
   llvm::SmallVector<Type, 8> result_types;
   for (Value result : results) result_types.push_back(result.getType());
-  auto cluster = builder.create<tf_device::ClusterOp>(last_cluster_op->getLoc(),
-                                                      result_types);
+  auto cluster = builder.create<mlir::tf_device::ClusterOp>(
+      last_cluster_op->getLoc(), result_types);
 
   Block* body = new Block;
   cluster.getBody().push_back(body);
@@ -469,8 +488,8 @@
   for (Operation* cluster_op : cluster_ops) {
     cluster_op->moveBefore(body, body->end());
     cluster_op->walk([&](Operation* inner_op) {
-      inner_op->removeAttr(TF::kReplicationInfoAttr);
-      inner_op->removeAttr(TF::kCompileDeviceTypeAttr);
+      inner_op->removeAttr(mlir::TF::kReplicationInfoAttr);
+      inner_op->removeAttr(mlir::TF::kCompileDeviceTypeAttr);
 
       if (auto attr = inner_op->getAttrOfType<StringAttr>(kDeviceAttr)) {
         // Preserve device attribute if the op is placed on a replicated core
@@ -488,7 +507,7 @@
 
   // Add terminator.
   builder.setInsertionPointToEnd(body);
-  builder.create<tf_device::ReturnOp>(last_cluster_op->getLoc(), results);
+  builder.create<mlir::tf_device::ReturnOp>(last_cluster_op->getLoc(), results);
 
   // Replaces uses of cluster ops results outside of cluster with the associated
   // `tf_device.cluster` results.
@@ -510,15 +529,16 @@
 // Returns an op of the given type that uses the result, along with
 // a list of identity ops along the way.
 template <typename T>
-std::tuple<T, llvm::SmallVector<TF::IdentityOp, 4>> GetSingleUserOfType(
+std::tuple<T, llvm::SmallVector<mlir::TF::IdentityOp, 4>> GetSingleUserOfType(
     OpResult result) {
-  llvm::SmallVector<TF::IdentityOp, 4> identity_ops;
+  llvm::SmallVector<mlir::TF::IdentityOp, 4> identity_ops;
 
   do {
     Operation* user = result.hasOneUse() ? *result.getUsers().begin() : nullptr;
     if (auto t = llvm::dyn_cast_or_null<T>(user)) {
       return std::make_tuple(t, identity_ops);
-    } else if (auto identity = llvm::dyn_cast_or_null<TF::IdentityOp>(user)) {
+    } else if (auto identity =
+                   llvm::dyn_cast_or_null<mlir::TF::IdentityOp>(user)) {
       identity_ops.emplace_back(identity);
       result = identity->getResult(0);
     } else {
@@ -529,27 +549,27 @@
   return std::make_tuple(T(), identity_ops);
 }
 
-using PartitionedClusterOutputMap =
-    absl::flat_hash_map<uint64_t,
-                        llvm::SmallVector<TF::TPUPartitionedOutputV2Op, 8>>;
+using PartitionedClusterOutputMap = absl::flat_hash_map<
+    uint64_t, llvm::SmallVector<mlir::TF::TPUPartitionedOutputV2Op, 8>>;
 
 // Returns the partitioned output ops from the cluster if there are any,
 // along with any single user identity ops between them. Not all outputs
 // of a cluster must be partitioned, so the output is a map from cluster
 // output ids to ops.
-std::tuple<PartitionedClusterOutputMap, llvm::SmallVector<TF::IdentityOp, 8>>
-GetPartitionedOutputsAndIdentityOps(tf_device::ClusterOp cluster) {
+std::tuple<PartitionedClusterOutputMap,
+           llvm::SmallVector<mlir::TF::IdentityOp, 8>>
+GetPartitionedOutputsAndIdentityOps(mlir::tf_device::ClusterOp cluster) {
   PartitionedClusterOutputMap partitioned_outputs;
-  llvm::SmallVector<TF::IdentityOp, 8> erase_list;
+  llvm::SmallVector<mlir::TF::IdentityOp, 8> erase_list;
 
   for (auto [cluster_result_id, cluster_result] :
        llvm::enumerate(cluster.getResults())) {
     auto [replicated_output, _] =
-        GetSingleUserOfType<TF::TPUReplicatedOutputOp>(cluster_result);
+        GetSingleUserOfType<mlir::TF::TPUReplicatedOutputOp>(cluster_result);
     if (replicated_output) {
       for (OpResult per_replica_result : replicated_output->getResults()) {
         auto [partitioned_output, id_ops] =
-            GetSingleUserOfType<TF::TPUPartitionedOutputV2Op>(
+            GetSingleUserOfType<mlir::TF::TPUPartitionedOutputV2Op>(
                 per_replica_result);
         if (partitioned_output) {
           erase_list.insert(erase_list.end(), id_ops.begin(), id_ops.end());
@@ -566,10 +586,10 @@
 // Inlines the partitioned output ops into the cluster, and updates
 // their users to point to the replicate op instead.
 Operation* BuildPartitionedOutputs(
-    OpBuilder& builder, tf_device::ClusterOp cluster,
-    tf_device::ReplicateOp replicate_op,
+    OpBuilder& builder, mlir::tf_device::ClusterOp cluster,
+    mlir::tf_device::ReplicateOp replicate_op,
     PartitionedClusterOutputMap& partitioned_outputs,
-    llvm::SmallVector<TF::IdentityOp, 8>& erase_list,
+    llvm::SmallVector<mlir::TF::IdentityOp, 8>& erase_list,
     llvm::SmallVector<Type, 8>& result_types, int num_replicas) {
   Operation* result_op;
   llvm::SmallVector<Value, 8> results;
@@ -586,7 +606,8 @@
     // Otherwise, "inline" the partitioned output ops by:
     // - Building a new op within the cluster.
     // - Replacing all the uses of the original ops with the cluster's outputs.
-    llvm::SmallVector<TF::TPUPartitionedOutputV2Op, 8>& ops = search->second;
+    llvm::SmallVector<mlir::TF::TPUPartitionedOutputV2Op, 8>& ops =
+        search->second;
     for (auto [replica_id, partitioned_output] : llvm::enumerate(ops)) {
       for (auto [core_id, result] :
            llvm::enumerate(partitioned_output->getResults())) {
@@ -600,11 +621,11 @@
     }
 
     // Assume all the replicas have the same structure.
-    TF::TPUPartitionedOutputV2Op first_op = *(ops.begin());
-    ArrayAttr dims = first_op.getPartitionDimsAttr();
+    mlir::TF::TPUPartitionedOutputV2Op first_op = *(ops.begin());
+    mlir::ArrayAttr dims = first_op.getPartitionDimsAttr();
     StringAttr sharding = first_op.get_XlaShardingAttr();
     Operation::result_type_range output_types = first_op.getResultTypes();
-    result_op = builder.create<TF::TPUPartitionedOutputV2Op>(
+    result_op = builder.create<mlir::TF::TPUPartitionedOutputV2Op>(
         replicate_op.getLoc(), output_types, cluster.getResult(result_id), dims,
         sharding);
 
@@ -613,16 +634,16 @@
   }
 
   // Once we've accumulated all the cluster's results, build a return op.
-  builder.create<tf_device::ReturnOp>(result_op->getLoc(), results);
+  builder.create<mlir::tf_device::ReturnOp>(result_op->getLoc(), results);
 
   // Then erase all the identity and partitioned output ops.
   for (auto [_, ops] : partitioned_outputs) {
-    for (TF::TPUPartitionedOutputV2Op op : ops) {
+    for (mlir::TF::TPUPartitionedOutputV2Op op : ops) {
       op->erase();
     }
   }
 
-  for (TF::IdentityOp to_erase : erase_list) {
+  for (mlir::TF::IdentityOp to_erase : erase_list) {
     to_erase->erase();
   }
 
@@ -632,7 +653,7 @@
 // Return the cluster's per-replica result type, converting any full-shaped
 // tensor types into sharded-shaped ones if they're partitioned.
 llvm::SmallVector<Type, 8> GetClusterResultTypes(
-    tf_device::ClusterOp cluster,
+    mlir::tf_device::ClusterOp cluster,
     const PartitionedClusterOutputMap& partitioned_outputs) {
   llvm::SmallVector<Type, 8> result_types;
   Operation::result_type_range cluster_result_types = cluster.getResultTypes();
@@ -663,8 +684,8 @@
 
 // Creates a `tf_device.replicate` to represent replication for the cluster, if
 // necessary. Erases Identity ops between partitioned and replicated output ops.
-LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas,
-                               int num_cores_per_replica) {
+LogicalResult ReplicateCluster(mlir::tf_device::ClusterOp cluster,
+                               int num_replicas, int num_cores_per_replica) {
   OpBuilder builder(cluster);
   auto [partitioned_outputs, erase_list] =
       GetPartitionedOutputsAndIdentityOps(cluster);
@@ -682,7 +703,7 @@
   if (num_replicas == 1) {
     // Collapse all the Identity ops between the TRO and TPO ops.
     if (!partitioned_outputs.empty()) {
-      for (TF::IdentityOp to_erase : erase_list) {
+      for (mlir::TF::IdentityOp to_erase : erase_list) {
         Value in = to_erase->getOperand(0);
         OpResult out = to_erase->getResult(0);
         out.replaceAllUsesWith(in);
@@ -699,12 +720,13 @@
 
   LogicalResult status = success();
   // Collect all used TPUReplicatedInput ops.
-  llvm::SmallVector<TF::TPUReplicatedInputOp, 8> replicated_input_ops;
-  llvm::SmallSet<TF::TPUReplicatedInputOp, 8> seen_ops;
+  llvm::SmallVector<mlir::TF::TPUReplicatedInputOp, 8> replicated_input_ops;
+  llvm::SmallSet<mlir::TF::TPUReplicatedInputOp, 8> seen_ops;
   mlir::visitUsedValuesDefinedAbove(
       cluster.getBody(), cluster.getBody(), [&](mlir::OpOperand* operand) {
         Operation* def = operand->get().getDefiningOp();
-        if (auto ri = llvm::dyn_cast_or_null<TF::TPUReplicatedInputOp>(def)) {
+        if (auto ri =
+                llvm::dyn_cast_or_null<mlir::TF::TPUReplicatedInputOp>(def)) {
           if (!seen_ops.contains(ri)) {
             seen_ops.insert(ri);
             replicated_input_ops.push_back(ri);
@@ -713,15 +735,16 @@
         // When model parallelism is used in conjunction with data parallelism
         // for resource inputs, we need to collect the per replica resource
         // inputs from input to `tf.TPUPartitionedInputV2` ops.
-        if (auto pi =
-                llvm::dyn_cast_or_null<TF::TPUPartitionedInputV2Op>(def)) {
+        if (auto pi = llvm::dyn_cast_or_null<mlir::TF::TPUPartitionedInputV2Op>(
+                def)) {
           if (pi->getNumOperands() != num_cores_per_replica)
             status = pi.emitOpError()
                      << "requires " << num_cores_per_replica
                      << " operands but found " << pi->getNumOperands();
           for (auto operand : pi.getInputs()) {
-            if (auto ri = llvm::dyn_cast_or_null<TF::TPUReplicatedInputOp>(
-                    operand.getDefiningOp())) {
+            if (auto ri =
+                    llvm::dyn_cast_or_null<mlir::TF::TPUReplicatedInputOp>(
+                        operand.getDefiningOp())) {
               if (!seen_ops.contains(ri)) {
                 seen_ops.insert(ri);
                 replicated_input_ops.push_back(ri);
@@ -731,7 +754,7 @@
         }
       });
 
-  if (failed(status)) return failure();
+  if (failed(status)) return mlir::failure();
 
   // Indices of the replicate op's arguments that are mirrored variables.
   llvm::SmallVector<int64_t, 8> mirrored_variable_indices;
@@ -741,8 +764,8 @@
   // creating the replicate op.
   llvm::SmallVector<std::pair<ValueRange, Type>, 8> replicated_inputs;
   llvm::SmallVector<Value, 8> packed_inputs;
-  llvm::SmallVector<TF::TPUReplicatedInputOp, 8> replicated_ops;
-  llvm::SmallVector<TF::TPUReplicatedInputOp, 8> packed_ops;
+  llvm::SmallVector<mlir::TF::TPUReplicatedInputOp, 8> replicated_ops;
+  llvm::SmallVector<mlir::TF::TPUReplicatedInputOp, 8> packed_ops;
   for (const auto& pos_and_input : llvm::enumerate(replicated_input_ops)) {
     auto input = pos_and_input.value();
     bool is_packed = input.getIsPacked();
@@ -763,8 +786,8 @@
   // Create `ordered_tpu_replicate_inputs` which contains the final ordered
   // replicate inputs. All packed arguments are moved to the end of the arg
   // list.
-  llvm::SmallVector<TF::TPUReplicatedInputOp, 8> ordered_tpu_replicate_inputs =
-      replicated_ops;
+  llvm::SmallVector<mlir::TF::TPUReplicatedInputOp, 8>
+      ordered_tpu_replicate_inputs = replicated_ops;
   ordered_tpu_replicate_inputs.append(packed_ops.begin(), packed_ops.end());
 
   // Assign `mirrored_variable_indices` based on the ordered replicated inputs.
@@ -778,9 +801,10 @@
 
   // Create replicate op.
   auto result_types = GetClusterResultTypes(cluster, partitioned_outputs);
-  auto replicate_op = builder.create<tf_device::ReplicateOp>(
+  auto replicate_op = builder.create<mlir::tf_device::ReplicateOp>(
       cluster.getLoc(), num_replicas,
-      llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<StringRef, 4>>(),
+      llvm::SmallDenseMap<llvm::StringRef,
+                          llvm::SmallVector<llvm::StringRef, 4>>(),
       replicated_inputs, packed_inputs, result_types);
 
   if (!mirrored_variable_indices.empty())
@@ -801,7 +825,7 @@
         std::next(replicate_op.result_begin(), offset + num_replicas));
     for (auto& use : llvm::make_early_inc_range(result.getUses())) {
       Operation* def = use.getOwner();
-      if (!llvm::isa<TF::TPUReplicatedOutputOp>(def)) {
+      if (!llvm::isa<mlir::TF::TPUReplicatedOutputOp>(def)) {
         // If user is not a `tf.TPUReplicatedOutput`, simply forward the first
         // replica output. Certain Graphs under V1 create `tf.Identity` users of
         // replicated ops to pin the TPU computation for execution.
@@ -825,13 +849,14 @@
   for (auto input_and_block_arg :
        llvm::zip(ordered_tpu_replicate_inputs,
                  replicate_op.GetBody().getArguments())) {
-    TF::TPUReplicatedInputOp input = std::get<0>(input_and_block_arg);
+    mlir::TF::TPUReplicatedInputOp input = std::get<0>(input_and_block_arg);
     Value block_arg = std::get<1>(input_and_block_arg);
     mlir::replaceAllUsesInRegionWith(input->getResult(0), block_arg,
                                      cluster.getBody());
     // Update replicated input use in tf.TPUPartitionedInputV2 op.
     for (auto& use : input->getUses()) {
-      auto pi = llvm::dyn_cast<TF::TPUPartitionedInputV2Op>(use.getOwner());
+      auto pi =
+          llvm::dyn_cast<mlir::TF::TPUPartitionedInputV2Op>(use.getOwner());
       if (pi) {
         pi.setOperand(use.getOperandNumber(), block_arg);
         partitioned_inputs.insert(pi.getOperation());
@@ -849,8 +874,8 @@
                                         partitioned_outputs, erase_list,
                                         result_types, num_replicas);
   } else {
-    result_op = builder.create<tf_device::ReturnOp>(replicate_op.getLoc(),
-                                                    cluster.getResults());
+    result_op = builder.create<mlir::tf_device::ReturnOp>(replicate_op.getLoc(),
+                                                          cluster.getResults());
   }
 
   for (auto pi : partitioned_inputs) pi->moveBefore(result_op);
@@ -860,13 +885,13 @@
   return success();
 }
 
-void SetNoReplicationClusterAttrs(tf_device::ClusterOp cluster,
+void SetNoReplicationClusterAttrs(mlir::tf_device::ClusterOp cluster,
                                   llvm::StringRef device_type,
                                   llvm::StringRef device) {
   OpBuilder builder(cluster);
-  cluster->setAttr(TF::kReplicationInfoAttr,
+  cluster->setAttr(mlir::TF::kReplicationInfoAttr,
                    builder.getStringAttr(kNoReplicationCluster));
-  cluster->setAttr(TF::kCompileDeviceTypeAttr,
+  cluster->setAttr(mlir::TF::kCompileDeviceTypeAttr,
                    builder.getStringAttr(device_type));
 
   if (!device.empty()) {
@@ -904,7 +929,8 @@
 //      attribute `num_replicas` is greater than 1.
 //   9. Copy over TPUReplicateMetadata attributes to `tf_device.cluster`.
 LogicalResult FormClustersInBlock(
-    Block* block, const TF::SideEffectAnalysis::Info& side_effect_analysis,
+    Block* block,
+    const mlir::TF::SideEffectAnalysis::Info& side_effect_analysis,
     bool strict_clusters) {
   MetadataMap metadata_map;
   LogicalResult result = CollectMetadata(block, &metadata_map);
@@ -919,7 +945,7 @@
           return op.emitOpError("Expected single block region");
         if (failed(FormClustersInBlock(&region.front(), side_effect_analysis,
                                        strict_clusters)))
-          return failure();
+          return mlir::failure();
       }
     }
   }
@@ -941,9 +967,9 @@
     // No TPUReplicateMetadata for a `_replication_info` attribute.
     if (has_replication && cluster_metadata == metadata_map.end()) {
       block->getParentOp()->emitWarning()
-          << "TPUReplicateMetadata for associated '" << TF::kReplicationInfoAttr
-          << "' attribute '" << cluster_metadata_and_ops.getFirst()
-          << "' is missing";
+          << "TPUReplicateMetadata for associated '"
+          << mlir::TF::kReplicationInfoAttr << "' attribute '"
+          << cluster_metadata_and_ops.getFirst() << "' is missing";
       continue;
     }
 
@@ -955,7 +981,7 @@
     llvm::SmallVector<Value, 8> results =
         CollectClusterResults(block, cluster_ops);
 
-    tf_device::ClusterOp cluster = CreateClusterOp(
+    mlir::tf_device::ClusterOp cluster = CreateClusterOp(
         block, cluster_ops, results, cluster_successor_ops.getArrayRef());
 
     if (!has_replication) {
@@ -979,7 +1005,7 @@
     if (num_cores_per_replica_attr)
       num_cores_per_replica = num_cores_per_replica_attr.getInt();
     if (failed(ReplicateCluster(cluster, num_replicas, num_cores_per_replica)))
-      return failure();
+      return mlir::failure();
 
     // Copy TPUReplicateMetadata attributes to `tf_device.cluster`.
     cluster->setAttrs(
@@ -992,18 +1018,20 @@
 }
 
 LogicalResult FormClustersInFunction(
-    func::FuncOp func, const TF::SideEffectAnalysis::Info& side_effect_analysis,
+    mlir::func::FuncOp func,
+    const mlir::TF::SideEffectAnalysis::Info& side_effect_analysis,
     bool strict_clusters) {
   if (!llvm::hasSingleElement(func))
     return func.emitOpError("Expecting a single block function");
 
   if (failed(FormClustersInBlock(&func.front(), side_effect_analysis,
                                  strict_clusters)))
-    return failure();
+    return mlir::failure();
 
   // Remove TPUReplicatedInput and TPUReplicatedOutput nodes.
   auto remove_result = func.walk([&](Operation* op) {
-    if (!llvm::isa<TF::TPUReplicatedInputOp, TF::TPUReplicatedOutputOp>(op))
+    if (!llvm::isa<mlir::TF::TPUReplicatedInputOp,
+                   mlir::TF::TPUReplicatedOutputOp>(op))
       return WalkResult::advance();
 
     // Forward operand to result. When `num_replicas` attribute is 1, no
@@ -1026,7 +1054,7 @@
     return WalkResult::advance();
   });
 
-  return failure(remove_result.wasInterrupted());
+  return mlir::failure(remove_result.wasInterrupted());
 }
 
 void TPUClusterFormationPass::runOnOperation() {
@@ -1037,7 +1065,7 @@
   // TODO(kramm): Remove this once tf.Const's folder is aware of extra
   // attributes.
   auto value_str_attr = StringAttr::get(&getContext(), "value");
-  getOperation().walk([&](TF::ConstOp cst) {
+  getOperation().walk([&](mlir::TF::ConstOp cst) {
     auto dict = cst->getAttrDictionary();
     if (dict.size() == 1) {
       return;  // Optimization. Assume the one attribute is "value".
@@ -1048,8 +1076,8 @@
     cst->setAttrs(attributes.getDictionary(&getContext()));
   });
 
-  auto& side_effect_analysis = getAnalysis<TF::SideEffectAnalysis>();
-  for (auto func : getOperation().getOps<func::FuncOp>())
+  auto& side_effect_analysis = getAnalysis<mlir::TF::SideEffectAnalysis>();
+  for (auto func : getOperation().getOps<mlir::func::FuncOp>())
     if (!func.isExternal() &&
         failed(FormClustersInFunction(
             func, side_effect_analysis.GetAnalysisForFunc(func),
@@ -1058,10 +1086,11 @@
 }
 }  // anonymous namespace
 
-std::unique_ptr<OperationPass<ModuleOp>> CreateTPUClusterFormationPass(
+std::unique_ptr<mlir::OperationPass<ModuleOp>> CreateTPUClusterFormationPass(
     bool strict_clusters) {
   return std::make_unique<TPUClusterFormationPass>(strict_clusters);
 }
 
-}  // namespace TFTPU
-}  // namespace mlir
+}  // namespace internal
+}  // namespace tf2xla
+}  // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_clustering_pass.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_clustering_pass.cc
new file mode 100644
index 0000000..235a7ca
--- /dev/null
+++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_clustering_pass.cc
@@ -0,0 +1,69 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <set>
+#include <string>
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
+#include "mlir/IR/Operation.h"  // from @llvm-project
+#include "mlir/IR/Visitors.h"  // from @llvm-project
+#include "mlir/Pass/Pass.h"  // from @llvm-project
+
+namespace tensorflow {
+namespace tf2xla {
+namespace internal {
+
+namespace {
+
+#define GEN_PASS_DEF_VERIFYCLUSTERINGPASS
+#include "tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h.inc"
+
+class VerifyClusteringPass
+    : public impl::VerifyClusteringPassBase<VerifyClusteringPass> {
+ public:
+  void runOnOperation() override;
+};
+
+void VerifyClusteringPass::runOnOperation() {
+  std::set<std::string> valid_namespaces = {"tf", "func", "return", "tf_device",
+                                            "builtin"};
+  mlir::Operation* func_op = getOperation();
+
+  auto walk_result = func_op->walk([&](mlir::Operation* op) {
+    if (valid_namespaces.find(op->getDialect()->getNamespace().str()) ==
+        valid_namespaces.end()) {
+      std::string error = "op is in dialect " +
+                          op->getDialect()->getNamespace().str() +
+                          " not in tf functional dialect";
+      op->emitError() << error;
+      return mlir::WalkResult::interrupt();
+    }
+    return mlir::WalkResult::advance();
+  });
+
+  if (walk_result.wasInterrupted()) {
+    signalPassFailure();
+  }
+}
+}  // namespace
+
+std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
+CreateVerifyClusteringPass() {
+  return std::make_unique<VerifyClusteringPass>();
+}
+}  // namespace internal
+}  // namespace tf2xla
+}  // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_clustering_pass_test.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_clustering_pass_test.cc
new file mode 100644
index 0000000..6767a00
--- /dev/null
+++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_clustering_pass_test.cc
@@ -0,0 +1,87 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+
+#include <gtest/gtest.h>
+#include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/MLIRContext.h"  // from @llvm-project
+#include "mlir/IR/OwningOpRef.h"  // from @llvm-project
+#include "mlir/Pass/PassManager.h"  // from @llvm-project
+#include "mlir/Support/LogicalResult.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h"
+#include "tensorflow/compiler/mlir/tf2xla/transforms/test_utils.h"
+#include "tsl/platform/statusor.h"
+
+namespace tensorflow {
+namespace tf2xla {
+namespace internal {
+
+namespace {
+
+using mlir::mhlo::test::GetMlirModuleFromString;
+
+class VerifyClusteringPassTest : public testing::Test {
+ protected:
+  void CreateModule(const char* module_string) {
+    TF_ASSERT_OK_AND_ASSIGN(module_,
+                            GetMlirModuleFromString(module_string, &context_));
+    pm_ = std::make_unique<mlir::PassManager>(&context_);
+    pm_->addNestedPass<mlir::func::FuncOp>(CreateVerifyClusteringPass());
+  }
+
+  mlir::LogicalResult Run() { return pm_->run(module_.get()); }
+
+ private:
+  mlir::MLIRContext context_;
+  mlir::OwningOpRef<mlir::ModuleOp> module_;
+  std::unique_ptr<mlir::PassManager> pm_;
+};
+
+TEST_F(VerifyClusteringPassTest, OnlyTfFunctionalPasses) {
+  static constexpr char kMlirModuleStr[] = R"(
+  module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
+    func.func @main() -> tensor<1xi32> {
+      %0 = "tf.Const"() {value = dense<1000> : tensor<1xi32>} : () -> tensor<1xi32>
+      return %0 : tensor<1xi32>
+    }
+  })";
+  CreateModule(kMlirModuleStr);
+
+  auto result = Run();
+
+  EXPECT_TRUE(result.succeeded());
+}
+
+TEST_F(VerifyClusteringPassTest, NotTfFunctionalFails) {
+  static constexpr char kMlirModuleStr[] = R"(
+  module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
+    func.func @main() -> tensor<3x32x32x3xf32> {
+      %0 = mhlo.constant dense<2.550000e+02> : tensor<3x32x32x3xf32>
+      return %0 : tensor<3x32x32x3xf32>
+    }
+  })";
+  CreateModule(kMlirModuleStr);
+
+  auto result = Run();
+
+  EXPECT_TRUE(result.failed());
+}
+
+}  // namespace
+}  // namespace internal
+}  // namespace tf2xla
+}  // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_clustering_pass_test.mlir b/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_clustering_pass_test.mlir
new file mode 100644
index 0000000..23e6024
--- /dev/null
+++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_clustering_pass_test.mlir
@@ -0,0 +1,16 @@
+// RUN: tf-opt -verify-clustering-pass  -split-input-file -verify-diagnostics %s | FileCheck %s
+// Tests the VerifyClusteringPass Pass, ensures that an error is thrown when validation fails.
+
+func.func @testNotTfDialect(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
+ // expected-error@below {{op is in dialect chlo not in tf functional dialect}}
+  %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
+  func.return %0 : tensor<1x32x10x32xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @testTFDialect
+func.func @testTFDialect(%arg0: tensor<4x?x!tf_type.stringref>) -> tensor<4x2x!tf_type.string> {
+  %0 = "tf.Identity"(%arg0) : (tensor<4x?x!tf_type.stringref>) -> tensor<4x2x!tf_type.string>
+  func.return %0 : tensor<4x2x!tf_type.string>
+}
\ No newline at end of file
diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_input_dialect_to_executor.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_input_dialect_to_executor.cc
new file mode 100644
index 0000000..dd78c06
--- /dev/null
+++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_input_dialect_to_executor.cc
@@ -0,0 +1,44 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+    http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
+#include "mlir/Pass/Pass.h"  // from @llvm-project
+
+namespace tensorflow {
+namespace tf2xla {
+namespace internal {
+
+namespace {
+
+#define GEN_PASS_DEF_VERIFYINPUTDIALECTTOEXECUTORPASS
+#include "tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h.inc"
+
+class VerifyInputDialectToexecutorPass
+    : public impl::VerifyInputDialectToexecutorPassBase<
+          VerifyInputDialectToexecutorPass> {
+ public:
+  void runOnOperation() override;
+};
+
+void VerifyInputDialectToexecutorPass::runOnOperation() {}
+
+}  // namespace
+
+std::unique_ptr<OperationPass<func::FuncOp>>
+CreateVerifyInputDialectToExecutorPass() {
+  return std::make_unique<VerifyInputDialectToexecutorPass>();
+}
+
+}  // namespace internal
+}  // namespace tf2xla
+}  // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-binary-elementwise.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-binary-elementwise.mlir
index 30dc8ec..01dc470 100644
--- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-binary-elementwise.mlir
+++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-binary-elementwise.mlir
@@ -266,28 +266,28 @@
 
 // CHECK-LABEL: func @equal_broadcast_no_incompatible_shapes_error
 func.func @equal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
-  // CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = true}
+  // CHECK-NEXT: "tf.Equal"(%arg0, %arg1) <{incompatible_shape_error = true}>
   %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
   func.return %0: tensor<1x2xi1>
 }
 
 // CHECK-LABEL: func @equal_incompatible_shape_broadcastable
 func.func @equal_incompatible_shape_broadcastable(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1> {
-  // CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = true}
+  // CHECK-NEXT: "tf.Equal"(%arg0, %arg1) <{incompatible_shape_error = true}>
   %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
   func.return %0: tensor<?xi1>
 }
 
 // CHECK-LABEL: func @equal_incompatible_shape_dynamic
 func.func @equal_incompatible_shape_dynamic(%arg0: tensor<2xi32>, %arg1: tensor<?xi32>) -> tensor<*xi1> {
-  // CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false}
+  // CHECK-NEXT: "tf.Equal"(%arg0, %arg1) <{incompatible_shape_error = false}>
   %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<2xi32>, tensor<?xi32>) -> tensor<*xi1>
   func.return %0: tensor<*xi1>
 }
 
 // CHECK-LABEL: func @equal_incompatible_shape_both_dynamic
 func.func @equal_incompatible_shape_both_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<*xi1> {
-  // CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false}
+  // CHECK-NEXT: "tf.Equal"(%arg0, %arg1) <{incompatible_shape_error = false}>
   %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<?xi32>, tensor<?xi32>) -> tensor<*xi1>
   func.return %0: tensor<*xi1>
 }
diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla-hlo-importer.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla-hlo-importer.mlir
index 32ef8cb..a732c6d 100644
--- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla-hlo-importer.mlir
+++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla-hlo-importer.mlir
@@ -701,9 +701,8 @@
 
   // Verifies that the following functions are added from xla_call_module. Note this must be at the end of the file.
   // CHECK: func.func private @main.2(%arg0: tensor<f32> {mhlo.sharding = "{replicated}"}) -> tensor<f32> {
-  // CHECK:   %0 = mhlo.bitcast_convert %arg0 : (tensor<f32>) -> tensor<f32> 
-  // CHECK:   %1 = mhlo.sine %0 : tensor<f32>
-  // CHECK:   return %1 : tensor<f32>
+  // CHECK:   %0 = mhlo.sine %arg0 : tensor<f32>
+  // CHECK:   return %0 : tensor<f32>
   // CHECK: }
 
 }
diff --git a/tensorflow/compiler/mlir/tf2xla/tests/tfxla_device_specific_transformations_cpu.mlir b/tensorflow/compiler/mlir/tf2xla/tests/tfxla_device_specific_transformations_cpu.mlir
index c06ba74..67aa69d 100644
--- a/tensorflow/compiler/mlir/tf2xla/tests/tfxla_device_specific_transformations_cpu.mlir
+++ b/tensorflow/compiler/mlir/tf2xla/tests/tfxla_device_specific_transformations_cpu.mlir
@@ -4,9 +4,9 @@
 
 // CHECK-LABEL: stateless_op
 func.func @stateless_op() -> tensor<i32> {
-  // CHECK: %cst = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: %cst = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
   %0 = "tf.StatelessRandomGetAlg"() {device = ""} : () -> tensor<i32>
   return %0 : tensor<i32>
 }
 
-}
\ No newline at end of file
+}
diff --git a/tensorflow/compiler/mlir/tf2xla/tests/tfxla_device_specific_transformations_gpu.mlir b/tensorflow/compiler/mlir/tf2xla/tests/tfxla_device_specific_transformations_gpu.mlir
index f051960..4d5da14 100644
--- a/tensorflow/compiler/mlir/tf2xla/tests/tfxla_device_specific_transformations_gpu.mlir
+++ b/tensorflow/compiler/mlir/tf2xla/tests/tfxla_device_specific_transformations_gpu.mlir
@@ -4,9 +4,9 @@
 
 // CHECK-LABEL: stateless_op
 func.func @stateless_op() -> tensor<i32> {
-  // CHECK: %cst = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: %cst = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
   %0 = "tf.StatelessRandomGetAlg"() {device = ""} : () -> tensor<i32>
   return %0 : tensor<i32>
 }
 
-}
\ No newline at end of file
+}
diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD
index b4f5d00..ed0429a 100644
--- a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD
+++ b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD
@@ -1,10 +1,10 @@
 # Description:
 #    TF2XLA Bridge transforms
 
-load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
-load("//tensorflow:tensorflow.bzl", "tf_cc_test")
 load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")
+load("//tensorflow:tensorflow.bzl", "tf_cc_test")
 load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable")
+load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
 load("@local_tsl//tsl/platform:build_config_root.bzl", "if_static")
 
 package(
@@ -487,13 +487,12 @@
         "//tensorflow/compiler/mlir/tensorflow",
         "//tensorflow/compiler/tf2xla:xla_op_registry",
         "//tensorflow/compiler/tf2xla/kernels:xla_ops",
-        "//tensorflow/core/tpu:tpu_defs",
         "@com_google_absl//absl/status",
         "@com_google_googletest//:gtest_main",
         "@llvm-project//mlir:FuncDialect",
         "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Support",
         "@local_tsl//tsl/lib/core:status_test_util",
-        "@local_tsl//tsl/platform:errors",
         "@local_tsl//tsl/platform:status",
         "@local_tsl//tsl/platform:statusor",
     ],
diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc
index 24f7711..ade2b5f 100644
--- a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc
+++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc
@@ -20,21 +20,22 @@
 #include <string>
 #include <vector>
 
-#include <gmock/gmock.h>
 #include <gtest/gtest.h>
 #include "absl/status/status.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/DialectRegistry.h"  // from @llvm-project
+#include "mlir/IR/OperationSupport.h"  // from @llvm-project
+#include "mlir/IR/OwningOpRef.h"  // from @llvm-project
+#include "mlir/IR/PatternMatch.h"  // from @llvm-project
+#include "mlir/Support/TypeID.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/register_common_dialects.h"
-#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 #include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h"
 #include "tensorflow/compiler/mlir/tf2xla/transforms/test_utils.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/core/tpu/tpu_defs.h"
 #include "tsl/lib/core/status_test_util.h"
-#include "tsl/platform/errors.h"
 #include "tsl/platform/status.h"
 #include "tsl/platform/statusor.h"
 
@@ -131,7 +132,7 @@
   // a new op, we should expect these to change too.
   EXPECT_EQ(mlir_lowering_count, 67);
   EXPECT_EQ(tf2xla_fallback_count, 315);
-  EXPECT_EQ(non_categorized_count, 420);
+  EXPECT_EQ(non_categorized_count, 422);
 }
 
 // Just a counter test to see which ops have duplicate lowerings. This isn't a
diff --git a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc
index 4a454c3..f803230 100644
--- a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc
+++ b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc
@@ -27,6 +27,7 @@
 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/mlprogram_util.h"
 #include "tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h"
+#include "tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h"
 #include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h"
 #include "tensorflow/compiler/mlir/tosa/tf_passes.h"
 #include "tensorflow/compiler/mlir/tosa/tf_tfl_passes.h"
@@ -55,6 +56,7 @@
   mlir::mhlo::registerLegalizeTfPasses();
   mlir::mhlo::registerTfXlaPasses();
   mlir::quant::stablehlo::registerBridgePasses();
+  tensorflow::tf2xla::internal::registerTFXLABridgePasses();
   mlir::tosa::registerLegalizeTosaPasses();
   mlir::tosa::registerTFtoTOSALegalizationPipeline();
   mlir::tosa::registerTFLtoTOSALegalizationPipeline();
diff --git a/tensorflow/compiler/mlir/tfr/examples/mnist/BUILD b/tensorflow/compiler/mlir/tfr/examples/mnist/BUILD
index ee79b37..135bc20 100644
--- a/tensorflow/compiler/mlir/tfr/examples/mnist/BUILD
+++ b/tensorflow/compiler/mlir/tfr/examples/mnist/BUILD
@@ -81,6 +81,7 @@
         "notap",  # The test is too long to run as part of llvm presubmits (b/173661843).
         "notpu",  # Takes too long (b/192305423)
         "notsan",  # Not needed, and there were issues with timeouts.
+        "requires-net:external",
     ],
 
     # TODO(b/175056184): Re-enable xla_enable_strict_auto_jit once the issues
diff --git a/tensorflow/compiler/mlir/tfr/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tfr/tests/canonicalize.mlir
index 77508b6..912385f 100644
--- a/tensorflow/compiler/mlir/tfr/tests/canonicalize.mlir
+++ b/tensorflow/compiler/mlir/tfr/tests/canonicalize.mlir
@@ -41,7 +41,7 @@
   %1 = "tfr.constant_tensor"(%0) : (!tfr.attr) -> !tfr.tensor
   func.return %1 : !tfr.tensor
 
-// CHECK-NEXT: %[[RES:.*]] = "tf.Const"() {value = dense<[1, -1, 3]> : tensor<3xi64>} : () -> tensor<3xi64>
+// CHECK-NEXT: %[[RES:.*]] = "tf.Const"() <{value = dense<[1, -1, 3]> : tensor<3xi64>}> : () -> tensor<3xi64>
 // CHECK-NEXT: "tfr.cast"(%[[RES]]) : (tensor<3xi64>) -> !tfr.tensor
 // CHECK-NEXT: return
 }
@@ -54,7 +54,7 @@
   %1 = "tfr.constant_tensor"(%0) : (i32) -> !tfr.tensor
   func.return %1 : !tfr.tensor
 
-// CHECK-NEXT: %[[RES:.*]] = "tf.Const"() {value = dense<42> : tensor<i32>} : () -> tensor<i32>
+// CHECK-NEXT: %[[RES:.*]] = "tf.Const"() <{value = dense<42> : tensor<i32>}> : () -> tensor<i32>
 // CHECK-NEXT: "tfr.cast"(%[[RES]]) : (tensor<i32>) -> !tfr.tensor
 // CHECK-NEXT: return
 }
@@ -83,7 +83,7 @@
 
 // CHECK-LABEL:  quant_raw_data_with_list
 func.func @quant_raw_data_with_list(%arg0: !tfr.tensor, %arg1: !tfr.tensor) -> !tfr.tensor {
-  %cst_1 = "tf.Const"() {value = dense<1> : tensor<i64>} : () -> tensor<i64>
+  %cst_1 = "tf.Const"() <{value = dense<1> : tensor<i64>}> : () -> tensor<i64>
   %1 = "tfr.cast"(%arg0) : (!tfr.tensor) -> tensor<1x4x4x3x!quant.uniform<i8:f32, 0.0078420601785182952:-1>>
   %2 = "tfr.cast"(%arg1) : (!tfr.tensor) -> tensor<1x3x4x3x!quant.uniform<i8:f32, 0.0078420601785182952:-1>>
   %3 = "tfr.cast"(%2) : (tensor<1x3x4x3x!quant.uniform<i8:f32, 0.0078420601785182952:-1>>) -> !tfr.tensor
@@ -94,7 +94,7 @@
   %8 = tfr.call @tf__concat(%7, %6) : (!tfr.tensor, !tfr.tensor_list) -> !tfr.tensor
   func.return %8 : !tfr.tensor
 
-// CHECK: %[[CONST_0:.*]] = "tf.Const"() {value = dense<1> : tensor<i64>} : () -> tensor<i64>
+// CHECK: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<1> : tensor<i64>}> : () -> tensor<i64>
 // CHECK: %[[BUILD_LIST_0:.*]] = "tfr.build_list"(%arg1, %arg0) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor_list
 // CHECK: %[[CAST_0:.*]] = "tfr.cast"(%[[CONST_0]]) : (tensor<i64>) -> !tfr.tensor
 // CHECK: %[[CONCAT_O:.*]] = tfr.call @tf__concat(%[[CAST_0]], %[[BUILD_LIST_0]]) : (!tfr.tensor, !tfr.tensor_list) -> !tfr.tensor
@@ -131,8 +131,8 @@
   %2 = "tfr.cast"(%zp) : (!tfr.tensor) -> tensor<i32>
   func.return %1, %2 : tensor<f32>, tensor<i32>
 
-// CHECK-DAG: %[[scale:.*]] = "tf.Const"() {value = dense<1.000000e-01> : tensor<f32>}
-// CHECK-DAG: %[[zp:.*]] = "tf.Const"() {value = dense<42> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG: %[[scale:.*]] = "tf.Const"() <{value = dense<1.000000e-01> : tensor<f32>}>
+// CHECK-DAG: %[[zp:.*]] = "tf.Const"() <{value = dense<42> : tensor<i32>}> : () -> tensor<i32>
 // CHECK: return %[[scale]], %[[zp]]
 }
 
@@ -144,8 +144,8 @@
   %2 = "tfr.cast"(%zp) : (!tfr.tensor) -> tensor<3xi32>
   func.return %1, %2 : tensor<3xf32>, tensor<3xi32>
 
-// CHECK-DAG: %[[scale:.*]] = "tf.Const"() {value = dense<[1.000000e-01, 2.000000e-01, 3.000000e-01]> : tensor<3xf32>}
-// CHECK-DAG: %[[zp:.*]] = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32>
+// CHECK-DAG: %[[scale:.*]] = "tf.Const"() <{value = dense<[1.000000e-01, 2.000000e-01, 3.000000e-01]> : tensor<3xf32>}>
+// CHECK-DAG: %[[zp:.*]] = "tf.Const"() <{value = dense<[1, 2, 3]> : tensor<3xi32>}> : () -> tensor<3xi32>
 // CHECK: return %[[scale]], %[[zp]]
 }
 
@@ -168,9 +168,9 @@
   %2 = "tfr.cast"(%0) : (!tfr.tensor) -> tensor<2xi32>
   func.return %1, %2 : tensor<*xi32>, tensor<2xi32>
 
-// CHECK: %[[tf_cast_unranked:.*]] = "tf.Cast"(%arg0) {Truncate = false} : (tensor<*xf32>) -> tensor<*xi32>
-// CHECK: %[[ensure_shape:.*]] = "tf.EnsureShape"(%arg0) {shape = #tf_type.shape<2>} : (tensor<*xf32>) -> tensor<2xf32>
-// CHECK: %[[tf_cast_ranked:.*]] = "tf.Cast"(%[[ensure_shape]]) {Truncate = false} : (tensor<2xf32>) -> tensor<2xi32>
+// CHECK: %[[tf_cast_unranked:.*]] = "tf.Cast"(%arg0) <{Truncate = false}> : (tensor<*xf32>) -> tensor<*xi32>
+// CHECK: %[[ensure_shape:.*]] = "tf.EnsureShape"(%arg0) <{shape = #tf_type.shape<2>}> : (tensor<*xf32>) -> tensor<2xf32>
+// CHECK: %[[tf_cast_ranked:.*]] = "tf.Cast"(%[[ensure_shape]]) <{Truncate = false}> : (tensor<2xf32>) -> tensor<2xi32>
 // CHECK: return %[[tf_cast_unranked]], %[[tf_cast_ranked]] :  tensor<*xi32>, tensor<2xi32>
 }
 
@@ -185,7 +185,7 @@
   func.return %3 : tensor<10xi32>
 // CHECK: %[[CAST_0:.*]] = "tfr.cast"(%arg0) : (tensor<10x!quant.uniform<i8:f32, 0.0039133410900831223:-128>>) -> !tfr.tensor
 // CHECK: %[[CAST_1:.*]] = "tfr.cast"(%[[CAST_0]]) : (!tfr.tensor) -> tensor<10xi8>
-// CHECK: %[[CAST_2:.*]] = "tf.Cast"(%[[CAST_1]]) {Truncate = false} : (tensor<10xi8>) -> tensor<10xi32>
+// CHECK: %[[CAST_2:.*]] = "tf.Cast"(%[[CAST_1]]) <{Truncate = false}> : (tensor<10xi8>) -> tensor<10xi32>
 // CHECK: return %[[CAST_2]] : tensor<10xi32>
 }
 
diff --git a/tensorflow/compiler/mlir/tfr/tests/decompose.mlir b/tensorflow/compiler/mlir/tfr/tests/decompose.mlir
index 0dcd363..eb35c3a 100644
--- a/tensorflow/compiler/mlir/tfr/tests/decompose.mlir
+++ b/tensorflow/compiler/mlir/tfr/tests/decompose.mlir
@@ -227,8 +227,8 @@
   %list2 = "tfr.build_list"(%input_scale_tensor, %perchannel_scale_tensor) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor_list
   %perchannel = "tfr.quant_scale_factor"(%output_scale, %list2) : (f32, !tfr.tensor_list) -> !tfr.tensor
   func.return %out, %perchannel : !tfr.tensor, !tfr.tensor
-// CHECK-DAG: %[[scale_factors:.*]] = "tf.Const"() {value = dense<[1.000000e+00, 1.000000e+01]> : tensor<2xf32>} : () -> tensor<2xf32>
-// CHECK-DAG: %[[scale_factor:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
+// CHECK-DAG: %[[scale_factors:.*]] = "tf.Const"() <{value = dense<[1.000000e+00, 1.000000e+01]> : tensor<2xf32>}> : () -> tensor<2xf32>
+// CHECK-DAG: %[[scale_factor:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
 // CHECK: %[[cast:.*]] = "tfr.cast"(%[[scale_factor]]) : (tensor<f32>) -> !tfr.tensor
 // CHECK: %[[cast_perchannel:.*]] = "tfr.cast"(%[[scale_factors]]) : (tensor<2xf32>) -> !tfr.tensor
 // CHECK: return %[[cast]], %[[cast_perchannel]] : !tfr.tensor, !tfr.tensor
@@ -245,8 +245,8 @@
   %out = "tfr.quant_scale_factor"(%output_scale, %list) : (f32, !tfr.tensor_list) -> !tfr.tensor
   func.return %out : !tfr.tensor
 // CHECK-DAG: %[[cst_0:.*]] = arith.constant 1.000000e-01 : f32
-// CHECK-DAG: %[[cst_1:.*]] = "tf.Const"() {value = dense<2.500000e-01> : tensor<f32>} : () -> tensor<f32>
-// CHECK-DAG: %[[cst_2:.*]] = "tf.Const"() {value = dense<4.000000e-01> : tensor<f32>} : () -> tensor<f32>
+// CHECK-DAG: %[[cst_1:.*]] = "tf.Const"() <{value = dense<2.500000e-01> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG: %[[cst_2:.*]] = "tf.Const"() <{value = dense<4.000000e-01> : tensor<f32>}> : () -> tensor<f32>
 // CHECK: %[[tfrcast0:.*]] = "tfr.cast"(%[[cst_1]]) : (tensor<f32>) -> !tfr.tensor
 // CHECK: %[[tfrcast1:.*]] = "tfr.cast"(%[[cst_2]]) : (tensor<f32>) -> !tfr.tensor
 // CHECK: %[[list:.*]] = "tfr.build_list"(%[[tfrcast0]], %[[tfrcast1]], %[[tfrcast0]]) : (!tfr.tensor, !tfr.tensor, !tfr.tensor) -> !tfr.tensor_list
@@ -265,9 +265,9 @@
 
 // CHECK-DAG: %[[f32:.*]] = tfr.constant f32 -> !tfr.attr
 // CHECK-DAG: %[[i32:.*]] = tfr.constant i32 -> !tfr.attr
-// CHECK-DAG: %[[scale_cst:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
+// CHECK-DAG: %[[scale_cst:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
 // CHECK-DAG: %false = arith.constant false
-// CHECK-DAG: %[[zp_cst:.*]] = "tf.Const"() {value = dense<67> : tensor<i64>} : () -> tensor<i64>
+// CHECK-DAG: %[[zp_cst:.*]] = "tf.Const"() <{value = dense<67> : tensor<i64>}> : () -> tensor<i64>
 // CHECK: %[[zp:.*]] = "tfr.cast"(%[[zp_cst]]) : (tensor<i64>) -> !tfr.tensor
 // CHECK: %[[scale:.*]] = "tfr.cast"(%[[scale_cst]]) : (tensor<f32>) -> !tfr.tensor
 // CHECK: %[[input:.*]] = "tfr.cast"(%arg0) : (tensor<2xi32>) -> !tfr.tensor
diff --git a/tensorflow/compiler/mlir/tfr/tests/end2end.mlir b/tensorflow/compiler/mlir/tfr/tests/end2end.mlir
index 6a49a0d..0654b21 100644
--- a/tensorflow/compiler/mlir/tfr/tests/end2end.mlir
+++ b/tensorflow/compiler/mlir/tfr/tests/end2end.mlir
@@ -17,7 +17,7 @@
 
 // CHECK-NEXT: %[[RE:.*]] = "tf.RiscReciprocal"(%arg0) : (tensor<2x3xf32>) -> tensor<*xf32>
 // CHECK-NEXT: %[[SQRT:.*]] = "tf.RiscSqrt"(%[[RE]]) : (tensor<*xf32>) -> tensor<*xf32>
-// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[SQRT]]) {shape = #tf_type.shape<3x2x3>} : (tensor<*xf32>) -> tensor<3x2x3xf32>
+// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[SQRT]]) <{shape = #tf_type.shape<3x2x3>}> : (tensor<*xf32>) -> tensor<3x2x3xf32>
 // CHECK-NEXT: return %[[ES]] : tensor<3x2x3xf32>
 }
 
@@ -26,11 +26,11 @@
   %0 = "tf.MyLeakyRelu"(%arg0) {alpha=3.0 : f32} : (tensor<2x3xf32>) -> tensor<3x2x3xf32>
   func.return %0 : tensor<3x2x3xf32>
 
-// CHECK-NEXT: %[[ALPHA:.*]] = "tf.Const"() {value = dense<3.000000e+00> : tensor<f32>} : () -> tensor<f32>
+// CHECK-NEXT: %[[ALPHA:.*]] = "tf.Const"() <{value = dense<3.000000e+00> : tensor<f32>}> : () -> tensor<f32>
 // CHECK-NEXT: %[[SHAPE:.*]] = "tf.RiscShape"(%arg0) {T = i32} : (tensor<2x3xf32>) -> tensor<*xi32>
 // CHECK-NEXT: %[[ALPHA1:.*]] = "tf.RiscBroadcast"(%[[ALPHA]], %[[SHAPE]]) : (tensor<f32>, tensor<*xi32>) -> tensor<*xf32>
 // CHECK-NEXT: %[[MAX:.*]] = "tf.RiscMaximum"(%arg0, %[[ALPHA1]]) : (tensor<2x3xf32>, tensor<*xf32>) -> tensor<*xf32>
-// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[MAX]]) {shape = #tf_type.shape<3x2x3>} : (tensor<*xf32>) -> tensor<3x2x3xf32>
+// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[MAX]]) <{shape = #tf_type.shape<3x2x3>}> : (tensor<*xf32>) -> tensor<3x2x3xf32>
 // CHECK-NEXT: return %[[ES]] : tensor<3x2x3xf32>
 }
 
@@ -39,11 +39,11 @@
   %0 = "tf.MyLeakyRelu"(%arg0) : (tensor<2x3xf32>) -> tensor<3x2x3xf32>
   func.return %0 : tensor<3x2x3xf32>
 
-// CHECK-NEXT: %[[ALPHA:.*]] = "tf.Const"() {value = dense<2.000000e-01> : tensor<f32>} : () -> tensor<f32>
+// CHECK-NEXT: %[[ALPHA:.*]] = "tf.Const"() <{value = dense<2.000000e-01> : tensor<f32>}> : () -> tensor<f32>
 // CHECK-NEXT: %[[SHAPE:.*]] = "tf.RiscShape"(%arg0) {T = i32} : (tensor<2x3xf32>) -> tensor<*xi32>
 // CHECK-NEXT: %[[ALPHA1:.*]] = "tf.RiscBroadcast"(%[[ALPHA]], %[[SHAPE]]) : (tensor<f32>, tensor<*xi32>) -> tensor<*xf32>
 // CHECK-NEXT: %[[MAX:.*]] = "tf.RiscMaximum"(%arg0, %[[ALPHA1]]) : (tensor<2x3xf32>, tensor<*xf32>) -> tensor<*xf32>
-// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[MAX]]) {shape = #tf_type.shape<3x2x3>} : (tensor<*xf32>) -> tensor<3x2x3xf32>
+// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[MAX]]) <{shape = #tf_type.shape<3x2x3>}> : (tensor<*xf32>) -> tensor<3x2x3xf32>
 // CHECK-NEXT: return %[[ES]] : tensor<3x2x3xf32>
 }
 
@@ -53,7 +53,7 @@
   func.return %0 : tensor<2x3xi32>
 
 // CHECK-NEXT: %[[CAST:.*]] = "tf.RiscCast"(%arg0) {Tout = i32} : (tensor<2x3xf32>) -> tensor<*xi32>
-// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[CAST]]) {shape = #tf_type.shape<2x3>} : (tensor<*xi32>) -> tensor<2x3xi32>
+// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[CAST]]) <{shape = #tf_type.shape<2x3>}> : (tensor<*xi32>) -> tensor<2x3xi32>
 // CHECK-NEXT: return %[[ES]] : tensor<2x3xi32>
 }
 
@@ -62,9 +62,9 @@
   %0 = "tf.MyPack"(%arg0) {N=1:i32, axis=0:i32} : (tensor<2x3xf32>) -> tensor<3x2x3xf32>
   func.return %0 : tensor<3x2x3xf32>
 
-// CHECK-NEXT: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-NEXT: %[[AXIS:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK-NEXT: %[[ED:.*]] = "tf.ExpandDims"(%arg0, %[[AXIS]]) : (tensor<2x3xf32>, tensor<i32>) -> tensor<*xf32>
-// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[ED]]) {shape = #tf_type.shape<3x2x3>} : (tensor<*xf32>) -> tensor<3x2x3xf32>
+// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[ED]]) <{shape = #tf_type.shape<3x2x3>}> : (tensor<*xf32>) -> tensor<3x2x3xf32>
 // CHECK-NEXT: return %[[ES]] : tensor<3x2x3xf32>
 }
 
@@ -73,13 +73,13 @@
   %0 = "tf.MyPack"(%arg0, %arg1, %arg2) {N=3:i32, axis=0:i32} : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<3x2x3xf32>
   func.return %0 : tensor<3x2x3xf32>
 
-// CHECK-NEXT: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK-NEXT: %[[AXIS:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
 // CHECK-NEXT: %[[ED0:.*]] = "tf.ExpandDims"(%arg0, %[[AXIS]]) : (tensor<2x3xf32>, tensor<i32>) -> tensor<*xf32>
 // CHECK-NEXT: %[[ED1:.*]] = "tf.ExpandDims"(%arg1, %[[AXIS]]) : (tensor<2x3xf32>, tensor<i32>) -> tensor<*xf32>
 // CHECK-NEXT: %[[CC0:.*]] = "tf.RiscConcat"(%[[ED0]], %[[ED1]]) {axis = 0 : i32} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
 // CHECK-NEXT: %[[ED2:.*]] = "tf.ExpandDims"(%arg2, %[[AXIS]]) : (tensor<2x3xf32>, tensor<i32>) -> tensor<*xf32>
 // CHECK-NEXT: %[[CC1:.*]] = "tf.RiscConcat"(%[[CC0]], %[[ED2]]) {axis = 0 : i32} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
-// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[CC1]]) {shape = #tf_type.shape<3x2x3>} : (tensor<*xf32>) -> tensor<3x2x3xf32>
+// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[CC1]]) <{shape = #tf_type.shape<3x2x3>}> : (tensor<*xf32>) -> tensor<3x2x3xf32>
 // CHECK-NEXT: return %[[ES]] : tensor<3x2x3xf32>
 }
 
@@ -98,7 +98,7 @@
 
 // CHECK-NEXT: %[[ADD0:.*]] = "tf.RiscAdd"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<*xf32>
 // CHECK-NEXT: %[[ADD1:.*]] = "tf.RiscAdd"(%[[ADD0]], %arg2) : (tensor<*xf32>, tensor<2x3xf32>) -> tensor<*xf32>
-// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[ADD1]]) {shape = #tf_type.shape<2x3>} : (tensor<*xf32>) -> tensor<2x3xf32>
+// CHECK-NEXT: %[[ES:.*]] = "tf.EnsureShape"(%[[ADD1]]) <{shape = #tf_type.shape<2x3>}> : (tensor<*xf32>) -> tensor<2x3xf32>
 // CHECK-NEXT: return %[[ES]] : tensor<2x3xf32>
 }
 
@@ -112,10 +112,10 @@
     : (tensor<*x!tf_type.variant>, tensor<*xf32>, tensor<*xi32>) -> tensor<*x!tf_type.variant>
   func.return %0 : tensor<*x!tf_type.variant>
 
-// CHECK-DAG: %[[BATCH:.*]] = "tf.Const"() {value = dense<1000> : tensor<i64>} : () -> tensor<i64>
-// CHECK-DAG: %[[PARAL:.*]] = "tf.Const"() {value = dense<8> : tensor<i64>} : () -> tensor<i64>
-// CHECK-DAG: %[[KEEP:.*]] = "tf.Const"() {value = dense<false> : tensor<i1>} : () -> tensor<i1>
-// CHECK: %[[CAST:.*]] = "tf.Cast"(%arg2) {Truncate = false} : (tensor<*xi32>) -> tensor<*xf32>
+// CHECK-DAG: %[[BATCH:.*]] = "tf.Const"() <{value = dense<1000> : tensor<i64>}> : () -> tensor<i64>
+// CHECK-DAG: %[[PARAL:.*]] = "tf.Const"() <{value = dense<8> : tensor<i64>}> : () -> tensor<i64>
+// CHECK-DAG: %[[KEEP:.*]] = "tf.Const"() <{value = dense<false> : tensor<i1>}> : () -> tensor<i1>
+// CHECK: %[[CAST:.*]] = "tf.Cast"(%arg2) <{Truncate = false}> : (tensor<*xi32>) -> tensor<*xf32>
 // CHECK: %[[RET:.*]] = "tf.MapAndBatchDatasetV0"(%arg0, %[[BATCH]], %[[PARAL]], %[[KEEP]], %arg1, %[[CAST]])
 // CHECK-SAME: {f = @__some_func, output_shapes = [#tf_type.shape<>], output_types = [f32], preserve_cardinality = true} : (tensor<*x!tf_type.variant>, tensor<i64>, tensor<i64>, tensor<i1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*x!tf_type.variant>
 // CHECK: return %[[RET]] : tensor<*x!tf_type.variant>
diff --git a/tensorflow/compiler/mlir/tfr/tests/raise_to_tf.mlir b/tensorflow/compiler/mlir/tfr/tests/raise_to_tf.mlir
index 4be59b5..14d2127 100644
--- a/tensorflow/compiler/mlir/tfr/tests/raise_to_tf.mlir
+++ b/tensorflow/compiler/mlir/tfr/tests/raise_to_tf.mlir
@@ -17,7 +17,7 @@
   func.return %2 : tensor<1x2x3x4x!tf_type.string>
 
 // CHECK: %[[id:.*]] = "tf.RiscSame"(%arg0) : (tensor<1x2x3x4x!tf_type.string>) -> tensor<*x!tf_type.string>
-// CHECK: %[[es:.*]] = "tf.EnsureShape"(%[[id]]) {shape = #tf_type.shape<1x2x3x4>} : (tensor<*x!tf_type.string>) -> tensor<1x2x3x4x!tf_type.string>
+// CHECK: %[[es:.*]] = "tf.EnsureShape"(%[[id]]) <{shape = #tf_type.shape<1x2x3x4>}> : (tensor<*x!tf_type.string>) -> tensor<1x2x3x4x!tf_type.string>
 // CHECK: return %[[es]] : tensor<1x2x3x4x!tf_type.string>
 }
 
@@ -32,7 +32,7 @@
 
 // CHECK: %[[id0:.*]] = "tf.RiscSame"(%arg0) : (tensor<1x2x3x4x!tf_type.string>) -> tensor<*x!tf_type.string>
 // CHECK: %[[id2:.*]] = "tf.RiscSame"(%arg2) : (tensor<f32>) -> tensor<*xf32>
-// CHECK: %[[es:.*]] = "tf.EnsureShape"(%[[id2]]) {shape = #tf_type.shape<>} : (tensor<*xf32>) -> tensor<f32>
+// CHECK: %[[es:.*]] = "tf.EnsureShape"(%[[id2]]) <{shape = #tf_type.shape<>}> : (tensor<*xf32>) -> tensor<f32>
 // CHECK: return %[[es]] : tensor<f32>
 }
 
@@ -47,7 +47,7 @@
   func.return %4 : tensor<3xf32>
 
 // CHECK: %[[concat:.*]] = "tf.RiscConcat"(%arg0, %arg1, %arg2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<*xf32>
-// CHECK: %[[es:.*]] = "tf.EnsureShape"(%[[concat]]) {shape = #tf_type.shape<3>} : (tensor<*xf32>) -> tensor<3xf32>
+// CHECK: %[[es:.*]] = "tf.EnsureShape"(%[[concat]]) <{shape = #tf_type.shape<3>}> : (tensor<*xf32>) -> tensor<3xf32>
 // CHECK: return %[[es]] : tensor<3xf32>
 }
 
@@ -62,7 +62,7 @@
   func.return %4 : tensor<f32>
 
 // CHECK: %[[split:.*]]:3 = "tf.RiscSplit"(%arg0) {N = 3 : i32} : (tensor<3xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>)
-// CHECK: %[[es:.*]] = "tf.EnsureShape"(%[[split]]#0) {shape = #tf_type.shape<>} : (tensor<*xf32>) -> tensor<f32>
+// CHECK: %[[es:.*]] = "tf.EnsureShape"(%[[split]]#0) <{shape = #tf_type.shape<>}> : (tensor<*xf32>) -> tensor<f32>
 // CHECK: return %[[es]] : tensor<f32>
 }
 
@@ -75,7 +75,7 @@
   func.return %4 : tensor<i32>
 
 // CHECK: %[[tfcast:.*]] = "tf.RiscCast"(%arg0) {K = i32} : (tensor<f32>) -> tensor<*xi32>
-// CHECK: %[[es:.*]] = "tf.EnsureShape"(%[[tfcast]]) {shape = #tf_type.shape<>} : (tensor<*xi32>) -> tensor<i32>
+// CHECK: %[[es:.*]] = "tf.EnsureShape"(%[[tfcast]]) <{shape = #tf_type.shape<>}> : (tensor<*xi32>) -> tensor<i32>
 // CHECK: return %[[es]] : tensor<i32>
 }
 
@@ -87,7 +87,7 @@
   %4 = "tfr.cast"(%cst) : (!tfr.tensor) -> tensor<f32>
   func.return %4 : tensor<f32>
 
-// CHECK: %[[cst:.*]] = "tf.Const"() {value = dense<3.000000e+00> : tensor<f32>} : () -> tensor<f32>
+// CHECK: %[[cst:.*]] = "tf.Const"() <{value = dense<3.000000e+00> : tensor<f32>}> : () -> tensor<f32>
 // CHECK: return %[[cst]] : tensor<f32>
 }
 
@@ -100,7 +100,7 @@
   func.return %4 : tensor<i32>
 
 // CHECK: %[[tfcast:.*]] = "tf.RiscCast"(%arg0) {K = i32, _tpu_replicate, device = "hello"} : (tensor<f32>) -> tensor<*xi32>
-// CHECK: %[[es:.*]] = "tf.EnsureShape"(%[[tfcast]]) {shape = #tf_type.shape<>} : (tensor<*xi32>) -> tensor<i32>
+// CHECK: %[[es:.*]] = "tf.EnsureShape"(%[[tfcast]]) <{shape = #tf_type.shape<>}> : (tensor<*xi32>) -> tensor<i32>
 // CHECK: return %[[es]] : tensor<i32>
 }
 
@@ -111,7 +111,7 @@
   %2 = "tfr.cast"(%1) : (!tfr.tensor) -> tensor<2xi1>
   func.return %2 : tensor<2xi1>
 // CHECK: %[[positive:.*]] = "tf.Positive"(%arg0) : (tensor<2xf32>) -> tensor<*xi1>
-// CHECK: %[[res:.*]] = "tf.EnsureShape"(%[[positive]]) {shape = #tf_type.shape<2>} : (tensor<*xi1>) -> tensor<2xi1>
+// CHECK: %[[res:.*]] = "tf.EnsureShape"(%[[positive]]) <{shape = #tf_type.shape<2>}> : (tensor<*xi1>) -> tensor<2xi1>
 // CHECK: return %[[res]] : tensor<2xi1>
 }
 
diff --git a/tensorflow/compiler/mlir/tfr/tests/rewrite_quantized_io.mlir b/tensorflow/compiler/mlir/tfr/tests/rewrite_quantized_io.mlir
index af2f843..05823bf 100644
--- a/tensorflow/compiler/mlir/tfr/tests/rewrite_quantized_io.mlir
+++ b/tensorflow/compiler/mlir/tfr/tests/rewrite_quantized_io.mlir
@@ -22,8 +22,8 @@
   %1 = "tf.Intermediate"(%arg1) : (tensor<1x5xf32>) -> tensor<1x5xf32>
   func.return %0, %1 : tensor<1x10x!quant.uniform<i8:f32, 0.2:42>>, tensor<1x5xf32>
 
-// CHECK-DAG: %[[scale:.*]] = "tf.Const"() {value = dense<1.000000e-01> : tensor<f32>} : () -> tensor<f32>
-// CHECK-DAG: %[[zp:.*]] = "tf.Const"() {value = dense<-128> : tensor<i32>} : () -> tensor<i32>
+// CHECK-DAG: %[[scale:.*]] = "tf.Const"() <{value = dense<1.000000e-01> : tensor<f32>}> : () -> tensor<f32>
+// CHECK-DAG: %[[zp:.*]] = "tf.Const"() <{value = dense<-128> : tensor<i32>}> : () -> tensor<i32>
 // CHECK: %[[quant:.*]] = "tfr.cast"(%arg0) : (tensor<1x10xi8>) -> !tfr.tensor
 // CHECK: %[[scale_cast:.*]] = "tfr.cast"(%[[scale]])
 // CHECK: %[[zp_cast:.*]] = "tfr.cast"(%[[zp]])
diff --git a/tensorflow/compiler/mlir/tfrt/BUILD b/tensorflow/compiler/mlir/tfrt/BUILD
index f38289a..797d966 100644
--- a/tensorflow/compiler/mlir/tfrt/BUILD
+++ b/tensorflow/compiler/mlir/tfrt/BUILD
@@ -1,5 +1,5 @@
 load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
-load("//tensorflow:tensorflow.bzl", "if_google", "tf_cc_binary")
+load("//tensorflow:tensorflow.bzl", "if_google", "tf_cc_binary", "tf_cc_test")
 
 # Note: keep the following lines separate due to the way copybara works
 load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable")
@@ -22,6 +22,7 @@
         "//tensorflow/core/runtime_fallback/...",
         "//tensorflow/core/tfrt/experimental/data/...",
         "//tensorflow/core/tfrt/graph_executor/...",
+        "//tensorflow/core/tfrt/ifrt/...",
         "//tensorflow/core/tfrt/mlrt/...",
         "//tensorflow/core/tfrt/saved_model/...",
         "//tensorflow/core/tfrt/tfrt_session/...",
@@ -126,6 +127,116 @@
 )
 
 cc_library(
+    name = "tf_ifrt_passes",
+    srcs = [
+        "transforms/ifrt/rewrite_cluster_to_ifrt_call.cc",
+        "transforms/ifrt/tf_ifrt_passes.cc",
+    ],
+    hdrs = [
+        "transforms/ifrt/rewrite_cluster_to_ifrt_call.h",
+        "transforms/ifrt/tf_ifrt_passes.h",
+    ],
+    #compatible_with = get_compatible_with_portable(),  # copybara: comment
+    deps = [
+        "//tensorflow/compiler/mlir/tensorflow",
+        "//tensorflow/compiler/mlir/tensorflow:bridge_logger",
+        "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util",
+        "//tensorflow/compiler/mlir/tensorflow:error_util",
+        "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops",
+        "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes",
+        "//tensorflow/core:framework",
+        "//tensorflow/core/platform:random",
+        "@com_google_absl//absl/base",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/strings",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:FuncDialect",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:Support",
+    ],
+)
+
+cc_library(
+    name = "ifrt_serving_executable",
+    srcs = ["transforms/ifrt/ifrt_serving_executable.cc"],
+    hdrs = ["transforms/ifrt/ifrt_serving_executable.h"],
+    deps = [
+        "//tensorflow/compiler/mlir/tensorflow:serialize_mlir_module_utils",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:protos_all_cc",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/memory",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:span",
+        "@llvm-project//mlir:IR",
+    ],
+)
+
+cc_library(
+    name = "ifrt_backend_compiler",
+    srcs = ["transforms/ifrt/ifrt_backend_compiler.cc"],
+    hdrs = ["transforms/ifrt/ifrt_backend_compiler.h"],
+    deps = [
+        ":backend_compiler",
+        ":ifrt_serving_executable",
+        ":tf_ifrt_passes",
+        ":tpu_passes",
+        "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util",
+        "//tensorflow/compiler/mlir/tensorflow:error_util",
+        "//tensorflow/compiler/mlir/tensorflow:visitor",
+        "//tensorflow/compiler/mlir/tf2xla/api/v2:cluster_tf",
+        "//tensorflow/core/tfrt/ifrt:ifrt_executable_registry",
+        "//tensorflow/core/tfrt/ifrt:ifrt_model_context",
+        "//tensorflow/core/tfrt/runtime",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:FuncDialect",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Support",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:statusor",
+        "@local_tsl//tsl/profiler/lib:traceme",
+    ],
+)
+
+tf_cc_test(
+    name = "ifrt_backend_compiler_test",
+    srcs = [
+        "transforms/ifrt/ifrt_backend_compiler_test.cc",
+    ],
+    data = [
+        "//tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata",
+    ],
+    tags = ["no_oss"],
+    deps = [
+        ":ifrt_backend_compiler",
+        "//tensorflow/compiler/mlir/tensorflow",
+        "//tensorflow/core:test",
+        "//tensorflow/core/platform:resource_loader",
+        "//tensorflow/core/tfrt/graph_executor:graph_execution_options",
+        "//tensorflow/core/tfrt/ifrt:ifrt_model_context",
+        "//tensorflow/core/tfrt/runtime",
+        "//tensorflow/core/tfrt/saved_model:saved_model_testutil",
+        "@com_google_absl//absl/strings",
+        "@com_google_googletest//:gtest_main",
+        "@llvm-project//mlir:AllPassesAndDialects",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Parser",
+        "@tf_runtime//:hostcontext",
+    ],
+)
+
+cc_library(
     name = "corert_converter",
     srcs = [
         "transforms/corert_converter.cc",
@@ -361,11 +472,15 @@
         "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util",
         "//tensorflow/compiler/mlir/tensorflow:error_util",
         "//tensorflow/compiler/mlir/tensorflow:import_model",
+        "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
         "//tensorflow/compiler/mlir/tensorflow:serialize_mlir_module_utils",
         "//tensorflow/compiler/mlir/tensorflow/transforms:bridge",
         "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes",
         "//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_asset_sinking_pass",
+        "//tensorflow/compiler/mlir/tensorflow/transforms/host_runtime:lower_cluster_to_runtime_ops",
         "//tensorflow/compiler/mlir/tf2xla/api/v2:cluster_tf",
+        "//tensorflow/compiler/mlir/tf2xla/api/v2:tf_dialect_to_executor",
+        "//tensorflow/compiler/tf2xla:xla_op_registry",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core/common_runtime:function_body",
@@ -373,7 +488,13 @@
         "//tensorflow/core/platform:status",
         "//tensorflow/core/tfrt/fallback:fallback_state",
         "//tensorflow/core/tfrt/runtime",
+        "//tensorflow/core/tpu:tpu_defs",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/functional:function_ref",
+        "@com_google_absl//absl/log",
         "@com_google_absl//absl/status",
+        "@com_google_absl//absl/strings",
+        "@llvm-project//mlir:FuncDialect",
         "@llvm-project//mlir:FuncExtensions",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Pass",
@@ -382,6 +503,7 @@
         "@local_tsl//tsl/platform:errors",
         "@tf_runtime//:bef",
         "@tf_runtime//:mlirtobef",
+        "@tf_runtime//:support",
     ],
 )
 
@@ -485,13 +607,14 @@
         ":test_cost_analysis_pass",
         ":test_opkernels",
         ":test_tensor_array_side_effect_analysis",
+        ":tf_ifrt_passes",
         ":tf_to_tfrt",
+        ":tpu_passes",
         ":transforms/gpu_passes",
         "//tensorflow/compiler/mlir:init_mlir",
         "//tensorflow/compiler/mlir:passes",
         "//tensorflow/compiler/mlir/lite:tensorflow_lite",
         "//tensorflow/compiler/mlir/tensorflow",
-        "//tensorflow/compiler/mlir/tensorflow/transforms:bridge_pass_test_pipeline_registration",
         "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes",
         "//tensorflow/compiler/mlir/tensorflow/transforms/host_runtime:lower_cluster_to_runtime_ops",
         "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_async_opdefs",
@@ -500,11 +623,11 @@
         "//tensorflow/compiler/mlir/tfrt/ir/mlrt:mlrt_ops",
         "//tensorflow/compiler/mlir/tfrt/ir/mlrt:tf_mlrt_ops",
         "//tensorflow/compiler/mlir/tfrt/transforms/mlrt:passes",
-        "//tensorflow/core:lib",
         "//tensorflow/core:tensorflow",
         "@llvm-project//mlir:AllPassesAndDialects",
         "@llvm-project//mlir:MlirOptLib",
         "@llvm-project//mlir:ShapeDialect",
+        "@llvm-project//mlir:Support",
         "@llvm-project//mlir:Transforms",
         "@local_xla//xla/mlir_hlo",
         "@local_xla//xla/mlir_hlo:gml_st",
@@ -667,9 +790,13 @@
 cc_library(
     name = "tpu_passes",
     hdrs = ["transforms/tpu_passes.h"],
+    visibility = [":friends"] + if_google([
+        "//learning/brain/tfrt/ifrt/pjrt/__subpackages__",
+    ]),
     deps = [
         ":fallback_converter",
         ":tfrt_compile_options",
+        "@llvm-project//llvm:Support",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:Support",
diff --git a/tensorflow/compiler/mlir/tfrt/tests/fuse_tpu_compile_and_execute_ops.mlir b/tensorflow/compiler/mlir/tfrt/tests/fuse_tpu_compile_and_execute_ops.mlir
index 823cbbc..75eff82 100644
--- a/tensorflow/compiler/mlir/tfrt/tests/fuse_tpu_compile_and_execute_ops.mlir
+++ b/tensorflow/compiler/mlir/tfrt/tests/fuse_tpu_compile_and_execute_ops.mlir
@@ -11,7 +11,7 @@
   // CHECK-NOT: tf.TPUExecuteOp
 
   // CHECK-NEXT: %0 = "tf.ReadVariableOp"(%arg1)
-  // CHECK:      [[key:%.*]], [[exec_result:%.*]] = "tf.TPUCompileMlirAndExecute"(%arg0, %0) {metadata = "metadata", mlir_module = "mlir_module", operandSegmentSizes = array<i32: 2, 0>, operands_with_static_shape = [], producer_name = "default"} : (tensor<*xi32>, tensor<*xi32>) -> (tensor<3x!tf_type.string>, tensor<*xi32>)
+  // CHECK:      [[key:%.*]], [[exec_result:%.*]] = "tf.TPUCompileMlirAndExecute"(%arg0, %0) <{metadata = "metadata", mlir_module = "mlir_module", operandSegmentSizes = array<i32: 2, 0>, operands_with_static_shape = [], producer_name = "default"}> : (tensor<*xi32>, tensor<*xi32>) -> (tensor<3x!tf_type.string>, tensor<*xi32>)
   // CHECK-NEXT: return [[exec_result]] : tensor<*xi32>
 
   %0 = "tf.ReadVariableOp"(%arg1) {device = "/CPU:0"} : (tensor<*x!tf_type.resource>) -> tensor<*xi32>
@@ -38,8 +38,8 @@
   // CHECK-NOT: tf.TPUExecuteOp
 
   // CHECK-NEXT: %0 = "tf.ReadVariableOp"(%arg1)
-  // CHECK:      [[key:%.*]], [[exec_result:%.*]] = "tf.TPUCompileMlirAndExecute"(%arg0, %0) {metadata = "metadata", mlir_module = "mlir_module", operandSegmentSizes = array<i32: 2, 0>, operands_with_static_shape = [], producer_name = "default"} : (tensor<*xi32>, tensor<*xi32>) -> (tensor<3x!tf_type.string>, tensor<*xi32>)
-  // CHECK-NEXT: "tf._XlaSendFromHost"(%arg0, %0, [[key]]) {_xla_has_host_transfer = true, device = "/job:localhost/replica:0/task:0/device:CPU:0", device_ordinal = 0 : i64, key = "host_compute_channel_0_retvals"} : (tensor<*xi32>, tensor<*xi32>, tensor<3x!tf_type.string>) -> ()
+  // CHECK:      [[key:%.*]], [[exec_result:%.*]] = "tf.TPUCompileMlirAndExecute"(%arg0, %0) <{metadata = "metadata", mlir_module = "mlir_module", operandSegmentSizes = array<i32: 2, 0>, operands_with_static_shape = [], producer_name = "default"}> : (tensor<*xi32>, tensor<*xi32>) -> (tensor<3x!tf_type.string>, tensor<*xi32>)
+  // CHECK-NEXT: "tf._XlaSendFromHost"(%arg0, %0, [[key]]) <{device_ordinal = 0 : i64, key = "host_compute_channel_0_retvals"}> {_xla_has_host_transfer = true, device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<*xi32>, tensor<*xi32>, tensor<3x!tf_type.string>) -> ()
   // CHECK-NEXT: return [[exec_result]] : tensor<*xi32>
   %0 = "tf.ReadVariableOp"(%arg1) {device = "/CPU:0"} : (tensor<*x!tf_type.resource>) -> tensor<*xi32>
   %1 = "tf.Shape"(%arg0) {device = "/CPU:0"} : (tensor<*xi32>) -> tensor<?xi64>
@@ -69,8 +69,8 @@
   // CHECK: [[read_result:%.*]] = "tf.ReadVariableOp"(%arg1)
   // CHECK: [[shape_result_1:%.*]] = "tf.Shape"(%arg0) {device = "/CPU:0"} : (tensor<?x?xi32>) -> tensor<?xi64>
   // CHECK: [[shape_result_2:%.*]] = "tf.Shape"([[read_result]]) {device = "/CPU:0"} : (tensor<*xi32>) -> tensor<?xi64>
-  // CHECK: [[key:%.*]], [[exec_result:%.*]] = "tf.TPUCompileMlirAndExecute"(%arg0, [[shape_result_2]], %0, %0, %arg2, %arg4, %arg3) {metadata = "metadata", mlir_module = "mlir_module", operandSegmentSizes = array<i32: 4, 3>, operands_with_static_shape = [0 : i32, 1 : i32, 3 : i32], producer_name = "default"} : (tensor<?x?xi32>, tensor<?xi64>, tensor<*xi32>, tensor<*xi32>, tensor<2xi64>, tensor<?xi64>, tensor<?xi64>) -> (tensor<3x!tf_type.string>, tensor<*xi32>)
-  // CHECK: [[key_1:%.*]], [[exec_result_1:%.*]] = "tf.TPUCompileMlirAndExecute"(%arg0, %2, %0, %1) {metadata = "metadata", mlir_module = "mlir_module", operandSegmentSizes = array<i32: 4, 0>, operands_with_static_shape = [], producer_name = "default"} : (tensor<?x?xi32>, tensor<?xi64>, tensor<*xi32>, tensor<?xi64>) -> (tensor<3x!tf_type.string>, tensor<*xi32>)
+  // CHECK: [[key:%.*]], [[exec_result:%.*]] = "tf.TPUCompileMlirAndExecute"(%arg0, [[shape_result_2]], %0, %0, %arg2, %arg4, %arg3) <{metadata = "metadata", mlir_module = "mlir_module", operandSegmentSizes = array<i32: 4, 3>, operands_with_static_shape = [0 : i32, 1 : i32, 3 : i32], producer_name = "default"}> : (tensor<?x?xi32>, tensor<?xi64>, tensor<*xi32>, tensor<*xi32>, tensor<2xi64>, tensor<?xi64>, tensor<?xi64>) -> (tensor<3x!tf_type.string>, tensor<*xi32>)
+  // CHECK: [[key_1:%.*]], [[exec_result_1:%.*]] = "tf.TPUCompileMlirAndExecute"(%arg0, %2, %0, %1) <{metadata = "metadata", mlir_module = "mlir_module", operandSegmentSizes = array<i32: 4, 0>, operands_with_static_shape = [], producer_name = "default"}> : (tensor<?x?xi32>, tensor<?xi64>, tensor<*xi32>, tensor<?xi64>) -> (tensor<3x!tf_type.string>, tensor<*xi32>)
   // CHECK-NEXT: return [[exec_result]] : tensor<*xi32>
   %0 = "tf.ReadVariableOp"(%arg1) {device = "/CPU:0"} : (tensor<*x!tf_type.resource>) -> tensor<*xi32>
   %dyn_arg0 = "tf.SetStaticDimensionBounds" (%arg0, %arg2) :(tensor<?x?xi32>, tensor<2xi64>) -> tensor<?x?xi32>
diff --git a/tensorflow/compiler/mlir/tfrt/tests/hoist_invariant_ops.mlir b/tensorflow/compiler/mlir/tfrt/tests/hoist_invariant_ops.mlir
index 28c08a1..e6d5aec 100644
--- a/tensorflow/compiler/mlir/tfrt/tests/hoist_invariant_ops.mlir
+++ b/tensorflow/compiler/mlir/tfrt/tests/hoist_invariant_ops.mlir
@@ -5,16 +5,16 @@
 // Test hoisting varhandle op.
 
 // CHECK-LABEL: func @_tfrt_resource_init
-// CHECK: [[handle:%.*]] = "tf.VarHandleOp"() {container = "", shared_name = "x"} : () -> tensor<!tf_type.resource<tensor<i32>>>
+// CHECK: [[handle:%.*]] = "tf.VarHandleOp"() <{container = "", shared_name = "x"}> : () -> tensor<!tf_type.resource<tensor<i32>>>
 // CHECK: [[x:%.*]] = "tf.ReadVariableOp"([[handle]]) {device = "/CPU:0", dtype = i32} : (tensor<!tf_type.resource<tensor<i32>>>) -> tensor<i32>
-// CHECK: "tf._TfrtSetResource"([[x]]) {device = "/CPU:0", index = 0 : i64} : (tensor<i32>) -> ()
+// CHECK: "tf._TfrtSetResource"([[x]]) <{index = 0 : i64}> {device = "/CPU:0"} : (tensor<i32>) -> ()
 
 // CHECK-LABEL: func @test_hoist_varhandleop
 func.func @hoist_varhandleop(%arg: tensor<i32> {tf_saved_model.index_path = ["input"]}) -> (tensor<i32> {tf_saved_model.index_path = ["r"]})
   attributes {tf_saved_model.exported_names = ["test_hoist_varhandleop"]} {
   // CHECK-NOT: tf.VarHandleOp
   // CHECK-NOT: tf.ReadVariableOp
-  // CHECK: [[v:%.*]] = "tf._TfrtGetResource"() {container = [""], device = "/CPU:0", indices = [0], shared_name = [""]} : () -> tensor<i32>
+  // CHECK: [[v:%.*]] = "tf._TfrtGetResource"() <{container = [""], indices = [0], shared_name = [""]}> {device = "/CPU:0"} : () -> tensor<i32>
   // CHECK: [[r:%.*]] = "tf.AddV2"({{.*}}, [[v]]) {device = "/CPU:0"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
   // CHECK: return [[r]]
   %handle = "tf.VarHandleOp"() {container = "", shared_name = "x"} : () -> tensor<!tf_type.resource<tensor<i32>>>
@@ -34,16 +34,16 @@
 // CHECK-LABEL: func @_tfrt_resource_init
 // CHECK: [[handle:%.*]] = "tf.HashTableV2"()
 // CHECK-SAME: shared_name = "x"
-// CHECK: "tf._TfrtSetResource"([[handle]]) {device = "/job:localhost/replica:0/task:0/device:CPU:0", index = [[handle_idx:.*]] : i64}
+// CHECK: "tf._TfrtSetResource"([[handle]]) <{index = [[handle_idx:.*]] : i64}> {device = "/job:localhost/replica:0/task:0/device:CPU:0"}
 // CHECK: [[x:%.*]] = "tf.LookupTableSizeV2"([[handle]])
-// CHECK: "tf._TfrtSetResource"([[x]]) {device = "/job:localhost/replica:0/task:0/device:CPU:0", index = [[size_idx:.*]] : i64} : (tensor<i64>) -> ()
+// CHECK: "tf._TfrtSetResource"([[x]]) <{index = [[size_idx:.*]] : i64}> {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<i64>) -> ()
 
 // CHECK: func @test_hoist_hash_table
 func.func @hoist_hash_table(%arg: tensor<?x!tf_type.string> {tf_saved_model.index_path = ["input"]}, %default: tensor<i64> {tf_saved_model.index_path = ["default"]}) -> (tensor<i64> {tf_saved_model.index_path = ["r"]}, tensor<*xi64> {tf_saved_model.index_path = ["r1"]})
   attributes {tf_saved_model.exported_names = ["test_hoist_hash_table"]} {
   // CHECK-NOT: tf.HashTableV2
   // CHECK-NOT: tf.LookupTableSizeV2
-  // CHECK: [[v:%.*]]:2 = "tf._TfrtGetResource"() {container = ["", ""], device = "/job:localhost/replica:0/task:0/device:CPU:0", indices = [0, 1], shared_name = [{{.*}}, {{.*}}]}
+  // CHECK: [[v:%.*]]:2 = "tf._TfrtGetResource"() <{container = ["", ""], indices = [0, 1], shared_name = [{{.*}}, {{.*}}]}> {device = "/job:localhost/replica:0/task:0/device:CPU:0"}
   // CHECK: [[r:%.*]] = "tf.LookupTableFindV2"([[v]]#[[handle_idx]]
   // CHECK: return [[v]]#[[size_idx]], [[r]]
   %0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf_type.string, shared_name = "x", use_node_name_sharing = false, value_dtype = i64} : () -> tensor<!tf_type.resource>
@@ -61,17 +61,17 @@
 // Test hoisting const op.
 
 // CHECK-LABEL: func @_tfrt_resource_init
-// CHECK: [[const:%.*]] = "tf.Const"() {device = "/CPU:0", value = dense<0> : tensor<i32>} : () -> tensor<i32>
+// CHECK: [[const:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> {device = "/CPU:0"} : () -> tensor<i32>
 // CHECK: [[x:%.*]] = "tf.AddV2"([[const]], [[const]]) {device = "/CPU:0"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
-// CHECK: "tf._TfrtSetResource"([[x]]) {device = "/CPU:0", index = 0 : i64} : (tensor<i32>) -> ()
-// CHECK: [[const_1:%.*]] = "tf.Const"() {device = "/CPU:0", value = dense<1> : tensor<i32>} : () -> tensor<i32>
-// CHECK: "tf._TfrtSetResource"([[const_1]]) {device = "/CPU:0", index = 1 : i64} : (tensor<i32>) -> ()
+// CHECK: "tf._TfrtSetResource"([[x]]) <{index = 0 : i64}> {device = "/CPU:0"} : (tensor<i32>) -> ()
+// CHECK: [[const_1:%.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> {device = "/CPU:0"} : () -> tensor<i32>
+// CHECK: "tf._TfrtSetResource"([[const_1]]) <{index = 1 : i64}> {device = "/CPU:0"} : (tensor<i32>) -> ()
 
 // CHECK-LABEL: func @test_hoist_const
 func.func @hoist_const(%arg: tensor<i32> {tf_saved_model.index_path = ["input"]}) -> (tensor<i32> {tf_saved_model.index_path = ["r"]})
   attributes {tf_saved_model.exported_names = ["test_hoist_const"]} {
   // CHECK-NOT: tf.Const
-  // CHECK: [[v:%.*]] = "tf._TfrtGetResource"() {container = [""], device = "/CPU:0", indices = [0], shared_name = [""]} : () -> tensor<i32>
+  // CHECK: [[v:%.*]] = "tf._TfrtGetResource"() <{container = [""], indices = [0], shared_name = [""]}> {device = "/CPU:0"} : () -> tensor<i32>
   // CHECK-NEXT: "tf.AddV2"({{.*}}, [[v]]) {device = "/CPU:0"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
   // CHECK-NEXT: return
   %const = "tf.Const"() {device = "/CPU:0", value = dense<0> : tensor<i32>} : () -> tensor<i32>
@@ -84,7 +84,7 @@
 func.func @hoist_const_return(%arg: tensor<i32> {tf_saved_model.index_path = ["input"]}) -> (tensor<i32> {tf_saved_model.index_path = ["r"]})
   attributes {tf_saved_model.exported_names = ["test_hoist_const_return"]} {
   // CHECK-NOT: tf.Const
-  // CHECK: [[v:%.*]] = "tf._TfrtGetResource"() {container = [""], device = "/CPU:0", indices = [1], shared_name = [""]} : () -> tensor<i32>
+  // CHECK: [[v:%.*]] = "tf._TfrtGetResource"() <{container = [""], indices = [1], shared_name = [""]}> {device = "/CPU:0"} : () -> tensor<i32>
   // CHECK-NEXT: return [[v]]
   %const = "tf.Const"() {device = "/CPU:0", value = dense<1> : tensor<i32>} : () -> tensor<i32>
   func.return %const : tensor<i32>
@@ -99,17 +99,17 @@
 // Test hoisting write side-effect ops.
 
 // CHECK-LABEL: func @_tfrt_resource_init
-// CHECK: [[const:%.*]] = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<0> : tensor<i32>} : () -> tensor<i32>
-// CHECK: "tf._TfrtSetResource"([[const]]) {device = "/job:localhost/replica:0/task:0/device:CPU:0", index = [[const_idx:.*]] : i64} : (tensor<i32>) -> ()
-// CHECK: [[handle:%.*]] = "tf.VarHandleOp"() {container = "", shared_name = "x"} : () -> tensor<!tf_type.resource<tensor<i32>>>
-// CHECK: "tf._TfrtSetResource"([[handle]]) {device = "/job:localhost/replica:0/task:0/device:CPU:0", index = [[handle_idx:.*]] : i64} : (tensor<!tf_type.resource<tensor<i32>>>) -> ()
+// CHECK: [[const:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> tensor<i32>
+// CHECK: "tf._TfrtSetResource"([[const]]) <{index = [[const_idx:.*]] : i64}> {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<i32>) -> ()
+// CHECK: [[handle:%.*]] = "tf.VarHandleOp"() <{container = "", shared_name = "x"}> : () -> tensor<!tf_type.resource<tensor<i32>>>
+// CHECK: "tf._TfrtSetResource"([[handle]]) <{index = [[handle_idx:.*]] : i64}> {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<!tf_type.resource<tensor<i32>>>) -> ()
 
 // CHECK: func @test_hoist_var_read_write
 func.func @hoist_var_read_write() -> (tensor<i32> {tf_saved_model.index_path = ["x"]}, tensor<i32> {tf_saved_model.index_path = ["r"]})
   attributes {tf_saved_model.exported_names = ["test_hoist_var_read_write"]} {
   // CHECK-NOT: tf.Const
   // CHECK-NOT: tf.VarHandleOp
-  // CHECK: [[v:%.*]]:2 = "tf._TfrtGetResource"() {container = ["", ""], device = "/job:localhost/replica:0/task:0/device:CPU:0", indices = [0, 1], shared_name = [{{.*}}, {{.*}}]} : () -> ({{.*}})
+  // CHECK: [[v:%.*]]:2 = "tf._TfrtGetResource"() <{container = ["", ""], indices = [0, 1], shared_name = [{{.*}}, {{.*}}]}> {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> ({{.*}})
   // CHECK: [[x:%.*]] = "tf.ReadVariableOp"([[v]]#[[handle_idx]]) {device = "/CPU:0", dtype = i32} : (tensor<!tf_type.resource<tensor<i32>>>) -> tensor<i32>
   // CHECK-NEXT: "tf.AssignVariable"([[v]]#[[handle_idx]], [[v]]#[[const_idx]]) {device = "/CPU:0"} : (tensor<!tf_type.resource<tensor<i32>>>, tensor<i32>) -> ()
   // CHECK-NEXT: [[r:%.*]] = "tf.ReadVariableOp"([[v]]#[[handle_idx]]) {device = "/CPU:0", dtype = i32} : (tensor<!tf_type.resource<tensor<i32>>>) -> tensor<i32>
@@ -131,13 +131,13 @@
 // Test not hoisting read variable op that used by control flow ops if var handle op and read variable op are separated, but still hoists const ops and var handle ops.
 
 // CHECK-LABEL: func @_tfrt_resource_init
-// CHECK: [[handle:%.*]] = "tf.VarHandleOp"() {container = "", shared_name = "x"} : () -> tensor<!tf_type.resource<tensor<i32>>>
+// CHECK: [[handle:%.*]] = "tf.VarHandleOp"() <{container = "", shared_name = "x"}> : () -> tensor<!tf_type.resource<tensor<i32>>>
 // CHECK: "tf._TfrtSetResource"([[handle]])
 // CHECK-SAME: index = [[handle_index:.*]]
-// CHECK: [[handle1:%.*]] = "tf.VarHandleOp"() {container = "", shared_name = "x"} : () -> tensor<!tf_type.resource<tensor<i32>>>
+// CHECK: [[handle1:%.*]] = "tf.VarHandleOp"() <{container = "", shared_name = "x"}> : () -> tensor<!tf_type.resource<tensor<i32>>>
 // CHECK: "tf._TfrtSetResource"([[handle1]])
 // CHECK-SAME: index = [[handle1_index:.*]]
-// CHECK: [[const:%.*]] = "tf.Const"() {device = "/CPU:0", value = dense<true> : tensor<i1>} : () -> tensor<i1>
+// CHECK: [[const:%.*]] = "tf.Const"() <{value = dense<true> : tensor<i1>}> {device = "/CPU:0"} : () -> tensor<i1>
 // CHECK: "tf._TfrtSetResource"([[const]])
 // CHECK-SAME: index = [[const_index:.*]]
 func.func private @some_func(
@@ -164,7 +164,7 @@
   attributes {tf_saved_model.exported_names = ["test_not_hoist_if"]} {
   %handle = "tf.VarHandleOp"() {container = "", shared_name = "x"} : () -> tensor<!tf_type.resource<tensor<i32>>>
   // CHECK-NOT: tf.Const
-  // CHECK:  "tf._TfrtGetResource"() 
+  // CHECK:  "tf._TfrtGetResource"()
   %cond = "tf.Const"() {device = "/CPU:0", value = dense<true> : tensor<i1>} : () -> tensor<i1>
   // CHECK: tf.If
   %x = "tf.If"(%cond, %handle) {then_branch = @some_func, else_branch = @some_func, is_stateless = false} : (tensor<i1>, tensor<!tf_type.resource<tensor<i32>>>) -> tensor<i32>
@@ -185,7 +185,7 @@
   attributes {tf._input_shapes = [#tf_type.shape<1x3>, #tf_type.shape<*>], tf.signature.is_stateful} {
   // CHECK-NOT: tf.VarHandleOp
   // CHECK-NOT: tf.ReadVariableOp
-  // CHECK:  "tf._TfrtGetResource"() 
+  // CHECK:  "tf._TfrtGetResource"()
   %0 = "tf.VarHandleOp"() {device = "/device:CPU:0", container = "", shared_name = "variable"} : () -> tensor<!tf_type.resource<tensor<1x3xf32>>>
   %1 = "tf.ReadVariableOp"(%0) {device = "/device:CPU:0"} : (tensor<!tf_type.resource<tensor<1x3xf32>>>) -> tensor<1x3xf32>
   %2 = "tf.AddV2"(%arg0, %1) {device = "/device:CPU:0"} : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32>
@@ -197,7 +197,7 @@
 func.func @main(%arg0: tensor<1x3xf32> {tf_saved_model.index_path = ["input"]}) -> (tensor<*xf32> {tf_saved_model.index_path = ["r"]}) 
   attributes {tf_saved_model.exported_names = ["main"]} {
   // CHECK-NOT: tf.VarHandleOp
-  // CHECK:  "tf._TfrtGetResource"() 
+  // CHECK:  "tf._TfrtGetResource"()
   %0 = "tf.VarHandleOp"() {device = "/device:CPU:0", container = "", shared_name = "variable"} : () -> tensor<!tf_type.resource<tensor<1x3xf32>>>
   // CHECK: "tf.BatchFunction"(%arg0, %0)
   // CHECK: operandSegmentSizes = array<i32: 1, 1>
@@ -288,4 +288,4 @@
   func.return %r : tensor<i32>
 }
 
-}
\ No newline at end of file
+}
diff --git a/tensorflow/compiler/mlir/tfrt/tests/hoist_invariant_ops_mlrt.mlir b/tensorflow/compiler/mlir/tfrt/tests/hoist_invariant_ops_mlrt.mlir
index 7b797b3..7f82726 100644
--- a/tensorflow/compiler/mlir/tfrt/tests/hoist_invariant_ops_mlrt.mlir
+++ b/tensorflow/compiler/mlir/tfrt/tests/hoist_invariant_ops_mlrt.mlir
@@ -7,17 +7,17 @@
 // CHECK-LABEL: func @_tfrt_resource_init
 // CHECK: [[handle:%.*]] = "tf.HashTableV2"()
 // CHECK-SAME: shared_name = "x"
-// CHECK: "tf._TfrtSetResource"([[handle]]) {device = "/job:localhost/replica:0/task:0/device:CPU:0", index = [[handle_id:.*]] : i64}
+// CHECK: "tf._TfrtSetResource"([[handle]]) <{index = [[handle_id:.*]] : i64}> {device = "/job:localhost/replica:0/task:0/device:CPU:0"}
 // CHECK: [[x:%.*]] = "tf.LookupTableSizeV2"([[handle]])
-// CHECK: "tf._TfrtSetResource"([[x]]) {device = "/job:localhost/replica:0/task:0/device:CPU:0", index = [[size_id:.*]] : i64} : (tensor<i64>) -> ()
+// CHECK: "tf._TfrtSetResource"([[x]]) <{index = [[size_id:.*]] : i64}> {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<i64>) -> ()
 
 // CHECK: func @test_hoist_hash_table
 func.func @hoist_hash_table(%arg: tensor<?x!tf_type.string> {tf_saved_model.index_path = ["input"]}, %default: tensor<i64> {tf_saved_model.index_path = ["default"]}) -> (tensor<i64> {tf_saved_model.index_path = ["r"]}, tensor<*xi64> {tf_saved_model.index_path = ["r1"]})
   attributes {tf_saved_model.exported_names = ["test_hoist_hash_table"]} {
   // CHECK-NOT: tf.HashTableV2
   // CHECK-NOT: tf.LookupTableSizeV2
-  // CHECK-DAG: [[v0:%.*]] = "tf._TfrtGetResource"() {container = [""], device = "/job:localhost/replica:0/task:0/device:CPU:0", indices = [[[handle_id]]], shared_name = [{{.*}}]}
-  // CHECK-DAG: [[v1:%.*]] = "tf._TfrtGetResource"() {container = [""], device = "/job:localhost/replica:0/task:0/device:CPU:0", indices = [[[size_id]]], shared_name = [{{.*}}]}
+  // CHECK-DAG: [[v0:%.*]] = "tf._TfrtGetResource"() <{container = [""], indices = [[[handle_id]]], shared_name = [{{.*}}]}> {device = "/job:localhost/replica:0/task:0/device:CPU:0"}
+  // CHECK-DAG: [[v1:%.*]] = "tf._TfrtGetResource"() <{container = [""], indices = [[[size_id]]], shared_name = [{{.*}}]}> {device = "/job:localhost/replica:0/task:0/device:CPU:0"}
   // CHECK-DAG: [[r:%.*]] = "tf.LookupTableFindV2"([[v0]]
   // CHECK-DAG: return [[v1]], [[r]]
   %0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf_type.string, shared_name = "x", use_node_name_sharing = false, value_dtype = i64} : () -> tensor<!tf_type.resource>
diff --git a/tensorflow/compiler/mlir/tfrt/tests/ifrt/BUILD b/tensorflow/compiler/mlir/tfrt/tests/ifrt/BUILD
new file mode 100644
index 0000000..8da9c4c
--- /dev/null
+++ b/tensorflow/compiler/mlir/tfrt/tests/ifrt/BUILD
@@ -0,0 +1,21 @@
+load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
+
+# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"])
+
+glob_lit_tests(
+    name = "all_tests",
+    data = [":test_utilities"],
+    driver = "//tensorflow/compiler/mlir:run_lit.sh",
+    test_file_exts = ["mlir"],
+)
+
+# Bundle together all of the test utilities that are used by tests.
+filegroup(
+    name = "test_utilities",
+    testonly = True,
+    data = [
+        "//tensorflow/compiler/mlir/tfrt:tf-tfrt-opt",
+        "@llvm-project//llvm:FileCheck",
+        "@llvm-project//mlir:run_lit.sh",
+    ],
+)
diff --git a/tensorflow/compiler/mlir/tfrt/tests/ifrt/rewrite_cluster_to_ifrt_call.mlir b/tensorflow/compiler/mlir/tfrt/tests/ifrt/rewrite_cluster_to_ifrt_call.mlir
new file mode 100644
index 0000000..5a8782d
--- /dev/null
+++ b/tensorflow/compiler/mlir/tfrt/tests/ifrt/rewrite_cluster_to_ifrt_call.mlir
@@ -0,0 +1,34 @@
+// RUN: tf-tfrt-opt -split-input-file -rewrite-cluster-to-ifrt-call %s | FileCheck %s
+
+// -----
+
+// CHECK-LABEL: func.func @serving_default(%arg0: tensor<3x1xf32>, %arg1: tensor<1x3xf32>) -> tensor<1x1xf32> {
+// CHECK-NEXT:  %0 = "tf.IfrtCall"(%arg1, %arg0) 
+// CHECK-SAME:       {program_id = [[PROGRAM_ID:.*]] : i64, variable_names = []} 
+// CHECK-SAME:       (tensor<1x3xf32>, tensor<3x1xf32>) -> tensor<1x1xf32>
+// CHECK-NEXT:    %1 = "tf.Identity"(%arg1) {device = ""} : (tensor<1x3xf32>) -> tensor<1x3xf32>
+// CHECK-NEXT:    %2 = "tf.IfrtCall"(%1, %arg0) 
+// CHECK-SAME:       {program_id = [[PROGRAM_ID]] : i64, variable_names = []} 
+// CHECK-SAME:       (tensor<1x3xf32>, tensor<3x1xf32>) -> tensor<1x1xf32>
+// CHECK-NEXT:    %3 = "tf.add"(%0, %2) : (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>
+// CHECK:    return
+//
+// CHECK:  func.func @_ifrt_program__func(%arg0: tensor<1x3xf32>, %arg1: tensor<3x1xf32>) -> tensor<1x1xf32> 
+// CHECK-SAME:      attributes {tfrt_ifrt_serving.program_id = [[PROGRAM_ID]] : i64
+// CHECK-NEXT:     %0 = "tf.MatMul"(%arg0, %arg1)
+// CHECK:          return
+
+func.func @serving_default(%arg0: tensor<3x1xf32>,  %arg1: tensor<1x3xf32>) -> (tensor<1x1xf32>) {
+  %outputs  =  "tf.TPUCompilationResult"() {_tpu_compilation_status = "cluster", device = ""} : () -> tensor<!tf_type.string>
+  %outputs_0 = "tf_device.cluster_func"(%arg1, %arg0) {_producer_name = "UNKNOWN", func = @_func } : (tensor<1x3xf32>, tensor<3x1xf32>) -> tensor<1x1xf32>
+  %duplicate_arg =  "tf.Identity"(%arg1) {device = ""} : (tensor<1x3xf32>) -> tensor<1x3xf32>
+  %outputs_1 = "tf_device.cluster_func"(%duplicate_arg, %arg0) {_producer_name = "UNKNOWN", func = @_func } : (tensor<1x3xf32>, tensor<3x1xf32>) -> tensor<1x1xf32>
+  %outputs_2 = "tf.add"(%outputs_0, %outputs_1): (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>
+  return %outputs_2 : tensor<1x1xf32>
+}
+
+// CHECK-LABEL: @_func
+func.func private @_func(%arg0: tensor<1x3xf32>, %arg1: tensor<3x1xf32>) -> (tensor<1x1xf32>) {
+  %outputs_0 =  "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<1x3xf32>, tensor<3x1xf32>) -> tensor<1x1xf32>
+  return %outputs_0 : tensor<1x1xf32>
+}
\ No newline at end of file
diff --git a/tensorflow/compiler/mlir/tfrt/tests/mlrt/async_while.mlir b/tensorflow/compiler/mlir/tfrt/tests/mlrt/async_while.mlir
index b13096c..d7fee9d 100644
--- a/tensorflow/compiler/mlir/tfrt/tests/mlrt/async_while.mlir
+++ b/tensorflow/compiler/mlir/tfrt/tests/mlrt/async_while.mlir
@@ -22,11 +22,11 @@
 }
 
 // CHECK-LABEL: func.func private @"map/while_body/TfMlrtAsyncWhileBody"(%arg0: !mlrt.promise, %arg1: !mlrt.future, %arg2: !mlrt.promise, %arg3: !mlrt.future, %arg4: !mlrt.promise, %arg5: tensor<i32>, %arg6: tensor<?x!tf_type.resource>, %arg7: tensor<*xf32>) {
-// CHECK-NEXT:    %cst = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+// CHECK-NEXT:    %cst = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
 // CHECK-NEXT:    %0 = "tf_mlrt.tf_await"(%arg1) : (!mlrt.future) -> tensor<i32>
 // CHECK-NEXT:    %1 = "tf.AddV2"(%0, %cst) : (tensor<i32>, tensor<i32>) -> tensor<i32>
 // CHECK-NEXT:    "tf_mlrt.tf_promise"(%arg2, %1) : (!mlrt.promise, tensor<i32>) -> ()
-// CHECK-NEXT:    %2 = "tf.PartitionedCall"(%1, %arg5) {config = "", config_proto = "", executor_type = "", f = @"map/while_cond/TfMlrtAsyncWhilePredicate"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+// CHECK-NEXT:    %2 = "tf.PartitionedCall"(%1, %arg5) <{config = "", config_proto = "", executor_type = "", f = @"map/while_cond/TfMlrtAsyncWhilePredicate"}> : (tensor<i32>, tensor<i32>) -> tensor<i1>
 // CHECK-NEXT:    "tf_mlrt.tf_promise"(%arg0, %2) : (!mlrt.promise, tensor<i1>) -> ()
 // CHECK-NEXT:    %3 = "tf.TensorArrayReadV3"(%arg6, %0, %arg7) : (tensor<?x!tf_type.resource>, tensor<i32>, tensor<*xf32>) -> tensor<3x3xf32>
 // CHECK-NEXT:    %4 = "tf_mlrt.tf_await"(%arg3) : (!mlrt.future) -> tensor<3x3xf32>
@@ -37,7 +37,7 @@
 //CHECK-LABEL: func.func @serving_default
 func.func @serving_default(%max_iterations: tensor<i32>, %array_handle: tensor<?x!tf_type.resource>, %array_flow: tensor<*xf32>, %matrix: tensor<3x3xf32>) -> (tensor<3x3xf32>, tensor<*xf32>) {
   %cst_0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-  // CHECK: %0 = "tf.PartitionedCall"(%cst, %arg0) {config = "", config_proto = "", executor_type = "", f = @"map/while_cond/TfMlrtAsyncWhilePredicate"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+  // CHECK: %0 = "tf.PartitionedCall"(%cst, %arg0) <{config = "", config_proto = "", executor_type = "", f = @"map/while_cond/TfMlrtAsyncWhilePredicate"}> : (tensor<i32>, tensor<i32>) -> tensor<i1>
   // CHECK-NEXT: %1:6 = tf_mlrt.tf_async_while @"map/while_body/TfMlrtAsyncWhileBody"(%0, %cst, %arg3, %arg0, %arg1, %arg2) {invariant_size = 3 : i32} : (tensor<i1>, tensor<i32>, tensor<3x3xf32>, tensor<i32>, tensor<?x!tf_type.resource>, tensor<*xf32>) -> (!mlrt.future, !mlrt.future, !mlrt.future, !mlrt.future, !mlrt.future, !mlrt.future)
   %1:5 = "tf.While"(%cst_0, %max_iterations, %array_handle, %array_flow, %matrix) {body= @"map/while_body", cond = @"map/while_cond", is_stateless = false, parallel_iterations = 10 : i64, shape_invariant} : (tensor<i32>, tensor<i32>, tensor<?x!tf_type.resource>, tensor<*xf32>, tensor<3x3xf32>) ->  (tensor<i32>, tensor<i32>, tensor<?x!tf_type.resource>, tensor<*xf32>, tensor<3x3xf32>)
   // CHECK-NEXT: %2 = "tf_mlrt.tf_await"(%1#5) : (!mlrt.future) -> tensor<*xf32>
@@ -50,10 +50,10 @@
 //CHECK-LABEL: func.func @multi_while_test
 func.func @multi_while_test(%max_iterations: tensor<i32>, %array_handle: tensor<?x!tf_type.resource>, %array_flow: tensor<*xf32>, %matrix: tensor<3x3xf32>, %array_handle_2: tensor<?x!tf_type.resource>, %array_flow_2: tensor<*xf32>, %matrix_2: tensor<3x3xf32>) -> (tensor<3x3xf32>, tensor<3x3xf32>) {
   %cst_0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-  // CHECK: %0 = "tf.PartitionedCall"(%cst, %arg0) {config = "", config_proto = "", executor_type = "", f = @"map/while_cond/TfMlrtAsyncWhilePredicate"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+  // CHECK: %0 = "tf.PartitionedCall"(%cst, %arg0) <{config = "", config_proto = "", executor_type = "", f = @"map/while_cond/TfMlrtAsyncWhilePredicate"}> : (tensor<i32>, tensor<i32>) -> tensor<i1>
   // CHECK-NEXT: %1:6 = tf_mlrt.tf_async_while @"map/while_body/TfMlrtAsyncWhileBody"(%0, %cst, %arg3, %arg0, %arg1, %arg2) {invariant_size = 3 : i32} : (tensor<i1>, tensor<i32>, tensor<3x3xf32>, tensor<i32>, tensor<?x!tf_type.resource>, tensor<*xf32>) -> (!mlrt.future, !mlrt.future, !mlrt.future, !mlrt.future, !mlrt.future, !mlrt.future)
   %1:5 = "tf.While"(%cst_0, %max_iterations, %array_handle, %array_flow, %matrix) {body= @"map/while_body", cond = @"map/while_cond", is_stateless = false, parallel_iterations = 10 : i64, shape_invariant} : (tensor<i32>, tensor<i32>, tensor<?x!tf_type.resource>, tensor<*xf32>, tensor<3x3xf32>) ->  (tensor<i32>, tensor<i32>, tensor<?x!tf_type.resource>, tensor<*xf32>, tensor<3x3xf32>)
-  // CHECK: %2 = "tf.PartitionedCall"(%cst, %arg0) {config = "", config_proto = "", executor_type = "", f = @"map/while_cond/TfMlrtAsyncWhilePredicate"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+  // CHECK: %2 = "tf.PartitionedCall"(%cst, %arg0) <{config = "", config_proto = "", executor_type = "", f = @"map/while_cond/TfMlrtAsyncWhilePredicate"}> : (tensor<i32>, tensor<i32>) -> tensor<i1>
   // CHECK-NEXT: %3:6 = tf_mlrt.tf_async_while @"map/while_body/TfMlrtAsyncWhileBody"(%2, %cst, %arg6, %arg0, %arg4, %arg5) {invariant_size = 3 : i32} : (tensor<i1>, tensor<i32>, tensor<3x3xf32>, tensor<i32>, tensor<?x!tf_type.resource>, tensor<*xf32>) -> (!mlrt.future, !mlrt.future, !mlrt.future, !mlrt.future, !mlrt.future, !mlrt.future)
   %2:5 = "tf.While"(%cst_0, %max_iterations, %array_handle_2, %array_flow_2, %matrix_2) {body= @"map/while_body", cond = @"map/while_cond", is_stateless = false, parallel_iterations = 10 : i64, shape_invariant} : (tensor<i32>, tensor<i32>, tensor<?x!tf_type.resource>, tensor<*xf32>, tensor<3x3xf32>) ->  (tensor<i32>, tensor<i32>, tensor<?x!tf_type.resource>, tensor<*xf32>, tensor<3x3xf32>)
   // CHECK-NEXT: %4 = "tf_mlrt.tf_await"(%1#2) : (!mlrt.future) -> tensor<3x3xf32>
@@ -128,11 +128,11 @@
 }
 
 // CHECK-LABEL: func.func private @"random/while_body/TfMlrtAsyncWhileBody_1"(%arg0: !mlrt.promise, %arg1: !mlrt.future, %arg2: !mlrt.promise, %arg3: !mlrt.future, %arg4: !mlrt.promise, %arg5: tensor<i32>, %arg6: tensor<?x!tf_type.resource>, %arg7: tensor<*xf32>) {
-// CHECK-NEXT:    %cst = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+// CHECK-NEXT:    %cst = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
 // CHECK-NEXT:    %0 = "tf_mlrt.tf_await"(%arg1) : (!mlrt.future) -> tensor<i32>
 // CHECK-NEXT:    %1 = "tf.AddV2"(%0, %cst) : (tensor<i32>, tensor<i32>) -> tensor<i32>
 // CHECK-NEXT:    "tf_mlrt.tf_promise"(%arg2, %1) : (!mlrt.promise, tensor<i32>) -> ()
-// CHECK-NEXT:    %2 = "tf.PartitionedCall"(%1, %arg5) {config = "", config_proto = "", executor_type = "", f = @"random/while_cond/TfMlrtAsyncWhilePredicate_0"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+// CHECK-NEXT:    %2 = "tf.PartitionedCall"(%1, %arg5) <{config = "", config_proto = "", executor_type = "", f = @"random/while_cond/TfMlrtAsyncWhilePredicate_0"}> : (tensor<i32>, tensor<i32>) -> tensor<i1>
 // CHECK-NEXT:    "tf_mlrt.tf_promise"(%arg0, %2) : (!mlrt.promise, tensor<i1>) -> ()
 // CHECK-NEXT:    %3 = "tf.TensorArrayReadV3"(%arg6, %0, %arg7) : (tensor<?x!tf_type.resource>, tensor<i32>, tensor<*xf32>) -> tensor<3x3xf32>
 // CHECK-NEXT:    %4 = "tf_mlrt.tf_await"(%arg3) : (!mlrt.future) -> tensor<3x3xf32>
@@ -143,7 +143,7 @@
 //CHECK-LABEL: func.func @random_serving_default
 func.func @random_serving_default(%max_iterations: tensor<i32>, %array_handle: tensor<?x!tf_type.resource>, %array_flow: tensor<*xf32>, %matrix: tensor<3x3xf32>) -> (tensor<3x3xf32>, tensor<*xf32>) {
   %cst_0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-  // CHECK: %0 = "tf.PartitionedCall"(%cst, %arg0) {config = "", config_proto = "", executor_type = "", f = @"random/while_cond/TfMlrtAsyncWhilePredicate_0"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+  // CHECK: %0 = "tf.PartitionedCall"(%cst, %arg0) <{config = "", config_proto = "", executor_type = "", f = @"random/while_cond/TfMlrtAsyncWhilePredicate_0"}> : (tensor<i32>, tensor<i32>) -> tensor<i1>
   // CHECK-NEXT: %1:6 = tf_mlrt.tf_async_while @"random/while_body/TfMlrtAsyncWhileBody_1"(%0, %cst, %arg3, %arg0, %arg1, %arg2) {invariant_size = 3 : i32} : (tensor<i1>, tensor<i32>, tensor<3x3xf32>, tensor<i32>, tensor<?x!tf_type.resource>, tensor<*xf32>) -> (!mlrt.future, !mlrt.future, !mlrt.future, !mlrt.future, !mlrt.future, !mlrt.future)
   %1:5 = "tf.While"(%cst_0, %max_iterations, %array_handle, %array_flow, %matrix) {body= @"random/while_body", cond = @"random/while_cond", is_stateless = false, parallel_iterations = 10 : i64, shape_invariant} : (tensor<i32>, tensor<i32>, tensor<?x!tf_type.resource>, tensor<*xf32>, tensor<3x3xf32>) ->  (tensor<i32>, tensor<i32>, tensor<?x!tf_type.resource>, tensor<*xf32>, tensor<3x3xf32>)
   // CHECK-NEXT: %2 = "tf_mlrt.tf_await"(%1#5) : (!mlrt.future) -> tensor<*xf32>
@@ -166,7 +166,7 @@
 // CHECK-NEXT:    return %0 : tensor<i1>
 
 // CHECK-LABEL: func.func private @"sort_map/while_body"
-// CHECK-NEXT:    %cst = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+// CHECK-NEXT:    %cst = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
 // CHECK-NEXT:    %0 = "tf.AddV2"(%arg0, %cst) : (tensor<i32>, tensor<i32>) -> tensor<i32>
 // CHECK-NEXT:    %1 = "tf.TensorArrayReadV3"(%arg2, %arg0, %arg3) : (tensor<?x!tf_type.resource>, tensor<i32>, tensor<*xf32>) -> tensor<3x3xf32>
 // CHECK-NEXT:    %2 = "tf.TensorArrayReadV3"(%arg5, %arg0, %arg6) : (tensor<?x!tf_type.resource>, tensor<i32>, tensor<*xf32>) -> tensor<3x3xf32>
@@ -177,7 +177,7 @@
 // CHECK-NEXT:   %7 = "tf.Identity"(%4) : (tensor<i1>) -> tensor<i1>
 // CHECK-NEXT:    %8 = "tf.MatMul"(%5, %arg7) : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
 // CHECK-NEXT:    %9 = "tf.Select"(%7, %8, %arg7) : (tensor<i1>, tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
-// CHECK-NEXT:    return %0, %arg1, %arg2, %arg3, %6, %arg5, %arg6, %9, %arg8 
+// CHECK-NEXT:    return %0, %arg1, %arg2, %arg3, %6, %arg5, %arg6, %9, %arg8
 func.func private @"sort_map/while_body"(%loop_count: tensor<i32>, %max_iterations: tensor<i32>, %handle: tensor<?x!tf_type.resource>, %flow_in: tensor<*xf32>, %matrix: tensor<3x3xf32>, %handle_2: tensor<?x!tf_type.resource>, %flow_in_2: tensor<*xf32>, %matrix_2: tensor<3x3xf32>, %bound: tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<?x!tf_type.resource>, tensor<*xf32>, tensor<3x3xf32>, tensor<?x!tf_type.resource>, tensor<*xf32>, tensor<3x3xf32>, tensor<i32>) {
   %cst_1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
   %updated_loop_count = "tf.AddV2"(%loop_count, %cst_1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
@@ -194,11 +194,11 @@
 }
 
 // CHECK-LABEL: func.func private @"sort_map/while_body/TfMlrtAsyncWhileBody"
-// CHECK-NEXT:    %cst = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+// CHECK-NEXT:    %cst = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
 // CHECK-NEXT:    %0 = "tf_mlrt.tf_await"(%arg1) : (!mlrt.future) -> tensor<i32>
 // CHECK-NEXT:    %1 = "tf.AddV2"(%0, %cst) : (tensor<i32>, tensor<i32>) -> tensor<i32>
 // CHECK-NEXT:    "tf_mlrt.tf_promise"(%arg2, %1) : (!mlrt.promise, tensor<i32>) -> ()
-// CHECK-NEXT:    %2 = "tf.PartitionedCall"(%1, %arg7) {config = "", config_proto = "", executor_type = "", f = @"sort_map/while_cond/TfMlrtAsyncWhilePredicate"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+// CHECK-NEXT:    %2 = "tf.PartitionedCall"(%1, %arg7) <{config = "", config_proto = "", executor_type = "", f = @"sort_map/while_cond/TfMlrtAsyncWhilePredicate"}> : (tensor<i32>, tensor<i32>) -> tensor<i1>
 // CHECK-NEXT:    "tf_mlrt.tf_promise"(%arg0, %2) : (!mlrt.promise, tensor<i1>) -> ()
 // CHECK-NEXT:    %3 = "tf.TensorArrayReadV3"(%arg8, %0, %arg9) : (tensor<?x!tf_type.resource>, tensor<i32>, tensor<*xf32>) -> tensor<3x3xf32>
 // CHECK-NEXT:    %4 = "tf.TensorArrayReadV3"(%arg10, %0, %arg11) : (tensor<?x!tf_type.resource>, tensor<i32>, tensor<*xf32>) -> tensor<3x3xf32>
@@ -218,7 +218,7 @@
 //CHECK-LABEL: func.func @sort_serving_default
 func.func @sort_serving_default(%max_iterations: tensor<i32>, %array_handle: tensor<?x!tf_type.resource>, %array_flow: tensor<*xf32>, %matrix: tensor<3x3xf32>, %bound: tensor<i32>) -> (tensor<3x3xf32>, tensor<3x3xf32>, tensor<*xf32>) {
   %cst_0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-  // CHECK: %0 = "tf.PartitionedCall"(%cst, %arg0) {config = "", config_proto = "", executor_type = "", f = @"sort_map/while_cond/TfMlrtAsyncWhilePredicate"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+  // CHECK: %0 = "tf.PartitionedCall"(%cst, %arg0) <{config = "", config_proto = "", executor_type = "", f = @"sort_map/while_cond/TfMlrtAsyncWhilePredicate"}> : (tensor<i32>, tensor<i32>) -> tensor<i1>
   // CHECK-NEXT: %1:10 = tf_mlrt.tf_async_while @"sort_map/while_body/TfMlrtAsyncWhileBody"(%0, %cst, %arg3, %arg3, %arg0, %arg1, %arg2, %arg1, %arg2, %arg4) {invariant_size = 6 : i32} : (tensor<i1>, tensor<i32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<i32>, tensor<?x!tf_type.resource>, tensor<*xf32>, tensor<?x!tf_type.resource>, tensor<*xf32>, tensor<i32>) -> (!mlrt.future, !mlrt.future, !mlrt.future, !mlrt.future, !mlrt.future, !mlrt.future, !mlrt.future, !mlrt.future, !mlrt.future, !mlrt.future)
   %1:9 = "tf.While"(%cst_0, %max_iterations, %array_handle, %array_flow, %matrix , %array_handle, %array_flow, %matrix, %bound) {body= @"sort_map/while_body", cond = @"sort_map/while_cond", is_stateless = false, parallel_iterations = 10 : i64, shape_invariant} : (tensor<i32>, tensor<i32>, tensor<?x!tf_type.resource>, tensor<*xf32>, tensor<3x3xf32>, tensor<?x!tf_type.resource>, tensor<*xf32>, tensor<3x3xf32>,tensor<i32>) ->  (tensor<i32>, tensor<i32>, tensor<?x!tf_type.resource>, tensor<*xf32>, tensor<3x3xf32>, tensor<?x!tf_type.resource>, tensor<*xf32>, tensor<3x3xf32>, tensor<i32>)
   // CHECK-NEXT:  %2 = "tf_mlrt.tf_await"(%1#6) : (!mlrt.future) -> tensor<*xf32>
diff --git a/tensorflow/compiler/mlir/tfrt/tests/mlrt/while_to_map_fn.mlir b/tensorflow/compiler/mlir/tfrt/tests/mlrt/while_to_map_fn.mlir
index 27c9228..2b1f5fc 100644
--- a/tensorflow/compiler/mlir/tfrt/tests/mlrt/while_to_map_fn.mlir
+++ b/tensorflow/compiler/mlir/tfrt/tests/mlrt/while_to_map_fn.mlir
@@ -37,7 +37,7 @@
 // CHECK-SAME: (%arg0: !mlrt.future, %arg1: !mlrt.promise, %arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<?xf32>)
 // CHECK: [[det:%.*]] = "tf.MatrixDeterminant"
 // CHECK-NEXT: [[ta_0:%.*]] = "tf_mlrt.tf_await"(%arg0) : (!mlrt.future) -> tensor<!tf_type.variant<tensor<*xf32>>>
-// CHECK-NEXT: [[ta_1:%.*]] = "tf.TensorListSetItem"([[ta_0]], %arg3, [[det]]) {
+// CHECK-NEXT: [[ta_1:%.*]] = "tf.TensorListSetItem"([[ta_0]], %arg3, [[det]]) <{
 // CHECK-NEXT:  "tf_mlrt.tf_promise"(%arg1, [[ta_1]]) : (!mlrt.promise, tensor<!tf_type.variant<tensor<*xf32>>>) -> ()
 // CHECK-NEXT: return
 
@@ -53,7 +53,7 @@
   // CHECK-SAME: {body_fn = @"map/while_body/MapFnBody", num_tensor_list_or_flow_in = 1 : i32}
   // CHECK-NOT: tf.While
   %1:4 = "tf.While"(%cst, %cst, %0, %arg0) {_lower_using_switch_merge = true, _num_original_outputs = 6 : i64, _read_only_resource_inputs = [], _xla_propagate_compile_time_consts = true, body = @"map/while_body", cond = @"map/while_cond", device = "/job:localhost/replica:0/task:0/device:CPU:0", is_stateless = true, parallel_iterations = 4 : i64, shape_invariant} : (tensor<i32>, tensor<i32>, tensor<!tf_type.variant<tensor<*xf32>>>, tensor<?xf32>) -> (tensor<i32>, tensor<i32>, tensor<!tf_type.variant<tensor<*xf32>>>, tensor<?xf32>)
-  // CHECK-NEXT: "tf.TensorListStack"([[map_fn_result]], %cst_0) {
+  // CHECK-NEXT: "tf.TensorListStack"([[map_fn_result]], %cst_0) <{
   %2 = "tf.TensorListStack"(%1#2, %cst_0) {device = "/job:localhost/replica:0/task:0/device:CPU:0", num_elements = 3 : i64} : (tensor<!tf_type.variant<tensor<*xf32>>>, tensor<0xi32>) -> tensor<3xf32>
   return %2 : tensor<3xf32>
 }
@@ -458,7 +458,7 @@
   %5 = "tf.TensorArrayGatherV3"(%handle_12, %1, %4#2) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<2x!tf_type.resource<tensor<*x!tf_type.variant>>>, tensor<i32>, tensor<f32>) -> tensor<?x!tf_type.variant>
   // CHECK: TensorArrayGatherV3
   %6 = "tf.TensorArrayGatherV3"(%handle_14, %2, %4#3) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<2x!tf_type.resource<tensor<*x!tf_type.variant>>>, tensor<i32>, tensor<f32>) -> tensor<?x!tf_type.variant>
-  return %5, %6 : tensor<?x!tf_type.variant>, tensor<?x!tf_type.variant> 
+  return %5, %6 : tensor<?x!tf_type.variant>, tensor<?x!tf_type.variant>
 }
 
 // -----
diff --git a/tensorflow/compiler/mlir/tfrt/tests/runtime_lowering_tpu.mlir b/tensorflow/compiler/mlir/tfrt/tests/runtime_lowering_tpu.mlir
index 5225c2e..d6ffe03 100644
--- a/tensorflow/compiler/mlir/tfrt/tests/runtime_lowering_tpu.mlir
+++ b/tensorflow/compiler/mlir/tfrt/tests/runtime_lowering_tpu.mlir
@@ -4,7 +4,7 @@
 
   // CHECK-LABEL: @converts_cluster
   func.func @converts_cluster() {
-    // CHECK: %0:2 = "tf_device.launch"() ({
+    // CHECK: %0:2 = "tf_device.launch"() <{{.*}}> ({
     // CHECK: %compilation_status, %program = "tf._TPUCompileMlir"()
     "tf_device.cluster_func"() {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "", topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> ()
     func.return
@@ -26,4 +26,4 @@
   func.func @empty_func() {
     func.return
   }
-}
\ No newline at end of file
+}
diff --git a/tensorflow/compiler/mlir/tfrt/tests/sink_in_invariant_ops.mlir b/tensorflow/compiler/mlir/tfrt/tests/sink_in_invariant_ops.mlir
index 2b0344c..42e2e7c 100644
--- a/tensorflow/compiler/mlir/tfrt/tests/sink_in_invariant_ops.mlir
+++ b/tensorflow/compiler/mlir/tfrt/tests/sink_in_invariant_ops.mlir
@@ -93,8 +93,8 @@
 // CHECK-LABEL: func private @batched_function
 func.func private @batched_function(%arg0: tensor<!tf_type.resource<tensor<1x3xf32>>>, %arg1: tensor<!tf_type.resource<tensor<1x3xf32>>>) -> tensor<1x3xf32>
   attributes {tf._input_shapes = [#tf_type.shape<1x3>, #tf_type.shape<*>], tf.signature.is_stateful} {
-  // CHECK-DAG: [[handle1:%.*]] = "tf.VarHandleOp"() {{{.*}}, shared_name = "variable1"}
-  // CHECK-DAG: [[handle2:%.*]] = "tf.VarHandleOp"() {{{.*}}, shared_name = "variable2"}
+  // CHECK-DAG: [[handle1:%.*]] = "tf.VarHandleOp"() <{{{.*}}, shared_name = "variable1"}>
+  // CHECK-DAG: [[handle2:%.*]] = "tf.VarHandleOp"() <{{{.*}}, shared_name = "variable2"}>
   // CHECK: "tf.ReadVariableOp"([[handle1]])
   // CHECK: "tf.ReadVariableOp"([[handle2]])
   %0 = "tf.ReadVariableOp"(%arg0) {device = "/device:CPU:0"} : (tensor<!tf_type.resource<tensor<1x3xf32>>>) -> tensor<1x3xf32>
@@ -298,7 +298,7 @@
 }
 
 // CHECK-LABEL: func @main
-func.func @main(%arg0: tensor<1x3xf32> {tf_saved_model.index_path = ["input"]}) -> (tensor<*xf32> {tf_saved_model.index_path = ["r"]}) 
+func.func @main(%arg0: tensor<1x3xf32> {tf_saved_model.index_path = ["input"]}) -> (tensor<*xf32> {tf_saved_model.index_path = ["r"]})
   attributes {tf_saved_model.exported_names = ["main"]} {
   // CHECK: [[handle:%.*]] = "tf.VarHandleOp"()
   %0 = "tf.VarHandleOp"() {device = "/device:CPU:0", container = "", shared_name = "variable"} : () -> tensor<!tf_type.resource<tensor<1x3xf32>>>
diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/merge_tf_if_ops.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/merge_tf_if_ops.mlir
index f79c319..cb19907 100644
--- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/merge_tf_if_ops.mlir
+++ b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/merge_tf_if_ops.mlir
@@ -51,8 +51,8 @@
 
 // CHECK-LABEL: func private @merge_stateless_merged_if_0_0_else
 // CHECK-SAME: ([[x:%.*]]: tensor<i32>, [[y:%.*]]: tensor<i32>)
-// CHECK-DAG: [[cst:%.*]] = "tf.Const"() {value = dense<1> : tensor<i32>}
-// CHECK-DAG: [[cst_0:%.*]] = "tf.Const"() {value = dense<2> : tensor<i32>}
+// CHECK-DAG: [[cst:%.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}>
+// CHECK-DAG: [[cst_0:%.*]] = "tf.Const"() <{value = dense<2> : tensor<i32>}>
 // CHECK: [[r0:%.*]] = "tf.AddV2"([[x]], [[cst]])
 // CHECK: [[r1:%.*]] = "tf.AddV2"([[y]], [[r0]])
 // CHECK: [[r2:%.*]] = "tf.AddV2"([[x]], [[cst_0]])
@@ -63,7 +63,7 @@
 // CHECK-SAME: ([[x:%.*]]: tensor<i32>, [[y:%.*]]: tensor<i32>, [[cond:%.*]]: tensor<i1>)
 func.func @merge_stateless(%x: tensor<i32>, %y: tensor<i32>, %cond: tensor<i1>) -> (tensor<i32>, tensor<i32>, tensor<i32>) {
   // CHECK-NEXT: [[res:%.*]]:3 = "tf.If"([[cond]], [[x]], [[y]])
-  // CHECK-SAME: {else_branch = @merge_stateless_merged_if_0_0_else, is_stateless = true, then_branch = @merge_stateless_merged_if_0_0_then}
+  // CHECK-SAME: <{else_branch = @merge_stateless_merged_if_0_0_else, is_stateless = true, then_branch = @merge_stateless_merged_if_0_0_then}>
   // CHECK-SAME: (tensor<i1>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>)
   // CHECK-NEXT: return [[res]]#0, [[res]]#1, [[res]]#2
   %0, %1 = "tf.If"(%cond, %x, %y) {else_branch = @no_side_effect_else_0, then_branch = @no_side_effect_then_0, is_stateless = true} : (tensor<i1>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>)
@@ -83,8 +83,8 @@
 
 // CHECK-LABEL: func private @merge_nested_if_op_merged_if_0_0_else_merged_if_1_0_else
 // CHECK-SAME: ([[x:%.*]]: tensor<i32>, [[y:%.*]]: tensor<i32>)
-// CHECK-NEXT: [[cst:%.*]] = "tf.Const"() {value = dense<2> : tensor<i32>}
-// CHECK-NEXT: [[cst_0:%.*]] = "tf.Const"() {value = dense<1> : tensor<i32>}
+// CHECK-NEXT: [[cst:%.*]] = "tf.Const"() <{value = dense<2> : tensor<i32>}>
+// CHECK-NEXT: [[cst_0:%.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}>
 // CHECK-NEXT: [[r0:%.*]] = "tf.AddV2"([[x]], [[cst_0]])
 // CHECK-NEXT: [[r1:%.*]] = "tf.AddV2"([[y]], [[r0]])
 // CHECK-NEXT: [[r2:%.*]] = "tf.AddV2"([[x]], [[cst]])
@@ -93,14 +93,14 @@
 
 // CHECK-LABEL: func private @merge_nested_if_op_merged_if_0_0_else
 // CHECK-SAME: ([[cond:%.*]]: tensor<i1>, [[x:%.*]]: tensor<i32>, [[y:%.*]]: tensor<i32>)
-// CHECK-NEXT: [[r0:%.*]]:3 = "tf.If"(%arg0, %arg1, %arg2) {else_branch = @merge_nested_if_op_merged_if_0_0_else_merged_if_1_0_else, is_stateless = true, then_branch = @merge_nested_if_op_merged_if_0_0_else_merged_if_1_0_then} : (tensor<i1>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>)
+// CHECK-NEXT: [[r0:%.*]]:3 = "tf.If"(%arg0, %arg1, %arg2) <{else_branch = @merge_nested_if_op_merged_if_0_0_else_merged_if_1_0_else, is_stateless = true, then_branch = @merge_nested_if_op_merged_if_0_0_else_merged_if_1_0_then}> : (tensor<i1>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>)
 // CHECK-NEXT: return [[r0]]#0, [[r0]]#1, [[r0]]#2
 
 // CHECK-LABEL: func @merge_nested_if_op
 // CHECK-SAME: ([[x:%.*]]: tensor<i32>, [[y:%.*]]: tensor<i32>, [[cond:%.*]]: tensor<i1>, [[nested_cond:%.*]]: tensor<i1>)
 func.func @merge_nested_if_op(%x: tensor<i32>, %y: tensor<i32>, %cond: tensor<i1>, %nested_cond: tensor<i1>) -> (tensor<i32>, tensor<i32>, tensor<i32>) {
   // CHECK-NEXT: [[res:%.*]]:3 = "tf.If"([[cond]], [[nested_cond]], [[x]], [[y]])
-  // CHECK-SAME: {else_branch = @merge_nested_if_op_merged_if_0_0_else, is_stateless = true, then_branch = @merge_nested_if_op_merged_if_0_0_then}
+  // CHECK-SAME: <{else_branch = @merge_nested_if_op_merged_if_0_0_else, is_stateless = true, then_branch = @merge_nested_if_op_merged_if_0_0_then}>
   // CHECK-SAME: (tensor<i1>, tensor<i1>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>)
   // CHECK-NEXT: return [[res]]#0, [[res]]#1, [[res]]#2
   %0, %1 = "tf.If"(%cond, %nested_cond, %x, %y) {else_branch = @nested_if_op_else_0, then_branch = @nested_if_op_then_0, is_stateless = true} : (tensor<i1>, tensor<i1>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>)
@@ -122,7 +122,7 @@
 // CHECK-LABEL: func @multiple_uses
 func.func @multiple_uses(%x: tensor<i32>, %y: tensor<i32>, %cond: tensor<i1>) -> (tensor<i32>, tensor<i32>, tensor<i32>) {
   // CHECK-NEXT: tf.If
-  // CHECK-SAME: {else_branch = @multiple_uses_merged_if_0_0_else, is_stateless = true, then_branch = @multiple_uses_merged_if_0_0_then}
+  // CHECK-SAME: <{else_branch = @multiple_uses_merged_if_0_0_else, is_stateless = true, then_branch = @multiple_uses_merged_if_0_0_then}>
   %0, %1 = "tf.If"(%cond, %x, %y) {else_branch = @no_side_effect_else_0, then_branch = @no_side_effect_then_0, is_stateless = true} : (tensor<i1>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>)
   %2 = "tf.If"(%cond, %x, %y) {else_branch = @no_side_effect_else_1, then_branch = @no_side_effect_then_1, is_stateless = true} : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
   func.return %0, %1, %2 : tensor<i32>, tensor<i32>, tensor<i32>
diff --git a/tensorflow/compiler/mlir/tfrt/tf-tfrt-opt.cc b/tensorflow/compiler/mlir/tfrt/tf-tfrt-opt.cc
index 6ab7337..1ae3e8f 100644
--- a/tensorflow/compiler/mlir/tfrt/tf-tfrt-opt.cc
+++ b/tensorflow/compiler/mlir/tfrt/tf-tfrt-opt.cc
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Shape/IR/Shape.h"  // from @llvm-project
 #include "mlir/InitAllDialects.h"  // from @llvm-project
 #include "mlir/InitAllPasses.h"  // from @llvm-project
+#include "mlir/Support/LogicalResult.h"  // from @llvm-project
 #include "mlir/Tools/mlir-opt/MlirOptMain.h"  // from @llvm-project
 #include "mlir/Transforms/Passes.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/init_mlir.h"
@@ -29,12 +30,12 @@
 #include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.h"
 #include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_sync.h"
 #include "tensorflow/compiler/mlir/tfrt/transforms/gpu_passes.h"
+#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.h"
 #include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/passes.h"
-#include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"
+#include "tensorflow/compiler/mlir/tfrt/transforms/tpu_passes.h"
 #include "xla/mlir_hlo/gml_st/IR/gml_st_ops.h"
 #include "xla/mlir_hlo/gml_st/transforms/passes.h"
 #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
-#include "tensorflow/core/platform/init_main.h"
 #include "tfrt/init_tfrt_dialects.h"  // from @tf_runtime
 
 int main(int argc, char **argv) {
@@ -48,6 +49,7 @@
   mlir::gml_st::registerGmlStPasses();
 
   tensorflow::mlrt_compiler::RegisterMlrtPasses();
+  tensorflow::ifrt_serving::RegisterTfIfrtPasses();
 
   mlir::DialectRegistry registry;
   mlir::registerAllDialects(registry);
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.cc
new file mode 100644
index 0000000..55ae462
--- /dev/null
+++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.cc
@@ -0,0 +1,187 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.h"
+
+#include <cstdint>
+#include <memory>
+#include <optional>
+#include <utility>
+#include <vector>
+
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
+#include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/Operation.h"  // from @llvm-project
+#include "mlir/IR/Value.h"  // from @llvm-project
+#include "mlir/IR/Verifier.h"  // from @llvm-project
+#include "mlir/Support/LogicalResult.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
+#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
+#include "tensorflow/compiler/mlir/tensorflow/utils/visitor.h"
+#include "tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.h"
+#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_serving_executable.h"
+#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.h"
+#include "tensorflow/compiler/mlir/tfrt/transforms/tpu_passes.h"
+#include "tensorflow/core/tfrt/ifrt/ifrt_executable_registry.h"
+#include "tensorflow/core/tfrt/ifrt/ifrt_model_context.h"
+#include "tensorflow/core/tfrt/runtime/runtime.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/profiler/lib/traceme.h"
+
+namespace tensorflow {
+namespace ifrt_serving {
+
+absl::StatusOr<std::vector<ServingExecutableRegistry::Handle>>
+IfrtBackendCompiler::CompileAndRegisterIfrtPrograms(
+    absl::string_view model_name, mlir::ModuleOp module) const {
+  std::vector<ServingExecutableRegistry::Handle> handles;
+
+  // Compile Ifrt programs and register the executables. Outlined Ifrt
+  // programs are marked with `tfrt_ifrt_serving.program_id` attributes.
+  for (auto func : module.getOps<mlir::func::FuncOp>()) {
+    int64_t program_id;
+    if (auto attr = func->getAttrOfType<mlir::IntegerAttr>(
+            "tfrt_ifrt_serving.program_id")) {
+      DCHECK(attr.getType().isSignlessInteger());
+      program_id = attr.getInt();
+    } else {
+      continue;
+    }
+
+    mlir::StatusScopedDiagnosticHandler diag_handler(module->getContext());
+
+    auto entry_function_name = func.getSymName();
+    auto submodule = mlir::TF::CreatePrunedModule(module, entry_function_name);
+    if (mlir::failed(submodule)) {
+      return diag_handler.ConsumeStatus();
+    }
+
+    auto executable = IfrtServingExecutable::Create(
+        model_name, entry_function_name, *std::move(submodule));
+
+    // Register the Ifrt program to `ServingExecutableRegistry` so that
+    // the client TF program can invoke them via `IfrtProgramCall` ops.
+    TF_ASSIGN_OR_RETURN(auto handle, ServingExecutableRegistry::Register(
+                                         program_id, std::move(executable)));
+
+    handles.push_back(std::move(handle));
+  }
+
+  return handles;
+}
+
+absl::Status IfrtBackendCompiler::CompileTensorflowForIfrtServing(
+    absl::string_view model_name, IfrtModelContext& ifrt_model_context,
+    mlir::ModuleOp module) const {
+  tsl::profiler::TraceMe trace_me("CompileTensorflowForIfrtServing");
+  mlir::Builder builder(module.getContext());
+
+  TF_RETURN_IF_ERROR(
+      RunClusterToIfrtRuntimeOpsPassPipeline(module, model_name));
+
+  TF_ASSIGN_OR_RETURN(auto handles,
+                      CompileAndRegisterIfrtPrograms(model_name, module));
+
+  for (auto& handle : handles) {
+    ifrt_model_context.RegisterHandle(std::move(handle));
+  }
+
+  return absl::OkStatus();
+}
+
+// Compile ifrt programs in TF dialect into ifrt executables.
+// Remove ifrt programs afterwards.
+absl::Status IfrtBackendCompiler::CompileTensorflow(
+    tensorflow::tfrt_stub::ModelRuntimeContext& model_context,
+    mlir::ModuleOp module) const {
+  auto ifrt_model_context =
+      model_context.resource_context().GetResource<IfrtModelContext>(
+          kIfrtModelContextName);
+  if (!ifrt_model_context.has_value()) {
+    return absl::InternalError(
+        "Failed to find model context for ifrt serving.");
+  }
+
+  mlir::StatusScopedDiagnosticHandler diag_handler(module->getContext());
+  if (VLOG_IS_ON(1)) {
+    tensorflow::DumpMlirOpToFile("ifrt_tpu_bct_conversion_before", module);
+  }
+
+  // TODO(b/305734600): conditionally running backward compat pass on host with
+  // tpu only.
+  //
+  // Run backward compat pass so that we can use bridge to do clustering.
+  auto backward_compat_result =
+      tensorflow::RunTPUBackwardCompatConversion(module, {});
+  if (mlir::failed(backward_compat_result)) {
+    return diag_handler.Combine(
+        absl::InternalError("Failed to handle legacy TPU Ops"));
+  }
+
+  if (VLOG_IS_ON(1)) {
+    tensorflow::DumpMlirOpToFile("ifrt_tpu_bct_conversion_after", module);
+  }
+
+  // Use bridge for cluster formation.
+  TF_RETURN_IF_ERROR(tensorflow::tf2xla::v2::RunFunctionTf2xlaClusteringBridge(
+      module, tensorflow::tf2xla::v2::DeviceType::XLA_TPU_JIT,
+      /*is_in_fallback_enabled_mode=*/false));
+
+  if (VLOG_IS_ON(1)) {
+    tensorflow::DumpMlirOpToFile("before_ifrt_outlining", module);
+  }
+
+  // Extract TPU program for IFRT call.
+  TF_RETURN_IF_ERROR(CompileTensorflowForIfrtServing(
+      model_context.name(), **ifrt_model_context, module));
+
+  if (VLOG_IS_ON(1)) {
+    tensorflow::DumpMlirOpToFile("after_ifrt_outlining", module);
+  }
+
+  // IFRT program is no longer needed.
+  llvm::SmallVector<mlir::func::FuncOp> to_erase;
+  for (auto func : module.getOps<mlir::func::FuncOp>()) {
+    if (func->getAttr("tfrt_ifrt_serving.program_id")) {
+      to_erase.push_back(func);
+    }
+  }
+  for (auto func : to_erase) {
+    func->erase();
+  }
+
+  if (VLOG_IS_ON(1)) {
+    tensorflow::DumpMlirOpToFile("after_ifrt_program_removal", module);
+  }
+
+  if (mlir::failed(mlir::verify(module))) {
+    return diag_handler.ConsumeStatus();
+  }
+
+  return absl::OkStatus();
+}
+
+}  // namespace ifrt_serving
+}  // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.h b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.h
new file mode 100644
index 0000000..f97f0dd
--- /dev/null
+++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.h
@@ -0,0 +1,55 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_IFRT_BACKEND_COMPILER_H_
+#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_IFRT_BACKEND_COMPILER_H_
+
+#include <vector>
+
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/tfrt/backend_compiler.h"
+#include "tensorflow/core/tfrt/ifrt/ifrt_executable_registry.h"
+#include "tensorflow/core/tfrt/ifrt/ifrt_model_context.h"
+#include "tensorflow/core/tfrt/runtime/runtime.h"
+
+namespace tensorflow {
+namespace ifrt_serving {
+
+// Implements the custom backend compiler for IFRT based serving in TFRT.
+class IfrtBackendCompiler : public tensorflow::BackendCompiler {
+ public:
+  // Rewrites the tensorflow graph in MLIR for IFRT serving. The methods
+  // extracts regions for IFRT execution on accelerator (e.g. TPU).
+  absl::Status CompileTensorflow(
+      tensorflow::tfrt_stub::ModelRuntimeContext& model_context,
+      mlir::ModuleOp module) const override;
+
+ private:
+  absl::Status CompileTensorflowForIfrtServing(
+      absl::string_view model_name, IfrtModelContext& ifrt_model_context,
+      mlir::ModuleOp module) const;
+
+  absl::StatusOr<std::vector<ServingExecutableRegistry::Handle>>
+  CompileAndRegisterIfrtPrograms(absl::string_view model_name,
+                                 mlir::ModuleOp module) const;
+};
+
+}  // namespace ifrt_serving
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_IFRT_BACKEND_COMPILER_H_
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler_test.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler_test.cc
new file mode 100644
index 0000000..8330551
--- /dev/null
+++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler_test.cc
@@ -0,0 +1,83 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.h"
+
+#include <memory>
+#include <string>
+#include <utility>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/DialectRegistry.h"  // from @llvm-project
+#include "mlir/IR/MLIRContext.h"  // from @llvm-project
+#include "mlir/IR/OwningOpRef.h"  // from @llvm-project
+#include "mlir/InitAllDialects.h"  // from @llvm-project
+#include "mlir/Parser/Parser.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
+#include "tensorflow/core/platform/resource_loader.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/tfrt/graph_executor/graph_execution_options.h"
+#include "tensorflow/core/tfrt/ifrt/ifrt_model_context.h"
+#include "tensorflow/core/tfrt/runtime/runtime.h"
+#include "tensorflow/core/tfrt/saved_model/saved_model_testutil.h"
+#include "tsl/lib/core/status_test_util.h"
+#include "tfrt/host_context/resource_context.h"  // from @tf_runtime
+
+namespace tensorflow {
+namespace ifrt_serving {
+
+TEST(IfrtBackendCompilerTest, Basic) {
+  // Create test input module
+  constexpr absl::string_view kDataDirectory =
+      "tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata";
+  std::string mlir_module_path = tensorflow::GetDataDependencyFilepath(
+      absl::StrCat(kDataDirectory, "/ifrt_cluster.mlir"));
+
+  mlir::DialectRegistry registry;
+  mlir::registerAllDialects(registry);
+  mlir::RegisterAllTensorFlowDialects(registry);
+
+  mlir::MLIRContext context(registry);
+
+  mlir::OwningOpRef<mlir::ModuleOp> mlir_module =
+      mlir::parseSourceFile<mlir::ModuleOp>(mlir_module_path, &context);
+
+  ASSERT_TRUE(mlir_module);
+  ASSERT_TRUE(mlir_module.get() != nullptr);
+
+  // Create contexts required for the compiler execution.
+  IfrtModelContext model_context;
+
+  std::unique_ptr<tensorflow::tfrt_stub::Runtime> runtime =
+      tensorflow::tfrt_stub::DefaultTfrtRuntime(/*num_threads=*/1);
+  tensorflow::tfrt_stub::GraphExecutionOptions graph_execution_options(
+      runtime.get());
+  tfrt::ResourceContext resource_context;
+  tensorflow::tfrt_stub::ModelRuntimeContext runtime_context(
+      &graph_execution_options, /*export_dir=*/"", &resource_context);
+
+  runtime_context.resource_context().CreateResource<IfrtModelContext>(
+      "IfrtModelContext", std::move(model_context));
+
+  IfrtBackendCompiler compiler;
+  TF_ASSERT_OK(compiler.CompileTensorflow(runtime_context, mlir_module.get()));
+}
+
+}  // namespace ifrt_serving
+}  // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_serving_executable.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_serving_executable.cc
new file mode 100644
index 0000000..508c796
--- /dev/null
+++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_serving_executable.cc
@@ -0,0 +1,38 @@
+
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_serving_executable.h"
+
+#include <vector>
+
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/types/span.h"
+#include "tensorflow/core/framework/tensor.h"
+
+namespace tensorflow {
+namespace ifrt_serving {
+
+// Executes the computation.
+absl::StatusOr<std::vector<tensorflow::Tensor>> IfrtServingExecutable::Execute(
+    absl::Span<const tensorflow::Tensor> inputs) {
+  return absl::UnimplementedError("Not implemented");
+}
+
+}  // namespace ifrt_serving
+}  // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_serving_executable.h b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_serving_executable.h
new file mode 100644
index 0000000..3ee05e1
--- /dev/null
+++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_serving_executable.h
@@ -0,0 +1,81 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_IFRT_SERVING_EXECUTABLE_H_
+#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_IFRT_SERVING_EXECUTABLE_H_
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/log/log.h"
+#include "absl/memory/memory.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/OwningOpRef.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+
+namespace tensorflow {
+namespace ifrt_serving {
+
+class IfrtServingExecutable {
+ public:
+  static std::unique_ptr<IfrtServingExecutable> Create(
+      absl::string_view model_name, absl::string_view signature_name,
+      mlir::OwningOpRef<mlir::ModuleOp> module) {
+    VLOG(1) << "Creating IfrtServingExecutable";
+    std::string serialized_mlir_module =
+        tensorflow::SerializeMlirModule(*module);
+
+    return absl::WrapUnique(new IfrtServingExecutable(
+        model_name, signature_name, std::move(serialized_mlir_module)));
+  }
+
+  // Movable but not copyable.
+  IfrtServingExecutable(IfrtServingExecutable&& other) = default;
+  IfrtServingExecutable& operator=(IfrtServingExecutable&& other) = default;
+  IfrtServingExecutable(const IfrtServingExecutable& other) = delete;
+  IfrtServingExecutable& operator=(const IfrtServingExecutable& other) = delete;
+
+  absl::string_view model_name() const { return model_name_; }
+  absl::string_view signature_name() const { return signature_name_; }
+
+  // Executes the computation.
+  absl::StatusOr<std::vector<tensorflow::Tensor>> Execute(
+      absl::Span<const tensorflow::Tensor> inputs);
+
+ private:
+  std::string model_name_;
+  std::string signature_name_;
+
+  std::string serialized_mlir_module_;
+
+  explicit IfrtServingExecutable(absl::string_view model_name,
+                                 absl::string_view signature_name,
+                                 std::string serialized_mlir_module)
+      : model_name_(std::string(model_name)),
+        signature_name_(std::string(signature_name)),
+        serialized_mlir_module_(std::move(serialized_mlir_module)) {}
+};
+
+}  // namespace ifrt_serving
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_IFRT_SERVING_EXECUTABLE_H_
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/rewrite_cluster_to_ifrt_call.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/rewrite_cluster_to_ifrt_call.cc
new file mode 100644
index 0000000..a379662
--- /dev/null
+++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/rewrite_cluster_to_ifrt_call.cc
@@ -0,0 +1,192 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/rewrite_cluster_to_ifrt_call.h"
+
+#include <cstdint>
+#include <memory>
+#include <vector>
+
+#include "absl/base/casts.h"
+#include "absl/strings/str_cat.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
+#include "mlir/IR/Attributes.h"  // from @llvm-project
+#include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/DialectRegistry.h"  // from @llvm-project
+#include "mlir/IR/IRMapping.h"  // from @llvm-project
+#include "mlir/IR/Operation.h"  // from @llvm-project
+#include "mlir/IR/SymbolTable.h"  // from @llvm-project
+#include "mlir/IR/Value.h"  // from @llvm-project
+#include "mlir/Pass/Pass.h"  // from @llvm-project
+#include "mlir/Support/TypeID.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.h"
+#include "tensorflow/core/platform/random.h"
+
+namespace tensorflow {
+namespace ifrt_serving {
+namespace {
+
+// A pass that inserts tf.ifrt_call and create its callee as a Ifrt
+// Program.
+class RewriteClusterToIfrtCallPass
+    : public mlir::PassWrapper<RewriteClusterToIfrtCallPass,
+                               mlir::OperationPass<mlir::ModuleOp>> {
+ public:
+  RewriteClusterToIfrtCallPass() = default;
+  RewriteClusterToIfrtCallPass &operator=(
+      const RewriteClusterToIfrtCallPass &) = delete;
+
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RewriteClusterToIfrtCallPass)
+
+ private:
+  // Returns a new unique program id.
+  static int64_t NewProgramId() {
+    const uint64_t id = static_cast<int64_t>(tensorflow::random::New64());
+    // We use a signed int for program ids since TensorFlow doesn't
+    // support uint64_t attributes.
+    return absl::bit_cast<int64_t>(id);
+  }
+
+  void getDependentDialects(mlir::DialectRegistry &registry) const override {}
+
+  llvm::StringRef getArgument() const final {
+    return "rewrite-cluster-to-ifrt-call";
+  }
+
+  llvm::StringRef getDescription() const final {
+    return "Convert tf_device.cluster_func to tf.ifrt_proram_call";
+  }
+
+  void runOnOperation() override {
+    mlir::ModuleOp module = getOperation();
+    mlir::SymbolTable symbol_table(module);
+
+    // key: original callee function in tf_device.cluster_func. value: ifrt
+    // program.
+    llvm::DenseMap<mlir::func::FuncOp, mlir::func::FuncOp>
+        cluster_to_ifrt_program;
+
+    std::vector<mlir::tf_device::ClusterFuncOp> cluster_func_ops;
+    module.walk([&](mlir::tf_device::ClusterFuncOp cluster_func) {
+      cluster_func_ops.push_back(cluster_func);
+    });
+    for (auto cluster_func : cluster_func_ops) {
+      Rewrite(symbol_table, cluster_to_ifrt_program, cluster_func);
+    }
+
+    // TODO(b/304839793): Move this to a separate pass. The old remove
+    // compilation result pass rely on TPUPartitionedCall
+    llvm::SmallVector<mlir::TF::TPUCompilationResultOp> compilation_result_ops;
+    module.walk([&](mlir::TF::TPUCompilationResultOp op) {
+      compilation_result_ops.push_back(op);
+    });
+    for (auto op : compilation_result_ops) {
+      if (!op.use_empty()) {
+        module->emitError("TPUCompilationResultOp is under use");
+        return signalPassFailure();
+      }
+      op.erase();
+    }
+  }
+
+  void Rewrite(mlir::SymbolTable &symbol_table,
+               llvm::DenseMap<mlir::func::FuncOp, mlir::func::FuncOp>
+                   &cluster_to_ifrt_program,
+               mlir::tf_device::ClusterFuncOp cluster_func) {
+    mlir::OpBuilder builder(cluster_func);
+    mlir::FlatSymbolRefAttr callee_symbol = cluster_func.getFuncAttr();
+    mlir::func::FuncOp callee_func =
+        symbol_table.lookup<mlir::func::FuncOp>(callee_symbol.getValue());
+
+    auto ifrt_program_name =
+        absl::StrCat("_ifrt_program_", callee_func.getSymName().str());
+    if (mlir::func::FuncOp ifrt_program =
+            cluster_to_ifrt_program[callee_func]) {
+      // ifrt program already exists
+      builder.setInsertionPoint(cluster_func);
+
+      mlir::TF::IfrtCallOp ifrt_call_op = builder.create<mlir::TF::IfrtCallOp>(
+          cluster_func->getLoc(), cluster_func.getResultTypes(),
+          cluster_func->getOperands());
+
+      int64_t program_id;
+      if (auto attr = ifrt_program->getAttrOfType<mlir::IntegerAttr>(
+              "tfrt_ifrt_serving.program_id")) {
+        program_id = attr.getInt();
+      } else {
+        return signalPassFailure();
+      }
+
+      // TODO(b/304839793): populate variable names after adding a variable
+      // hoisting pass.
+      ifrt_call_op.setVariableNamesAttr(builder.getArrayAttr({}));
+      ifrt_call_op.setProgramId(program_id);
+
+      cluster_func->replaceAllUsesWith(ifrt_call_op.getResults());
+      cluster_func->erase();
+
+      return;
+    }
+
+    mlir::OpBuilder::InsertionGuard insertion_guard(builder);
+    builder.setInsertionPoint(callee_func);
+
+    mlir::func::FuncOp cloned_ifrt_program = builder.create<mlir::func::FuncOp>(
+        callee_func->getLoc(), ifrt_program_name,
+        callee_func.getFunctionType());
+    mlir::IRMapping mapper;
+    callee_func.cloneInto(cloned_ifrt_program, mapper);
+
+    cloned_ifrt_program.setName(ifrt_program_name);
+    cloned_ifrt_program.setPublic();
+
+    int64_t program_id = NewProgramId();
+    cloned_ifrt_program->setAttr("tfrt_ifrt_serving.program_id",
+                                 builder.getI64IntegerAttr(program_id));
+
+    builder.setInsertionPoint(cluster_func);
+
+    mlir::TF::IfrtCallOp ifrt_call_op = builder.create<mlir::TF::IfrtCallOp>(
+        cluster_func->getLoc(), cluster_func.getResultTypes(),
+        cluster_func->getOperands());
+
+    // TODO(b/304839793): populate variable names after adding a variable
+    // hoisting pass.
+    ifrt_call_op.setVariableNamesAttr(builder.getArrayAttr({}));
+    ifrt_call_op.setProgramId(program_id);
+
+    cluster_func->replaceAllUsesWith(ifrt_call_op.getResults());
+    cluster_func->erase();
+
+    symbol_table.insert(cloned_ifrt_program);
+    cluster_to_ifrt_program[callee_func] = cloned_ifrt_program;
+  }
+};
+}  // namespace
+
+std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
+CreateRewriteClusterToIfrtCallPass() {
+  return std::make_unique<RewriteClusterToIfrtCallPass>();
+}
+
+}  // namespace ifrt_serving
+}  // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/rewrite_cluster_to_ifrt_call.h b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/rewrite_cluster_to_ifrt_call.h
new file mode 100644
index 0000000..4809e82
--- /dev/null
+++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/rewrite_cluster_to_ifrt_call.h
@@ -0,0 +1,33 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_REWRITE_CLUSTER_TO_IFRT_CALL_H_
+#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_REWRITE_CLUSTER_TO_IFRT_CALL_H_
+
+#include <memory>
+
+#include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/Pass/Pass.h"  // from @llvm-project
+
+namespace tensorflow {
+namespace ifrt_serving {
+
+std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
+CreateRewriteClusterToIfrtCallPass();
+
+}  // namespace ifrt_serving
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_REWRITE_CLUSTER_TO_IFRT_CALL_H_
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata/BUILD b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata/BUILD
new file mode 100644
index 0000000..8e6be47
--- /dev/null
+++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata/BUILD
@@ -0,0 +1,12 @@
+package(
+    # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
+    default_visibility = ["//tensorflow/compiler/mlir/tfrt:__subpackages__"],
+    licenses = ["notice"],
+)
+
+filegroup(
+    name = "testdata",
+    srcs = glob(
+        ["*"],
+    ),
+)
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata/ifrt_cluster.mlir b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata/ifrt_cluster.mlir
new file mode 100644
index 0000000..984608c
--- /dev/null
+++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata/ifrt_cluster.mlir
@@ -0,0 +1,9 @@
+module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
+  func.func @main() {
+    "tf_device.cluster_func"() {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "", topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> ()
+    func.return
+  }
+  func.func @empty_func() {
+    func.return
+  }
+}
\ No newline at end of file
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.cc
new file mode 100644
index 0000000..d03eda0
--- /dev/null
+++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.cc
@@ -0,0 +1,120 @@
+
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.h"
+
+#include <memory>
+#include <string>
+
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "llvm/ADT/StringRef.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/Pass/PassManager.h"  // from @llvm-project
+#include "mlir/Pass/PassRegistry.h"  // from @llvm-project
+#include "mlir/Support/LogicalResult.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
+#include "tensorflow/compiler/mlir/tensorflow/utils/data_dumper_logger_config.h"
+#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
+#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
+#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/rewrite_cluster_to_ifrt_call.h"
+#include "tensorflow/core/util/debug_data_dumper.h"
+
+namespace tensorflow {
+namespace ifrt_serving {
+namespace {
+
+using mlir::LogicalResult;
+using mlir::OpPassManager;
+using mlir::PassManager;
+using mlir::func::FuncOp;
+
+// Setup the input pass manager to enable IR dumping after each pass.
+// Note a side effect of this method is that multi threading will be disabled.
+void EnablePassIRPrinting(PassManager& pm, const std::string& dump_group_name,
+                          llvm::StringRef module_name) {
+  // Print the whole module after each pass, which requires disabling
+  // multi-threading as well.
+  pm.getContext()->disableMultithreading();
+  pm.enableIRPrinting(std::make_unique<::tensorflow::DataDumperLoggerConfig>(
+      [module_name, dump_group_name](const std::string& pass_tag_name,
+                                     mlir::Operation* op) {
+        return DEBUG_DATA_DUMPER()->GetDumpFilename(
+            module_name.str(), dump_group_name, pass_tag_name);
+      },
+      /*pass_prefix=*/"",
+      /*print_module_scope=*/true));
+  pm.enableTiming();
+}
+
+void AddClusterToIfrtRuntimeOpsPassPipeline(OpPassManager& pm,
+                                            llvm::StringRef module_name) {
+  pm.addNestedPass<mlir::func::FuncOp>(
+      mlir::CreateExecutorDialectToFunctionalConversionPass());
+
+  pm.addNestedPass<mlir::func::FuncOp>(
+      mlir::TF::CreateCanonicalizeCompileAndReplicateAttributesPass());
+
+  pm.addPass(CreateRewriteClusterToIfrtCallPass());
+}
+
+}  // namespace
+
+absl::Status RunClusterToIfrtRuntimeOpsPassPipeline(
+    mlir::ModuleOp module, llvm::StringRef module_name) {
+  mlir::StatusScopedDiagnosticHandler diag_handler(
+      module.getContext(), /*propagate=*/false,
+      /*filter_stack=*/!VLOG_IS_ON(1));
+
+  PassManager runtime_lowering(module.getContext());
+  ::tensorflow::applyTensorflowAndCLOptions(runtime_lowering);
+
+  AddClusterToIfrtRuntimeOpsPassPipeline(runtime_lowering, module_name);
+
+  if (VLOG_IS_ON(1)) {
+    ::tensorflow::DumpMlirOpToFile(
+        DEBUG_DATA_DUMPER()->GetDumpFilename(module_name.str(), kDebugGroupMain,
+                                             "ifrt_runtime_lowering_before"),
+        module, llvm::StringRef(), &runtime_lowering);
+  }
+
+  if (VLOG_IS_ON(2)) {
+    EnablePassIRPrinting(runtime_lowering, kDebugGroupRuntimeLowering,
+                         module_name);
+  }
+
+  // Ignore the result since diag_handler consumes it
+  LogicalResult result = runtime_lowering.run(module);
+  (void)result;
+
+  if (VLOG_IS_ON(1)) {
+    ::tensorflow::DumpMlirOpToFile(
+        DEBUG_DATA_DUMPER()->GetDumpFilename(module_name.str(), kDebugGroupMain,
+                                             "ifrt_runtime_lowering_after"),
+        module, llvm::StringRef(), &runtime_lowering);
+  }
+
+  return diag_handler.ConsumeStatus();
+}
+
+// Register all IfrtPass
+void RegisterTfIfrtPasses() {
+  mlir::registerPass([]() { return CreateRewriteClusterToIfrtCallPass(); });
+}
+
+}  // namespace ifrt_serving
+}  // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.h b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.h
new file mode 100644
index 0000000..5709732
--- /dev/null
+++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.h
@@ -0,0 +1,37 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_TF_IFRT_PASSES_H_
+#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_TF_IFRT_PASSES_H_
+
+#include "absl/status/status.h"
+#include "llvm/ADT/StringRef.h"
+#include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+
+namespace tensorflow {
+namespace ifrt_serving {
+
+// Register all passes.
+void RegisterTfIfrtPasses();
+
+// Convert tf_device.cluster_func to tf.ifrt_program_call.
+// The callee function is converted to a ifrt_program.
+absl::Status RunClusterToIfrtRuntimeOpsPassPipeline(
+    mlir::ModuleOp module, llvm::StringRef module_name = llvm::StringRef());
+
+}  // namespace ifrt_serving
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_TF_IFRT_PASSES_H_
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.cc b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.cc
index 2807264..6e04fe1 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.cc
+++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.cc
@@ -14,7 +14,9 @@
 ==============================================================================*/
 #include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.h"
 
+#include <string>
 #include <utility>
+#include <vector>
 
 #include "absl/log/log.h"
 #include "absl/status/status.h"
@@ -49,10 +51,10 @@
 namespace mlrt_compiler {
 
 StatusOr<mlrt::bc::Buffer> ConvertTfMlirToBytecode(
-    const TfrtCompileOptions& options,
-    const tfrt_stub::FallbackState& fallback_state, mlir::ModuleOp module,
-    tfrt_stub::ModelRuntimeContext& model_context,
-    mlir::OwningOpRef<mlir::ModuleOp>* module_with_op_keys) {
+    const TfrtCompileOptions& options, tfrt_stub::FallbackState& fallback_state,
+    mlir::ModuleOp module, tfrt_stub::ModelRuntimeContext& model_context,
+    mlir::OwningOpRef<mlir::ModuleOp>* module_with_op_keys,
+    std::vector<std::string>* added_xla_function_names) {
   mlrt::bc::Buffer bytecode_buffer;
   TF_RETURN_IF_ERROR(ConvertTfMlirToRuntimeExecutable(
       options, module,
@@ -127,7 +129,7 @@
         bytecode_buffer = std::move(*statusor);
         return OkStatus();
       },
-      model_context));
+      model_context, &fallback_state, added_xla_function_names));
   return bytecode_buffer;
 }
 
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.h b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.h
index 87dc685..ef9caeb 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.h
+++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.h
@@ -33,10 +33,10 @@
 //
 // This is for initial conversion.
 StatusOr<mlrt::bc::Buffer> ConvertTfMlirToBytecode(
-    const TfrtCompileOptions& options,
-    const tfrt_stub::FallbackState& fallback_state, mlir::ModuleOp module,
-    tfrt_stub::ModelRuntimeContext& model_context,
-    mlir::OwningOpRef<mlir::ModuleOp>* module_with_op_keys = nullptr);
+    const TfrtCompileOptions& options, tfrt_stub::FallbackState& fallback_state,
+    mlir::ModuleOp module, tfrt_stub::ModelRuntimeContext& model_context,
+    mlir::OwningOpRef<mlir::ModuleOp>* module_with_op_keys = nullptr,
+    std::vector<std::string>* added_xla_function_names = nullptr);
 
 // Converts an MLIR `module_with_op_keys` in TF dialect to MLRT's bytecode
 // format, with op costs from `cost_recorder`.
diff --git a/tensorflow/compiler/mlir/tfrt/translate/import_model.cc b/tensorflow/compiler/mlir/tfrt/translate/import_model.cc
index 48598a0..9a8d41a4 100644
--- a/tensorflow/compiler/mlir/tfrt/translate/import_model.cc
+++ b/tensorflow/compiler/mlir/tfrt/translate/import_model.cc
@@ -29,6 +29,7 @@
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
+#include "tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_asset_sinking_pass.h"
 #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
@@ -37,17 +38,21 @@
 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
 #include "tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.h"
+#include "tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor.h"
 #include "tensorflow/compiler/mlir/tfrt/backend_compiler.h"
 #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"
 #include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h"
 #include "tensorflow/compiler/mlir/tfrt/transforms/tpu_passes.h"
 #include "tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
 #include "tensorflow/core/common_runtime/function_body.h"
 #include "tensorflow/core/common_runtime/function_def_utils.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/platform/statusor.h"
 #include "tensorflow/core/tfrt/fallback/fallback_state.h"
+#include "tensorflow/core/tpu/tpu_defs.h"
+#include "tsl/framework/device_type.h"
 #include "tsl/platform/env.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/statusor.h"
@@ -214,6 +219,13 @@
         tensorflow::tf2xla::v2::RunFunctionTf2xlaClusteringBridge(
             module, tf2xla::v2::DeviceType::XLA_TPU_JIT,
             /*is_in_fallback_enabled_mode=*/VLOG_IS_ON(1)));
+
+    TF_RETURN_IF_ERROR(
+        tensorflow::tfrt_compiler::RunLowerClusterToRuntimeOpsPassPipeline(
+            module, tsl::DeviceType(DEVICE_TPU_XLA_JIT)));
+
+    TF_RETURN_IF_ERROR(
+        tensorflow::tf2xla::v2::ExportFromTensorflowDialectToExecutor(module));
   } else if (options.device_target == TfrtDeviceInfraTarget::kTfFallback) {
     auto tpu_partitioned_call_fallback_compat_result =
         tensorflow::RunTPUPartitionedCallFallbackCompatConversion(module);
@@ -222,7 +234,17 @@
           "Failed to process TPUPartitionedCallOp for fallback execution"));
     }
   } else if (options.device_target == TfrtDeviceInfraTarget::kGpu) {
-    TF_RETURN_IF_ERROR(mlir::TF::RunTFXLABridge(module));
+    TF_RETURN_IF_ERROR(
+        tensorflow::tf2xla::v2::RunFunctionTf2xlaClusteringBridge(
+            module, tf2xla::v2::DeviceType::XLA_GPU_JIT,
+            /*is_in_fallback_enabled_mode=*/false));
+
+    TF_RETURN_IF_ERROR(
+        tensorflow::tfrt_compiler::RunLowerClusterToRuntimeOpsPassPipeline(
+            module, tsl::DeviceType(DEVICE_GPU_XLA_JIT)));
+
+    TF_RETURN_IF_ERROR(
+        tensorflow::tf2xla::v2::ExportFromTensorflowDialectToExecutor(module));
 
     if (options.serialize_mlir_module_to_aot_packages) {
       const std::string mlir_string = SerializeMlirModule(module);
diff --git a/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h b/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h
index 34dc667..11790e9 100644
--- a/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h
+++ b/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h
@@ -155,6 +155,12 @@
 
   // Serialized MLIR module file under aot_packages.
   std::string aot_mlir_module_file;
+
+  // If true, BEF will be serialized to aot_packages.
+  bool serialize_bef_to_aot_packages = false;
+
+  // Serialized BEF file under aot_packages.
+  std::string aot_bef_file;
 };
 
 std::ostream& operator<<(std::ostream& os, const TfrtCompileOptions& options);
diff --git a/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir
index 26f76f9..53cbd84 100644
--- a/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir
+++ b/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir
@@ -1137,3 +1137,12 @@
   %1 = "tf.BroadcastTo"(%arg0, %s) : (tensor<2x3x13x1xi32>, tensor<2xi32>) -> tensor<13x7xi32>
   return %1 : tensor<13x7xi32>
 }
+
+// -----
+
+// CHECK-LABEL: test_erf
+// CHECK: %[[VAR0:.*]] = tosa.erf %arg0 :
+func.func @test_erf(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
+  %0 = "tf.Erf"(%arg0) : (tensor<4x4xf32>) -> tensor<4x4xf32>
+  func.return %0 : tensor<4x4xf32>
+}
diff --git a/tensorflow/compiler/mlir/tosa/transforms/tf_legalize_patterns.td b/tensorflow/compiler/mlir/tosa/transforms/tf_legalize_patterns.td
index cf5f078..194f949 100644
--- a/tensorflow/compiler/mlir/tosa/transforms/tf_legalize_patterns.td
+++ b/tensorflow/compiler/mlir/tosa/transforms/tf_legalize_patterns.td
@@ -32,6 +32,7 @@
 def : Pat<(TF_CeilOp $arg), (Tosa_CeilOp $arg)>;
 def : Pat<(TF_FloorOp $arg), (Tosa_FloorOp $arg)>;
 def : Pat<(TF_ExpOp $arg), (Tosa_ExpOp $arg)>;
+def : Pat<(TF_ErfOp $arg), (Tosa_ErfOp $arg)>;
 def : Pat<(TF_LogOp $arg), (Tosa_LogOp $arg)>;
 def : Pat<(TF_ReciprocalOp $arg), (Tosa_ReciprocalOp $arg)>;
 def : Pat<(TF_RsqrtOp $arg), (Tosa_RsqrtOp $arg)>;
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 612dee3..bf64d8c 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -891,8 +891,8 @@
         "//tensorflow/python/framework:constant_op",
         "//tensorflow/python/framework:test_lib",
         "//tensorflow/python/ops:resource_variable_ops",
+        "//tensorflow/python/ops:training_ops_gen",
         "//tensorflow/python/platform:test",
-        "//tensorflow/python/training:training_ops",
         "//third_party/py/numpy",
     ],
 )
@@ -1438,7 +1438,7 @@
     ],
     deps = [
         ":xla_test",
-        "//tensorflow/python/client",
+        "//tensorflow/python/client:device_lib",
         "//tensorflow/python/framework:dtypes",
         "//tensorflow/python/ops:array_ops",
         "//tensorflow/python/ops:math_ops",
@@ -1566,7 +1566,7 @@
     ],
     deps = [
         ":xla_test",
-        "//tensorflow/python/client",
+        "//tensorflow/python/client:device_lib",
         "//tensorflow/python/eager:def_function",
         "//tensorflow/python/framework:config",
         "//tensorflow/python/framework:dtypes",
@@ -2919,3 +2919,42 @@
         "//tensorflow/python/platform:client_testlib",
     ],
 )
+
+tf_xla_py_strict_test(
+    name = "xla_dump_to_test",
+    size = "medium",
+    srcs = ["xla_dump_to_test.py"],
+    enable_mlir_bridge = True,
+    python_version = "PY3",
+    tags = [
+        "no_pip",  # TODO(b/149738646): fix pip install so these tests run on kokoro pip
+        "optonly",
+    ],
+    deps = [
+        ":xla_test",
+        "//tensorflow/python/ops:array_ops",
+        "//tensorflow/python/ops:math_ops",
+        "//tensorflow/python/platform:test",
+        "//third_party/py/numpy",
+    ],
+)
+
+# copybara:uncomment_begin(google-only)
+# tf_xla_py_strict_test(
+#     name = "xla_dump_to_sponge_test",
+#     size = "medium",
+#     srcs = ["xla_dump_to_sponge_test.py"],
+#     enable_mlir_bridge = True,
+#     python_version = "PY3",
+#     tags = [
+#         "optonly",
+#     ],
+#     deps = [
+#         ":xla_test",
+#         "//third_party/py/numpy",
+#         "//tensorflow/python/ops:array_ops",
+#         "//tensorflow/python/ops:math_ops",
+#         "//tensorflow/python/platform:test",
+#     ],
+# )
+# copybara:uncomment_end
diff --git a/tensorflow/compiler/tests/ftrl_ops_test.py b/tensorflow/compiler/tests/ftrl_ops_test.py
index 61d73a2..d28a826 100644
--- a/tensorflow/compiler/tests/ftrl_ops_test.py
+++ b/tensorflow/compiler/tests/ftrl_ops_test.py
@@ -19,9 +19,9 @@
 from tensorflow.compiler.tests import xla_test
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import test_util
+from tensorflow.python.ops import gen_training_ops
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.platform import googletest
-from tensorflow.python.training import training_ops
 
 
 class ResourceApplyFtrlTest(xla_test.XLATestCase):
@@ -55,12 +55,12 @@
         session.run(v_linear.create)
         assert not (use_v2 and multiply_linear_by_lr)
         if use_v2:
-          session.run(training_ops.resource_apply_ftrl_v2(
+          session.run(gen_training_ops.resource_apply_ftrl_v2(
               v_var.handle, v_accum.handle, v_linear.handle,
               grad, lr, l1, l2, l2_shrinkage, lr_power,
               multiply_linear_by_lr=multiply_linear_by_lr))
         else:
-          session.run(training_ops.resource_apply_ftrl(
+          session.run(gen_training_ops.resource_apply_ftrl(
               v_var.handle, v_accum.handle, v_linear.handle,
               grad, lr, l1, l2, lr_power,
               multiply_linear_by_lr=multiply_linear_by_lr))
diff --git a/tensorflow/compiler/tests/xla_call_module_test.py b/tensorflow/compiler/tests/xla_call_module_test.py
index fc6accf..50f81b5 100644
--- a/tensorflow/compiler/tests/xla_call_module_test.py
+++ b/tensorflow/compiler/tests/xla_call_module_test.py
@@ -92,12 +92,12 @@
 
     self._assertOpOutputMatchesExpected(f, (x,), (np.sin(np.cos(x)),))
 
-  def test_basic_with_token(self):
+  def test_basic_with_token_v8(self):
     x = np.array([1.0, 2.0, 3.0], dtype=np.float32)
 
     def f(x):
       # sin(cos(x))
-      module, version = serialize("""
+      module, _ = serialize("""
 module @jit_f.0 {
   func.func public @main(%arg0: !stablehlo.token, %arg1: tensor<3xf32>) -> (!stablehlo.token, tensor<3xf32>) {
     %0 = stablehlo.cosine %arg1 : tensor<3xf32>
@@ -108,11 +108,61 @@
 """)
       return xla.call_module(
           [x],
+          version=8,  # Version 8 uses only one prefix token
+          module=module,
+          Tout=[x.dtype],
+          Sout=[x.shape],
+          has_token_input_output=True,  # Version 8 cares about this
+          platforms=[self.testing_platform()],
+      )
+
+    self._assertOpOutputMatchesExpected(f, (x,), (np.sin(np.cos(x)),))
+
+  def test_basic_with_multiple_tokens(self):
+    x = np.array([1.0, 2.0, 3.0], dtype=np.float32)
+
+    def f(x):
+      # sin(cos(x))
+      module, version = serialize("""
+module @jit_f.0 {
+  func.func public @main(%arg0: !stablehlo.token {jax.token = true}, %arg1: !stablehlo.token {jax.token = true}, %arg2: tensor<3xf32>) -> (!stablehlo.token, !stablehlo.token, tensor<3xf32>) {
+    %0 = stablehlo.cosine %arg2 : tensor<3xf32>
+    %1 = stablehlo.sine %0 : tensor<3xf32>
+    return %arg0, %arg1, %1 : !stablehlo.token, !stablehlo.token, tensor<3xf32>
+  }
+}
+""")
+      return xla.call_module(
+          [x],
           version=version,
           module=module,
           Tout=[x.dtype],
           Sout=[x.shape],
-          has_token_input_output=True,
+          platforms=[self.testing_platform()],
+      )
+
+    self._assertOpOutputMatchesExpected(f, (x,), (np.sin(np.cos(x)),))
+
+  def test_basic_with_tokens_preceeded_by_other_args(self):
+    x = np.array([1.0, 2.0, 3.0], dtype=np.float32)
+
+    def f(x):
+      # sin(cos(x))
+      module, version = serialize("""
+module @jit_f.0 {
+  func.func public @main(%arg0: tensor<i32>, %arg1: !stablehlo.token {jax.token = true}, %arg2: !stablehlo.token {jax.token = true}, %arg3: tensor<3xf32>) -> (!stablehlo.token, !stablehlo.token, tensor<3xf32>) {
+    %0 = stablehlo.cosine %arg3 : tensor<3xf32>
+    %1 = stablehlo.sine %0 : tensor<3xf32>
+    return %arg1, %arg2, %1 : !stablehlo.token, !stablehlo.token, tensor<3xf32>
+  }
+}
+""")
+      return xla.call_module(
+          [np.int32(0), x],
+          version=version,
+          module=module,
+          Tout=[x.dtype],
+          Sout=[x.shape],
           platforms=[self.testing_platform()],
       )
 
@@ -183,7 +233,7 @@
     %0, %1 = call @dyn_main(%arg0_new, %arg1) : (tensor<{dim_var_type}>, tensor<2x?xf32>) -> (tensor<2x?xf32>, tensor<{dim_var_type}>)
     return %0, %1 : tensor<2x?xf32>, tensor<{dim_var_type}>
   }}
-  func.func private @dyn_main(%arg0: tensor<{dim_var_type}>, %arg1: tensor<2x?xf32>) -> (tensor<2x?xf32>, tensor<{dim_var_type}>) {{
+  func.func private @dyn_main(%arg0: tensor<{dim_var_type}> {{jax.global_constant = "b"}}, %arg1: tensor<2x?xf32>) -> (tensor<2x?xf32>, tensor<{dim_var_type}>) {{
     %0 = stablehlo.sine %arg1 : tensor<2x?xf32>
     return %0, %arg0 : tensor<2x?xf32>, tensor<{dim_var_type}>
   }}
@@ -278,7 +328,7 @@
     #  returns x + 2. on CPU, x + 3. on GPU (CUDA or ROCM) and x + 4. on TPU
     module, version = serialize(f"""
 module @jit_f.0 {{
-  func.func public @main(%arg_platform_idx: tensor<{platform_idx_type}>, %arg0: tensor<f32>) -> tensor<f32> {{
+  func.func public @main(%arg_platform_idx: tensor<{platform_idx_type}> {{jax.global_constant = "_platform_index"}}, %arg0: tensor<f32>) -> tensor<f32> {{
     %0 = stablehlo.convert %arg_platform_idx : (tensor<{platform_idx_type}>) -> tensor<i32>
     %to_add = "stablehlo.case"(%0) ({{
       %cpu_val = stablehlo.constant dense<2.> : tensor<f32>
@@ -319,7 +369,7 @@
     #  returns x + 2. on CPU, x + 3. on GPU, and x + 4. on TPU
     module, version = serialize("""
 module @jit_f.0 {
-  func.func public @main(%arg_platform_idx: tensor<i32>, %arg0: tensor<f32>) -> tensor<f32> {
+  func.func public @main(%arg_platform_idx: tensor<i32> {jax.global_constant = "_platform_index"}, %arg0: tensor<f32>) -> tensor<f32> {
     %to_add = "stablehlo.case"(%arg_platform_idx) ({
       %cpu_val = stablehlo.constant dense<2.> : tensor<f32>
       stablehlo.return %cpu_val : tensor<f32>
@@ -358,13 +408,13 @@
 
     module, version = serialize("""
 module @jit_f_jax attributes {jax.uses_shape_polymorphism = true} {
-  func.func public @main(%arg_platform_idx: tensor<i32>, %arg0: tensor<?xf32>) -> (tensor<?xf32>) {
+  func.func public @main(%arg_platform_idx: tensor<i32> {jax.global_constant = "_platform_index"}, %arg0: tensor<?xf32>) -> (tensor<?xf32>) {
     %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?xf32>) -> tensor<i32>
     %5 = call @_wrapped_jax_export_main(%arg_platform_idx, %0, %arg0) : (tensor<i32>, tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
     return %5 : tensor<?xf32>
   }
 
-  func.func private @_wrapped_jax_export_main(%arg_platform_idx: tensor<i32>, %arg0: tensor<i32>, %arg1: tensor<?xf32>) -> (tensor<?xf32>) {
+  func.func private @_wrapped_jax_export_main(%arg_platform_idx: tensor<i32> {jax.global_constant = "_platform_index"}, %arg0: tensor<i32> {jax.global_constant = "b"}, %arg1: tensor<?xf32>) -> (tensor<?xf32>) {
     %to_add = "stablehlo.case"(%arg_platform_idx) ({
       %cpu_val = stablehlo.constant dense<2.> : tensor<f32>
       stablehlo.return %cpu_val : tensor<f32>
@@ -395,6 +445,49 @@
     )
     self._assertOpOutputMatchesExpected(f, (x,), (expected_value,))
 
+  def test_platforms_and_poly_and_tokens(self):
+    x = np.arange(6, dtype=np.float32)
+    #  returns x + 2. on CPU, x + 3. on GPU (CUDA or ROCM) and x + 4. on TPU
+
+    module, version = serialize("""
+module @jit_f_jax attributes {jax.uses_shape_polymorphism = true} {
+  func.func public @main(%arg_platform_idx: tensor<i32> {jax.global_constant = "_platform_index"}, %arg_tok: !stablehlo.token {jax.token = true}, %arg0: tensor<?xf32>) -> (!stablehlo.token, tensor<?xf32>) {
+    %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?xf32>) -> tensor<i32>
+    %5:2 = call @_wrapped_jax_export_main(%arg_platform_idx, %0, %arg_tok, %arg0) : (tensor<i32>, tensor<i32>, !stablehlo.token, tensor<?xf32>) -> (!stablehlo.token, tensor<?xf32>)
+    return %5#0, %5#1 : !stablehlo.token, tensor<?xf32>
+  }
+
+  func.func private @_wrapped_jax_export_main(%arg_platform_idx: tensor<i32> {jax.global_constant = "_platform_index"}, %arg0: tensor<i32> {jax.global_constant = "b"}, %arg_tok: !stablehlo.token {jax.token = true}, %arg1: tensor<?xf32>) -> (!stablehlo.token, tensor<?xf32>) {
+    %to_add = "stablehlo.case"(%arg_platform_idx) ({
+      %cpu_val = stablehlo.constant dense<2.> : tensor<f32>
+      stablehlo.return %cpu_val : tensor<f32>
+    }, {
+      %gpu_val = stablehlo.constant dense<3.> : tensor<f32>
+      stablehlo.return %gpu_val : tensor<f32>
+    }, {
+      %tpu_val = stablehlo.constant dense<4.> : tensor<f32>
+      stablehlo.return %tpu_val : tensor<f32>
+    }) : (tensor<i32>) -> tensor<f32>
+    %1 = stablehlo.reshape %arg0 : (tensor<i32>) -> tensor<1xi32>
+    %3 = stablehlo.dynamic_broadcast_in_dim %to_add, %1, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
+    %4 = stablehlo.add %3, %arg1 : tensor<?xf32>
+    return %arg_tok, %4 : !stablehlo.token, tensor<?xf32>
+  }
+}
+""")
+    platforms = ['CPU', 'CUDA', 'ROCM', 'TPU']
+    def f(x):
+      return xla.call_module([x], version=version,
+                             module=module,
+                             Tout=[np.float32],
+                             Sout=[()],
+                             platforms=platforms)
+
+    expected_value = (
+        x + dict(CPU=2.0, CUDA=3.0, ROCM=3.0, TPU=4.0)[self.testing_platform()]
+    )
+    self._assertOpOutputMatchesExpected(f, (x,), (expected_value,))
+
   # A module used for testing errors related to use of "platforms".
   platforms_errors_module_str = """
   module @jit_f.0 {
@@ -742,7 +835,7 @@
     %0 = call @dyn_main(%arg0_new, %arg1) : (tensor<i32>, tensor<?x5xi32>) -> tensor<?xi32>
     return %0 : tensor<?xi32>
   }
-  func.func private @dyn_main(%arg0: tensor<i32>, %arg1: tensor<?x5xi32>) -> tensor<?xi32> {
+  func.func private @dyn_main(%arg0: tensor<i32> {jax.global_constant = "b"}, %arg1: tensor<?x5xi32>) -> tensor<?xi32> {
     %0 = stablehlo.reshape %arg0 : (tensor<i32>) -> tensor<1xi32>
     %1 = "stablehlo.dynamic_iota"(%0) {iota_dimension = 0 : i64} : (tensor<1xi32>) -> tensor<?xi32>
     return %1 : tensor<?xi32>
@@ -790,7 +883,7 @@
     %0 = call @dyn_main(%arg0_new, %arg1) : (tensor<i32>, tensor<?x3xf32>) -> tensor<?xf32>
     return %0 : tensor<?xf32>
   }
-  func.func private @dyn_main(%arg0: tensor<i32>, %arg1: tensor<?x3xf32>) -> tensor<?xf32> {
+  func.func private @dyn_main(%arg0: tensor<i32> {jax.global_constant = "b"}, %arg1: tensor<?x3xf32>) -> tensor<?xf32> {
     %0 = stablehlo.constant dense<3> : tensor<i32>
     %1 = stablehlo.multiply %arg0, %0 : tensor<i32>
     %2 = stablehlo.reshape %1 : (tensor<i32>) -> tensor<1xi32>
@@ -819,7 +912,7 @@
     %0 = call @dyn_main(%arg0_new, %arg1) : (tensor<i32>, tensor<?x4xf32>) -> tensor<?x2xf32>
     return %0 : tensor<?x2xf32>
   }
-  func.func private @dyn_main(%arg0: tensor<i32>, %arg1: tensor<?x4xf32>) -> tensor<?x2xf32> {
+  func.func private @dyn_main(%arg0: tensor<i32> {jax.global_constant = "b"}, %arg1: tensor<?x4xf32>) -> tensor<?x2xf32> {
     %0 = stablehlo.constant dense<0> : tensor<i64>
     %1 = stablehlo.constant dense<0> : tensor<1xi64>
     %2 = stablehlo.reshape %arg0 : (tensor<i32>) -> tensor<1xi32>
@@ -850,7 +943,7 @@
     %0 = call @dyn_main(%arg0_new, %arg1) : (tensor<i32>, tensor<?x4xf32>) -> tensor<4xf32>
     return %0 : tensor<4xf32>
   }
-  func.func private @dyn_main(%arg0: tensor<i32>, %arg1: tensor<?x4xf32>) -> tensor<4xf32> {
+  func.func private @dyn_main(%arg0: tensor<i32> {jax.global_constant = "b"}, %arg1: tensor<?x4xf32>) -> tensor<4xf32> {
     %0 = stablehlo.constant dense<-1> : tensor<i32>
     %1 = stablehlo.add %arg0, %0 : tensor<i32>
     %2 = stablehlo.reshape %1 : (tensor<i32>) -> tensor<1xi32>
@@ -887,7 +980,7 @@
     %0 = call @dyn_main(%arg0_new, %arg1, %arg2) : (tensor<i32>, tensor<?x4xf32>, tensor<i32>) -> tensor<?x4xf32>
     return %0 : tensor<?x4xf32>
   }
-  func.func private @dyn_main(%arg0: tensor<i32>, %arg1: tensor<?x4xf32>, %arg2: tensor<i32>) -> tensor<?x4xf32> {
+  func.func private @dyn_main(%arg0: tensor<i32> {jax.global_constant = "b"}, %arg1: tensor<?x4xf32>, %arg2: tensor<i32>) -> tensor<?x4xf32> {
     %0 = stablehlo.constant dense<0> : tensor<i32>
     %1 = stablehlo.compare  LT, %arg2, %0,  SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
     %2 = stablehlo.add %arg2, %arg0 : tensor<i32>
@@ -920,7 +1013,7 @@
     %0, %1 = call @dyn_main(%arg0_new, %arg1, %arg2) : (tensor<i32>, tensor<?x4xf32>, tensor<2x?x4xf32>) -> (tensor<2x?x4xf32>, tensor<2x?x4xf32>)
     return %0, %1 : tensor<2x?x4xf32>, tensor<2x?x4xf32>
   }
-  func.func private @dyn_main(%arg0: tensor<i32>, %arg1: tensor<?x4xf32>, %arg2: tensor<2x?x4xf32>) -> (tensor<2x?x4xf32>, tensor<2x?x4xf32>) {
+  func.func private @dyn_main(%arg0: tensor<i32> {jax.global_constant = "b"}, %arg1: tensor<?x4xf32>, %arg2: tensor<2x?x4xf32>) -> (tensor<2x?x4xf32>, tensor<2x?x4xf32>) {
     %0 = stablehlo.constant dense<2> : tensor<1xi32>
     %2 = stablehlo.reshape %arg0 : (tensor<i32>) -> tensor<1xi32>
     %3 = stablehlo.constant dense<4> : tensor<1xi32>
@@ -952,7 +1045,7 @@
     %0 = call @dyn_main(%arg0_new, %arg1) : (tensor<i32>, tensor<?xi32>) -> tensor<i32>
     return %0 : tensor<i32>
   }
-  func.func private @dyn_main(%arg0: tensor<i32>, %arg1: tensor<?xi32>) -> tensor<i32> {
+  func.func private @dyn_main(%arg0: tensor<i32> {jax.global_constant = "b"}, %arg1: tensor<?xi32>) -> tensor<i32> {
     %0 = stablehlo.constant dense<0> : tensor<i32>
     %1 = stablehlo.reduce(%arg1 init: %0) across dimensions = [0] : (tensor<?xi32>, tensor<i32>) -> tensor<i32>
      reducer(%arg2: tensor<i32>, %arg3: tensor<i32>)  {
@@ -984,7 +1077,7 @@
     %0 = call @dyn_main(%arg0_new, %arg1) : (tensor<i32>, tensor<?x5xf32>) -> tensor<?x1xf32>
     return %0 : tensor<?x1xf32>
   }
-  func.func private @dyn_main(%arg0: tensor<i32>, %arg1: tensor<?x5xf32>) -> tensor<?x1xf32> {
+  func.func private @dyn_main(%arg0: tensor<i32> {jax.global_constant = "b"}, %arg1: tensor<?x5xf32>) -> tensor<?x1xf32> {
     %0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
     %1 = stablehlo.reduce(%arg1 init: %0) across dimensions = [1] : (tensor<?x5xf32>, tensor<f32>) -> tensor<?xf32>
      reducer(%arg2: tensor<f32>, %arg3: tensor<f32>)  {
@@ -1020,11 +1113,11 @@
     %0 = call @dyn_main(%arg0_new, %arg1) : (tensor<i32>, tensor<?xf32>) -> tensor<?xi32>
     return %0 : tensor<?xi32>
   }
-  func.func private @dyn_main(%arg0: tensor<i32>, %arg1: tensor<?xf32>) -> tensor<?xi32> {
+  func.func private @dyn_main(%arg0: tensor<i32> {jax.global_constant = "b"}, %arg1: tensor<?xf32>) -> tensor<?xi32> {
     %0 = call @f(%arg0, %arg1) : (tensor<i32>, tensor<?xf32>) -> tensor<?xi32>
     return %0 : tensor<?xi32>
   }
-  func.func private @f(%arg0: tensor<i32>, %arg1: tensor<?xf32>) -> tensor<?xi32> {
+  func.func private @f(%arg0: tensor<i32> {jax.global_constant = "b"}, %arg1: tensor<?xf32>) -> tensor<?xi32> {
     %0 = stablehlo.reshape %arg0 : (tensor<i32>) -> tensor<1xi32>
     %1 = "stablehlo.dynamic_iota"(%0) {iota_dimension = 0 : i64} : (tensor<1xi32>) -> tensor<?xi32>
     return %1 : tensor<?xi32>
@@ -1051,7 +1144,7 @@
     %0 = call @dyn_main(%arg0_new, %arg1) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
     return %0 : tensor<?xf32>
   }
-  func.func private @dyn_main(%arg0: tensor<i32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
+  func.func private @dyn_main(%arg0: tensor<i32> {jax.global_constant = "b"}, %arg1: tensor<?xf32>) -> tensor<?xf32> {
     return %arg1 : tensor<?xf32>
   }
 }
@@ -1081,7 +1174,7 @@
     %0, %1 = call @dyn_main(%arg0_new, %arg1) : (tensor<i32>, tensor<?xf32>) -> (tensor<?xf32>, tensor<i64>)
     return %0, %1 : tensor<?xf32>, tensor<i64>
   }
-  func.func private @dyn_main(%arg0: tensor<i32>, %arg1: tensor<?xf32>) -> (tensor<?xf32>, tensor<i64>) {
+  func.func private @dyn_main(%arg0: tensor<i32> {jax.global_constant = "b"}, %arg1: tensor<?xf32>) -> (tensor<?xf32>, tensor<i64>) {
     %0 = stablehlo.constant dense<0> : tensor<i64>
     %1:2 = "stablehlo.while"(%arg1, %0) ({
     ^bb0(%arg2: tensor<?xf32>, %arg3: tensor<i64>):
@@ -1123,7 +1216,7 @@
     %0 = call @dyn_main(%arg0_new, %arg1) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
     return %0 : tensor<?xf32>
   }}
-  func.func private @dyn_main(%arg0: tensor<i32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {{
+  func.func private @dyn_main(%arg0: tensor<i32> {{jax.global_constant = "b"}}, %arg1: tensor<?xf32>) -> tensor<?xf32> {{
     return %arg1 : tensor<?xf32>
   }}
 }}
@@ -1311,7 +1404,6 @@
           Sout=[res.shape],
           platforms=[self.testing_platform()],
           function_list=(foo,),
-          has_token_input_output=True,
       )
 
     self._assertOpOutputMatchesExpected(f, (x, y), (res,))
diff --git a/tensorflow/compiler/tests/xla_dump_to_sponge_test.py b/tensorflow/compiler/tests/xla_dump_to_sponge_test.py
new file mode 100644
index 0000000..89a367d
--- /dev/null
+++ b/tensorflow/compiler/tests/xla_dump_to_sponge_test.py
@@ -0,0 +1,47 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test cases for debug XLA dumps."""
+
+import glob
+import os
+
+import numpy as np
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import googletest
+
+
+class XlaDumpToSpongeTest(xla_test.XLATestCase):
+  """Test that ensures --XLA_FLAGS=--dump_to_xla=sponge produces output."""
+
+  def _compute(self):
+    with self.session() as sess, self.device_scope():
+      data = np.array([0], dtype=np.float32)
+      indices = np.array([0], dtype=np.int32)
+      d = array_ops.placeholder(data.dtype, shape=data.shape)
+      i = array_ops.placeholder(indices.dtype, shape=indices.shape)
+      sess.run(math_ops.segment_max_v2(data, indices, 1), {d: data, i: indices})
+
+  def testDumpToSponge(self):
+    os.environ['XLA_FLAGS'] = '--xla_dump_to=sponge'
+    self._compute()
+    out_dir = os.environ['TEST_UNDECLARED_OUTPUTS_DIR']
+    self.assertNotEmpty(glob.glob(os.path.join(out_dir, 'module_0*')))
+
+
+if __name__ == '__main__':
+  googletest.main()
diff --git a/tensorflow/compiler/tests/xla_dump_to_test.py b/tensorflow/compiler/tests/xla_dump_to_test.py
new file mode 100644
index 0000000..56c74ea
--- /dev/null
+++ b/tensorflow/compiler/tests/xla_dump_to_test.py
@@ -0,0 +1,47 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test cases for debug XLA dumps."""
+
+import glob
+import os
+
+import numpy as np
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import googletest
+
+
+class XlaDumpToDirTest(xla_test.XLATestCase):
+  """Test that ensures --XLA_FLAGS=--dump_to_xla=<dir> produces output."""
+
+  def _compute(self):
+    with self.session() as sess, self.device_scope():
+      data = np.array([0], dtype=np.float32)
+      indices = np.array([0], dtype=np.int32)
+      d = array_ops.placeholder(data.dtype, shape=data.shape)
+      i = array_ops.placeholder(indices.dtype, shape=indices.shape)
+      sess.run(math_ops.segment_max_v2(data, indices, 1), {d: data, i: indices})
+
+  def testDumpToTempDir(self):
+    tmp_dir = self.create_tempdir().full_path
+    os.environ['XLA_FLAGS'] = '--xla_dump_to=' + tmp_dir
+    self._compute()
+    self.assertNotEmpty(glob.glob(os.path.join(tmp_dir, 'module_0*')))
+
+
+if __name__ == '__main__':
+  googletest.main()
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index dde02e7..55e19a2 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -256,8 +256,6 @@
     copts = runtime_copts() + tf_openmp_copts(),
     defines = [
         "EIGEN_NEON_GEBP_NR=4",
-        # TODO(b/238649163): remove this once no longer necessary.
-        "EIGEN_USE_AVX512_GEMM_KERNELS=0",
     ],
     features = [
         "fully_static_link",
@@ -968,6 +966,7 @@
         "//tensorflow/core/framework:tensor_testutil",
         "@com_google_absl//absl/memory",
         "@com_google_absl//absl/strings",
+        "@local_tsl//tsl/platform:statusor",
         "@local_xla//xla:literal",
         "@local_xla//xla:shape_util",
         "@local_xla//xla:status_macros",
@@ -976,6 +975,7 @@
         "@local_xla//xla/client:xla_builder",
         "@local_xla//xla/service:cpu_plugin",
         "@local_xla//xla/service:hlo_proto_cc",
+        "@local_xla//xla/service:hlo_proto_util",
         "@local_xla//xla/tests:literal_test_util",
     ],
 )
@@ -1153,21 +1153,29 @@
     visibility = [":internal"],
     deps = [
         ":tf2xla_defs",
+        ":xla_op_registry",
         "//tensorflow/compiler/jit:flags",
         "//tensorflow/compiler/mlir:mlir_graph_optimization_pass",
         "//tensorflow/compiler/mlir/tensorflow",
         "//tensorflow/compiler/mlir/tensorflow:device_util",
         "//tensorflow/compiler/mlir/tensorflow/transforms:bridge",
+        "//tensorflow/compiler/mlir/tensorflow/transforms/host_runtime:lower_cluster_to_runtime_ops",
         "//tensorflow/compiler/mlir/tf2xla:mlir_bridge_rollout_policy",
         "//tensorflow/compiler/mlir/tf2xla/api/v1:cluster_tf",
+        "//tensorflow/compiler/mlir/tf2xla/api/v1:tf_dialect_to_executor",
         "//tensorflow/compiler/mlir/tf2xla/api/v2:cluster_tf",
+        "//tensorflow/compiler/mlir/tf2xla/api/v2:tf_dialect_to_executor",
         "//tensorflow/core:core_cpu",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core/common_runtime:device_set",
+        "//tensorflow/core/tpu:tpu_defs",
         "@com_google_absl//absl/base",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/status",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:FuncDialect",
+        "@llvm-project//mlir:IR",
     ],
     alwayslink = 1,
 )
diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc
index 794485a..b39d6f8 100644
--- a/tensorflow/compiler/tf2xla/const_analysis.cc
+++ b/tensorflow/compiler/tf2xla/const_analysis.cc
@@ -65,7 +65,8 @@
     std::vector<int>* const_input_idxs, FunctionLibraryRuntime* flib_runtime) {
   TF_RET_CHECK(!branch_bodies.empty());
   TF_RET_CHECK(branch_bodies[0] != nullptr);
-  int num_inputs = branch_bodies[0]->fdef.signature().input_arg_size();
+  int num_inputs =
+      branch_bodies[0]->record->fdef().signature().input_arg_size();
   // Stores indices of the "branch function" inputs that are expected to be
   // compile time constants.
   std::vector<bool> compile_time_const_arg_indices(num_inputs);
@@ -99,7 +100,7 @@
     TF_RETURN_IF_ERROR(GetFunctionBody(flib_runtime, node, "body", &fbody));
     TF_RET_CHECK(fcond);
     TF_RET_CHECK(fbody);
-    int num_inputs = fbody->fdef.signature().input_arg_size();
+    int num_inputs = fbody->record->fdef().signature().input_arg_size();
 
     // Stores which of the loop inputs are expected to be compile time
     // constants.
@@ -151,7 +152,7 @@
              node.op() == "StatefulPartitionedCall") {
     const FunctionBody* fbody;
     TF_RETURN_IF_ERROR(GetFunctionBody(flib_runtime, node, "f", &fbody));
-    int num_inputs = fbody->fdef.signature().input_arg_size();
+    int num_inputs = fbody->record->fdef().signature().input_arg_size();
     std::vector<bool> compile_time_const_arg_indices(num_inputs);
     TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
         *(fbody->graph), &compile_time_const_arg_indices,
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index e5ff2ae..e1605b5 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -369,6 +369,7 @@
         "//tensorflow/core:framework",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core/tpu:tpu_defs",
+        "@com_google_absl//absl/log",
         "@com_google_absl//absl/status",
         "@com_google_absl//absl/strings",
         "@llvm-project//llvm:Support",
@@ -2052,6 +2053,9 @@
         "//tensorflow/compiler/tf2xla/lib:broadcast",
         "//tensorflow/compiler/tf2xla/ops:xla_ops",
         "//tensorflow/core:framework",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/strings",
         "@local_xla//xla:shape_util",
         "@local_xla//xla/client:client_library",
         "@local_xla//xla/client:xla_builder",
@@ -2789,7 +2793,12 @@
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core/tpu:tpu_defs",
+        "@com_google_absl//absl/strings",
+        "@local_xla//xla:shape_util",
+        "@local_xla//xla:status_macros",
+        "@local_xla//xla:util",
         "@local_xla//xla/client:xla_builder",
+        "@local_xla//xla/client:xla_computation",
         "@local_xla//xla/client/lib:arithmetic",
         "@local_xla//xla/client/lib:comparators",
         "@local_xla//xla/client/lib:constants",
diff --git a/tensorflow/compiler/tf2xla/kernels/all_reduce_op.cc b/tensorflow/compiler/tf2xla/kernels/all_reduce_op.cc
index eb0250d..8fbaf49 100644
--- a/tensorflow/compiler/tf2xla/kernels/all_reduce_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/all_reduce_op.cc
@@ -13,6 +13,8 @@
 limitations under the License.
 ==============================================================================*/
 
+#include <vector>
+
 #include "tensorflow/compiler/tf2xla/type_util.h"
 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
@@ -75,8 +77,12 @@
     xla::ChannelHandle channel_handle;
     channel_handle.set_type(xla::ChannelHandle::DEVICE_TO_DEVICE);
     channel_handle.set_handle(*channel_id);
-    ctx->SetOutput(0,
-                   xla::AllReduce(ctx->Input(0), *reducer, {}, channel_handle));
+    std::vector<xla::ReplicaGroup> replica_groups(1);
+    for (int64_t i = 0; i < group_size; i++) {
+      replica_groups[0].add_replica_ids(i);
+    }
+    ctx->SetOutput(0, xla::AllReduce(ctx->Input(0), *reducer, replica_groups,
+                                     channel_handle));
   }
 
  private:
diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
index 9b4e574..11d7ce3 100644
--- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
@@ -18,21 +18,20 @@
 #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
 
 #include <algorithm>
+#include <cstdint>
 #include <utility>
 #include <vector>
 
+#include "absl/algorithm/container.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
 #include "tensorflow/compiler/tf2xla/lib/broadcast.h"
-#include "tensorflow/compiler/tf2xla/type_util.h"
-#include "tensorflow/compiler/tf2xla/xla_helpers.h"
 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
-#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "xla/client/client_library.h"
 #include "xla/client/lib/constants.h"
 #include "xla/client/xla_builder.h"
 #include "xla/shape.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/tensor_shape.h"
-#include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/util/bcast.h"
 
 namespace tensorflow {
@@ -46,60 +45,106 @@
   auto lhs_handle = ctx->Input(0);
   auto rhs_handle = ctx->Input(1);
   if (lhs_shape.dims() == rhs_shape.dims()) {
-    auto reconcile_tensor_mismatched_dims =
-        [ctx](xla::XlaOp op, const xla::Shape& lhs_xla_shape,
-              const xla::Shape& rhs_xla_shape, TensorShape* lhs_tensor_shape) {
-          // Find out mismatched dimensions that are non-broadcastable.
-          // Reconcile the
-          // difference by slicing the bigger dimension.
-          for (int64_t i = 0; i < lhs_xla_shape.rank(); ++i) {
-            if (lhs_xla_shape.is_dynamic_dimension(i)) {
-              if (!rhs_xla_shape.is_dynamic_dimension(i) &&
-                  lhs_xla_shape.dimensions(i) > rhs_xla_shape.dimensions(i) &&
-                  rhs_xla_shape.dimensions(i) != 1) {
-                // e.g., :
-                // lhs = [..., <=N, ...]
-                // rhs = [..., 2  , ...]
-                // Slice N into 2.
-                // Size 1 dim doesn't need slice as the other side is
-                // broadcastable.
-                auto size = xla::GetDimensionSize(op, i);
-                op = xla::SliceInDim(op, 0, rhs_xla_shape.dimensions(i), 1,
-                                     /*dimno=*/i);
-                lhs_tensor_shape->set_dim(i, rhs_xla_shape.dimensions(i));
-                // Propagate dynamic dimension.
-                op = xla::SetDimensionSize(op, size, i);
-              }
-              if (rhs_xla_shape.is_dynamic_dimension(i) &&
-                  lhs_xla_shape.dimensions(i) < rhs_xla_shape.dimensions(i) &&
-                  rhs_xla_shape.dimensions(i) != 1 &&
-                  lhs_xla_shape.dimensions(i) != 1) {
-                // e.g., :
-                // lhs = [..., <=M, ...]
-                // rhs = [..., <=N  , ...]
-                // where M < N
-                //
-                // In this case we pad M into N to make the bounds the same.
-                // Note that we can't slice N into M because M could be a
-                // dynamic size 1 dim that's meant to be broadcasted to N.
-                auto size = xla::GetDimensionSize(op, i);
-                int64_t diff =
-                    rhs_xla_shape.dimensions(i) - lhs_xla_shape.dimensions(i);
-                op = xla::PadInDim(
-                    op, xla::Zero(ctx->builder(), lhs_xla_shape.element_type()),
-                    i, 0, diff);
-                lhs_tensor_shape->set_dim(i, rhs_xla_shape.dimensions(i));
-                // Propagate dynamic dimension.
-                op = xla::SetDimensionSize(op, size, i);
-              }
-            }
+    auto reconcile_tensor_mismatched_dims = [ctx](
+                                                xla::XlaOp lhs, xla::XlaOp rhs,
+                                                const xla::Shape& lhs_xla_shape,
+                                                const xla::Shape& rhs_xla_shape,
+                                                TensorShape* lhs_tensor_shape) {
+      // Find out mismatched dimensions that are non-broadcastable.
+      // Reconcile the
+      // difference by slicing the bigger dimension.
+      for (int64_t i = 0; i < lhs_xla_shape.rank(); ++i) {
+        if (lhs_xla_shape.is_dynamic_dimension(i)) {
+          if (!rhs_xla_shape.is_dynamic_dimension(i) &&
+              lhs_xla_shape.dimensions(i) > rhs_xla_shape.dimensions(i) &&
+              rhs_xla_shape.dimensions(i) != 1) {
+            // e.g., :
+            // lhs = [..., <=N, ...]
+            // rhs = [..., 2  , ...]
+            // Slice N into 2.
+            // Size 1 dim doesn't need slice as the other side is
+            // broadcastable.
+            auto size = xla::GetDimensionSize(lhs, i);
+            lhs = xla::SliceInDim(lhs, 0, rhs_xla_shape.dimensions(i), 1,
+                                  /*dimno=*/i);
+            lhs_tensor_shape->set_dim(i, rhs_xla_shape.dimensions(i));
+            // Propagate dynamic dimension.
+            lhs = xla::SetDimensionSize(lhs, size, i);
           }
-          return op;
-        };
-    lhs_handle = reconcile_tensor_mismatched_dims(lhs_handle, lhs_xla_shape,
-                                                  rhs_xla_shape, &lhs_shape);
-    rhs_handle = reconcile_tensor_mismatched_dims(rhs_handle, rhs_xla_shape,
-                                                  lhs_xla_shape, &rhs_shape);
+          if (rhs_xla_shape.is_dynamic_dimension(i) &&
+              lhs_xla_shape.dimensions(i) < rhs_xla_shape.dimensions(i) &&
+              rhs_xla_shape.dimensions(i) != 1 &&
+              lhs_xla_shape.dimensions(i) != 1) {
+            // e.g., :
+            // lhs = [..., <=M, ...]
+            // rhs = [..., <=N  , ...]
+            // where M < N
+            //
+            // In this case we pad M into N to make the bounds the same.
+            // Note that we can't slice N into M because M could be a
+            // dynamic size 1 dim that's meant to be broadcasted to N.
+            auto size = xla::GetDimensionSize(lhs, i);
+            int64_t diff =
+                rhs_xla_shape.dimensions(i) - lhs_xla_shape.dimensions(i);
+            lhs = xla::PadInDim(
+                lhs, xla::Zero(ctx->builder(), lhs_xla_shape.element_type()), i,
+                0, diff);
+            lhs_tensor_shape->set_dim(i, rhs_xla_shape.dimensions(i));
+            // Propagate dynamic dimension.
+            lhs = xla::SetDimensionSize(lhs, size, i);
+          }
+          if (lhs_xla_shape.dimensions(i) == 1 &&
+              rhs_xla_shape.dimensions(i) != 1) {
+            // lhs = [..., <=1, ...]
+            // rhs = [...,   N, ...] or [..., <=N, ...]
+            // where N != 1.
+            //
+            // In this case we will need to broadcast this dimension to N.
+            // If the dynamic size is 0, the result size is zero.
+            // If the dynamic size is 1, the result size is N.
+            //
+            // However, XLA only does degenerate broadcasts for non-dynamic
+            // dimensions of size 1.
+
+            // Get the original size.
+            auto size = xla::GetDimensionSize(lhs, i);
+
+            // Remove the dynamic dimension.
+            lhs = xla::RemoveDynamicDimension(lhs, i);
+
+            // Broadcast the dimension to N.
+            std::vector<int64_t> dimensions(lhs_xla_shape.dimensions().begin(),
+                                            lhs_xla_shape.dimensions().end());
+            dimensions[i] = rhs_xla_shape.dimensions(i);
+            std::vector<int64_t> broadcast_dimensions(lhs_xla_shape.rank());
+            absl::c_iota(broadcast_dimensions, 0);
+            lhs = xla::BroadcastInDim(lhs, dimensions, broadcast_dimensions);
+
+            xla::XlaOp rhs_size;
+            if (rhs_xla_shape.is_dynamic_dimension(i)) {
+              rhs_size = xla::GetDimensionSize(rhs, i);
+            } else {
+              rhs_size = xla::ConstantR0<int32_t>(lhs.builder(),
+                                                  rhs_xla_shape.dimensions(i));
+            }
+
+            // The original size is 0 or 1, so we can multiply it by the RHS
+            // size to get the size of the resulting broadcast.
+            size = xla::Mul(size, rhs_size);
+
+            // Set the resulting dimension size.
+            lhs = xla::SetDimensionSize(lhs, size, i);
+
+            lhs_tensor_shape->set_dim(i, rhs_xla_shape.dimensions(i));
+          }
+        }
+      }
+      return lhs;
+    };
+    lhs_handle = reconcile_tensor_mismatched_dims(
+        lhs_handle, rhs_handle, lhs_xla_shape, rhs_xla_shape, &lhs_shape);
+    rhs_handle = reconcile_tensor_mismatched_dims(
+        rhs_handle, lhs_handle, rhs_xla_shape, lhs_xla_shape, &rhs_shape);
   }
   // By TensorFlow conventions the inputs may not have the same
   // shapes, in which case they will be automatically broadcast if
@@ -110,9 +155,9 @@
   BCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape),
               /*fewer_dims_optimization=*/false);
   if (!bcast.IsValid()) {
-    ctx->SetStatus(errors::InvalidArgument("Incompatible shapes: ",
-                                           lhs_shape.DebugString(), " vs. ",
-                                           rhs_shape.DebugString()));
+    ctx->SetStatus(absl::InvalidArgumentError(
+        absl::StrCat("Incompatible shapes: ", lhs_shape.DebugString(), " vs. ",
+                     rhs_shape.DebugString())));
     return;
   }
 
diff --git a/tensorflow/compiler/tf2xla/kernels/where_op.cc b/tensorflow/compiler/tf2xla/kernels/where_op.cc
index b505d90..c780dde 100644
--- a/tensorflow/compiler/tf2xla/kernels/where_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/where_op.cc
@@ -14,12 +14,11 @@
 ==============================================================================*/
 
 #include <array>
+#include <cstdint>
 #include <memory>
 #include <vector>
 
-#include "tensorflow/compiler/tf2xla/literal_util.h"
-#include "tensorflow/compiler/tf2xla/type_util.h"
-#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "absl/strings/string_view.h"
 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
 #include "xla/client/lib/arithmetic.h"
@@ -27,14 +26,17 @@
 #include "xla/client/lib/constants.h"
 #include "xla/client/lib/dynamic_shaped_ops.h"
 #include "xla/client/xla_builder.h"
+#include "xla/client/xla_computation.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/status_macros.h"
+#include "xla/util.h"
 #include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/ops_util.h"
-#include "tensorflow/core/framework/register_types.h"
-#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/op_requires.h"
+#include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/lib/core/bits.h"
-#include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/platform/statusor.h"
-#include "tensorflow/core/tpu/tpu_defs.h"
+#include "tsl/platform/errors.h"
 
 namespace tensorflow {
 namespace {
@@ -157,8 +159,8 @@
   XlaOp condition = ctx->Input(0);
   TF_ASSIGN_OR_RETURN(xla::Shape input_shape,
                       ctx->builder()->GetShape(condition));
-  auto iota_shape = input_shape;
-  iota_shape.set_element_type(xla::S32);
+  auto iota_shape =
+      xla::ShapeUtil::MakeShape(xla::S32, input_shape.dimensions());
 
   int64_t flattened_size = xla::Product(iota_shape.dimensions());
   XlaOp reshaped_condition = xla::Reshape(condition, {flattened_size});
@@ -272,8 +274,7 @@
   //
   // and then scatter iotas[out_idxs] into the output.
   std::vector<XlaOp> iotas_to_concat;
-  auto iota_shape = input_shape;
-  iota_shape.set_element_type(S32);
+  auto iota_shape = xla::ShapeUtil::MakeShape(S32, input_shape.dimensions());
   for (int64_t axis = 0; axis < iota_shape.rank(); ++axis) {
     iotas_to_concat.push_back(
         xla::Reshape(xla::Iota(b, iota_shape, axis), {flattened_size, 1}));
diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc
index 2f3fec6..f602e66 100644
--- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc
+++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc
@@ -80,11 +80,11 @@
 constexpr int kVersionStartSupportDisabledChecks = 6;
 constexpr int kVersionStartSupportShapeAssertions = 7;
 constexpr int kVersionStartSupportUsesShapePolymorphismAttr = 8;
-constexpr int kVersionMinimumSupported = kVersionStartSupportCallTFGraph;
+constexpr int kVersionStartSupportEffects = 9;
+constexpr int kVersionMinimumSupported = kVersionStartStableHloCompatibility;
 
 // This should match xla.py:call_module_maximum_supported_version
-constexpr int kVersionMaximumSupported =
-    kVersionStartSupportUsesShapePolymorphismAttr;
+constexpr int kVersionMaximumSupported = kVersionStartSupportEffects;
 
 constexpr llvm::StringRef kDisabledCheckPlatform = "platform";
 
@@ -141,6 +141,11 @@
 
 }  // namespace
 
+bool IsTokenType(mlir::Type type) {
+  return type.isa<mlir::stablehlo::TokenType>() ||
+         type.isa<mlir::mhlo::TokenType>();
+}
+
 tsl::StatusOr<std::unique_ptr<XlaCallModuleLoader>> XlaCallModuleLoader::Create(
     mlir::MLIRContext *context, int version, std::string module_str,
     std::vector<std::string> disabled_checks,
@@ -236,7 +241,8 @@
   // Refine 'main' argument types to use static input types instead. The main
   // arguments may occur as return values, or as inputs to called functions,
   // and changing their types may invalidate the module. To prevent this
-  // we insert dummy conversion ops as the sole uses of the main arguments.
+  // we insert dummy conversion ops as the sole uses of the main arguments, for
+  // the arguments that are not tokens and have dynamic shape.
   // If we use stablehlo.convert, we end up with "convert 3xf32 -> *xf32"
   // after we set the static shapes for the main arguments. The "convert"
   // op does not support unranked result for ranked inputs. So, we use
@@ -246,9 +252,16 @@
   op_builder.setInsertionPointToStart(&main_body);
   for (auto i = 0; i < main_body.getNumArguments(); ++i) {
     mlir::BlockArgument arg = main_body.getArgument(i);
-    auto convert_op = op_builder.create<mlir::stablehlo::BitcastConvertOp>(
-        arg.getLoc(), arg.getType(), arg);
-    arg.replaceAllUsesExcept(convert_op, convert_op);
+    mlir::Type arg_type = arg.getType();
+    if (IsTokenType(arg_type)) {
+      continue;
+    }
+    auto ranked_arg_type = arg_type.dyn_cast<mlir::RankedTensorType>();
+    if (!ranked_arg_type || !ranked_arg_type.hasStaticShape()) {
+      auto convert_op = op_builder.create<mlir::stablehlo::BitcastConvertOp>(
+          arg.getLoc(), arg_type, arg);
+      arg.replaceAllUsesExcept(convert_op, convert_op);
+    }
   }
 
   auto static_array_output_types = llvm::to_vector(main_.getResultTypes());
@@ -376,7 +389,19 @@
   }
 
   mlir::Block &main_body = main_.front();
-  int nr_token_arguments = main_has_token_input_output ? 1 : 0;
+
+  int nr_token_arguments = llvm::count_if(InputTypes(), IsTokenType);
+  if (version < kVersionStartSupportEffects) {
+    bool has_token_at_start = (nr_token_arguments == 1 &&
+                               IsTokenType(main_.getArgument(0).getType()));
+    if (main_has_token_input_output != has_token_at_start) {
+      return absl::InvalidArgumentError(absl::StrCat(
+          "Expected a token at start iff main_has_token_input_output. ",
+          "Found main function type ",
+          mlir::debugString(main_.getFunctionType()),
+          " and main_has_token_input_output = ", main_has_token_input_output));
+    }
+  }
   int nr_platform_args = (platform_index_ >= 0 ? 1 : 0);
   if (num_invocation_args != main_body.getNumArguments() - nr_token_arguments) {
     return absl::InvalidArgumentError(absl::StrCat(
diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h
index 33de164..bb2e73e 100644
--- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h
+++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h
@@ -26,11 +26,15 @@
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/OwningOpRef.h"  // from @llvm-project
 #include "mlir/IR/TypeRange.h"  // from @llvm-project
+#include "stablehlo/dialect/StablehloOps.h"  // from @stablehlo
 #include "xla/client/xla_computation.h"
+#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
 #include "tsl/platform/statusor.h"
 
 namespace tensorflow {
 
+bool IsTokenType(mlir::Type type);
+
 class XlaCallModuleLoader {
  public:
   static tsl::StatusOr<std::unique_ptr<XlaCallModuleLoader>> Create(
@@ -39,8 +43,11 @@
       std::vector<std::string> platforms, std::string loading_platform,
       int num_invocation_args, bool main_has_token_input_output);
 
-  int nr_outputs() { return main_.getNumResults(); }
-  mlir::TypeRange output_types() { return main_.getResultTypes(); }
+  int NrInputs() { return main_.getNumArguments(); }
+  mlir::TypeRange InputTypes() { return main_.getArgumentTypes(); }
+
+  int NrOutputs() { return main_.getNumResults(); }
+  mlir::TypeRange OutputTypes() { return main_.getResultTypes(); }
 
   // Refines the dynamic module arguments based on the static argument shapes.
   // This assumes that the module has a "main" function without dimension args,
diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc
index 200b249..0491420 100644
--- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc
@@ -19,10 +19,12 @@
 #include <utility>
 #include <vector>
 
+#include "absl/log/log.h"
 #include "absl/status/status.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/str_join.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "mlir/Dialect/Func/Extensions/AllExtensions.h"  // from @llvm-project
 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
@@ -150,7 +152,36 @@
     OP_REQUIRES_OK(ctx, ctx->GetAttr("disabled_checks", &disabled_checks));
     std::vector<string> platforms;
     OP_REQUIRES_OK(ctx, ctx->GetAttr("platforms", &platforms));
+    // TODO(necula): change this to OP_REQUIRES_OK when 6 months have passed
+    // since we added the function_list and has_token_input_output
+    // attributes (May 25, 2023).
+    bool main_has_token_input_output = false;
+    if (!ctx->GetAttr("has_token_input_output", &main_has_token_input_output)
+             .ok()) {
+      // Whether the StableHLO module's main function has token input/output as
+      // the first argument and the first result.
+      // This is used only prior to version 9; afterwards, we just look for
+      // tokens among the types of the arguments and results, and we support
+      // multiple tokens, not necessarily at the start.
+      main_has_token_input_output = false;
+    }
+    if (!ctx->GetAttr("function_list", &function_list_).ok()) {
+      function_list_.clear();
+    }
 
+    if (VLOG_IS_ON(3)) {
+      VLOG(3) << "Initializing XlaCallModuleOp (version = " << version
+              << ", platforms = [" << absl::StrJoin(platforms, ", ")
+              << "], has_token_input_output = " << main_has_token_input_output
+              << ", disabled_checks = [" << absl::StrJoin(disabled_checks, ", ")
+              << "], "
+              << "function_list = ["
+              << absl::StrJoin(function_list_, ",",
+                               [](std::string *out, NameAttrList x) {
+                                 absl::StrAppend(out, x.name());
+                               })
+              << "])";
+    }
     string loading_device_type = ctx->device_type().type_string();
     string loading_platform = "";
     if (loading_device_type == DEVICE_CPU_XLA_JIT) {
@@ -171,31 +202,21 @@
                   absl::UnimplementedError(absl::StrCat(
                       "Unexpected device type ", loading_device_type)));
     }
-    VLOG(3) << "Initialized XlaCallModuleOp on " << loading_platform;
-    if (!ctx->GetAttr("has_token_input_output", &module_has_token_input_output_)
-             .ok()) {
-      module_has_token_input_output_ = false;
-    }
+    VLOG(3) << "Initializing XlaCallModuleOp on " << loading_platform;
     {
       auto loader = XlaCallModuleLoader::Create(
           &context_, version, std::move(module_str), std::move(disabled_checks),
           std::move(platforms), loading_platform,
           /*num_invocation_args=*/ctx->num_inputs(),
-          module_has_token_input_output_);
+          main_has_token_input_output);
       OP_REQUIRES_OK(ctx, loader.status());
       loader_ = *std::move(loader);
     }
     OP_REQUIRES_OK(ctx, loader_->ValidateDialect());
 
-    if (!ctx->GetAttr("function_list", &function_list_).ok()) {
-      function_list_.clear();
-    }
-
-    if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) {
-      token_input_nodes_.clear();
-      op_has_token_input_output_ = false;
-    } else {
-      op_has_token_input_output_ = !token_input_nodes_.empty();
+    if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &op_token_input_nodes_)
+             .ok()) {
+      op_token_input_nodes_.clear();
     }
     if (!ctx->GetAttr(kXlaOriginalOutsideCompilationNodeName,
                       &original_node_name_)
@@ -213,13 +234,15 @@
     xla::XlaBuilder *const b = ctx->builder();
 
     std::vector<xla::Shape> input_shapes;
-    if (module_has_token_input_output_) {
-      input_shapes.push_back(xla::ShapeUtil::MakeTokenShape());
-    }
-    for (int i = 0; i < ctx->num_inputs(); ++i) {
-      auto shape = ctx->InputXlaShape(i);
-      OP_REQUIRES_OK(ctx, shape.status());
-      input_shapes.push_back(*std::move(shape));
+    int next_actual_input = 0;
+    for (mlir::Type inputType : loader_->InputTypes()) {
+      if (IsTokenType(inputType)) {
+        input_shapes.push_back(xla::ShapeUtil::MakeTokenShape());
+      } else {
+        auto shape = ctx->InputXlaShape(next_actual_input++);
+        OP_REQUIRES_OK(ctx, shape.status());
+        input_shapes.push_back(*std::move(shape));
+      }
     }
     OP_REQUIRES_OK(ctx, loader_->RefineDynamicShapes(input_shapes));
     OP_REQUIRES_OK(ctx, loader_->ValidateStaticShapes());
@@ -228,27 +251,30 @@
       OP_REQUIRES_OK(ctx, LowerTfFunctionCalls(ctx));
     }
 
-    std::vector<xla::XlaOp> inputs;
-    if (module_has_token_input_output_) {
-      // The main function expects a token input at the start.
-      if (!token_input_nodes_.empty()) {
-        std::vector<xla::XlaOp> token_inputs;
-        for (const string &node_name : token_input_nodes_) {
-          auto token = compiler->GetNodeToken(node_name);
-          OP_REQUIRES_OK(ctx, token.status());
-          token_inputs.push_back(token.value());
-        }
-        inputs.push_back(xla::AfterAll(b, token_inputs));
-      } else {
-        // Generate a dummy token if the main function expects a token but the
-        // XlaCallModule doesn't take one.
-        inputs.push_back(xla::CreateToken(b));
+    xla::XlaOp token_input;
+    if (!op_token_input_nodes_.empty()) {
+      std::vector<xla::XlaOp> token_inputs;
+      for (const string &node_name : op_token_input_nodes_) {
+        auto token = compiler->GetNodeToken(node_name);
+        OP_REQUIRES_OK(ctx, token.status());
+        token_inputs.push_back(token.value());
       }
-    }
-    for (int i = 0, end = ctx->num_inputs(); i < end; ++i) {
-      inputs.push_back(ctx->Input(i));
+      token_input = xla::AfterAll(b, token_inputs);
     }
 
+    std::vector<xla::XlaOp> inputs;
+    next_actual_input = 0;
+    for (mlir::Type inputType : loader_->InputTypes()) {
+      if (IsTokenType(inputType)) {
+        if (token_input.IsUninitialized()) {
+          // Generate a dummy token if the XlaCallModule doesn't take one.
+          token_input = xla::CreateToken(b);
+        }
+        inputs.push_back(token_input);
+      } else {
+        inputs.push_back(ctx->Input(next_actual_input++));
+      }
+    }
     auto xla_computation = loader_->ToXlaComputation();
     OP_REQUIRES_OK(ctx, xla_computation.status());
 
@@ -266,47 +292,58 @@
                                      hlo_module->ToString(options)));
     }
 
-    xla::XlaOp output = xla::Call(b, *xla_computation, inputs);
+    xla::XlaOp computation_output = xla::Call(b, *xla_computation, inputs);
 
     // Check that the resulting computation returns the expected shape
-    OP_REQUIRES_VALUE(xla::Shape found_output_shape, ctx, b->GetShape(output));
+    OP_REQUIRES_VALUE(xla::Shape found_output_shape, ctx,
+                      b->GetShape(computation_output));
     VLOG(3) << "XlaCallModule compiled output shape : "
             << xla::ShapeUtil::HumanString(found_output_shape);
-
-    std::vector<xla::XlaOp> outputs;
-    if (loader_->nr_outputs() == 1) {
-      outputs.push_back(output);
+    std::vector<xla::XlaOp> computation_outputs;
+    if (loader_->NrOutputs() == 1) {
+      computation_outputs.push_back(computation_output);
     } else {
-      for (int i = 0; i < loader_->nr_outputs(); ++i) {
-        outputs.push_back(xla::GetTupleElement(output, i));
+      for (int i = 0; i < loader_->NrOutputs(); ++i) {
+        computation_outputs.push_back(
+            xla::GetTupleElement(computation_output, i));
       }
     }
 
-    xla::XlaOp token_output;
-    if (module_has_token_input_output_) {
-      // The main function returns a token as the first output.
-      token_output = outputs.front();
-      outputs.erase(outputs.begin());
-      auto shape = b->GetShape(token_output);
+    // Collect the token outputs and set the non-token outputs
+    std::vector<xla::XlaOp> token_outputs;
+    int next_actual_output = 0;
+    for (auto it : llvm::enumerate(loader_->OutputTypes())) {
+      int i = it.index();
+      mlir::Type output_type = it.value();
+      auto shape = b->GetShape(computation_outputs[i]);
       OP_REQUIRES_OK(ctx, shape.status());
-      OP_REQUIRES(ctx, shape->IsToken(),
-                  absl::FailedPreconditionError(
-                      absl::StrCat("Token output is not token type: ",
-                                   xla::ShapeUtil::HumanString(*shape))));
+      if (IsTokenType(output_type)) {
+        OP_REQUIRES(ctx, shape->IsToken(),
+                    absl::FailedPreconditionError(absl::StrCat(
+                        "Token output at index ", i, " is not token type: ",
+                        xla::ShapeUtil::HumanString(*shape))));
+        token_outputs.push_back(computation_outputs[i]);
+      } else {
+        OP_REQUIRES(ctx, !shape->IsToken(),
+                    absl::FailedPreconditionError(absl::StrCat(
+                        "Non-token output at index ", i, " is a token type: ",
+                        xla::ShapeUtil::HumanString(*shape))));
+        ctx->SetOutput(next_actual_output++, computation_outputs[i]);
+      }
     }
-    if (op_has_token_input_output_) {
-      if (token_output.IsUninitialized()) {
-        // The main function does not return any token, but the XlaCallModule is
-        // expected to return one. Create a dummy token.
-        token_output = xla::CreateToken(b);
+
+    if (!op_token_input_nodes_.empty()) {
+      xla::XlaOp token_output = token_input;
+      if (!token_outputs.empty()) {
+        token_output = xla::AfterAll(b, token_outputs);
+      } else {
+        if (token_output.IsUninitialized()) {
+          token_output = xla::CreateToken(b);
+        }
       }
       OP_REQUIRES_OK(ctx,
                      compiler->SetNodeToken(original_node_name_, token_output));
     }
-
-    for (int i = 0; i < outputs.size(); ++i) {
-      ctx->SetOutput(i, outputs[i]);
-    }
   }
 
  private:
@@ -404,7 +441,7 @@
       options.always_return_tuple = true;
       options.is_entry_computation = false;
       // Propagate tokens from XlaCallModule to inner computation.
-      options.add_token_input_output = op_has_token_input_output_;
+      options.add_token_input_output = !op_token_input_nodes_.empty();
 
       XlaCompiler::CompilationResult result;
       TF_RETURN_IF_ERROR(
@@ -518,11 +555,8 @@
   std::unique_ptr<XlaCallModuleLoader> loader_;
   std::vector<NameAttrList> function_list_;
 
-  // Whether the StableHLO module's main function has token input/output.
-  bool module_has_token_input_output_;
   // Whether the XlaCallModule op has token input/output.
-  bool op_has_token_input_output_;
-  std::vector<std::string> token_input_nodes_;
+  std::vector<std::string> op_token_input_nodes_;
   std::string original_node_name_;
 };
 
diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc
index 143a38c..0cfcc0a 100644
--- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc
+++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc
@@ -19,19 +19,31 @@
 
 #include "tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h"
 #include "absl/base/call_once.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/Visitors.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
-#include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
+#include "tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
 #include "tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf.h"
+#include "tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor.h"
 #include "tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.h"
+#include "tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor.h"
 #include "tensorflow/compiler/tf2xla/tf2xla_defs.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
 #include "tensorflow/core/common_runtime/device_set.h"
 #include "tensorflow/core/framework/metrics.h"
 #include "tensorflow/core/lib/monitoring/gauge.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/tpu/tpu_defs.h"
 #include "tensorflow/core/util/device_name_utils.h"
+#include "tsl/framework/device_type.h"
+#include "tsl/platform/errors.h"
 
 namespace tensorflow {
 
@@ -44,6 +56,8 @@
 
 namespace {
 
+using ::mlir::ModuleOp;
+
 bool HasTPUDevice(mlir::ModuleOp module) {
   mlir::TF::RuntimeDevices devices;
   if (failed(GetDevicesFromOp(module.getOperation(), &devices))) return false;
@@ -125,6 +139,33 @@
   return has_tpu_partitioned_call;
 }
 
+// V1 Compat Bridge extracts out a program into a submodule and runs clustering
+// only on the submodule.
+absl::Status RunLowerToRuntimeOpsOnSubmodule(ModuleOp parent_module,
+                                             bool is_in_fallback_enabled_mode) {
+  int num_submodules = 0;
+  absl::Status runtime_lowering_status;
+  parent_module.walk([&](ModuleOp submodule) {
+    if (submodule == parent_module) return mlir::WalkResult::advance();
+    num_submodules++;
+    runtime_lowering_status =
+        tensorflow::tfrt_compiler::RunLowerClusterToRuntimeOpsPassPipeline(
+            submodule, tsl::DeviceType(DEVICE_TPU_XLA_JIT));
+    if (num_submodules > 1) {
+      return mlir::WalkResult::interrupt();
+    }
+
+    return mlir::WalkResult::advance();
+  });
+
+  if (num_submodules > 1) {
+    return absl::InternalError(
+        "Lower to runtime has more than one submodule. Erroring out.");
+  }
+
+  return runtime_lowering_status;
+}
+
 }  // namespace
 
 // Analyzes the user requested policy as well as the contents of the graph and
@@ -246,7 +287,10 @@
   }
 
   if (HasTPUPartitionedCallOpInModule(module)) {
-    VLOG(1) << "This is an inference module.";
+    VLOG(1) << "Skipping MLIR TF2XLA Bridge. This is an inference graph, "
+               "Session V1 Bridge should be used during execution of "
+               "TPUPartitionedCall.";
+    return OkStatus();
   }
 
   // TODO(b/241853328): Add caching of pass state and call logging/metrics
@@ -263,8 +307,8 @@
     return OkStatus();
   }
 
+  bool fallback_enabled = false;
   if (run_tpu_bridge) {
-    bool fallback_enabled = false;
     if (pass_state == MlirOptimizationPassState::FallbackEnabled) {
       // We set `uses_uninitialized_resource_args` to false here because the
       // first phase of the bridge is not affected by uninitialized resource
@@ -278,12 +322,29 @@
     }
     VLOG(1) << "Running MLIR TPU Bridge";
     mlir_bridge_gauge_v2->GetCell()->Set(true);
-    return tensorflow::tf2xla::v2::RunFunctionTf2xlaClusteringBridge(
-        module, tf2xla::v2::DeviceType::XLA_TPU_JIT,
-        /*is_in_fallback_enabled_mode=*/fallback_enabled, function_name);
+
+    TF_RETURN_IF_ERROR(
+        tensorflow::tf2xla::v2::RunFunctionTf2xlaClusteringBridge(
+            module, tf2xla::v2::DeviceType::XLA_TPU_JIT, fallback_enabled,
+            function_name));
+
+    TF_RETURN_IF_ERROR(
+        tensorflow::tfrt_compiler::RunLowerClusterToRuntimeOpsPassPipeline(
+            module, tsl::DeviceType(DEVICE_TPU_XLA_JIT), function_name));
+  } else {
+    VLOG(1) << "Running GPU/CPU Bridge";
+    TF_RETURN_IF_ERROR(
+        tensorflow::tf2xla::v2::RunFunctionTf2xlaClusteringBridge(
+            module, tf2xla::v2::DeviceType::XLA_GPU_JIT, fallback_enabled,
+            function_name));
+
+    TF_RETURN_IF_ERROR(
+        tensorflow::tfrt_compiler::RunLowerClusterToRuntimeOpsPassPipeline(
+            module, tsl::DeviceType(DEVICE_GPU_XLA_JIT), function_name));
   }
-  VLOG(1) << "Running MLIR CPU/GPU Bridge";
-  return mlir::TF::RunTFXLABridge(module, function_name);
+
+  return tensorflow::tf2xla::v2::ExportFromTensorflowDialectToExecutor(
+      module, function_name);
 }
 
 MlirOptimizationPassState MlirBridgeV1CompatPass::GetPassState(
@@ -371,6 +432,7 @@
     return OkStatus();
   }
 
+  bool fallback_enabled = false;
   if (pass_state == MlirOptimizationPassState::FallbackEnabled) {
     // We set `uses_uninitialized_resource_args` to false here because the first
     // phase of the bridge is not affected by uninitialized resource args.
@@ -380,12 +442,23 @@
                      options.session_options->config,
                      /*uses_uninitialized_resource_args=*/false,
                      /*is_v1_compat=*/true);
+    fallback_enabled = true;
   }
 
   VLOG(1) << "Running MLIR TPU Bridge V1 Compat";
-
   mlir_bridge_gauge_v1->GetCell()->Set(true);
-  return tensorflow::tf2xla::v1::RunSessionTf2xlaClusteringBridge(module);
+  TF_RETURN_IF_ERROR(tensorflow::tf2xla::v1::RunSessionTf2xlaClusteringBridge(
+      module, fallback_enabled));
+
+  auto lower_cluster_to_runtime_ops_pass_pipeline =
+      RunLowerToRuntimeOpsOnSubmodule(module, fallback_enabled);
+  if (!lower_cluster_to_runtime_ops_pass_pipeline.ok()) {
+    VLOG(1) << "Error while lowering cluster to runtime ops: "
+            << lower_cluster_to_runtime_ops_pass_pipeline;
+    return lower_cluster_to_runtime_ops_pass_pipeline;
+  }
+
+  return tensorflow::tf2xla::v1::ExportFromTensorflowDialectToExecutor(module);
 }
 
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
index 4aa2811..480dc47 100644
--- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc
+++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
@@ -1392,7 +1392,10 @@
 has_token_input_output: If true, the embedded StableHLO module's main function
   must take a `!stablehlo.token` as its first argument and returns a token as
   its first result. This can be used in conjunction with the TF2XLA's side
-  effect mechanism in order to model side effects.
+  effect mechanism in order to model side effects. This is used only in versions
+  prior to version 9. After that, the number and position of tokens among
+  the arguments and results are obtained from the main function type. This
+  allows us to support more than one token and not necessarily at the start.
 disabled_checks: A list of strings describing the safety checks that were
   disabled at serialization time. This attribute was added in version 6.
   For more details see
diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py
index 80eecfb..27940b7 100644
--- a/tensorflow/compiler/tf2xla/python/xla.py
+++ b/tensorflow/compiler/tf2xla/python/xla.py
@@ -669,7 +669,7 @@
   See versioning details documentation for the XlaCallModule op at:
   https://github.com/search?q=repo%3Atensorflow%2Ftensorflow+path%3Axla_call_module+%22int+VERSION_MAXIMUM_SUPPORTED%22&type=code
   """
-  return 8
+  return 9
 
 # pylint: enable=g-doc-args
 # pylint: enable=g-doc-return-or-yield
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 101b728..aa2c761 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -444,7 +444,6 @@
 
 }  // namespace
 
-
 string XlaCompiler::Argument::HumanString() const {
   string common;
   if (!name.empty()) {
@@ -602,7 +601,7 @@
   CopyGraph(*fbody->graph, graph.get());
 
   bool is_inside_mustcompile = false;
-  TryGetNodeAttr(AttrSlice(&fbody->fdef.attr()), kXlaMustCompileAttr,
+  TryGetNodeAttr(AttrSlice(&fbody->record->fdef().attr()), kXlaMustCompileAttr,
                  &is_inside_mustcompile);
 
   // Performs a first function inlining pass before shape inference, since
@@ -954,6 +953,15 @@
         TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
             arg_sharding, /*use_fast_memory=*/false,
             options_.shape_determination_fns, xla_shape));
+        // If the arg is dynamic then we update the shape to reflect that. The
+        // arg's value_dynamism is a Tensor of bools set to the dynamism of each
+        // dimension.
+        if (arg.value_dynamism.has_value()) {
+          auto dynamism = arg.value_dynamism.value().vec<bool>();
+          for (int i = 0; i < dynamism.size(); ++i) {
+            xla_shape->set_dynamic_dimension(i, dynamism(i));
+          }
+        }
       } else {
         if (absl::holds_alternative<xla::Shape>(arg.shape)) {
           *xla_shape = std::get<xla::Shape>(arg.shape);
@@ -1427,10 +1435,11 @@
       StackFrame({"dummy_file_name", 10, "dummy_function_name"})};
 };
 
-Status XlaCompiler::CompileGraph(
-    const XlaCompiler::CompileOptions& options, string const& name,
-    std::unique_ptr<Graph> graph, absl::Span<const XlaCompiler::Argument> args,
-    CompilationResult* result) {
+Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
+                                 string const& name,
+                                 std::unique_ptr<Graph> graph,
+                                 absl::Span<const XlaCompiler::Argument> args,
+                                 CompilationResult* result) {
   VLOG(1) << "Executing graph symbolically to populate XlaBuilder.: " << name;
   if (VLOG_IS_ON(2)) {
     VLOG(2) << "XlaCompiler::CompileGraph: "
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index e32f1fb..f8b394e 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -25,6 +25,7 @@
 #include "tensorflow/cc/ops/math_ops.h"
 #include "tensorflow/cc/ops/resource_variable_ops.h"
 #include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/compiler/tf2xla/literal_util.h"
 #include "tensorflow/compiler/tf2xla/shape_util.h"
 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
 #include "tensorflow/compiler/tf2xla/type_util.h"
@@ -35,6 +36,7 @@
 #include "xla/client/xla_builder.h"
 #include "xla/literal.h"
 #include "xla/service/hlo.pb.h"
+#include "xla/service/hlo_proto_util.h"
 #include "xla/shape_util.h"
 #include "xla/status_macros.h"
 #include "xla/tests/literal_test_util.h"
@@ -59,6 +61,7 @@
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/public/version.h"
+#include "tsl/platform/statusor.h"
 
 namespace tensorflow {
 
@@ -241,6 +244,56 @@
   EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
 }
 
+StatusOr<std::unique_ptr<xla::HloModule>> LoadModuleFromHloProto(
+    const xla::HloModuleProto& module_proto) {
+  TF_ASSIGN_OR_RETURN(auto module_config,
+                      xla::HloModule::CreateModuleConfigFromProto(
+                          module_proto, xla::GetDebugOptionsFromFlags()));
+  return xla::CreateModuleFromProto(module_proto, module_config);
+}
+
+// Tests compilation and execution of a graph that adds two tensors with dynamic
+// shape parameters.
+TEST_F(XlaCompilerTest, SimpleDynamicShapeParameter) {
+  // Builds a graph that adds two Tensors.
+  Scope scope = Scope::NewRootScope().ExitOnError();
+  auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
+  auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1);
+  auto c = ops::Add(scope.WithOpName("C"), a, b);
+  auto d = ops::_Retval(scope.WithOpName("D"), c, 0);
+  std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+  TF_ASSERT_OK(scope.ToGraph(graph.get()));
+
+  // Builds a description of the arguments.
+  std::vector<XlaCompiler::Argument> args(2);
+  args[0].kind = XlaCompiler::Argument::kParameter;
+  args[0].type = DT_INT32;
+  args[0].shape = TensorShape({2});
+  args[0].value_bound = Tensor(DT_INT32, std::get<0>(args[0].shape));
+  Tensor dynamism_tensor(DT_BOOL);
+  TF_ASSERT_OK(LiteralToHostTensor(xla::LiteralUtil::CreateR1<bool>({true}),
+                                   DT_BOOL, &dynamism_tensor));
+  args[0].value_dynamism = dynamism_tensor;
+  args[1].kind = XlaCompiler::Argument::kParameter;
+  args[1].type = DT_INT32;
+  args[1].shape = TensorShape({2});
+
+  // Compiles the graph.
+  XlaCompiler compiler(DefaultOptions());
+
+  XlaCompiler::CompilationResult result;
+  TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
+                                     std::move(graph), args, &result));
+
+  auto hlo = result.computation->proto();
+  TF_ASSERT_OK_AND_ASSIGN(auto module, LoadModuleFromHloProto(hlo));
+  EXPECT_EQ(module->computation_count(), 1);
+  EXPECT_TRUE(module->mutable_computation(0)
+                  ->parameter_instruction(0)
+                  ->shape()
+                  .is_dynamic());
+}
+
 // Tests compilation of a graph where the _Retval node is not necessarily last
 // amongst the graph nodes in construction order, and always_return_tuple is
 // false. Regression test for bug where the wrong value was returned.
diff --git a/tensorflow/compiler/xrt/tests/BUILD b/tensorflow/compiler/xrt/tests/BUILD
index f6d4252..0139bd1 100644
--- a/tensorflow/compiler/xrt/tests/BUILD
+++ b/tensorflow/compiler/xrt/tests/BUILD
@@ -1,9 +1,9 @@
-load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
 load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_cuda_cc_test")
 load(
     "//tensorflow/core/platform:build_config_root.bzl",
     "tf_cuda_tests_tags",
 )
+load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
 
 package(
     # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
@@ -27,24 +27,32 @@
         "//tensorflow/compiler/xrt:xrt_proto_cc",
         "//tensorflow/compiler/xrt:xrt_server",
         "//tensorflow/compiler/xrt/cc:xrt_ops",
+        "//tensorflow/core:core_cpu",
         "//tensorflow/core:core_cpu_internal",
         "//tensorflow/core:framework",
         "//tensorflow/core:framework_internal",
         "//tensorflow/core:lib",
         "//tensorflow/core:tensorflow_opensource",
         "//tensorflow/core:test",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
         "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:span",
+        "@local_tsl//tsl/platform:status",
         "@local_xla//xla:literal",
         "@local_xla//xla:literal_util",
         "@local_xla//xla:shape_util",
         "@local_xla//xla:xla_data_proto_cc",
         "@local_xla//xla/client:client_library",
+        "@local_xla//xla/client:executable_build_options",
         "@local_xla//xla/client:local_client",
+        "@local_xla//xla/client:padding",
         "@local_xla//xla/client:xla_builder",
         "@local_xla//xla/client:xla_computation",
         "@local_xla//xla/client/lib:arithmetic",
         "@local_xla//xla/client/lib:constants",
         "@local_xla//xla/service:platform_util",
+        "@local_xla//xla/stream_executor:platform",
     ],
 )
 
diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc
index cae878f..10f32f4 100644
--- a/tensorflow/compiler/xrt/tests/raw_api_test.cc
+++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc
@@ -13,40 +13,60 @@
 limitations under the License.
 ==============================================================================*/
 
+#include <cstddef>
+#include <cstdint>
 #include <functional>
+#include <initializer_list>
 #include <memory>
 #include <string>
 #include <utility>
 #include <vector>
 
+#include "absl/log/check.h"
+#include "absl/log/log.h"
 #include "absl/strings/str_cat.h"
+#include "absl/types/span.h"
 #include "tensorflow/cc/client/client_session.h"
 #include "tensorflow/cc/framework/ops.h"
 #include "tensorflow/cc/framework/scope.h"
-#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/cc/ops/array_ops.h"
+#include "tensorflow/cc/ops/const_op.h"
 #include "tensorflow/compiler/tf2xla/literal_util.h"
 #include "tensorflow/compiler/tf2xla/shape_util.h"
 #include "xla/client/client_library.h"
+#include "xla/client/executable_build_options.h"
 #include "xla/client/lib/arithmetic.h"
 #include "xla/client/lib/constants.h"
 #include "xla/client/local_client.h"
+#include "xla/client/padding.h"
 #include "xla/client/xla_builder.h"
 #include "xla/client/xla_computation.h"
+#include "xla/layout.h"
+#include "xla/layout_util.h"
 #include "xla/literal.h"
 #include "xla/literal_util.h"
 #include "xla/service/platform_util.h"
+#include "xla/shape.h"
 #include "xla/shape_util.h"
+#include "xla/stream_executor/platform.h"
 #include "xla/xla_data.pb.h"
 #include "tensorflow/compiler/xrt/cc/ops/xrt_compile_ops.h"
 #include "tensorflow/compiler/xrt/cc/ops/xrt_execute_op.h"
 #include "tensorflow/compiler/xrt/cc/ops/xrt_state_ops.h"
 #include "tensorflow/compiler/xrt/xrt.pb.h"
 #include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/tstring.h"
 #include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/public/session_options.h"
 #include "tensorflow/core/util/command_line_flags.h"
+#include "tsl/lib/core/status_test_util.h"
+#include "tsl/platform/status.h"
+#include "tsl/platform/statusor.h"
 
 namespace tensorflow {
 namespace {
@@ -112,7 +132,7 @@
   auto pad_sum = xla::SetDimensionSize(sum, p2, 0);
   auto pad_sub = xla::SetDimensionSize(sub, p2 + one, 0);
   auto tuple = xla::Tuple(&builder, {pad_sum, sum, pad_sub});
-  return builder.Build(tuple, /*remove_dynamic_dimensions=*/true).value();
+  return builder.Build(tuple).value();
 }
 
 xla::XlaComputation AcceptDynamicR1Tuple() {
diff --git a/tensorflow/core/api_def/BUILD b/tensorflow/core/api_def/BUILD
index 4884972..b3efda8 100644
--- a/tensorflow/core/api_def/BUILD
+++ b/tensorflow/core/api_def/BUILD
@@ -6,12 +6,12 @@
 #   :python_api_def
 #   :java_api_def
 
-load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
 load(
     "//tensorflow:tensorflow.bzl",
     "tf_cc_binary",
     "tf_cc_test",
 )
+load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
 load(
     "//third_party/mkl:build_defs.bzl",
     "if_mkl",
@@ -116,5 +116,8 @@
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
         "//tensorflow/core/platform:resource_loader",
+        "//tensorflow/core/tpu/ops:sparse_core_ops",
+        "//tensorflow/core/tpu/ops:sparse_core_preprocess_ops",
+        "//tensorflow/core/tpu/ops:tpu_copy_with_dynamic_shape_op",
     ],
 )
diff --git a/tensorflow/core/api_def/base_api/api_def_ConvertToCooTensor.pbtxt b/tensorflow/core/api_def/base_api/api_def_ConvertToCooTensor.pbtxt
new file mode 100644
index 0000000..c9d6295
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ConvertToCooTensor.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "ConvertToCooTensor"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_GetMinibatchSplitsWithPhysicalReplica.pbtxt b/tensorflow/core/api_def/base_api/api_def_GetMinibatchSplitsWithPhysicalReplica.pbtxt
new file mode 100644
index 0000000..e402d2b
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_GetMinibatchSplitsWithPhysicalReplica.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "GetMinibatchSplitsWithPhysicalReplica"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_GetMinibatchesInCsrWithPhysicalReplica.pbtxt b/tensorflow/core/api_def/base_api/api_def_GetMinibatchesInCsrWithPhysicalReplica.pbtxt
new file mode 100644
index 0000000..49493ee
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_GetMinibatchesInCsrWithPhysicalReplica.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "GetMinibatchesInCsrWithPhysicalReplica"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_StoreMinibatchStatisticsInFdo.pbtxt b/tensorflow/core/api_def/base_api/api_def_StoreMinibatchStatisticsInFdo.pbtxt
new file mode 100644
index 0000000..9545ffd
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_StoreMinibatchStatisticsInFdo.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "StoreMinibatchStatisticsInFdo"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_TPUAnnotateTensorsWithDynamicShape.pbtxt b/tensorflow/core/api_def/base_api/api_def_TPUAnnotateTensorsWithDynamicShape.pbtxt
new file mode 100644
index 0000000..84ac58c
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_TPUAnnotateTensorsWithDynamicShape.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "TPUAnnotateTensorsWithDynamicShape"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_TPUCopyWithDynamicShape.pbtxt b/tensorflow/core/api_def/base_api/api_def_TPUCopyWithDynamicShape.pbtxt
new file mode 100644
index 0000000..423e1a2
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_TPUCopyWithDynamicShape.pbtxt
@@ -0,0 +1,8 @@
+op {
+  graph_op_name: "TPUCopyWithDynamicShape"
+  visibility: HIDDEN
+  summary: <<END
+Op that copies host tensor to device with dynamic shape support.
+For internal use only.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_XlaSparseCoreAdagrad.pbtxt b/tensorflow/core/api_def/base_api/api_def_XlaSparseCoreAdagrad.pbtxt
new file mode 100644
index 0000000..4775d3d
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_XlaSparseCoreAdagrad.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "XlaSparseCoreAdagrad"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_XlaSparseCoreAdagradMomentum.pbtxt b/tensorflow/core/api_def/base_api/api_def_XlaSparseCoreAdagradMomentum.pbtxt
new file mode 100644
index 0000000..8c0f70d
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_XlaSparseCoreAdagradMomentum.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "XlaSparseCoreAdagradMomentum"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_XlaSparseCoreAdam.pbtxt b/tensorflow/core/api_def/base_api/api_def_XlaSparseCoreAdam.pbtxt
new file mode 100644
index 0000000..289521f
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_XlaSparseCoreAdam.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "XlaSparseCoreAdam"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_XlaSparseCoreFtrl.pbtxt b/tensorflow/core/api_def/base_api/api_def_XlaSparseCoreFtrl.pbtxt
new file mode 100644
index 0000000..fc3122b
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_XlaSparseCoreFtrl.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "XlaSparseCoreFtrl"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_XlaSparseCoreSgd.pbtxt b/tensorflow/core/api_def/base_api/api_def_XlaSparseCoreSgd.pbtxt
new file mode 100644
index 0000000..9ea35a6
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_XlaSparseCoreSgd.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "XlaSparseCoreSgd"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmul.pbtxt b/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmul.pbtxt
new file mode 100644
index 0000000..acc0d15
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmul.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "XlaSparseDenseMatmul"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulGradWithAdagradAndCsrInput.pbtxt b/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulGradWithAdagradAndCsrInput.pbtxt
new file mode 100644
index 0000000..478d544
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulGradWithAdagradAndCsrInput.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "XlaSparseDenseMatmulGradWithAdagradAndCsrInput"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput.pbtxt b/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput.pbtxt
new file mode 100644
index 0000000..c8e69c3
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulGradWithAdamAndCsrInput.pbtxt b/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulGradWithAdamAndCsrInput.pbtxt
new file mode 100644
index 0000000..62f6b15
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulGradWithAdamAndCsrInput.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "XlaSparseDenseMatmulGradWithAdamAndCsrInput"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulGradWithFtrlAndCsrInput.pbtxt b/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulGradWithFtrlAndCsrInput.pbtxt
new file mode 100644
index 0000000..36a0c24
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulGradWithFtrlAndCsrInput.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "XlaSparseDenseMatmulGradWithFtrlAndCsrInput"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulGradWithSgdAndCsrInput.pbtxt b/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulGradWithSgdAndCsrInput.pbtxt
new file mode 100644
index 0000000..b5d12b4
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulGradWithSgdAndCsrInput.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "XlaSparseDenseMatmulGradWithSgdAndCsrInput"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulWithCsrInput.pbtxt b/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulWithCsrInput.pbtxt
new file mode 100644
index 0000000..e16a65c
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulWithCsrInput.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "XlaSparseDenseMatmulWithCsrInput"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_ConvertToCooTensor.pbtxt b/tensorflow/core/api_def/python_api/api_def_ConvertToCooTensor.pbtxt
new file mode 100644
index 0000000..c9d6295
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_ConvertToCooTensor.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "ConvertToCooTensor"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_GetMinibatchSplitsWithPhysicalReplica.pbtxt b/tensorflow/core/api_def/python_api/api_def_GetMinibatchSplitsWithPhysicalReplica.pbtxt
new file mode 100644
index 0000000..e402d2b
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_GetMinibatchSplitsWithPhysicalReplica.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "GetMinibatchSplitsWithPhysicalReplica"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_GetMinibatchesInCsrWithPhysicalReplica.pbtxt b/tensorflow/core/api_def/python_api/api_def_GetMinibatchesInCsrWithPhysicalReplica.pbtxt
new file mode 100644
index 0000000..49493ee
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_GetMinibatchesInCsrWithPhysicalReplica.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "GetMinibatchesInCsrWithPhysicalReplica"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_StoreMinibatchStatisticsInFdo.pbtxt b/tensorflow/core/api_def/python_api/api_def_StoreMinibatchStatisticsInFdo.pbtxt
new file mode 100644
index 0000000..9545ffd
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StoreMinibatchStatisticsInFdo.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "StoreMinibatchStatisticsInFdo"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TPUAnnotateTensorsWithDynamicShape.pbtxt b/tensorflow/core/api_def/python_api/api_def_TPUAnnotateTensorsWithDynamicShape.pbtxt
new file mode 100644
index 0000000..84ac58c
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TPUAnnotateTensorsWithDynamicShape.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "TPUAnnotateTensorsWithDynamicShape"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TPUCopyWithDynamicShape.pbtxt b/tensorflow/core/api_def/python_api/api_def_TPUCopyWithDynamicShape.pbtxt
new file mode 100644
index 0000000..f628f72
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TPUCopyWithDynamicShape.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "TPUCopyWithDynamicShape"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_XlaSparseCoreAdagrad.pbtxt b/tensorflow/core/api_def/python_api/api_def_XlaSparseCoreAdagrad.pbtxt
new file mode 100644
index 0000000..4775d3d
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_XlaSparseCoreAdagrad.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "XlaSparseCoreAdagrad"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_XlaSparseCoreAdagradMomentum.pbtxt b/tensorflow/core/api_def/python_api/api_def_XlaSparseCoreAdagradMomentum.pbtxt
new file mode 100644
index 0000000..8c0f70d
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_XlaSparseCoreAdagradMomentum.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "XlaSparseCoreAdagradMomentum"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_XlaSparseCoreAdam.pbtxt b/tensorflow/core/api_def/python_api/api_def_XlaSparseCoreAdam.pbtxt
new file mode 100644
index 0000000..289521f
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_XlaSparseCoreAdam.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "XlaSparseCoreAdam"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_XlaSparseCoreFtrl.pbtxt b/tensorflow/core/api_def/python_api/api_def_XlaSparseCoreFtrl.pbtxt
new file mode 100644
index 0000000..fc3122b
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_XlaSparseCoreFtrl.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "XlaSparseCoreFtrl"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_XlaSparseCoreSgd.pbtxt b/tensorflow/core/api_def/python_api/api_def_XlaSparseCoreSgd.pbtxt
new file mode 100644
index 0000000..9ea35a6
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_XlaSparseCoreSgd.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "XlaSparseCoreSgd"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmul.pbtxt b/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmul.pbtxt
new file mode 100644
index 0000000..acc0d15
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmul.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "XlaSparseDenseMatmul"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulGradWithAdagradAndCsrInput.pbtxt b/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulGradWithAdagradAndCsrInput.pbtxt
new file mode 100644
index 0000000..478d544
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulGradWithAdagradAndCsrInput.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "XlaSparseDenseMatmulGradWithAdagradAndCsrInput"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput.pbtxt b/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput.pbtxt
new file mode 100644
index 0000000..c8e69c3
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulGradWithAdamAndCsrInput.pbtxt b/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulGradWithAdamAndCsrInput.pbtxt
new file mode 100644
index 0000000..62f6b15
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulGradWithAdamAndCsrInput.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "XlaSparseDenseMatmulGradWithAdamAndCsrInput"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulGradWithFtrlAndCsrInput.pbtxt b/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulGradWithFtrlAndCsrInput.pbtxt
new file mode 100644
index 0000000..36a0c24
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulGradWithFtrlAndCsrInput.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "XlaSparseDenseMatmulGradWithFtrlAndCsrInput"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulGradWithSgdAndCsrInput.pbtxt b/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulGradWithSgdAndCsrInput.pbtxt
new file mode 100644
index 0000000..b5d12b4
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulGradWithSgdAndCsrInput.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "XlaSparseDenseMatmulGradWithSgdAndCsrInput"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulWithCsrInput.pbtxt b/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulWithCsrInput.pbtxt
new file mode 100644
index 0000000..e16a65c
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulWithCsrInput.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "XlaSparseDenseMatmulWithCsrInput"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/build_defs.bzl b/tensorflow/core/build_defs.bzl
new file mode 100644
index 0000000..b9952c2
--- /dev/null
+++ b/tensorflow/core/build_defs.bzl
@@ -0,0 +1,88 @@
+"""Defines the build rules to disable non-core TF libraries."""
+
+load("//third_party/bazel_rules/rules_python/python:py_binary.bzl", "py_binary")
+
+def _tf_core_transition_impl(settings, attr):
+    _ignore = (settings, attr)  # @unused
+    return {"@local_tsl//tsl/framework/contraction:disable_onednn_contraction_kernel": True}
+
+_tf_core_transition = transition(
+    implementation = _tf_core_transition_impl,
+    inputs = [],
+    outputs = ["@local_tsl//tsl/framework/contraction:disable_onednn_contraction_kernel"],
+)
+
+def _py_binary_tf_core_impl(ctx):
+    out = ctx.actions.declare_file(ctx.label.name)
+
+    # Put the py binary in the expected location.
+    ctx.actions.run_shell(
+        inputs = [ctx.executable.py_binary],
+        outputs = [out],
+        command = "cp %s %s" % (ctx.executable.py_binary.path, out.path),
+    )
+
+    wrapped_defaultinfo = ctx.attr.py_binary[0][DefaultInfo]
+    runfiles = ctx.runfiles(files = [out])
+    wrapped_default_runfiles = wrapped_defaultinfo.default_runfiles.files.to_list()
+
+    # Remove the wrapped py_binary from the runfiles
+    if ctx.executable.py_binary in wrapped_default_runfiles:
+        wrapped_default_runfiles.remove(ctx.executable.py_binary)
+
+    return [
+        DefaultInfo(
+            executable = out,
+            files = depset([out]),
+            # Merge the wrapped executable's data into runfiles
+            runfiles = runfiles.merge(ctx.runfiles(files = wrapped_default_runfiles)),
+        ),
+    ]
+
+# This rule sets the flag values to disable non-core TF libraries when compiling the referenced
+# py_binary.
+_py_binary_tf_core = rule(
+    implementation = _py_binary_tf_core_impl,
+    attrs = {
+        "py_binary": attr.label(
+            cfg = _tf_core_transition,
+            mandatory = True,
+            executable = True,
+        ),
+        # Deps is unused, but some other rules assume all targets have a "deps" attribute
+        # (such as scaffolding_registration_test)
+        "deps": attr.label_list(
+            default = [],
+        ),
+        "_allowlist_function_transition": attr.label(
+            default = "@bazel_tools//tools/allowlists/function_transition_allowlist",
+        ),
+    },
+    # Marking this executable means it works with "$ bazel run".
+    executable = True,
+)
+
+def py_binary_tf_core(name, visibility = [], **kwargs):
+    """A wrapper of py_binary that disables non-core TF libraries.
+
+    Args:
+      name: The name of the resulting binary.
+      visibility: The visibility of the resulting binary.
+      **kwargs: All other args are passed to the wrapped py_binary.
+    """
+
+    wrapped_binary_name = "%s_wrapped_binary" % name
+
+    # When users reference ":${name}" they will actually reference the output
+    # of this transition rule, instead of the wrapped binary. This causes the
+    # build system to apply our transition when evaluating the build graph.
+    _py_binary_tf_core(
+        name = name,
+        py_binary = ":%s" % wrapped_binary_name,
+        visibility = visibility,
+    )
+    py_binary(
+        name = wrapped_binary_name,
+        visibility = visibility,
+        **kwargs
+    )
diff --git a/tensorflow/core/common_runtime/BUILD b/tensorflow/core/common_runtime/BUILD
index 89bf11e..bf5b15e 100644
--- a/tensorflow/core/common_runtime/BUILD
+++ b/tensorflow/core/common_runtime/BUILD
@@ -1,7 +1,6 @@
 load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
 load(
     "//tensorflow:tensorflow.bzl",
-    "if_google",
     "if_libtpu",
     "if_macos",
     "if_oss",
@@ -50,13 +49,6 @@
 package(
     # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
     default_visibility = default_package_visibility,
-    features = if_google(
-        [
-            "-layering_check",
-            "-parse_headers",
-        ],
-        ["-layering_check"],
-    ),
     licenses = ["notice"],
 )
 
@@ -78,6 +70,7 @@
         "//tensorflow/core/public:session.h",
         "//tensorflow/core/public:session_options.h",
     ],
+    features = ["-layering_check"],
     visibility = ["//visibility:public"],
     deps = [
         ":core_cpu_internal",
@@ -86,6 +79,7 @@
 
 cc_header_only_library(
     name = "core_cpu_headers_lib",
+    features = ["-parse_headers"],
     visibility = ["//visibility:public"],
     deps = [
         ":core_cpu_lib",
@@ -137,6 +131,7 @@
     srcs = ["collective_test_util.cc"],
     hdrs = ["collective_test_util.h"],
     copts = tf_copts(),
+    features = ["-layering_check"],
     deps = [
         ":device_resolver_local",
         ":process_util",
@@ -226,6 +221,7 @@
         "//tensorflow/core/public:session.h",
     ],
     copts = tf_copts(),
+    features = ["-layering_check"],
     deps = [
         ":scoped_allocator",
         ":stats_publisher_interface",
@@ -333,6 +329,7 @@
     srcs = ["all_to_all.cc"],
     hdrs = ["all_to_all.h"],
     copts = tf_copts(),
+    features = ["-layering_check"],
     deps = [
         ":base_collective_executor",
         ":collective_rma_local",
@@ -356,6 +353,7 @@
         "arg_ret_placement.h",
     ],
     copts = tf_copts(),
+    features = ["-layering_check"],
     deps = [
         "//tensorflow/core:graph",
     ],
@@ -387,6 +385,7 @@
     srcs = ["buf_rendezvous.cc"],
     hdrs = ["buf_rendezvous.h"],
     copts = tf_copts(),
+    features = ["-layering_check"],
     deps = [
         ":device",
         ":device_mgr",
@@ -463,6 +462,7 @@
     srcs = ["collective_param_resolver_local.cc"],
     hdrs = ["collective_param_resolver_local.h"],
     copts = tf_copts(),
+    features = ["-layering_check"],
     deps = [
         ":device_mgr",
         "//tensorflow/core:framework",
@@ -476,6 +476,7 @@
     srcs = ["collective_rma_local.cc"],
     hdrs = ["collective_rma_local.h"],
     copts = tf_copts(),
+    features = ["-layering_check"],
     deps = [
         ":buf_rendezvous",
         ":copy_tensor",
@@ -498,6 +499,7 @@
         "inspecting_placer.h",
     ],
     copts = tf_copts(),
+    features = ["-layering_check"],
     deps = [
         ":composite_device",
         ":device",
@@ -523,6 +525,7 @@
     srcs = ["composite_device.cc"],
     hdrs = ["composite_device.h"],
     copts = tf_copts(),
+    features = ["-layering_check"],
     deps = [
         ":device",
         "//tensorflow/core:framework",
@@ -536,6 +539,7 @@
     srcs = ["constant_folding.cc"],
     hdrs = ["constant_folding.h"],
     copts = tf_copts(),
+    features = ["-layering_check"],
     deps = [
         ":device",
         ":device_factory",
@@ -557,6 +561,7 @@
     srcs = ["costmodel_manager.cc"],
     hdrs = ["costmodel_manager.h"],
     copts = tf_copts(),
+    features = ["-layering_check"],
     deps = [
         "//tensorflow/core:framework",
         "//tensorflow/core:graph",
@@ -570,6 +575,7 @@
     srcs = ["debugger_state_interface.cc"],
     hdrs = ["debugger_state_interface.h"],
     copts = tf_copts(),
+    features = ["-layering_check"],
     deps = [
         ":device",
         "//tensorflow/core:graph",
@@ -581,6 +587,7 @@
     name = "device",
     hdrs = ["device.h"],
     copts = tf_copts(),
+    features = ["-layering_check"],
     deps = [
         "//tensorflow/core:framework_internal",
     ],
@@ -590,6 +597,7 @@
     name = "device_factory",
     hdrs = ["device_factory.h"],
     copts = tf_copts(),
+    features = ["-layering_check"],
     deps = [
         "//tensorflow/core:framework_internal",
     ],
@@ -603,6 +611,7 @@
     ],
     hdrs = ["device_mgr.h"],
     copts = tf_copts(),
+    features = ["-layering_check"],
     deps = [
         ":device",
         ":local_device",
@@ -632,6 +641,7 @@
     name = "entry",
     hdrs = ["entry.h"],
     copts = tf_copts(),
+    features = ["-layering_check"],
     deps = [
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
@@ -643,6 +653,7 @@
     srcs = ["executor.cc"],
     hdrs = ["executor.h"],
     copts = tf_copts(),
+    features = ["-layering_check"],
     deps = [
         ":costmodel_manager",
         ":device",
@@ -692,6 +703,7 @@
     srcs = ["type_inference.cc"],
     hdrs = ["type_inference.h"],
     copts = tf_copts(),
+    features = ["-layering_check"],
     visibility = default_package_visibility,
     deps = [
         ":optimization_registry",
@@ -708,6 +720,7 @@
     srcs = ["single_threaded_executor.cc"],
     hdrs = ["single_threaded_executor.h"],
     copts = tf_copts(),
+    features = ["-layering_check"],
     deps = [
         ":entry",
         ":executor",
@@ -754,6 +767,7 @@
     name = "type_inference_test",
     size = "small",
     srcs = ["type_inference_test.cc"],
+    features = ["-layering_check"],
     deps = [
         ":core_cpu",
         ":core_cpu_internal",
@@ -810,6 +824,7 @@
     srcs = ["device_set.cc"],
     hdrs = ["device_set.h"],
     copts = tf_copts(),
+    features = ["-layering_check"],
     deps = [
         ":device",
         ":device_factory",
@@ -837,6 +852,7 @@
         "process_function_library_runtime.h",
     ],
     copts = tf_copts(),
+    features = ["-layering_check"],
     deps = [
         ":arg_ret_placement",
         ":composite_device",
@@ -872,6 +888,8 @@
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
         "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core/config:flag_defs",
+        "//tensorflow/core/config:flags",
         "//tensorflow/core/profiler/lib:connected_traceme",
         "//tensorflow/core/profiler/lib:traceme",
         "@com_google_absl//absl/algorithm:container",
@@ -889,10 +907,10 @@
     hdrs = ["function_body.h"],
     copts = tf_copts(),
     deps = [
-        ":device",
         "//tensorflow/core:framework",
         "//tensorflow/core:graph",
         "//tensorflow/core:lib",
+        "//tensorflow/core/platform:refcount",
     ],
 )
 
@@ -913,6 +931,7 @@
     srcs = ["function_def_utils.cc"],
     hdrs = ["function_def_utils.h"],
     copts = tf_copts(),
+    features = ["-layering_check"],
     deps = [
         ":function_body",
         ":graph_constructor",
@@ -973,6 +992,7 @@
         ":core_cpu_lib_headers",
     ],
     copts = tf_copts(),
+    features = ["-layering_check"],
     visibility = default_package_visibility + [
         "//platforms/performance/autograppler:__subpackages__",
         "//platforms/performance/tf_sim:__subpackages__",
@@ -1026,6 +1046,7 @@
     srcs = ["graph_optimizer.cc"],
     hdrs = ["graph_optimizer.h"],
     copts = tf_copts(),
+    features = ["-layering_check"],
     deps = [
         ":constant_folding",
         ":function_utils",
@@ -1073,6 +1094,7 @@
     srcs = ["immutable_executor_state.cc"],
     hdrs = ["immutable_executor_state.h"],
     copts = tf_copts(),
+    features = ["-layering_check"],
     deps = [
         ":graph_view",
         ":local_executor_params",
@@ -1164,6 +1186,7 @@
     srcs = ["local_device.cc"],
     hdrs = ["local_device.h"],
     copts = tf_copts(),
+    features = ["-layering_check"],
     deps = [
         ":device",
         ":process_state",
@@ -1220,6 +1243,7 @@
         "//tensorflow/core:framework",
         "//tensorflow/core:graph",
         "//tensorflow/core:lib",
+        "//tensorflow/core/platform:refcount",
         "@com_google_absl//absl/algorithm:container",
     ],
 )
@@ -1260,6 +1284,7 @@
         "optimization_registry.h",
     ],
     copts = tf_copts(),
+    features = ["-layering_check"],
     deps = [
         ":optimization_registry",
         "//tensorflow/core:framework",
@@ -1379,6 +1404,7 @@
     srcs = ["node_file_writer.cc"],
     hdrs = ["node_file_writer.h"],
     copts = tf_copts(),
+    features = ["-layering_check"],
     deps = [
         "//tensorflow/core:framework",
         "//tensorflow/core:framework_internal",
@@ -1434,6 +1460,7 @@
     name = "pending_counts",
     hdrs = ["pending_counts.h"],
     copts = tf_copts(),
+    features = ["-parse_headers"],
     deps = [
         "//tensorflow/core:lib",
     ],
@@ -1465,6 +1492,7 @@
     srcs = ["pool_allocator.cc"],
     hdrs = ["pool_allocator.h"],
     copts = tf_copts(),
+    features = ["-layering_check"],
     deps = [
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
@@ -1478,6 +1506,7 @@
     srcs = ["placer.cc"],
     hdrs = ["placer.h"],
     copts = tf_copts(),
+    features = ["-layering_check"],
     deps = [
         ":colocation_graph",
         ":device",
@@ -1499,6 +1528,7 @@
         "//tensorflow/core:framework",
         "//tensorflow/core:graph",
         "//tensorflow/core:lib",
+        "//tensorflow/core/platform:refcount",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:optional",
     ],
@@ -1525,6 +1555,7 @@
     srcs = ["process_util.cc"],
     hdrs = ["process_util.h"],
     copts = tf_copts() + tf_openmp_copts(),
+    features = ["-layering_check"],
     linkopts = tf_openmp_lopts(),
     deps = [
         ":session_options",
@@ -1538,6 +1569,7 @@
     name = "profile_handler",
     hdrs = ["profile_handler.h"],
     copts = tf_copts(),
+    features = ["-layering_check"],
     deps = [
         "//tensorflow/core:framework",
         "//tensorflow/core:graph",
@@ -1608,6 +1640,7 @@
     srcs = ["renamed_device.cc"],
     hdrs = ["renamed_device.h"],
     copts = tf_copts(),
+    features = ["-layering_check"],
     deps = [
         ":device",
         "//tensorflow/core:lib",
@@ -1622,6 +1655,7 @@
     srcs = ["rendezvous_mgr.cc"],
     hdrs = ["rendezvous_mgr.h"],
     copts = tf_copts(),
+    features = ["-layering_check"],
     deps = [
         ":copy_tensor",
         ":device",
@@ -1700,6 +1734,7 @@
     srcs = ["rendezvous_util.cc"],
     hdrs = ["rendezvous_util.h"],
     copts = tf_copts(),
+    features = ["-layering_check"],
     deps = [
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
@@ -1758,6 +1793,7 @@
     srcs = ["session.cc"],
     hdrs = ["//tensorflow/core/public:session.h"],
     copts = tf_copts(),
+    features = ["-layering_check"],
     deps = [
         ":session_factory",
         "//tensorflow/core:framework",
@@ -1847,6 +1883,7 @@
     srcs = ["stats_publisher_interface.cc"],
     hdrs = ["stats_publisher_interface.h"],
     copts = tf_copts(),
+    features = ["-layering_check"],
     deps = [
         ":build_graph_options",
         ":profile_handler",
@@ -1875,6 +1912,7 @@
     srcs = ["threadpool_device.cc"],
     hdrs = ["threadpool_device.h"],
     copts = tf_copts() + tf_openmp_copts(),
+    features = ["-layering_check"],
     linkopts = tf_openmp_lopts(),
     deps = [
         ":device_factory",
@@ -1914,6 +1952,7 @@
     name = "core_cpu_impl",
     hdrs = [":core_cpu_lib_headers"],
     copts = tf_copts(),
+    features = ["-layering_check"],
     deps = [
         ":accumulate_n_optimizer",
         ":all_to_all",
@@ -1984,6 +2023,7 @@
 tf_cuda_library(
     name = "core_cpu_lib",
     hdrs = [":core_cpu_lib_headers"],
+    features = ["-layering_check"],
     deps = [
         "//tensorflow/core:core_cpu_base",
         "//tensorflow/core/grappler:grappler_item",
@@ -1993,6 +2033,7 @@
 tf_cuda_library(
     name = "core_cpu_lib_no_ops",
     hdrs = [":core_cpu_lib_headers"],
+    features = ["-layering_check"],
     deps = [
         ":core_cpu_base_no_ops",
         "//tensorflow/core/grappler:grappler_item",
@@ -2009,6 +2050,7 @@
         ":core_cpu_lib_headers",
     ],
     copts = tf_copts(),
+    features = ["-layering_check"],
     deps = [
         "//tensorflow/core:framework",
         "//tensorflow/core:graph",
@@ -2040,7 +2082,10 @@
         "allocator_retry.h",
     ],
     hdrs = ["bfc_allocator.h"],
-    features = ["parse_headers"],
+    features = [
+        "-layering_check",
+        "parse_headers",
+    ],
     visibility = ["//visibility:public"],
     deps = [
         ":shared_counter",
@@ -2076,6 +2121,7 @@
         "direct_session.h",
     ],
     copts = tf_copts(),
+    features = ["-layering_check"],
     deps = [
         ":core_cpu_internal",
         ":local_session_selection",
@@ -2292,6 +2338,7 @@
 tf_cuda_cc_test(
     name = "all_to_all_test",
     srcs = ["all_to_all_test.cc"],
+    features = ["-layering_check"],
     tags = ["no_cuda_on_cpu_tap"],
     deps = [
         ":collective_test_util",
@@ -2390,6 +2437,7 @@
     srcs = [
         "replicate_constants_pass_test.cc",
     ],
+    features = ["-layering_check"],
     deps = [
         ":core",
         ":core_cpu",
@@ -2679,6 +2727,7 @@
     name = "process_function_library_runtime_test",
     size = "small",
     srcs = ["process_function_library_runtime_test.cc"],
+    features = ["-layering_check"],
     deps = [
         ":core_cpu",
         ":core_cpu_internal",
@@ -2783,6 +2832,7 @@
     size = "medium",
     srcs = ["direct_session_test.cc"],
     args = [] + if_cuda(["--heap_check="]),  # The GPU tracer leaks memory
+    features = ["-layering_check"],
     deps = [
         ":core_cpu",
         ":core_cpu_internal",
@@ -2822,6 +2872,7 @@
 tf_cc_test(
     name = "direct_session_with_debug_test",
     srcs = ["direct_session_test.cc"],
+    features = ["-layering_check"],
     deps = [
         ":core",
         ":core_cpu",
@@ -2956,6 +3007,7 @@
     name = "function_test",
     size = "small",
     srcs = ["function_test.cc"],
+    features = ["-layering_check"],
     tags = [
         "manual",
         "no_oss",
@@ -3043,6 +3095,7 @@
     name = "inline_function_utils_test",
     size = "small",
     srcs = ["inline_function_utils_test.cc"],
+    features = ["-layering_check"],
     deps = [
         "//tensorflow/core:framework",
         "//tensorflow/core:test_main",
@@ -3212,6 +3265,7 @@
     name = "graph_constructor_test",
     size = "small",
     srcs = ["graph_constructor_test.cc"],
+    features = ["-layering_check"],
     linkopts = select({
         "//tensorflow:macos": ["-headerpad_max_install_names"],
         "//conditions:default": [],
@@ -3244,6 +3298,7 @@
 tf_cc_test(
     name = "cost_measurement_registry_test",
     srcs = ["cost_measurement_registry_test.cc"],
+    features = ["-layering_check"],
     deps = [
         ":cost_measurement_registry",
         "//tensorflow/core:test",
@@ -3255,6 +3310,7 @@
 tf_cc_test(
     name = "no_op_cost_measurement_test",
     srcs = ["no_op_cost_measurement_test.cc"],
+    features = ["-layering_check"],
     deps = [
         ":no_op_cost_measurement",
         "//tensorflow/core:test",
@@ -3297,6 +3353,7 @@
 tf_cc_test(
     name = "cost_util_test",
     srcs = ["cost_util_test.cc"],
+    features = ["-layering_check"],
     deps = [
         ":cost_measurement_registry",
         ":cost_util",
@@ -3311,6 +3368,7 @@
     name = "device_propagation_test",
     size = "small",
     srcs = ["device_propagation_test.cc"],
+    features = ["-layering_check"],
     deps = [
         ":device_propagation",
         "//tensorflow/cc:array_ops",
@@ -3329,6 +3387,7 @@
     srcs = ["optimized_function_graph_info.cc"],
     hdrs = ["optimized_function_graph_info.h"],
     copts = tf_copts(),
+    features = ["-layering_check"],
     visibility = ["//visibility:public"],
     deps = [
         ":graph_constructor",
@@ -3340,6 +3399,7 @@
 tf_cc_test(
     name = "optimized_function_graph_info_test",
     srcs = ["optimized_function_graph_info_test.cc"],
+    features = ["-layering_check"],
     tags = if_oss([
         "no_oss",
     ]),  # b/169705709, no protobuf matchers in OSS.
@@ -3362,6 +3422,7 @@
     srcs = ["optimize_function_graph_utils.cc"],
     hdrs = ["optimize_function_graph_utils.h"],
     copts = tf_copts(),
+    features = ["-layering_check"],
     deps = [
         ":composite_device",
         ":device_set",
@@ -3388,6 +3449,7 @@
 tf_cc_test(
     name = "optimize_function_graph_utils_test",
     srcs = ["optimize_function_graph_utils_test.cc"],
+    features = ["-layering_check"],
     deps = [
         ":device",
         ":device_factory",
@@ -3410,6 +3472,7 @@
     srcs = [
         "int32_fulltype_test.cc",
     ],
+    features = ["-layering_check"],
     deps = [
         ":core",
         ":int32_fulltype",
@@ -3427,6 +3490,7 @@
     srcs = [
         "arg_ret_placement_test.cc",
     ],
+    features = ["-layering_check"],
     deps = [
         ":arg_ret_placement",
         "//tensorflow/cc:scope",
@@ -3464,6 +3528,7 @@
     srcs = ["serving_device_selector_policies.cc"],
     hdrs = ["serving_device_selector_policies.h"],
     copts = tf_copts(),
+    features = ["-layering_check"],
     deps = [
         ":serving_device_selector",
     ],
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index db5984c..587137c 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -694,8 +694,9 @@
           ((measure_step_count + 1) % build_cost_model_every == 0);
     }
   }
-  if (do_trace || update_cost_model ||
-      run_options.report_tensor_allocations_upon_oom()) {
+  if (run_metadata != nullptr &&
+      (do_trace || update_cost_model ||
+       run_options.report_tensor_allocations_upon_oom())) {
     run_state.collector.reset(
         new StepStatsCollector(run_metadata->mutable_step_stats()));
     args.stats_collector = run_state.collector.get();
@@ -781,7 +782,7 @@
     run_status.Update(errors::Cancelled("Run call was cancelled"));
   }
 
-  if (device_profiler_session) {
+  if (run_metadata != nullptr && device_profiler_session) {
     TF_RETURN_IF_ERROR(device_profiler_session->CollectData(
         run_metadata->mutable_step_stats()));
   }
@@ -814,11 +815,13 @@
     mutex_lock l(executor_lock_);
     run_state.collector->BuildCostModel(&cost_model_manager_, device_to_graph);
 
-    // annotate stats onto cost graph.
-    CostGraphDef* cost_graph = run_metadata->mutable_cost_graph();
-    for (const auto& item : executors_and_keys->items) {
-      TF_RETURN_IF_ERROR(
-          cost_model_manager_.AddToCostGraphDef(item.graph.get(), cost_graph));
+    if (run_metadata != nullptr) {
+      // annotate stats onto cost graph.
+      CostGraphDef* cost_graph = run_metadata->mutable_cost_graph();
+      for (const auto& item : executors_and_keys->items) {
+        TF_RETURN_IF_ERROR(cost_model_manager_.AddToCostGraphDef(
+            item.graph.get(), cost_graph));
+      }
     }
   }
 
@@ -828,7 +831,7 @@
       return errors::InvalidArgument(
           "RunOptions.output_partition_graphs() is not supported when "
           "disable_output_partition_graphs is true.");
-    } else {
+    } else if (run_metadata != nullptr) {
       protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
           run_metadata->mutable_partition_graphs();
       for (const PerPartitionExecutorsAndLib& exec_and_lib :
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 208ed4b..227821a 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -16,6 +16,7 @@
 #include "tensorflow/core/common_runtime/function.h"
 
 #include <deque>
+#include <utility>
 #include <vector>
 
 #include "absl/algorithm/container.h"
@@ -25,6 +26,7 @@
 #include "tensorflow/core/common_runtime/device.h"
 #include "tensorflow/core/common_runtime/executor.h"
 #include "tensorflow/core/common_runtime/executor_factory.h"
+#include "tensorflow/core/common_runtime/function_def_utils.h"
 #include "tensorflow/core/common_runtime/gradients.h"
 #include "tensorflow/core/common_runtime/graph_constructor.h"
 #include "tensorflow/core/common_runtime/graph_optimizer.h"
@@ -50,10 +52,12 @@
 #include "tensorflow/core/lib/core/threadpool.h"
 #include "tensorflow/core/lib/gtl/map_util.h"
 #include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/refcount.h"
 #include "tensorflow/core/platform/str_util.h"
 #include "tensorflow/core/profiler/lib/connected_traceme.h"
 #include "tensorflow/core/profiler/lib/traceme.h"
 #include "tensorflow/core/protobuf/config.pb.h"
+#include "tsl/platform/statusor.h"
 
 // See core/kernels/function_ops.cc for related kernels.
 
@@ -152,8 +156,8 @@
 class FunctionLibraryRuntimeOverlay : public FunctionLibraryRuntime {
  public:
   FunctionLibraryRuntimeOverlay(FunctionLibraryRuntime* base_flr,
-                                const FunctionLibraryDefinition* lib_def)
-      : base_flr_(base_flr), lib_def_(lib_def) {}
+                                FunctionLibraryDefinition lib_def)
+      : base_flr_(base_flr), lib_def_(std::move(lib_def)) {}
   ~FunctionLibraryRuntimeOverlay() override;
 
   Status Instantiate(const string& function_name, AttrSlice attrs,
@@ -203,7 +207,7 @@
 
  private:
   FunctionLibraryRuntime* base_flr_;          // not owned
-  const FunctionLibraryDefinition* lib_def_;  // not owned
+  const FunctionLibraryDefinition lib_def_;
 };
 
 FunctionLibraryRuntimeOverlay::~FunctionLibraryRuntimeOverlay() = default;
@@ -213,9 +217,9 @@
     const InstantiateOptions& options, Handle* handle) {
   // We automatically set the `lib_def` option for all instantiations, if the
   // caller doesn't set this option explicitly.
-  if (!options.lib_def && lib_def_) {
+  if (!options.lib_def) {
     InstantiateOptions options_copy = options;
-    options_copy.lib_def = lib_def_;
+    options_copy.lib_def = &lib_def_;
     return base_flr_->Instantiate(function_name, attrs, options_copy, handle);
   } else {
     return base_flr_->Instantiate(function_name, attrs, options, handle);
@@ -276,7 +280,7 @@
     const string& function_name) const {
   // Important: we do not forward lookup to the base FLR.
   const OpDef* op_def;
-  const Status s = lib_def_->LookUpOpDef(function_name, &op_def);
+  const Status s = lib_def_.LookUpOpDef(function_name, &op_def);
   return s.ok() && op_def->is_stateful();
 }
 
@@ -303,7 +307,7 @@
 
 const FunctionLibraryDefinition*
 FunctionLibraryRuntimeOverlay::GetFunctionLibraryDefinition() const {
-  return lib_def_ ? lib_def_ : base_flr_->GetFunctionLibraryDefinition();
+  return &lib_def_;
 }
 
 string FunctionLibraryRuntimeOverlay::DebugString(Handle handle) {
@@ -437,7 +441,8 @@
   // FunctionLibraryDefinition.
   Status CreateKernel(const std::shared_ptr<const NodeProperties>& props,
                       FunctionLibraryRuntime* flr, OpKernel** kernel);
-  Status FunctionDefToBody(const FunctionDef& fdef, AttrSlice attrs,
+  Status FunctionDefToBody(core::RefCountPtr<FunctionRecord>&& record,
+                           AttrSlice attrs,
                            const FunctionLibraryDefinition* lib_def,
                            std::unique_ptr<FunctionBody>* fbody);
   Status CreateItem(Item** item);
@@ -666,7 +671,7 @@
   // Constructs a CallOp kernel for running the instantiated function.
   auto device_type = DeviceType(device_->attributes().device_type());
   auto new_props = std::make_shared<NodeProperties>(
-      &fbody->fdef.signature(), props->node_def, fbody->arg_types,
+      &fbody->record->fdef().signature(), props->node_def, fbody->arg_types,
       fbody->ret_types);
   OpKernelConstruction construction(
       device_type, device_, device_->GetAllocator(AllocatorAttributes()), flr,
@@ -679,16 +684,18 @@
 }
 
 Status FunctionLibraryRuntimeImpl::FunctionDefToBody(
-    const FunctionDef& fdef, AttrSlice attrs,
+    core::RefCountPtr<FunctionRecord>&& record, AttrSlice attrs,
     const FunctionLibraryDefinition* lib_def,
     std::unique_ptr<FunctionBody>* fbody) {
   if (lib_def == base_lib_def_) {
-    return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig_, fbody);
+    return FunctionDefToBodyHelper(std::move(record), attrs, lib_def,
+                                   get_func_sig_, fbody);
   } else {
     auto get_func_sig = [lib_def](const string& op, const OpDef** sig) {
       return lib_def->LookUpOpDef(op, sig);
     };
-    return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig, fbody);
+    return FunctionDefToBodyHelper(std::move(record), attrs, lib_def,
+                                   get_func_sig, fbody);
   }
 }
 
@@ -708,8 +715,10 @@
     // TODO(josh11b): Should filter out the attrs from func that aren't used
     // by the gradient function.
     TF_RETURN_IF_ERROR(creator(AttrSlice(&func.attr()), &grad_fdef));
-    TF_RETURN_IF_ERROR(
-        FunctionDefToBody(grad_fdef, AttrSlice(&func.attr()), lib_def, g_body));
+    core::RefCountPtr<FunctionRecord> record(
+        new FunctionRecord(std::move(grad_fdef), {}, true));
+    TF_RETURN_IF_ERROR(FunctionDefToBody(
+        std::move(record), AttrSlice(&func.attr()), lib_def, g_body));
   } else {
     // f is a user-defined function.
     InstantiateOptions options;
@@ -805,11 +814,12 @@
     }
     TF_RETURN_IF_ERROR(InstantiateSymbolicGradient(func, lib_def, &fbody));
   } else {
-    const FunctionDef* fdef = lib_def->Find(function_name);
+    core::RefCountPtr<FunctionRecord> fdef = lib_def->FindRecord(function_name);
     if (fdef == nullptr) {
       return errors::NotFound("Function ", function_name, " is not defined.");
     }
-    TF_RETURN_IF_ERROR(FunctionDefToBody(*fdef, attrs, lib_def, &fbody));
+    TF_RETURN_IF_ERROR(
+        FunctionDefToBody(std::move(fdef), attrs, lib_def, &fbody));
     Int32FulltypePass int32_fulltype("FunctionLibraryRuntime::Instantiate");
     TF_RETURN_IF_ERROR(
         int32_fulltype.ProcessGraph(fbody->graph, /*ints_on_device=*/false));
@@ -833,8 +843,11 @@
       item->allow_control_flow_sync_execution =
           options.allow_control_flow_sync_execution;
       if (options.lib_def) {
-        item->overlay_flr.reset(
-            new FunctionLibraryRuntimeOverlay(this, options.lib_def));
+        TF_ASSIGN_OR_RETURN(
+            FunctionLibraryDefinition reachable_lib_def,
+            options.lib_def->ReachableDefinitions(function_name));
+        item->overlay_flr.reset(new FunctionLibraryRuntimeOverlay(
+            this, std::move(reachable_lib_def)));
       }
       local_handle = next_handle_++;
       items_->emplace(local_handle, std::unique_ptr<Item>(item));
@@ -945,7 +958,7 @@
   auto g = std::make_unique<Graph>(lib_def);
   CopyGraph(*fbody->graph, g.get());
 
-  PruneFunctionBody(fbody->fdef, g.get());
+  PruneFunctionBody(fbody->record->fdef(), g.get());
   optimizer_.Optimize(this, env(), device(), &g, GraphOptimizer::Options());
   TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device()->device_type()),
                                        device()->name(), g.get()));
@@ -1448,7 +1461,10 @@
 
   // Copy just the fdef attributes (copy '_noinline' and other similar flags to
   // the gradient function body).
-  *(gbody->fdef.mutable_attr()) = fbody_->fdef.attr();
+  FunctionDef fdef;
+  *(fdef.mutable_attr()) = fbody_->record->fdef().attr();
+  gbody->record = core::RefCountPtr<FunctionRecord>(
+      new FunctionRecord(std::move(fdef), {}, true));
 
   // Copy the nodes.
   node_map[src.source_node()->id()] = dst->source_node();
diff --git a/tensorflow/core/common_runtime/function_body.cc b/tensorflow/core/common_runtime/function_body.cc
index 3b3442b..1ca6f6a 100644
--- a/tensorflow/core/common_runtime/function_body.cc
+++ b/tensorflow/core/common_runtime/function_body.cc
@@ -15,14 +15,19 @@
 
 #include "tensorflow/core/common_runtime/function_body.h"
 
+#include <utility>
+
+#include "tensorflow/core/framework/function.h"
 #include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/platform/refcount.h"
 
 namespace tensorflow {
 
-FunctionBody::FunctionBody(const FunctionDef& f, DataTypeSlice arg_t,
-                           DataTypeSlice ret_t, Graph* g)
-    : fdef(f),
+FunctionBody::FunctionBody(core::RefCountPtr<FunctionRecord>&& record,
+                           DataTypeSlice arg_t, DataTypeSlice ret_t, Graph* g)
+    : record(std::move(record)),
       graph(g),
       arg_types(arg_t.begin(), arg_t.end()),
       ret_types(ret_t.begin(), ret_t.end()) {
@@ -48,7 +53,7 @@
   }
   // 2. Find ControlRet nodes that must be always executed.
   std::unordered_set<StringPiece, StringPieceHasher> control_ret_node_names;
-  for (const auto& control_ret : fdef.control_ret()) {
+  for (const auto& control_ret : this->record->fdef().control_ret()) {
     control_ret_node_names.insert(control_ret.second);
   }
   this->control_ret_nodes.reserve(control_ret_node_names.size());
diff --git a/tensorflow/core/common_runtime/function_body.h b/tensorflow/core/common_runtime/function_body.h
index cbd6026..97d27f5 100644
--- a/tensorflow/core/common_runtime/function_body.h
+++ b/tensorflow/core/common_runtime/function_body.h
@@ -19,9 +19,11 @@
 #include "tensorflow/core/framework/function.h"
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/platform/refcount.h"
 
 namespace tensorflow {
 
+class FunctionRecord;
 class Graph;
 class Node;
 
@@ -29,7 +31,7 @@
 // instantiated function that is represented as a Graph with arg/ret
 // nodes annotated.
 struct FunctionBody {
-  FunctionDef fdef;
+  core::RefCountPtr<FunctionRecord> record;
   Graph* graph = nullptr;  // owned.
   DataTypeVector arg_types;
   DataTypeVector ret_types;
@@ -42,8 +44,8 @@
   gtl::InlinedVector<Node*, 4> control_ret_nodes;
 
   FunctionBody() {}
-  FunctionBody(const FunctionDef& f, DataTypeSlice arg_types,
-               DataTypeSlice ret_types, Graph* g);
+  FunctionBody(core::RefCountPtr<FunctionRecord>&& record,
+               DataTypeSlice arg_types, DataTypeSlice ret_types, Graph* g);
   ~FunctionBody();
 };
 
diff --git a/tensorflow/core/common_runtime/function_def_utils.cc b/tensorflow/core/common_runtime/function_def_utils.cc
index b39d2ba..1d7b520 100644
--- a/tensorflow/core/common_runtime/function_def_utils.cc
+++ b/tensorflow/core/common_runtime/function_def_utils.cc
@@ -15,6 +15,7 @@
 
 #include "tensorflow/core/common_runtime/function_def_utils.h"
 
+#include <utility>
 #include <vector>
 
 #include "tensorflow/core/common_runtime/function_body.h"
@@ -24,22 +25,26 @@
 #include "tensorflow/core/graph/control_flow.h"
 #include "tensorflow/core/graph/graph.h"
 #include "tensorflow/core/graph/graph_debug_info_builder.h"
+#include "tensorflow/core/platform/refcount.h"
+#include "tsl/platform/errors.h"
 
 namespace tensorflow {
 
 Status FunctionDefToBodyHelper(
-    const FunctionDef& fdef, const AttrSlice& attrs,
+    core::RefCountPtr<FunctionRecord>&& record, const AttrSlice& attrs,
     const FunctionLibraryDefinition* const lib_def,
     const std::function<Status(const string&, const OpDef**)>& get_func_sig,
     std::unique_ptr<FunctionBody>* fbody) {
   // Instantiates the function template into a graph def.
   InstantiationResult result;
-  TF_RETURN_IF_ERROR(InstantiateFunction(fdef, attrs, get_func_sig, &result));
+  TF_RETURN_IF_ERROR(
+      InstantiateFunction(record->fdef(), attrs, get_func_sig, &result));
 
   auto graph = std::make_unique<Graph>(lib_def);
 
-  auto construction_context_iter = fdef.attr().find("_construction_context");
-  if (construction_context_iter != fdef.attr().end()) {
+  auto construction_context_iter =
+      record->fdef().attr().find("_construction_context");
+  if (construction_context_iter != record->fdef().attr().end()) {
     if (construction_context_iter->second.s() == "kEagerRuntime") {
       graph->SetConstructionContext(ConstructionContext::kEagerRuntime);
     } else {
@@ -56,7 +61,7 @@
                                             /*debug_info=*/nullptr));
 
   const StackTracesMap* stack_traces =
-      lib_def->GetStackTraces(fdef.signature().name());
+      lib_def->GetStackTraces(record->fdef().signature().name());
   if (stack_traces) {
     for (Node* n : graph->nodes()) {
       if (n) {
@@ -73,18 +78,32 @@
   std::vector<ControlFlowInfo> dummy;
   TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph.get(), &dummy));
 
-  *fbody = std::make_unique<FunctionBody>(fdef, result.arg_types,
+  *fbody = std::make_unique<FunctionBody>(std::move(record), result.arg_types,
                                           result.ret_types, graph.release());
   return OkStatus();
 }
 
-Status FunctionDefToBodyHelper(const FunctionDef& fdef, const AttrSlice& attrs,
+Status FunctionDefToBodyHelper(core::RefCountPtr<FunctionRecord>&& record,
+                               const AttrSlice& attrs,
                                const FunctionLibraryDefinition* lib_def,
                                std::unique_ptr<FunctionBody>* fbody) {
   const auto get_func_sig = [&lib_def](const string& op, const OpDef** sig) {
     return lib_def->LookUpOpDef(op, sig);
   };
-  return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig, fbody);
+  return FunctionDefToBodyHelper(std::move(record), attrs, lib_def,
+                                 get_func_sig, fbody);
+}
+
+Status FunctionDefToBodyHelper(const FunctionDef& fdef, const AttrSlice& attrs,
+                               const FunctionLibraryDefinition* lib_def,
+                               std::unique_ptr<FunctionBody>* fbody) {
+  core::RefCountPtr<FunctionRecord> record(
+      new FunctionRecord(FunctionDef(fdef), {}, true));
+  const auto get_func_sig = [&lib_def](const string& op, const OpDef** sig) {
+    return lib_def->LookUpOpDef(op, sig);
+  };
+  return FunctionDefToBodyHelper(std::move(record), attrs, lib_def,
+                                 get_func_sig, fbody);
 }
 
 }  // end namespace tensorflow
diff --git a/tensorflow/core/common_runtime/function_def_utils.h b/tensorflow/core/common_runtime/function_def_utils.h
index f269cc6..1d60ce3 100644
--- a/tensorflow/core/common_runtime/function_def_utils.h
+++ b/tensorflow/core/common_runtime/function_def_utils.h
@@ -19,7 +19,9 @@
 #include <functional>
 #include <memory>
 
+#include "tensorflow/core/framework/function.h"
 #include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/refcount.h"
 
 namespace tensorflow {
 
@@ -27,10 +29,21 @@
 struct FunctionBody;
 class FunctionDef;
 class FunctionLibraryDefinition;
+class FunctionRecord;
 class OpDef;
 
 // Instantiates FunctionDef into a graph. Set *fbody to point to the
 // FunctionBody that holds the instantiated FunctionDef.
+Status FunctionDefToBodyHelper(core::RefCountPtr<FunctionRecord>&& record,
+                               const AttrSlice& attrs,
+                               const FunctionLibraryDefinition* lib_def,
+                               std::unique_ptr<FunctionBody>* fbody);
+
+// Instantiates FunctionDef into a graph. Set *fbody to point to the
+// FunctionBody that holds the instantiated FunctionDef.
+//
+// NOTE(mrry): This implementation incurs a copy of `fdef`. If possible, use
+//   the overload that takes a `core::RefCountPtr<FunctionRecord>`.
 Status FunctionDefToBodyHelper(const FunctionDef& fdef, const AttrSlice& attrs,
                                const FunctionLibraryDefinition* lib_def,
                                std::unique_ptr<FunctionBody>* fbody);
@@ -39,7 +52,7 @@
 // FunctionBody that holds the instantiated FunctionDef. Use custom function
 // signature lookup, in case instantiated function is not in the 'lib_def'.
 Status FunctionDefToBodyHelper(
-    const FunctionDef& fdef, const AttrSlice& attrs,
+    core::RefCountPtr<FunctionRecord>&& record, const AttrSlice& attrs,
     const FunctionLibraryDefinition* lib_def,
     const std::function<Status(const string&, const OpDef**)>& get_func_sig,
     std::unique_ptr<FunctionBody>* fbody);
diff --git a/tensorflow/core/common_runtime/gpu/BUILD b/tensorflow/core/common_runtime/gpu/BUILD
index 1c4fe07..b99b417 100644
--- a/tensorflow/core/common_runtime/gpu/BUILD
+++ b/tensorflow/core/common_runtime/gpu/BUILD
@@ -32,10 +32,8 @@
     ],
     features = if_google(
         [
-            "-layering_check",
             "-parse_headers",
         ],
-        ["-layering_check"],
     ),
     licenses = ["notice"],
 )
@@ -93,6 +91,7 @@
         "gpu_id.h",
         "gpu_id_manager.h",
     ],
+    features = ["-layering_check"],
     deps = [
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
@@ -163,6 +162,7 @@
         ":gpu_virtual_mem_allocator",
     ],
     defines = if_linux_x86_64(["TF_PLATFORM_LINUX_X86_64"]),
+    features = ["-layering_check"],
     visibility = [
         "//tensorflow:internal",
         "//tensorflow_models:__subpackages__",
@@ -245,7 +245,10 @@
         "gpu_bfc_allocator.cc",
     ],
     hdrs = ["gpu_bfc_allocator.h"],
-    features = ["parse_headers"],
+    features = [
+        "-layering_check",
+        "parse_headers",
+    ],
     visibility = ["//visibility:public"],
     deps = [
         ":gpu_virtual_mem_allocator",
@@ -270,7 +273,10 @@
         "@local_xla//xla/stream_executor/gpu:gpu_driver_header",
         "@local_xla//xla/stream_executor/gpu:gpu_types_header",
     ],
-    features = ["parse_headers"],
+    features = [
+        "-layering_check",
+        "parse_headers",
+    ],
     visibility = ["//visibility:public"],
     deps = [
         ":gpu_id",
@@ -294,6 +300,7 @@
     name = "gpu_device_on_non_gpu_machine_test",
     size = "small",
     srcs = ["gpu_device_on_non_gpu_machine_test.cc"],
+    features = ["-layering_check"],
     deps = [
         ":gpu_headers_lib",
         ":gpu_id",
@@ -308,6 +315,7 @@
     srcs = [
         "gpu_bfc_allocator_test.cc",
     ],
+    features = ["-layering_check"],
     tags = tf_cuda_tests_tags(),
     deps = [
         ":gpu_id",
@@ -334,6 +342,7 @@
     srcs = [
         "gpu_device_test.cc",
     ],
+    features = ["-layering_check"],
     tags = tf_cuda_tests_tags(),
     deps = [
         ":gpu_id",
@@ -362,6 +371,7 @@
     srcs = [
         "pool_allocator_test.cc",
     ],
+    features = ["-layering_check"],
     tags = tf_cuda_tests_tags(),
     deps = [
         ":gpu_id",
@@ -388,6 +398,7 @@
     srcs = [
         "gpu_device_test.cc",
     ],
+    features = ["-layering_check"],
     # Runs test on a Guitar cluster that uses P100s to test unified memory
     # allocations.
     tags = tf_cuda_tests_tags() + [
@@ -439,6 +450,7 @@
     size = "medium",
     srcs = ["gpu_debug_allocator_test.cc"],
     args = ["--gtest_death_test_style=threadsafe"],
+    features = ["-layering_check"],
     tags = tf_cuda_tests_tags(),
     deps = [
         ":gpu_id",
@@ -464,6 +476,7 @@
     name = "gpu_virtual_mem_allocator_test",
     size = "small",
     srcs = ["gpu_virtual_mem_allocator_test.cc"],
+    features = ["-layering_check"],
     tags = tf_cuda_tests_tags(),
     deps = [
         ":gpu_virtual_mem_allocator",
@@ -481,6 +494,7 @@
     name = "gpu_serving_device_selector",
     srcs = ["gpu_serving_device_selector.cc"],
     hdrs = ["gpu_serving_device_selector.h"],
+    features = ["-layering_check"],
     deps = [
         "//tensorflow/core/common_runtime:serving_device_selector",
         "@com_google_absl//absl/container:fixed_array",
diff --git a/tensorflow/core/common_runtime/inline_function_utils.cc b/tensorflow/core/common_runtime/inline_function_utils.cc
index 829e589..6259406 100644
--- a/tensorflow/core/common_runtime/inline_function_utils.cc
+++ b/tensorflow/core/common_runtime/inline_function_utils.cc
@@ -272,7 +272,7 @@
 namespace {
 
 Status ValidateNoInline(const FunctionBody* fbody) {
-  const auto attr = AttrSlice(&fbody->fdef.attr());
+  const auto attr = AttrSlice(&fbody->record->fdef().attr());
   bool noinline = false;
   if (TryGetNodeAttr(attr, kNoInlineAttr, &noinline) && noinline) {
     return errors::InvalidArgument(
@@ -380,11 +380,12 @@
 
   if (!options.inline_impl_selection_group_functions) {
     bool is_impl_selection_group_function =
-        fbody->fdef.attr().find("api_implements") != fbody->fdef.attr().end();
+        fbody->record->fdef().attr().find("api_implements") !=
+        fbody->record->fdef().attr().end();
     if (is_impl_selection_group_function) {
       return errors::InvalidArgument(
           "Inlining of implementation selection group function ",
-          fbody->fdef.signature().name(),
+          fbody->record->fdef().signature().name(),
           " is disabled by options.inline_impl_selection_group_functions");
     }
   }
@@ -480,7 +481,8 @@
                           const InlineFunctionBodyOptions& options) {
   VLOG(3) << "Inline function call: " << SummarizeNode(*caller) << " ["
           << options.DebugString() << "]";
-  VLOG(4) << "Inlining function: " << fbody->fdef.DebugString();
+  VLOG(4) << "Inlining function: "
+          << fbody->record->fdef().DebugString();  // NOLINT
   VLOG(4) << "Current graphdef: " << g->ToGraphDefDebug().DebugString();
   VLOG(4) << "Caller: " << caller->DebugString();
 
@@ -582,8 +584,8 @@
       return errors::Internal("Null node found for input ", i);
 
     Node* n = input_identity("input", inputs[i], i);
-    input_node_name_map[arg_name(fbody->fdef.signature().input_arg(), i)] =
-        n->name();
+    input_node_name_map[arg_name(fbody->record->fdef().signature().input_arg(),
+                                 i)] = n->name();
     input_nodes.push_back(n);
   }
 
@@ -607,7 +609,8 @@
     if (device.has_value()) ndef.set_device(*device);
 
     // Add inlined function name to inlined node debug information.
-    PropagateDebugInfoToNode(fbody->fdef.signature().name(), {n}, &ndef);
+    PropagateDebugInfoToNode(fbody->record->fdef().signature().name(), {n},
+                             &ndef);
 
     // Add the function node name as a prefix:
     //  1) to node name to avoid collisions
@@ -701,8 +704,8 @@
     Node* arg = node_map[fbody->arg_nodes[i]->id()];
     Node* n = input_nodes[i];
     VLOG(4) << "    [index " << i << "] "
-            << arg_name(fbody->fdef.signature().input_arg(), i) << " as "
-            << n->name() << " (input: " << inputs[i].name()
+            << arg_name(fbody->record->fdef().signature().input_arg(), i)
+            << " as " << n->name() << " (input: " << inputs[i].name()
             << ", requested_device: " << n->requested_device() << ")";
 
     if (input_control_node) {
@@ -753,9 +756,10 @@
     Node* n = output_identity("output", data, i);
     outputs[i] = n;
     VLOG(4) << "    [index " << i << "] "
-            << arg_name(fbody->fdef.signature().output_arg(), i) << " as "
-            << n->name() << " (ret: " << data.node->name() << ":" << data.index
-            << ", requested_device: " << n->requested_device() << ")";
+            << arg_name(fbody->record->fdef().signature().output_arg(), i)
+            << " as " << n->name() << " (ret: " << data.node->name() << ":"
+            << data.index << ", requested_device: " << n->requested_device()
+            << ")";
     for (const Edge* e : ret->in_edges()) {
       if (e->IsControlEdge()) {
         g->AddControlEdge(e->src(), n, kDoNotCheckDuplicates);
diff --git a/tensorflow/core/common_runtime/inspecting_placer.cc b/tensorflow/core/common_runtime/inspecting_placer.cc
index 626ad78..ea6a1d4 100644
--- a/tensorflow/core/common_runtime/inspecting_placer.cc
+++ b/tensorflow/core/common_runtime/inspecting_placer.cc
@@ -16,6 +16,7 @@
 
 #include <memory>
 #include <unordered_map>
+#include <utility>
 #include <vector>
 
 #include "absl/strings/str_join.h"
@@ -29,6 +30,7 @@
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/graph/graph_node_util.h"
 #include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/refcount.h"
 
 namespace tensorflow {
 
@@ -125,13 +127,13 @@
 
 Status InspectingPlacer::ComputeIOColocationGroups(const Node& node,
                                                    IOColocationGroups* groups) {
-  const FunctionDef* fdef;
+  core::RefCountPtr<FunctionRecord> fdef;
   NameAttrList func;
   TF_RETURN_IF_ERROR(GetFunctionDefAndAttrs(flib_def_, node, &fdef, &func));
   std::unique_ptr<FunctionBody> fbody;
 
-  TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, AttrSlice(&func.attr()),
-                                             &flib_def_, &fbody));
+  TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
+      std::move(fdef), AttrSlice(&func.attr()), &flib_def_, &fbody));
 
   TF_RETURN_IF_ERROR(
       IsolatePlacerInspectionRequiredOps(flib_def_, fbody->graph));
diff --git a/tensorflow/core/common_runtime/lower_function_call_op.cc b/tensorflow/core/common_runtime/lower_function_call_op.cc
index 3ff98ac..b5077da 100644
--- a/tensorflow/core/common_runtime/lower_function_call_op.cc
+++ b/tensorflow/core/common_runtime/lower_function_call_op.cc
@@ -15,6 +15,8 @@
 
 #include "tensorflow/core/common_runtime/lower_function_call_op.h"
 
+#include <utility>
+
 #include "absl/algorithm/container.h"
 #include "tensorflow/core/common_runtime/function_def_utils.h"
 #include "tensorflow/core/common_runtime/inline_function_utils.h"
@@ -23,6 +25,7 @@
 #include "tensorflow/core/graph/graph.h"
 #include "tensorflow/core/graph/graph_node_util.h"
 #include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/refcount.h"
 
 namespace tensorflow {
 
@@ -64,16 +67,16 @@
     return errors::InvalidArgument("Unsupported function inlining policy");
   }
 
-  const FunctionDef* fdef;
+  core::RefCountPtr<FunctionRecord> fdef;
   if (n->IsPartitionedCall()) {
     NameAttrList func;
     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "f", &func));
-    fdef = flib_def.Find(func.name());
+    fdef = flib_def.FindRecord(func.name());
   } else if (n->type_string() == FunctionLibraryDefinition::kGradientOp) {
     VLOG(2) << "Skip SymbolicGradient lowering";
     return OkStatus();
   } else {
-    fdef = flib_def.Find(n->type_string());
+    fdef = flib_def.FindRecord(n->type_string());
   }
 
   if (fdef == nullptr) {
@@ -82,7 +85,7 @@
 
   std::unique_ptr<FunctionBody> fbody;
   TF_RETURN_IF_ERROR(
-      FunctionDefToBodyHelper(*fdef, n->attrs(), &flib_def, &fbody));
+      FunctionDefToBodyHelper(std::move(fdef), n->attrs(), &flib_def, &fbody));
 
   Status can_inline_function_call =
       ValidateInlining(n, fbody.get(), inline_options);
diff --git a/tensorflow/core/common_runtime/optimize_function_graph_utils.cc b/tensorflow/core/common_runtime/optimize_function_graph_utils.cc
index e52e219..c2f330c 100644
--- a/tensorflow/core/common_runtime/optimize_function_graph_utils.cc
+++ b/tensorflow/core/common_runtime/optimize_function_graph_utils.cc
@@ -44,9 +44,11 @@
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/framework/metrics.h"
 #include "tensorflow/core/framework/optimized_function_graph.pb.h"
+#include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/graph/graph.h"
 #include "tensorflow/core/graph/graph_node_util.h"
 #include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/refcount.h"
 #include "tensorflow/core/util/debug_data_dumper.h"
 #include "tsl/platform/env.h"
 #include "tsl/platform/errors.h"
@@ -253,17 +255,21 @@
                       tsl::port::TaskId(), "_", plain_func_name, "_",
                       fdef->node_def_size());
 }
-}  // namespace
 
-Status GetGraphAndArgRets(
-    const string& function_name, AttrSlice attrs, const FunctionDef* fdef,
-    const FunctionLibraryDefinition* lib_def, std::unique_ptr<Graph>* graph,
-    std::vector<Node*>* arg_nodes, std::vector<Node*>* ret_nodes,
-    std::vector<string>* ret_node_names, DataTypeVector* ret_types,
-    std::vector<string>* control_ret_node_names) {
+// Generates graph and return information given the input function name,
+// attributes and function definition.
+Status GetGraphAndArgRets(const string& function_name, AttrSlice attrs,
+                          core::RefCountPtr<FunctionRecord>&& fdef,
+                          const FunctionLibraryDefinition* lib_def,
+                          std::unique_ptr<Graph>* graph,
+                          std::vector<Node*>* arg_nodes,
+                          std::vector<Node*>* ret_nodes,
+                          std::vector<string>* ret_node_names,
+                          DataTypeVector* ret_types,
+                          std::vector<string>* control_ret_node_names) {
   std::unique_ptr<FunctionBody> fbody;
-  // TODO(iga): FunctionDefToBodyHelper copies fdef. Avoid this copy.
-  TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, attrs, lib_def, &fbody));
+  TF_RETURN_IF_ERROR(
+      FunctionDefToBodyHelper(std::move(fdef), attrs, lib_def, &fbody));
   if (!fbody) {
     LOG(ERROR) << "Failed to get FunctionBody for \"" << function_name << "\"";
     return errors::Internal("Failed to construct FunctionBody for ",
@@ -290,6 +296,7 @@
   }
   return OkStatus();
 }
+}  // namespace
 
 Status PinArgsAndRets(const std::vector<string>& input_devices,
                       const std::vector<string>& output_devices,
@@ -467,13 +474,13 @@
   const FunctionLibraryDefinition* lib_def =
       options.lib_def == nullptr ? input_lib_def : options.lib_def;
 
-  const FunctionDef* fdef = lib_def->Find(function_name);
+  core::RefCountPtr<FunctionRecord> fdef = lib_def->FindRecord(function_name);
   if (fdef == nullptr) {
     return errors::InvalidArgument("Failed to find function \"", function_name,
                                    "\" in function library: ", lib_def);
   }
 
-  TF_RETURN_IF_ERROR(ValidateMultiDeviceOptions(*fdef, options));
+  TF_RETURN_IF_ERROR(ValidateMultiDeviceOptions(fdef->fdef(), options));
 
   std::unique_ptr<Graph> graph;
   std::vector<Node*> arg_nodes, ret_nodes;
@@ -482,8 +489,8 @@
   std::vector<string> control_ret_node_names;
 
   TF_RETURN_IF_ERROR(GetGraphAndArgRets(
-      function_name, attrs, fdef, lib_def, &graph, &arg_nodes, &ret_nodes,
-      &ret_node_names, &ret_types, &control_ret_node_names));
+      function_name, attrs, fdef.GetNewRef(), lib_def, &graph, &arg_nodes,
+      &ret_nodes, &ret_node_names, &ret_types, &control_ret_node_names));
 
   DEBUG_DATA_DUMPER()->DumpOpCreationStackTraces(
       function_name, kDebugGroupOpStacktrace, "before_opt", graph.get());
@@ -555,7 +562,7 @@
       node_name_to_control_ret.emplace(control_ret, control_ret);
     }
   } else {
-    for (const auto& control_ret : fdef->control_ret()) {
+    for (const auto& control_ret : fdef->fdef().control_ret()) {
       node_name_to_control_ret.emplace(control_ret.second, control_ret.first);
     }
   }
@@ -572,7 +579,7 @@
   optimization_options.is_function_graph = true;
   optimization_options.composite_devices = &composite_devices;
   optimization_options.default_function_device = default_device;
-  optimization_options.function_def = fdef;
+  optimization_options.function_def = &fdef->fdef();
   optimization_options.shape_inference_on_tfe_dialect_import =
       options.shape_inference_on_tfe_dialect_import;
   optimization_options.debug_filename_prefix = function_name;
diff --git a/tensorflow/core/common_runtime/optimize_function_graph_utils.h b/tensorflow/core/common_runtime/optimize_function_graph_utils.h
index 526e35a..d9bcd85 100644
--- a/tensorflow/core/common_runtime/optimize_function_graph_utils.h
+++ b/tensorflow/core/common_runtime/optimize_function_graph_utils.h
@@ -39,15 +39,6 @@
 // Note: setting this threshold to 0 means to cache for every function.
 constexpr absl::Duration kCachingThresholdDuration = absl::Seconds(3);
 
-// Generates graph and return information given the input function name,
-// attributes and function definition.
-Status GetGraphAndArgRets(
-    const string& function_name, AttrSlice attrs, const FunctionDef* fdef,
-    const FunctionLibraryDefinition* lib_def, std::unique_ptr<Graph>* graph,
-    std::vector<Node*>* arg_nodes, std::vector<Node*>* ret_nodes,
-    std::vector<string>* ret_node_names, DataTypeVector* ret_types,
-    std::vector<string>* control_ret_node_names);
-
 // TODO(iga): Reword
 // Pins each arg that emits a `DT_RESOURCE` tensor to the device on which the
 // corresponding resource lives. This ensures that the Placer assigns ops that
diff --git a/tensorflow/core/common_runtime/placer_inspection_required_ops_utils.cc b/tensorflow/core/common_runtime/placer_inspection_required_ops_utils.cc
index 4f48469..c6c90d4 100644
--- a/tensorflow/core/common_runtime/placer_inspection_required_ops_utils.cc
+++ b/tensorflow/core/common_runtime/placer_inspection_required_ops_utils.cc
@@ -21,9 +21,11 @@
 #include "absl/types/optional.h"
 #include "tensorflow/core/framework/function.h"
 #include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/node_def_util.h"
 #include "tensorflow/core/graph/graph.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/refcount.h"
 
 namespace tensorflow {
 namespace {
@@ -63,12 +65,12 @@
   if (!IsFunctionCall(node)) {
     return Set(node, false, is_deep, &cache_);
   }
-  const FunctionDef* fdef;
+  core::RefCountPtr<FunctionRecord> fdef;
   NameAttrList func;
   TF_RETURN_IF_ERROR(GetFunctionDefAndAttrs(flib_def_, node, &fdef, &func));
   DataTypeVector types;
-  TF_RETURN_IF_ERROR(
-      OutputTypesForNode(AttrSlice(&func.attr()), fdef->signature(), &types));
+  TF_RETURN_IF_ERROR(OutputTypesForNode(AttrSlice(&func.attr()),
+                                        fdef->fdef().signature(), &types));
   for (DataType type : types) {
     if (type == DT_RESOURCE) {
       return Set(node, true, is_deep, &cache_);
@@ -78,11 +80,12 @@
 }
 
 Status GetFunctionDefAndAttrs(const FunctionLibraryDefinition& flib_def,
-                              const Node& node, const FunctionDef** fdef,
+                              const Node& node,
+                              core::RefCountPtr<FunctionRecord>* fdef,
                               NameAttrList* func) {
   TF_RETURN_IF_ERROR(GetNodeAttr(node.def(), "f", func));
   const string& function_name = func->name();
-  *fdef = flib_def.Find(function_name);
+  *fdef = flib_def.FindRecord(function_name);
   if (*fdef == nullptr) {
     return errors::InvalidArgument(
         "Failed to find function \"", function_name,
diff --git a/tensorflow/core/common_runtime/placer_inspection_required_ops_utils.h b/tensorflow/core/common_runtime/placer_inspection_required_ops_utils.h
index f882d02..daa717e 100644
--- a/tensorflow/core/common_runtime/placer_inspection_required_ops_utils.h
+++ b/tensorflow/core/common_runtime/placer_inspection_required_ops_utils.h
@@ -73,7 +73,8 @@
 // Extracts `fdef` and `func` from `flib_def` for the function identified
 // in "f" attribute of `node`.
 Status GetFunctionDefAndAttrs(const FunctionLibraryDefinition& flib_def,
-                              const Node& node, const FunctionDef** fdef,
+                              const Node& node,
+                              core::RefCountPtr<FunctionRecord>* fdef,
                               NameAttrList* func);
 
 // The "call" stack of functions.
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc
index aa82468..801addf6 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc
@@ -24,6 +24,7 @@
 #include <utility>
 
 #include "absl/container/flat_hash_map.h"
+#include "absl/strings/str_cat.h"
 #include "absl/types/optional.h"
 #include "absl/types/variant.h"
 #include "tensorflow/core/common_runtime/build_graph_options.h"
@@ -39,10 +40,12 @@
 #include "tensorflow/core/common_runtime/replicate_per_replica_nodes.h"
 #include "tensorflow/core/common_runtime/single_threaded_executor.h"
 #include "tensorflow/core/common_runtime/stats_publisher_interface.h"
+#include "tensorflow/core/config/flag_defs.h"
 #include "tensorflow/core/framework/cancellation.h"
 #include "tensorflow/core/framework/function.h"
 #include "tensorflow/core/framework/graph_to_functiondef.h"
 #include "tensorflow/core/framework/metrics.h"
+#include "tensorflow/core/framework/node_def_util.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/types.h"
@@ -55,6 +58,7 @@
 #include "tensorflow/core/platform/blocking_counter.h"
 #include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/core/platform/notification.h"
+#include "tensorflow/core/platform/random.h"
 #include "tensorflow/core/public/session_options.h"
 #include "tensorflow/core/util/device_name_utils.h"
 #include "tensorflow/core/util/dump_graph.h"
@@ -462,44 +466,13 @@
 
 void ProcessFunctionLibraryRuntime::PublishSubgraphs(
     const std::string& function_name,
-    std::unique_ptr<std::unordered_map<std::string, std::unique_ptr<Graph>>>
-        subgraphs) {
-  // Use shared_ptr since std::function cannot capture move-only objects
-  auto subgraphs_new =
-      std::shared_ptr<std::unordered_map<std::string, std::unique_ptr<Graph>>>(
-          subgraphs.release());
-  auto completed = std::make_unique<tsl::Notification>();
-  // Converting graphs to GraphDefs involves expensive copies. Delegate the work
-  // to a separate thread to unblock the caller.
-  std::function<void()> thread_fn = [this, function_name, n = completed.get(),
-                                     subgraphs = subgraphs_new]() {
-    std::unique_ptr<StatsPublisherInterface> stats_publisher =
-        stats_publisher_factory_(function_name, BuildGraphOptions(),
-                                 SessionOptions());
-    std::vector<GraphDef> published_graph_defs;
-    published_graph_defs.reserve(subgraphs->size());
-    for (const auto& pair : *subgraphs) {
-      Graph* subgraph = pair.second.get();
-      GraphDef gd;
-      subgraph->ToGraphDef(&gd);
-      published_graph_defs.push_back(std::move(gd));
-    }
-    stats_publisher->PublishGraphProto(std::move(published_graph_defs));
-    {
-      mutex_lock l(mu_);
-      stats_publishers_.push_back(std::move(stats_publisher));
-    }
-    n->Notify();
-  };
-  {
-    mutex_lock l(mu_);
-    stats_publisher_completed_.push_back(std::move(completed));
-  }
-  if (default_thread_pool_ != nullptr) {
-    default_thread_pool_->Schedule(std::move(thread_fn));
-  } else {
-    env_->SchedClosure(std::move(thread_fn));
-  }
+    std::vector<core::RefCountPtr<FunctionRecord>>&& function_records) {
+  std::unique_ptr<StatsPublisherInterface> stats_publisher =
+      stats_publisher_factory_(function_name, BuildGraphOptions(),
+                               SessionOptions());
+  stats_publisher->PublishGraphProto(std::move(function_records));
+  mutex_lock l(mu_);
+  stats_publishers_.push_back(std::move(stats_publisher));
 }
 
 Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
@@ -626,15 +599,15 @@
 
   auto data = std::make_unique<MultiDeviceFunctionData>(
       function_name, function_key, optimized_graph_info->num_return_nodes,
-      std::move(optimized_graph_info->lib_def),
       std::move(optimized_graph_info->ret_types));
 
   int i = 0;
+  FunctionLibraryDefinition data_lib_def =
+      std::move(optimized_graph_info->lib_def);
   // Generate a random function_name to avoid one function reuse the partition
   // function instantiated by another function.
-  FunctionLibraryDefinition* data_lib_def = &data->lib_def_;
   FunctionNameGenerator name_generator(
-      data_lib_def, absl::StrCat(function_name, "_", random::New64()));
+      &data_lib_def, absl::StrCat(function_name, "_", random::New64()));
   const int num_subgraphs = subgraphs->size();
   gtl::InlinedVector<Status, 4> instantiate_status(num_subgraphs);
   BlockingCounter counter(static_cast<int>(num_subgraphs));
@@ -668,10 +641,10 @@
   // Instantiate each component function (subgraph).
   for (const auto& pair : *subgraphs) {
     Status* status = &instantiate_status[i];
-    string unique_name = name_generator.GetName();
     ComponentFunctionData* comp_data = &data->glue_[pair.first];
-    runner([this, &pair, dev_set, comp_data, unique_name, data_lib_def,
-            &control_ret, &options, status, &counter, &data] {
+    comp_data->name = name_generator.GetName();
+    runner([this, &pair, dev_set, comp_data, &data_lib_def, &control_ret,
+            &options, status, &counter, &data] {
       const string& target = pair.first;
 
       const string& device_type =
@@ -698,12 +671,19 @@
       }
       FunctionDef shard;
       status->Update(
-          GraphToFunctionDef(*subgraph, unique_name, control_ret, &shard));
+          GraphToFunctionDef(*subgraph, comp_data->name, control_ret, &shard));
       if (!status->ok()) {
         counter.DecrementCount();
         return;
       }
-      status->Update(data_lib_def->AddFunctionDef(shard));
+
+      // NOTE(mrry): Currently, `shard.attr()` is never set by
+      // `GraphToFunctionDef()` but we previously used it directly in the
+      // call to `Instantiate()`. To avoid subtle bugs, we retain a copy here
+      // before the move in case `GraphToFunctionDef()` changes in future.
+      AttrValueMap attrs(shard.attr());
+
+      status->Update(data_lib_def.AddFunctionDef(std::move(shard)));
       if (!status->ok()) {
         counter.DecrementCount();
         return;
@@ -711,7 +691,7 @@
       FunctionLibraryRuntime::InstantiateOptions opts;
       opts.executor_type = options.executor_type;
       opts.target = target;
-      opts.lib_def = data_lib_def;
+      opts.lib_def = &data_lib_def;
       opts.create_kernels_eagerly = options.create_kernels_eagerly;
       opts.state_handle = options.state_handle;
       opts.allow_small_function_optimizations = data->enable_sync_execution;
@@ -719,20 +699,20 @@
           options.allow_control_flow_sync_execution;
       AttrValue ints_on_device_attr;
       ints_on_device_attr.set_b(options.int_args_and_retvals_on_device);
-      shard.mutable_attr()->insert(
+      attrs.insert(
           {FunctionLibraryDefinition::kIntsOnDeviceAttr, ints_on_device_attr});
-      auto attrs = AttrSlice(&shard.attr());
-      VLOG(1) << "Start instantiating component function " << unique_name
+      VLOG(1) << "Start instantiating component function " << comp_data->name
               << " on device " << target;
       VLOG(4) << DebugString(shard);
 
       auto* component_handle = new FunctionLibraryRuntime::Handle;
-      auto done = [this, status, unique_name, comp_data, component_handle,
-                   &data, &counter](const Status& s) {
+      auto done = [this, status, comp_data, component_handle, &data,
+                   &counter](const Status& s) {
         status->Update(s);
 
-        VLOG(1) << "Finished instantiating component function " << unique_name
-                << " with handle " << *component_handle << " status: " << s;
+        VLOG(1) << "Finished instantiating component function "
+                << comp_data->name << " with handle " << *component_handle
+                << " status: " << s;
         if (status->ok()) {
           {
             mutex_lock l(mu_);
@@ -749,12 +729,14 @@
       FunctionLibraryRuntime* flr = GetFLR(opts.target);
       if (flr != nullptr) {
         // Initialize local function synchronously.
-        Status s = flr->Instantiate(unique_name, attrs, opts, component_handle);
+        Status s = flr->Instantiate(comp_data->name, AttrSlice(&attrs), opts,
+                                    component_handle);
         done(s);
       } else {
         opts.ret_indices = comp_data->ret_indices;
         // Initialize remote function asynchronously.
-        InstantiateRemote(unique_name, attrs, opts, component_handle, done);
+        InstantiateRemote(comp_data->name, AttrSlice(&attrs), opts,
+                          component_handle, done);
       }
     });
     i += 1;
@@ -766,11 +748,23 @@
   }
   TF_RETURN_IF_ERROR(group.as_summary_status());
 
+  std::vector<core::RefCountPtr<FunctionRecord>> function_records;
+  const bool should_publish_function_graphs =
+      flags::Global().publish_function_graphs.value();
+  if (should_publish_function_graphs) {
+    for (const auto& pair : *subgraphs) {
+      ComponentFunctionData* comp_data = &data->glue_[pair.first];
+      function_records.push_back(data_lib_def.FindRecord(comp_data->name));
+    }
+  }
+
   *handle = AddMultiDeviceHandle(std::move(data), function_key);
   VLOG(1) << "Instantiated MultiDevice function \"" << function_name
           << "\" with handle " << *handle;
 
-  PublishSubgraphs(function_name, std::move(subgraphs));
+  if (should_publish_function_graphs) {
+    PublishSubgraphs(function_name, std::move(function_records));
+  }
   return OkStatus();
 }
 
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h
index 8721fb4..433459e 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.h
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.h
@@ -16,23 +16,23 @@
 #define TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_
 
 #include <functional>
+#include <memory>
 #include <optional>
 #include <string>
 #include <unordered_map>
+#include <vector>
 
-#include "absl/types/variant.h"
 #include "tensorflow/core/common_runtime/composite_device.h"
 #include "tensorflow/core/common_runtime/device_mgr.h"
 #include "tensorflow/core/common_runtime/device_set.h"
-#include "tensorflow/core/common_runtime/optimized_function_graph_info.h"
 #include "tensorflow/core/common_runtime/stats_publisher_interface.h"
 #include "tensorflow/core/framework/function.h"
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/platform/platform.h"
+#include "tensorflow/core/platform/refcount.h"
+#include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/protobuf/config.pb.h"
-#include "tsl/platform/notification.h"
 #include "tsl/platform/thread_annotations.h"
 
 #if !defined(IS_MOBILE_PLATFORM)
@@ -84,11 +84,6 @@
     // since the flr_map_ may have already been deleted. Explicitly releasing
     // flr_map_ here and checking flr_map_ in ReleaseHandle to avoid this.
     flr_map_.reset();
-    // Graph and stats publishers might have pending work in async threads that
-    // requires access to PFLR instance. Wait for completion before destructing.
-    for (const auto& n : stats_publisher_completed_) {
-      n->WaitForNotification();
-    }
   }
 
   // Sends `tensors_to_send` from `source_device` to `target_device` using
@@ -270,6 +265,8 @@
   struct ComponentFunctionData {
     // The handle for the instantiated component function.
     FunctionLibraryRuntime::Handle handle;
+    // The name for the component function.
+    string name;
     // arg_indices.size() is the number of arguments to the component function.
     // The i-th argument of the component function comes from the
     // `arg_indices[i]`-th argument of the multi-device function.
@@ -295,12 +292,10 @@
   struct MultiDeviceFunctionData {
     MultiDeviceFunctionData(const string& function_name,
                             const string& function_key, int num_outputs,
-                            FunctionLibraryDefinition&& lib_def,
                             DataTypeVector ret_types)
         : function_name_(function_name),
           function_key_(function_key),
           instantiation_counter_(1),
-          lib_def_(std::move(lib_def)),
           num_outputs_(num_outputs),
           ret_types_(std::move(ret_types)),
           is_cross_process_(false),
@@ -309,9 +304,6 @@
     const string function_name_;
     const string function_key_;
     uint64 instantiation_counter_;
-    // A library that contains definitions of component functions and their
-    // transitive dependencies.
-    FunctionLibraryDefinition lib_def_;
     // Stored here to resize the output tensor vector when function is run.
     const int num_outputs_;
     DataTypeVector ret_types_;
@@ -452,8 +444,7 @@
 
   void PublishSubgraphs(
       const std::string& function_name,
-      std::unique_ptr<std::unordered_map<string, std::unique_ptr<Graph>>>
-          subgraphs);
+      std::vector<core::RefCountPtr<FunctionRecord>>&& function_records);
 
   // Data structure holding information for a single instantiated remote
   // (to be executed on `target_device`) function.
@@ -545,8 +536,6 @@
   // instantiated function.
   std::vector<std::unique_ptr<StatsPublisherInterface>> stats_publishers_
       TF_GUARDED_BY(mu_);
-  std::vector<std::unique_ptr<tsl::Notification>> stats_publisher_completed_
-      TF_GUARDED_BY(mu_);
 };
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/request_cost.h b/tensorflow/core/common_runtime/request_cost.h
index c678270..ce4e5cc 100644
--- a/tensorflow/core/common_runtime/request_cost.h
+++ b/tensorflow/core/common_runtime/request_cost.h
@@ -50,6 +50,8 @@
     int64_t input_size = 0;
     // In this batch, the padding amount.
     int64_t padding_size = 0;
+    // Costs for processing this batch.
+    absl::flat_hash_map<std::string, absl::Duration> batch_costs;
   };
 
   // Records the metrics of a batch.
diff --git a/tensorflow/core/common_runtime/request_cost_test.cc b/tensorflow/core/common_runtime/request_cost_test.cc
index 17f8820..052f1ee 100644
--- a/tensorflow/core/common_runtime/request_cost_test.cc
+++ b/tensorflow/core/common_runtime/request_cost_test.cc
@@ -22,6 +22,8 @@
 namespace tensorflow {
 namespace {
 
+using ::testing::ElementsAre;
+using ::testing::FieldsAre;
 using ::testing::Pair;
 using ::testing::UnorderedElementsAre;
 
@@ -53,13 +55,26 @@
   RequestCost request_cost;
 
   request_cost.RecordBatchMetrics(RequestCost::BatchMetrics{
-      /*processed_size=*/8, /*input_size=*/8, /*padding_size=*/0});
+      /*processed_size=*/8,
+      /*input_size=*/8,
+      /*padding_size=*/0,
+      {{"gcu", absl::Milliseconds(80)}, {"tpu", absl::Milliseconds(160)}}});
   request_cost.RecordBatchMetrics(RequestCost::BatchMetrics{
-      /*processed_size=*/4, /*input_size=*/2, /*padding_size=*/1});
+      /*processed_size=*/4,
+      /*input_size=*/2,
+      /*padding_size=*/1,
+      {{"gcu", absl::Milliseconds(40)}, {"tpu", absl::Milliseconds(80)}}});
 
-  EXPECT_THAT(request_cost.GetBatchMetrics(),
-              testing::ElementsAre(testing::FieldsAre(8, 8, 0),
-                                   testing::FieldsAre(4, 2, 1)));
+  EXPECT_THAT(
+      request_cost.GetBatchMetrics(),
+      ElementsAre(
+          FieldsAre(8, 8, 0,
+                    UnorderedElementsAre(Pair("gcu", absl::Milliseconds(80)),
+                                         Pair("tpu", absl::Milliseconds(160)))),
+          FieldsAre(
+              4, 2, 1,
+              UnorderedElementsAre(Pair("gcu", absl::Milliseconds(40)),
+                                   Pair("tpu", absl::Milliseconds(80))))));
 }
 
 }  // namespace
diff --git a/tensorflow/core/common_runtime/stats_publisher_interface.cc b/tensorflow/core/common_runtime/stats_publisher_interface.cc
index 3dab626..8b04ac9 100644
--- a/tensorflow/core/common_runtime/stats_publisher_interface.cc
+++ b/tensorflow/core/common_runtime/stats_publisher_interface.cc
@@ -19,7 +19,9 @@
 #include <string>
 #include <vector>
 
+#include "tensorflow/core/framework/function.h"
 #include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/platform/refcount.h"
 
 namespace tensorflow {
 namespace {
@@ -37,6 +39,9 @@
 
   void PublishGraphProto(std::vector<GraphDef> graph_defs) override {}
 
+  void PublishGraphProto(std::vector<core::RefCountPtr<FunctionRecord>>&&
+                             function_records) override {}
+
   std::unique_ptr<ProfileHandler> GetProfileHandler(
       uint64 step, int64_t execution_count, const RunOptions& ropts) override {
     return nullptr;
diff --git a/tensorflow/core/common_runtime/stats_publisher_interface.h b/tensorflow/core/common_runtime/stats_publisher_interface.h
index 6f77106..450683e 100644
--- a/tensorflow/core/common_runtime/stats_publisher_interface.h
+++ b/tensorflow/core/common_runtime/stats_publisher_interface.h
@@ -23,6 +23,7 @@
 #include "tensorflow/core/common_runtime/build_graph_options.h"
 #include "tensorflow/core/common_runtime/profile_handler.h"
 #include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/refcount.h"
 #include "tensorflow/core/protobuf/config.pb.h"
 #include "tensorflow/core/public/session_options.h"
 
@@ -52,6 +53,8 @@
   virtual void PublishGraphProto(
       const std::vector<const GraphDef*>& graph_defs) = 0;
   virtual void PublishGraphProto(std::vector<GraphDef> graph_defs) = 0;
+  virtual void PublishGraphProto(
+      std::vector<core::RefCountPtr<FunctionRecord>>&& function_records) = 0;
 
   // Returns a profile handler for the given step based on the execution_count
   // and RunOptions.
diff --git a/tensorflow/core/config/flag_defs.h b/tensorflow/core/config/flag_defs.h
index 14ae3cb..78bfae7 100644
--- a/tensorflow/core/config/flag_defs.h
+++ b/tensorflow/core/config/flag_defs.h
@@ -53,6 +53,10 @@
                   "Enable a graph optimization pass that replicate each small "
                   "constant to its successors' devices. This can decrease "
                   "message passing.");
+  TF_DECLARE_FLAG(publish_function_graphs, true,
+                  "Enables the publication of partitioned function graphs "
+                  "via StatsPublisherInterface. Disabling this flag can "
+                  "reduce memory consumption.");
   // LINT.ThenChange(//tensorflow/core/config/flags_api_wrapper.cc)
 };
 
diff --git a/tensorflow/core/config/flags_api_wrapper.cc b/tensorflow/core/config/flags_api_wrapper.cc
index 974581e..b7fa39e 100644
--- a/tensorflow/core/config/flags_api_wrapper.cc
+++ b/tensorflow/core/config/flags_api_wrapper.cc
@@ -52,5 +52,6 @@
   TF_PY_DECLARE_FLAG(tf_shape_default_int64);
   TF_PY_DECLARE_FLAG(more_stack_traces);
   TF_PY_DECLARE_FLAG(replicate_small_constants);
+  TF_PY_DECLARE_FLAG(publish_function_graphs);
   // LINT.ThenChange(//tensorflow/core/config/flag_defs.h)
 };
diff --git a/tensorflow/core/data/dataset_utils.cc b/tensorflow/core/data/dataset_utils.cc
index 268c7d3..ee9a934 100644
--- a/tensorflow/core/data/dataset_utils.cc
+++ b/tensorflow/core/data/dataset_utils.cc
@@ -871,13 +871,14 @@
 absl::flat_hash_set<tstring> CreateGraphRewriteConfigs(const Options& options) {
   absl::flat_hash_set<tstring> configs;
   const auto& autotune_options = options.autotune_options();
-  std::array<tstring, 9> autotune_only_optimizations = {
+  std::array<tstring, 10> autotune_only_optimizations = {
       kAutotuneBufferSizesOpt,
       kBatchParallelizationOpt,
       kDisablePrefetchLegacyAutotuneOpt,
       kEnableGradientDescentOpt,
       kFilterParallelizationOpt,
       kMapParallelizationOpt,
+      kMapFusionOpt,
       kInjectPrefetchOpt,
       kInjectIoPrefetchEligibleOpt,
       kInjectIoPrefetchOpt};
@@ -1005,7 +1006,9 @@
 REGISTER_DATASET_EXPERIMENT("inject_io_prefetch", RandomJobSamplePercentage<0>,
                             AllTasks);
 REGISTER_DATASET_EXPERIMENT("reduce_array_record_dataset_memory_usage",
-                            RandomJobSamplePercentage<50>, AllTasks);
+                            RandomJobSamplePercentage<0>, AllTasks);
+REGISTER_DATASET_EXPERIMENT("map_fusion", RandomJobSamplePercentage<1>,
+                            IndependentHostTasks);
 }  // namespace
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD
index 4e02eba..ee93b3b 100644
--- a/tensorflow/core/data/service/BUILD
+++ b/tensorflow/core/data/service/BUILD
@@ -380,6 +380,7 @@
     deps = [
         ":common_proto_cc",
         ":data_transfer",
+        ":dataset_store",
         ":dispatcher_client",
         ":test_cluster",
         ":test_util",
@@ -392,6 +393,7 @@
         "//tensorflow/core/platform:status_matchers",
         "//tensorflow/core/platform:statusor",
         "@com_google_absl//absl/container:flat_hash_set",
+        "@local_tsl//tsl/platform:path",
         "@local_tsl//tsl/protobuf:protos_all_cc",
     ] + tf_grpc_cc_dependencies() + tf_protos_profiler_service(),
 )
@@ -1098,6 +1100,7 @@
         "//tensorflow/core/data/service/snapshot:snapshot_stream_writer",
         "//tensorflow/core/platform:env",
         "//tensorflow/core/platform:errors",
+        "//tensorflow/core/platform:logging",
         "//tensorflow/core/platform:platform_port",
         "//tensorflow/core/platform:status",
         "//tensorflow/core/platform:statusor",
diff --git a/tensorflow/core/data/service/client/data_service_client.cc b/tensorflow/core/data/service/client/data_service_client.cc
index eddb40f..def2d20 100644
--- a/tensorflow/core/data/service/client/data_service_client.cc
+++ b/tensorflow/core/data/service/client/data_service_client.cc
@@ -351,10 +351,10 @@
               << task_info.worker_address() << "'.";
     return worker;
   }
-  LOG(WARNING) << "Failed to start client for data transfer protocol '"
-               << transfer_server.protocol() << "' for worker '"
-               << task_info.worker_address() << "'; falling back to grpc. "
-               << "Original error: " << worker.status();
+  LOG(INFO) << "Failed to start client for data transfer protocol '"
+            << transfer_server.protocol() << "' for worker '"
+            << task_info.worker_address() << "'; falling back to grpc. "
+            << "Original error: " << worker.status();
   metrics::RecordTFDataServiceDataTransferProtocolFallback(
       transfer_server.protocol(),
       static_cast<error::Code>(worker.status().raw_code()),
diff --git a/tensorflow/core/data/service/dataset_store.cc b/tensorflow/core/data/service/dataset_store.cc
index f7db630..105c593 100644
--- a/tensorflow/core/data/service/dataset_store.cc
+++ b/tensorflow/core/data/service/dataset_store.cc
@@ -37,10 +37,6 @@
 Status FileSystemDatasetStore::Put(const std::string& key,
                                    const DatasetDef& dataset) {
   std::string path_to_write = io::JoinPath(datasets_dir_, key);
-
-  if (Env::Default()->FileExists(path_to_write).ok()) {
-    return errors::AlreadyExists("File ", path_to_write, " already exists");
-  }
   TF_RETURN_IF_ERROR(WriteDatasetDef(path_to_write, dataset));
   return OkStatus();
 }
@@ -58,10 +54,6 @@
 Status MemoryDatasetStore::Put(const std::string& key,
                                const DatasetDef& dataset) {
   auto& stored_dataset = datasets_[key];
-  if (stored_dataset) {
-    return errors::AlreadyExists("Dataset with key ", key,
-                                 " is already stored.");
-  }
   stored_dataset = std::make_shared<const DatasetDef>(dataset);
   return OkStatus();
 }
diff --git a/tensorflow/core/data/service/dataset_store.h b/tensorflow/core/data/service/dataset_store.h
index 790d0d7..437066d 100644
--- a/tensorflow/core/data/service/dataset_store.h
+++ b/tensorflow/core/data/service/dataset_store.h
@@ -33,8 +33,8 @@
  public:
   virtual ~DatasetStore() = default;
 
-  // Stores the given dataset under the given key. Returns ALREADY_EXISTS if the
-  // key already exists.
+  // Stores the given dataset under the given key. Overwrites a dataset if it
+  // already exists.
   virtual Status Put(const std::string& key, const DatasetDef& dataset) = 0;
   // Gets the dataset for the given key, storing the dataset in `dataset_def`.
   virtual Status Get(const std::string& key,
diff --git a/tensorflow/core/data/service/dataset_store_test.cc b/tensorflow/core/data/service/dataset_store_test.cc
index 41bb370..46c1111 100644
--- a/tensorflow/core/data/service/dataset_store_test.cc
+++ b/tensorflow/core/data/service/dataset_store_test.cc
@@ -102,8 +102,7 @@
   DatasetDef dataset_def = DatasetDefWithVersion(version);
   std::string key = "key";
   TF_ASSERT_OK(store->Put(key, dataset_def));
-  Status s = store->Put(key, dataset_def);
-  EXPECT_EQ(s.code(), error::ALREADY_EXISTS);
+  TF_EXPECT_OK(store->Put(key, dataset_def));
   std::shared_ptr<const DatasetDef> result;
   TF_ASSERT_OK(store->Get(key, result));
   EXPECT_EQ(result->graph().version(), version);
diff --git a/tensorflow/core/data/service/dispatcher_client_test.cc b/tensorflow/core/data/service/dispatcher_client_test.cc
index 86ad7a0..ef11446 100644
--- a/tensorflow/core/data/service/dispatcher_client_test.cc
+++ b/tensorflow/core/data/service/dispatcher_client_test.cc
@@ -23,6 +23,7 @@
 #include "absl/container/flat_hash_set.h"
 #include "tensorflow/core/data/service/common.pb.h"
 #include "tensorflow/core/data/service/data_transfer.h"
+#include "tensorflow/core/data/service/dataset_store.h"
 #include "tensorflow/core/data/service/snapshot/path_utils.h"
 #include "tensorflow/core/data/service/test_cluster.h"
 #include "tensorflow/core/data/service/test_util.h"
@@ -36,6 +37,8 @@
 #include "tensorflow/core/protobuf/error_codes.pb.h"
 #include "tensorflow/core/protobuf/snapshot.pb.h"
 #include "tensorflow/core/protobuf/struct.pb.h"
+#include "tsl/platform/path.h"
+#include "tsl/platform/test.h"
 #include "tsl/protobuf/error_codes.pb.h"
 
 namespace tensorflow {
@@ -361,6 +364,51 @@
                      HasSubstr("Existing processing mode: <>"),
                      HasSubstr("Existing cross-trainer cache: <disabled>"))));
 }
+
+class DispatcherClientTest_DatasetId
+    : public DispatcherClientTest,
+      public ::testing::WithParamInterface<std::optional<std::string>> {};
+
+TEST_P(DispatcherClientTest_DatasetId, SyncDatasetStoreWithDispatcherState) {
+  TestCluster::Config config;
+  config.num_workers = 1;
+  config.work_dir = tsl::io::JoinPath(tsl::testing::TmpDir(), "work_dir");
+
+  test_cluster_ = std::make_unique<TestCluster>(config);
+  TF_ASSERT_OK(test_cluster_->Initialize());
+  dispatcher_client_ = std::make_unique<DataServiceDispatcherClient>(
+      test_cluster_->DispatcherAddress(), kProtocol);
+
+  DatasetDef dataset_def = RangeDataset(10);
+  std::optional<std::string> requested_dataset_id = GetParam();
+  std::string dataset_id;
+  TF_ASSERT_OK(dispatcher_client_->RegisterDataset(
+      dataset_def, GetDefaultMetadata(),
+      /*requested_dataset_id=*/std::nullopt, dataset_id));
+  EXPECT_EQ(dataset_id, "1000");
+
+  // Writes an inconsistent dataset file. It should be discarded when the user
+  // registers a new dataset.
+  std::string datasets_dir = tsl::io::JoinPath(config.work_dir, "datasets");
+  FileSystemDatasetStore dataset_store(datasets_dir);
+  TF_ASSERT_OK(dataset_store.Put("1001", dataset_def));
+  if (requested_dataset_id.has_value()) {
+    TF_ASSERT_OK(dataset_store.Put(*requested_dataset_id, dataset_def));
+  }
+
+  TF_ASSERT_OK(dispatcher_client_->RegisterDataset(
+      dataset_def, GetDefaultMetadata(),
+      /*requested_dataset_id=*/requested_dataset_id, dataset_id));
+  if (requested_dataset_id.has_value()) {
+    EXPECT_EQ(dataset_id, *requested_dataset_id);
+  } else {
+    EXPECT_EQ(dataset_id, "1001");
+  }
+}
+
+INSTANTIATE_TEST_SUITE_P(DatasetId, DispatcherClientTest_DatasetId,
+                         ::testing::Values(std::nullopt, "dataset_id"));
+
 }  // namespace
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/data/service/snapshot/BUILD b/tensorflow/core/data/service/snapshot/BUILD
index 7938b1b..46e420a 100644
--- a/tensorflow/core/data/service/snapshot/BUILD
+++ b/tensorflow/core/data/service/snapshot/BUILD
@@ -235,7 +235,9 @@
         "//tensorflow/core/data:snapshot_utils",
         "//tensorflow/core/data:utils",
         "@com_google_absl//absl/status",
+        "@com_google_absl//absl/strings",
         "@local_tsl//tsl/platform:env",
+        "@local_tsl//tsl/platform:path",
         "@local_tsl//tsl/platform:tstring",
     ],
 )
@@ -338,12 +340,10 @@
         "//tensorflow/core/data/service:task_runner",
         "@com_google_absl//absl/container:flat_hash_set",
         "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
         "@com_google_absl//absl/strings",
         "@local_tsl//tsl/platform:env",
-        "@local_tsl//tsl/platform:errors",
         "@local_tsl//tsl/platform:path",
-        "@local_tsl//tsl/platform:status",
-        "@local_tsl//tsl/platform:statusor",
     ],
 )
 
diff --git a/tensorflow/core/data/service/snapshot/snapshot_chunk_dataset_op.cc b/tensorflow/core/data/service/snapshot/snapshot_chunk_dataset_op.cc
index fe24485..ad2e7ba 100644
--- a/tensorflow/core/data/service/snapshot/snapshot_chunk_dataset_op.cc
+++ b/tensorflow/core/data/service/snapshot/snapshot_chunk_dataset_op.cc
@@ -19,15 +19,18 @@
 #include <vector>
 
 #include "absl/status/status.h"
+#include "absl/strings/string_view.h"
 #include "tensorflow/core/data/name_utils.h"
 #include "tensorflow/core/data/snapshot_utils.h"
 #include "tensorflow/core/data/utils.h"
 #include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/metrics.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/graph/graph.h"
 #include "tsl/platform/env.h"
+#include "tsl/platform/path.h"
 #include "tsl/platform/tstring.h"
 
 namespace tensorflow {
@@ -39,9 +42,16 @@
 constexpr const char* const kStartIndex = "start_index";
 constexpr const char* const kOutputTypes = "output_types";
 constexpr const char* const kOutputShapes = "output_shapes";
+constexpr const char* const kSnapshotChunkDataset = "SnapshotChunkDataset";
 
 constexpr int64_t kTFRecordReaderOutputBufferSize = 512 << 20;  // 512MB
 
+absl::string_view GetSnapshotPath(absl::string_view chunk_file) {
+  // Snapshot chunks are placed in snapshot_path/chunks/chunk_x.
+  absl::string_view chunk_dir = tsl::io::Dirname(chunk_file);
+  return tsl::io::Dirname(chunk_dir);
+}
+
 // A reader dataset is responsible for reading one chunk file of a snapshot.
 // TODO(b/250921378): Merge this with `snapshot_util::Reader::Dataset`.
 class SnapshotChunkDatasetOp : public DatasetOpKernel {
@@ -119,6 +129,7 @@
       reader_ = std::make_unique<snapshot_util::TFRecordReader>(
           TranslateFileName(dataset()->chunk_file_), dataset()->compression_,
           dataset()->dtypes_, kTFRecordReaderOutputBufferSize);
+      bytes_read_ = 0;
       return reader_->Initialize(ctx->env());
     }
 
@@ -128,7 +139,7 @@
                                  bool* end_of_sequence) override {
       *end_of_sequence = false;
       absl::Status status = reader_->ReadTensors(out_tensors);
-      if (errors::IsOutOfRange(status)) {
+      if (absl::IsOutOfRange(status)) {
         *end_of_sequence = true;
         return absl::OkStatus();
       }
@@ -136,6 +147,7 @@
           status,
           " Failed to read tf.data snapshot file: ", dataset()->chunk_file_);
       ++start_index_;
+      RecordBytesRead();
       return status;
     }
 
@@ -166,8 +178,17 @@
       return absl::OkStatus();
     }
 
+    void RecordBytesRead() {
+      uint64_t bytes_read = reader_->BytesRead();
+      static auto* bytes_counter =
+          metrics::GetTFDataBytesReadCounter(kSnapshotChunkDataset);
+      bytes_counter->IncrementBy(bytes_read - bytes_read_);
+      bytes_read_ = bytes_read;
+    }
+
     std::unique_ptr<snapshot_util::TFRecordReader> reader_;
     int64_t start_index_ = 0;
+    uint64_t bytes_read_ = 0;
   };
 
   const tstring chunk_file_;
@@ -191,9 +212,11 @@
   *output = new SnapshotChunkDatasetOp::Dataset(DatasetContext(ctx), chunk_file,
                                                 compression_, output_types_,
                                                 output_shapes_);
+  metrics::RecordTFDataServiceSnapshotOp(
+      std::string(GetSnapshotPath(chunk_file)), kSnapshotChunkDataset);
 }
 
-REGISTER_KERNEL_BUILDER(Name("SnapshotChunkDataset").Device(DEVICE_CPU),
+REGISTER_KERNEL_BUILDER(Name(kSnapshotChunkDataset).Device(DEVICE_CPU),
                         SnapshotChunkDatasetOp);
 
 }  // namespace
diff --git a/tensorflow/core/data/service/snapshot/test_utils.cc b/tensorflow/core/data/service/snapshot/test_utils.cc
index 7b7048f..f4046d4 100644
--- a/tensorflow/core/data/service/snapshot/test_utils.cc
+++ b/tensorflow/core/data/service/snapshot/test_utils.cc
@@ -22,6 +22,7 @@
 
 #include "absl/container/flat_hash_set.h"
 #include "absl/status/status.h"
+#include "absl/status/statusor.h"
 #include "absl/strings/numbers.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/str_split.h"
@@ -31,17 +32,14 @@
 #include "tensorflow/core/data/service/task_runner.h"
 #include "tensorflow/core/data/standalone.h"
 #include "tsl/platform/env.h"
-#include "tsl/platform/errors.h"
 #include "tsl/platform/path.h"
-#include "tsl/platform/status.h"
-#include "tsl/platform/statusor.h"
 
 namespace tensorflow {
 namespace data {
 namespace testing {
 namespace {
 
-tsl::StatusOr<std::string> CreateTmpDirectory() {
+absl::StatusOr<std::string> CreateTmpDirectory() {
   std::string snapshot_path;
   if (!Env::Default()->LocalTempFilename(&snapshot_path)) {
     return absl::FailedPreconditionError(
@@ -52,7 +50,7 @@
   return snapshot_path;
 }
 
-tsl::StatusOr<int64_t> CommittedChunkIndex(const std::string& chunk_file) {
+absl::StatusOr<int64_t> CommittedChunkIndex(const std::string& chunk_file) {
   std::vector<std::string> tokens = absl::StrSplit(chunk_file, '_');
   int64_t result = 0;
   if (tokens.size() != 4 || !absl::SimpleAtoi(tokens[2], &result)) {
@@ -61,7 +59,7 @@
   return result;
 }
 
-tsl::StatusOr<int64_t> CheckpointIndex(const std::string& checkpoint_file) {
+absl::StatusOr<int64_t> CheckpointIndex(const std::string& checkpoint_file) {
   std::vector<std::string> tokens = absl::StrSplit(checkpoint_file, '_');
   int64_t result = 0;
   if (tokens.size() != 3 || !absl::SimpleAtoi(tokens[1], &result)) {
@@ -85,7 +83,7 @@
       max_chunk_size_bytes_(max_chunk_size_bytes),
       checkpoint_interval_(checkpoint_interval) {}
 
-tsl::StatusOr<PartialSnapshotWriter> PartialSnapshotWriter::Create(
+absl::StatusOr<PartialSnapshotWriter> PartialSnapshotWriter::Create(
     const DatasetDef& dataset, const std::string& snapshot_path,
     int64_t stream_index, const std::string& compression,
     int64_t max_chunk_size_bytes, absl::Duration checkpoint_interval) {
@@ -96,7 +94,7 @@
   return writer;
 }
 
-tsl::Status PartialSnapshotWriter::Initialize() {
+absl::Status PartialSnapshotWriter::Initialize() {
   TF_ASSIGN_OR_RETURN(tmp_snapshot_path_, CreateTmpDirectory());
   // Each chunk contains one record.
   SnapshotWriterParams writer_params{tmp_snapshot_path_,
@@ -112,7 +110,7 @@
   return snapshot_writer.Wait().status();
 }
 
-tsl::Status PartialSnapshotWriter::WriteCommittedChunks(
+absl::Status PartialSnapshotWriter::WriteCommittedChunks(
     const absl::flat_hash_set<int64_t>& committed_chunk_indexes) const {
   std::string tmp_chunks_directory =
       CommittedChunksDirectory(tmp_snapshot_path_);
@@ -135,10 +133,10 @@
           Env::Default()->CopyFile(tmp_chunk_path, committed_chunk_path));
     }
   }
-  return OkStatus();
+  return absl::OkStatus();
 }
 
-tsl::Status PartialSnapshotWriter::WriteUncommittedChunks(
+absl::Status PartialSnapshotWriter::WriteUncommittedChunks(
     const absl::flat_hash_set<int64_t>& uncommitted_chunk_indexes) const {
   std::string tmp_chunks_directory =
       CommittedChunksDirectory(tmp_snapshot_path_);
@@ -162,10 +160,10 @@
           Env::Default()->CopyFile(tmp_chunk_path, uncommitted_chunk_path));
     }
   }
-  return OkStatus();
+  return absl::OkStatus();
 }
 
-tsl::Status PartialSnapshotWriter::WriteCheckpoints(
+absl::Status PartialSnapshotWriter::WriteCheckpoints(
     const absl::flat_hash_set<int64_t>& checkpoint_indexes) const {
   std::string tmp_checkpoints_directory =
       CheckpointsDirectory(tmp_snapshot_path_, stream_index_);
@@ -189,10 +187,10 @@
           Env::Default()->CopyFile(tmp_checkpoint_path, checkpoint_path));
     }
   }
-  return OkStatus();
+  return absl::OkStatus();
 }
 
-tsl::StatusOr<std::unique_ptr<StandaloneTaskIterator>> TestIterator(
+absl::StatusOr<std::unique_ptr<StandaloneTaskIterator>> TestIterator(
     const DatasetDef& dataset_def) {
   std::unique_ptr<standalone::Dataset> dataset;
   TF_RETURN_IF_ERROR(standalone::Dataset::FromGraph(
diff --git a/tensorflow/core/data/service/snapshot/test_utils.h b/tensorflow/core/data/service/snapshot/test_utils.h
index 329a816..3d9f2ab 100644
--- a/tensorflow/core/data/service/snapshot/test_utils.h
+++ b/tensorflow/core/data/service/snapshot/test_utils.h
@@ -22,6 +22,7 @@
 
 #include "absl/container/flat_hash_set.h"
 #include "absl/status/status.h"
+#include "absl/status/statusor.h"
 #include "tensorflow/core/data/service/common.pb.h"
 #include "tensorflow/core/data/service/snapshot/file_utils.h"
 #include "tensorflow/core/data/service/snapshot/path_utils.h"
@@ -29,10 +30,7 @@
 #include "tensorflow/core/data/snapshot_utils.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tsl/platform/env.h"
-#include "tsl/platform/errors.h"
 #include "tsl/platform/path.h"
-#include "tsl/platform/status.h"
-#include "tsl/platform/statusor.h"
 
 namespace tensorflow {
 namespace data {
@@ -40,11 +38,11 @@
 
 // Reads the records from a distributed tf.data snapshot written at `base_path`.
 template <class T>
-tsl::StatusOr<std::vector<T>> ReadSnapshot(const std::string& base_path,
-                                           const std::string& compression) {
+absl::StatusOr<std::vector<T>> ReadSnapshot(const std::string& base_path,
+                                            const std::string& compression) {
   std::vector<T> result;
   std::string chunks_directory = CommittedChunksDirectory(base_path);
-  TF_ASSIGN_OR_RETURN(std::vector<string> chunk_files,
+  TF_ASSIGN_OR_RETURN(std::vector<std::string> chunk_files,
                       GetChildren(chunks_directory, Env::Default()));
   for (const std::string& chunk_file : chunk_files) {
     std::string chunk_file_path =
@@ -55,7 +53,7 @@
 
     while (true) {
       std::vector<Tensor> tensors;
-      Status status = tfrecord_reader.ReadTensors(&tensors);
+      absl::Status status = tfrecord_reader.ReadTensors(&tensors);
       if (absl::IsOutOfRange(status)) {
         break;
       }
@@ -71,7 +69,7 @@
 // checkpoints.
 class PartialSnapshotWriter {
  public:
-  static tsl::StatusOr<PartialSnapshotWriter> Create(
+  static absl::StatusOr<PartialSnapshotWriter> Create(
       const DatasetDef& dataset, const std::string& snapshot_path,
       int64_t stream_index, const std::string& compression,
       int64_t max_chunk_size_bytes = 1,
@@ -83,15 +81,15 @@
   PartialSnapshotWriter& operator=(PartialSnapshotWriter&&) = delete;
 
   // Writes the specified chunks.
-  tsl::Status WriteCommittedChunks(
+  absl::Status WriteCommittedChunks(
       const absl::flat_hash_set<int64_t>& committed_chunk_indexes) const;
 
   // Writes the specified uncommitted chunks.
-  tsl::Status WriteUncommittedChunks(
+  absl::Status WriteUncommittedChunks(
       const absl::flat_hash_set<int64_t>& uncommitted_chunk_indexes) const;
 
   // Writes the specified checkpoints.
-  tsl::Status WriteCheckpoints(
+  absl::Status WriteCheckpoints(
       const absl::flat_hash_set<int64_t>& checkpoint_indexes) const;
 
  private:
@@ -101,7 +99,7 @@
                         int64_t max_chunk_size_bytes,
                         absl::Duration checkpoint_interval);
 
-  tsl::Status Initialize();
+  absl::Status Initialize();
 
   const DatasetDef dataset_;
   const std::string snapshot_path_;
@@ -115,7 +113,7 @@
 
 // Creates a test iterator for the input dataset. The iterator will generate all
 // elements of the dataset.
-tsl::StatusOr<std::unique_ptr<StandaloneTaskIterator>> TestIterator(
+absl::StatusOr<std::unique_ptr<StandaloneTaskIterator>> TestIterator(
     const DatasetDef& dataset_def);
 
 }  // namespace testing
diff --git a/tensorflow/core/data/service/test_cluster.cc b/tensorflow/core/data/service/test_cluster.cc
index ffe68e7..9383de1 100644
--- a/tensorflow/core/data/service/test_cluster.cc
+++ b/tensorflow/core/data/service/test_cluster.cc
@@ -49,6 +49,9 @@
   }
   initialized_ = true;
   experimental::DispatcherConfig dispatcher_config;
+  if (!config_.work_dir.empty()) {
+    dispatcher_config.set_work_dir(config_.work_dir);
+  }
   dispatcher_config.set_protocol(kProtocol);
   for (int i = 0; i < num_workers_; ++i) {
     dispatcher_config.add_worker_addresses("localhost");
diff --git a/tensorflow/core/data/service/test_cluster.h b/tensorflow/core/data/service/test_cluster.h
index 985a762..0528a74 100644
--- a/tensorflow/core/data/service/test_cluster.h
+++ b/tensorflow/core/data/service/test_cluster.h
@@ -54,6 +54,7 @@
     int64_t worker_heartbeat_interval_ms = 0;
     int64_t job_gc_check_interval_ms = 0;
     int64_t job_gc_timeout_ms = 0;
+    std::string work_dir;
   };
 
   // Creates a new test cluster with a dispatcher and `num_workers` workers.
diff --git a/tensorflow/core/data/service/worker_impl.cc b/tensorflow/core/data/service/worker_impl.cc
index 1c4cf2b..ce1a3fe 100644
--- a/tensorflow/core/data/service/worker_impl.cc
+++ b/tensorflow/core/data/service/worker_impl.cc
@@ -24,6 +24,7 @@
 #include "grpcpp/create_channel.h"
 #include "absl/algorithm/container.h"
 #include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
 #include "absl/strings/str_join.h"
 #include "absl/strings/string_view.h"
 #include "absl/strings/substitute.h"
@@ -53,12 +54,14 @@
 #include "tensorflow/core/platform/env_time.h"
 #include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/host_info.h"
+#include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/core/platform/statusor.h"
 #include "tensorflow/core/platform/thread_annotations.h"
 #include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/protobuf/service_config.pb.h"
 #include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/util/dump_graph.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/status_to_from_proto.h"
 #include "tsl/platform/statusor.h"
@@ -407,16 +410,17 @@
   TF_ASSIGN_OR_RETURN(bool compression_disabled_at_runtime,
                       DisableCompressionAtRuntime(task_def.dataset_id()));
   GraphDef graph = dataset_def.graph();
+  if (VLOG_IS_ON(1)) {
+    std::string prefix = absl::StrCat(task_def.dataset_id(), "_", worker_uid_);
+    DumpGraphDefToFile(absl::StrCat(prefix, "-prerewrite_GraphDef"), graph);
+    DumpProtoToFile(absl::StrCat(prefix, "-prerewrite_TaskDef"), task_def);
+  }
   if (compression_disabled_at_runtime) {
     RemoveCompressionMapRewriter remove_compression_map_rewriter;
-    VLOG(2) << "Applying compression map rewrite. GraphDef: "
-            << graph.DebugString();
     TF_ASSIGN_OR_RETURN(
         graph, remove_compression_map_rewriter.ApplyRemoveCompressionMapRewrite(
                    graph));
   }
-  VLOG(2) << "Applying autoshard rewrite. TaskDef: " << task_def.DebugString()
-          << ", GraphDef: " << graph.DebugString();
   TF_ASSIGN_OR_RETURN(AutoShardRewriter auto_shard_rewriter,
                       AutoShardRewriter::Create(task_def));
   // `ApplyAutoShardRewrite` does nothing if auto-sharding is disabled.
diff --git a/tensorflow/core/data/snapshot_utils.cc b/tensorflow/core/data/snapshot_utils.cc
index ee3b3d4..92d9d82 100644
--- a/tensorflow/core/data/snapshot_utils.cc
+++ b/tensorflow/core/data/snapshot_utils.cc
@@ -749,6 +749,7 @@
     std::optional<int64_t> output_buffer_size)
     : filename_(filename),
       offset_(0),
+      bytes_read_(0),
       compression_(compression),
       output_buffer_size_(output_buffer_size) {}
 
@@ -763,12 +764,14 @@
   }
 #endif  // IS_SLIM_BUILD
   record_reader_ = std::make_unique<io::RecordReader>(file_.get(), options);
+  bytes_read_ = 0;
   return OkStatus();
 }
 
 StatusOr<Tensor> TFRecordReaderImpl::GetNext() {
   tstring record;
   TF_RETURN_IF_ERROR(record_reader_->ReadRecord(&offset_, &record));
+  bytes_read_ += record.size();
   return Parse(record);
 }
 
diff --git a/tensorflow/core/data/snapshot_utils.h b/tensorflow/core/data/snapshot_utils.h
index 0ca3273..92f0053 100644
--- a/tensorflow/core/data/snapshot_utils.h
+++ b/tensorflow/core/data/snapshot_utils.h
@@ -269,6 +269,9 @@
   // Reads all Tensors in the input file.
   StatusOr<std::vector<Tensor>> GetTensors();
 
+  // Returns the number of bytes read.
+  uint64_t BytesRead() const { return bytes_read_; }
+
  private:
   // Parses `record` into a Tensor.
   StatusOr<Tensor> Parse(const tstring& record);
@@ -276,7 +279,8 @@
   std::string filename_;
   std::unique_ptr<RandomAccessFile> file_;
   std::unique_ptr<io::RecordReader> record_reader_;
-  uint64_t offset_;
+  uint64_t offset_ = 0;
+  uint64_t bytes_read_ = 0;
 
   const string compression_;
   const std::optional<int64_t> output_buffer_size_;
@@ -299,6 +303,9 @@
   // end of file, or an error status if there is an error.
   Status ReadTensors(std::vector<Tensor>* read_tensors) override;
 
+  // Returns the number of bytes read.
+  uint64_t BytesRead() const { return reader_impl_.BytesRead(); }
+
  private:
   TFRecordReaderImpl reader_impl_;
   const DataTypeVector dtypes_;
diff --git a/tensorflow/core/distributed_runtime/integration_test/c_api_multi_client_function_test.cc b/tensorflow/core/distributed_runtime/integration_test/c_api_multi_client_function_test.cc
index 853656d..7a077cb 100644
--- a/tensorflow/core/distributed_runtime/integration_test/c_api_multi_client_function_test.cc
+++ b/tensorflow/core/distributed_runtime/integration_test/c_api_multi_client_function_test.cc
@@ -346,7 +346,7 @@
     // To make sure the sender won't delete the data it sent before the receiver
     // retrieves it, we need to do the following steps:
     // 1. Since we created async EagerContext, we need to force each worker to
-    //    wait until all pening operations finish before deleting the context.
+    //    wait until all pending operations finish before deleting the context.
     // 2. In addition, use the blocking counter to notify the 2 workers when
     //    it is safe to clean up all the data.
     TFE_ContextAsyncWait(ctx, status);
diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD
index a647fa4..27c92c1 100644
--- a/tensorflow/core/framework/BUILD
+++ b/tensorflow/core/framework/BUILD
@@ -712,6 +712,7 @@
     hdrs = ["resource_base.h"],
     visibility = default_visibility + [
         "//learning/brain/google/data/core/kernels:__pkg__",
+        "//learning/deepmind/tensorflow/sstable:__pkg__",
     ],
     deps = [
         "//tensorflow/core/lib/core:refcount",
diff --git a/tensorflow/core/framework/device.h b/tensorflow/core/framework/device.h
index a7d5bb2..3c60452 100644
--- a/tensorflow/core/framework/device.h
+++ b/tensorflow/core/framework/device.h
@@ -76,7 +76,7 @@
   // human-readable and not computer-parsed, except that two devices
   // with the same device_type() are expected to perform similarly
   // (both from a computation and communication perspective).
-  const std::string& device_type() const {
+  const std::string& device_type() const override {
     return device_attributes_.device_type();
   }
 
diff --git a/tensorflow/core/framework/device_base.cc b/tensorflow/core/framework/device_base.cc
index 3430414..8e33fc7 100644
--- a/tensorflow/core/framework/device_base.cc
+++ b/tensorflow/core/framework/device_base.cc
@@ -78,6 +78,11 @@
   std::abort();
 }
 
+const std::string& DeviceBase::device_type() const {
+  LOG(FATAL) << "DeviceBase does not implement device_type()";  // Crash OK
+  std::abort();
+}
+
 void DeviceBase::set_eigen_cpu_device(Eigen::ThreadPoolDevice* d) {
   // Eigen::ThreadPoolDevice is a very cheap struct (two pointers and
   // an int).  Therefore, we can afford a pre-allocated array of
diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h
index ada04e0..c8fbf9e 100644
--- a/tensorflow/core/framework/device_base.h
+++ b/tensorflow/core/framework/device_base.h
@@ -24,6 +24,7 @@
 #include "absl/strings/string_view.h"
 #include "tensorflow/core/framework/device_attributes.pb.h"
 #include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/refcount.h"
 #include "tensorflow/core/lib/core/status.h"
@@ -236,6 +237,7 @@
   virtual int NumaNode() const { return attributes().locality().numa_node(); }
   virtual const std::string& name() const;
   virtual const DeviceNameUtils::ParsedName& parsed_name() const;
+  virtual const std::string& device_type() const;
 
   // Updates `attributes()`, indicating the XLA global ID associated with this
   // device. This ID is unique across clients in a multi-client setup. For TPUs
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index 7dbfc4f..9e2925d 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -24,6 +24,7 @@
 #include <vector>
 
 #include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
 #include "absl/strings/escaping.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/str_format.h"
@@ -1405,6 +1406,17 @@
   return status;
 }
 
+Status FunctionLibraryDefinition::AddFunctionDef(
+    FunctionDef&& fdef, StackTracesMap&& stack_traces) {
+  mutex_lock l(mu_);
+  bool added;
+  FunctionRecord* record =
+      new FunctionRecord(std::move(fdef), std::move(stack_traces), true);
+  core::ScopedUnref scoped_unref(record);
+  Status status = AddHelper(record, &added);
+  return status;
+}
+
 Status FunctionLibraryDefinition::AddFunctionDefHelper(
     FunctionDef&& fdef, StackTracesMap&& stack_traces, bool* added) {
   FunctionRecord* record =
@@ -1464,12 +1476,16 @@
           "Cannot copy function '", name,
           "' because a different function with the same name already "
           "exists.");
+    } else {
+      return OkStatus();
     }
+  } else if (other_record->finalized()) {
+    bool added;
+    mutex_lock l(mu_);
+    return AddHelper(other_record.get(), &added);
   } else {
-    TF_RETURN_IF_ERROR(
-        AddFunctionDef(other_record->fdef(), other_record->stack_traces()));
+    return AddFunctionDef(other_record->fdef(), other_record->stack_traces());
   }
-  return OkStatus();
 }
 
 Status FunctionLibraryDefinition::AddGradientDef(const GradientDef& grad) {
@@ -1945,6 +1961,20 @@
   return ReachableFunctionLibraryDefinition(*this, func.node_def());
 }
 
+StatusOr<FunctionLibraryDefinition>
+FunctionLibraryDefinition::ReachableDefinitions(
+    const std::string& function_name) const {
+  auto* func = Find(function_name);
+  if (func) {
+    FunctionLibraryDefinition ret =
+        ReachableFunctionLibraryDefinition(*this, func->node_def());
+    TF_RETURN_IF_ERROR(ret.CopyFunctionDefFrom(function_name, *this));
+    return ret;
+  } else {
+    return absl::NotFoundError(function_name);
+  }
+}
+
 string FunctionLibraryRuntime::Options::DebugString() const {
   return absl::StrCat(
       "FLR::Options(step_id=", step_id, " rendezvous=", IsSet(rendezvous),
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index 2d31f6b..d234b20 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -451,6 +451,8 @@
   Status AddFunctionDef(const FunctionDef& fdef,
                         const StackTracesMap& stack_traces = {})
       TF_LOCKS_EXCLUDED(mu_);
+  Status AddFunctionDef(FunctionDef&& fdef, StackTracesMap&& stack_traces = {})
+      TF_LOCKS_EXCLUDED(mu_);
 
   // Adds gradient definition 'grad' to this function library.
   // This is a no-op if 'grad' already exists in this function library.
@@ -563,6 +565,8 @@
   // reachable from the nodes of `graph` or `func`.
   FunctionLibraryDefinition ReachableDefinitions(const GraphDef& graph) const;
   FunctionLibraryDefinition ReachableDefinitions(const FunctionDef& func) const;
+  StatusOr<FunctionLibraryDefinition> ReachableDefinitions(
+      const std::string& function_name) const;
 
   // Copies the function named `func` from `other` to this
   // FunctionLibraryDefinition.
diff --git a/tensorflow/core/framework/function_test.cc b/tensorflow/core/framework/function_test.cc
index 8a659b8..49d5d8b 100644
--- a/tensorflow/core/framework/function_test.cc
+++ b/tensorflow/core/framework/function_test.cc
@@ -1119,6 +1119,22 @@
   TF_EXPECT_OK(lib_def.AddFunctionDef(test::function::XTimesTwo()));
 }
 
+TEST(FunctionLibraryDefinitionTest, AddFunctionDefMove) {
+  FunctionLibraryDefinition lib_def(OpRegistry::Global(), FunctionDefLibrary());
+  FunctionDef fdef = test::function::XTimesTwo();
+  EXPECT_GT(fdef.node_def_size(), 0);
+  TF_CHECK_OK(lib_def.AddFunctionDef(std::move(fdef)));
+  // The protobuf move constructor will empty the node defs from the function.
+  EXPECT_EQ(fdef.node_def_size(), 0);  // NOLINT
+
+  // Test lookup of existing function.
+  const OpDef* op_def;
+  TF_EXPECT_OK(lib_def.LookUpOpDef("XTimesTwo", &op_def));
+  ASSERT_NE(op_def, nullptr);
+  EXPECT_EQ(op_def->DebugString(),
+            test::function::XTimesTwo().signature().DebugString());
+}
+
 TEST(FunctionLibraryDefinitionTest, AddGradientDef) {
   // AddGradientDef() doesn't check that functions referenced exist (yet?)
   FunctionLibraryDefinition lib_def(OpRegistry::Global(), FunctionDefLibrary());
diff --git a/tensorflow/core/framework/graph_to_functiondef.h b/tensorflow/core/framework/graph_to_functiondef.h
index 83e56ca..834bf50 100644
--- a/tensorflow/core/framework/graph_to_functiondef.h
+++ b/tensorflow/core/framework/graph_to_functiondef.h
@@ -60,13 +60,6 @@
                           const std::vector<std::string>& output_names,
                           FunctionDef* fdef);
 
-Status GetGraphAndArgRets(
-    const string& function_name, AttrSlice attrs, const FunctionDef* fdef,
-    const FunctionLibraryDefinition* lib_def, std::unique_ptr<Graph>* graph,
-    std::vector<Node*>* arg_nodes, std::vector<Node*>* ret_nodes,
-    std::vector<string>* ret_node_names, DataTypeVector* ret_types,
-    std::vector<string>* control_ret_node_names);
-
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CORE_FRAMEWORK_GRAPH_TO_FUNCTIONDEF_H_
diff --git a/tensorflow/core/framework/metrics.cc b/tensorflow/core/framework/metrics.cc
index 26b67ce..104af6d 100644
--- a/tensorflow/core/framework/metrics.cc
+++ b/tensorflow/core/framework/metrics.cc
@@ -219,6 +219,10 @@
         "/tensorflow/data/service/snapshot_bytes_committed",
         "tf.data service distributed snapshot committed bytes.");
 
+auto* tf_data_service_snapshot_ops_counter = tsl::monitoring::Counter<2>::New(
+    "/tensorflow/data/service/snapshot_ops",
+    "Number times a tf.data snapshot is saved/loaded.", "path", "op");
+
 auto* tf_data_service_data_transfer_protocol_used =
     tsl::monitoring::Counter<1>::New(
         "/tensorflow/data/service/data_transfer_protocol_used",
@@ -616,6 +620,11 @@
   tf_data_service_snapshot_bytes_committed->GetCell()->IncrementBy(bytes);
 }
 
+void RecordTFDataServiceSnapshotOp(const std::string& path,
+                                   const std::string& op) {
+  tf_data_service_snapshot_ops_counter->GetCell(path, op)->IncrementBy(1);
+}
+
 void RecordTFDataServiceOptimalNumberOfWorkers(int64_t number_of_workers) {
   tf_data_service_optimal_number_of_workers->GetCell()->Set(number_of_workers);
 }
diff --git a/tensorflow/core/framework/metrics.h b/tensorflow/core/framework/metrics.h
index 38c12be..5b15ee1 100644
--- a/tensorflow/core/framework/metrics.h
+++ b/tensorflow/core/framework/metrics.h
@@ -16,6 +16,7 @@
 #define TENSORFLOW_CORE_FRAMEWORK_METRICS_H_
 
 #include <cstdint>
+#include <string>
 
 #include "tensorflow/core/framework/dataset_options.pb.h"
 #include "tensorflow/core/lib/monitoring/counter.h"
@@ -180,9 +181,13 @@
 // Records tf.data service cross-trainer cache memory usage in bytes.
 void RecordTFDataServiceCrossTrainerCacheSizeBytes(size_t bytes);
 
-// Records distributed tf.data snapshot bytes committed.
+// Records tf.data distributed snapshot bytes committed.
 void RecordTFDataServiceSnapshotBytesCommitted(int64_t bytes);
 
+// Records tf.data distributed snapshot save/load ops.
+void RecordTFDataServiceSnapshotOp(const std::string& path,
+                                   const std::string& op);
+
 // Records the current estimated optimal number of tf.data service workers.
 void RecordTFDataServiceOptimalNumberOfWorkers(int64_t number_of_workers);
 
diff --git a/tensorflow/core/framework/rendezvous_test.cc b/tensorflow/core/framework/rendezvous_test.cc
index 21b593d..1212fad 100644
--- a/tensorflow/core/framework/rendezvous_test.cc
+++ b/tensorflow/core/framework/rendezvous_test.cc
@@ -511,6 +511,7 @@
       TF_CHECK_OK(rendez->Recv(KeyBar(), args, &bar, &is_dead));
     }
     CHECK_EQ("bar", V(bar));
+    rendez->Unref();
   }
   state.SetItemsProcessed(messages_count * state.iterations());
   delete pool;
diff --git a/tensorflow/core/framework/tensor_util.cc b/tensorflow/core/framework/tensor_util.cc
index ee19fc3..ec624d8 100644
--- a/tensorflow/core/framework/tensor_util.cc
+++ b/tensorflow/core/framework/tensor_util.cc
@@ -182,7 +182,7 @@
 }
 
 namespace internal {
-void SetTensorProtoShape(std::vector<size_t> shape,
+void SetTensorProtoShape(const absl::Span<const size_t> shape,
                          TensorShapeProto* shape_proto) {
   for (auto dim : shape) {
     shape_proto->mutable_dim()->Add()->set_size(dim);
diff --git a/tensorflow/core/framework/tensor_util.h b/tensorflow/core/framework/tensor_util.h
index 6e09b8e..213025d 100644
--- a/tensorflow/core/framework/tensor_util.h
+++ b/tensorflow/core/framework/tensor_util.h
@@ -66,7 +66,7 @@
              std::vector<Tensor>* result) TF_MUST_USE_RESULT;
 
 namespace internal {
-void SetTensorProtoShape(std::vector<size_t> shape,
+void SetTensorProtoShape(absl::Span<const size_t> shape,
                          TensorShapeProto* shape_proto);
 
 template <typename Type>
@@ -265,31 +265,58 @@
   }
 };
 
-}  // namespace internal
-
-// Creates a 'TensorProto' with specified shape and values.
-// The dtype and a field to represent data values of the returned 'TensorProto'
-// are determined based on type of the 'values' parameter.
-template <typename Type>
+template <typename Type, typename IterType>
 typename std::enable_if<internal::TensorProtoHelper<Type>::value,
                         TensorProto>::type
-CreateTensorProto(const std::vector<Type>& values,
-                  const std::vector<size_t>& shape) {
+CreateTensorProto(IterType values_begin, IterType values_end,
+                  const size_t values_size,
+                  const absl::Span<const size_t> shape) {
   TensorProto tensor;
   TensorShapeProto tensor_shape_proto;
   internal::SetTensorProtoShape(shape, &tensor_shape_proto);
-  if (TensorShape(tensor_shape_proto).num_elements() != values.size()) {
-    LOG(ERROR) << "Shape and number of values (" << values.size()
+  if (TensorShape(tensor_shape_proto).num_elements() != values_size) {
+    LOG(ERROR) << "Shape and number of values (" << values_size
                << ") are incompatible.";
     return tensor;
   }
   using TypeHelper = internal::TensorProtoHelper<Type>;
   tensor.set_dtype(TypeHelper::GetDataType());
-  tensor.mutable_tensor_shape()->Swap(&tensor_shape_proto);
-  TypeHelper::AddValues(values.begin(), values.end(), &tensor);
+  *tensor.mutable_tensor_shape() = std::move(tensor_shape_proto);
+  TypeHelper::AddValues(values_begin, values_end, &tensor);
   return tensor;
 }
 
+}  // namespace internal
+
+// Creates a 'TensorProto' with the specified shape and values. The dtype and a
+// field to represent data values of the returned 'TensorProto' are determined
+// based on Type. Note that unless the argument provided to `values` is already
+// an absl::Span, `Type` will need to be provided as a template parameter--the
+// compiler can't infer it:
+//   auto proto = CreateTensorProtoSpan<float>(my_array, shape);
+template <typename Type>
+typename std::enable_if<internal::TensorProtoHelper<Type>::value,
+                        TensorProto>::type
+CreateTensorProtoSpan(const absl::Span<const Type> values,
+                      const absl::Span<const size_t> shape) {
+  return internal::CreateTensorProto<Type>(values.begin(), values.end(),
+                                           values.size(), shape);
+}
+
+// Version of the above that's more convenient if `values` is an std::vector, in
+// which case Type can automatically be inferred:
+//   auto proto = CreateTensorProto(my_vector, shape);
+template <typename Type>
+typename std::enable_if<internal::TensorProtoHelper<Type>::value,
+                        TensorProto>::type
+CreateTensorProto(const std::vector<Type>& values,
+                  const absl::Span<const size_t> shape) {
+  // This awkward iterator passing is essentially just to support vector<bool>,
+  // otherwise we could just represent the vector as a Span.
+  return internal::CreateTensorProto<Type>(values.begin(), values.end(),
+                                           values.size(), shape);
+}
+
 // Converts values in tensor to run-length encoded compressed form.
 //
 // The elements of a tensor can be stored in a TensorProto in one of the
diff --git a/tensorflow/core/framework/tensor_util_test.cc b/tensorflow/core/framework/tensor_util_test.cc
index 1777ed4..51aa40e 100644
--- a/tensorflow/core/framework/tensor_util_test.cc
+++ b/tensorflow/core/framework/tensor_util_test.cc
@@ -294,6 +294,34 @@
   }
 }
 
+TEST(TensorProtoUtil, CreateTensorProtoSpan_string) {
+  // Don't use vector to trigger Span version.
+  string s[2] = {"a", "b"};
+  std::vector<size_t> shape{1, 2};
+  auto proto = tensor::CreateTensorProtoSpan<string>(s, shape);
+  TensorProto expected_tensor_proto;
+  expected_tensor_proto.set_dtype(DT_STRING);
+  expected_tensor_proto.mutable_tensor_shape()->add_dim()->set_size(1);
+  expected_tensor_proto.mutable_tensor_shape()->add_dim()->set_size(2);
+  expected_tensor_proto.add_string_val("a");
+  expected_tensor_proto.add_string_val("b");
+  EXPECT_EQ(proto.DebugString(), expected_tensor_proto.DebugString());
+}
+
+TEST(TensorProtoUtil, CreateTensorProtoSpan_int32) {
+  // Don't use vector to trigger Span version.
+  int32 s[2] = {123, 456};
+  std::vector<size_t> shape{1, 2};
+  auto proto = tensor::CreateTensorProtoSpan<int32>(s, shape);
+  TensorProto expected_tensor_proto;
+  expected_tensor_proto.set_dtype(DT_INT32);
+  expected_tensor_proto.mutable_tensor_shape()->add_dim()->set_size(1);
+  expected_tensor_proto.mutable_tensor_shape()->add_dim()->set_size(2);
+  expected_tensor_proto.add_int_val(123);
+  expected_tensor_proto.add_int_val(456);
+  EXPECT_EQ(proto.DebugString(), expected_tensor_proto.DebugString());
+}
+
 TEST(TensorProtoUtil, CreatesStringTensorProto) {
   std::vector<string> values{"a", "b", "c"};
   std::vector<size_t> shape{1, 3};
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD
index 84d94db..5519abc 100644
--- a/tensorflow/core/grappler/optimizers/data/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/BUILD
@@ -687,6 +687,7 @@
         "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
         "//tensorflow/core/grappler/utils:topological_sort",
         "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/log",
     ] + tf_protos_all(),
     alwayslink = 1,
 )
@@ -706,6 +707,9 @@
         "//tensorflow/core:testlib",
         "//tensorflow/core/grappler:grappler_item",
         "//tensorflow/core/kernels:control_flow_ops",
+        "//tensorflow/core/platform:status",
+        "@com_google_googletest//:gtest_main",
+        "@local_tsl//tsl/platform:errors",
     ],
 )
 
diff --git a/tensorflow/core/grappler/optimizers/data/map_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_fusion.cc
index 1ce431a..cfd6826 100644
--- a/tensorflow/core/grappler/optimizers/data/map_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_fusion.cc
@@ -16,7 +16,9 @@
 #include "tensorflow/core/grappler/optimizers/data/map_fusion.h"
 
 #include "absl/container/flat_hash_set.h"
+#include "absl/log/log.h"
 #include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/framework/node_def.pb.h"
 #include "tensorflow/core/grappler/clusters/cluster.h"
 #include "tensorflow/core/grappler/grappler_item.h"
@@ -29,11 +31,54 @@
 #include "tensorflow/core/grappler/utils/topological_sort.h"
 #include "tensorflow/core/lib/gtl/map_util.h"
 #include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/platform/types.h"
 
 namespace tensorflow {
 namespace grappler {
 namespace {
 
+constexpr char kMapDatasetOp[] = "MapDataset";
+constexpr char kParallelMapDatasetOp[] = "ParallelMapDatasetV2";
+constexpr char kDeterministicAttr[] = "deterministic";
+constexpr char kConstOp[] = "Const";
+constexpr char kValueAttr[] = "value";
+constexpr int kAutotuneValue = -1;
+
+// Returns true if it is a `tf.data.AUTOTUNE` node.
+bool IsAutotuneNode(const string& node_name, const MutableGraphView& graph) {
+  const NodeDef* node = graph.GetNode(node_name);
+  if (!node) return false;
+  if (node->op() != kConstOp) return false;
+
+  const auto* value = gtl::FindOrNull(node->attr(), kValueAttr);
+  if (!value) return false;
+
+  if (value->has_tensor()) {
+    if (value->tensor().int64_val_size()) {
+      return value->tensor().int64_val(0) == kAutotuneValue;
+    }
+  }
+
+  return false;
+}
+
+// Returns true if both parent and child parallel map nodes have same
+// `determistic` attr value.
+bool SameDeterministicAttr(const NodeDef& parallel_map_node,
+                           const NodeDef& parent_parallel_map_node) {
+  const auto* first_deterministic_val =
+      gtl::FindOrNull(parallel_map_node.attr(), kDeterministicAttr);
+  const auto* second_deterministic_val =
+      gtl::FindOrNull(parent_parallel_map_node.attr(), kDeterministicAttr);
+
+  if (first_deterministic_val && second_deterministic_val) {
+    return first_deterministic_val->s() == second_deterministic_val->s();
+  }
+
+  return false;
+}
+
 // Sets basic function parameters and copies attributes from parent and map
 // node.
 NodeDef MakeFusedNode(const NodeDef& parent_map_node, const NodeDef& map_node,
@@ -41,8 +86,15 @@
                       MutableGraphView* graph) {
   NodeDef fused_node;
   graph_utils::SetUniqueGraphNodeName("fused_map", graph->graph(), &fused_node);
-  fused_node.set_op("MapDataset");
-  fused_node.add_input(parent_map_node.input(0));
+
+  if (map_node.op() == kMapDatasetOp) {
+    fused_node.set_op(kMapDatasetOp);
+    fused_node.add_input(parent_map_node.input(0));  // `input_dataset`
+  } else if (map_node.op() == kParallelMapDatasetOp) {
+    fused_node.set_op(kParallelMapDatasetOp);
+    fused_node.add_input(parent_map_node.input(0));  // `input_dataset`
+    fused_node.add_input(parent_map_node.input(1));  // `num_parallel_calls`
+  }
 
   auto attr = parent_map_node.attr().at("f");
   *attr.mutable_func()->mutable_name() = fused_function.signature().name();
@@ -74,6 +126,10 @@
 
   graph_utils::MaybeSetFusedMetadata(parent_map_node, map_node, &fused_node);
 
+  if (map_node.op() == kParallelMapDatasetOp) {
+    graph_utils::CopyAttribute(kDeterministicAttr, map_node, &fused_node);
+  }
+
   return fused_node;
 }
 
@@ -87,15 +143,28 @@
   TF_RETURN_IF_ERROR(TopologicalSort(&sorted_old_graph));
   *output = sorted_old_graph;
 
+  if (!autotune_) {
+    VLOG(1) << "The optimization map_fusion is not applied if "
+               "autotune is off.";
+    return OkStatus();
+  }
+
   MutableGraphView graph(output);
   absl::flat_hash_set<string> nodes_to_delete;
   FunctionLibraryDefinition function_library(OpRegistry::Global(),
                                              item.graph.library());
 
-  auto get_map_node = [](const NodeDef& node) -> const NodeDef* {
+  auto get_map_node = [&graph](const NodeDef& node) -> const NodeDef* {
     // TODO(b/148614504): Support ParallelMapDataset and MapAndBatchDataset.
     // TODO(b/148614315): Support captured inputs.
-    if (node.op() == "MapDataset" && node.input_size() == 1) return &node;
+    if (node.op() == kMapDatasetOp && node.input_size() == 1) return &node;
+    // Only parallel map with no captured inputs (empty `other_arguments`) and
+    // parallelism set to "AUTOTUNE" would be eligible for rewrite.
+    if (node.op() == kParallelMapDatasetOp) {
+      if (node.input_size() != 2) return nullptr;
+      if (!IsAutotuneNode(node.input(1), graph)) return nullptr;
+      return &node;
+    }
     return nullptr;
   };
 
@@ -129,6 +198,15 @@
         get_map_node(*graph_utils::GetInputNode(*map_node, graph));
     if (!parent_map_node) continue;
 
+    // TODO(b/148614504): Support fusing different types of map operations.
+    if (parent_map_node->op() != map_node->op()) continue;
+
+    // TODO(b/148614504): Support fusing parallel map operations with different
+    // `deterministic` attr values.
+    if (map_node->op() == kParallelMapDatasetOp) {
+      if (!SameDeterministicAttr(*parent_map_node, *map_node)) continue;
+    }
+
     const auto* fused_function = make_fused_function(parent_map_node, map_node);
     if (fused_function == nullptr) continue;
 
diff --git a/tensorflow/core/grappler/optimizers/data/map_fusion.h b/tensorflow/core/grappler/optimizers/data/map_fusion.h
index 74549d0..7ce44d5 100644
--- a/tensorflow/core/grappler/optimizers/data/map_fusion.h
+++ b/tensorflow/core/grappler/optimizers/data/map_fusion.h
@@ -16,11 +16,14 @@
 #ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_FUSION_H_
 #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_FUSION_H_
 
+#include "tensorflow/core/framework/attr_value.pb.h"
 #include "tensorflow/core/grappler/optimizers/data/optimizer_base.h"
 
 namespace tensorflow {
 namespace grappler {
 
+constexpr char kAutotune[] = "autotune";
+
 // This optimization fuses map transformations by merging their map functions.
 class MapFusion : public TFDataOptimizerBase {
  public:
@@ -33,12 +36,26 @@
 
   Status Init(
       const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+    if (!config) return OkStatus();
+
+    const string& autotune = config->parameter_map().at(kAutotune).s();
+    if (autotune == "true") {
+      autotune_ = true;
+    } else if (autotune == "false") {
+      autotune_ = false;
+    } else {
+      return errors::InvalidArgument("Received an invalid value for parameter ",
+                                     kAutotune, ": ", autotune);
+    }
     return OkStatus();
   }
 
   Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item,
                                  GraphDef* output,
                                  OptimizationStats* stats) override;
+
+ private:
+  bool autotune_ = true;
 };
 
 }  // namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc
index 8889f9d..76b4524 100644
--- a/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc
@@ -15,21 +15,99 @@
 
 #include "tensorflow/core/grappler/optimizers/data/map_fusion.h"
 
+#include <functional>
+#include <memory>
+
+#include <gtest/gtest.h>
 #include "tensorflow/core/framework/attr_value_util.h"
 #include "tensorflow/core/framework/function_testlib.h"
-#include "tensorflow/core/framework/tensor_testutil.h"
 #include "tensorflow/core/grappler/grappler_item.h"
 #include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
-
 #include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/platform/test.h"
+#include "tsl/lib/core/status_test_util.h"
+#include "tsl/platform/errors.h"
 
 namespace tensorflow {
 namespace grappler {
 namespace {
 
 using graph_tests_utils::MakeMapNode;
+using graph_tests_utils::MakeParallelMapV2Node;
+
+constexpr char kConstOpName[] = "Const";
+
+NodeDef CreateScalarConstNodeHelper(
+    const std::string& node_name, DataType dtype,
+    const std::function<void(TensorProto*)>& add_value) {
+  NodeDef node;
+  node.set_op(kConstOpName);
+  node.set_name(node_name);
+
+  (*node.mutable_attr())["dtype"].set_type(dtype);
+  auto tensor = std::make_unique<tensorflow::TensorProto>();
+  auto tensor_shape = std::make_unique<tensorflow::TensorShapeProto>();
+  tensor->set_allocated_tensor_shape(tensor_shape.release());
+  tensor->set_dtype(dtype);
+  add_value(tensor.get());
+  (*node.mutable_attr())["value"].set_allocated_tensor(tensor.release());
+
+  return node;
+}
+
+Status OptimizeWithMapFusion(const GrapplerItem& item, GraphDef* output,
+                             bool autotune) {
+  MapFusion optimizer;
+  RewriterConfig_CustomGraphOptimizer config;
+  if (autotune) {
+    (*config.mutable_parameter_map())["autotune"].set_s("true");
+  } else {
+    (*config.mutable_parameter_map())["autotune"].set_s("false");
+  }
+  TF_RETURN_IF_ERROR(optimizer.Init(&config));
+  return optimizer.Optimize(nullptr, item, output);
+}
+
+class AutotuneSetting : public ::testing::TestWithParam<bool> {};
+
+TEST_P(AutotuneSetting, MapFusionTest) {
+  const bool autotune = GetParam();
+
+  using test::function::NDef;
+  GrapplerItem item;
+  NodeDef num_parallel_calls_node = CreateScalarConstNodeHelper(
+      "num_parallel_calls", DT_INT64,
+      [](TensorProto* proto) { proto->add_int64_val(-1); });
+  item.graph = test::function::GDef(
+      {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+       NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+       NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+       NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
+       num_parallel_calls_node,
+       MakeParallelMapV2Node("map1", "range", num_parallel_calls_node.name(),
+                             "XTimesTwo", "default"),
+       MakeParallelMapV2Node("map2", "map1", num_parallel_calls_node.name(),
+                             "XTimesTwo", "default")},
+      // FunctionLib
+      {
+          test::function::XTimesTwo(),
+      });
+
+  MapFusion optimizer;
+  GraphDef output;
+  TF_ASSERT_OK(OptimizeWithMapFusion(item, &output, autotune));
+  if (autotune) {
+    EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map1", output));
+    EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map2", output));
+  } else {
+    EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName("map1", output));
+    EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName("map2", output));
+  }
+}
+
+INSTANTIATE_TEST_SUITE_P(Test, AutotuneSetting, ::testing::Values(false, true));
 
 TEST(MapFusionTest, FuseTwoMapNodesIntoOne) {
   using test::function::NDef;
@@ -47,7 +125,7 @@
 
   MapFusion optimizer;
   GraphDef output;
-  TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+  TF_ASSERT_OK(OptimizeWithMapFusion(item, &output, true));
   EXPECT_TRUE(graph_utils::ContainsNodeWithOp("MapDataset", output));
   EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map1", output));
   EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map2", output));
@@ -72,13 +150,42 @@
 
   MapFusion optimizer;
   GraphDef output;
-  TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+  TF_ASSERT_OK(OptimizeWithMapFusion(item, &output, true));
   EXPECT_TRUE(graph_utils::ContainsNodeWithOp("MapDataset", output));
   EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map1", output));
   EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map2", output));
   EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map3", output));
 }
 
+TEST(MapFusionTest, FuseTwoParallelMapNodesIntoOne) {
+  using test::function::NDef;
+  GrapplerItem item;
+  NodeDef num_parallel_calls_node = CreateScalarConstNodeHelper(
+      "num_parallel_calls", DT_INT64,
+      [](TensorProto* proto) { proto->add_int64_val(-1); });
+  item.graph = test::function::GDef(
+      {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+       NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+       NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+       NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
+       num_parallel_calls_node,
+       MakeParallelMapV2Node("map1", "range", num_parallel_calls_node.name(),
+                             "XTimesTwo", "default"),
+       MakeParallelMapV2Node("map2", "map1", num_parallel_calls_node.name(),
+                             "XTimesTwo", "default")},
+      // FunctionLib
+      {
+          test::function::XTimesTwo(),
+      });
+
+  MapFusion optimizer;
+  GraphDef output;
+  TF_ASSERT_OK(OptimizeWithMapFusion(item, &output, true));
+  EXPECT_TRUE(graph_utils::ContainsNodeWithOp("ParallelMapDatasetV2", output));
+  EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map1", output));
+  EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map2", output));
+}
+
 }  // namespace
 }  // namespace grappler
 }  // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc
index 3762806..87bb504 100644
--- a/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc
@@ -43,10 +43,10 @@
     "disable_intra_op_parallelism",
     "use_private_thread_pool",
     "shuffle_and_repeat_fusion",
+    "map_parallelization",
     "map_fusion",
     "filter_fusion",
     "map_and_filter_fusion",
-    "map_parallelization",
     "map_and_batch_fusion",
     "batch_parallelization",
     "filter_parallelization",
diff --git a/tensorflow/core/grappler/optimizers/evaluation_utils.cc b/tensorflow/core/grappler/optimizers/evaluation_utils.cc
index e51eb2b..2a541f3 100644
--- a/tensorflow/core/grappler/optimizers/evaluation_utils.cc
+++ b/tensorflow/core/grappler/optimizers/evaluation_utils.cc
@@ -15,7 +15,9 @@
 
 #include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
 
+#include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/lib/core/threadpool.h"
 #include "tensorflow/core/platform/cpu_info.h"
 #include "tensorflow/core/platform/denormal.h"
@@ -69,7 +71,7 @@
   }
 
   std::unique_ptr<OpKernel> op_kernel(
-      CreateOpKernel("CPU", cpu_device, cpu_device->GetAllocator({}), node,
+      CreateOpKernel(DEVICE_CPU, cpu_device, cpu_device->GetAllocator({}), node,
                      TF_GRAPH_DEF_VERSION, &status));
   TF_RETURN_IF_ERROR(status);
   OpKernelContext::Params params;
diff --git a/tensorflow/core/grappler/optimizers/evaluation_utils.h b/tensorflow/core/grappler/optimizers/evaluation_utils.h
index dd7b877..a146c9a 100644
--- a/tensorflow/core/grappler/optimizers/evaluation_utils.h
+++ b/tensorflow/core/grappler/optimizers/evaluation_utils.h
@@ -22,6 +22,7 @@
 #include "tensorflow/core/framework/node_def.pb.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/lib/gtl/inlined_vector.h"
 
 namespace Eigen {
@@ -45,9 +46,12 @@
     return cpu_allocator();
   }
 
+  const std::string& device_type() const override { return device_type_; }
+
  private:
   DeviceBase::CpuWorkerThreads eigen_worker_threads_;
   std::unique_ptr<Eigen::ThreadPoolDevice> eigen_device_;
+  const std::string device_type_ = DEVICE_CPU;
 };
 
 Status EvaluateNode(const NodeDef& node,
diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc
index b7fe106..6b07197 100644
--- a/tensorflow/core/grappler/optimizers/function_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc
@@ -1361,7 +1361,8 @@
 
       TF_RETURN_IF_ERROR(InlineFunctionBody(flib_def, graph.get(), n,
                                             fbody.get(), inline_options));
-      inlined_function_names.push_back(fbody->fdef.signature().name());
+      inlined_function_names.push_back(
+          fbody->record->fdef().signature().name());
 
     } else {
       VLOG(2) << "Failed to inline function call node: "
diff --git a/tensorflow/core/grappler/optimizers/inference/BUILD b/tensorflow/core/grappler/optimizers/inference/BUILD
index e96fd3b..c9ccd64 100644
--- a/tensorflow/core/grappler/optimizers/inference/BUILD
+++ b/tensorflow/core/grappler/optimizers/inference/BUILD
@@ -15,11 +15,12 @@
     licenses = ["notice"],
 )
 
+# Expand the DEFAULT_VISIBILITY so that we can replace with public visibility with copybara.
 tf_proto_library(
     name = "batch_op_rewriter_proto",
     srcs = ["batch_op_rewriter.proto"],
     cc_api_version = 2,
-    visibility = DEFAULT_VISIBILITY,
+    visibility = ["//visibility:public"],
 )
 
 # copybara:uncomment_begin(google-only)
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index cce1aba..504620a 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -3593,58 +3593,6 @@
     deps = MATH_DEPS,
 )
 
-# This wrapper library requires special build flags, so must
-# by compiled separately.
-cc_library(
-    name = "fft_impl",
-    srcs = ["fft_impl.cc"],
-    hdrs = ["fft_impl.h"],
-    # DUCC requires exceptions and RTTI.
-    copts = [
-        "-fexceptions",
-        "-frtti",
-    ],
-    features = ["-use_header_modules"],
-    deps = [
-        "//tensorflow/core:framework",
-        "@com_google_absl//absl/status",
-        "@ducc//:fft",
-        "@eigen_archive//:eigen3",
-    ],
-)
-
-# This wrapper library requires special build flags, so must
-# by compiled separately.  This version is specific to "portable"
-# ops, where we have different linking requirements.
-cc_library(
-    name = "portable_fft_impl",
-    srcs = if_mobile([
-        "fft_impl.cc",
-        "@ducc//:mobile_srcs_no_runtime",
-    ]),
-    hdrs = ["fft_impl.h"],
-    copts = [
-        # DUCC requires exceptions and RTTI.
-        "-fexceptions",
-        "-frtti",
-        # DUCC custom threading.
-        "-DDUCC0_CUSTOM_LOWLEVEL_THREADING=1",
-        # Required for relative includes in the DUCC library.
-        "-Ithird_party/ducc/src",
-        "-Ithird_party/ducc/google",
-        # Required for relative includes in OSS builds.
-        "-Iexternal",
-        "-Iexternal/ducc/src",
-        "-Iexternal/ducc/google",
-    ],
-    features = ["-use_header_modules"],
-    deps = [
-        "//tensorflow/core:portable_tensorflow_lib_lite",
-        "@com_google_absl//absl/status",
-        "@eigen_archive//:eigen3",
-    ],
-)
-
 tf_kernel_library(
     name = "fft_ops",
     prefix = "fft_ops",
@@ -3652,7 +3600,7 @@
         "@com_google_absl//absl/container:flat_hash_map",
         "//tensorflow/core/platform:stream_executor",
         "@local_xla//xla/stream_executor/cuda:cufft_plugin",
-    ]) + [":fft_impl"],
+    ]) + ["@ducc//:fft_wrapper"],
 )
 
 tf_kernel_library(
@@ -3832,7 +3780,6 @@
 
 tf_cuda_cc_test(
     name = "matmul_op_test",
-    size = "small",
     srcs = ["matmul_op_test.cc"],
     tags = [
         "no_arm64",  # b/282068262
@@ -6976,8 +6923,6 @@
             "decode_proto_op.cc",
             "encode_proto_op.cc",
             "sobol_op.cc",
-            # Excluded because require special build rules:
-            "fft_impl.cc",  # Must instead link in ":portable_fft_impl".
             # Excluded due to experimental status:
             "debug_ops.*",
             "mutex_ops.*",
@@ -7042,7 +6987,6 @@
     textual_hdrs = ANDROID_TEXTUAL_HDRS,
     visibility = ["//visibility:public"],
     deps = [
-        ":portable_fft_impl",
         "//tensorflow/core:portable_gif_internal",
         "//tensorflow/core:portable_jpeg_internal",
         "//tensorflow/core:portable_tensorflow_lib_lite",
@@ -7074,6 +7018,7 @@
         "@com_google_absl//absl/log:check",
         "@com_google_absl//absl/strings",
         "@com_google_protobuf//:protobuf",
+        "@ducc//:fft_wrapper",
         "@eigen_archive//:eigen3",
         "@fft2d",
         "@gemmlowp",
diff --git a/tensorflow/core/kernels/batching_util/BUILD b/tensorflow/core/kernels/batching_util/BUILD
index c1c7d3995..8b8bb86 100644
--- a/tensorflow/core/kernels/batching_util/BUILD
+++ b/tensorflow/core/kernels/batching_util/BUILD
@@ -366,6 +366,7 @@
         "//tensorflow/core/profiler/lib:traceme_encode",
         "//tensorflow/core/protobuf:for_core_protos_cc",
         "//tensorflow/core/util:incremental_barrier",
+        "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/synchronization",
         "@com_google_absl//absl/time",
@@ -402,7 +403,6 @@
         "//tensorflow/core/common_runtime:no_op_cost_measurement",
         "//tensorflow/core/common_runtime:request_cost",
         "//tensorflow/core/framework:types_proto_cc",
-        "@com_google_absl//absl/memory",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/time",
         "@com_google_googletest//:gtest_main",
diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.cc b/tensorflow/core/kernels/batching_util/batch_resource_base.cc
index dcbc994..c1395ed 100644
--- a/tensorflow/core/kernels/batching_util/batch_resource_base.cc
+++ b/tensorflow/core/kernels/batching_util/batch_resource_base.cc
@@ -26,6 +26,7 @@
 #include <utility>
 #include <vector>
 
+#include "absl/container/flat_hash_map.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/string_view.h"
 #include "absl/synchronization/blocking_counter.h"
@@ -48,7 +49,9 @@
 #include "tensorflow/core/lib/monitoring/gauge.h"
 #include "tensorflow/core/lib/monitoring/percentile_sampler.h"
 #include "tensorflow/core/lib/monitoring/sampler.h"
+#include "tensorflow/core/lib/monitoring/types.h"
 #include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/lib/traceme.h"
 #include "tensorflow/core/profiler/lib/traceme_encode.h"
 #include "tensorflow/core/util/incremental_barrier.h"
@@ -831,8 +834,7 @@
   Status status;
   bool cleanup_done = false;
   int64_t processed_size = batch->size();
-  auto cleanup_fn = [&cleanup_done, &batch, &processed_size,
-                     &batch_cost_measurements](const Status& status) {
+  auto cleanup_fn = [&](const Status& status) {
     if (cleanup_done) {
       return;
     }
@@ -1054,10 +1056,11 @@
 }
 
 void BatchResourceBase::SplitBatchCostsAndRecordMetrics(
-    std::vector<std::unique_ptr<CostMeasurement>>& batch_cost_measurements,
+    const std::vector<std::unique_ptr<CostMeasurement>>&
+        batch_cost_measurements,
     const int64_t processed_size, BatchT& batch) {
   // 1. Split the batch costs to each task.
-  for (auto& batch_cost_measurement : batch_cost_measurements) {
+  for (const auto& batch_cost_measurement : batch_cost_measurements) {
     if (batch_cost_measurement->GetTotalCost() <= absl::ZeroDuration()) {
       continue;
     }
@@ -1096,15 +1099,21 @@
 
   // 2. Records the batch metrics in each task.
   const int64_t padding_size = processed_size - batch.size();
+  absl::flat_hash_map<std::string, absl::Duration> batch_costs;
+  for (const auto& batch_cost_measurement : batch_cost_measurements) {
+    if (batch_cost_measurement->GetTotalCost() > absl::ZeroDuration()) {
+      batch_costs[batch_cost_measurement->GetCostType()] =
+          batch_cost_measurement->GetTotalCost();
+    }
+  }
   for (int i = 0; i < batch.num_tasks(); i++) {
     RequestCost* request_cost = batch.task(i).request_cost;
     // Skip recording the metrics if the request_cost is null.
     if (!request_cost) continue;
 
     request_cost->RecordBatchMetrics(RequestCost::BatchMetrics{
-        /*processed_size=*/processed_size,
-        /*input_size=*/static_cast<int64_t>(batch.task(i).size()),
-        /*padding_size=*/padding_size});
+        processed_size, static_cast<int64_t>(batch.task(i).size()),
+        padding_size, batch_costs});
   }
 }
 
diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.h b/tensorflow/core/kernels/batching_util/batch_resource_base.h
index b7e634f..5124e9f 100644
--- a/tensorflow/core/kernels/batching_util/batch_resource_base.h
+++ b/tensorflow/core/kernels/batching_util/batch_resource_base.h
@@ -238,7 +238,8 @@
   //   2) the input size from this task;
   //   3) the padding amount.
   static void SplitBatchCostsAndRecordMetrics(
-      std::vector<std::unique_ptr<CostMeasurement>>& batch_cost_measurements,
+      const std::vector<std::unique_ptr<CostMeasurement>>&
+          batch_cost_measurements,
       int64_t processed_size, BatchT& batch);
 
  private:
diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base_test.cc b/tensorflow/core/kernels/batching_util/batch_resource_base_test.cc
index b7d4d6d..dc75fde 100644
--- a/tensorflow/core/kernels/batching_util/batch_resource_base_test.cc
+++ b/tensorflow/core/kernels/batching_util/batch_resource_base_test.cc
@@ -21,7 +21,6 @@
 
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
-#include "absl/memory/memory.h"
 #include "absl/strings/string_view.h"
 #include "absl/time/time.h"
 #include "tensorflow/core/common_runtime/cost_measurement.h"
@@ -75,10 +74,10 @@
                                                      /*processed_size=*/16,
                                                      batch);
   EXPECT_TRUE(batch.task(0).request_cost->GetCosts().empty());
-  EXPECT_THAT(
-      batch.task(0).request_cost->GetBatchMetrics(),
-      ::testing::ElementsAre(::testing::FieldsAre(
-          /*processed_size=*/16, /*input_size=*/1, /*padding_size=*/15)));
+  EXPECT_THAT(batch.task(0).request_cost->GetBatchMetrics(),
+              ::testing::ElementsAre(::testing::FieldsAre(
+                  /*processed_size=*/16, /*input_size=*/1, /*padding_size=*/15,
+                  ::testing::IsEmpty())));
 }
 
 TEST(SplitBatchCostsAndRecordMetricsTest, SkipOnZeroCost) {
@@ -95,10 +94,10 @@
                                                      /*processed_size=*/16,
                                                      batch);
   EXPECT_TRUE(batch.task(0).request_cost->GetCosts().empty());
-  EXPECT_THAT(
-      batch.task(0).request_cost->GetBatchMetrics(),
-      ::testing::ElementsAre(::testing::FieldsAre(
-          /*processed_size=*/16, /*input_size=*/1, /*padding_size=*/15)));
+  EXPECT_THAT(batch.task(0).request_cost->GetBatchMetrics(),
+              ::testing::ElementsAre(::testing::FieldsAre(
+                  /*processed_size=*/16, /*input_size=*/1, /*padding_size=*/15,
+                  ::testing::IsEmpty())));
 }
 
 TEST(SplitBatchCostsAndRecordMetricsTest, SkipOnZeroBatchSize) {
@@ -154,7 +153,8 @@
   EXPECT_THAT(
       batch.task(0).request_cost->GetBatchMetrics(),
       ::testing::ElementsAre(::testing::FieldsAre(
-          /*processed_size=*/20, /*input_size=*/1, /*padding_size=*/10)));
+          /*processed_size=*/20, /*input_size=*/1, /*padding_size=*/10,
+          UnorderedElementsAre(Pair("test_tpu", absl::Milliseconds(100))))));
   EXPECT_THAT(
       batch.task(1).request_cost->GetCosts(),
       UnorderedElementsAre(Pair("test_tpu_with_smear", absl::Milliseconds(90)),
@@ -162,7 +162,8 @@
   EXPECT_THAT(
       batch.task(1).request_cost->GetBatchMetrics(),
       ::testing::ElementsAre(::testing::FieldsAre(
-          /*processed_size=*/20, /*input_size=*/9, /*padding_size=*/10)));
+          /*processed_size=*/20, /*input_size=*/9, /*padding_size=*/10,
+          UnorderedElementsAre(Pair("test_tpu", absl::Milliseconds(100))))));
 }
 
 TEST(SplitBatchCostsAndRecordMetricsTest, SplitMultiCostTypes) {
@@ -191,7 +192,9 @@
   EXPECT_THAT(
       batch.task(0).request_cost->GetBatchMetrics(),
       ::testing::ElementsAre(::testing::FieldsAre(
-          /*processed_size=*/20, /*input_size=*/1, /*padding_size=*/10)));
+          /*processed_size=*/20, /*input_size=*/1, /*padding_size=*/10,
+          UnorderedElementsAre(Pair("test_tpu", absl::Milliseconds(100)),
+                               Pair("test_gcu", absl::Milliseconds(200))))));
 
   EXPECT_THAT(
       batch.task(1).request_cost->GetCosts(),
@@ -202,7 +205,9 @@
   EXPECT_THAT(
       batch.task(1).request_cost->GetBatchMetrics(),
       ::testing::ElementsAre(::testing::FieldsAre(
-          /*processed_size=*/20, /*input_size=*/9, /*padding_size=*/10)));
+          /*processed_size=*/20, /*input_size=*/9, /*padding_size=*/10,
+          UnorderedElementsAre(Pair("test_tpu", absl::Milliseconds(100)),
+                               Pair("test_gcu", absl::Milliseconds(200))))));
 }
 
 TEST(SplitBatchCostsAndRecordMetricsTest, SplitOnlyNonZeroCostTypes) {
@@ -229,7 +234,8 @@
   EXPECT_THAT(
       batch.task(0).request_cost->GetBatchMetrics(),
       ::testing::ElementsAre(::testing::FieldsAre(
-          /*processed_size=*/20, /*input_size=*/1, /*padding_size=*/10)));
+          /*processed_size=*/20, /*input_size=*/1, /*padding_size=*/10,
+          UnorderedElementsAre(Pair("test_tpu", absl::Milliseconds(100))))));
 
   EXPECT_THAT(
       batch.task(1).request_cost->GetCosts(),
@@ -238,7 +244,8 @@
   EXPECT_THAT(
       batch.task(1).request_cost->GetBatchMetrics(),
       ::testing::ElementsAre(::testing::FieldsAre(
-          /*processed_size=*/20, /*input_size=*/9, /*padding_size=*/10)));
+          /*processed_size=*/20, /*input_size=*/9, /*padding_size=*/10,
+          UnorderedElementsAre(Pair("test_tpu", absl::Milliseconds(100))))));
 }
 
 }  // namespace
diff --git a/tensorflow/core/kernels/conv_grad_filter_ops_3d.cc b/tensorflow/core/kernels/conv_grad_filter_ops_3d.cc
index ff4933d..454e7a7 100644
--- a/tensorflow/core/kernels/conv_grad_filter_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_grad_filter_ops_3d.cc
@@ -742,7 +742,8 @@
     OP_REQUIRES_OK(context, stream->ThenBlasGemm(
                                 se::blas::Transpose::kNoTranspose,
                                 se::blas::Transpose::kTranspose, n, m, k, a_ptr,
-                                n, b_ptr, m, &c_ptr, n, GetNumericOptions()));
+                                n, b_ptr, m, &c_ptr, n, GetNumericOptions(),
+                                se::blas::CallContext::kNone));
     return;
   } else if (!is_grouped_convolution &&
              dims.filter_size(0) == dims.input_size(0) &&
@@ -764,7 +765,8 @@
     OP_REQUIRES_OK(context, stream->ThenBlasGemm(
                                 se::blas::Transpose::kNoTranspose,
                                 se::blas::Transpose::kTranspose, n, m, k, b_ptr,
-                                n, a_ptr, m, &c_ptr, n, GetNumericOptions()));
+                                n, a_ptr, m, &c_ptr, n, GetNumericOptions(),
+                                se::blas::CallContext::kNone));
     return;
   }
 
diff --git a/tensorflow/core/kernels/conv_grad_filter_ops_launcher.cc b/tensorflow/core/kernels/conv_grad_filter_ops_launcher.cc
index 499a961..1c1472e 100644
--- a/tensorflow/core/kernels/conv_grad_filter_ops_launcher.cc
+++ b/tensorflow/core/kernels/conv_grad_filter_ops_launcher.cc
@@ -264,7 +264,8 @@
     OP_REQUIRES_OK(ctx, stream->ThenBlasGemm(se::blas::Transpose::kNoTranspose,
                                              se::blas::Transpose::kTranspose, n,
                                              m, k, a_ptr, n, b_ptr, m, &c_ptr,
-                                             n, GetNumericOptions()));
+                                             n, GetNumericOptions(),
+                                             se::blas::CallContext::kNone));
     return;
   } else if (dims.spatial_dims[0].filter_size ==
                  dims.spatial_dims[0].input_size &&
@@ -289,7 +290,8 @@
     OP_REQUIRES_OK(ctx, stream->ThenBlasGemm(se::blas::Transpose::kNoTranspose,
                                              se::blas::Transpose::kTranspose, n,
                                              m, k, b_ptr, n, a_ptr, m, &c_ptr,
-                                             n, GetNumericOptions()));
+                                             n, GetNumericOptions(),
+                                             se::blas::CallContext::kNone));
     return;
   }
 
diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc
index cbf56bc..3278556 100644
--- a/tensorflow/core/kernels/conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_input_ops.cc
@@ -159,7 +159,8 @@
 
     OP_REQUIRES_OK(
         ctx, stream->ThenBlasGemm(transpose, no_transpose, n, m, k, b_ptr, k,
-                                  a_ptr, k, &c_ptr, n, GetNumericOptions()));
+                                  a_ptr, k, &c_ptr, n, GetNumericOptions(),
+                                  se::blas::CallContext::kNone));
     return;
   } else if (dims.spatial_dims[0].filter_size ==
                  dims.spatial_dims[0].input_size &&
@@ -186,7 +187,8 @@
 
     OP_REQUIRES_OK(
         ctx, stream->ThenBlasGemm(transpose, no_transpose, n, m, k, b_ptr, k,
-                                  a_ptr, k, &c_ptr, n, GetNumericOptions()));
+                                  a_ptr, k, &c_ptr, n, GetNumericOptions(),
+                                  se::blas::CallContext::kNone));
     return;
   }
 
diff --git a/tensorflow/core/kernels/conv_grad_input_ops_3d.cc b/tensorflow/core/kernels/conv_grad_input_ops_3d.cc
index e3cda75..06cf67d 100644
--- a/tensorflow/core/kernels/conv_grad_input_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_grad_input_ops_3d.cc
@@ -743,7 +743,8 @@
 
     OP_REQUIRES_OK(context, stream->ThenBlasGemm(transpose, no_transpose, n, m,
                                                  k, b_ptr, k, a_ptr, k, &c_ptr,
-                                                 n, GetNumericOptions()));
+                                                 n, GetNumericOptions(),
+                                                 se::blas::CallContext::kNone));
     return;
   } else if (!is_grouped_convolution &&
              dims.filter_size(0) == dims.input_size(0) &&
@@ -767,7 +768,8 @@
 
     OP_REQUIRES_OK(context, stream->ThenBlasGemm(transpose, no_transpose, n, m,
                                                  k, b_ptr, k, a_ptr, k, &c_ptr,
-                                                 n, GetNumericOptions()));
+                                                 n, GetNumericOptions(),
+                                                 se::blas::CallContext::kNone));
     return;
   }
 
diff --git a/tensorflow/core/kernels/conv_ops_impl.h b/tensorflow/core/kernels/conv_ops_impl.h
index 6a3773b..432093b 100644
--- a/tensorflow/core/kernels/conv_ops_impl.h
+++ b/tensorflow/core/kernels/conv_ops_impl.h
@@ -805,9 +805,10 @@
                                 output->template flat<T>().size());
 
     auto no_transpose = se::blas::Transpose::kNoTranspose;
-    OP_REQUIRES_OK(context, stream->ThenBlasGemm(
-                                no_transpose, no_transpose, n, m, k, b_ptr, n,
-                                a_ptr, k, &c_ptr, n, GetNumericOptions()));
+    OP_REQUIRES_OK(context, stream->ThenBlasGemm(no_transpose, no_transpose, n,
+                                                 m, k, b_ptr, n, a_ptr, k,
+                                                 &c_ptr, n, GetNumericOptions(),
+                                                 se::blas::CallContext::kNone));
     return;
   } else if (!is_grouped_convolution && filter_same_dims && padding == VALID &&
              data_format == FORMAT_NHWC) {
@@ -826,9 +827,10 @@
                                 output->template flat<T>().size());
 
     auto no_transpose = se::blas::Transpose::kNoTranspose;
-    OP_REQUIRES_OK(context, stream->ThenBlasGemm(
-                                no_transpose, no_transpose, n, m, k, b_ptr, n,
-                                a_ptr, k, &c_ptr, n, GetNumericOptions()));
+    OP_REQUIRES_OK(context, stream->ThenBlasGemm(no_transpose, no_transpose, n,
+                                                 m, k, b_ptr, n, a_ptr, k,
+                                                 &c_ptr, n, GetNumericOptions(),
+                                                 se::blas::CallContext::kNone));
     return;
   }
 
diff --git a/tensorflow/core/kernels/data/experimental/distributed_save_op.cc b/tensorflow/core/kernels/data/experimental/distributed_save_op.cc
index cc17235..ed04719 100644
--- a/tensorflow/core/kernels/data/experimental/distributed_save_op.cc
+++ b/tensorflow/core/kernels/data/experimental/distributed_save_op.cc
@@ -24,6 +24,7 @@
 #include "tensorflow/core/data/service/grpc_util.h"
 #include "tensorflow/core/data/service/py_utils.h"
 #include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/metrics.h"
 #include "tensorflow/core/protobuf/snapshot.pb.h"
 #include "tsl/lib/io/compression.h"
 
@@ -101,6 +102,7 @@
           /*description=*/
           strings::StrCat("save with tf.data service dispatcher at ", address),
           deadline_micros));
+  metrics::RecordTFDataServiceSnapshotOp(directory, kDistributedSave);
 }
 
 REGISTER_KERNEL_BUILDER(Name(kDistributedSave).Device(DEVICE_CPU),
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
index aa24c01..2ae2829 100644
--- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
@@ -931,7 +931,12 @@
     void CurrentWorkerThread(std::shared_ptr<IteratorContext> ctx)
         TF_LOCKS_EXCLUDED(mu_) {
       RecordStart(ctx.get());
+      std::shared_ptr<Element> element;
       auto done = [&]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+        // Release the shared ownership so that
+        // the iterator managed by `element` is guaranteed destroyed
+        // before this class instance members.
+        element.reset();
         RecordStop(ctx.get());
         DecrementActiveWorkers();
         DecrementCurrentActiveWorkers();
@@ -940,7 +945,7 @@
       };
       while (true) {
         int element_index;
-        std::shared_ptr<Element> element;
+        element.reset();
         // Find an element to process.
         {
           mutex_lock l(*mu_);
@@ -1000,12 +1005,16 @@
     void FutureWorkerThread(std::shared_ptr<IteratorContext> ctx)
         TF_LOCKS_EXCLUDED(mu_) {
       RecordStart(ctx.get());
+      std::shared_ptr<Element> element;
       auto done = [&]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+        // Release the shared ownership so that
+        // the iterator managed by `element` is guaranteed destroyed
+        // before this class instance members.
+        element.reset();
         RecordStop(ctx.get());
         DecrementActiveWorkers();
         DecrementOutstandingThreads();
       };
-      std::shared_ptr<Element> element;
       while (true) {
         {
           mutex_lock l(*mu_);
diff --git a/tensorflow/core/kernels/decode_raw_op.cc b/tensorflow/core/kernels/decode_raw_op.cc
index 38bfe48..f763bec 100644
--- a/tensorflow/core/kernels/decode_raw_op.cc
+++ b/tensorflow/core/kernels/decode_raw_op.cc
@@ -82,8 +82,7 @@
     // output type is a single byte, we can copy the memory directly.
     if (!convert_data_endianness_ || sizeof(T) == 1) {
       for (int64_t i = 0; i < flat_in.size(); ++i) {
-        const T* in_data = reinterpret_cast<const T*>(flat_in(i).data());
-        memcpy(out_data, in_data, str_size);
+        memcpy(out_data, flat_in(i).data(), str_size);
         out_data += added_dim;
       }
     } else {
diff --git a/tensorflow/core/kernels/fft_impl.cc b/tensorflow/core/kernels/fft_impl.cc
deleted file mode 100644
index 6ee859f..0000000
--- a/tensorflow/core/kernels/fft_impl.cc
+++ /dev/null
@@ -1,155 +0,0 @@
-/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#define EIGEN_USE_THREADS
-
-#include "tensorflow/core/kernels/fft_impl.h"  // NOLINT: declarations.
-
-#include <complex>
-#include <cstddef>
-#include <cstdint>
-#include <vector>
-
-#include "absl/status/status.h"
-#include "absl/strings/str_cat.h"
-#include "ducc/google/threading.h"  // from @ducc
-#include "ducc/src/ducc0/fft/fft.h"  // from @ducc
-#include "ducc/src/ducc0/fft/fft1d_impl.h"  // from @ducc  // NOLINT: DUCC definitions.
-#include "ducc/src/ducc0/fft/fftnd_impl.h"  // from @ducc  // NOLINT: DUCC definitions.
-#include "ducc/src/ducc0/infra/mav.h"  // from @ducc
-#include "ducc/src/ducc0/infra/threading.h"  // from @ducc
-#include "unsupported/Eigen/CXX11/Tensor"  // from @eigen_archive
-#include "tensorflow/core/framework/op_requires.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tsl/framework/numeric_types.h"
-
-namespace tensorflow {
-namespace internal {
-
-using CPUDevice = Eigen::ThreadPoolDevice;
-
-template <>
-absl::Status FftImpl<CPUDevice>(const CPUDevice& device, const Tensor& in,
-                                Tensor* out, const uint64_t* fft_shape,
-                                const std::vector<size_t>& axes, bool forward) {
-  const size_t fft_rank = axes.size();
-  ducc0::fmav_info::shape_t in_shape(in.dims());
-  ducc0::fmav_info::stride_t in_stride(in.dims());
-  ducc0::fmav_info::shape_t out_shape(out->dims());
-  ducc0::fmav_info::stride_t out_stride(out->dims());
-
-  size_t next_stride = 1;
-  for (int i = in.dims(); i-- > 0;) {
-    in_shape[i] = in.dim_size(i);
-    in_stride[i] = next_stride;
-    next_stride *= in_shape[i];
-  }
-  next_stride = 1;
-  for (int i = out->dims(); i-- > 0;) {
-    out_shape[i] = out->dim_size(i);
-    out_stride[i] = next_stride;
-    next_stride *= out_shape[i];
-  }
-
-  // DUCC doesn't handle the case where fft_size[i] < input_size[i],
-  // so manually adjust inputs if required.  If doing irfft, the limit
-  // of the last axis is actually fft_size[i]/2 + 1.
-  const bool is_iffrt = !(forward || out->dtype() == DT_COMPLEX128 ||
-                          out->dtype() == DT_COMPLEX64);
-  for (int i = 0; i < fft_rank; ++i) {
-    int limit = (is_iffrt && (i == (fft_rank - 1))) ? fft_shape[i] / 2 + 1
-                                                    : fft_shape[i];
-    if (in_shape[axes[i]] > limit) {
-      in_shape[axes[i]] = limit;
-    }
-  }
-
-  double inv_scale = 1.0;
-  for (int i = 0; i < fft_rank; ++i) {
-    inv_scale *= out_shape[axes[i]];
-  }
-  double scale = forward ? 1.0 : 1.0 / inv_scale;
-
-  // Set DUCC to use the current device threadpool.  Since this is a
-  // thread-local setting, this is thread-safe.
-  ducc0::google::EigenThreadPool thread_pool(*device.getPool());
-  ducc0::detail_threading::ScopedUseThreadPool thread_pool_guard(thread_pool);
-  size_t nthreads = thread_pool.nthreads();
-
-  try {
-    if (in.dtype() == DT_COMPLEX128 && out->dtype() == DT_COMPLEX128) {
-      auto input = in.template flat<complex128>();
-      auto output = out->template flat<complex128>();
-      ducc0::cfmav<std::complex<double>> m_in(input.data(), in_shape,
-                                              in_stride);
-      ducc0::vfmav<std::complex<double>> m_out(output.data(), out_shape,
-                                               out_stride);
-      ducc0::c2c<double>(m_in, m_out, axes, forward, scale, nthreads);
-    } else if (in.dtype() == DT_COMPLEX64 && out->dtype() == DT_COMPLEX64) {
-      auto input = in.flat<complex64>();
-      auto output = out->flat<complex64>();
-      ducc0::cfmav<std::complex<float>> m_in(input.data(), in_shape, in_stride);
-      ducc0::vfmav<std::complex<float>> m_out(output.data(), out_shape,
-                                              out_stride);
-      ducc0::c2c<float>(m_in, m_out, axes, forward, static_cast<float>(scale),
-                        nthreads);
-    } else if (in.dtype() == DT_DOUBLE && out->dtype() == DT_COMPLEX128 &&
-               forward) {
-      auto input = in.flat<double>();
-      auto output = out->flat<complex128>();
-      ducc0::cfmav<double> m_in(input.data(), in_shape, in_stride);
-      ducc0::vfmav<std::complex<double>> m_out(output.data(), out_shape,
-                                               out_stride);
-      ducc0::r2c<double>(m_in, m_out, axes, forward, scale, nthreads);
-    } else if (in.dtype() == DT_FLOAT && out->dtype() == DT_COMPLEX64 &&
-               forward) {
-      auto input = in.flat<float>();
-      auto output = out->flat<complex64>();
-      ducc0::cfmav<float> m_in(input.data(), in_shape, in_stride);
-      ducc0::vfmav<std::complex<float>> m_out(output.data(), out_shape,
-                                              out_stride);
-      ducc0::r2c<float>(m_in, m_out, axes, forward, static_cast<float>(scale),
-                        nthreads);
-    } else if (in.dtype() == DT_COMPLEX128 && out->dtype() == DT_DOUBLE &&
-               !forward) {
-      auto input = in.flat<complex128>();
-      auto output = out->flat<double>();
-      ducc0::cfmav<std::complex<double>> m_in(input.data(), in_shape,
-                                              in_stride);
-      ducc0::vfmav<double> m_out(output.data(), out_shape, out_stride);
-      ducc0::c2r<double>(m_in, m_out, axes, forward, scale, nthreads);
-    } else if (in.dtype() == DT_COMPLEX64 && out->dtype() == DT_FLOAT &&
-               !forward) {
-      auto input = in.flat<complex64>();
-      auto output = out->flat<float>();
-      ducc0::cfmav<std::complex<float>> m_in(input.data(), in_shape, in_stride);
-      ducc0::vfmav<float> m_out(output.data(), out_shape, out_stride);
-      ducc0::c2r<float>(m_in, m_out, axes, forward, static_cast<float>(scale),
-                        nthreads);
-    } else {
-      return absl::InvalidArgumentError(
-          absl::StrCat("Invalid FFT parameters, in.dtype=", in.dtype(),
-                       ", out->dtype=", out->dtype(), ", forward=", forward));
-    }
-  } catch (const std::runtime_error& ex) {
-    return absl::InternalError(ex.what());
-  } catch (const std::invalid_argument& ex) {
-    return absl::InvalidArgumentError(ex.what());
-  }
-  return absl::OkStatus();
-}
-
-}  // namespace internal
-}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/fft_impl.h b/tensorflow/core/kernels/fft_impl.h
deleted file mode 100644
index ebaa24d..0000000
--- a/tensorflow/core/kernels/fft_impl.h
+++ /dev/null
@@ -1,39 +0,0 @@
-/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CORE_KERNELS_FFT_IMPL_H_
-#define TENSORFLOW_CORE_KERNELS_FFT_IMPL_H_
-
-// Generic interface for N-D FFT implementation.
-// Required to isolate the DUCC FFT implementation on CPU, so we can limit
-// required build flags to a single module.
-
-#include <vector>
-
-#include "absl/status/status.h"
-#include "tensorflow/core/framework/tensor.h"
-
-namespace tensorflow {
-namespace internal {
-
-template <typename Device>
-absl::Status FftImpl(const Device& device, const Tensor& in, Tensor* out,
-                     const uint64_t* fft_shape, const std::vector<size_t>& axes,
-                     bool forward);
-
-}  // namespace internal
-}  // namespace tensorflow
-
-#endif  // TENSORFLOW_CORE_KERNELS_FFT_IMPL_H_
diff --git a/tensorflow/core/kernels/fft_ops.cc b/tensorflow/core/kernels/fft_ops.cc
index 28fd113..099a570 100644
--- a/tensorflow/core/kernels/fft_ops.cc
+++ b/tensorflow/core/kernels/fft_ops.cc
@@ -17,28 +17,31 @@
 
 // See docs in ../ops/fft_ops.cc.
 
+#include <cstddef>
+#include <cstdint>
+#include <vector>
+
 #include "absl/log/check.h"
 #include "absl/status/status.h"
 #include "absl/strings/str_cat.h"
+#include "ducc/google/fft.h"  // from @ducc
 #include "unsupported/Eigen/CXX11/Tensor"  // from @eigen_archive
 #include "unsupported/Eigen/CXX11/ThreadPool"  // from @eigen_archive
-#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/numeric_types.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/op_requires.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/kernels/fft_impl.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/util/env_var.h"
-#include "tsl/framework/numeric_types.h"
 
 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
 #include "absl/container/flat_hash_map.h"
 #include "tensorflow/core/platform/stream_executor.h"
+#include "tensorflow/core/util/env_var.h"
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 #if defined(GOOGLE_CUDA) && GOOGLE_CUDA
 #include "third_party/gpus/cuda/include/cuda.h"  // CUDA_VERSION
@@ -46,6 +49,105 @@
 
 namespace tensorflow {
 
+namespace {
+
+using std::size_t;
+using Shape = ducc0::google::Shape;
+using Stride = ducc0::google::Stride;
+absl::Status DuccFftImpl(const Eigen::ThreadPoolDevice& device,
+                         const Tensor& in, Tensor* out,
+                         const uint64_t* fft_shape,
+                         const std::vector<size_t>& axes, bool forward) {
+  const size_t fft_rank = axes.size();
+  Shape in_shape(in.dims());
+  Stride in_stride(in.dims());
+  Shape out_shape(out->dims());
+  Stride out_stride(out->dims());
+
+  size_t next_stride = 1;
+  for (int i = in.dims(); i-- > 0;) {
+    in_shape[i] = in.dim_size(i);
+    in_stride[i] = next_stride;
+    next_stride *= in_shape[i];
+  }
+  next_stride = 1;
+  for (int i = out->dims(); i-- > 0;) {
+    out_shape[i] = out->dim_size(i);
+    out_stride[i] = next_stride;
+    next_stride *= out_shape[i];
+  }
+
+  // DUCC doesn't handle the case where fft_size[i] < input_size[i],
+  // so manually adjust inputs if required.  If doing irfft, the limit
+  // of the last axis is actually fft_size[i]/2 + 1.
+  const bool is_iffrt = !(forward || out->dtype() == DT_COMPLEX128 ||
+                          out->dtype() == DT_COMPLEX64);
+  for (int i = 0; i < fft_rank; ++i) {
+    int limit = (is_iffrt && (i == (fft_rank - 1))) ? fft_shape[i] / 2 + 1
+                                                    : fft_shape[i];
+    if (in_shape[axes[i]] > limit) {
+      in_shape[axes[i]] = limit;
+    }
+  }
+
+  double inv_scale = 1.0;
+  for (int i = 0; i < fft_rank; ++i) {
+    inv_scale *= out_shape[axes[i]];
+  }
+  double scale = forward ? 1.0 : 1.0 / inv_scale;
+
+  Eigen::ThreadPoolInterface* thread_pool = device.getPool();
+
+  if (in.dtype() == DT_COMPLEX128 && out->dtype() == DT_COMPLEX128) {
+    auto input = in.template flat<complex128>();
+    auto output = out->template flat<complex128>();
+    ducc0::google::c2c<double>(input.data(), in_shape, in_stride, output.data(),
+                               out_shape, out_stride, axes, forward, scale,
+                               thread_pool);
+  } else if (in.dtype() == DT_COMPLEX64 && out->dtype() == DT_COMPLEX64) {
+    auto input = in.template flat<complex64>();
+    auto output = out->template flat<complex64>();
+    ducc0::google::c2c<float>(input.data(), in_shape, in_stride, output.data(),
+                              out_shape, out_stride, axes, forward,
+                              static_cast<float>(scale), thread_pool);
+  } else if (in.dtype() == DT_DOUBLE && out->dtype() == DT_COMPLEX128 &&
+             forward) {
+    auto input = in.flat<double>();
+    auto output = out->flat<complex128>();
+    ducc0::google::r2c<double>(input.data(), in_shape, in_stride, output.data(),
+                               out_shape, out_stride, axes, forward, scale,
+                               thread_pool);
+  } else if (in.dtype() == DT_FLOAT && out->dtype() == DT_COMPLEX64 &&
+             forward) {
+    auto input = in.flat<float>();
+    auto output = out->flat<complex64>();
+    ducc0::google::r2c<float>(input.data(), in_shape, in_stride, output.data(),
+                              out_shape, out_stride, axes, forward,
+                              static_cast<float>(scale), thread_pool);
+  } else if (in.dtype() == DT_COMPLEX128 && out->dtype() == DT_DOUBLE &&
+             !forward) {
+    auto input = in.flat<complex128>();
+    auto output = out->flat<double>();
+    ducc0::google::c2r<double>(input.data(), in_shape, in_stride, output.data(),
+                               out_shape, out_stride, axes, forward, scale,
+                               thread_pool);
+  } else if (in.dtype() == DT_COMPLEX64 && out->dtype() == DT_FLOAT &&
+             !forward) {
+    auto input = in.flat<complex64>();
+    auto output = out->flat<float>();
+    ducc0::google::c2r<float>(input.data(), in_shape, in_stride, output.data(),
+                              out_shape, out_stride, axes, forward,
+                              static_cast<float>(scale), thread_pool);
+  } else {
+    return absl::InvalidArgumentError(
+        absl::StrCat("Invalid FFT parameters, in.dtype=", in.dtype(),
+                     ", out->dtype=", out->dtype(), ", forward=", forward));
+  }
+  return absl::OkStatus();
+}
+
+}  // namespace
+
 class FFTBase : public OpKernel {
  public:
   explicit FFTBase(OpKernelConstruction* ctx) : OpKernel(ctx) {}
@@ -297,9 +399,8 @@
       axes[i] = batch_dims + i;
     }
 
-    OP_REQUIRES_OK(
-        ctx, internal::FftImpl<CPUDevice>(ctx->eigen_device<CPUDevice>(), in,
-                                          out, fft_shape, axes, Forward));
+    OP_REQUIRES_OK(ctx, DuccFftImpl(ctx->eigen_device<CPUDevice>(), in, out,
+                                    fft_shape, axes, Forward));
   }
 };
 
@@ -1007,4 +1108,4 @@
                         FFTGPU<false, false, 3>);
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
-}  // end namespace tensorflow
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/matmul_op_impl.h b/tensorflow/core/kernels/matmul_op_impl.h
index faa5b79..71f338f 100644
--- a/tensorflow/core/kernels/matmul_op_impl.h
+++ b/tensorflow/core/kernels/matmul_op_impl.h
@@ -711,7 +711,8 @@
                     static_cast<Coefficient>(1.0), b_ptrs,
                     adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k,
                     static_cast<Coefficient>(0.0), c_ptrs, n, batch_size,
-                    GetNumericOptions(), &scratch_allocator)
+                    GetNumericOptions(), &scratch_allocator,
+                    se::blas::CallContext::kNone)
                 .ok();
         if (!blas_launch_status) {
           context->SetStatus(errors::Internal(
@@ -805,20 +806,22 @@
           }
         }
 
-        OP_REQUIRES_OK(context, stream->ThenBlasGemm(
-                                    blas_transpose_b, blas_transpose_a, n, m, k,
-                                    *(b_ptrs[0]), adj_y || trans_y ? k : n,
-                                    *(a_ptrs[0]), adj_x || trans_x ? m : k,
-                                    c_ptrs[0], n, GetNumericOptions()));
+        OP_REQUIRES_OK(context,
+                       stream->ThenBlasGemm(
+                           blas_transpose_b, blas_transpose_a, n, m, k,
+                           *(b_ptrs[0]), adj_y || trans_y ? k : n, *(a_ptrs[0]),
+                           adj_x || trans_x ? m : k, c_ptrs[0], n,
+                           GetNumericOptions(), se::blas::CallContext::kNone));
       } else if (use_strided_batched) {
         OP_REQUIRES_OK(
-            context, stream->ThenBlasGemmStridedBatched(
-                         blas_transpose_b, blas_transpose_a, n, m, k,
-                         static_cast<Coefficient>(1.0), *b_ptrs[0],
-                         adj_y || trans_y ? k : n, b_stride, *a_ptrs[0],
-                         adj_x || trans_x ? m : k, a_stride,
-                         static_cast<Coefficient>(0.0), c_ptrs[0], n, c_stride,
-                         batch_size, GetNumericOptions()));
+            context,
+            stream->ThenBlasGemmStridedBatched(
+                blas_transpose_b, blas_transpose_a, n, m, k,
+                static_cast<Coefficient>(1.0), *b_ptrs[0],
+                adj_y || trans_y ? k : n, b_stride, *a_ptrs[0],
+                adj_x || trans_x ? m : k, a_stride,
+                static_cast<Coefficient>(0.0), c_ptrs[0], n, c_stride,
+                batch_size, GetNumericOptions(), se::blas::CallContext::kNone));
       } else {
         BlasScratchAllocator scratch_allocator(context);
         bool blas_launch_status =
@@ -828,7 +831,8 @@
                     static_cast<Coefficient>(1.0), b_ptrs,
                     adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k,
                     static_cast<Coefficient>(0.0), c_ptrs, n, batch_size,
-                    GetNumericOptions(), &scratch_allocator)
+                    GetNumericOptions(), &scratch_allocator,
+                    se::blas::CallContext::kNone)
                 .ok();
         if (!blas_launch_status) {
           context->SetStatus(errors::Internal(
diff --git a/tensorflow/core/kernels/mkl/BUILD b/tensorflow/core/kernels/mkl/BUILD
index 979fc0a..291a5aa 100644
--- a/tensorflow/core/kernels/mkl/BUILD
+++ b/tensorflow/core/kernels/mkl/BUILD
@@ -323,6 +323,7 @@
     linkstatic = 1,  # Fixes dyld error on MacOS.
     deps = [
         ":mkl_dequantize_op",
+        ":mkl_kernel_util",
         ":mkl_tfconv_op",
         "//tensorflow/core:array_ops_op_lib",
         "//tensorflow/core:math_ops_op_lib",
diff --git a/tensorflow/core/kernels/mkl/mkl_dequantize_op.cc b/tensorflow/core/kernels/mkl/mkl_dequantize_op.cc
index a41afd6..a93b059 100644
--- a/tensorflow/core/kernels/mkl/mkl_dequantize_op.cc
+++ b/tensorflow/core/kernels/mkl/mkl_dequantize_op.cc
@@ -35,7 +35,7 @@
 
 typedef Eigen::ThreadPoolDevice CPUDevice;
 
-template <typename Device, typename T, bool native_format = false>
+template <typename Device, typename T, typename U, bool native_format = false>
 class MklDequantizeOp : public OpKernel {
  public:
   explicit MklDequantizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
@@ -47,6 +47,12 @@
                     mode_string + "'"));
 
     OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range_));
+    OP_REQUIRES(
+        ctx,
+        (ctx->output_type(0) == DT_FLOAT || ctx->output_type(0) == DT_BFLOAT16),
+        errors::InvalidArgument("Output type must be float or bfloat16,"
+                                " is '" +
+                                DataTypeString(ctx->output_type(0)) + "'"));
   }
 
   void Compute(OpKernelContext* ctx) override {
@@ -69,7 +75,7 @@
 
       // Create reorder memory for src and dst
       MklDnnData<T> src(&cpu_engine);
-      MklDnnData<float> dst(&cpu_engine);
+      MklDnnData<U> dst(&cpu_engine);
 
       std::shared_ptr<stream> reorder_stream;
       // Create the oneDNN wrapper over Eigen threadpool and set max threads
@@ -116,9 +122,8 @@
       MklDnnShape output_mkl_shape;
       TensorShape output_tf_shape;
       memory::desc dst_md =
-          memory::desc(src_dims, MklDnnType<float>(), dst_layout_type);
+          memory::desc(src_dims, MklDnnType<U>(), dst_layout_type);
 
-      // If input is MKL shape, output is also MKL shape.
       // If input is TF shape, output is also TF shape.
       output_mkl_shape.SetMklTensor(false);
       output_tf_shape = MklDnnDimsToTFShape(output_dims);
@@ -149,8 +154,7 @@
       } else {
         scale_factor = max_range / v_max;
       }
-      std::vector<float> scales;
-      scales.push_back(scale_factor);
+      std::vector<float> scales = {scale_factor};
       primitive_attr attr;
 #ifndef ENABLE_ONEDNN_V3
       attr.set_output_scales(0, scales);
@@ -161,13 +165,12 @@
                                memory::format_tag::x},
                               cpu_engine, scales.data());
 #endif  // !ENABLE_ONEDNN_V3
-      std::vector<primitive> net;
 
       // Create reorder primitive and then execute.
       auto reorder_pd =
           ReorderPd(cpu_engine, src.GetUsrMem()->get_desc(), cpu_engine,
                     dst.GetUsrMem()->get_desc(), attr);
-      net.push_back(reorder(reorder_pd));
+      std::vector<primitive> net = {reorder(reorder_pd)};
       std::vector<std::unordered_map<int, memory>> reorder_net_args;
 #ifndef ENABLE_ONEDNN_V3
       reorder_net_args.push_back({{DNNL_ARG_FROM, *src.GetUsrMem()},
@@ -199,13 +202,27 @@
 REGISTER_KERNEL_BUILDER(Name("_MklDequantize")
                             .Device(DEVICE_CPU)
                             .TypeConstraint<quint8>("T")
+                            .TypeConstraint<float>("dtype")
                             .Label(mkl_op_registry::kMklQuantizedOpLabel),
-                        MklDequantizeOp<CPUDevice, quint8, true>);
+                        MklDequantizeOp<CPUDevice, quint8, float, true>);
 REGISTER_KERNEL_BUILDER(Name("_MklDequantize")
                             .Device(DEVICE_CPU)
                             .TypeConstraint<qint8>("T")
+                            .TypeConstraint<float>("dtype")
                             .Label(mkl_op_registry::kMklQuantizedOpLabel),
-                        MklDequantizeOp<CPUDevice, qint8, true>);
+                        MklDequantizeOp<CPUDevice, qint8, float, true>);
+REGISTER_KERNEL_BUILDER(Name("_MklDequantize")
+                            .Device(DEVICE_CPU)
+                            .TypeConstraint<quint8>("T")
+                            .TypeConstraint<bfloat16>("dtype")
+                            .Label(mkl_op_registry::kMklQuantizedOpLabel),
+                        MklDequantizeOp<CPUDevice, quint8, bfloat16, true>);
+REGISTER_KERNEL_BUILDER(Name("_MklDequantize")
+                            .Device(DEVICE_CPU)
+                            .TypeConstraint<qint8>("T")
+                            .TypeConstraint<bfloat16>("dtype")
+                            .Label(mkl_op_registry::kMklQuantizedOpLabel),
+                        MklDequantizeOp<CPUDevice, qint8, bfloat16, true>);
 
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/kernels/mkl/mkl_dequantize_op_test.cc b/tensorflow/core/kernels/mkl/mkl_dequantize_op_test.cc
index 6c46ca1..63eeb95 100644
--- a/tensorflow/core/kernels/mkl/mkl_dequantize_op_test.cc
+++ b/tensorflow/core/kernels/mkl/mkl_dequantize_op_test.cc
@@ -14,12 +14,17 @@
 ==============================================================================*/
 
 #if defined(INTEL_MKL) && defined(ENABLE_MKL)
+#define EIGEN_USE_THREADS
 
+#include "tensorflow/cc/ops/const_op.h"
+#include "tensorflow/cc/ops/standard_ops.h"
 #include "tensorflow/core/framework/fake_input.h"
 #include "tensorflow/core/framework/node_def_builder.h"
 #include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/mkl/mkl_kernel_util.h"
 #include "tensorflow/core/kernels/ops_testutil.h"
 #include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/kernels/quantization_utils.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/platform/test_benchmark.h"
@@ -27,53 +32,95 @@
 
 namespace tensorflow {
 
-class MklDequantizeOpTest : public OpsTestBase {};
+class MklDequantizeOpTest : public OpsTestBase {
+ protected:
+  template <typename Tinput, typename Toutput>
+  void RunMklDequantize(const Tensor& input_quantized,
+                        const Tensor& min_range_float,
+                        const Tensor& max_range_float,
+                        const Tensor& expected_output) {
+    AddInputFromArray<Tinput>(input_quantized.shape(),
+                              input_quantized.flat<Tinput>());
+    AddInputFromArray<float>(min_range_float.shape(),
+                             min_range_float.flat<float>());
+    AddInputFromArray<float>(max_range_float.shape(),
+                             max_range_float.flat<float>());
 
-TEST_F(MklDequantizeOpTest, small) {
-  TF_ASSERT_OK(NodeDefBuilder("dequantize_op", "_MklDequantize")
-                   .Input(FakeInput(DT_QUINT8))
-                   .Input(FakeInput(DT_FLOAT))
-                   .Input(FakeInput(DT_FLOAT))
-                   .Attr("T", DataTypeToEnum<quint8>::v())
-                   .Attr("mode", "SCALED")
-                   .Attr("_kernel", "QuantizedMklOp")
-                   .Finalize(node_def()));
-  TF_ASSERT_OK(InitOp());
-  AddInputFromArray<quint8>(TensorShape({1, 2, 2, 2}),
-                            {0, 10, 50, 40, 25, 115, 190, 255});
-  // min_range = 0
-  AddInputFromArray<float>(TensorShape({}), {0});
-  // max_range = 200
-  AddInputFromArray<float>(TensorShape({}), {200.0f});
-  TF_ASSERT_OK(RunOpKernel());
-  Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 2, 2, 2}));
-  test::FillValues<float>(&expected,
-                          {0.0, 7.84, 39.21, 31.37, 19.6, 90.2, 149.0, 200});
-  const Tensor& output = *GetOutput(0);
-  test::ExpectTensorNear<float>(expected, output, 0.1);
+    TF_ASSERT_OK(RunOpKernel());
+
+    const Tensor& actual_output = *GetOutput(0);
+    test::ExpectTensorNear<Toutput>(expected_output, actual_output, 0.1);
+  }
+
+  template <typename Tinput, typename Toutput>
+  void TestMklDequantize() {
+    const DataType input_dt = DataTypeToEnum<Tinput>::v();
+    const DataType output_dt = DataTypeToEnum<Toutput>::v();
+
+    TF_ASSERT_OK(NodeDefBuilder("dequantize_op", "_MklDequantize")
+                     .Input(FakeInput(input_dt))
+                     .Input(FakeInput(DT_FLOAT))  // min_range
+                     .Input(FakeInput(DT_FLOAT))  // max_range
+                     .Attr("T", input_dt)
+                     .Attr("dtype", output_dt)
+                     .Attr("mode", "SCALED")
+                     .Attr("_kernel", "QuantizedMklOp")
+                     .Finalize(node_def()));
+
+    TF_ASSERT_OK(InitOp());
+
+    Tensor input_float(DT_FLOAT, {1, 2, 2, 2});
+    test::FillValues<float>(&input_float, {0, 10, 50, 40, 25, 115, 190, 255});
+
+    const float min_range = 0.0f;
+    const float max_range = 255.0f;
+
+    Tensor min_range_float(DT_FLOAT, {});
+    test::FillValues<float>(&min_range_float, {min_range});
+
+    Tensor max_range_float(DT_FLOAT, {});
+    test::FillValues<float>(&max_range_float, {max_range});
+
+    Tensor input_quantized =
+        FloatTensorToQuantized<Tinput>(input_float, min_range, max_range);
+
+    Tensor expected_output_float32;
+    MklTestingUtil::RunDequantizeOp(input_quantized, min_range_float,
+                                    max_range_float, "SCALED",
+                                    &expected_output_float32);
+
+    if (output_dt == DT_BFLOAT16) {
+      // Since DequantizeOp does not support "SCALED" mode for bf16 output,
+      // use a workaround by casting fp32 output (computed using "SCALED" mode)
+      // into bf16 output.
+      Tensor expected_output_bfloat16(DT_BFLOAT16, {1, 2, 2, 2});
+      expected_output_bfloat16.flat<bfloat16>() =
+          expected_output_float32.flat<float>().cast<bfloat16>();
+      RunMklDequantize<Tinput, Toutput>(input_quantized, min_range_float,
+                                        max_range_float,
+                                        expected_output_bfloat16);
+    } else {
+      RunMklDequantize<Tinput, Toutput>(input_quantized, min_range_float,
+                                        max_range_float,
+                                        expected_output_float32);
+    }
+  }
+};
+
+TEST_F(MklDequantizeOpTest, MklDequantize_Unsigned_Input_Float_Output) {
+  TestMklDequantize<quint8, float>();
 }
 
-TEST_F(MklDequantizeOpTest, MKLInput) {
-  TF_ASSERT_OK(NodeDefBuilder("dequantize_op", "_MklDequantize")
-                   .Input(FakeInput(DT_QUINT8))
-                   .Input(FakeInput(DT_FLOAT))
-                   .Input(FakeInput(DT_FLOAT))
-                   .Attr("T", DataTypeToEnum<quint8>::v())
-                   .Attr("mode", "SCALED")
-                   .Attr("_kernel", "QuantizedMklOp")
-                   .Finalize(node_def()));
-  TF_ASSERT_OK(InitOp());
-  AddInputFromArray<quint8>(TensorShape({1, 2, 2, 2}),
-                            {0, 10, 50, 40, 25, 115, 190, 255});
-  // min_range = 0
-  AddInputFromArray<float>(TensorShape({}), {0});
-  // max_range = 200
-  AddInputFromArray<float>(TensorShape({}), {200.0f});
-  TF_ASSERT_OK(RunOpKernel());
-  Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 2, 2, 2}));
-  test::FillValues<float>(&expected,
-                          {0.0, 7.84, 39.21, 31.37, 19.6, 90.2, 149.0, 200});
-  test::ExpectTensorNear<float>(expected, *GetOutput(0), 0.1);
+TEST_F(MklDequantizeOpTest, MklDequantize_Signed_Input_Float_Output) {
+  TestMklDequantize<qint8, float>();
+}
+
+TEST_F(MklDequantizeOpTest, MklDequantize_Unsigned_Input_Bfloat16_Output) {
+  TestMklDequantize<quint8, bfloat16>();
+}
+
+TEST_F(MklDequantizeOpTest, MklDequantize_Signed_Input_Bfloat16_Output) {
+  TestMklDequantize<qint8, bfloat16>();
 }
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/mkl/mkl_kernel_util.h b/tensorflow/core/kernels/mkl/mkl_kernel_util.h
index 1fb9720..fb9df4d 100644
--- a/tensorflow/core/kernels/mkl/mkl_kernel_util.h
+++ b/tensorflow/core/kernels/mkl/mkl_kernel_util.h
@@ -25,6 +25,8 @@
 
 using dnnl::memory;
 
+using dnnl::memory;
+
 namespace tensorflow {
 
 class MklTestingUtil {
diff --git a/tensorflow/core/kernels/rnn/blas_gemm.cc b/tensorflow/core/kernels/rnn/blas_gemm.cc
index b83de9f..96c9cde 100644
--- a/tensorflow/core/kernels/rnn/blas_gemm.cc
+++ b/tensorflow/core/kernels/rnn/blas_gemm.cc
@@ -54,7 +54,7 @@
       ctx, ctx->op_device_context()->stream()->ThenBlasGemm(
                trans[transa], trans[transb], m, n, k, static_cast<T>(alpha),
                a_ptr, lda, b_ptr, ldb, static_cast<T>(beta), &c_ptr, ldc,
-               GetNumericOptions()));
+               GetNumericOptions(), se::blas::CallContext::kNone));
 #else
   ctx->SetStatus(errors::InvalidArgument("CuBlasGemm needs CUDA."));
 #endif
diff --git a/tensorflow/core/lib/gtl/manual_constructor.h b/tensorflow/core/lib/gtl/manual_constructor.h
index c5a4e89..4431f5e 100644
--- a/tensorflow/core/lib/gtl/manual_constructor.h
+++ b/tensorflow/core/lib/gtl/manual_constructor.h
@@ -56,8 +56,7 @@
 #if defined(_MSC_VER)
 #define TF_LIB_GTL_ALIGN_ATTRIBUTE(X) __declspec(align(X))
 #define TF_LIB_GTL_ALIGN_OF(T) __alignof(T)
-#elif defined(COMPILER_GCC3) || __GNUC__ >= 3 || defined(__APPLE__) || \
-    defined(COMPILER_ICC) || defined(OS_NACL) || defined(__clang__)
+#else
 #define TF_LIB_GTL_ALIGN_ATTRIBUTE(X) __attribute__((aligned(X)))
 #define TF_LIB_GTL_ALIGN_OF(T) __alignof__(T)
 #endif
diff --git a/tensorflow/core/ops/compat/ops_history_v2/ComputeDedupDataSize.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/ComputeDedupDataSize.pbtxt
index 0d4ec98..b9f68e2 100644
--- a/tensorflow/core/ops/compat/ops_history_v2/ComputeDedupDataSize.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history_v2/ComputeDedupDataSize.pbtxt
@@ -1,4 +1,4 @@
-op {
+op 	 {
   name: "ComputeDedupDataSize"
   output_arg {
     name: "num_elements"
diff --git a/tensorflow/core/ops/compat/ops_history_v2/ConvertToCooTensor.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/ConvertToCooTensor.pbtxt
new file mode 100644
index 0000000..31aeb0d
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/ConvertToCooTensor.pbtxt
@@ -0,0 +1,38 @@
+op 	 {
+  name: "ConvertToCooTensor"
+  input_arg {
+    name: "indices_or_row_splits"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "values"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "weights"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "row_ids"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "col_ids"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "gains"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "sample_count"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "combiner"
+    type: "string"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/GetMinibatchSplitsWithPhysicalReplica.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/GetMinibatchSplitsWithPhysicalReplica.pbtxt
new file mode 100644
index 0000000..764a898
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/GetMinibatchSplitsWithPhysicalReplica.pbtxt
@@ -0,0 +1,86 @@
+op 	 {
+  name: "GetMinibatchSplitsWithPhysicalReplica"
+  input_arg {
+    name: "program_key"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "row_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "col_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "gains"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "sorted_row_ids"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "sorted_col_ids"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "sorted_gains"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "splits"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "id_counts"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "max_ids"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "max_uniques"
+    type: DT_INT32
+  }
+  attr {
+    name: "sample_count"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "num_replica"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "table_vocab_size"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "feature_width"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "num_sc_per_chip"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+  }
+  attr {
+    name: "mini_batch_splits"
+    type: "string"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/GetMinibatchesInCsrWithPhysicalReplica.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/GetMinibatchesInCsrWithPhysicalReplica.pbtxt
new file mode 100644
index 0000000..45c3b53
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/GetMinibatchesInCsrWithPhysicalReplica.pbtxt
@@ -0,0 +1,106 @@
+op 	 {
+  name: "GetMinibatchesInCsrWithPhysicalReplica"
+  input_arg {
+    name: "program_key"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "row_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "col_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "gains"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "splits"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "id_counts"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "row_pointers"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "sorted_sample_ids"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "sorted_token_ids"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "sorted_gains"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "row_pointers_unpadded_size"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "ids_unpadded_size"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "num_minibatches_per_physical_sparse_core"
+    type: DT_INT32
+  }
+  attr {
+    name: "sample_count"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "num_replica"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "max_minibatches_per_sc"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "max_ids_per_chip_per_sample"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "table_vocab_size"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "feature_width"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "num_sc_per_chip"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+  }
+  attr {
+    name: "mini_batch_in_csr"
+    type: "string"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/StoreMinibatchStatisticsInFdo.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/StoreMinibatchStatisticsInFdo.pbtxt
new file mode 100644
index 0000000..2250ba0
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/StoreMinibatchStatisticsInFdo.pbtxt
@@ -0,0 +1,48 @@
+op 	 {
+  name: "StoreMinibatchStatisticsInFdo"
+  input_arg {
+    name: "program_key"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "max_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "max_uniques"
+    type: DT_INT32
+  }
+  attr {
+    name: "sample_count"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "num_replica"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "feature_width"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "num_sc_per_chip"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+  }
+  attr {
+    name: "mini_batch_splits"
+    type: "string"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/TPUAnnotateTensorsWithDynamicShape.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/TPUAnnotateTensorsWithDynamicShape.pbtxt
new file mode 100644
index 0000000..09d484e
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/TPUAnnotateTensorsWithDynamicShape.pbtxt
@@ -0,0 +1,18 @@
+op 	 {
+  name: "TPUAnnotateTensorsWithDynamicShape"
+  input_arg {
+    name: "tensors"
+    type_list_attr: "T"
+  }
+  output_arg {
+    name: "tpu_tensors"
+    type_list_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/TPUCopyWithDynamicShape.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/TPUCopyWithDynamicShape.pbtxt
new file mode 100644
index 0000000..8b897ff
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/TPUCopyWithDynamicShape.pbtxt
@@ -0,0 +1,28 @@
+op 	 {
+  name: "TPUCopyWithDynamicShape"
+  input_arg {
+    name: "tensors"
+    type_list_attr: "T"
+  }
+  input_arg {
+    name: "unpadded_sizes"
+    type: DT_INT32
+    number_attr: "N"
+  }
+  output_arg {
+    name: "tpu_tensors"
+    type_list_attr: "T"
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+  }
+  attr {
+    name: "T"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseCoreAdagrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseCoreAdagrad.pbtxt
new file mode 100644
index 0000000..9cf626b
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseCoreAdagrad.pbtxt
@@ -0,0 +1,36 @@
+op 	 {
+  name: "XlaSparseCoreAdagrad"
+  input_arg {
+    name: "indices"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "gradient"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "learning_rate"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "accumulator"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "embedding_table"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_embedding_table"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_accumulator"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "feature_width"
+    type: "int"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseCoreAdagradMomentum.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseCoreAdagradMomentum.pbtxt
new file mode 100644
index 0000000..b644604
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseCoreAdagradMomentum.pbtxt
@@ -0,0 +1,64 @@
+op 	 {
+  name: "XlaSparseCoreAdagradMomentum"
+  input_arg {
+    name: "indices"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "gradient"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "learning_rate"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "beta_1"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "epsilon"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "accumulator"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "momentum"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "embedding_table"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_embedding_table"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_accumulator"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_momentum"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "feature_width"
+    type: "int"
+  }
+  attr {
+    name: "use_nesterov"
+    type: "bool"
+  }
+  attr {
+    name: "beta_2"
+    type: "float"
+  }
+  attr {
+    name: "exponent"
+    type: "float"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseCoreAdam.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseCoreAdam.pbtxt
new file mode 100644
index 0000000..38af8af
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseCoreAdam.pbtxt
@@ -0,0 +1,60 @@
+op 	 {
+  name: "XlaSparseCoreAdam"
+  input_arg {
+    name: "embedding_table"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "indices"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "gradient"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "learning_rate"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "momentum"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "velocity"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "beta_1"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "beta_2"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "epsilon"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_embedding_table"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_velocity"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_momentum"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "feature_width"
+    type: "int"
+  }
+  attr {
+    name: "use_sum_inside_sqrt"
+    type: "bool"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseCoreFtrl.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseCoreFtrl.pbtxt
new file mode 100644
index 0000000..afbf9e0
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseCoreFtrl.pbtxt
@@ -0,0 +1,64 @@
+op 	 {
+  name: "XlaSparseCoreFtrl"
+  input_arg {
+    name: "embedding_table"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "accumulator"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "linear"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "learning_rate"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "indices"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "gradient"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "beta"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "learning_rate_power"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "l2_regularization_strength"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_embedding_table"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_accumulator"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_linear"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "feature_width"
+    type: "int"
+  }
+  attr {
+    name: "multiply_linear_by_learning_rate"
+    type: "bool"
+  }
+  attr {
+    name: "l1_regularization_strength"
+    type: "float"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseCoreSgd.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseCoreSgd.pbtxt
new file mode 100644
index 0000000..7f507c7
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseCoreSgd.pbtxt
@@ -0,0 +1,28 @@
+op 	 {
+  name: "XlaSparseCoreSgd"
+  input_arg {
+    name: "indices"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "gradient"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "learning_rate"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "embedding_table"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_embedding_table"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "feature_width"
+    type: "int"
+  }
+  is_stateful: true
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmul.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmul.pbtxt
new file mode 100644
index 0000000..5ecf0c2
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmul.pbtxt
@@ -0,0 +1,58 @@
+op 	 {
+  name: "XlaSparseDenseMatmul"
+  input_arg {
+    name: "row_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "col_ids"
+    type: DT_UINT32
+  }
+  input_arg {
+    name: "values"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "offsets"
+    type: DT_UINT32
+  }
+  input_arg {
+    name: "embedding_table"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "activations"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "row_pointers"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "sorted_embedding_ids"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "sorted_sample_ids"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "sorted_gains"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "max_ids_per_partition"
+    type: "int"
+    has_minimum: true
+  }
+  attr {
+    name: "max_unique_ids_per_partition"
+    type: "int"
+    has_minimum: true
+  }
+  attr {
+    name: "input_size"
+    type: "int"
+    has_minimum: true
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithAdagradAndCsrInput.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithAdagradAndCsrInput.pbtxt
new file mode 100644
index 0000000..e13cbfc
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithAdagradAndCsrInput.pbtxt
@@ -0,0 +1,116 @@
+op 	 {
+  name: "XlaSparseDenseMatmulGradWithAdagradAndCsrInput"
+  input_arg {
+    name: "row_pointers"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_sample_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_token_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_gains"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "activation_gradients"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "learning_rate"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "embedding_table"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "accumulator"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "num_minibatches_per_physical_sparse_core"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "updated_embedding_table"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_accumulator"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+  }
+}
+op {
+  name: "XlaSparseDenseMatmulGradWithAdagradAndCsrInput"
+  input_arg {
+    name: "row_pointers"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_sample_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_token_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_gains"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "activation_gradients"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "learning_rate"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "embedding_table"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "accumulator"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "num_minibatches_per_physical_sparse_core"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "updated_embedding_table"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_accumulator"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "clip_weight_min"
+    type: "float"
+    default_value {
+      f: -inf
+    }
+  }
+  attr {
+    name: "clip_weight_max"
+    type: "float"
+    default_value {
+      f: inf
+    }
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput.pbtxt
new file mode 100644
index 0000000..e6f2eed
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput.pbtxt
@@ -0,0 +1,172 @@
+op 	 {
+  name: "XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput"
+  input_arg {
+    name: "row_pointers"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_sample_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_token_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_gains"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "activation_gradients"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "learning_rate"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "embedding_table"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "accumulator"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "momenta"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "num_minibatches_per_physical_sparse_core"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "updated_embedding_table"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_accumulator"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_momenta"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "use_nesterov"
+    type: "bool"
+  }
+  attr {
+    name: "exponent"
+    type: "float"
+  }
+  attr {
+    name: "beta1"
+    type: "float"
+  }
+  attr {
+    name: "beta2"
+    type: "float"
+  }
+  attr {
+    name: "epsilon"
+    type: "float"
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+  }
+}
+op {
+  name: "XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput"
+  input_arg {
+    name: "row_pointers"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_sample_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_token_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_gains"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "activation_gradients"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "learning_rate"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "embedding_table"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "accumulator"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "momenta"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "num_minibatches_per_physical_sparse_core"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "updated_embedding_table"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_accumulator"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_momenta"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "use_nesterov"
+    type: "bool"
+  }
+  attr {
+    name: "exponent"
+    type: "float"
+  }
+  attr {
+    name: "beta1"
+    type: "float"
+  }
+  attr {
+    name: "beta2"
+    type: "float"
+  }
+  attr {
+    name: "epsilon"
+    type: "float"
+  }
+  attr {
+    name: "clip_weight_min"
+    type: "float"
+    default_value {
+      f: -inf
+    }
+  }
+  attr {
+    name: "clip_weight_max"
+    type: "float"
+    default_value {
+      f: inf
+    }
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithAdamAndCsrInput.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithAdamAndCsrInput.pbtxt
new file mode 100644
index 0000000..202e6f4
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithAdamAndCsrInput.pbtxt
@@ -0,0 +1,164 @@
+op 	 {
+  name: "XlaSparseDenseMatmulGradWithAdamAndCsrInput"
+  input_arg {
+    name: "row_pointers"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_sample_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_token_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_gains"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "activation_gradients"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "learning_rate"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "embedding_table"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "momenta"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "velocity"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "num_minibatches_per_physical_sparse_core"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "updated_embedding_table"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_momenta"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_velocity"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "use_sum_inside_sqrt"
+    type: "bool"
+  }
+  attr {
+    name: "beta1"
+    type: "float"
+  }
+  attr {
+    name: "beta2"
+    type: "float"
+  }
+  attr {
+    name: "epsilon"
+    type: "float"
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+  }
+}
+op {
+  name: "XlaSparseDenseMatmulGradWithAdamAndCsrInput"
+  input_arg {
+    name: "row_pointers"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_sample_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_token_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_gains"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "activation_gradients"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "learning_rate"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "embedding_table"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "momenta"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "velocity"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "num_minibatches_per_physical_sparse_core"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "updated_embedding_table"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_momenta"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_velocity"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "use_sum_inside_sqrt"
+    type: "bool"
+  }
+  attr {
+    name: "beta1"
+    type: "float"
+  }
+  attr {
+    name: "beta2"
+    type: "float"
+  }
+  attr {
+    name: "epsilon"
+    type: "float"
+  }
+  attr {
+    name: "clip_weight_min"
+    type: "float"
+    default_value {
+      f: -inf
+    }
+  }
+  attr {
+    name: "clip_weight_max"
+    type: "float"
+    default_value {
+      f: inf
+    }
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithFtrlAndCsrInput.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithFtrlAndCsrInput.pbtxt
new file mode 100644
index 0000000..96121a6
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithFtrlAndCsrInput.pbtxt
@@ -0,0 +1,172 @@
+op 	 {
+  name: "XlaSparseDenseMatmulGradWithFtrlAndCsrInput"
+  input_arg {
+    name: "row_pointers"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_sample_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_token_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_gains"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "activation_gradients"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "learning_rate"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "embedding_table"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "accumulator"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "linear"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "num_minibatches_per_physical_sparse_core"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "updated_embedding_table"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_accumulator"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_linear"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "multiply_linear_by_learning_rate"
+    type: "bool"
+  }
+  attr {
+    name: "beta"
+    type: "float"
+  }
+  attr {
+    name: "learning_rate_power"
+    type: "float"
+  }
+  attr {
+    name: "l1_regularization_strength"
+    type: "float"
+  }
+  attr {
+    name: "l2_regularization_strength"
+    type: "float"
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+  }
+}
+op {
+  name: "XlaSparseDenseMatmulGradWithFtrlAndCsrInput"
+  input_arg {
+    name: "row_pointers"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_sample_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_token_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_gains"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "activation_gradients"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "learning_rate"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "embedding_table"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "accumulator"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "linear"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "num_minibatches_per_physical_sparse_core"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "updated_embedding_table"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_accumulator"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_linear"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "multiply_linear_by_learning_rate"
+    type: "bool"
+  }
+  attr {
+    name: "beta"
+    type: "float"
+  }
+  attr {
+    name: "learning_rate_power"
+    type: "float"
+  }
+  attr {
+    name: "l1_regularization_strength"
+    type: "float"
+  }
+  attr {
+    name: "l2_regularization_strength"
+    type: "float"
+  }
+  attr {
+    name: "clip_weight_min"
+    type: "float"
+    default_value {
+      f: -inf
+    }
+  }
+  attr {
+    name: "clip_weight_max"
+    type: "float"
+    default_value {
+      f: inf
+    }
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithSgdAndCsrInput.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithSgdAndCsrInput.pbtxt
new file mode 100644
index 0000000..3ad518f
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulGradWithSgdAndCsrInput.pbtxt
@@ -0,0 +1,100 @@
+op 	 {
+  name: "XlaSparseDenseMatmulGradWithSgdAndCsrInput"
+  input_arg {
+    name: "row_pointers"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_sample_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_token_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_gains"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "activation_gradients"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "learning_rate"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "embedding_table"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "num_minibatches_per_physical_sparse_core"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "updated_embedding_table"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+  }
+}
+op {
+  name: "XlaSparseDenseMatmulGradWithSgdAndCsrInput"
+  input_arg {
+    name: "row_pointers"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_sample_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_token_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_gains"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "activation_gradients"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "learning_rate"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "embedding_table"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "num_minibatches_per_physical_sparse_core"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "updated_embedding_table"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "clip_weight_min"
+    type: "float"
+    default_value {
+      f: -inf
+    }
+  }
+  attr {
+    name: "clip_weight_max"
+    type: "float"
+    default_value {
+      f: inf
+    }
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulWithCsrInput.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulWithCsrInput.pbtxt
new file mode 100644
index 0000000..1aa1743
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/XlaSparseDenseMatmulWithCsrInput.pbtxt
@@ -0,0 +1,53 @@
+op 	 {
+  name: "XlaSparseDenseMatmulWithCsrInput"
+  input_arg {
+    name: "row_pointers"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_sample_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_token_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_gains"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "embedding_table"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "num_minibatches_per_physical_sparse_core"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "activations"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "input_size"
+    type: "int"
+    has_minimum: true
+  }
+  attr {
+    name: "quantization_config_low"
+    type: "float"
+  }
+  attr {
+    name: "quantization_config_high"
+    type: "float"
+  }
+  attr {
+    name: "quantization_config_num_buckets"
+    type: "int"
+    has_minimum: true
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+  }
+}
diff --git a/tensorflow/core/ops/mkl_array_ops.cc b/tensorflow/core/ops/mkl_array_ops.cc
index b5e368c..6dad4c0 100644
--- a/tensorflow/core/ops/mkl_array_ops.cc
+++ b/tensorflow/core/ops/mkl_array_ops.cc
@@ -102,7 +102,7 @@
     .Input("input: T")
     .Input("min_range: float")
     .Input("max_range: float")
-    .Output("output: float")
+    .Output("output: dtype")
     .Attr("T: quantizedtype")
     .Attr("narrow_range: bool = false")
     .Attr("axis: int = -1")
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 7e30b21..2bcdc8b 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -10649,6 +10649,44 @@
   }
 }
 op {
+  name: "ConvertToCooTensor"
+  input_arg {
+    name: "indices_or_row_splits"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "values"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "weights"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "row_ids"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "col_ids"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "gains"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "sample_count"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "combiner"
+    type: "string"
+  }
+  is_stateful: true
+}
+op {
   name: "Copy"
   input_arg {
     name: "input"
@@ -21746,6 +21784,198 @@
   }
 }
 op {
+  name: "GetMinibatchSplitsWithPhysicalReplica"
+  input_arg {
+    name: "program_key"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "row_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "col_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "gains"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "sorted_row_ids"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "sorted_col_ids"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "sorted_gains"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "splits"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "id_counts"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "max_ids"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "max_uniques"
+    type: DT_INT32
+  }
+  attr {
+    name: "sample_count"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "num_replica"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "table_vocab_size"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "feature_width"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "num_sc_per_chip"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+  }
+  attr {
+    name: "mini_batch_splits"
+    type: "string"
+  }
+  is_stateful: true
+}
+op {
+  name: "GetMinibatchesInCsrWithPhysicalReplica"
+  input_arg {
+    name: "program_key"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "row_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "col_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "gains"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "splits"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "id_counts"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "row_pointers"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "sorted_sample_ids"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "sorted_token_ids"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "sorted_gains"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "row_pointers_unpadded_size"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "ids_unpadded_size"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "num_minibatches_per_physical_sparse_core"
+    type: DT_INT32
+  }
+  attr {
+    name: "sample_count"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "num_replica"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "max_minibatches_per_sc"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "max_ids_per_chip_per_sample"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "table_vocab_size"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "feature_width"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "num_sc_per_chip"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+  }
+  attr {
+    name: "mini_batch_in_csr"
+    type: "string"
+  }
+  is_stateful: true
+}
+op {
   name: "GetOptions"
   input_arg {
     name: "input_dataset"
@@ -58542,6 +58772,54 @@
   }
 }
 op {
+  name: "StoreMinibatchStatisticsInFdo"
+  input_arg {
+    name: "program_key"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "max_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "max_uniques"
+    type: DT_INT32
+  }
+  attr {
+    name: "sample_count"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "num_replica"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "feature_width"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "num_sc_per_chip"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+  }
+  attr {
+    name: "mini_batch_splits"
+    type: "string"
+  }
+  is_stateful: true
+}
+op {
   name: "StridedSlice"
   input_arg {
     name: "input"
@@ -59501,6 +59779,24 @@
   is_stateful: true
 }
 op {
+  name: "TPUAnnotateTensorsWithDynamicShape"
+  input_arg {
+    name: "tensors"
+    type_list_attr: "T"
+  }
+  output_arg {
+    name: "tpu_tensors"
+    type_list_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
+op {
   name: "TPUCompilationResult"
   output_arg {
     name: "output"
@@ -59566,6 +59862,34 @@
   is_stateful: true
 }
 op {
+  name: "TPUCopyWithDynamicShape"
+  input_arg {
+    name: "tensors"
+    type_list_attr: "T"
+  }
+  input_arg {
+    name: "unpadded_sizes"
+    type: DT_INT32
+    number_attr: "N"
+  }
+  output_arg {
+    name: "tpu_tensors"
+    type_list_attr: "T"
+  }
+  attr {
+    name: "N"
+    type: "int"
+    has_minimum: true
+  }
+  attr {
+    name: "T"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
+op {
   name: "TPUEmbeddingActivations"
   input_arg {
     name: "embedding_variable"
@@ -65847,6 +66171,766 @@
   is_stateful: true
 }
 op {
+  name: "XlaSparseCoreAdagrad"
+  input_arg {
+    name: "indices"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "gradient"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "learning_rate"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "accumulator"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "embedding_table"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_embedding_table"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_accumulator"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "feature_width"
+    type: "int"
+  }
+  is_stateful: true
+}
+op {
+  name: "XlaSparseCoreAdagradMomentum"
+  input_arg {
+    name: "indices"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "gradient"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "learning_rate"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "beta_1"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "epsilon"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "accumulator"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "momentum"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "embedding_table"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_embedding_table"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_accumulator"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_momentum"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "feature_width"
+    type: "int"
+  }
+  attr {
+    name: "use_nesterov"
+    type: "bool"
+  }
+  attr {
+    name: "beta_2"
+    type: "float"
+  }
+  attr {
+    name: "exponent"
+    type: "float"
+  }
+  is_stateful: true
+}
+op {
+  name: "XlaSparseCoreAdam"
+  input_arg {
+    name: "embedding_table"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "indices"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "gradient"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "learning_rate"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "momentum"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "velocity"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "beta_1"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "beta_2"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "epsilon"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_embedding_table"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_velocity"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_momentum"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "feature_width"
+    type: "int"
+  }
+  attr {
+    name: "use_sum_inside_sqrt"
+    type: "bool"
+  }
+  is_stateful: true
+}
+op {
+  name: "XlaSparseCoreFtrl"
+  input_arg {
+    name: "embedding_table"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "accumulator"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "linear"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "learning_rate"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "indices"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "gradient"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "beta"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "learning_rate_power"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "l2_regularization_strength"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_embedding_table"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_accumulator"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_linear"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "feature_width"
+    type: "int"
+  }
+  attr {
+    name: "multiply_linear_by_learning_rate"
+    type: "bool"
+  }
+  attr {
+    name: "l1_regularization_strength"
+    type: "float"
+  }
+  is_stateful: true
+}
+op {
+  name: "XlaSparseCoreSgd"
+  input_arg {
+    name: "indices"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "gradient"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "learning_rate"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "embedding_table"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_embedding_table"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "feature_width"
+    type: "int"
+  }
+  is_stateful: true
+}
+op {
+  name: "XlaSparseDenseMatmul"
+  input_arg {
+    name: "row_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "col_ids"
+    type: DT_UINT32
+  }
+  input_arg {
+    name: "values"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "offsets"
+    type: DT_UINT32
+  }
+  input_arg {
+    name: "embedding_table"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "activations"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "row_pointers"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "sorted_embedding_ids"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "sorted_sample_ids"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "sorted_gains"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "max_ids_per_partition"
+    type: "int"
+    has_minimum: true
+  }
+  attr {
+    name: "max_unique_ids_per_partition"
+    type: "int"
+    has_minimum: true
+  }
+  attr {
+    name: "input_size"
+    type: "int"
+    has_minimum: true
+  }
+}
+op {
+  name: "XlaSparseDenseMatmulGradWithAdagradAndCsrInput"
+  input_arg {
+    name: "row_pointers"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_sample_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_token_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_gains"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "activation_gradients"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "learning_rate"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "embedding_table"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "accumulator"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "num_minibatches_per_physical_sparse_core"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "updated_embedding_table"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_accumulator"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "clip_weight_min"
+    type: "float"
+    default_value {
+      f: -inf
+    }
+  }
+  attr {
+    name: "clip_weight_max"
+    type: "float"
+    default_value {
+      f: inf
+    }
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+  }
+}
+op {
+  name: "XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput"
+  input_arg {
+    name: "row_pointers"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_sample_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_token_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_gains"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "activation_gradients"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "learning_rate"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "embedding_table"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "accumulator"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "momenta"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "num_minibatches_per_physical_sparse_core"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "updated_embedding_table"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_accumulator"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_momenta"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "use_nesterov"
+    type: "bool"
+  }
+  attr {
+    name: "exponent"
+    type: "float"
+  }
+  attr {
+    name: "beta1"
+    type: "float"
+  }
+  attr {
+    name: "beta2"
+    type: "float"
+  }
+  attr {
+    name: "epsilon"
+    type: "float"
+  }
+  attr {
+    name: "clip_weight_min"
+    type: "float"
+    default_value {
+      f: -inf
+    }
+  }
+  attr {
+    name: "clip_weight_max"
+    type: "float"
+    default_value {
+      f: inf
+    }
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+  }
+}
+op {
+  name: "XlaSparseDenseMatmulGradWithAdamAndCsrInput"
+  input_arg {
+    name: "row_pointers"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_sample_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_token_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_gains"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "activation_gradients"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "learning_rate"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "embedding_table"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "momenta"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "velocity"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "num_minibatches_per_physical_sparse_core"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "updated_embedding_table"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_momenta"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_velocity"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "use_sum_inside_sqrt"
+    type: "bool"
+  }
+  attr {
+    name: "beta1"
+    type: "float"
+  }
+  attr {
+    name: "beta2"
+    type: "float"
+  }
+  attr {
+    name: "epsilon"
+    type: "float"
+  }
+  attr {
+    name: "clip_weight_min"
+    type: "float"
+    default_value {
+      f: -inf
+    }
+  }
+  attr {
+    name: "clip_weight_max"
+    type: "float"
+    default_value {
+      f: inf
+    }
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+  }
+}
+op {
+  name: "XlaSparseDenseMatmulGradWithFtrlAndCsrInput"
+  input_arg {
+    name: "row_pointers"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_sample_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_token_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_gains"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "activation_gradients"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "learning_rate"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "embedding_table"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "accumulator"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "linear"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "num_minibatches_per_physical_sparse_core"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "updated_embedding_table"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_accumulator"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "updated_linear"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "multiply_linear_by_learning_rate"
+    type: "bool"
+  }
+  attr {
+    name: "beta"
+    type: "float"
+  }
+  attr {
+    name: "learning_rate_power"
+    type: "float"
+  }
+  attr {
+    name: "l1_regularization_strength"
+    type: "float"
+  }
+  attr {
+    name: "l2_regularization_strength"
+    type: "float"
+  }
+  attr {
+    name: "clip_weight_min"
+    type: "float"
+    default_value {
+      f: -inf
+    }
+  }
+  attr {
+    name: "clip_weight_max"
+    type: "float"
+    default_value {
+      f: inf
+    }
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+  }
+}
+op {
+  name: "XlaSparseDenseMatmulGradWithSgdAndCsrInput"
+  input_arg {
+    name: "row_pointers"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_sample_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_token_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_gains"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "activation_gradients"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "learning_rate"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "embedding_table"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "num_minibatches_per_physical_sparse_core"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "updated_embedding_table"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "clip_weight_min"
+    type: "float"
+    default_value {
+      f: -inf
+    }
+  }
+  attr {
+    name: "clip_weight_max"
+    type: "float"
+    default_value {
+      f: inf
+    }
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+  }
+}
+op {
+  name: "XlaSparseDenseMatmulWithCsrInput"
+  input_arg {
+    name: "row_pointers"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_sample_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_token_ids"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "sorted_gains"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "embedding_table"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "num_minibatches_per_physical_sparse_core"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "activations"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "input_size"
+    type: "int"
+    has_minimum: true
+  }
+  attr {
+    name: "quantization_config_low"
+    type: "float"
+  }
+  attr {
+    name: "quantization_config_high"
+    type: "float"
+  }
+  attr {
+    name: "quantization_config_num_buckets"
+    type: "int"
+    has_minimum: true
+  }
+  attr {
+    name: "table_name"
+    type: "string"
+  }
+}
+op {
   name: "XlaSplitND"
   input_arg {
     name: "input"
diff --git a/tensorflow/core/platform/prefetch.h b/tensorflow/core/platform/prefetch.h
index 85ffcd8..019493f 100644
--- a/tensorflow/core/platform/prefetch.h
+++ b/tensorflow/core/platform/prefetch.h
@@ -24,8 +24,6 @@
 using ::tsl::port::prefetch;
 using ::tsl::port::PREFETCH_HINT_NTA;
 using ::tsl::port::PREFETCH_HINT_T0;
-using ::tsl::port::PREFETCH_HINT_T1;
-using ::tsl::port::PREFETCH_HINT_T2;
 using ::tsl::port::PrefetchHint;
 // NOLINTEND(misc-unused-using-decls)
 }  // namespace port
diff --git a/tensorflow/core/profiler/convert/BUILD b/tensorflow/core/profiler/convert/BUILD
index 7dfbd5e..3590422 100644
--- a/tensorflow/core/profiler/convert/BUILD
+++ b/tensorflow/core/profiler/convert/BUILD
@@ -932,11 +932,14 @@
         "//tensorflow/core/profiler/convert/trace_viewer:trace_events_util",
         "//tensorflow/core/profiler/protobuf:trace_events_proto_cc",
         "//tensorflow/core/profiler/protobuf:trace_events_raw_proto_cc",
+        "//tensorflow/core/profiler/utils:xplane_utils",
         "@com_google_absl//absl/strings",
         "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc",
         "@local_tsl//tsl/profiler/utils:tf_xplane_visitor",
         "@local_tsl//tsl/profiler/utils:timespan",
+        "@local_tsl//tsl/profiler/utils:trace_utils",
         "@local_tsl//tsl/profiler/utils:xplane_schema",
+        "@local_tsl//tsl/profiler/utils:xplane_utils",
         "@local_tsl//tsl/profiler/utils:xplane_visitor",
     ],
 )
diff --git a/tensorflow/core/profiler/convert/op_profile_builder.cc b/tensorflow/core/profiler/convert/op_profile_builder.cc
index cfa12a9..2111ea4 100644
--- a/tensorflow/core/profiler/convert/op_profile_builder.cc
+++ b/tensorflow/core/profiler/convert/op_profile_builder.cc
@@ -150,18 +150,13 @@
 // This is only for convolutions, not other HLOs, categories or whole programs.
 // TODO(b/243596435) Find a permanent fix to this problem.
 int64_t GetComputationSize(Node node) {
-  int64_t computation_size = 0;
-  for (const auto& child : node.children()) {
-    if (GetComputationSize(child) != 0) {
-      computation_size = GetComputationSize(child);
-    }
+  if (node.has_xla() && node.xla().computation_primitive_size() > 0) {
+    return node.xla().computation_primitive_size();
   }
-  if (node.has_xla()) {
-    if (node.xla().computation_primitive_size() > 0) {
-      return node.xla().computation_primitive_size();
-    } else {
+  for (auto child_iter = node.children().rbegin();
+       child_iter != node.children().rend(); ++child_iter) {
+    if (const int64_t computation_size = GetComputationSize(*child_iter))
       return computation_size;
-    }
   }
   return 0;
 }
diff --git a/tensorflow/core/profiler/convert/op_stats_combiner.cc b/tensorflow/core/profiler/convert/op_stats_combiner.cc
index e46e5b2..a4ce52b 100644
--- a/tensorflow/core/profiler/convert/op_stats_combiner.cc
+++ b/tensorflow/core/profiler/convert/op_stats_combiner.cc
@@ -104,7 +104,9 @@
   dst->mutable_hostnames()->insert(src.hostnames().begin(),
                                    src.hostnames().end());
   dst->set_host_count(dst->hostnames_size());
-  if (src.device_type() != "CPU") {
+  // Ignore CPU and Unknown Device type for device type selection if the
+  // destination does not have a device type already.
+  if (src.device_type() != "CPU" && src.device_type() != "Device") {
     dst->set_device_type(src.device_type());
     dst->set_device_core_count(src.device_core_count() +
                                dst->device_core_count());
diff --git a/tensorflow/core/profiler/convert/op_stats_combiner_test.cc b/tensorflow/core/profiler/convert/op_stats_combiner_test.cc
index 8a43da8..517fd22 100644
--- a/tensorflow/core/profiler/convert/op_stats_combiner_test.cc
+++ b/tensorflow/core/profiler/convert/op_stats_combiner_test.cc
@@ -60,6 +60,26 @@
                      .profile_duration_ms());
 }
 
+TEST(CombineAllOpStatsTest, CombineRunEnvironmentWithUnknownDevice) {
+  OpStats dst_op_stats, op_stats_1, op_stats_2;
+  op_stats_1.mutable_run_environment()->set_device_type("TPU");
+  op_stats_2.mutable_run_environment()->set_device_type("Device");
+  OpStatsInfo op_stats_info_1(&op_stats_1, TPU, 0),
+      op_stats_info_2(&op_stats_2, TPU, 0);
+  std::vector<OpStatsInfo> all_op_stats_info = {op_stats_info_1,
+                                                op_stats_info_2};
+
+  // Construct dummy step_intersection.
+  StepDatabaseResult dummy_step_db_result;
+  absl::flat_hash_map<uint32 /*=host_id*/, const StepDatabaseResult*> result;
+  result.insert({0, &dummy_step_db_result});
+  StepIntersection dummy_step_intersection = StepIntersection(1, result);
+
+  CombineAllOpStats(all_op_stats_info, dummy_step_intersection, &dst_op_stats);
+
+  EXPECT_EQ("TPU", dst_op_stats.run_environment().device_type());
+}
+
 }  // namespace
 }  // namespace profiler
 }  // namespace tensorflow
diff --git a/tensorflow/core/profiler/convert/op_stats_to_op_profile.cc b/tensorflow/core/profiler/convert/op_stats_to_op_profile.cc
index 9b74af2..5b9a4bf 100644
--- a/tensorflow/core/profiler/convert/op_stats_to_op_profile.cc
+++ b/tensorflow/core/profiler/convert/op_stats_to_op_profile.cc
@@ -88,18 +88,15 @@
                          /*exclude_idle_ops=*/true, op_profile_limit,
                          profile.mutable_by_category_exclude_idle());
 
-  // Don't generate per program profile if there's only a single program.
-  if (op_stats.program_id_to_name_map_size() > 1) {
-    BuildOpProfileNodeTree(op_stats,
-                           /*group_by_program=*/true,
-                           /*exclude_idle_ops=*/false, op_profile_limit,
-                           profile.mutable_by_program());
+  BuildOpProfileNodeTree(op_stats,
+                         /*group_by_program=*/true,
+                         /*exclude_idle_ops=*/false, op_profile_limit,
+                         profile.mutable_by_program());
 
-    BuildOpProfileNodeTree(op_stats,
-                           /*group_by_program=*/true,
-                           /*exclude_idle_ops=*/true, op_profile_limit,
-                           profile.mutable_by_program_exclude_idle());
-  }
+  BuildOpProfileNodeTree(op_stats,
+                         /*group_by_program=*/true,
+                         /*exclude_idle_ops=*/true, op_profile_limit,
+                         profile.mutable_by_program_exclude_idle());
 }
 
 }  // namespace profiler
diff --git a/tensorflow/core/profiler/convert/xplane_to_trace_container.cc b/tensorflow/core/profiler/convert/xplane_to_trace_container.cc
index 67b311e..53f9437 100644
--- a/tensorflow/core/profiler/convert/xplane_to_trace_container.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_trace_container.cc
@@ -19,6 +19,7 @@
 #include <memory>
 #include <optional>
 #include <string>
+#include <vector>
 
 #include "absl/strings/string_view.h"
 #include "tensorflow/core/profiler/convert/trace_viewer/trace_event_arguments_builder.h"
@@ -27,13 +28,17 @@
 #include "tensorflow/core/profiler/protobuf/trace_events_raw.pb.h"
 #include "tsl/profiler/utils/tf_xplane_visitor.h"
 #include "tsl/profiler/utils/timespan.h"
+#include "tsl/profiler/utils/trace_utils.h"
 #include "tsl/profiler/utils/xplane_schema.h"
+#include "tsl/profiler/utils/xplane_utils.h"
 #include "tsl/profiler/utils/xplane_visitor.h"
 
 namespace tensorflow {
 namespace profiler {
 namespace {
 
+using tsl::profiler::FindPlanesWithPrefix;
+using tsl::profiler::FindPlaneWithName;
 using tsl::profiler::HostEventType;
 using tsl::profiler::StatType;
 using tsl::profiler::XEventVisitor;
@@ -185,12 +190,10 @@
   });
 }
 
-}  // namespace
-
-void ConvertXPlaneToTraceEventsContainer(absl::string_view hostname,
+void ConvertXPlaneToTraceEventsContainer(uint64_t device_id,
+                                         absl::string_view hostname,
                                          const XPlane& xplane,
                                          TraceEventsContainer* container) {
-  uint64_t device_id = xplane.id();
   XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(&xplane);
   std::unique_ptr<ResourceGrouperInterface> resource_grouper =
       CreateDefaultResourceGrouper(device_id, plane.Name());
@@ -211,11 +214,35 @@
   });
 }
 
+}  // namespace
+
 void ConvertXSpaceToTraceEventsContainer(absl::string_view hostname,
                                          const XSpace& space,
                                          TraceEventsContainer* container) {
-  for (const auto& plane : space.planes()) {
-    ConvertXPlaneToTraceEventsContainer(hostname, plane, container);
+  const XPlane* host_plane =
+      FindPlaneWithName(space, tsl::profiler::kHostThreadsPlaneName);
+  if (host_plane != nullptr) {
+    ConvertXPlaneToTraceEventsContainer(tsl::profiler::kHostThreadsDeviceId,
+                                        hostname, *host_plane, container);
+  }
+
+  std::vector<const XPlane*> device_planes =
+      FindPlanesWithPrefix(space, tsl::profiler::kGpuPlanePrefix);
+
+  if (device_planes.empty()) {
+    device_planes = FindPlanesWithPrefix(space, tsl::profiler::kTpuPlanePrefix);
+  }
+
+  for (const XPlane* device_plane : device_planes) {
+    ConvertXPlaneToTraceEventsContainer(
+        tsl::profiler::kFirstDeviceId + device_plane->id(), hostname,
+        *device_plane, container);
+  }
+  for (const XPlane* custom_plane :
+       FindPlanesWithPrefix(space, tsl::profiler::kCustomPlanePrefix)) {
+    ConvertXPlaneToTraceEventsContainer(
+        tsl::profiler::kCustomPlaneDeviceId + custom_plane->id(), hostname,
+        *custom_plane, container);
   }
 }
 
diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h
index d52b59d..6ebc351 100644
--- a/tensorflow/core/public/version.h
+++ b/tensorflow/core/public/version.h
@@ -108,7 +108,7 @@
 
 #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
-#define TF_GRAPH_DEF_VERSION 1654  // Updated: 2023/10/19
+#define TF_GRAPH_DEF_VERSION 1668  // Updated: 2023/11/2
 
 // Checkpoint compatibility versions (the versions field in SavedSliceMeta).
 //
diff --git a/tensorflow/core/tfrt/graph_executor/BUILD b/tensorflow/core/tfrt/graph_executor/BUILD
index 7cc355f..38653fe 100644
--- a/tensorflow/core/tfrt/graph_executor/BUILD
+++ b/tensorflow/core/tfrt/graph_executor/BUILD
@@ -1,7 +1,7 @@
 # Placeholder: load py_proto_library
-load("//tensorflow/core/platform:build_config.bzl", "tf_proto_library")
 load("//tensorflow:tensorflow.bzl", "if_google", "tf_cc_shared_object", "tf_cc_test")
 load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable")
+load("//tensorflow/core/platform:build_config.bzl", "tf_proto_library")
 
 package(
     # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
@@ -16,6 +16,7 @@
         # copybara:uncomment "//learning/brain/tfrt/...",
         # copybara:uncomment "//learning/serving/servables/tfrt/...",
         # copybara:uncomment "//smartass/brain/inference/...",
+        # copybara:uncomment "//tensorflow/compiler/mlir/tfrt/...",
         "//tensorflow/core/tfrt/...",
         "//tensorflow/core/tfrt/graph_executor/python/...",
         # copybara:uncomment "//tensorflow_serving/servables/tensorflow/...",
diff --git a/tensorflow/core/tfrt/graph_executor/graph_executor.cc b/tensorflow/core/tfrt/graph_executor/graph_executor.cc
index 2b8430f..2b719d9 100644
--- a/tensorflow/core/tfrt/graph_executor/graph_executor.cc
+++ b/tensorflow/core/tfrt/graph_executor/graph_executor.cc
@@ -196,7 +196,7 @@
     tfrt::ResourceContext* client_graph_resource_context,
     OpKernelRunnerTable* runner_table,
     tfd::FallbackResourceArray* resource_array,
-    const tensorflow::tfrt_stub::FallbackState& fallback_state,
+    tensorflow::tfrt_stub::FallbackState& fallback_state,
     const tensorflow::ProcessFunctionLibraryRuntime&
         process_function_library_runtime,
     CostRecorder* cost_recorder) {
@@ -279,7 +279,7 @@
     tfrt::ResourceContext* client_graph_resource_context,
     OpKernelRunnerTable* runner_table,
     tfd::FallbackResourceArray* resource_array, const Runtime& runtime,
-    const FallbackState& fallback_state,
+    FallbackState& fallback_state,
     const tensorflow::ProcessFunctionLibraryRuntime&
         process_function_library_runtime,
     tfrt::RequestDeadlineTracker* req_deadline_tracker,
@@ -453,13 +453,13 @@
 }
 
 GraphExecutor::GraphExecutor(
-    Options options, const FallbackState& fallback_state,
+    Options options, std::unique_ptr<FallbackState> fallback_state,
     std::unique_ptr<tfrt::ResourceContext> resource_context,
     std::unique_ptr<tensorflow::tfrt_stub::TfrtGraphExecutionState>
         graph_execution_state,
     std::unique_ptr<mlrt::KernelRegistry> kernel_registry)
     : options_(std::move(options)),
-      fallback_state_(fallback_state),
+      fallback_state_(std::move(fallback_state)),
       graph_execution_state_(std::move(graph_execution_state)),
       req_deadline_tracker_(options_.runtime->core_runtime()->GetHostContext()),
       kernel_registry_(std::move(kernel_registry)),
@@ -469,7 +469,7 @@
 }
 
 StatusOr<std::unique_ptr<GraphExecutor>> GraphExecutor::Create(
-    Options options, const FallbackState& fallback_state,
+    Options options, std::unique_ptr<FallbackState> fallback_state,
     std::unique_ptr<tfrt::ResourceContext> resource_context,
     tensorflow::GraphDef graph_def,
     std::unique_ptr<mlrt::KernelRegistry> kernel_registry) {
@@ -491,10 +491,11 @@
   TF_ASSIGN_OR_RETURN(
       auto graph_execution_state,
       TfrtGraphExecutionState::Create(graph_execution_state_options,
-                                      std::move(graph_def), fallback_state));
+                                      std::move(graph_def), *fallback_state));
   return std::make_unique<GraphExecutor>(
-      std::move(options), fallback_state, std::move(resource_context),
-      std::move(graph_execution_state), std::move(kernel_registry));
+      std::move(options), std::move(fallback_state),
+      std::move(resource_context), std::move(graph_execution_state),
+      std::move(kernel_registry));
 }
 
 namespace {
@@ -605,7 +606,7 @@
       &flat_outputs, resource_context_.get(),
       &executable_context->resource_context,
       &loaded_client_graph.runner_table(),
-      &loaded_client_graph.resource_array(), runtime(), fallback_state_,
+      &loaded_client_graph.resource_array(), runtime(), fallback_state(),
       loaded_client_graph.process_function_library_runtime(),
       &req_deadline_tracker_, loaded_client_graph.stream_callback_id(),
       cost_recorder));
diff --git a/tensorflow/core/tfrt/graph_executor/graph_executor.h b/tensorflow/core/tfrt/graph_executor/graph_executor.h
index 0e9a7bc..f3c0ef5 100644
--- a/tensorflow/core/tfrt/graph_executor/graph_executor.h
+++ b/tensorflow/core/tfrt/graph_executor/graph_executor.h
@@ -87,8 +87,7 @@
     tfrt::ResourceContext* resource_context,
     tfrt::ResourceContext* client_graph_resource_context,
     OpKernelRunnerTable* runner_table,
-    tfd::FallbackResourceArray* resource_array,
-    const FallbackState& fallback_state,
+    tfd::FallbackResourceArray* resource_array, FallbackState& fallback_state,
     const ProcessFunctionLibraryRuntime& process_function_library_runtime,
     CostRecorder* cost_recorder = nullptr);
 
@@ -110,7 +109,7 @@
     tfrt::ResourceContext* client_graph_resource_context,
     OpKernelRunnerTable* runner_table,
     tfd::FallbackResourceArray* resource_array, const Runtime& runtime,
-    const FallbackState& fallback_state,
+    FallbackState& fallback_state,
     const tensorflow::ProcessFunctionLibraryRuntime&
         process_function_library_runtime,
     tfrt::RequestDeadlineTracker* req_deadline_tracker,
@@ -232,13 +231,13 @@
 
   // Creates a `GraphExecutor` given the args.
   static StatusOr<std::unique_ptr<GraphExecutor>> Create(
-      Options options, const FallbackState& fallback_state,
+      Options options, std::unique_ptr<FallbackState> fallback_state,
       std::unique_ptr<tfrt::ResourceContext> resource_context,
       tensorflow::GraphDef graph_def,
       std::unique_ptr<mlrt::KernelRegistry> kernel_registry);
 
   // Ctor. Public for `Create()`. Do not use directly.
-  GraphExecutor(Options options, const FallbackState& fallback_state,
+  GraphExecutor(Options options, std::unique_ptr<FallbackState> fallback_state,
                 std::unique_ptr<tfrt::ResourceContext> resource_context,
                 std::unique_ptr<tensorflow::tfrt_stub::TfrtGraphExecutionState>
                     graph_execution_state,
@@ -282,7 +281,8 @@
   tfrt::ResourceContext& resource_context() { return *resource_context_; }
 
   const Options& options() const { return options_; }
-  const FallbackState& fallback_state() const { return fallback_state_; }
+  const FallbackState& fallback_state() const { return *fallback_state_; }
+  FallbackState& fallback_state() { return *fallback_state_; }
 
   // Compiles graph for `graph_name` and runs any initializers.
   tensorflow::Status CompileGraph(
@@ -329,7 +329,7 @@
       TF_LOCKS_EXCLUDED(loaded_client_graphs_mu_);
 
   Options options_;
-  std::reference_wrapper<const FallbackState> fallback_state_;
+  std::unique_ptr<FallbackState> fallback_state_;
 
   std::unique_ptr<tensorflow::tfrt_stub::TfrtGraphExecutionState>
       graph_execution_state_;
diff --git a/tensorflow/core/tfrt/graph_executor/graph_executor_test.cc b/tensorflow/core/tfrt/graph_executor/graph_executor_test.cc
index c6155ce..7088231 100644
--- a/tensorflow/core/tfrt/graph_executor/graph_executor_test.cc
+++ b/tensorflow/core/tfrt/graph_executor/graph_executor_test.cc
@@ -96,7 +96,7 @@
   auto resource_context = std::make_unique<tfrt::ResourceContext>();
   TF_ASSERT_OK_AND_ASSIGN(
       auto graph_executor,
-      GraphExecutor::Create(std::move(options), *fallback_state,
+      GraphExecutor::Create(std::move(options), std::move(fallback_state),
                             std::move(resource_context), graph_def,
                             GetKernelRegistry()));
 
@@ -136,7 +136,7 @@
   auto resource_context = std::make_unique<tfrt::ResourceContext>();
   TF_ASSERT_OK_AND_ASSIGN(
       auto graph_executor_base,
-      GraphExecutor::Create(std::move(options), *fallback_state,
+      GraphExecutor::Create(std::move(options), std::move(fallback_state),
                             std::move(resource_context), graph_def,
                             GetKernelRegistry()));
   auto graph_executor = std::unique_ptr<GraphExecutorForTestingCostAnalysis>(
@@ -190,7 +190,7 @@
   auto resource_context = std::make_unique<tfrt::ResourceContext>();
   TF_ASSERT_OK_AND_ASSIGN(
       auto graph_executor_base,
-      GraphExecutor::Create(std::move(options), *fallback_state,
+      GraphExecutor::Create(std::move(options), std::move(fallback_state),
                             std::move(resource_context), graph_def,
                             GetKernelRegistry()));
   auto graph_executor = std::unique_ptr<GraphExecutorForTestingCostAnalysis>(
@@ -234,7 +234,7 @@
   auto resource_context = std::make_unique<tfrt::ResourceContext>();
   TF_ASSERT_OK_AND_ASSIGN(
       auto graph_executor_base,
-      GraphExecutor::Create(std::move(options), *fallback_state,
+      GraphExecutor::Create(std::move(options), std::move(fallback_state),
                             std::move(resource_context), graph_def,
                             GetKernelRegistry()));
   auto graph_executor = std::unique_ptr<GraphExecutorForTestingCostAnalysis>(
@@ -273,7 +273,7 @@
   auto resource_context = std::make_unique<tfrt::ResourceContext>();
   TF_ASSERT_OK_AND_ASSIGN(
       auto graph_executor_base,
-      GraphExecutor::Create(std::move(options), *fallback_state,
+      GraphExecutor::Create(std::move(options), std::move(fallback_state),
                             std::move(resource_context), graph_def,
                             GetKernelRegistry()));
   auto graph_executor = std::unique_ptr<GraphExecutorForTestingCostAnalysis>(
@@ -402,7 +402,7 @@
   auto resource_context = std::make_unique<tfrt::ResourceContext>();
   TF_ASSERT_OK_AND_ASSIGN(
       auto graph_executor,
-      GraphExecutor::Create(std::move(options), *fallback_state,
+      GraphExecutor::Create(std::move(options), std::move(fallback_state),
                             std::move(resource_context), graph_def,
                             GetKernelRegistry()));
   {
@@ -458,7 +458,7 @@
   auto resource_context = std::make_unique<tfrt::ResourceContext>();
   TF_ASSERT_OK_AND_ASSIGN(
       auto graph_executor,
-      GraphExecutor::Create(std::move(options), *fallback_state,
+      GraphExecutor::Create(std::move(options), std::move(fallback_state),
                             std::move(resource_context), graph_def,
                             GetKernelRegistry()));
 
@@ -503,7 +503,7 @@
   auto resource_context = std::make_unique<tfrt::ResourceContext>();
   TF_ASSERT_OK_AND_ASSIGN(
       auto graph_executor,
-      GraphExecutor::Create(std::move(options), *fallback_state,
+      GraphExecutor::Create(std::move(options), std::move(fallback_state),
                             std::move(resource_context), graph_def,
                             GetKernelRegistry()));
 
@@ -550,7 +550,7 @@
   auto resource_context = std::make_unique<tfrt::ResourceContext>();
   TF_ASSERT_OK_AND_ASSIGN(
       auto graph_executor,
-      GraphExecutor::Create(std::move(options), *fallback_state,
+      GraphExecutor::Create(std::move(options), std::move(fallback_state),
                             std::move(resource_context), graph_def,
                             GetKernelRegistry()));
 
diff --git a/tensorflow/core/tfrt/ifrt/BUILD b/tensorflow/core/tfrt/ifrt/BUILD
new file mode 100644
index 0000000..68ac79c
--- /dev/null
+++ b/tensorflow/core/tfrt/ifrt/BUILD
@@ -0,0 +1,39 @@
+package(
+    # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
+    default_visibility = [
+        # copybara:uncomment "//learning/brain/experimental/tfrt:__subpackages__",
+        # copybara:uncomment "//learning/brain/tfrt:__subpackages__",
+        # copybara:uncomment "//learning/infra/mira/distributed:__subpackages__",
+        "//tensorflow/compiler/mlir/tfrt:__subpackages__",
+        "//tensorflow/compiler/mlir/tfrt/transforms/mlrt:__subpackages__",
+        "//tensorflow/compiler/mlir/tfrt/translate/mlrt:__subpackages__",
+        "//tensorflow/core/tfrt:__subpackages__",
+        "//tensorflow/core/tfrt/mlrt:__subpackages__",
+    ],
+)
+
+cc_library(
+    name = "ifrt_executable_registry",
+    srcs = ["ifrt_executable_registry.cc"],
+    hdrs = ["ifrt_executable_registry.h"],
+    deps = [
+        "//tensorflow/compiler/mlir/tfrt:ifrt_serving_executable",
+        "@com_google_absl//absl/base:core_headers",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/synchronization",
+    ],
+)
+
+cc_library(
+    name = "ifrt_model_context",
+    hdrs = ["ifrt_model_context.h"],
+    deps = [
+        ":ifrt_executable_registry",
+        "@com_google_absl//absl/strings",
+    ],
+)
diff --git a/tensorflow/core/tfrt/ifrt/ifrt_executable_registry.cc b/tensorflow/core/tfrt/ifrt/ifrt_executable_registry.cc
new file mode 100644
index 0000000..6cae8be
--- /dev/null
+++ b/tensorflow/core/tfrt/ifrt/ifrt_executable_registry.cc
@@ -0,0 +1,106 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/tfrt/ifrt/ifrt_executable_registry.h"
+
+#include <cstdint>
+#include <memory>
+#include <optional>
+#include <utility>
+
+#include "absl/base/attributes.h"
+#include "absl/base/const_init.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_cat.h"
+#include "absl/synchronization/mutex.h"
+#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_serving_executable.h"
+
+namespace tensorflow {
+namespace ifrt_serving {
+ServingExecutableRegistry::Handle::Handle(Handle&& other) {
+  *this = std::move(other);
+}
+
+ServingExecutableRegistry::Handle& ServingExecutableRegistry::Handle::operator=(
+    Handle&& other) {
+  if (this != &other) {
+    program_id_ = std::move(other.program_id_);
+    other.program_id_ = std::nullopt;
+  }
+  return *this;
+}
+
+ServingExecutableRegistry::Handle::~Handle() { Release(); }
+
+void ServingExecutableRegistry::Handle::Release() {
+  if (!program_id_.has_value()) {
+    return;
+  }
+
+  absl::MutexLock l(&ServingExecutableRegistry::mu_);
+
+  const auto it = ServingExecutableRegistry::executables_->find(*program_id_);
+  if (it == ServingExecutableRegistry::executables_->end()) {
+    LOG(ERROR) << "Program " << *program_id_ << " not found in the registry";
+    return;
+  }
+
+  VLOG(1) << "Unregistering program " << *program_id_ << " from signature '"
+          << it->second->signature_name() << "' of model '"
+          << it->second->model_name() << "'";
+  ServingExecutableRegistry::executables_->erase(it);
+
+  program_id_ = std::nullopt;
+}
+
+ServingExecutableRegistry::Handle::Handle(int64_t program_id)
+    : program_id_(program_id) {}
+
+absl::StatusOr<ServingExecutableRegistry::Handle>
+ServingExecutableRegistry::Register(
+    int64_t program_id, std::shared_ptr<IfrtServingExecutable> executable) {
+  absl::MutexLock l(&mu_);
+  VLOG(1) << "Registering program " << program_id << " from signature '"
+          << executable->signature_name() << "' of model '"
+          << executable->model_name() << "'"
+          << ", address is " << executable.get();
+  if (!executables_->insert({program_id, std::move(executable)}).second) {
+    return absl::AlreadyExistsError(absl::StrCat(
+        "Program ", program_id, " already exists in the program registry"));
+  }
+  return Handle(program_id);
+}
+
+std::shared_ptr<IfrtServingExecutable> ServingExecutableRegistry::Lookup(
+    int64_t program_id) {
+  absl::ReaderMutexLock l(&mu_);
+  VLOG(1) << "Looking up program " << program_id;
+  const auto it = executables_->find(program_id);
+  return it != executables_->end() ? it->second : nullptr;
+}
+
+ABSL_CONST_INIT absl::Mutex ServingExecutableRegistry::mu_(absl::kConstInit);
+
+absl::flat_hash_map<int64_t, std::shared_ptr<IfrtServingExecutable>>* const
+    ServingExecutableRegistry::executables_ =
+        new absl::flat_hash_map<int64_t,
+                                std::shared_ptr<IfrtServingExecutable>>();
+
+}  // namespace ifrt_serving
+}  // namespace tensorflow
diff --git a/tensorflow/core/tfrt/ifrt/ifrt_executable_registry.h b/tensorflow/core/tfrt/ifrt/ifrt_executable_registry.h
new file mode 100644
index 0000000..a5942d7
--- /dev/null
+++ b/tensorflow/core/tfrt/ifrt/ifrt_executable_registry.h
@@ -0,0 +1,97 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_TFRT_IFRT_IFRT_EXECUTABLE_REGISTRY_H_
+#define TENSORFLOW_CORE_TFRT_IFRT_IFRT_EXECUTABLE_REGISTRY_H_
+
+#include <cstdint>
+#include <memory>
+#include <optional>
+
+#include "absl/base/thread_annotations.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/status/statusor.h"
+#include "absl/synchronization/mutex.h"
+#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_serving_executable.h"
+
+namespace tensorflow {
+namespace ifrt_serving {
+
+// Maintains a process-wide map from program ids to executables. Used by the
+// `IfrtCall` TensorFlow op to look up executables and invoke them.
+//
+// Invoking a TPU program inside a `IfrtCall` TF op requires being
+// able to retrieve an executable for the given program. Since there's no easy
+// way to pass non-serializable attributes to TF ops, we encode a program id
+// (that is unique within a process) as an attribute of a `IfrtCall` op and
+// use this registry class to let the `IfrtCall` op look up an executable
+// during TF op execution.
+class ServingExecutableRegistry {
+ public:
+  // RAII handle for registered executables.
+  class Handle {
+   public:
+    Handle();  // Constructs an empty handle.
+
+    // Move only.
+    Handle(Handle&& other);
+    Handle& operator=(Handle&& other);
+    Handle(const Handle&) = delete;
+    Handle& operator=(const Handle&) = delete;
+
+    ~Handle();
+
+    // Returns the program id that the handle represents, or `std::nullopt` if
+    // the handle is empty.
+    std::optional<int64_t> program_id() const { return program_id_; }
+
+    // Unregisters the owned executable, if any, early (before the destructor).
+    // Calling this method multiple times is a no-op.
+    void Release();
+
+   private:
+    friend class ServingExecutableRegistry;
+
+    // Can only be constructed by `ServingExecutableRegistry::Register()`.
+    explicit Handle(int64_t program_id);
+
+    // Program id. `std::nullopt` if the handle is already released.
+    std::optional<int64_t> program_id_;
+  };
+
+  // Registers an executable under the given program id. Returns an RAII handle
+  // that unregisters the program at its destruction.
+  static absl::StatusOr<Handle> Register(
+      int64_t program_id, std::shared_ptr<IfrtServingExecutable> executable);
+
+  // Looks up an executable registered under the given program id, or returns
+  // nullptr if there's no such program.
+  static std::shared_ptr<IfrtServingExecutable> Lookup(int64_t program_id);
+
+ private:
+  friend class Handle;
+
+  static absl::Mutex mu_;
+
+  // Mapping from program ids to executables.
+  static absl::flat_hash_map<int64_t,
+                             std::shared_ptr<IfrtServingExecutable>>* const
+      executables_ ABSL_GUARDED_BY(&mu_);
+};
+
+}  // namespace ifrt_serving
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_TFRT_IFRT_IFRT_EXECUTABLE_REGISTRY_H_
diff --git a/tensorflow/core/tfrt/ifrt/ifrt_model_context.h b/tensorflow/core/tfrt/ifrt/ifrt_model_context.h
new file mode 100644
index 0000000..631044c
--- /dev/null
+++ b/tensorflow/core/tfrt/ifrt/ifrt_model_context.h
@@ -0,0 +1,46 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_TFRT_IFRT_IFRT_MODEL_CONTEXT_H_
+#define TENSORFLOW_CORE_TFRT_IFRT_IFRT_MODEL_CONTEXT_H_
+
+#include <utility>
+#include <vector>
+
+#include "absl/strings/string_view.h"
+#include "tensorflow/core/tfrt/ifrt/ifrt_executable_registry.h"
+
+namespace tensorflow {
+namespace ifrt_serving {
+
+inline constexpr absl::string_view kIfrtModelContextName = "IfrtModelContext";
+
+// The runtime context for ifrt to be used in TFRT serving.
+//
+// This class is thread compatible.
+class IfrtModelContext {
+ public:
+  void RegisterHandle(ServingExecutableRegistry::Handle handle) {
+    handles_.push_back(std::move(handle));
+  }
+
+ private:
+  std::vector<ServingExecutableRegistry::Handle> handles_;
+};
+
+}  // namespace ifrt_serving
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_TFRT_IFRT_IFRT_MODEL_CONTEXT_H_
diff --git a/tensorflow/core/tfrt/mlrt/kernel/batch_kernel.cc b/tensorflow/core/tfrt/mlrt/kernel/batch_kernel.cc
index e9e5a79..52e8757 100644
--- a/tensorflow/core/tfrt/mlrt/kernel/batch_kernel.cc
+++ b/tensorflow/core/tfrt/mlrt/kernel/batch_kernel.cc
@@ -24,6 +24,7 @@
 #include "absl/base/optimization.h"
 #include "absl/strings/string_view.h"
 #include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
 #include "tensorflow/core/platform/statusor.h"
 #include "tensorflow/core/runtime_fallback/runtime/fallback_batch_kernel.h"
@@ -383,6 +384,13 @@
     Name(kMlrtBatchFunctionName).Device(DEVICE_CPU),
     tfrt_stub::BatchFunctionFallbackKernel<MlrtBatchResource>);
 
+// TFRT does not depend on the device annotation.
+// MLRT Batch function will not actually execute on GPU, but rather on CPU.
+// This kernel is registered on accelerator to get through the check.
+REGISTER_KERNEL_BUILDER(
+    Name(kMlrtBatchFunctionName).Device(DEVICE_GPU),
+    tfrt_stub::BatchFunctionFallbackKernel<MlrtBatchResource>);
+
 // Identical to BatchFunction except it has 2 extra TFRT attributes and it does
 // not have `f` attribute. Users will not invoke this op directly.
 REGISTER_OP(kMlrtBatchFunctionName)
diff --git a/tensorflow/core/tfrt/runtime/runtime.h b/tensorflow/core/tfrt/runtime/runtime.h
index 1c22508..aa53c65 100644
--- a/tensorflow/core/tfrt/runtime/runtime.h
+++ b/tensorflow/core/tfrt/runtime/runtime.h
@@ -66,6 +66,12 @@
     flib_def_ = flib_def;
   }
 
+  bool is_local_session() const { return is_local_session_; }
+
+  void set_is_local_session(bool is_local_session) {
+    is_local_session_ = is_local_session;
+  }
+
   tfrt::ResourceContext& resource_context() { return *resource_context_; }
 
   const GraphExecutionOptions& graph_execution_options() const {
@@ -81,6 +87,8 @@
   tfrt::ResourceContext* resource_context_ = nullptr;
 
   FunctionLibraryDefinition* flib_def_ = nullptr;
+
+  bool is_local_session_ = false;
 };
 
 // This defines the runtime abstraction in tensorflow for TFRT. It is supposed
diff --git a/tensorflow/core/tfrt/saved_model/BUILD b/tensorflow/core/tfrt/saved_model/BUILD
index 6830567..3d69e74 100644
--- a/tensorflow/core/tfrt/saved_model/BUILD
+++ b/tensorflow/core/tfrt/saved_model/BUILD
@@ -18,6 +18,7 @@
         # copybara:uncomment "//learning/pathways/serving/runtime/...",
         "//tensorflow/core/runtime_fallback/...",
         "//tensorflow/core/tfrt/mlrt/application/tensorflow/tests/...",
+        "//tensorflow/compiler/mlir/tfrt/...",
         "//tensorflow/core/tfrt/...",
         "//tensorflow_serving/...",
         "//tensorflow/core/tfrt/saved_model/python/...",
@@ -44,7 +45,9 @@
         "//tensorflow/compiler/mlir/tensorflow:serialize_mlir_module_utils",
         "//tensorflow/compiler/mlir/tensorflow:translate_lib",
         "//tensorflow/compiler/mlir/tfrt:import_model",
+        "//tensorflow/compiler/mlir/tfrt:tfrt_compile_options",
         "//tensorflow/compiler/mlir/tfrt:tfrt_pipeline_options",
+        "//tensorflow/compiler/mlir/tfrt/transforms/mlrt:import_model",
         "//tensorflow/compiler/tf2xla:xla_compiler",
         "//tensorflow/core:core_cpu_base",
         "//tensorflow/core:framework",
@@ -62,8 +65,9 @@
         "//tensorflow/core/tfrt/graph_executor",
         "//tensorflow/core/tfrt/graph_executor:export_mlir",
         "//tensorflow/core/tfrt/graph_executor:graph_execution_options",
+        "//tensorflow/core/tfrt/mlrt/bytecode",
         "//tensorflow/core/tfrt/runtime",
-        "//tensorflow/core/tfrt/saved_model/utils:serialize_bef_utils",
+        "//tensorflow/core/tfrt/saved_model/utils:serialize_utils",
         "//tensorflow/core/tfrt/utils",
         "//tensorflow/core/tpu:virtual_device",
         "@com_google_absl//absl/container:flat_hash_map",
@@ -98,6 +102,7 @@
     deps = [
         ":saved_model_util",
         "//tensorflow/cc/saved_model:reader",
+        "//tensorflow/compiler/jit:flags_headers",
         "//tensorflow/compiler/mlir/tensorflow",
         "//tensorflow/compiler/mlir/tensorflow:import_model",
         "//tensorflow/compiler/mlir/tensorflow:serialize_mlir_module_utils",
@@ -135,7 +140,7 @@
         "//tensorflow/core/tfrt/mlrt/kernel:batch_kernel",
         "//tensorflow/core/tfrt/runtime",
         "//tensorflow/core/tfrt/runtime:work_queue_interface",
-        "//tensorflow/core/tfrt/saved_model/utils:serialize_bef_utils",
+        "//tensorflow/core/tfrt/saved_model/utils:serialize_utils",
         "//tensorflow/core/tfrt/stubs:model_config_stub",
         "//tensorflow/core/tfrt/utils",
         "//tensorflow/core/tfrt/utils:error_util",
@@ -280,7 +285,7 @@
         "//tensorflow/core/tfrt/graph_executor",
         "//tensorflow/core/tfrt/graph_executor:graph_execution_options",
         "//tensorflow/core/tfrt/runtime",
-        "//tensorflow/core/tfrt/saved_model/utils:serialize_bef_utils",
+        "//tensorflow/core/tfrt/saved_model/utils:serialize_utils",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/log",
         "@com_google_absl//absl/log:check",
diff --git a/tensorflow/core/tfrt/saved_model/python/BUILD b/tensorflow/core/tfrt/saved_model/python/BUILD
index 517f1b2..eaa54fd 100644
--- a/tensorflow/core/tfrt/saved_model/python/BUILD
+++ b/tensorflow/core/tfrt/saved_model/python/BUILD
@@ -34,25 +34,6 @@
 # )
 # copybara:uncomment_end
 
-tf_python_pybind_extension(
-    name = "_pywrap_saved_model_aot_compile",
-    srcs = ["saved_model_aot_compile_wrapper.cc"],
-    enable_stub_generation = True,
-    module_name = "_pywrap_saved_model_aot_compile",
-    pytype_srcs = [
-        "_pywrap_saved_model_aot_compile.pyi",
-    ],
-    deps = [
-        "//tensorflow/core/tfrt/graph_executor:graph_execution_options",
-        "//tensorflow/core/tfrt/runtime",
-        "//tensorflow/core/tfrt/saved_model:saved_model_aot_compile",
-        "//tensorflow/python/lib/core:pybind11_lib",
-        "@pybind11",
-        "@pybind11_abseil//pybind11_abseil:absl_casters",
-        "@pybind11_abseil//pybind11_abseil:status_casters",
-    ],
-)
-
 cc_library(
     name = "saved_model_load_and_run",
     srcs = ["saved_model_load_and_run.cc"],
diff --git a/tensorflow/core/tfrt/saved_model/python/saved_model_aot_compile_wrapper.cc b/tensorflow/core/tfrt/saved_model/python/saved_model_aot_compile_wrapper.cc
deleted file mode 100644
index 1dd6257..0000000
--- a/tensorflow/core/tfrt/saved_model/python/saved_model_aot_compile_wrapper.cc
+++ /dev/null
@@ -1,38 +0,0 @@
-/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "pybind11/pybind11.h"  // from @pybind11
-#include "pybind11_abseil/status_casters.h"  // from @pybind11_abseil
-#include "tensorflow/core/tfrt/graph_executor/graph_execution_options.h"
-#include "tensorflow/core/tfrt/runtime/runtime.h"
-#include "tensorflow/core/tfrt/saved_model/saved_model_aot_compile.h"
-#include "tensorflow/python/lib/core/pybind11_lib.h"
-
-namespace py = pybind11;
-
-PYBIND11_MODULE(_pywrap_saved_model_aot_compile, m) {
-  py::google::ImportStatusModule();
-
-  py::class_<tensorflow::tfrt_stub::AotOptions>(m, "AotOptions")
-      .def(py::init<>());
-  m.doc() = "pybind11 AotOptions Python - C++ Wrapper";
-
-  m.def("AotCompileSavedModel",
-        &tensorflow::tfrt_stub::AotCompileSavedModelAndSaveResult,
-        py::arg("input_model_dir") = absl::string_view(),
-        py::arg("aot_options") = tensorflow::tfrt_stub::AotOptions(),
-        py::arg("output_model_dir") = absl::string_view());
-  m.doc() = "pybind11 AotCompileSavedModel Python - C++ Wrapper";
-}
diff --git a/tensorflow/core/tfrt/saved_model/saved_model.cc b/tensorflow/core/tfrt/saved_model/saved_model.cc
index 29b91db..487c774 100644
--- a/tensorflow/core/tfrt/saved_model/saved_model.cc
+++ b/tensorflow/core/tfrt/saved_model/saved_model.cc
@@ -40,6 +40,7 @@
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/OwningOpRef.h"  // from @llvm-project
 #include "tensorflow/cc/saved_model/reader.h"
+#include "tensorflow/compiler/jit/flags.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
 #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
@@ -76,7 +77,7 @@
 #include "tensorflow/core/tfrt/runtime/runtime.h"
 #include "tensorflow/core/tfrt/runtime/work_queue_interface.h"
 #include "tensorflow/core/tfrt/saved_model/saved_model_util.h"
-#include "tensorflow/core/tfrt/saved_model/utils/serialize_bef_utils.h"
+#include "tensorflow/core/tfrt/saved_model/utils/serialize_utils.h"
 #include "tensorflow/core/tfrt/stubs/model_config_stub.h"
 #include "tensorflow/core/tfrt/utils/error_util.h"
 #include "tensorflow/core/tfrt/utils/fallback_tensor.h"
@@ -136,8 +137,7 @@
     const InitializersAndSignatures& initializers_and_signatures,
     const mlrt::LoadedExecutable& loaded_executable,
     tfrt::ResourceContext* resource_context, OpKernelRunnerTable* runner_table,
-    tfd::FallbackResourceArray* resource_array,
-    const FallbackState& fallback_state) {
+    tfd::FallbackResourceArray* resource_array, FallbackState& fallback_state) {
   TF_ASSIGN_OR_RETURN(
       auto request_info,
       CreateRequestInfo(options, /*run_options=*/{},
@@ -182,8 +182,7 @@
     const InitializersAndSignatures& initializers_and_signatures,
     tfrt::BEFFile* bef_file, tfrt::ResourceContext* resource_context,
     OpKernelRunnerTable* runner_table,
-    tfd::FallbackResourceArray* resource_array,
-    const FallbackState& fallback_state) {
+    tfd::FallbackResourceArray* resource_array, FallbackState& fallback_state) {
   DCHECK(options.runtime);
   TF_ASSIGN_OR_RETURN(
       auto request_info,
@@ -450,11 +449,12 @@
   options.graph_execution_options.compile_options.saved_model_dir =
       saved_model_dir;
 
+  const bool aot_exist = AotPackageExists(saved_model_dir);
   // Register TFRT dialects
   mlir::DialectRegistry registry;
-  if (AotPackageExists(saved_model_dir)) {
+  if (aot_exist) {
     LOG(INFO) << "Found AoT package. Register required dialects.";
-    RegisterTFRTDialectsForAoT(registry);
+    RegisterTfrtDialectsForAot(registry);
   }
   RegisterMlirDialect(registry);
   mlir::MLIRContext context(registry);
@@ -491,11 +491,11 @@
   }
 
   mlir::OwningOpRef<mlir::ModuleOp> mlir_module;
-  if (AotPackageExists(saved_model_dir)) {
+  if (aot_exist) {
     LOG(INFO) << "Found AoT package. Load and deserialize MLIR module.";
 
     TF_RETURN_IF_ERROR(
-        DeserializeAoTMlirModule(saved_model_dir, &context, &mlir_module));
+        DeserializeAotMlirModule(saved_model_dir, &context, &mlir_module));
   } else {
     ASSIGN_OR_RETURN_IN_IMPORT(
         mlir_module,
@@ -555,13 +555,18 @@
 
   mlrt::bc::Buffer bytecode;
   tfrt::BefBuffer bef;
-  if (AotPackageExists(saved_model_dir)) {
+  if (aot_exist) {
     LOG(INFO) << "Found AoT package. Load and deserialize BEF.";
+    if (options.graph_execution_options.enable_mlrt) {
+      // TODO(b/303504882): Add deserialization for mlrt path
+      return absl::InternalError("AOT is not supported in MLRT");
+    } else {
+      ASSIGN_OR_RETURN_IN_COMPILE(
+          bef, LoadBefAndMlir(options.graph_execution_options.compile_options,
+                              mlir_module.get(), saved_model_dir_string,
+                              fallback_state.get()));
+    }
 
-    ASSIGN_OR_RETURN_IN_COMPILE(
-        bef, LoadAotPackages(options.graph_execution_options.compile_options,
-                             mlir_module.get(), saved_model_dir_string,
-                             fallback_state.get()));
   } else {
     tensorflow::tf_mlrt::RegisterTfMlrtKernels(*kernel_registry);
     tensorflow::tf_mlrt::RegisterTfMlrtBatchKernels(*kernel_registry);
@@ -575,12 +580,18 @@
       RETURN_IF_ERROR_IN_COMPILE(tensorflow::ConvertTfMlirToBef(
           options.graph_execution_options.compile_options, mlir_module.get(),
           &bef, model_context, fallback_state.get()));
+      if (options.graph_execution_options.compile_options
+              .serialize_bef_to_aot_packages) {
+        TF_RETURN_IF_ERROR(SerializeBEF(
+            bef, options.graph_execution_options.compile_options.aot_bef_file));
+      }
     }
   }
 
   ASSIGN_OR_RETURN_WITH_STAGE_INFO(
       "graph_executor creation", auto graph_executor,
-      GraphExecutor::Create(options.graph_execution_options, *fallback_state,
+      GraphExecutor::Create(options.graph_execution_options,
+                            std::move(fallback_state),
                             std::move(resource_context),
                             std::move(*meta_graph_def.mutable_graph_def()),
                             std::move(kernel_registry)));
@@ -610,13 +621,14 @@
     RETURN_IF_ERROR_IN_INIT(RunBytecodeInitializers(
         graph_executor->options(), initializers_and_signatures,
         *loaded_executable, &graph_executor->resource_context(),
-        runner_table.get(), resource_array.get(), *fallback_state));
+        runner_table.get(), resource_array.get(),
+        graph_executor->fallback_state()));
   } else {
     DCHECK(bef_file);
     RETURN_IF_ERROR_IN_INIT(RunBefInitializers(
         graph_executor->options(), initializers_and_signatures, bef_file.get(),
         &graph_executor->resource_context(), runner_table.get(),
-        resource_array.get(), *fallback_state));
+        resource_array.get(), graph_executor->fallback_state()));
   }
 
   const auto init_duration = absl::Now() - init_start_time;
@@ -625,14 +637,27 @@
   LOG(INFO) << "TFRT finished initializing savedmodel. Took "
             << absl::ToInt64Milliseconds(init_duration) << " ms.";
 
+  if (aot_exist) {
+    // Set persistent cache directory so that the binaries can be loaded from
+    // the AOT directory.
+    const std::string persistent_cache_directory =
+        GetAotPackagePath(saved_model_dir);
+    tensorflow::GetMarkForCompilationPassFlags()
+        ->tf_xla_persistent_cache_directory = persistent_cache_directory;
+    tensorflow::GetMarkForCompilationPassFlags()
+        ->tf_xla_persistent_cache_read_only = true;
+    LOG(INFO) << "Set persistent cache directory to "
+              << persistent_cache_directory << ", and set it to read-only.";
+  }
+
   // Finally, create the saved model.
   return {std::make_unique<SavedModelImpl>(
       std::move(options), std::move(symbol_uids), std::move(meta_graph_def),
       std::move(bef), std::move(bef_file), std::move(bytecode),
       std::move(loaded_executable),
       std::move(initializers_and_signatures.signature_map),
-      std::move(fallback_state), std::move(runner_table),
-      std::move(resource_array), std::move(graph_executor))};
+      std::move(runner_table), std::move(resource_array),
+      std::move(graph_executor))};
 }
 
 SavedModelImpl::SavedModelImpl(
@@ -640,8 +665,7 @@
     tensorflow::MetaGraphDef meta_graph_def, tfrt::BefBuffer bef,
     tfrt::RCReference<tfrt::BEFFile> bef_file, mlrt::bc::Buffer bytecode,
     std::optional<mlrt::LoadedExecutable> loaded_executable,
-    SignatureMap signatures, std::unique_ptr<FallbackState> fallback_state,
-    std::unique_ptr<OpKernelRunnerTable> runner_table,
+    SignatureMap signatures, std::unique_ptr<OpKernelRunnerTable> runner_table,
     std::unique_ptr<tfd::FallbackResourceArray> resource_array,
     std::unique_ptr<GraphExecutor> graph_executor)
     : SavedModel(std::move(options), std::move(graph_executor)),
@@ -655,7 +679,6 @@
           options_.graph_execution_options.runtime->core_runtime()
               ->GetHostContext()),
       signatures_(std::move(signatures)),
-      fallback_state_(std::move(fallback_state)),
       runner_table_(std::move(runner_table)),
       resource_array_(std::move(resource_array)) {}
 
@@ -755,7 +778,7 @@
       options_.graph_execution_options, run_options, name, *symbol_uids, func,
       loaded_executable, inputs, outputs, resource_context,
       client_graph_resource_context, runner_table, resource_array, runtime(),
-      *fallback_state_, fallback_state_->process_function_library_runtime(),
+      fallback_state(), fallback_state().process_function_library_runtime(),
       &req_deadline_tracker_, /*stream_callback_id=*/std::nullopt);
 }
 
@@ -991,7 +1014,7 @@
     ASSIGN_OR_RETURN_IN_COMPILE(
         loading_result->bytecode_buffer,
         tensorflow::mlrt_compiler::ConvertTfMlirToBytecode(
-            options_.graph_execution_options.compile_options, *fallback_state_,
+            options_.graph_execution_options.compile_options, fallback_state(),
             module.get(), model_context));
     mlrt::bc::Executable executable(loading_result->bytecode_buffer.data());
     loading_result->bytecode_executable =
@@ -1002,11 +1025,11 @@
         *loading_result->bytecode_executable,
         &graph_executor_->resource_context(),
         loading_result->runner_table.get(),
-        loading_result->resource_array.get(), *fallback_state_));
+        loading_result->resource_array.get(), fallback_state()));
   } else {
     TF_RETURN_IF_ERROR(tensorflow::ConvertTfMlirToBef(
         options_.graph_execution_options.compile_options, module.get(),
-        &loading_result->bef, model_context, fallback_state_.get()));
+        &loading_result->bef, model_context, &fallback_state()));
     ASSIGN_OR_RETURN_IN_COMPILE(
         loading_result->bef_file,
         tfrt::CreateBefFileFromBefBuffer(
@@ -1016,7 +1039,7 @@
         /*initializers_and_signatures=*/{}, loading_result->bef_file.get(),
         &graph_executor_->resource_context(),
         loading_result->runner_table.get(),
-        loading_result->resource_array.get(), *fallback_state_));
+        loading_result->resource_array.get(), fallback_state()));
   }
   symbol_uids.tfrt_symbol_uid = MaybeUploadMlirToXsymbol(module.get());
   loading_result->symbol_uids = std::move(symbol_uids);
diff --git a/tensorflow/core/tfrt/saved_model/saved_model.h b/tensorflow/core/tfrt/saved_model/saved_model.h
index d16ba7b..8fc7f30 100644
--- a/tensorflow/core/tfrt/saved_model/saved_model.h
+++ b/tensorflow/core/tfrt/saved_model/saved_model.h
@@ -181,6 +181,11 @@
       std::vector<tensorflow::Tensor>* outputs) = 0;
 
  protected:
+  const FallbackState& fallback_state() const {
+    return graph_executor_->fallback_state();
+  }
+  FallbackState& fallback_state() { return graph_executor_->fallback_state(); }
+
   const Options options_;
   std::unique_ptr<GraphExecutor> graph_executor_;
 };
@@ -216,7 +221,6 @@
       tfrt::RCReference<tfrt::BEFFile> bef_file, mlrt::bc::Buffer bytecode,
       std::optional<mlrt::LoadedExecutable> loaded_executable,
       absl::flat_hash_map<std::string, internal::Signature> signatures,
-      std::unique_ptr<FallbackState> fallback_state,
       std::unique_ptr<OpKernelRunnerTable> runner_table,
       std::unique_ptr<tfd::FallbackResourceArray> resource_array,
       std::unique_ptr<GraphExecutor> graph_executor);
@@ -311,7 +315,6 @@
 
   tfrt::RequestDeadlineTracker req_deadline_tracker_;
   absl::flat_hash_map<std::string, internal::Signature> signatures_;
-  std::unique_ptr<FallbackState> fallback_state_;
   std::unique_ptr<OpKernelRunnerTable> runner_table_;
   std::unique_ptr<tfd::FallbackResourceArray> resource_array_;
   tensorflow::mutex loading_result_cache_mu_;
diff --git a/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.cc b/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.cc
index fc02ed9..17b1ba3 100644
--- a/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.cc
+++ b/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.cc
@@ -28,6 +28,9 @@
 #include "absl/status/status.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/string_view.h"
+#include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/MLIRContext.h"  // from @llvm-project
+#include "mlir/IR/OwningOpRef.h"  // from @llvm-project
 #include "mlir/Pass/PassManager.h"  // from @llvm-project
 #include "tensorflow/cc/saved_model/constants.h"
 #include "tensorflow/compiler/jit/device_compilation_cluster_signature.h"
@@ -37,8 +40,10 @@
 #include "tensorflow/compiler/jit/xla_platform_info.h"
 #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
+#include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.h"
 #include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h"
 #include "tensorflow/compiler/mlir/tfrt/translate/import_model.h"
+#include "tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h"
 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
 #include "xla/pjrt/gpu/se_gpu_pjrt_client.h"
 #include "xla/pjrt/gpu/se_gpu_pjrt_compiler.h"
@@ -58,9 +63,10 @@
 #include "tensorflow/core/tfrt/graph_executor/export_mlir.h"
 #include "tensorflow/core/tfrt/graph_executor/graph_execution_options.h"
 #include "tensorflow/core/tfrt/graph_executor/graph_executor.h"
+#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h"
 #include "tensorflow/core/tfrt/runtime/runtime.h"
 #include "tensorflow/core/tfrt/saved_model/saved_model_util.h"
-#include "tensorflow/core/tfrt/saved_model/utils/serialize_bef_utils.h"
+#include "tensorflow/core/tfrt/saved_model/utils/serialize_utils.h"
 #include "tensorflow/core/tfrt/utils/utils.h"
 #include "tensorflow/core/tpu/virtual_device.h"
 #include "tsl/platform/casts.h"
@@ -189,85 +195,6 @@
 
 AotOptions::AotOptions() : graph_execution_options(nullptr) {}
 
-Status AotCompileSavedModelAndSaveResult(absl::string_view input_model_dir,
-                                         AotOptions aot_options,
-                                         absl::string_view output_model_dir) {
-  // Create aot_packages directory.
-  Env* env = Env::Default();
-  const bool new_directory = !output_model_dir.empty();
-  std::string output_dir;
-  if (!new_directory) {
-    output_dir = std::string(input_model_dir);
-  } else {
-    // TODO(chrisminge) modify to copy everything in input directory
-    output_dir = std::string(output_model_dir);
-    TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(output_dir, {}));
-  }
-  const std::string aot_directory =
-      io::JoinPath(output_dir, kAoTPackagesDirectory);
-  TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(aot_directory));
-
-  if (aot_options.graph_execution_options == nullptr) {
-    // Since we are not going to actually run the model during AoT
-    // compilation and optimization, we choose a value of 4 inter_op_threads
-    // which is commonly used for testing.
-    SetGlobalRuntime(tfrt_stub::Runtime::Create(/*num_inter_op_threads=*/4));
-
-    GraphExecutionOptions graph_execution_options(GetGlobalRuntime());
-
-    graph_execution_options.enable_tfrt_gpu = true;
-    graph_execution_options.enable_grappler_function_optimizer = true;
-    graph_execution_options.compile_options.enable_grappler = true;
-    graph_execution_options.compile_options.device_target =
-        TfrtDeviceInfraTarget::kGpu;
-    graph_execution_options.compile_options.hoist_invariant_ops = true;
-    graph_execution_options.compile_options
-        .serialize_mlir_module_to_aot_packages = true;
-    graph_execution_options.compile_options.aot_mlir_module_file =
-        io::JoinPath(aot_directory, kMLIRModuleFilename);
-
-    aot_options.graph_execution_options =
-        std::make_shared<GraphExecutionOptions>(graph_execution_options);
-  }
-
-  if (aot_options.tags.empty()) {
-    aot_options.tags = {"serve", "gpu"};
-  }
-
-  TF_ASSIGN_OR_RETURN(AotResult result,
-                      AotCompileSavedModel(input_model_dir, aot_options));
-
-  const std::string warmup_requests_path = io::JoinPath(
-      input_model_dir, "assets.extra", "tf_serving_warmup_requests");
-  TF_RETURN_IF_ERROR(env->FileExists(warmup_requests_path));
-
-  const std::string saved_model_pb_path =
-      io::JoinPath(input_model_dir, kSavedModelFilenamePb);
-  const std::string saved_model_pbtxt_path =
-      io::JoinPath(input_model_dir, kSavedModelFilenamePbTxt);
-  bool pb_found = env->FileExists(saved_model_pb_path).ok();
-  bool pbtxt_found = env->FileExists(saved_model_pbtxt_path).ok();
-  if (!pb_found && !pbtxt_found) {
-    return absl::NotFoundError(absl::StrCat(
-        "saved_model not found in input directory: ", input_model_dir));
-  }
-
-  // Serialize BEF buffer to a file under aot_packages
-  const std::string serialized_bef_path =
-      io::JoinPath(aot_directory, kBefBufferFilenameMLIRBEF);
-  TF_RETURN_IF_ERROR(SerializeBEF(result.bef, serialized_bef_path));
-
-  if (pb_found) {
-    const std::string output_file_directory =
-        io::JoinPath(std::string(output_model_dir), kSavedModelFilenamePb);
-    return env->CopyFile(saved_model_pb_path, output_file_directory);
-  } else {
-    const std::string output_file_directory =
-        io::JoinPath(std::string(output_model_dir), kSavedModelFilenamePbTxt);
-    return env->CopyFile(saved_model_pbtxt_path, output_file_directory);
-  }
-}
-
 StatusOr<AotResult> AotCompileSavedModel(absl::string_view input_model_dir,
                                          AotOptions aot_options) {
   TF_ASSIGN_OR_RETURN(tensorflow::MetaGraphDef meta_graph_def,
@@ -322,11 +249,30 @@
 
   tfrt::BefBuffer bef;
   std::vector<std::string> xla_function_names;
-  RETURN_IF_ERROR_IN_COMPILE(tensorflow::ConvertTfMlirToBef(
-      aot_options.graph_execution_options->compile_options, mlir_module.get(),
-      &bef, model_context, fallback_state.get(), &xla_function_names));
-  if (bef.empty()) {
-    return absl::InternalError("BefBuffer is empty.");
+
+  mlrt::bc::Buffer bytecode_buffer;
+  if (aot_options.graph_execution_options->enable_mlrt) {
+    mlir::OwningOpRef<mlir::ModuleOp> module_with_op_keys;
+
+    ASSIGN_OR_RETURN_IN_COMPILE(
+        bytecode_buffer,
+        tensorflow::mlrt_compiler::ConvertTfMlirToBytecode(
+            aot_options.graph_execution_options->compile_options,
+            *fallback_state, mlir_module.get(), model_context,
+            &module_with_op_keys, &xla_function_names));
+
+    if (bytecode_buffer.empty()) {
+      LOG(ERROR) << "MLRT byte buffer is empty.";
+      return absl::InternalError("bytecode_buffer is empty.");
+    }
+  } else {
+    RETURN_IF_ERROR_IN_COMPILE(tensorflow::ConvertTfMlirToBef(
+        aot_options.graph_execution_options->compile_options, mlir_module.get(),
+        &bef, model_context, fallback_state.get(), &xla_function_names));
+    if (bef.empty()) {
+      LOG(ERROR) << "BEF byte buffer is empty.";
+      return absl::InternalError("BefBuffer is empty.");
+    }
   }
 
   const FunctionLibraryDefinition& flib_def = fallback_state->func_lib_def();
@@ -340,7 +286,9 @@
     }
     xla_functions.push_back(*xla_func_def);
   }
-
+  if (aot_options.graph_execution_options->enable_mlrt) {
+    return AotResult{std::move(bytecode_buffer), std::move(xla_functions)};
+  }
   return AotResult{std::move(bef), std::move(xla_functions)};
 }
 
diff --git a/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.h b/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.h
index ec68b15..6f1015e 100644
--- a/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.h
+++ b/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.h
@@ -19,6 +19,7 @@
 #include <memory>
 #include <string>
 #include <unordered_set>
+#include <variant>
 #include <vector>
 
 #include "absl/container/flat_hash_map.h"
@@ -32,6 +33,7 @@
 #include "tensorflow/core/framework/function.pb.h"
 #include "tensorflow/core/protobuf/meta_graph.pb.h"
 #include "tensorflow/core/tfrt/graph_executor/graph_execution_options.h"
+#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h"
 #include "tensorflow/core/tfrt/runtime/runtime.h"
 #include "tfrt/bef/bef_buffer.h"  // from @tf_runtime
 
@@ -49,7 +51,7 @@
   using ExecutableMap =
       absl::flat_hash_map<DeviceCompilationClusterSignature, std::string,
                           DeviceCompilationClusterSignature::Hash>;
-  tfrt::BefBuffer bef;
+  std::variant<tfrt::BefBuffer, mlrt::bc::Buffer> buffer;
   // TODO(b/296466237): Investigate whether the whole FunctionDefLibrary should
   // be put here.
   // XLA cluster functions corresponding to `XlaLaunch` op, generated during
@@ -62,14 +64,6 @@
 StatusOr<AotResult> AotCompileSavedModel(absl::string_view input_model_dir,
                                          AotOptions aot_options = {});
 
-// AOT compiles saved_model in input_model_dir, writing output
-// saved_model and aot packages to output_model_dir, or
-// "{input_model_dir}/aot_packages" if output dir provided. Warmup requests
-// should be present in input_model_dir
-Status AotCompileSavedModelAndSaveResult(
-    absl::string_view input_model_dir, AotOptions aot_options = {},
-    absl::string_view output_model_dir = "");
-
 // TODO(b/296466237): Add unit test.
 // Runs bridge and compiles the generated XLA functions corresponding to the
 // signature function with name `siganture_name` in MetaGraphDef.
@@ -105,7 +99,6 @@
     int graph_def_version, const std::vector<XlaCompiler::Argument>& args,
     bool has_ref_vars, bool may_alias_resource_update,
     XlaCompiler::CompilationResult** compilation_result);
-
 }  // namespace tensorflow::tfrt_stub
 
 #endif  // TENSORFLOW_CORE_TFRT_SAVED_MODEL_SAVED_MODEL_AOT_COMPILE_H_
diff --git a/tensorflow/core/tfrt/saved_model/saved_model_util.cc b/tensorflow/core/tfrt/saved_model/saved_model_util.cc
index 6cdbe4f..07d5762 100644
--- a/tensorflow/core/tfrt/saved_model/saved_model_util.cc
+++ b/tensorflow/core/tfrt/saved_model/saved_model_util.cc
@@ -53,7 +53,7 @@
 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
 #include "tensorflow/core/tfrt/fallback/fallback_state.h"
 #include "tensorflow/core/tfrt/saved_model/saved_model_import_input.h"
-#include "tensorflow/core/tfrt/saved_model/utils/serialize_bef_utils.h"
+#include "tensorflow/core/tfrt/saved_model/utils/serialize_utils.h"
 #include "tsl/platform/env.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/path.h"
@@ -231,25 +231,25 @@
 }
 
 std::string GetAotPackagePath(absl::string_view saved_model_dir) {
-  return tsl::io::JoinPath(std::string(saved_model_dir), kAoTPackagesDirectory);
+  return tsl::io::JoinPath(std::string(saved_model_dir), kAotPackagesDirectory);
 }
 
-std::string GetBEFFilePath(std::string aot_package_directory) {
+std::string GetBefFilePath(std::string aot_package_directory) {
   return tsl::io::JoinPath(aot_package_directory,
-                           std::string(kBefBufferFilenameMLIRBEF));
+                           std::string(kBefBufferFileName));
 }
 
 std::string GetMlirFilePath(const std::string& aot_package_directory) {
-  return tsl::io::JoinPath(aot_package_directory, kMLIRModuleFilename);
+  return tsl::io::JoinPath(aot_package_directory, kMlirModuleFilename);
 }
 
-absl::StatusOr<tfrt::BefBuffer> LoadAotPackages(
+absl::StatusOr<tfrt::BefBuffer> LoadBefAndMlir(
     const TfrtCompileOptions& options, mlir::ModuleOp mlir_module,
     const std::string& saved_model_dir,
     tfrt_stub::FallbackState* fallback_state) {
   const std::string aot_package_directory = GetAotPackagePath(saved_model_dir);
   const std::string bef_file_path =
-      tfrt_stub::GetBEFFilePath(aot_package_directory);
+      tfrt_stub::GetBefFilePath(aot_package_directory);
   TF_ASSIGN_OR_RETURN(tfrt::BefBuffer bef, DeserializeBEFBuffer(bef_file_path));
 
   if (bef.empty()) {
@@ -263,7 +263,7 @@
   return bef;
 }
 
-absl::Status DeserializeAoTMlirModule(
+absl::Status DeserializeAotMlirModule(
     absl::string_view saved_model_dir, mlir::MLIRContext* context,
     mlir::OwningOpRef<mlir::ModuleOp>* mlir_module) {
   const std::string aot_package_directory = GetAotPackagePath(saved_model_dir);
@@ -276,7 +276,7 @@
   return absl::OkStatus();
 }
 
-void RegisterTFRTDialectsForAoT(mlir::DialectRegistry& registry) {
+void RegisterTfrtDialectsForAot(mlir::DialectRegistry& registry) {
   tfrt::RegisterTFRTDialects(registry);
   registry.insert<tfrt::fallback::FallbackDialect>();
   registry.insert<tfrt::fallback_async::FallbackAsyncDialect>();
diff --git a/tensorflow/core/tfrt/saved_model/saved_model_util.h b/tensorflow/core/tfrt/saved_model/saved_model_util.h
index a3a3c09..b306a35 100644
--- a/tensorflow/core/tfrt/saved_model/saved_model_util.h
+++ b/tensorflow/core/tfrt/saved_model/saved_model_util.h
@@ -47,13 +47,16 @@
 namespace tfrt_stub {
 
 // Filename for serialized BEF Buffer.
-inline constexpr char kBefBufferFilenameMLIRBEF[] = "serialized_bef.mlir.bef";
+inline constexpr char kBefBufferFileName[] = "serialized_bef.mlir.bef";
+
+// Filename for serialized MLRT bytecode Buffer.
+inline constexpr char kMlrtBufferFileName[] = "serialized_mlrt.mlir.mlrt";
 
 // Filename for serialized MLIR_MODULE.
-inline constexpr char kMLIRModuleFilename[] = "serialized_mlir.mlir";
+inline constexpr char kMlirModuleFilename[] = "serialized_mlir.mlir";
 
 // Subdirectory where AoT Packages are saved
-inline constexpr char kAoTPackagesDirectory[] = "aot_packages";
+inline constexpr char kAotPackagesDirectory[] = "aot_packages";
 
 // TODO(tfrt-dev): Replace tfrt::TensorSpec with tensorflow::TensorSpec once the
 // latter is checked in.
@@ -117,22 +120,22 @@
 
 std::string GetAotPackagePath(absl::string_view saved_model_dir);
 
-std::string GetBEFFilePath(std::string aot_package_directory);
+std::string GetBefFilePath(std::string aot_package_directory);
 
 std::string GetMlirFilePath(const std::string& aot_package_directory);
 
 // TODO(b/295241000): Implement MLIR deserialization to skip it AoT and remove
 // redundant steps
-absl::StatusOr<tfrt::BefBuffer> LoadAotPackages(
+absl::StatusOr<tfrt::BefBuffer> LoadBefAndMlir(
     const TfrtCompileOptions& options, mlir::ModuleOp mlir_module,
     const std::string& saved_model_dir,
     tfrt_stub::FallbackState* fallback_state);
 
-absl::Status DeserializeAoTMlirModule(
+absl::Status DeserializeAotMlirModule(
     absl::string_view saved_model_dir, mlir::MLIRContext* context,
     mlir::OwningOpRef<mlir::ModuleOp>* mlir_module);
 
-void RegisterTFRTDialectsForAoT(mlir::DialectRegistry& registry);
+void RegisterTfrtDialectsForAot(mlir::DialectRegistry& registry);
 
 }  // namespace tfrt_stub
 }  // namespace tensorflow
diff --git a/tensorflow/core/tfrt/saved_model/tests/BUILD b/tensorflow/core/tfrt/saved_model/tests/BUILD
index 44acbab..3495c11 100644
--- a/tensorflow/core/tfrt/saved_model/tests/BUILD
+++ b/tensorflow/core/tfrt/saved_model/tests/BUILD
@@ -14,6 +14,7 @@
 package_group(
     name = "internal",
     packages = [
+        # copybara:uncomment "//learning/brain/tfrt/...",
         "//tensorflow/core/tfrt/saved_model/tests/...",
         "//tensorflow/core/tfrt/tfrt_session/...",
         "//tensorflow/core/tfrt/utils/debug/...",
@@ -348,13 +349,11 @@
     deps = [
         ":disable_tf2",  # build_cleaner: keep; go/disable_tf2
         "//tensorflow/python/client:session",
-        "//tensorflow/python/eager:def_function",
         "//tensorflow/python/framework:dtypes",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/ops:array_ops",
         "//tensorflow/python/ops:lookup_ops",
-        "//tensorflow/python/ops:math_ops",
-        "//tensorflow/python/ops:variable_scope",
+        "//tensorflow/python/ops:resource_variables_toggle",
         "//tensorflow/python/ops:variables",
         "//tensorflow/python/platform:gfile",
         "//tensorflow/python/saved_model:builder",
@@ -362,7 +361,6 @@
         "//tensorflow/python/saved_model:signature_def_utils",
         "//tensorflow/python/saved_model:tag_constants",
         "//tensorflow/python/saved_model:utils",
-        "//tensorflow/python/training",
         "@absl_py//absl:app",
         "@absl_py//absl/flags",
     ],
diff --git a/tensorflow/core/tfrt/saved_model/tests/gen_hash_table_asset_v1.py b/tensorflow/core/tfrt/saved_model/tests/gen_hash_table_asset_v1.py
index a3695f2..8f75e3b 100644
--- a/tensorflow/core/tfrt/saved_model/tests/gen_hash_table_asset_v1.py
+++ b/tensorflow/core/tfrt/saved_model/tests/gen_hash_table_asset_v1.py
@@ -25,7 +25,7 @@
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import lookup_ops
-from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import resource_variables_toggle
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import gfile
 from tensorflow.python.saved_model import builder
@@ -54,7 +54,7 @@
 
   shutil.rmtree(FLAGS.saved_model_path)
 
-  variable_scope.enable_resource_variables()
+  resource_variables_toggle.enable_resource_variables()
 
   # Create the graph
   table_initializer = lookup_ops.TextFileInitializer(
diff --git a/tensorflow/core/tfrt/saved_model/utils/BUILD b/tensorflow/core/tfrt/saved_model/utils/BUILD
index a216250..0427100 100644
--- a/tensorflow/core/tfrt/saved_model/utils/BUILD
+++ b/tensorflow/core/tfrt/saved_model/utils/BUILD
@@ -11,46 +11,53 @@
     packages = [
         # Authorized users go here.
         "//tensorflow/core/tfrt/saved_model/...",
+        "//tensorflow/compiler/mlir/tensorflow/...",
         "//learning/brain/tfrt/cpp_tests/gpu_inference/...",
         "//tensorflow/compiler/mlir/tfrt/...",
+        "//tensorflow/compiler/mlir/tfrt/translate/...",
     ],
 )
 
 cc_library(
-    name = "serialize_bef_utils",
-    srcs = ["serialize_bef_utils.cc"],
-    hdrs = ["serialize_bef_utils.h"],
+    name = "serialize_utils",
+    srcs = ["serialize_utils.cc"],
+    hdrs = ["serialize_utils.h"],
     deps = [
         "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util",
         "//tensorflow/core/platform:status",
+        "//tensorflow/core/tfrt/mlrt/bytecode",
+        "//tensorflow/core/tfrt/mlrt/bytecode:executable",
         "@com_google_absl//absl/status",
         "@com_google_absl//absl/status:statusor",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:Support",
         "@local_tsl//tsl/platform:env",
         "@tf_runtime//:bef",
-        "@tf_runtime//:befexecutor",
     ],
 )
 
 tf_cc_shared_test(
-    name = "serialize_bef_utils_test",
-    srcs = ["serialize_bef_utils_test.cc"],
+    name = "serialize_utils_test",
+    srcs = ["serialize_utils_test.cc"],
     data = [
         "//tensorflow/compiler/mlir/tfrt/tests/saved_model:testdata",
     ],
     tags = ["no_oss"],
     deps = [
-        ":serialize_bef_utils",
+        ":serialize_utils",
         "//tensorflow/compiler/mlir/tensorflow",
         "//tensorflow/compiler/mlir/tfrt:import_model",
+        "//tensorflow/compiler/mlir/tfrt/transforms/mlrt:import_model",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
         "//tensorflow/core/platform:path",
         "//tensorflow/core/platform:resource_loader",
+        "//tensorflow/core/tfrt/fallback:fallback_state",
+        "//tensorflow/core/tfrt/mlrt/bytecode",
         "//tensorflow/core/tfrt/saved_model:saved_model_testutil",
         "//tensorflow/core/tfrt/utils",
         "@com_google_googletest//:gtest_main",
+        "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Parser",
         "@local_tsl//tsl/lib/core:status_test_util",
         "@tf_runtime//:bef",
diff --git a/tensorflow/core/tfrt/saved_model/utils/serialize_bef_utils.cc b/tensorflow/core/tfrt/saved_model/utils/serialize_utils.cc
similarity index 76%
rename from tensorflow/core/tfrt/saved_model/utils/serialize_bef_utils.cc
rename to tensorflow/core/tfrt/saved_model/utils/serialize_utils.cc
index b4564bd..eb62445 100644
--- a/tensorflow/core/tfrt/saved_model/utils/serialize_bef_utils.cc
+++ b/tensorflow/core/tfrt/saved_model/utils/serialize_utils.cc
@@ -13,7 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/core/tfrt/saved_model/utils/serialize_bef_utils.h"
+#include "tensorflow/core/tfrt/saved_model/utils/serialize_utils.h"
 
 #include <memory>
 #include <string>
@@ -23,6 +23,7 @@
 #include "mlir/Support/FileUtilities.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
 #include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h"
 #include "tsl/platform/env.h"
 #include "tfrt/bef/bef_buffer.h"  // from @tf_runtime
 
@@ -49,5 +50,17 @@
   return bef;
 }
 
+absl::Status SerializeMLRTBytecode(const mlrt::bc::Buffer &bytecode,
+                                   const std::string &filepath) {
+  std::string errorMessage;
+  auto output = mlir::openOutputFile(filepath, &errorMessage);
+  (output->os())
+      .write(reinterpret_cast<const char *>(bytecode.data()), bytecode.size());
+  output->keep();
+  LOG(INFO) << "Completed serializing MLRTBytecode to: " << filepath;
+
+  return absl::OkStatus();
+}
+
 }  // namespace tfrt_stub
 }  // namespace tensorflow
diff --git a/tensorflow/core/tfrt/saved_model/utils/serialize_bef_utils.h b/tensorflow/core/tfrt/saved_model/utils/serialize_utils.h
similarity index 71%
rename from tensorflow/core/tfrt/saved_model/utils/serialize_bef_utils.h
rename to tensorflow/core/tfrt/saved_model/utils/serialize_utils.h
index f359a8b..a025851 100644
--- a/tensorflow/core/tfrt/saved_model/utils/serialize_bef_utils.h
+++ b/tensorflow/core/tfrt/saved_model/utils/serialize_utils.h
@@ -13,8 +13,8 @@
 limitations under the License.
 ==============================================================================*/
 
-#ifndef TENSORFLOW_CORE_TFRT_SAVED_MODEL_UTILS_SERIALIZE_BEF_UTILS_H_
-#define TENSORFLOW_CORE_TFRT_SAVED_MODEL_UTILS_SERIALIZE_BEF_UTILS_H_
+#ifndef TENSORFLOW_CORE_TFRT_SAVED_MODEL_UTILS_SERIALIZE_UTILS_H_
+#define TENSORFLOW_CORE_TFRT_SAVED_MODEL_UTILS_SERIALIZE_UTILS_H_
 
 #include <memory>
 #include <string>
@@ -25,20 +25,26 @@
 #include "mlir/Support/FileUtilities.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
 #include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/tfrt/mlrt/bytecode/executable.h"
 #include "tsl/platform/env.h"
 #include "tfrt/bef/bef_buffer.h"  // from @tf_runtime
 
 namespace tensorflow {
 namespace tfrt_stub {
 
-// Serializes the BefBuffer into a file
+// Serializes the BefBuffer into a file.
 absl::Status SerializeBEF(const tfrt::BefBuffer &bef,
                           const std::string &filepath);
 
-// Deserializes BEF file from filepath into a BEFBuffer
+// Deserializes BEF file from filepath into a BEFBuffer.
 absl::StatusOr<tfrt::BefBuffer> DeserializeBEFBuffer(
     const std::string &filepath);
+
+// Serializes the MLRTBytecodeBuffer into a file.
+absl::Status SerializeMLRTBytecode(const mlrt::bc::Buffer &byteCode,
+                                   const std::string &filepath);
+
 }  // namespace tfrt_stub
 }  // namespace tensorflow
 
-#endif  // TENSORFLOW_CORE_TFRT_SAVED_MODEL_UTILS_SERIALIZE_BEF_UTILS_H_
+#endif  // TENSORFLOW_CORE_TFRT_SAVED_MODEL_UTILS_SERIALIZE_UTILS_H_
diff --git a/tensorflow/core/tfrt/saved_model/utils/serialize_bef_utils_test.cc b/tensorflow/core/tfrt/saved_model/utils/serialize_utils_test.cc
similarity index 62%
rename from tensorflow/core/tfrt/saved_model/utils/serialize_bef_utils_test.cc
rename to tensorflow/core/tfrt/saved_model/utils/serialize_utils_test.cc
index 3850e93..9ca6b53 100644
--- a/tensorflow/core/tfrt/saved_model/utils/serialize_bef_utils_test.cc
+++ b/tensorflow/core/tfrt/saved_model/utils/serialize_utils_test.cc
@@ -13,18 +13,22 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/core/tfrt/saved_model/utils/serialize_bef_utils.h"
+#include "tensorflow/core/tfrt/saved_model/utils/serialize_utils.h"
 
 #include <memory>
 #include <string>
 
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
+#include "mlir/IR/OwningOpRef.h"  // from @llvm-project
 #include "mlir/Parser/Parser.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
+#include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.h"
 #include "tensorflow/compiler/mlir/tfrt/translate/import_model.h"
 #include "tensorflow/core/platform/path.h"
 #include "tensorflow/core/platform/resource_loader.h"
+#include "tensorflow/core/tfrt/fallback/fallback_state.h"
+#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h"
 #include "tensorflow/core/tfrt/saved_model/saved_model_testutil.h"
 #include "tensorflow/core/tfrt/utils/utils.h"
 #include "tsl/lib/core/status_test_util.h"
@@ -87,6 +91,51 @@
                    *default_options.graph_execution_options.runtime, bef)
                    .status());
 }
+
+TEST(SerializeMLRTTest, HandlesSerializeProcess) {
+  // Create Empty MLRT Bytecode
+  // tfrt::BefBuffer old_bef;
+  mlrt::bc::Buffer old_byteCode;
+
+  // Load MLRT Bytecode Data
+
+  const std::string saved_model_mlir_path =
+      "third_party/tensorflow/compiler/mlir/tfrt/tests/saved_model/testdata/"
+      "test.mlir";
+
+  mlir::DialectRegistry registry;
+  mlir::RegisterAllTensorFlowDialects(registry);
+  mlir::MLIRContext context(registry);
+  auto module =
+      mlir::parseSourceFile<mlir::ModuleOp>(saved_model_mlir_path, &context);
+  ASSERT_TRUE(module);
+  mlir::OwningOpRef<mlir::ModuleOp> module_with_op_keys;
+  std::unique_ptr<Runtime> runtime =
+      tensorflow::tfrt_stub::Runtime::Create(/*num_inter_op_threads=*/1);
+  tfrt_stub::GraphExecutionOptions options(runtime.get());
+  options.enable_mlrt = true;
+  tfrt::ResourceContext resource_context;
+  TF_ASSERT_OK_AND_ASSIGN(
+      std::unique_ptr<tfrt_stub::FallbackState> fallback_state,
+      tfrt_stub::FallbackState::Create(SessionOptions(), FunctionDefLibrary()));
+  tfrt_stub::ModelRuntimeContext model_context(
+      &options, options.compile_options.saved_model_dir, &resource_context);
+  TF_ASSERT_OK_AND_ASSIGN(
+      auto buffer, mlrt_compiler::ConvertTfMlirToBytecode(
+                       options.compile_options, *fallback_state, module.get(),
+                       model_context, &module_with_op_keys));
+
+  // Create Filepath for .mlir.mlrt
+  const std::string filepath =
+      io::JoinPath(getenv("TEST_UNDECLARED_OUTPUTS_DIR"),
+                   std::string("serialized_mlrt.mlir.mlrt"));
+
+  // Serialize MLRT Bytecode
+  TF_ASSERT_OK(
+      tensorflow::tfrt_stub::SerializeMLRTBytecode(old_byteCode, filepath));
+  // Check that MLRT Bytecode is not empty
+  ASSERT_NE(buffer.size(), 0);
+}
 }  // namespace
 }  // namespace tfrt_stub
 }  // namespace tensorflow
diff --git a/tensorflow/core/tfrt/tfrt_session/tfrt_session.cc b/tensorflow/core/tfrt/tfrt_session/tfrt_session.cc
index 4b15be0..6971c4f 100644
--- a/tensorflow/core/tfrt/tfrt_session/tfrt_session.cc
+++ b/tensorflow/core/tfrt/tfrt_session/tfrt_session.cc
@@ -41,6 +41,7 @@
 #include "tensorflow/core/tfrt/mlrt/kernel/batch_kernel.h"
 #include "tensorflow/core/tfrt/mlrt/kernel/kernel.h"
 #include "tensorflow/core/tfrt/run_handler_thread_pool/run_handler_concurrent_work_queue.h"
+#include "tensorflow/core/tfrt/runtime/runtime.h"
 #include "tensorflow/core/tfrt/runtime/tf_threadpool_concurrent_work_queue.h"
 #include "tensorflow/core/tfrt/runtime/work_queue_interface.h"
 #include "tensorflow/core/tfrt/utils/utils.h"
@@ -123,12 +124,8 @@
         device_target_{device_target},
         tpu_use_tpu_runner_{tpu_use_tpu_runner},
         inter_op_thread_pools_{std::move(inter_op_thread_pools)},
-        model_metadata_(options.config.experimental().session_metadata()),
-        optimize_for_static_graph_(
-            options.config.experimental().optimize_for_static_graph()),
-        disable_optimize_for_static_graph_(
-            options.config.experimental().disable_optimize_for_static_graph()),
-        enable_mlrt_(enable_mlrt) {}
+        enable_mlrt_(enable_mlrt),
+        options_{options} {}
 
   Status Create(const GraphDef& graph) override {
     return Create(GraphDef(graph));
@@ -168,10 +165,11 @@
     auto session_options =
         tensorflow::tfrt_stub::CreateDefaultSessionOptions(options);
     session_options.config.mutable_experimental()
-        ->set_optimize_for_static_graph(optimize_for_static_graph_);
+        ->set_optimize_for_static_graph(
+            options_.config.experimental().optimize_for_static_graph());
     session_options.config.mutable_experimental()
         ->set_disable_optimize_for_static_graph(
-            disable_optimize_for_static_graph_);
+            options_.config.experimental().disable_optimize_for_static_graph());
     LOG_FIRST_N(INFO, 10) << "SessionOptions: "
                           << session_options.config.DebugString();
 
@@ -179,7 +177,7 @@
     // without applying placer or grappler, it is OK for now because it's only
     // used for captured functions in certain tf.data ops
     const auto& fdef_lib = graph.library();
-    TF_ASSIGN_OR_RETURN(fallback_state_,
+    TF_ASSIGN_OR_RETURN(auto fallback_state,
                         tensorflow::tfrt_stub::FallbackState::Create(
                             session_options, fdef_lib));
 
@@ -191,6 +189,12 @@
     auto resource_context = std::make_unique<tfrt::ResourceContext>();
     tfrt_stub::ModelRuntimeContext model_context(
         &options, /*export_dir=*/"unknown_export_dir", resource_context.get());
+    MetaGraphDef meta_graph_def;
+    *meta_graph_def.mutable_graph_def() = graph;
+    model_context.set_meta_graph_def(&meta_graph_def);
+    // TODO(b/300474723): Add functionality supporting Pathways initialization
+    // through TFRT Session.
+    model_context.set_is_local_session(true);
     TF_RETURN_IF_ERROR(options.runtime->CreateRuntimeResources(model_context));
 
     // `GraphExecutor::Create()` will preprocess the graph (e.g., apply
@@ -200,7 +204,7 @@
     TF_ASSIGN_OR_RETURN(
         graph_executor_,
         tensorflow::tfrt_stub::GraphExecutor::Create(
-            options, *fallback_state_, std::move(resource_context),
+            options, std::move(fallback_state), std::move(resource_context),
             std::move(graph), std::move(kernel_registry)));
 
     session_state_ = SessionState::kCreated;
@@ -437,7 +441,7 @@
     // implementation that supports the premapped memory optimization.
     compile_options.use_tpu_host_allocator_for_inputs = tpu_use_tpu_runner_;
 
-    options.model_metadata = model_metadata_;
+    options.model_metadata = options_.config.experimental().session_metadata();
     options.enable_mlrt = enable_mlrt_;
 
     return options;
@@ -472,17 +476,13 @@
   const bool tpu_use_tpu_runner_;
   TfrtSessionInterOpThreadPools inter_op_thread_pools_;
 
-  std::unique_ptr<tfrt_stub::FallbackState> fallback_state_;
-
   mutable absl::Mutex callables_lock_;
   CallableHandle next_callable_handle_ TF_GUARDED_BY(callables_lock_) = 0;
   absl::flat_hash_map<CallableHandle, Callable> callables_
       TF_GUARDED_BY(callables_lock_);
 
-  const tensorflow::SessionMetadata model_metadata_;
-  const bool optimize_for_static_graph_ = true;
-  const bool disable_optimize_for_static_graph_ = false;
   bool enable_mlrt_ = false;
+  SessionOptions options_ = SessionOptions();
 };
 
 std::unique_ptr<tensorflow::tfrt_stub::WorkQueueInterface>
diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD
index 722302d..317060b 100644
--- a/tensorflow/core/tpu/BUILD
+++ b/tensorflow/core/tpu/BUILD
@@ -234,10 +234,12 @@
         ":tpu_defs",
         "//tensorflow/compiler/jit:flags_headers",
         "//tensorflow/compiler/jit:shape_inference",
+        "//tensorflow/compiler/tf2xla:common",
         "//tensorflow/compiler/tf2xla:layout_util",
         "//tensorflow/compiler/tf2xla:tf2xla_util",
         "//tensorflow/compiler/tf2xla:xla_compiler",
         "//tensorflow/core:core_cpu_base",
+        "//tensorflow/core:framework",
         "//tensorflow/core/framework:attr_value_proto_cc",
         "//tensorflow/core/platform:status",
         "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
@@ -245,6 +247,7 @@
         "@com_google_absl//absl/container:flat_hash_set",
         "@com_google_absl//absl/log",
         "@com_google_absl//absl/types:span",
+        "@local_xla//xla:literal_util",
         "@local_xla//xla:xla_data_proto_cc",
         "@local_xla//xla/client:compile_only_client",
     ],
diff --git a/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc b/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc
index 8ca6606..d848a3e 100644
--- a/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc
+++ b/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc
@@ -1881,13 +1881,14 @@
   if (*rewritten) {
     FunctionDef rewritten_fdef;
     TF_RETURN_IF_ERROR(GraphToFunctionDef(
-        *(fbody.graph), fbody.fdef.signature().name(), &rewritten_fdef));
+        *(fbody.graph), fbody.record->fdef().signature().name(),
+        &rewritten_fdef));
     if (new_func_name) {
       rewritten_fdef.mutable_signature()->set_name(*new_func_name);
       TF_RETURN_IF_ERROR(fld->AddFunctionDef(rewritten_fdef));
     } else {
-      TF_RETURN_IF_ERROR(
-          fld->ReplaceFunction(fbody.fdef.signature().name(), rewritten_fdef));
+      TF_RETURN_IF_ERROR(fld->ReplaceFunction(
+          fbody.record->fdef().signature().name(), rewritten_fdef));
     }
   }
 
@@ -2128,11 +2129,11 @@
       cond_fbody.get()));
 
   FunctionDef rewritten_cond_fdef;
-  TF_RETURN_IF_ERROR(GraphToFunctionDef(*(cond_fbody->graph),
-                                        cond_fbody->fdef.signature().name(),
-                                        &rewritten_cond_fdef));
-  TF_RETURN_IF_ERROR(fld->ReplaceFunction(cond_fbody->fdef.signature().name(),
-                                          rewritten_cond_fdef));
+  TF_RETURN_IF_ERROR(GraphToFunctionDef(
+      *(cond_fbody->graph), cond_fbody->record->fdef().signature().name(),
+      &rewritten_cond_fdef));
+  TF_RETURN_IF_ERROR(fld->ReplaceFunction(
+      cond_fbody->record->fdef().signature().name(), rewritten_cond_fdef));
 
   // For body func, remove _Retval nodes, and replace _Arg nodes with
   // Placeholder nodes.
@@ -2146,11 +2147,11 @@
   CleanUpRetvalsForWhileBody(index_mapping, dtypes, body_fbody.get());
 
   FunctionDef rewritten_body_fdef;
-  TF_RETURN_IF_ERROR(GraphToFunctionDef(*(body_fbody->graph),
-                                        body_fbody->fdef.signature().name(),
-                                        &rewritten_body_fdef));
-  TF_RETURN_IF_ERROR(fld->ReplaceFunction(body_fbody->fdef.signature().name(),
-                                          rewritten_body_fdef));
+  TF_RETURN_IF_ERROR(GraphToFunctionDef(
+      *(body_fbody->graph), body_fbody->record->fdef().signature().name(),
+      &rewritten_body_fdef));
+  TF_RETURN_IF_ERROR(fld->ReplaceFunction(
+      body_fbody->record->fdef().signature().name(), rewritten_body_fdef));
 
   // Remove edges from lifted args to While node, and change "T" attr of the
   // While node.
@@ -2204,11 +2205,13 @@
       then_branch_fbody.get()));
 
   FunctionDef rewritten_then_branch_fdef;
-  TF_RETURN_IF_ERROR(GraphToFunctionDef(
-      *(then_branch_fbody->graph), then_branch_fbody->fdef.signature().name(),
-      &rewritten_then_branch_fdef));
-  TF_RETURN_IF_ERROR(fld->ReplaceFunction(
-      then_branch_fbody->fdef.signature().name(), rewritten_then_branch_fdef));
+  TF_RETURN_IF_ERROR(
+      GraphToFunctionDef(*(then_branch_fbody->graph),
+                         then_branch_fbody->record->fdef().signature().name(),
+                         &rewritten_then_branch_fdef));
+  TF_RETURN_IF_ERROR(
+      fld->ReplaceFunction(then_branch_fbody->record->fdef().signature().name(),
+                           rewritten_then_branch_fdef));
 
   TF_ASSIGN_OR_RETURN(
       std::unique_ptr<FunctionBody> else_branch_fbody,
@@ -2219,11 +2222,13 @@
       else_branch_fbody.get()));
 
   FunctionDef rewritten_else_branch_fdef;
-  TF_RETURN_IF_ERROR(GraphToFunctionDef(
-      *(else_branch_fbody->graph), else_branch_fbody->fdef.signature().name(),
-      &rewritten_else_branch_fdef));
-  TF_RETURN_IF_ERROR(fld->ReplaceFunction(
-      else_branch_fbody->fdef.signature().name(), rewritten_else_branch_fdef));
+  TF_RETURN_IF_ERROR(
+      GraphToFunctionDef(*(else_branch_fbody->graph),
+                         else_branch_fbody->record->fdef().signature().name(),
+                         &rewritten_else_branch_fdef));
+  TF_RETURN_IF_ERROR(
+      fld->ReplaceFunction(else_branch_fbody->record->fdef().signature().name(),
+                           rewritten_else_branch_fdef));
 
   // Remove edges from lifted args to If node, and change "Tin" attr of the
   // If node.
@@ -2292,9 +2297,10 @@
   // might be defined by user and we should not modify it.
   FunctionDef rewritten_fdef;
   TF_RETURN_IF_ERROR(GraphToFunctionDef(
-      *(fbody->graph), fbody->fdef.signature().name(), &rewritten_fdef));
+      *(fbody->graph), fbody->record->fdef().signature().name(),
+      &rewritten_fdef));
   std::string new_func_name =
-      fld->UniqueFunctionName(fbody->fdef.signature().name());
+      fld->UniqueFunctionName(fbody->record->fdef().signature().name());
   rewritten_fdef.mutable_signature()->set_name(new_func_name);
   TF_RETURN_IF_ERROR(fld->AddFunctionDef(rewritten_fdef));
 
@@ -2381,8 +2387,8 @@
       TF_ASSIGN_OR_RETURN(function_fbody,
                           InstantiateAssociatedFunction(*call_node, "f", fld));
       bool func_rewritten = false;
-      std::string new_func_name =
-          fld->UniqueFunctionName(function_fbody->fdef.signature().name());
+      std::string new_func_name = fld->UniqueFunctionName(
+          function_fbody->record->fdef().signature().name());
       TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsAndReplaceFunctionDef(
           *function_fbody, flr, fld, lifted_arg_count, new_func_name,
           &func_rewritten));
@@ -2403,8 +2409,8 @@
       TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, call_node->attrs(), fld,
                                                  &function_fbody));
       bool func_rewritten = false;
-      std::string new_func_name =
-          fld->UniqueFunctionName(function_fbody->fdef.signature().name());
+      std::string new_func_name = fld->UniqueFunctionName(
+          function_fbody->record->fdef().signature().name());
       TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsAndReplaceFunctionDef(
           *function_fbody, flr, fld, lifted_arg_count, new_func_name,
           &func_rewritten));
diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD
index 9521cdb..86dc7ae 100644
--- a/tensorflow/core/tpu/kernels/BUILD
+++ b/tensorflow/core/tpu/kernels/BUILD
@@ -55,6 +55,7 @@
         ":outfeed_ops",
         ":replication_ops",
         ":sharding_util_ops",
+        ":sparse_core_ops",
         ":topk_ops",
         ":tpu_compile_op",
         ":tpu_configuration_ops",
@@ -176,6 +177,7 @@
         "//tensorflow/core/platform:errors",
         "//tensorflow/core/platform:statusor",
         "@com_google_absl//absl/log",
+        "@com_google_absl//absl/status",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:span",
         "@local_xla//xla:literal_util",
@@ -183,7 +185,6 @@
         "@local_xla//xla:xla_data_proto_cc",
         "@local_xla//xla/client:xla_builder",
         "@local_xla//xla/client:xla_computation",
-        "@local_xla//xla/client/lib:constants",
         "@local_xla//xla/client/lib:slicing",
         "@local_xla//xla/stream_executor/tpu:c_api_decl",
         "@local_xla//xla/stream_executor/tpu:status_helper",
@@ -1103,8 +1104,13 @@
         "//tensorflow/compiler/tf2xla:common",
         "//tensorflow/compiler/tf2xla:xla_compiler",
         "//tensorflow/core:framework",
+        "//tensorflow/core/platform:statusor",
+        "//tensorflow/core/platform:types",
         "//tensorflow/core/tpu:tpu_defs",
+        "@com_google_absl//absl/log:check",
         "@com_google_absl//absl/strings",
+        "@local_tsl//tsl/platform:statusor",
+        "@local_xla//xla:shape_util",
         "@local_xla//xla/client:xla_builder",
         "@local_xla//xla/client/lib:constants",
     ],
@@ -1429,8 +1435,6 @@
         "@com_google_absl//absl/status",
         "@com_google_absl//absl/status:statusor",
         "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/strings:str_format",
-        "@local_tsl//tsl/platform:fingerprint",
         "@local_tsl//tsl/platform:stringpiece",
     ],
 )
@@ -1466,7 +1470,6 @@
     deps = [
         ":sparse_core_preprocess_ops",
         "//tensorflow/compiler/jit:xla_device",
-        "//tensorflow/compiler/mlir/tensorflow/transforms:tf_dialect_passes",
         "//tensorflow/compiler/tf2xla:common",
         "//tensorflow/compiler/tf2xla:xla_compiler",
         "//tensorflow/compiler/tf2xla:xla_op_registry",
diff --git a/tensorflow/core/tpu/kernels/image_resize_ops.cc b/tensorflow/core/tpu/kernels/image_resize_ops.cc
index 85ad60b..2e7c4b9 100644
--- a/tensorflow/core/tpu/kernels/image_resize_ops.cc
+++ b/tensorflow/core/tpu/kernels/image_resize_ops.cc
@@ -13,9 +13,10 @@
 limitations under the License.
 ==============================================================================*/
 
+#include <cstdint>
 #include <vector>
 
-#include "absl/strings/match.h"
+#include "absl/log/check.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/string_view.h"
 #include "tensorflow/compiler/tf2xla/shape_util.h"
@@ -23,10 +24,14 @@
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
 #include "xla/client/lib/constants.h"
 #include "xla/client/xla_builder.h"
-#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
 #include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/op_requires.h"
+#include "tensorflow/core/platform/statusor.h"
+#include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/tpu/tpu_defs.h"
+#include "tsl/platform/statusor.h"
 
 namespace tensorflow {
 
@@ -38,14 +43,17 @@
                    ctx->GetAttr("half_pixel_centers", &half_pixel_centers_));
   }
 
-  xla::Shape GetOutputShape(XlaOpKernelContext* ctx) const {
+  StatusOr<xla::Shape> GetOutputShape(XlaOpKernelContext* ctx) const {
     std::vector<int64_t> out_size;
     auto status = ctx->ConstantInputAsIntVector(1, &out_size);
     CHECK_EQ(out_size.size(), 2) << status;
+    TF_ASSIGN_OR_RETURN(xla::Shape input_shape, ctx->InputXlaShape(0));
     xla::Shape output_shape =
         TensorShapeToXLAShape(ctx->output_xla_type(0), ctx->InputShape(0));
     output_shape.mutable_dimensions()[1] = out_size[0];
     output_shape.mutable_dimensions()[2] = out_size[1];
+    output_shape.set_dynamic_dimension(0, input_shape.is_dynamic_dimension(0));
+    output_shape.set_dynamic_dimension(3, input_shape.is_dynamic_dimension(3));
     return output_shape;
   }
 
@@ -76,7 +84,7 @@
   }
 
   void CompileForward(XlaOpKernelContext* ctx, const char* target) {
-    auto output_shape = GetOutputShape(ctx);
+    OP_REQUIRES_VALUE(auto output_shape, ctx, GetOutputShape(ctx));
     if (ctx->InputShape(0).dim_size(1) == output_shape.dimensions(1) &&
         ctx->InputShape(0).dim_size(2) == output_shape.dimensions(2)) {
       ctx->SetOutput(
@@ -85,8 +93,10 @@
     }
     if (ctx->InputShape(0).dim_size(1) == 1 &&
         ctx->InputShape(0).dim_size(2) == 1) {
-      ctx->SetOutput(0,
-                     ctx->Input(0) + xla::Zeros(ctx->builder(), output_shape));
+      ctx->SetOutput(
+          0, ctx->Input(0) +
+                 xla::Zeros(ctx->builder(),
+                            xla::ShapeUtil::MakeStaticShape(output_shape)));
       return;
     }
     ctx->SetOutput(0, xla::CustomCall(ctx->builder(), target, {ctx->Input(0)},
@@ -121,7 +131,8 @@
   explicit TpuResizeNearestNeighborGradOp(OpKernelConstruction* ctx)
       : TpuCustomResizeOp(ctx) {}
   void Compile(XlaOpKernelContext* ctx) override {
-    CompileGrad(ctx, "ResizeNearestGrad", GetOutputShape(ctx));
+    OP_REQUIRES_VALUE(xla::Shape output_shape, ctx, GetOutputShape(ctx));
+    CompileGrad(ctx, "ResizeNearestGrad", output_shape);
   }
 };
 
@@ -130,8 +141,11 @@
   explicit TpuResizeBilinearGradOp(OpKernelConstruction* ctx)
       : TpuCustomResizeOp(ctx) {}
   void Compile(XlaOpKernelContext* ctx) override {
-    auto output_shape =
+    OP_REQUIRES_VALUE(xla::Shape input_shape, ctx, ctx->InputXlaShape(1));
+    xla::Shape output_shape =
         TensorShapeToXLAShape(ctx->output_xla_type(0), ctx->InputShape(1));
+    output_shape.set_dynamic_dimension(0, input_shape.is_dynamic_dimension(0));
+    output_shape.set_dynamic_dimension(3, input_shape.is_dynamic_dimension(3));
     CompileGrad(ctx, "ResizeBilinearGrad", output_shape);
   }
 };
diff --git a/tensorflow/core/tpu/kernels/sparse_core_layout.cc b/tensorflow/core/tpu/kernels/sparse_core_layout.cc
index 54979f9..4b3a961 100644
--- a/tensorflow/core/tpu/kernels/sparse_core_layout.cc
+++ b/tensorflow/core/tpu/kernels/sparse_core_layout.cc
@@ -24,11 +24,9 @@
 #include "absl/status/status.h"
 #include "absl/status/statusor.h"
 #include "absl/strings/str_cat.h"
-#include "absl/strings/str_format.h"
 #include "absl/strings/substitute.h"
 #include "tensorflow/compiler/jit/flags.h"
 #include "tensorflow/core/tpu/kernels/sparse_core_layout.pb.h"
-#include "tsl/platform/fingerprint.h"
 #include "tsl/platform/stringpiece.h"
 
 namespace tensorflow {
@@ -80,19 +78,20 @@
             << variable_shard_bytes_limit_;
   }
 
+  VLOG(1) << "Table " << table_name << ":";
   int64_t samples_per_sparse_core =
       output_samples / sparse_cores_per_partition_;
   int64_t padded_width = NextLargestMultiple(table_width, 8);
   int64_t padded_height =
       NextLargestMultiple(table_height, num_sparse_cores_ * 8);
-
+  VLOG(1) << "  Original size: " << table_height << "x" << table_width
+          << " padded size: " << padded_height << "x" << padded_width;
   // Find a stack to fit in.
   int64_t activation_mem_bytes =
       sizeof(float) * padded_width * samples_per_sparse_core;
   int64_t variable_shard_bytes =
       sizeof(float) * padded_width * padded_height / num_partitions_;
-  VLOG(1) << "Table " << table_name
-          << ": activation mem = " << activation_mem_bytes
+  VLOG(1) << "  activation mem = " << activation_mem_bytes
           << ", variable shard bytes = " << variable_shard_bytes;
 
   std::vector<TableStack> &candidate_stacks =
@@ -102,16 +101,21 @@
     for (TableStack &ts : candidate_stacks) {
       // Make sure we haven't exceeded the maximum stack memory.
       if (activation_mem_bytes_limit_ != 0 &&
-          ts.total_activation_mem_bytes + activation_mem_bytes >
+          ts.total_activation_mem_bytes + activation_mem_bytes >=
               activation_mem_bytes_limit_) {
         continue;
       }
       if (variable_shard_bytes_limit_ != 0 &&
-          ts.total_variable_shard_bytes + variable_shard_bytes >
+          ts.total_variable_shard_bytes + variable_shard_bytes >=
               variable_shard_bytes_limit_) {
         continue;
       }
 
+      if (row_limit_ != 0 &&
+          ts.unsharded_height + padded_height >= row_limit_) {
+        continue;
+      }
+
       // We found a stack we can put it in.
       stack = &ts;
       break;
@@ -147,8 +151,9 @@
   // four. Note that the python library is currently written only to advance by
   // sparse core, so the maximum shift is bounded by the number of sparse cores,
   // not the number of rows.
-  layout.set_rotation_offset(
-      (num_tables_ * num_sparse_cores_ / num_partitions_) % num_sparse_cores_);
+  layout.set_sparse_core_shard_rotation(((stack->incomplete_tables.size() - 1) *
+                                         num_sparse_cores_ / num_partitions_) %
+                                        num_sparse_cores_);
 
   // Can't set total_rows_per_sparse_core_shard yet because we may add more
   // tables to this stack.
@@ -156,7 +161,6 @@
   stack->total_variable_shard_bytes += variable_shard_bytes;
   stack->total_activation_mem_bytes += activation_mem_bytes;
 
-  ++num_tables_;
   return absl::OkStatus();
 }
 
@@ -183,15 +187,6 @@
         absl::StrAppend(&stacked_table_name, incomplete_layout.table_name());
       }
 
-      // If the table name is too long, shorten it and replace it with a hash.
-      // The stacked table name turns into a variable name, and for some
-      // systems, variable names that are too long can cause problems.
-      if (stacked_table_name.size() > 100) {
-        stacked_table_name = absl::StrCat(
-            stacked_table_name.substr(0, 80),
-            absl::StrFormat("_%x", tsl::Fingerprint64(stacked_table_name)));
-      }
-
       for (const SparseCoreTableLayout &incomplete_layout :
            stack.incomplete_tables) {
         SparseCoreTableLayout *out_layout = layouts.add_tables();
diff --git a/tensorflow/core/tpu/kernels/sparse_core_layout.h b/tensorflow/core/tpu/kernels/sparse_core_layout.h
index 5ee44a1..5c9c15d 100644
--- a/tensorflow/core/tpu/kernels/sparse_core_layout.h
+++ b/tensorflow/core/tpu/kernels/sparse_core_layout.h
@@ -54,6 +54,10 @@
     CHECK(stacks_by_group_.empty()) << "must call before AddTable";
     stacking_enabled_ = stacking_enabled;
   }
+  void SetStackingRowLimit(int64_t row_limit) {
+    CHECK(stacks_by_group_.empty()) << "must call before AddTable";
+    row_limit_ = row_limit;
+  }
 
   // Add a new table.  Arguments:
   //   table_name: How this table will be referred to.
@@ -102,7 +106,9 @@
   bool stacking_enabled_ = true;
   int64_t activation_mem_bytes_limit_ = 0;
   int64_t variable_shard_bytes_limit_ = 0;
-  int num_tables_ = 0;
+  // Sparse core ops use signed int for row numbers so we had better not stack
+  // beyond this limit.
+  int64_t row_limit_ = (1LL << 31) - 1;
 
   // All the stacks that we currently know about. Note that we use a btree_map
   // rather than a flat_hash_map so the resulting order is deterministic as long
diff --git a/tensorflow/core/tpu/kernels/sparse_core_layout.proto b/tensorflow/core/tpu/kernels/sparse_core_layout.proto
index 3625fa4..6b7bbd9 100644
--- a/tensorflow/core/tpu/kernels/sparse_core_layout.proto
+++ b/tensorflow/core/tpu/kernels/sparse_core_layout.proto
@@ -40,13 +40,12 @@
 // To find out which row of the saved variable corresponds to a given row of an
 // embedding table, you will need the information contained in this message. If
 // table_row is the row in the original unsharded table, then:
-//    rotated_row = (table_row + rotation_offset) % unsharded_padded_shape[0]
-//    sparse_core_shard = rotated_row % num_sparse_cores
+//    sparse_core_shard = (row + sparse_core_shard_rotation) % num_sparse_cores
 //    sparse_cores_per_partition = num_sparse_cores / num_partitions
 //    partition = sparse_core_shard // sparse_cores_per_partition
 //    sparse_core_shard_within_partition =
 //        sparse_core_shard % sparse_cores_per_partition
-//    row_within_sparse_core_shard = rotated_row // num_sparse_cores +
+//    row_within_sparse_core_shard = row // num_sparse_cores +
 //        sparse_core_shard_row_offset
 //    row_within_partition =
 //        total_rows_per_sparse_core_shard * sparse_core_shard_within_partition
@@ -86,13 +85,15 @@
 
   // It's common that row 0 of an embedding will be particularly hot. To
   // prevent this from landing in the same sparse core shard for all tables,
-  // different tables stack have their rows rotated around:
-  //    row_index = (row + rotation_offset) % unsharded_padded_rows[0].
+  // different tables stack rotate which sparse core they go on:
+  //    sparse_core_shard = (row + sparse_core_shard_rotation) %
+  //    num_sparse_cores
+  // Note that this is rotating around the sparse core shards, not rotating
+  // around the whole table.
   // As of 2023, this is usually set so different tables have row 0 on different
-  // partitions.  Because of mod sharding, this means that
-  //    sparse_cores_per_partition = num_sparse_cores / num_partitions
-  //    rotation_offset = table_index * sparse_cores_per_partition
-  int64 rotation_offset = 9;
+  // partitions.
+  //    sparse_core_shard_rotation = table_index * sparse_cores_per_partition
+  int64 sparse_core_shard_rotation = 9;
 }
 
 message SparseCoreTableLayouts {
diff --git a/tensorflow/core/tpu/kernels/sparse_core_layout_test.cc b/tensorflow/core/tpu/kernels/sparse_core_layout_test.cc
index 5bcd219..a77a762 100644
--- a/tensorflow/core/tpu/kernels/sparse_core_layout_test.cc
+++ b/tensorflow/core/tpu/kernels/sparse_core_layout_test.cc
@@ -40,7 +40,7 @@
                   unsharded_shape: [ 100, 6 ]
                   unsharded_padded_shape: [ 128, 8 ]
                   sparse_core_shard_row_offset: 0
-                  rotation_offset: 0
+                  sparse_core_shard_rotation: 0
                 }
                 tables {
                   table_name: 'table2'
@@ -51,7 +51,7 @@
                   unsharded_shape: [ 50, 5 ]
                   unsharded_padded_shape: [ 64, 8 ]
                   sparse_core_shard_row_offset: 16  # = 128/8
-                  rotation_offset: 4
+                  sparse_core_shard_rotation: 4
                 }
               )pb")));
 }
@@ -71,7 +71,7 @@
                   unsharded_shape: [ 100, 6 ]
                   unsharded_padded_shape: [ 128, 8 ]
                   sparse_core_shard_row_offset: 0
-                  rotation_offset: 0
+                  sparse_core_shard_rotation: 0
                 }
                 tables {
                   table_name: 'table2'
@@ -82,14 +82,14 @@
                   unsharded_shape: [ 50, 5 ]
                   unsharded_padded_shape: [ 64, 8 ]
                   sparse_core_shard_row_offset: 0
-                  rotation_offset: 4
+                  sparse_core_shard_rotation: 0
                 }
               )pb")));
 }
 
 TEST(SparseCoreLayoutStacker, RespectsActivationMemLimit) {
   SparseCoreLayoutStacker stacker(2);
-  stacker.SetActivationMemoryBytesLimit(16384);
+  stacker.SetActivationMemoryBytesLimit(16384 + 1);
 
   // Here there are several identical tables with an activation memory limit of
   //    sizeof (float) * 8 * 1024 = 8192 per table.
@@ -111,7 +111,7 @@
 
 TEST(SparseCoreLayoutStacker, RespectsVariableShardLimit) {
   SparseCoreLayoutStacker stacker(2);
-  stacker.SetVariableShardBytesLimit(4096);
+  stacker.SetVariableShardBytesLimit(4096 + 1);
 
   // Here there are several identical tables that contribute
   //    sizeof (float) * 8 * 128 / 2 = 2048 bytes to each shard.
@@ -131,6 +131,35 @@
       )pb"))));
 }
 
+TEST(SparseCoreLayoutStacker, RespectsRowLimit) {
+  SparseCoreLayoutStacker stacker(2);
+  // Disable the other limits.
+  stacker.SetActivationMemoryBytesLimit(0);
+  stacker.SetVariableShardBytesLimit(0);
+
+  // Here there are several identical tables that contribute 2^30 rows. Since
+  // the default row limit is 2^31-1, they should not be able to stack.
+  ASSERT_OK(stacker.AddTable("table1", 1 << 29, 8, "stack1", 1024));
+  ASSERT_OK(stacker.AddTable("table2", 1 << 29, 8, "stack1", 1024));
+  ASSERT_OK(stacker.AddTable("table3", 1 << 29, 8, "stack1", 1024));
+  ASSERT_OK(stacker.AddTable("table4", 1 << 29, 8, "stack1", 1024));
+  EXPECT_THAT(stacker.GetLayouts(), IsOkAndHolds(Partially(EqualsProto(R"pb(
+                tables {
+                  table_name: 'table1'
+                  stacked_table_name: 'table1_table2_table3'
+                }
+                tables {
+                  table_name: 'table2'
+                  stacked_table_name: 'table1_table2_table3'
+                }
+                tables {
+                  table_name: 'table3'
+                  stacked_table_name: 'table1_table2_table3'
+                }
+                tables { table_name: 'table4' stacked_table_name: 'table4' }
+              )pb"))));
+}
+
 }  // namespace
 }  // namespace tpu
 }  // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.cc b/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.cc
index 53698e1..54feed5 100644
--- a/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.cc
+++ b/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.cc
@@ -28,9 +28,9 @@
 #include "absl/status/status.h"
 #include "absl/strings/str_cat.h"
 #include "absl/types/span.h"
-#include "highway/hwy/base.h"  // from @com_google_highway
-#include "highway/hwy/contrib/sort/order.h"  // from @com_google_highway
-#include "highway/hwy/contrib/sort/vqsort.h"  // from @com_google_highway
+#include "hwy/base.h"  // from @com_google_highway
+#include "hwy/contrib/sort/order.h"  // from @com_google_highway
+#include "hwy/contrib/sort/vqsort.h"  // from @com_google_highway
 #include "xla/stream_executor/tpu/tpu_api.h"
 #include "xla/stream_executor/tpu/tpu_ops_c_api.h"
 #include "xla/util.h"
diff --git a/tensorflow/core/tpu/kernels/sparse_core_xla_ops.cc b/tensorflow/core/tpu/kernels/sparse_core_xla_ops.cc
index b83e33b..81d5680 100644
--- a/tensorflow/core/tpu/kernels/sparse_core_xla_ops.cc
+++ b/tensorflow/core/tpu/kernels/sparse_core_xla_ops.cc
@@ -21,6 +21,7 @@
 #include <vector>
 
 #include "absl/log/log.h"
+#include "absl/status/status.h"
 #include "absl/strings/str_cat.h"
 #include "absl/types/span.h"
 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
@@ -337,6 +338,15 @@
   explicit XlaSparseDenseMatmulGradWithCsrInputBase(OpKernelConstruction* ctx)
       : XlaOpKernel(ctx) {
     OP_REQUIRES_OK(ctx, ctx->GetAttr("table_name", &table_name_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("clip_weight_min", &clip_weight_min_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("clip_weight_max", &clip_weight_max_));
+
+    OP_REQUIRES(ctx, clip_weight_min_ <= clip_weight_max_,
+                absl::InvalidArgumentError(
+                    absl::StrCat("clip_weight_min must be smaller or equal to "
+                                 "clip_weight_max but got clip_weight_min as ",
+                                 clip_weight_min_, " and clip_weight_max as ",
+                                 clip_weight_max_, ".")));
   }
 
   ~XlaSparseDenseMatmulGradWithCsrInputBase() override = default;
@@ -350,6 +360,15 @@
 
   virtual xla::Shape get_tables_shape(xla::Shape embedding_table_shape) = 0;
 
+  xla::XlaOp apply_weight_clipping_to_table(xla::XlaBuilder* builder,
+                                            xla::XlaOp table) {
+    xla::XlaOp clip_weight_min = xla::ConstantR0(builder, clip_weight_min_);
+    xla::XlaOp clip_weight_max = xla::ConstantR0(builder, clip_weight_max_);
+    xla::XlaOp clipped_table =
+        xla::Clamp(clip_weight_min, table, clip_weight_max);
+    return clipped_table;
+  }
+
   void Compile(XlaOpKernelContext* ctx) override {
     xla::XlaBuilder* builder = ctx->builder();
 
@@ -451,6 +470,10 @@
     builder->SetFrontendAttributes(original_frontend_attributes);
   }
 
+ protected:
+  float clip_weight_min_;
+  float clip_weight_max_;
+
  private:
   std::string table_name_;
 
@@ -497,8 +520,12 @@
           xla::GetTupleElement(tables, 0) -
           xla::GetTupleElement(hyperparameters, 0) * gradient;
 
+      // Apply the weight clipping.
+      xla::XlaOp clipped_embedding_table = apply_weight_clipping_to_table(
+          sgd_optimizer_builder.get(), updated_embedding_table);
+
       xla::XlaOp updated_tables =
-          xla::Tuple(sgd_optimizer_builder.get(), {updated_embedding_table});
+          xla::Tuple(sgd_optimizer_builder.get(), {clipped_embedding_table});
 
       return sgd_optimizer_builder->Build(updated_tables).value();
     }();
@@ -574,9 +601,13 @@
           embedding_table -
           learning_rate * gradient / xla::Sqrt(new_accumulator);
 
+      // Apply the weight clipping.
+      xla::XlaOp clipped_embedding_table = apply_weight_clipping_to_table(
+          adagrad_optimizer_builder.get(), updated_embedding_table);
+
       xla::XlaOp updated_tables =
           xla::Tuple(adagrad_optimizer_builder.get(),
-                     {updated_embedding_table, new_accumulator});
+                     {clipped_embedding_table, new_accumulator});
       return adagrad_optimizer_builder->Build(updated_tables).value();
     }();
 
@@ -699,9 +730,13 @@
         updated_embedding_table = embedding_table - learning_rate * new_momenta;
       }
 
+      // Apply the weight clipping.
+      xla::XlaOp clipped_embedding_table = apply_weight_clipping_to_table(
+          adagrad_momentum_optimizer_builder.get(), updated_embedding_table);
+
       xla::XlaOp updated_tables =
           xla::Tuple(adagrad_momentum_optimizer_builder.get(),
-                     {updated_embedding_table, new_accumulator, new_momenta});
+                     {clipped_embedding_table, new_accumulator, new_momenta});
       return adagrad_momentum_optimizer_builder->Build(updated_tables).value();
     }();
 
@@ -768,7 +803,6 @@
 
       xla::XlaOp gradient = xla::Parameter(adam_optimizer_builder.get(), 0,
                                            per_row_shape, "gradient");
-
       xla::XlaOp tables =
           xla::Parameter(adam_optimizer_builder.get(), 1,
                          xla::ShapeUtil::MakeTupleShape(
@@ -814,9 +848,13 @@
           embedding_table -
           learning_rate * new_momenta / (xla::Sqrt(new_velocity + e1) + e2);
 
+      // Apply the weight clipping.
+      xla::XlaOp clipped_embedding_table = apply_weight_clipping_to_table(
+          adam_optimizer_builder.get(), updated_embedding_table);
+
       xla::XlaOp updated_tables =
           xla::Tuple(adam_optimizer_builder.get(),
-                     {updated_embedding_table, new_momenta, new_velocity});
+                     {clipped_embedding_table, new_momenta, new_velocity});
       return adam_optimizer_builder->Build(updated_tables).value();
     }();
 
@@ -973,9 +1011,13 @@
       }
       xla::XlaOp updated_embedding_table = numer / denom;
 
+      // Apply the weight clipping.
+      xla::XlaOp clipped_embedding_table = apply_weight_clipping_to_table(
+          ftrl_optimizer_builder.get(), updated_embedding_table);
+
       xla::XlaOp updated_tables =
           xla::Tuple(ftrl_optimizer_builder.get(),
-                     {updated_embedding_table, new_accumulator, new_linear});
+                     {clipped_embedding_table, new_accumulator, new_linear});
       return ftrl_optimizer_builder->Build(updated_tables).value();
     }();
 
diff --git a/tensorflow/core/tpu/ops/BUILD b/tensorflow/core/tpu/ops/BUILD
index 5bb2593..79e0d83 100644
--- a/tensorflow/core/tpu/ops/BUILD
+++ b/tensorflow/core/tpu/ops/BUILD
@@ -11,8 +11,11 @@
     linkstatic = 1,
     deps = [
         ":host_compute_ops",
+        ":sparse_core_ops",
+        ":sparse_core_preprocess_ops",
         ":topk_ops",
         ":tpu_compile_op",
+        ":tpu_copy_with_dynamic_shape_op",
         ":tpu_embedding_ops",
         ":tpu_execute_op",
         ":tpu_handle_to_key_op",
@@ -186,7 +189,6 @@
     deps = [
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
-        "//tensorflow/core/tpu/kernels:sparse_core_ops_utils",
         "@local_xla//xla:util",
     ],
     alwayslink = 1,
@@ -202,12 +204,6 @@
         "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core/tpu:tpu_embedding_optimization_parameters_utils",
-        "//tensorflow/core/tpu/ops:tpu_compile_op",
-        "//tensorflow/core/tpu/ops:tpu_embedding_ops",
-        "//tensorflow/core/tpu/ops:tpu_execute_op",
-        "//tensorflow/core/tpu/ops:tpu_handle_to_key_op",
-        "//tensorflow/core/tpu/ops:tpu_partitioned_ops",
-        "//tensorflow/core/tpu/ops:tpu_round_robin_op",
         "@com_google_absl//absl/strings",
     ],
     alwayslink = 1,
diff --git a/tensorflow/core/tpu/ops/sparse_core_ops.cc b/tensorflow/core/tpu/ops/sparse_core_ops.cc
index a10da3e..f9b9d64 100644
--- a/tensorflow/core/tpu/ops/sparse_core_ops.cc
+++ b/tensorflow/core/tpu/ops/sparse_core_ops.cc
@@ -106,6 +106,8 @@
     .Input("embedding_table: float32")
     .Input("num_minibatches_per_physical_sparse_core: int32")
     .Output("updated_embedding_table: float32")
+    .Attr("clip_weight_min: float = -inf")
+    .Attr("clip_weight_max: float = inf")
     .Attr("table_name: string")
     .SetShapeFn([](shape_inference::InferenceContext* c) -> Status {
       c->set_output(0, c->input(6));
@@ -124,6 +126,8 @@
     .Input("num_minibatches_per_physical_sparse_core: int32")
     .Output("updated_embedding_table: float32")
     .Output("updated_accumulator: float32")
+    .Attr("clip_weight_min: float = -inf")
+    .Attr("clip_weight_max: float = inf")
     .Attr("table_name: string")
     .SetShapeFn([](shape_inference::InferenceContext* c) -> Status {
       c->set_output(0, c->input(6));
@@ -150,6 +154,8 @@
     .Attr("beta1: float")
     .Attr("beta2: float")
     .Attr("epsilon: float")
+    .Attr("clip_weight_min: float = -inf")
+    .Attr("clip_weight_max: float = inf")
     .Attr("table_name: string")
     .SetShapeFn([](shape_inference::InferenceContext* c) -> Status {
       c->set_output(0, c->input(6));
@@ -176,6 +182,8 @@
     .Attr("beta1: float")
     .Attr("beta2: float")
     .Attr("epsilon: float")
+    .Attr("clip_weight_min: float = -inf")
+    .Attr("clip_weight_max: float = inf")
     .Attr("table_name: string")
     .SetShapeFn([](shape_inference::InferenceContext* c) -> Status {
       c->set_output(0, c->input(6));
@@ -203,6 +211,8 @@
     .Attr("learning_rate_power: float")
     .Attr("l1_regularization_strength: float")
     .Attr("l2_regularization_strength: float")
+    .Attr("clip_weight_min: float = -inf")
+    .Attr("clip_weight_max: float = inf")
     .Attr("table_name: string")
     .SetShapeFn([](shape_inference::InferenceContext* c) -> Status {
       c->set_output(0, c->input(6));
diff --git a/tensorflow/core/tpu/ops/sparse_core_preprocess_ops.cc b/tensorflow/core/tpu/ops/sparse_core_preprocess_ops.cc
index d25c9d1..a6d1fa9 100644
--- a/tensorflow/core/tpu/ops/sparse_core_preprocess_ops.cc
+++ b/tensorflow/core/tpu/ops/sparse_core_preprocess_ops.cc
@@ -17,7 +17,6 @@
 #include "tensorflow/core/framework/shape_inference.h"
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/tpu/kernels/sparse_core_ops_utils.h"
 #include "tsl/platform/errors.h"
 
 namespace tensorflow {
@@ -131,22 +130,9 @@
       c->set_output(0, c->UnknownShapeOfRank(1));
       c->set_output(1, c->UnknownShapeOfRank(1));
       c->set_output(2, c->UnknownShapeOfRank(1));
-      int32 num_replica;
-      TF_RETURN_IF_ERROR(c->GetAttr("num_replica", &num_replica));
-
-      int32 num_sc_per_chip;
-      TF_RETURN_IF_ERROR(c->GetAttr("num_sc_per_chip", &num_sc_per_chip));
-
-      const int max_division_level = GetMinibatchMaxDivisionLevel();
-
-      const int num_physical_replica = num_replica * num_sc_per_chip;
-
-      const int32 kMaxDivisions = 1 << max_division_level;
-
       c->set_output(3, c->Scalar());
-      c->set_output(
-          4, c->MakeShape(
-                 {num_physical_replica * kMaxDivisions * num_sc_per_chip + 1}));
+      // Depends on max division level, which is currently passed by flag.
+      c->set_output(4, c->UnknownShapeOfRank(1));
       c->set_output(5, c->Scalar());
       c->set_output(6, c->Scalar());
       return OkStatus();
diff --git a/tensorflow/core/tpu/ops/tpu_copy_with_dynamic_shape_op.cc b/tensorflow/core/tpu/ops/tpu_copy_with_dynamic_shape_op.cc
index df7312e..e402805 100644
--- a/tensorflow/core/tpu/ops/tpu_copy_with_dynamic_shape_op.cc
+++ b/tensorflow/core/tpu/ops/tpu_copy_with_dynamic_shape_op.cc
@@ -35,11 +35,7 @@
         c->set_output(i, c->input(i));
       }
       return OkStatus();
-    })
-    .Doc(R"(
-Op that copies host tensor to device with dynamic shape support.
-For internal use only.
-)");
+    });
 
 REGISTER_OP("TPUAnnotateTensorsWithDynamicShape")
     .Input("tensors: T")
diff --git a/tensorflow/core/tpu/tpu_compile.cc b/tensorflow/core/tpu/tpu_compile.cc
index 2264252..c504d17 100644
--- a/tensorflow/core/tpu/tpu_compile.cc
+++ b/tensorflow/core/tpu/tpu_compile.cc
@@ -15,6 +15,7 @@
 
 #include "tensorflow/core/tpu/tpu_compile.h"
 
+#include <algorithm>
 #include <map>
 #include <memory>
 #include <optional>
@@ -29,15 +30,20 @@
 #include "tensorflow/compiler/jit/flags.h"
 #include "tensorflow/compiler/jit/shape_inference.h"
 #include "tensorflow/compiler/tf2xla/layout_util.h"
+#include "tensorflow/compiler/tf2xla/literal_util.h"
 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
 #include "xla/client/compile_only_client.h"
+#include "xla/literal_util.h"
 #include "xla/xla_data.pb.h"
 #include "tensorflow/core/common_runtime/function_utils.h"
 #include "tensorflow/core/common_runtime/graph_constructor.h"
 #include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/graph/graph.h"
 #include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
 #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
 #include "tensorflow/core/tpu/tpu_defs.h"
 
@@ -277,6 +283,28 @@
   return OkStatus();
 }
 
+// If the metadata specifies any bounded dynamic shapes in the arg then create
+// the matching Tensor values for the Argument.
+Status MaybeBuildBoundedDynamicArgValues(
+    const tpu::TPUCompileMetadataProto::Arg& proto_arg,
+    const TensorShape& shape, XlaCompiler::Argument& arg) {
+  // If any entry in the is_bounded_dynamic_dim list is true then we update the
+  // value_bound and value_dynamism fields to indicate that there is dynamism,
+  // the bounds, and which dimensions are dynamic.
+  auto is_dynamic_dim = absl::MakeConstSpan(proto_arg.is_bounded_dynamic_dim());
+  if (std::any_of(is_dynamic_dim.begin(), is_dynamic_dim.end(),
+                  [](bool v) { return v; })) {
+    // Assume that the values in the shape are the maximums.
+    arg.value_bound = Tensor(arg.type, shape);
+    // Build a literal tensor of Bools to hold which Dims are dynamic.
+    auto literal = xla::LiteralUtil::CreateR1(is_dynamic_dim);
+    Tensor dynamism_tensor(DT_BOOL);
+    TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, DT_BOOL, &dynamism_tensor));
+    arg.value_dynamism = dynamism_tensor;
+  }
+  return OkStatus();
+}
+
 // Populates the arguments, core mapping and per core argument shape for the
 // computation.
 Status BuildComputationArgumentDescriptions(
@@ -305,6 +333,10 @@
     switch (proto_arg.kind()) {
       case tpu::TPUCompileMetadataProto::Arg::PARAMETER:
         arg.kind = XlaCompiler::Argument::kParameter;
+        // TODO(b/308845592) Maybe do this with the XlaCompileOnDemand version
+        // of this method and maybe move whole method to a shared location.
+        TF_RETURN_IF_ERROR(
+            MaybeBuildBoundedDynamicArgValues(proto_arg, arg_shapes[i], arg));
         break;
       case tpu::TPUCompileMetadataProto::Arg::VARIABLE:
         arg.kind = XlaCompiler::Argument::kResource;
diff --git a/tensorflow/core/transforms/utils/eval_utils.cc b/tensorflow/core/transforms/utils/eval_utils.cc
index 32cab6d..6e3b4ca 100644
--- a/tensorflow/core/transforms/utils/eval_utils.cc
+++ b/tensorflow/core/transforms/utils/eval_utils.cc
@@ -26,6 +26,7 @@
 #include "tensorflow/core/framework/allocator.h"
 #include "tensorflow/core/framework/control_flow.h"
 #include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/ir/importexport/convert_tensor.h"
 #include "tensorflow/core/ir/importexport/graphdef_export.h"
 #include "tensorflow/core/platform/logging.h"
@@ -99,7 +100,7 @@
   absl::InlinedVector<tensorflow::TensorValue, 4> input_tensor_values(
       operands.size());
   // For each operand, convert its ElementsAttr to a Tensor and the Tensor will
-  // be referenced by a TensorValue. To ensure Tensor/TensorValue have thier
+  // be referenced by a TensorValue. To ensure Tensor/TensorValue have their
   // lifecycle across the later evaluation. They are stored in
   // `input_tensors`\`input_tensor_values` respectively. The following loop zips
   // them together so that the bundled values are related. Note that the
@@ -112,8 +113,8 @@
 
   tensorflow::Status status;
   std::unique_ptr<tensorflow::OpKernel> op_kernel = tensorflow::CreateOpKernel(
-      "CPU", cpu_device, cpu_device->GetAllocator({}), node_def,
-      TF_GRAPH_DEF_VERSION, &status);
+      tensorflow::DEVICE_CPU, cpu_device, cpu_device->GetAllocator({}),
+      node_def, TF_GRAPH_DEF_VERSION, &status);
   if (!status.ok()) {
     VLOG(3) << status.message();
     return failure();
diff --git a/tensorflow/core/transforms/utils/eval_utils.h b/tensorflow/core/transforms/utils/eval_utils.h
index d11e022..972ce49 100644
--- a/tensorflow/core/transforms/utils/eval_utils.h
+++ b/tensorflow/core/transforms/utils/eval_utils.h
@@ -25,6 +25,7 @@
 #include "tensorflow/core/framework/device_base.h"
 #include "tensorflow/core/framework/resource_mgr.h"
 #include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/ir/tf_op_wrapper.h"
 
 namespace Eigen {
@@ -49,10 +50,13 @@
   tensorflow::Allocator* GetAllocator(
       tensorflow::AllocatorAttributes attr) override;
 
+  const std::string& device_type() const override { return device_type_; }
+
  private:
   std::unique_ptr<tensorflow::thread::ThreadPool> eigen_worker_;
   tensorflow::DeviceBase::CpuWorkerThreads eigen_worker_threads_;
   std::unique_ptr<Eigen::ThreadPoolDevice> eigen_device_;
+  const std::string device_type_ = tensorflow::DEVICE_CPU;
 };
 
 // Attempts to evaluates an MLIR Operation with the op registered kernel. The op
diff --git a/tensorflow/core/util/command_line_flags_test.cc b/tensorflow/core/util/command_line_flags_test.cc
index 284524a..221f347 100644
--- a/tensorflow/core/util/command_line_flags_test.cc
+++ b/tensorflow/core/util/command_line_flags_test.cc
@@ -303,4 +303,47 @@
   usage = Flags::Usage(tool_name, {});
   ASSERT_EQ(MatchWithAnyWhitespace(usage, " usage: some_tool_name\n"), true);
 }
+
+namespace {
+template <typename T, typename ExpectationFun>
+void PrefixTestTempl(ExpectationFun expectation_fun, const T &value0,
+                     const T &value1, string str0, string str1) {
+  int argc = 3;
+  std::vector<string> argv_strings = {
+      "program_name",
+      "--hello" + str0,
+      "--hello_world" + str1,
+  };
+  std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings);
+
+  T hello{};
+  T hello_world{};
+  bool parsed_ok = Flags::Parse(
+      &argc, argv_array.data(),
+      {
+          Flag("hello", &hello, "usage of hello"),
+          Flag("hello_world", &hello_world, "usage of hello world"),
+      });
+
+  EXPECT_EQ(true, parsed_ok);
+  expectation_fun(value0, hello);
+  expectation_fun(value1, hello_world);
+  EXPECT_EQ(argc, 1);
+}
+}  // namespace
+
+TEST(CommandLineFlagsTest, OneArgumentIsAPrefixOfAnother) {
+  auto expect_eq = [](auto a, auto b) { EXPECT_EQ(a, b); };
+  auto expect_near = [](auto a, auto b) { EXPECT_NEAR(a, b, 1e-5f); };
+
+  PrefixTestTempl<int32_t>(expect_eq, 1, 2, "=1", "=2");
+  PrefixTestTempl<int64_t>(expect_eq, 1, 2, "=1", "=2");
+  PrefixTestTempl<bool>(expect_eq, false, true, "=false", "=true");
+  PrefixTestTempl<bool>(expect_eq, false, true, "=false", "");
+  PrefixTestTempl<bool>(expect_eq, true, false, "=true", "=false");
+  PrefixTestTempl<bool>(expect_eq, true, false, "", "=false");
+  PrefixTestTempl<string>(expect_eq, "a", "b", "=a", "=b");
+  PrefixTestTempl<float>(expect_near, 0.1f, 0.2f, "=0.1", "=0.2");
+}
+
 }  // namespace tensorflow
diff --git a/tensorflow/core/util/dump_graph.cc b/tensorflow/core/util/dump_graph.cc
index 51fd114..e8e7cb6 100644
--- a/tensorflow/core/util/dump_graph.cc
+++ b/tensorflow/core/util/dump_graph.cc
@@ -270,4 +270,13 @@
                     });
 }
 
+string DumpProtoToFile(const string& name,
+                       tensorflow::protobuf::Message const& proto,
+                       const string& dirname) {
+  return DumpToFile(name, dirname, ".pbtxt", proto.GetTypeName(),
+                    [&](WritableFile* file) {
+                      return WriteTextProtoToUniqueFile(proto, file);
+                    });
+}
+
 }  // namespace tensorflow
diff --git a/tensorflow/core/util/dump_graph.h b/tensorflow/core/util/dump_graph.h
index 0a13bf2..aea03d4 100644
--- a/tensorflow/core/util/dump_graph.h
+++ b/tensorflow/core/util/dump_graph.h
@@ -58,6 +58,12 @@
 string DumpFunctionDefToFile(const string& name, FunctionDef const& fdef,
                              const string& dirname = "");
 
+// Similar to DumpGraphDefToFile, but dumps a proto of any type. Returns the
+// file name chosen.
+string DumpProtoToFile(const string& name,
+                       tensorflow::protobuf::Message const& proto,
+                       const string& dirname = "");
+
 // Sets a custom Graph dumper. If set, this dumper will be used to dump graphs
 // instead via DumpGraphToFile. As the custom dumper may not produce protobufs,
 // allow specifying a file suffix/extension too.
diff --git a/tensorflow/core/util/dump_graph_test.cc b/tensorflow/core/util/dump_graph_test.cc
index c7e510e..9942ba1 100644
--- a/tensorflow/core/util/dump_graph_test.cc
+++ b/tensorflow/core/util/dump_graph_test.cc
@@ -60,5 +60,19 @@
   EXPECT_EQ(ret, io::JoinPath(testing::TmpDir(), "function.pbtxt"));
 }
 
+TEST(DumpGraph, DumpProtoToFileSuccess) {
+  NodeDef ndef_in;
+  ndef_in.set_name("foo");
+
+  setenv("TF_DUMP_GRAPH_PREFIX", testing::TmpDir().c_str(), 1);
+  string expected_filepath = io::JoinPath(testing::TmpDir(), "node_def.pbtxt");
+  string actual_filepath = DumpProtoToFile("node_def", ndef_in);
+  EXPECT_EQ(expected_filepath, actual_filepath);
+
+  NodeDef ndef_out;
+  TF_ASSERT_OK(ReadTextProto(Env::Default(), expected_filepath, &ndef_out));
+  EXPECT_EQ(ndef_in.DebugString(), ndef_out.DebugString());
+}
+
 }  // namespace
 }  // namespace tensorflow
diff --git a/tensorflow/dtensor/mlir/tests/cluster_function_conversion.mlir b/tensorflow/dtensor/mlir/tests/cluster_function_conversion.mlir
index c63ffda..dc4e622 100644
--- a/tensorflow/dtensor/mlir/tests/cluster_function_conversion.mlir
+++ b/tensorflow/dtensor/mlir/tests/cluster_function_conversion.mlir
@@ -18,8 +18,8 @@
 func.func @check_layouts_retvals_attached_in_layout_op() -> tensor<i32> {
   // CHECK-NOT:       "tf_device.cluster_func"()
   // CHECK:           %[[SPC_OUT:.*]] = "tf.StatefulPartitionedCall"()
-  // CHECK-SAME:      _layout = ["sharding_specs: mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1,/job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3"]
   // CHECK-SAME:      config = "|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1,/job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3"
+  // CHECK-SAME:      _layout = ["sharding_specs: mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1,/job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3"]
   %0 = "tf_device.cluster_func"() {func = @single_in_out, _mesh="|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1,/job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3"} : () -> tensor<i32>
   func.return %0 : tensor<i32>
 }
@@ -33,8 +33,8 @@
 func.func @check_layouts_retval_attached_with_multi_in_op(%arg0: tensor<i64>, %arg1: tensor<1xf32> {tf._layout = "sharding_specs:scalar mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1,/job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3" }, %arg2: tensor<1xf32> {tf._layout = "mesh:CPU,x=2,y=2 layout:scalar" }) -> tensor<1xf32> {
   // CHECK-NOT:       "tf_device.cluster_func"()
   // CHECK-NEXT:      %[[SPC_OUT:.*]] = "tf.StatefulPartitionedCall"(%arg1, %arg2)
-  // CHECK-SAME:      _layout = ["sharding_specs:unsharded, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1,/job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3"]
   // CHECK-SAME:      config = "|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1,/job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3"
+  // CHECK-SAME:      _layout = ["sharding_specs:unsharded, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1,/job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3"]
   %0 = "tf_device.cluster_func"(%arg1, %arg2) {_mesh = "|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1,/job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3", func = @multi_in} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
   func.return %0 : tensor<1xf32>
 }
@@ -50,11 +50,11 @@
 func.func @check_input_resource_layouts_attached_in_call_op() -> tensor<i32> {
   // CHECK-NOT:       "tf_device.cluster_func"()
   // CHECK:           %[[SPC_OUT:.*]] = "tf.StatefulPartitionedCall"()
+  // CHECK-SAME:      config = "|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1,/job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3"
   // CHECK-SAME:      _inferred_resource_indices = dense<1> : vector<1xi32>
   // CHECK-SAME:      _inferred_resource_layouts
   // CHECK-SAME:      "sharding_specs:unsharded, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1,/job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3"
   // CHECK-SAME:      _layout = ["sharding_specs:unsharded, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1,/job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3"]
-  // CHECK-SAME:      config = "|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1,/job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3"
   %0 = "tf_device.cluster_func"() {func = @single_in_out, _mesh="|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1,/job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3", _inferred_resource_indices = dense<1> : vector<1xi32>,
     _inferred_resource_layouts = ["sharding_specs:unsharded, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1,/job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3"]} : () -> tensor<i32>
   func.return %0 : tensor<i32>
@@ -71,9 +71,9 @@
 func.func @check_nested_stateful_partitioned_call() -> (tensor<i32>, tensor<i32>) {
   // CHECK-NOT:       "tf_device.cluster_func"()
   // CHECK:           %[[SPC_OUT:.*]] = "tf.StatefulPartitionedCall"()
+  // CHECK-SAME:      config = "|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1,/job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3"
   // CHECK-SAME:      _layout = ["sharding_specs:unsharded, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1,/job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3"
   // CHECK-SAME:      "sharding_specs:unsharded, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1,/job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3"]
-  // CHECK-SAME:      config = "|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1,/job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3"
   %0:2 = "tf_device.cluster_func"() {func = @nested_stateful_partitioned_call, _mesh="|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1,/job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3"} : () -> (tensor<i32>, tensor<i32>)
   func.return %0#0, %0#1 : tensor<i32>, tensor<i32>
 }
@@ -100,9 +100,9 @@
 func.func @check_var_handle_op_skip_compilation() -> tensor<!tf_type.resource<tensor<i32>>> {
   // CHECK-NOT:       "tf_device.cluster_func"()
   // CHECK:           %[[SPC_OUT:.*]] = "tf.StatefulPartitionedCall"()
+  // CHECK-SAME:      config = "TPU|x=2,y=1|0,1|0,1|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1"
   // CHECK-SAME:      _layout = ["sharding_specs: mesh:TPU|x=2,y=1|0,1|0,1|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1"]
   // CHECK-SAME:      _skip_xla_compilation = true
-  // CHECK-SAME:      config = "TPU|x=2,y=1|0,1|0,1|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1"
   %0 = "tf_device.cluster_func"() {func = @var_handle_op, _mesh="TPU|x=2,y=1|0,1|0,1|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1"} : () -> tensor<!tf_type.resource<tensor<i32>>>
   func.return %0 : tensor<!tf_type.resource<tensor<i32>>>
 }
diff --git a/tensorflow/dtensor/mlir/tests/dtensor_allreduce_combine_optimization.mlir b/tensorflow/dtensor/mlir/tests/dtensor_allreduce_combine_optimization.mlir
index de72a60..9230e34 100644
--- a/tensorflow/dtensor/mlir/tests/dtensor_allreduce_combine_optimization.mlir
+++ b/tensorflow/dtensor/mlir/tests/dtensor_allreduce_combine_optimization.mlir
@@ -4,11 +4,11 @@
 // CHECK-LABEL: func @main
 func.func @main() {
   // CHECK:      %[[VAL_1:.*]] = "tf.Const"
-  // CHECK-SAME:   {value = dense<{{.*}}> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
+  // CHECK-SAME:   <{value = dense<{{.*}}> : tensor<4x4xf32>}> : () -> tensor<4x4xf32>
   // CHECK:      %[[GROUP_ASSIGNMENT:.*]] = "tf.Const"()
-  // CHECK-SAME:   {value = dense<{{.*}}> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
+  // CHECK-SAME:   <{value = dense<{{.*}}> : tensor<2x2xi32>}> : () -> tensor<2x2xi32>
   // CHECK:      %[[VAL_2:.*]] = "tf.Const"
-  // CHECK-SAME:   {value = dense<{{.*}}> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
+  // CHECK-SAME:   <{value = dense<{{.*}}> : tensor<4x4xf32>}> : () -> tensor<4x4xf32>
   //
   // CHECK:      %[[FILL:.*]] = "tf.Fill"
   // CHECK:      %[[FLATTEN_1:.*]] = "tf.Reshape"(%[[VAL_1]], %cst_{{[0-9]*}})
@@ -41,11 +41,11 @@
 // CHECK-LABEL: func @main
 func.func @main() {
   // CHECK:      %[[VAL_1:.*]] = "tf.Const"
-  // CHECK-SAME:   {value = dense<{{.*}}> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
+  // CHECK-SAME:   <{value = dense<{{.*}}> : tensor<4x4xf32>}> : () -> tensor<4x4xf32>
   // CHECK:      %[[GROUP_ASSIGNMENT:.*]] = "tf.Const"()
-  // CHECK-SAME:   {value = dense<{{.*}}> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
+  // CHECK-SAME:   <{value = dense<{{.*}}> : tensor<2x2xi32>}> : () -> tensor<2x2xi32>
   // CHECK:      %[[VAL_2:.*]] = "tf.Const"
-  // CHECK-SAME:   {value = dense<{{1.0.*}}> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
+  // CHECK-SAME:   <{value = dense<{{1.0.*}}> : tensor<4x4xf32>}> : () -> tensor<4x4xf32>
   //
   //
   // CHECK:      %[[ALL_REDUCE_0:.*]] = "tf.DTensorAllReduce"(%[[VAL_2]], %[[GROUP_ASSIGNMENT]])
@@ -97,9 +97,9 @@
 // CHECK-LABEL: func @main
 func.func @main() {
   // CHECK:      %[[VAL:.*]] = "tf.Const"
-  // CHECK-SAME:   {value = dense<{{.*}}> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
+  // CHECK-SAME:   <{value = dense<{{.*}}> : tensor<4x4xf32>}> : () -> tensor<4x4xf32>
   // CHECK:      %[[GROUP_ASSIGNMENT:.*]] = "tf.Const"()
-  // CHECK-SAME:   {value = dense<{{.*}}> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
+  // CHECK-SAME:   <{value = dense<{{.*}}> : tensor<2x2xi32>}> : () -> tensor<2x2xi32>
   // CHECK:      %[[ALL_REDUCE_1:.*]] = "tf.DTensorAllReduce"(%[[VAL]], %[[GROUP_ASSIGNMENT]])
   // CHECK-SAME:   (tensor<4x4xf32>, tensor<2x2xi32>) -> tensor<4x4xf32>
   //
@@ -272,4 +272,4 @@
     }) : () -> tensor<4x4xf32>
     "func.return"() : () -> ()
   }
-}
\ No newline at end of file
+}
diff --git a/tensorflow/dtensor/mlir/tests/dtensor_allreduce_lowering.mlir b/tensorflow/dtensor/mlir/tests/dtensor_allreduce_lowering.mlir
index ada3ccb..1c1aa3c 100644
--- a/tensorflow/dtensor/mlir/tests/dtensor_allreduce_lowering.mlir
+++ b/tensorflow/dtensor/mlir/tests/dtensor_allreduce_lowering.mlir
@@ -22,12 +22,12 @@
   // CHECK:      "tf_device.cluster"
   // CHECK:       %[[DEVICE_ID_RESHAPE:.*]] = "tf.Reshape"(%arg0
   // CHECK:       %[[RELATIVE_DEVICE_ID:.*]] = "tf.Sub"(%[[DEVICE_ID_RESHAPE]]
-  // CHECK-DAG:   %[[CONST_1:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>}
-  // CHECK-DAG:   %[[DEVICE_ID_TO_GROUP_KEY:.*]] = "tf.Const"() {value = dense<[[[GROUP_KEYS:.*]]]> : tensor<8xi32>}
+  // CHECK-DAG:   %[[CONST_1:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}>
+  // CHECK-DAG:   %[[DEVICE_ID_TO_GROUP_KEY:.*]] = "tf.Const"() <{value = dense<[[[GROUP_KEYS:.*]]]> : tensor<8xi32>}>
   // CHECK:       %[[GROUP_KEY_SLICE:.*]] = "tf.Slice"(%[[DEVICE_ID_TO_GROUP_KEY]], %[[RELATIVE_DEVICE_ID]], %[[CONST_1]]
   // CHECK:       %[[GROUP_KEY_RESHAPE:.*]] = "tf.Reshape"(%[[GROUP_KEY_SLICE]]
-  // CHECK-DAG:   %[[INSTANCE_KEY:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} 
-  // CHECK-DAG:   %[[GROUP_SIZE:.*]] = "tf.Const"() {value = dense<2> : tensor<i32>}
+  // CHECK-DAG:   %[[INSTANCE_KEY:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}>
+  // CHECK-DAG:   %[[GROUP_SIZE:.*]] = "tf.Const"() <{value = dense<2> : tensor<i32>}>
   // CHECK:       %[[REDUCE_OUT:.*]] = "tf.CollectiveReduceV2"(%arg1, %[[GROUP_SIZE]], %[[GROUP_KEY_RESHAPE]], %[[INSTANCE_KEY]])
   // CHECK-SAME:  final_op = "Id"
   // CHECK-SAME:  merge_op = "Add"
diff --git a/tensorflow/dtensor/mlir/tests/dtensor_allreduce_scatter_optimization.mlir b/tensorflow/dtensor/mlir/tests/dtensor_allreduce_scatter_optimization.mlir
index c572968..c562f59 100644
--- a/tensorflow/dtensor/mlir/tests/dtensor_allreduce_scatter_optimization.mlir
+++ b/tensorflow/dtensor/mlir/tests/dtensor_allreduce_scatter_optimization.mlir
@@ -19,10 +19,10 @@
 
 // CHECK-LABEL: func @all_reduce_scatter_2d_major_dim
 func.func @all_reduce_scatter_2d_major_dim() {
-    // CHECK:               %[[INPUT:.*]] = "tf.Const"() {value = dense<0.0
-    // CHECK:               %[[GROUP:.*]] = "tf.Const"() {value =
+    // CHECK:               %[[INPUT:.*]] = "tf.Const"() <{value = dense<0.0
+    // CHECK:               %[[GROUP:.*]] = "tf.Const"() <{value =
     // CHECK-SAME{LITERAL}: dense<[[0, 2], [1, 3]]>
-    // CHECK:               %[[SCATTER_DIM:.*]] = "tf.Const"() {value = dense<0>
+    // CHECK:               %[[SCATTER_DIM:.*]] = "tf.Const"() <{value = dense<0>
     // CHECK:               "tf.DTensorReduceScatter"(%[[INPUT]], %[[GROUP]], %[[SCATTER_DIM]])
     // CHECK-SAME:          reduce_op = "Add"
     // CHECK-NOT:           "tf.DTensorAllReduce"
@@ -41,10 +41,10 @@
 
 // CHECK-LABEL: func @all_reduce_scatter_2d_minor_dim
 func.func @all_reduce_scatter_2d_minor_dim() {
-    // CHECK:               %[[INPUT:.*]] = "tf.Const"() {value = dense<0.0
-    // CHECK:               %[[GROUP:.*]] = "tf.Const"() {value =
+    // CHECK:               %[[INPUT:.*]] = "tf.Const"() <{value = dense<0.0
+    // CHECK:               %[[GROUP:.*]] = "tf.Const"() <{value =
     // CHECK-SAME{LITERAL}: dense<[[0, 2], [1, 3]]>
-    // CHECK:               %[[SCATTER_DIM:.*]] = "tf.Const"() {value = dense<1>
+    // CHECK:               %[[SCATTER_DIM:.*]] = "tf.Const"() <{value = dense<1>
     // CHECK:               "tf.DTensorReduceScatter"(%[[INPUT]], %[[GROUP]], %[[SCATTER_DIM]])
     // CHECK-SAME:          reduce_op = "Add"
     // CHECK-NOT:           "tf.DTensorAllReduce"
diff --git a/tensorflow/dtensor/mlir/tests/dtensor_allreduce_sum_optimization.mlir b/tensorflow/dtensor/mlir/tests/dtensor_allreduce_sum_optimization.mlir
index b964b7b..28c01e0 100644
--- a/tensorflow/dtensor/mlir/tests/dtensor_allreduce_sum_optimization.mlir
+++ b/tensorflow/dtensor/mlir/tests/dtensor_allreduce_sum_optimization.mlir
@@ -149,13 +149,13 @@
   // CHECK:         "tf.A"
   // CHECK-NEXT:    "tf.Yield"
   // CHECK:         ^bb0(%[[BARG0:.*]]: tensor<4xf32>, %[[BARG1:.*]]: tensor<i32>)
-  // CHECK:          %[[INPUT0:.*]] = "tf.Const"() {value = dense<0> : tensor<4xi32>} : () -> tensor<4xi32>
-  // CHECK-NEXT:     %[[GROUP:.*]] = "tf.Const"() {value = dense<0> : tensor<2x64xi32>} : () -> tensor<2x64xi32>
+  // CHECK:          %[[INPUT0:.*]] = "tf.Const"() <{value = dense<0> : tensor<4xi32>}> : () -> tensor<4xi32>
+  // CHECK-NEXT:     %[[GROUP:.*]] = "tf.Const"() <{value = dense<0> : tensor<2x64xi32>}> : () -> tensor<2x64xi32>
   // CHECK-NEXT:     %[[CAST_OUT:.*]] = "tf.Cast"(%[[INPUT0]])
   // CHECK-NEXT:     %[[ADD_OUT:.*]] = "tf.AddV2"(%[[CAST_OUT]], %[[BARG0]])
   // CHECK-NEXT:     %[[OUT:.*]] = "tf.Identity"(%[[ADD_OUT]])
   // CHECK-NEXT:     "tf.Yield"
-  // CHECK:      %[[NEW_GROUP_ASSIGN:.*]] = "tf.Const"() {value = dense<0> : tensor<2x64xi32>} : () -> tensor<2x64xi32>
+  // CHECK:      %[[NEW_GROUP_ASSIGN:.*]] = "tf.Const"() <{value = dense<0> : tensor<2x64xi32>}> : () -> tensor<2x64xi32>
   // CHECK:      %[[ALL_REDUCE_OUT:.*]] = "tf.DTensorAllReduce"(%[[WHILE_OUT]]#0, %[[NEW_GROUP_ASSIGN]])
   %0 = "tf.Const"() {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32>
   %2 = "tf.Identity"(%0) : (tensor<4xf32>) -> tensor<4xf32>
diff --git a/tensorflow/dtensor/mlir/tests/dtensor_reduce_scatter_lowering.mlir b/tensorflow/dtensor/mlir/tests/dtensor_reduce_scatter_lowering.mlir
index c57183a..946720b 100644
--- a/tensorflow/dtensor/mlir/tests/dtensor_reduce_scatter_lowering.mlir
+++ b/tensorflow/dtensor/mlir/tests/dtensor_reduce_scatter_lowering.mlir
@@ -39,12 +39,12 @@
   // CHECK:      "tf_device.cluster"
   // CHECK:       %[[DEVICE_ID_RESHAPE:.*]] = "tf.Reshape"(%arg0
   // CHECK:       %[[RELATIVE_DEVICE_ID:.*]] = "tf.Sub"(%[[DEVICE_ID_RESHAPE]]
-  // CHECK-DAG:   %[[CONST_1:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>}
-  // CHECK-DAG:   %[[DEVICE_ID_TO_GROUP_KEY:.*]] = "tf.Const"() {value = dense<[0, 0, 1, 1, 2, 2, 3, 3]> : tensor<8xi32>}
+  // CHECK-DAG:   %[[CONST_1:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}>
+  // CHECK-DAG:   %[[DEVICE_ID_TO_GROUP_KEY:.*]] = "tf.Const"() <{value = dense<[0, 0, 1, 1, 2, 2, 3, 3]> : tensor<8xi32>}>
   // CHECK:       %[[GROUP_KEY_SLICE:.*]] = "tf.Slice"(%[[DEVICE_ID_TO_GROUP_KEY]], %[[RELATIVE_DEVICE_ID]], %[[CONST_1]]
   // CHECK:       %[[GROUP_KEY_RESHAPE:.*]] = "tf.Reshape"(%[[GROUP_KEY_SLICE]]
-  // CHECK-DAG:   %[[INSTANCE_KEY:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} 
-  // CHECK-DAG:   %[[GROUP_SIZE:.*]] = "tf.Const"() {value = dense<2> : tensor<i32>}
+  // CHECK-DAG:   %[[INSTANCE_KEY:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}>
+  // CHECK-DAG:   %[[GROUP_SIZE:.*]] = "tf.Const"() <{value = dense<2> : tensor<i32>}>
   // CHECK:       %[[REDUCE_SCATTER_OUT:.*]] = "tf.CollectiveReduceScatterV2"(%arg1, %[[GROUP_SIZE]], %[[GROUP_KEY_RESHAPE]], %[[INSTANCE_KEY]])
   // CHECK-SAME:  final_op = "Id"
   // CHECK-SAME:  merge_op = "Add"
diff --git a/tensorflow/dtensor/mlir/tests/layout_propagation_v2.mlir b/tensorflow/dtensor/mlir/tests/layout_propagation_v2.mlir
index 2e9b079..7bd63d9 100644
--- a/tensorflow/dtensor/mlir/tests/layout_propagation_v2.mlir
+++ b/tensorflow/dtensor/mlir/tests/layout_propagation_v2.mlir
@@ -4,7 +4,7 @@
 // CHECK-LABEL: func @main
 func.func @main() {
     // CHECK:        "tf_device.cluster"()
-    // CHECK-NEXT:     %[[CONST_OUT:.*]] = "tf.Const"() {_global_shape = [#tf_type.shape<>], value = dense<10> : tensor<i32>}
+    // CHECK-NEXT:     %[[CONST_OUT:.*]] = "tf.Const"() <{value = dense<10> : tensor<i32>}> {_global_shape = [#tf_type.shape<>]}
     // CHECK-NEXT:     %[[DTENSOR_LAYOUT_OUT:.*]] = "tf.DTensorLayout"(%[[CONST_OUT]])
     // CHECK-SAME:     layout = #dtensor.layout<sharding_specs: mesh:CPU|x=4,y=1|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1,/job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3>
     // CHECK-NEXT:     %[[NEG_OUT:.*]] = "tf.Neg"(%[[DTENSOR_LAYOUT_OUT]])
@@ -56,10 +56,10 @@
 func.func @main() {
    %6, %7 = "tf_device.cluster"() ({
     // CHECK:        "tf_device.cluster"()
-    // CHECK-NEXT:     %[[CONST_OUT_1:.*]] = "tf.Const"() {_global_shape = [#tf_type.shape<2x2>], value = dense<10> : tensor<2x2xi32>}
+    // CHECK-NEXT:     %[[CONST_OUT_1:.*]] = "tf.Const"() <{value = dense<10> : tensor<2x2xi32>}> {_global_shape = [#tf_type.shape<2x2>]}
     // CHECK-NEXT:     %[[DTENSOR_LAYOUT_OUT:.*]] = "tf.DTensorLayout"(%[[CONST_OUT_1]])
     // CHECK-SAME:     layout = #dtensor.layout<sharding_specs:x,y, mesh:CPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1,/job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3>
-    // CHECK-NEXT:     %[[CONST_OUT_2:.*]] = "tf.Const"() {_global_shape = [#tf_type.shape<2x2>], value = dense<10> : tensor<2x2xi32>}
+    // CHECK-NEXT:     %[[CONST_OUT_2:.*]] = "tf.Const"() <{value = dense<10> : tensor<2x2xi32>}> {_global_shape = [#tf_type.shape<2x2>]}
     // CHECK-NEXT:     %[[DTENSOR_LAYOUT_OUT:.*]] = "tf.DTensorLayout"(%[[CONST_OUT_2]])
     // CHECK-SAME:     layout = #dtensor.layout<sharding_specs:x,unsharded, mesh:CPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1,/job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3>
     %1 = "tf.Const"() {value = dense<10> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
@@ -896,7 +896,8 @@
                 %arg3: tensor<8x32x32x32x3xf32>) {
   // CHECK:      "tf_device.cluster"
   // CHECK:      %[[CONV_OUT:.*]] = "tf.Conv3DBackpropInput"
-  // CHECK-SAME: data_format = "NDHWC", dilations = [1, 1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1, 1]
+  // CHECK-SAME: dilations = [1, 1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1, 1]
+  // CHECK-SAME: data_format = "NDHWC"
   // CHECK:      "tf.DTensorLayout"(%[[CONV_OUT]])
   // CHECK-SAME: layout = #dtensor.layout<sharding_specs:x,unsharded,unsharded,unsharded,unsharded, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3>
   %0 = "tf_device.cluster"() ({
@@ -1018,7 +1019,8 @@
                 %arg3: tensor<8x32x32x32x3xf32>) {
   // CHECK:      "tf_device.cluster"
   // CHECK:      %[[CONV_OUT:.*]] = "tf.Conv3DBackpropFilter"
-  // CHECK-SAME: data_format = "NDHWC", dilations = [1, 1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1, 1]
+  // CHECK-SAME: dilations = [1, 1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1, 1]
+  // CHECK-SAME: data_format = "NDHWC"
   // CHECK:      "tf.DTensorLayout"(%[[CONV_OUT]])
   // CHECK-SAME: layout = #dtensor.layout<sharding_specs:unsharded,unsharded,unsharded,unsharded,unsharded, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3>
   %0 = "tf_device.cluster"() ({
diff --git a/tensorflow/dtensor/mlir/tests/lower_send_recv.mlir b/tensorflow/dtensor/mlir/tests/lower_send_recv.mlir
index 5a06d5d..552faaf 100644
--- a/tensorflow/dtensor/mlir/tests/lower_send_recv.mlir
+++ b/tensorflow/dtensor/mlir/tests/lower_send_recv.mlir
@@ -11,22 +11,22 @@
   // CHECK-DAG:   %[[RECV_DEVICE_ORDINAL:.*]] = "tf.Slice"(%[[RECV_ID_TO_ORDINAL:.*]], %[[RECV_DEVICE_ID]], %[[RECV_SLICE_SIZE:[^)]*]])
   // CHECK-DAG:   %[[RECV_DEVICE_ORDINAL_SCALAR:.*]] = "tf.Reshape"(%[[RECV_DEVICE_ORDINAL]], %[[RECV_SCALAR_TYPE:[^)]*]])
   // CHECK-DAG:   %[[RECV_DEVICE_ORDINAL_SCALAR_64:.*]] = "tf.Cast"(%[[RECV_DEVICE_ORDINAL_SCALAR]])
-  // CHECK-DAG:    %[[RECV_ID_TO_ORDINAL]] = "tf.Const"() {value = dense<0> : tensor<1xi32>}
-  // CHECK-DAG:    %[[RECV_SIZE_TYPE]] = "tf.Const"() {value = dense<1> : tensor<1xi32>}
-  // CHECK-DAG:    %[[RECV_SLICE_SIZE]] = "tf.Const"() {value = dense<1> : tensor<1xi32>}
-  // CHECK-DAG:    %[[RECV_SCALAR_TYPE]] = "tf.Const"() {value = dense<> : tensor<0xi32>}
+  // CHECK-DAG:    %[[RECV_ID_TO_ORDINAL]] = "tf.Const"() <{value = dense<0> : tensor<1xi32>}>
+  // CHECK-DAG:    %[[RECV_SIZE_TYPE]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}>
+  // CHECK-DAG:    %[[RECV_SLICE_SIZE]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}>
+  // CHECK-DAG:    %[[RECV_SCALAR_TYPE]] = "tf.Const"() <{value = dense<> : tensor<0xi32>}>
   // COMMENT: Recv and Send seperated by the output tensor.
   // CHECK:   %[[PROGRAM_KEY:.*]] = "tf._XlaCompileMlirPlaceholderProgramKey"
-  // CHECK-NEXT:   %[[CONST_OUT:.*]] = "tf.Const"() {value = dense<10> : tensor<1xi32>}
+  // CHECK-NEXT:   %[[CONST_OUT:.*]] = "tf.Const"() <{value = dense<10> : tensor<1xi32>}>
   // CHECK-NEXT:   %[[LAYOUT_OUT:.*]] = "tf.DTensorLayout"(%[[CONST_OUT]])
   // CHECK-DAG:   %[[SEND_DEVICE_ID:.*]] = "tf.Reshape"(%[[DEVICE_ID]], %[[SEND_SIZE_TYPE:[^)]*]])
   // CHECK-DAG:   %[[SEND_DEVICE_ORDINAL:.*]] = "tf.Slice"(%[[SEND_ID_TO_ORDINAL:.*]], %[[SEND_DEVICE_ID]], %[[SEND_SLICE_SIZE:[^)]*]])
   // CHECK-DAG:   %[[SEND_DEVICE_ORDINAL_SCALAR:.*]] = "tf.Reshape"(%[[SEND_DEVICE_ORDINAL]], %[[SEND_SCALAR_TYPE:[^)]*]])
   // CHECK-DAG:   %[[SEND_DEVICE_ORDINAL_SCALAR_64:.*]] = "tf.Cast"(%[[SEND_DEVICE_ORDINAL_SCALAR]])
-  // CHECK-DAG:    %[[SEND_ID_TO_ORDINAL]] = "tf.Const"() {value = dense<0> : tensor<1xi32>}
-  // CHECK-DAG:    %[[SEND_SIZE_TYPE]] = "tf.Const"() {value = dense<1> : tensor<1xi32>}
-  // CHECK-DAG:    %[[SEND_SLICE_SIZE]] = "tf.Const"() {value = dense<1> : tensor<1xi32>}
-  // CHECK-DAG:    %[[SEND_SCALAR_TYPE]] = "tf.Const"() {value = dense<> : tensor<0xi32>}
+  // CHECK-DAG:    %[[SEND_ID_TO_ORDINAL]] = "tf.Const"() <{value = dense<0> : tensor<1xi32>}>
+  // CHECK-DAG:    %[[SEND_SIZE_TYPE]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}>
+  // CHECK-DAG:    %[[SEND_SLICE_SIZE]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}>
+  // CHECK-DAG:    %[[SEND_SCALAR_TYPE]] = "tf.Const"() <{value = dense<> : tensor<0xi32>}>
   // CHECK:   "tf._XlaSendFromHostV2"(%[[LAYOUT_OUT]], %[[PROGRAM_KEY]], %[[SEND_DEVICE_ORDINAL_SCALAR_64]])
   // CHECK-NEXT:   %[[RECV_OUT:.*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_KEY]], %[[RECV_DEVICE_ORDINAL_SCALAR_64]])
   // CHECK-SAME:   key = "CPU|x=1|0|0|/job:localhost/task:0/device:CPU:0_2"
diff --git a/tensorflow/dtensor/mlir/tests/move_compilation_to_host.mlir b/tensorflow/dtensor/mlir/tests/move_compilation_to_host.mlir
index 9295cbf..06f9b2f 100644
--- a/tensorflow/dtensor/mlir/tests/move_compilation_to_host.mlir
+++ b/tensorflow/dtensor/mlir/tests/move_compilation_to_host.mlir
@@ -6,11 +6,11 @@
   // CHECK-LABEL: func @main
   func.func @main(%arg0: tensor<i32>,%arg1: tensor<4xi32> {tf._layout = "sharding_specs:unsharded, mesh:|x=2|0,1|0,1|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1", tf._mesh = "|x=2|0,1|0,1|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1"}) -> (tensor<f32> {tf._global_shape = #tf_type.shape<>}) attributes {tf.entry_function = {control_outputs = "eager_operation", inputs = "device_id,op_input_0", outputs = "op_output_0"}} {
     // CHECK:       "tf.StatefulPartitionedCall"
-    // CHECK-SAME:  _mesh = "|x=2|0,1|0,1|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1"
     // CHECK-SAME:  f = @_func_0
+    // CHECK-SAME:  _mesh = "|x=2|0,1|0,1|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1"
     // CHECK-NEXT:  "tf.StatefulPartitionedCall"
-    // CHECK-SAME:  _mesh = "|x=1|0|0|/job:localhost/replica:0/task:0/device:CPU:0"
     // CHECK-SAME:  f = @_func_1
+    // CHECK-SAME:  _mesh = "|x=1|0|0|/job:localhost/replica:0/task:0/device:CPU:0"
     "tf.StatefulPartitionedCall"(%arg0, %arg1) {_layout = [], _mesh = "|x=2|0,1|0,1|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1", config = "|x=2|0,1|0,1|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1", config_proto = "", executor_type = "", f = @_func_0} : (tensor<i32>, tensor<4xi32>) -> ()
     %0 = "tf.StatefulPartitionedCall"(%arg0) {_layout = ["sharding_specs: mesh:|x=1|0|0|/job:localhost/replica:0/task:0/device:CPU:0"], _mesh = "|x=1|0|0|/job:localhost/replica:0/task:0/device:CPU:0", config = "|x=1|0|0|/job:localhost/replica:0/task:0/device:CPU:0", config_proto = "", executor_type = "", f = @_func_1} : (tensor<i32>) -> tensor<f32>
     func.return %0 : tensor<f32>
@@ -53,31 +53,29 @@
   // CHECK-LABEL: func private @_func_1
   // CHECK-SAME:  %[[ARG0:.*]]: tensor<i32>
   func.func private @_func_1(%arg0: tensor<i32>) -> tensor<f32> {
-    // CHECK:      %[[COMPILE_OUT:.*]]:2 = "tf_device.launch"()
+    // CHECK:      %[[COMPILE_OUT:.*]]:2 = "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}>
     // CHECK-NEXT:   %[[COMPILATION_STATUS:.*]], %[[PROGRAM_KEY:.*]] = "tf._TPUCompileMlir"()
     // CHECK-NEXT:   "tf._HostSend"(%[[PROGRAM_KEY]])
-    // CHECK-SAME:   device = "/job:localhost/replica:0/task:0/device:CPU:0"
     // CHECK-SAME:   recv_device = "/job:localhost/replica:0/task:0/device:CPU:0"
     // CHECK-SAME:   send_device = "/job:localhost/replica:0/task:0/device:CPU:0"
-    // CHECK-NEXT:   "tf._HostSend"(%[[PROGRAM_KEY]])
     // CHECK-SAME:   device = "/job:localhost/replica:0/task:0/device:CPU:0"
+    // CHECK-NEXT:   "tf._HostSend"(%[[PROGRAM_KEY]])
     // CHECK-SAME:   recv_device = "/job:localhost/replica:0/task:0/device:TPU:0"
     // CHECK-SAME:   send_device = "/job:localhost/replica:0/task:0/device:CPU:0"
     // CHECK-SAME:   send_device_incarnation = 0
     // CHECK-SAME:   tensor_name = "compilation_send_recv_key_0
-    // CHECK-NEXT:   "tf._HostSend"(%[[PROGRAM_KEY]])
     // CHECK-SAME:   device = "/job:localhost/replica:0/task:0/device:CPU:0"
+    // CHECK-NEXT:   "tf._HostSend"(%[[PROGRAM_KEY]])
     // CHECK-SAME:   recv_device = "/job:localhost/replica:0/task:0/device:TPU:1"
     // CHECK-SAME:   send_device = "/job:localhost/replica:0/task:0/device:CPU:0"
     // CHECK-SAME:   send_device_incarnation = 0
     // CHECK-SAME:   tensor_name = "compilation_send_recv_key_1
+    // CHECK-SAME:   device = "/job:localhost/replica:0/task:0/device:CPU:0"
     // CHECK-NEXT:   tf_device.return %[[COMPILATION_STATUS]], %[[PROGRAM_KEY]]
-    // CHECK-NEXT: device = "/job:localhost/replica:0/task:0/device:CPU:0"
-    // CHECK-NEXT: "tf_device.launch"()
+    // CHECK:      "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}>
     // CHECK-NEXT:   "tf.TPUCompileSucceededAssert"(%[[COMPILE_OUT]]#0)
     // CHECK-NEXT:   tf_device.return
-    // CHECK-NEXT: device = "/job:localhost/replica:0/task:0/device:CPU:0"
-    // CHECK-NEXT:   %[[ID_TO_ORDINAL:.*]] = "tf.Const"
+    // CHECK:        %[[ID_TO_ORDINAL:.*]] = "tf.Const"
     // CHECK-SAME:   value = dense<0>
     // CHECK-NEXT:   %[[SIZE_TYPE:.*]] = "tf.Const"
     // CHECK-SAME:   value = dense<1>
@@ -118,11 +116,11 @@
   // CHECK-LABEL: func @main
   func.func @main(%arg0: tensor<i32>,%arg1: tensor<*x!tf_type.resource<tensor<4xf32>>> {tf._layout = "sharding_specs:unsharded, mesh:|x=2|0,1|0,1|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1", tf._mesh = "|x=2|0,1|0,1|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1"}) -> (tensor<f32> {tf._global_shape = #tf_type.shape<>}) attributes {tf.entry_function = {control_outputs = "eager_operation", inputs = "device_id,op_input_0", outputs = "op_output_0"}} {
     // CHECK:       "tf.StatefulPartitionedCall"
-    // CHECK-SAME:  _mesh = "|x=2|0,1|0,1|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1"
     // CHECK-SAME:  f = @_func_0
+    // CHECK-SAME:  _mesh = "|x=2|0,1|0,1|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1"
     // CHECK-NEXT:  "tf.StatefulPartitionedCall"
-    // CHECK-SAME:  _mesh = "|x=1|0|0|/job:localhost/replica:0/task:0/device:CPU:0"
     // CHECK-SAME:  f = @_func_1
+    // CHECK-SAME:  _mesh = "|x=1|0|0|/job:localhost/replica:0/task:0/device:CPU:0"
     "tf.StatefulPartitionedCall"(%arg0, %arg1) {_layout = [], _mesh = "|x=2|0,1|0,1|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1", config = "|x=2|0,1|0,1|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1", config_proto = "", executor_type = "", f = @_func_0} : (tensor<i32>, tensor<*x!tf_type.resource<tensor<4xf32>>>) -> ()
     %0 = "tf.StatefulPartitionedCall"(%arg0) {_layout = ["sharding_specs: mesh:|x=1|0|0|/job:localhost/replica:0/task:0/device:CPU:0"], _mesh = "|x=1|0|0|/job:localhost/replica:0/task:0/device:CPU:0", config = "|x=1|0|0|/job:localhost/replica:0/task:0/device:CPU:0", config_proto = "", executor_type = "", f = @_func_1} : (tensor<i32>) -> tensor<f32>
     func.return %0 : tensor<f32>
@@ -165,32 +163,30 @@
   // CHECK-LABEL: func private @_func_1
   // CHECK-SAME:  %[[ARG0:.*]]: tensor<i32>
   func.func private @_func_1(%arg0: tensor<i32>) -> tensor<f32> {
-    // CHECK:      %[[COMPILE_OUT:.*]]:2 = "tf_device.launch"()
+    // CHECK:      %[[COMPILE_OUT:.*]]:2 = "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}>
     // CHECK-NEXT:   %[[COMPILATION_STATUS:.*]], %[[PROGRAM_KEY:.*]] = "tf._TPUCompileMlir"()
     // CHECK-NEXT:   "tf._HostSend"(%[[PROGRAM_KEY]])
-    // CHECK-SAME:   device = "/job:localhost/replica:0/task:0/device:CPU:0"
     // CHECK-SAME:   recv_device = "/job:localhost/replica:0/task:0/device:CPU:0"
     // CHECK-SAME:   send_device = "/job:localhost/replica:0/task:0/device:CPU:0"
     // CHECK-SAME:   send_device_incarnation = 0
-    // CHECK-NEXT:   "tf._HostSend"(%[[PROGRAM_KEY]])
     // CHECK-SAME:   device = "/job:localhost/replica:0/task:0/device:CPU:0"
+    // CHECK-NEXT:   "tf._HostSend"(%[[PROGRAM_KEY]])
     // CHECK-SAME:   recv_device = "/job:localhost/replica:0/task:0/device:TPU:0"
     // CHECK-SAME:   send_device = "/job:localhost/replica:0/task:0/device:CPU:0"
     // CHECK-SAME:   send_device_incarnation = 0
     // CHECK-SAME:   tensor_name = "compilation_send_recv_key_0
-    // CHECK-NEXT:   "tf._HostSend"(%[[PROGRAM_KEY]])
     // CHECK-SAME:   device = "/job:localhost/replica:0/task:0/device:CPU:0"
+    // CHECK-NEXT:   "tf._HostSend"(%[[PROGRAM_KEY]])
     // CHECK-SAME:   recv_device = "/job:localhost/replica:0/task:0/device:TPU:1"
     // CHECK-SAME:   send_device = "/job:localhost/replica:0/task:0/device:CPU:0"
     // CHECK-SAME:   send_device_incarnation = 0
     // CHECK-SAME:   tensor_name = "compilation_send_recv_key_1
+    // CHECK-SAME:   device = "/job:localhost/replica:0/task:0/device:CPU:0"
     // CHECK-NEXT:   tf_device.return %[[COMPILATION_STATUS]], %[[PROGRAM_KEY]]
-    // CHECK-NEXT: device = "/job:localhost/replica:0/task:0/device:CPU:0"
-    // CHECK-NEXT: "tf_device.launch"()
+    // CHECK:      "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}>
     // CHECK-NEXT:   "tf.TPUCompileSucceededAssert"(%[[COMPILE_OUT]]#0)
     // CHECK-NEXT:   tf_device.return
-    // CHECK-NEXT: device = "/job:localhost/replica:0/task:0/device:CPU:0"
-    // CHECK-NEXT: %[[ID_TO_ORDINAL:.*]] = "tf.Const"
+    // CHECK:      %[[ID_TO_ORDINAL:.*]] = "tf.Const"
     // CHECK-SAME: value = dense<0>
     // CHECK-NEXT: %[[SIZE_TYPE:.*]] = "tf.Const"
     // CHECK-SAME: value = dense<1>
diff --git a/tensorflow/dtensor/mlir/tests/multi_device_expansion.mlir b/tensorflow/dtensor/mlir/tests/multi_device_expansion.mlir
index fd7ab20..091a9b2 100644
--- a/tensorflow/dtensor/mlir/tests/multi_device_expansion.mlir
+++ b/tensorflow/dtensor/mlir/tests/multi_device_expansion.mlir
@@ -24,11 +24,11 @@
   // CHECK: %arg7: tensor<8xi32> {tf.device = "/job:localhost/replica:0/task:0/device:CPU:7"}
   // CHECK: tf.entry_function = {inputs = "input_0,input_1,input_2,input_3,input_4,input_5,input_6,input_7", outputs = "output_0,output_1,output_2,output_3,output_4,output_5,output_6,output_7"
   // CHECK: %[[RES:.*]]:8 = "tf.StatefulPartitionedCall"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7)
+  // CHECK-SAME: f = @_multi_device_func_5372333290171538790_8525065017853554746
   // CHECK-SAME: _layout = ["sharding_specs:unsharded, mesh:|x=2,y=4|0,1,2,3,4,5,6,7|0,1,2,3,4,5,6,7|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1,/job:localhost/replica:0/task:0/device:CPU:2,/job:localhost/replica:0/task:0/device:CPU:3,/job:localhost/replica:0/task:0/device:CPU:4,/job:localhost/replica:0/task:0/device:CPU:5,/job:localhost/replica:0/task:0/device:CPU:6,/job:localhost/replica:0/task:0/device:CPU:7"]
-  // CHECK-SAME: f = @_multi_device_func_4093838507448400597_971647271862201157
   // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3, %[[RES]]#4, %[[RES]]#5, %[[RES]]#6, %[[RES]]#7
 
-  // CHECK-LABEL: func.func private @_multi_device_func_4093838507448400597_971647271862201157
+  // CHECK-LABEL: func.func private @_multi_device_func_5372333290171538790_8525065017853554746(
   // CHECK: %arg0: tensor<8xi32> {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}
   // CHECK: %arg1: tensor<8xi32> {tf.device = "/job:localhost/replica:0/task:0/device:CPU:1"}
   // CHECK: %arg2: tensor<8xi32> {tf.device = "/job:localhost/replica:0/task:0/device:CPU:2"}
@@ -38,7 +38,7 @@
   // CHECK: %arg6: tensor<8xi32> {tf.device = "/job:localhost/replica:0/task:0/device:CPU:6"}
   // CHECK: %arg7: tensor<8xi32> {tf.device = "/job:localhost/replica:0/task:0/device:CPU:7"}
   // CHECK: tf.entry_function = {inputs = "input_0,input_1,input_2,input_3,input_4,input_5,input_6,input_7", outputs = "output_0,output_1,output_2,output_3,output_4,output_5,output_6,output_7"
-  // CHECK: %[[CST0:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: %[[CST0:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
   // CHECK: %[[CST1:.*]] = "tf.Const"
   // CHECK: %[[CST2:.*]] = "tf.Const"
   // CHECK: %[[CST3:.*]] = "tf.Const"
@@ -133,25 +133,25 @@
 // CHECK-SAME: %arg0: tensor<1x2xi32> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}
 // CHECK-SAME: %arg1: tensor<1x2xi32> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:1"}
 // CHECK-SAME: -> (tensor<2xi32>, tensor<2xi32>) {
-// CHECK-NEXT:   %0:2 = "tf_device.launch"() ({
-// CHECK-NEXT:     %compilation_status, %program = "tf._TPUCompileMlir"() {metadata = ""} : () -> (tensor<!tf_type.string>, tensor<3x!tf_type.string>)
+// CHECK-NEXT:   %0:2 = "tf_device.launch"() <{device = ""}> ({
+// CHECK-NEXT:     %compilation_status, %program = "tf._TPUCompileMlir"() <{metadata = ""}> : () -> (tensor<!tf_type.string>, tensor<3x!tf_type.string>)
 // CHECK-NEXT:     tf_device.return %compilation_status, %program : tensor<!tf_type.string>, tensor<3x!tf_type.string>
-// CHECK-NEXT:   }) {device = ""} : () -> (tensor<!tf_type.string>, tensor<3x!tf_type.string>)
-// CHECK-NEXT:   "tf_device.launch"() ({
+// CHECK-NEXT:   }) : () -> (tensor<!tf_type.string>, tensor<3x!tf_type.string>)
+// CHECK-NEXT:   "tf_device.launch"() <{device = ""}> ({
 // CHECK-NEXT:     "tf.TPUCompileSucceededAssert"(%0#0) : (tensor<!tf_type.string>) -> ()
 // CHECK-NEXT:     tf_device.return
-// CHECK-NEXT:   }) {device = ""} : () -> ()
+// CHECK-NEXT:   }) : () -> ()
 // CHECK-NEXT:   %1:2 = "tf_device.parallel_execute"() ({
-// CHECK-NEXT:     %2 = "tf_device.launch"() ({
+// CHECK-NEXT:     %2 = "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:TPU:0"}> ({
 // CHECK-NEXT:       %3 = "tf.TPUExecute"(%arg0, %0#1) : (tensor<1x2xi32>, tensor<3x!tf_type.string>) -> tensor<2xi32>
 // CHECK-NEXT:       tf_device.return %3 : tensor<2xi32>
-// CHECK-NEXT:     }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : () -> tensor<2xi32>
+// CHECK-NEXT:     }) : () -> tensor<2xi32>
 // CHECK-NEXT:     tf_device.return %2 : tensor<2xi32>
 // CHECK-NEXT:   }, {
-// CHECK-NEXT:     %2 = "tf_device.launch"() ({
+// CHECK-NEXT:     %2 = "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:TPU:1"}> ({
 // CHECK-NEXT:       %3 = "tf.TPUExecute"(%arg1, %0#1) : (tensor<1x2xi32>, tensor<3x!tf_type.string>) -> tensor<2xi32>
 // CHECK-NEXT:       tf_device.return %3 : tensor<2xi32>
-// CHECK-NEXT:     }) {device = "/job:localhost/replica:0/task:0/device:TPU:1"} : () -> tensor<2xi32>
+// CHECK-NEXT:     }) : () -> tensor<2xi32>
 // CHECK-NEXT:     tf_device.return %2 : tensor<2xi32>
 // CHECK-NEXT:   }) : () -> (tensor<2xi32>, tensor<2xi32>)
 // CHECK-NEXT:   return %1#0, %1#1 : tensor<2xi32>, tensor<2xi32>
@@ -189,29 +189,29 @@
 // CHECK-SAME: %arg1: tensor<i32> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:1"}
 // CHECK-SAME: %arg2: tensor<!tf_type.resource<tensor<i32>>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}
 // CHECK-SAME: %arg3: tensor<!tf_type.resource<tensor<i32>>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:1"}
-// CHECK-NEXT:   %0:2 = "tf_device.launch"() ({
-// CHECK-NEXT:     %compilation_status, %program = "tf._TPUCompileMlir"() {metadata = ""} : () -> (tensor<!tf_type.string>, tensor<3x!tf_type.string>)
+// CHECK-NEXT:   %0:2 = "tf_device.launch"() <{device = ""}> ({
+// CHECK-NEXT:     %compilation_status, %program = "tf._TPUCompileMlir"() <{metadata = ""}> : () -> (tensor<!tf_type.string>, tensor<3x!tf_type.string>)
 // CHECK-NEXT:     tf_device.return %compilation_status, %program : tensor<!tf_type.string>, tensor<3x!tf_type.string>
-// CHECK-NEXT:   }) {device = ""} : () -> (tensor<!tf_type.string>, tensor<3x!tf_type.string>)
-// CHECK-NEXT:   "tf_device.launch"() ({
+// CHECK-NEXT:   }) : () -> (tensor<!tf_type.string>, tensor<3x!tf_type.string>)
+// CHECK-NEXT:   "tf_device.launch"() <{device = ""}> ({
 // CHECK-NEXT:     "tf.TPUCompileSucceededAssert"(%0#0) : (tensor<!tf_type.string>) -> ()
 // CHECK-NEXT:     tf_device.return
-// CHECK-NEXT:   }) {device = ""} : () -> ()
+// CHECK-NEXT:   }) : () -> ()
 // CHECK-NEXT:   %1:2 = "tf_device.parallel_execute"() ({
-// CHECK-NEXT:     %2 = "tf_device.launch"() ({
+// CHECK-NEXT:     %2 = "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:TPU:0"}> ({
 // CHECK-NEXT:       %3 = "tf.TPUExecute"(%arg0, %0#1) : (tensor<i32>, tensor<3x!tf_type.string>) -> tensor<i32>
 // CHECK-NEXT:       tf_device.return %3 : tensor<i32>
-// CHECK-NEXT:     }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : () -> tensor<i32>
+// CHECK-NEXT:     }) : () -> tensor<i32>
 // CHECK-NEXT:     tf_device.return %2 : tensor<i32>
 // CHECK-NEXT:   }, {
-// CHECK-NEXT:     %2 = "tf_device.launch"() ({
+// CHECK-NEXT:     %2 = "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:TPU:1"}> ({
 // CHECK-NEXT:       %3 = "tf.TPUExecute"(%arg1, %0#1) : (tensor<i32>, tensor<3x!tf_type.string>) -> tensor<i32>
 // CHECK-NEXT:       tf_device.return %3 : tensor<i32>
-// CHECK-NEXT:     }) {device = "/job:localhost/replica:0/task:0/device:TPU:1"} : () -> tensor<i32>
+// CHECK-NEXT:     }) : () -> tensor<i32>
 // CHECK-NEXT:     tf_device.return %2 : tensor<i32>
 // CHECK-NEXT:   }) : () -> (tensor<i32>, tensor<i32>)
-// CHECK-NEXT:   "tf.AssignVariableOp"(%arg2, %1#0) {_global_shape = [], _layout = [], device = "/job:localhost/replica:0/task:0/device:TPU:0", validate_shape = false} : (tensor<!tf_type.resource<tensor<i32>>>, tensor<i32>) -> ()
-// CHECK-NEXT:   "tf.AssignVariableOp"(%arg3, %1#1) {_global_shape = [], _layout = [], device = "/job:localhost/replica:0/task:0/device:TPU:1", validate_shape = false} : (tensor<!tf_type.resource<tensor<i32>>>, tensor<i32>) -> ()
+// CHECK-NEXT:   "tf.AssignVariableOp"(%arg2, %1#0) <{validate_shape = false}> {_global_shape = [], _layout = [], device = "/job:localhost/replica:0/task:0/device:TPU:0"} : (tensor<!tf_type.resource<tensor<i32>>>, tensor<i32>) -> ()
+// CHECK-NEXT:   "tf.AssignVariableOp"(%arg3, %1#1) <{validate_shape = false}> {_global_shape = [], _layout = [], device = "/job:localhost/replica:0/task:0/device:TPU:1"} : (tensor<!tf_type.resource<tensor<i32>>>, tensor<i32>) -> ()
 // CHECK-NEXT:   return
 // CHECK-NEXT: }
 
diff --git a/tensorflow/dtensor/mlir/tests/propagate_default_layout.mlir b/tensorflow/dtensor/mlir/tests/propagate_default_layout.mlir
index 18e49f7..0d4947c 100644
--- a/tensorflow/dtensor/mlir/tests/propagate_default_layout.mlir
+++ b/tensorflow/dtensor/mlir/tests/propagate_default_layout.mlir
@@ -12,12 +12,12 @@
   %arg1: tensor<i32>{ tf._layout = "sharding_specs:scalar mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1,/job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3"},
   %arg2: tensor<i32>{ tf._layout = "sharding_specs:scalar mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1,/job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3"}) -> (tensor<i32>) {
   // CHECK:      %[[ARG1_OUT:[a-z0-9]*]] = "tf.DTensorLayout"(%arg[[ARG_1]])
-  // CHECK-SAME: dtensor.from_arg_index = [[ARG_1]]
   // CHECK-SAME: layout = #dtensor.layout<sharding_specs: mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1,/job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3>
+  // CHECK-SAME: dtensor.from_arg_index = [[ARG_1]]
   // CHECK-SAME: (tensor<i32>) -> tensor<i32>
   // CHECK-NEXT: %[[ARG0_OUT:[a-z0-9]*]] = "tf.DTensorLayout"(%arg[[ARG_0]])
-  // CHECK-SAME: dtensor.from_arg_index = [[ARG_0]]
   // CHECK-SAME: layout = #dtensor.layout<sharding_specs: mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1,/job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3>
+  // CHECK-SAME: dtensor.from_arg_index = [[ARG_0]]
   // CHECK-SAME: (tensor<i32>) -> tensor<i32>
   // CHECK-NEXT: "tf.A"(%[[ARG0_OUT]], %[[ARG1_OUT]])
   // CHECK-NEXT: "tf.B"(%[[ARG1_OUT]])
diff --git a/tensorflow/dtensor/mlir/tests/restore_and_assign.mlir b/tensorflow/dtensor/mlir/tests/restore_and_assign.mlir
index ee6ca5b..d613c17 100644
--- a/tensorflow/dtensor/mlir/tests/restore_and_assign.mlir
+++ b/tensorflow/dtensor/mlir/tests/restore_and_assign.mlir
@@ -18,8 +18,8 @@
     // CHECK-NEXT:       "tf.DTensorLayout"
     // CHECK-NEXT:       "tf.DTensorLayout"
     // CHECK-NEXT:       %[[RESTORE:.*]] = "tf.RestoreV2"(%0, %1, %2) : (tensor<!tf_type.string>, tensor<!tf_type.string>, tensor<!tf_type.string>) -> tensor<4x8xf32>
-    // CHECK-NEXT:       %[[DLAYOUT:.*]] = "tf.DTensorLayout"(%[[RESTORE]]) {global_shape = #tf_type.shape<4x8>, layout = #dtensor.layout<sharding_specs:x,unsharded, mesh:|x=2|0,1|0,1|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1>} : (tensor<4x8xf32>) -> tensor<4x8xf32>
-    // CHECK-NEXT:       "tf.AssignVariableOp"(%3, %[[DLAYOUT]]) {validate_shape = true} : (tensor<*x!tf_type.resource<tensor<4x8xf32>>>, tensor<4x8xf32>) -> ()
+    // CHECK-NEXT:       %[[DLAYOUT:.*]] = "tf.DTensorLayout"(%[[RESTORE]]) <{global_shape = #tf_type.shape<4x8>, layout = #dtensor.layout<sharding_specs:x,unsharded, mesh:|x=2|0,1|0,1|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1>}> : (tensor<4x8xf32>) -> tensor<4x8xf32>
+    // CHECK-NEXT:       "tf.AssignVariableOp"(%3, %[[DLAYOUT]]) <{validate_shape = true}> : (tensor<*x!tf_type.resource<tensor<4x8xf32>>>, tensor<4x8xf32>) -> ()
     "tf_device.cluster"() ({
       %0 = "tf.DTensorLayout"(%arg1) {global_shape = #tf_type.shape<>, layout = #dtensor.layout<sharding_specs: mesh:|x=2|0,1|0,1|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1>} : (tensor<!tf_type.string>) -> tensor<!tf_type.string>
       %1 = "tf.DTensorLayout"(%arg2) {global_shape = #tf_type.shape<>, layout = #dtensor.layout<sharding_specs: mesh:|x=2|0,1|0,1|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1>} : (tensor<!tf_type.string>) -> tensor<!tf_type.string>
@@ -47,10 +47,10 @@
     // CHECK-NEXT:       "tf.DTensorLayout"
     // CHECK-NEXT:       "tf.DTensorLayout"
     // CHECK-NEXT:       %[[RESTORE:.*]]  = "tf.RestoreV2"(%0, %1, %2) : (tensor<!tf_type.string>, tensor<!tf_type.string>, tensor<!tf_type.string>) -> tensor<4x8xf32>
-    // CHECK-NEXT:       %[[DLAYOUT:.*]]  = "tf.DTensorLayout"(%[[RESTORE]]) {global_shape = #tf_type.shape<4x8>, layout = #dtensor.layout<sharding_specs:x,unsharded, mesh:|x=2|0,1|0,1|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1>} : (tensor<4x8xf32>) -> tensor<4x8xf32>
-    // CHECK-NEXT:       %[[CAST:.*]]     = "tf.Cast"(%[[DLAYOUT]]) {Truncate = false} : (tensor<4x8xf32>) -> tensor<4x8xf64>
-    // CHECK-NEXT:       %[[DLAYOUT2:.*]] = "tf.DTensorLayout"(%[[CAST]]) {global_shape = #tf_type.shape<4x8>, layout = #dtensor.layout<sharding_specs:x,unsharded, mesh:|x=2|0,1|0,1|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1>} : (tensor<4x8xf64>) -> tensor<4x8xf64>
-    // CHECK-NEXT:       "tf.AssignVariableOp"(%3, %[[DLAYOUT2]]) {validate_shape = true} : (tensor<*x!tf_type.resource<tensor<4x8xf64>>>, tensor<4x8xf64>) -> ()
+    // CHECK-NEXT:       %[[DLAYOUT:.*]]  = "tf.DTensorLayout"(%[[RESTORE]]) <{global_shape = #tf_type.shape<4x8>, layout = #dtensor.layout<sharding_specs:x,unsharded, mesh:|x=2|0,1|0,1|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1>}> : (tensor<4x8xf32>) -> tensor<4x8xf32>
+    // CHECK-NEXT:       %[[CAST:.*]]     = "tf.Cast"(%[[DLAYOUT]]) <{Truncate = false}> : (tensor<4x8xf32>) -> tensor<4x8xf64>
+    // CHECK-NEXT:       %[[DLAYOUT2:.*]] = "tf.DTensorLayout"(%[[CAST]]) <{global_shape = #tf_type.shape<4x8>, layout = #dtensor.layout<sharding_specs:x,unsharded, mesh:|x=2|0,1|0,1|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1>}> : (tensor<4x8xf64>) -> tensor<4x8xf64>
+    // CHECK-NEXT:       "tf.AssignVariableOp"(%3, %[[DLAYOUT2]]) <{validate_shape = true}> : (tensor<*x!tf_type.resource<tensor<4x8xf64>>>, tensor<4x8xf64>) -> ()
     "tf_device.cluster"() ({
       %0 = "tf.DTensorLayout"(%arg1) {global_shape = #tf_type.shape<>, layout = #dtensor.layout<sharding_specs: mesh:|x=2|0,1|0,1|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1>} : (tensor<!tf_type.string>) -> tensor<!tf_type.string>
       %1 = "tf.DTensorLayout"(%arg2) {global_shape = #tf_type.shape<>, layout = #dtensor.layout<sharding_specs: mesh:|x=2|0,1|0,1|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1>} : (tensor<!tf_type.string>) -> tensor<!tf_type.string>
@@ -75,7 +75,7 @@
   %arg4: tensor<*x!tf_type.resource<tensor<4x8xf32>>>) {
     // CHECK:        "tf_device.cluster"
     // CHECK-NEXT:       %[[RESOURCE:.*]] = "tf.DTensorLayout"(%arg4)
-    // CHECK-NEXT:       %[[RECV:.*]] = "tf.DTensorRecv"() {
+    // CHECK-NEXT:       %[[RECV:.*]] = "tf.DTensorRecv"() <{
     // CHECK-SAME:       key = "communication_key_|x=2|0,1|0,1|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1"
     // CHECK-SAME:       mesh = #dtensor.mesh<|x=2|0,1|0,1|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1>
     // CHECK-SAME:       shape = #tf_type.shape<4x8>
@@ -84,7 +84,7 @@
     // CHECK-NEXT:       %[[RECV_DL:.*]] = "tf.DTensorLayout"(%[[RECV]])
     // CHECK-NEXT:       %[[IDENTITY:.*]] = "tf.Identity"(%[[RECV_DL]]) : (tensor<4x8xf32>) -> tensor<4x8xf32>
     // CHECK-NEXT:       %[[IDENTITY_DL:.*]] = "tf.DTensorLayout"(%[[IDENTITY]])
-    // CHECK-NEXT:       "tf.AssignVariableOp"(%[[RESOURCE]], %[[IDENTITY_DL]]) {validate_shape = true} : (tensor<*x!tf_type.resource<tensor<4x8xf32>>>, tensor<4x8xf32>) -> ()
+    // CHECK-NEXT:       "tf.AssignVariableOp"(%[[RESOURCE]], %[[IDENTITY_DL]]) <{validate_shape = true}> : (tensor<*x!tf_type.resource<tensor<4x8xf32>>>, tensor<4x8xf32>) -> ()
     // CHECK-NEXT:       tf_device.return
     "tf_device.cluster"() ({
       %4 = "tf.DTensorLayout"(%arg4) {global_shape = #tf_type.shape<4x8>, layout = #dtensor.layout<sharding_specs:x,unsharded, mesh:|x=2|0,1|0,1|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1>} : (tensor<*x!tf_type.resource<tensor<4x8xf32>>>) -> tensor<*x!tf_type.resource<tensor<4x8xf32>>>
@@ -99,7 +99,7 @@
     // CHECK-NEXT:       %[[DL2:.*]] = "tf.DTensorLayout"(%arg2)
     // CHECK-NEXT:       %[[DL3:.*]] = "tf.DTensorLayout"(%arg3)
     // CHECK-NEXT:       %[[RESTORE:.*]] = "tf.RestoreV2"(%[[DL1]], %[[DL2]], %[[DL3]]) : (tensor<!tf_type.string>, tensor<!tf_type.string>, tensor<!tf_type.string>) -> tensor<4x8xf32>
-    // CHECK-NEXT:       "tf.DTensorLayout"(%[[RESTORE]]) {global_shape = #tf_type.shape<4x8>, layout = #dtensor.layout<sharding_specs:x,unsharded, mesh:CPU|x=2|0,1|0,1|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1>} : (tensor<4x8xf32>) -> tensor<4x8xf32>
+    // CHECK-NEXT:       "tf.DTensorLayout"(%[[RESTORE]]) <{global_shape = #tf_type.shape<4x8>, layout = #dtensor.layout<sharding_specs:x,unsharded, mesh:CPU|x=2|0,1|0,1|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1>}> : (tensor<4x8xf32>) -> tensor<4x8xf32>
     // CHECK-NEXT:       "tf.DTensorSend"
     // CHECK-NEXT:       tf_device.return
     "tf_device.cluster"() ({
diff --git a/tensorflow/dtensor/mlir/tests/restore_shape_inference.mlir b/tensorflow/dtensor/mlir/tests/restore_shape_inference.mlir
index 6f78b55..a925f5e 100644
--- a/tensorflow/dtensor/mlir/tests/restore_shape_inference.mlir
+++ b/tensorflow/dtensor/mlir/tests/restore_shape_inference.mlir
@@ -3,9 +3,9 @@
 // Check the tf.RestoreV2Op's and all connected ops' resulting types are inferred from the AssignVariableOps in a single mesh. All unknown shapes should be known after this pass.
 func.func @main(%arg0: tensor<i32>, %arg1: tensor<!tf_type.string>, %arg2: tensor<2x!tf_type.string>, %arg3: tensor<2x!tf_type.string>, %arg4: tensor<*x!tf_type.resource<tensor<4x8xf32>>>, %arg5: tensor<*x!tf_type.resource<tensor<i64>>>) {
     // CHECK:        %0:2 = "tf.RestoreV2"(%arg1, %arg2, %arg3) : (tensor<!tf_type.string>, tensor<2x!tf_type.string>, tensor<2x!tf_type.string>) -> (tensor<4x8xf32>, tensor<i64>)
-    // CHECK-NEXT:   "tf.AssignVariableOp"(%arg4, %0#0) {validate_shape = true} : (tensor<*x!tf_type.resource<tensor<4x8xf32>>>, tensor<4x8xf32>) -> ()
+    // CHECK-NEXT:   "tf.AssignVariableOp"(%arg4, %0#0) <{validate_shape = true}> : (tensor<*x!tf_type.resource<tensor<4x8xf32>>>, tensor<4x8xf32>) -> ()
     // CHECK:        %1 = "tf.Identity"(%0#1) : (tensor<i64>) -> tensor<i64>
-    // CHECK-NEXT:   "tf.AssignVariableOp"(%arg5, %1) {validate_shape = false} : (tensor<*x!tf_type.resource<tensor<i64>>>, tensor<i64>) -> ()
+    // CHECK-NEXT:   "tf.AssignVariableOp"(%arg5, %1) <{validate_shape = false}> : (tensor<*x!tf_type.resource<tensor<i64>>>, tensor<i64>) -> ()
     %0:2 = "tf.RestoreV2"(%arg1, %arg2, %arg3): (tensor<!tf_type.string>, tensor<2x!tf_type.string>, tensor<2x!tf_type.string>) -> (tensor<*xf32>, tensor<*xi64>)
     "tf.AssignVariableOp"(%arg4, %0#0) {validate_shape = true} : (tensor<*x!tf_type.resource<tensor<4x8xf32>>>, tensor<*xf32>) -> ()
     %1 = "tf.Identity"(%0#1) {} : (tensor<*xi64>) -> tensor<*xi64>
@@ -19,9 +19,9 @@
 // Check the tf.RestoreV2Op's and all connected ops' resulting types are inferred from the AssignVariableOps in cross mesh cluster. All unknown shapes should be known after this pass.
 func.func @main(%arg0: tensor<i32>, %arg1: tensor<!tf_type.string>, %arg2: tensor<2x!tf_type.string>, %arg3: tensor<2x!tf_type.string>, %arg4: tensor<*x!tf_type.resource<tensor<4x8xf32>>>, %arg5: tensor<*x!tf_type.resource<tensor<i64>>>) {
     // CHECK:        "tf_device.cluster"
-    // CHECK-NEXT:       %2 = "tf.DTensorRecv"() {key = "communication_key_|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1,/job:localhost/replica:0/task:0/device:TPU:2,/job:localhost/replica:0/task:0/device:TPU:3", mesh = #dtensor.mesh<|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1,/job:localhost/replica:0/task:0/device:TPU:2,/job:localhost/replica:0/task:0/device:TPU:3>, shape = #tf_type.shape<4x8>} : () -> tensor<4x8xf32>
+    // CHECK-NEXT:       %2 = "tf.DTensorRecv"() <{key = "communication_key_|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1,/job:localhost/replica:0/task:0/device:TPU:2,/job:localhost/replica:0/task:0/device:TPU:3", mesh = #dtensor.mesh<|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1,/job:localhost/replica:0/task:0/device:TPU:2,/job:localhost/replica:0/task:0/device:TPU:3>, shape = #tf_type.shape<4x8>}> : () -> tensor<4x8xf32>
     // CHECK-NEXT:       %3 = "tf.Identity"(%2) : (tensor<4x8xf32>) -> tensor<4x8xf32>
-    // CHECK-NEXT:       "tf.AssignVariableOp"(%arg4, %3) {validate_shape = true} : (tensor<*x!tf_type.resource<tensor<4x8xf32>>>, tensor<4x8xf32>) -> ()
+    // CHECK-NEXT:       "tf.AssignVariableOp"(%arg4, %3) <{validate_shape = true}> : (tensor<*x!tf_type.resource<tensor<4x8xf32>>>, tensor<4x8xf32>) -> ()
     // CHECK-NEXT:       tf_device.return
     "tf_device.cluster"() ({
       %1 = "tf.DTensorRecv"() {key = "communication_key_|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1,/job:localhost/replica:0/task:0/device:TPU:2,/job:localhost/replica:0/task:0/device:TPU:3", mesh = #dtensor.mesh<|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1,/job:localhost/replica:0/task:0/device:TPU:2,/job:localhost/replica:0/task:0/device:TPU:3>, shape = #tf_type.shape<*>} : () -> tensor<*xf32>
@@ -33,8 +33,8 @@
     // CHECK:        "tf_device.cluster"
     // CHECK-NEXT:       %2:2 = "tf.RestoreV2"(%arg1, %arg2, %arg3) : (tensor<!tf_type.string>, tensor<2x!tf_type.string>, tensor<2x!tf_type.string>) -> (tensor<4x8xf32>, tensor<i64>)
     // CHECK-NEXT:       %3 = "tf.Identity"(%2#1) : (tensor<i64>) -> tensor<i64>
-    // CHECK-NEXT:       "tf.AssignVariableOp"(%arg5, %3) {validate_shape = false} : (tensor<*x!tf_type.resource<tensor<i64>>>, tensor<i64>) -> ()
-    // CHECK-NEXT:       "tf.DTensorSend"(%2#0) {key = "communication_key_|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1,/job:localhost/replica:0/task:0/device:TPU:2,/job:localhost/replica:0/task:0/device:TPU:3", target_mesh = #dtensor.mesh<|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1,/job:localhost/replica:0/task:0/device:TPU:2,/job:localhost/replica:0/task:0/device:TPU:3>} : (tensor<4x8xf32>) -> ()
+    // CHECK-NEXT:       "tf.AssignVariableOp"(%arg5, %3) <{validate_shape = false}> : (tensor<*x!tf_type.resource<tensor<i64>>>, tensor<i64>) -> ()
+    // CHECK-NEXT:       "tf.DTensorSend"(%2#0) <{key = "communication_key_|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1,/job:localhost/replica:0/task:0/device:TPU:2,/job:localhost/replica:0/task:0/device:TPU:3", target_mesh = #dtensor.mesh<|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1,/job:localhost/replica:0/task:0/device:TPU:2,/job:localhost/replica:0/task:0/device:TPU:3>}> : (tensor<4x8xf32>) -> ()
     // CHECK-NEXT:       tf_device.return
     "tf_device.cluster"() ({
       %6:2 = "tf.RestoreV2"(%arg1, %arg2, %arg3) {} : (tensor<!tf_type.string>, tensor<2x!tf_type.string>, tensor<2x!tf_type.string>) -> (tensor<*xf32>, tensor<*xi64>)
@@ -51,9 +51,9 @@
 // Check correctness of shape inference and element type propagation of a graph containing tf.Cast ops.
 func.func @main(%arg0: tensor<i32>, %arg1: tensor<!tf_type.string>, %arg2: tensor<2x!tf_type.string>, %arg3: tensor<2x!tf_type.string>, %arg4: tensor<*x!tf_type.resource<tensor<4x8xf32>>>, %arg5: tensor<*x!tf_type.resource<tensor<f32>>>) {
     // CHECK:        %0:2 = "tf.RestoreV2"(%arg1, %arg2, %arg3) : (tensor<!tf_type.string>, tensor<2x!tf_type.string>, tensor<2x!tf_type.string>) -> (tensor<4x8xf32>, tensor<bf16>)
-    // CHECK-NEXT:   "tf.AssignVariableOp"(%arg4, %0#0) {validate_shape = true} : (tensor<*x!tf_type.resource<tensor<4x8xf32>>>, tensor<4x8xf32>) -> ()
-    // CHECK:        %1 = "tf.Cast"(%0#1) {Truncate = false} : (tensor<bf16>) -> tensor<f32>
-    // CHECK-NEXT:   "tf.AssignVariableOp"(%arg5, %1) {validate_shape = false} : (tensor<*x!tf_type.resource<tensor<f32>>>, tensor<f32>) -> ()
+    // CHECK-NEXT:   "tf.AssignVariableOp"(%arg4, %0#0) <{validate_shape = true}> : (tensor<*x!tf_type.resource<tensor<4x8xf32>>>, tensor<4x8xf32>) -> ()
+    // CHECK:        %1 = "tf.Cast"(%0#1) <{Truncate = false}> : (tensor<bf16>) -> tensor<f32>
+    // CHECK-NEXT:   "tf.AssignVariableOp"(%arg5, %1) <{validate_shape = false}> : (tensor<*x!tf_type.resource<tensor<f32>>>, tensor<f32>) -> ()
     %0:2 = "tf.RestoreV2"(%arg1, %arg2, %arg3): (tensor<!tf_type.string>, tensor<2x!tf_type.string>, tensor<2x!tf_type.string>) -> (tensor<*xf32>, tensor<*xbf16>)
     "tf.AssignVariableOp"(%arg4, %0#0) {validate_shape = true} : (tensor<*x!tf_type.resource<tensor<4x8xf32>>>, tensor<*xf32>) -> ()
     %1 = "tf.Cast"(%0#1) {} : (tensor<*xbf16>) -> tensor<*xf32>
diff --git a/tensorflow/dtensor/mlir/tests/spmd_concat.mlir b/tensorflow/dtensor/mlir/tests/spmd_concat.mlir
index 8989bda..0f84bfc 100644
--- a/tensorflow/dtensor/mlir/tests/spmd_concat.mlir
+++ b/tensorflow/dtensor/mlir/tests/spmd_concat.mlir
@@ -39,14 +39,14 @@
   // CHECK:       "tf_device.cluster"
   // CHECK-NEXT:    %[[AXIS:.*]] = "tf.Const"()
   // CHECK-NEXT:    %[[ARG1_RELAYOUT:.*]] = "tf.DTensorAllGather"(%[[ARG1]])
-  // CHECK-SAME:      _layout = ["sharding_specs:unsharded,unsharded,unsharded, mesh:|x=4|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"]
   // CHECK-SAME:      input_layout = #dtensor.layout<sharding_specs:unsharded,x,unsharded, mesh:|x=4|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3>
   // CHECK-SAME:      output_layout = #dtensor.layout<sharding_specs:unsharded,unsharded,unsharded, mesh:|x=4|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3>
+  // CHECK-SAME:      _layout = ["sharding_specs:unsharded,unsharded,unsharded, mesh:|x=4|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"]
   // CHECK-SAME:      (tensor<8x2x32xf32>) -> tensor<8x8x32xf32>
   // CHECK-NEXT:    %[[ARG2_RELAYOUT:.*]] = "tf.DTensorAllGather"(%[[ARG2]])
-  // CHECK-SAME:      _layout = ["sharding_specs:unsharded,unsharded,unsharded, mesh:|x=4|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"]
   // CHECK-SAME:      input_layout = #dtensor.layout<sharding_specs:unsharded,x,unsharded, mesh:|x=4|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3>
   // CHECK-SAME:      output_layout = #dtensor.layout<sharding_specs:unsharded,unsharded,unsharded, mesh:|x=4|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3>
+  // CHECK-SAME:      _layout = ["sharding_specs:unsharded,unsharded,unsharded, mesh:|x=4|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"]
   // CHECK-SAME:      (tensor<8x4x32xf32>) -> tensor<8x16x32xf32>
   // CHECK-NEXT:    %[[CONCAT_OUT:.*]] = "tf.ConcatV2"(%[[ARG0]], %[[ARG1_RELAYOUT]], %[[ARG2_RELAYOUT]], %[[AXIS]])
   // CHECK-SAME:      _layout = ["sharding_specs:unsharded,unsharded,unsharded, mesh:|x=4|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"]
@@ -80,14 +80,14 @@
   // CHECK:       "tf_device.cluster"
   // CHECK-NEXT:    %[[AXIS:.*]] = "tf.Const"()
   // CHECK-NEXT:    %[[ARG0_RELAYOUT:.*]] = "tf.DTensorAllScatter"(%[[ARG0]])
-  // CHECK-SAME:      _layout = ["sharding_specs:x,unsharded,y, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"]
   // CHECK-SAME:      input_layout = #dtensor.layout<sharding_specs:x,unsharded,unsharded, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3>
   // CHECK-SAME:      output_layout = #dtensor.layout<sharding_specs:x,unsharded,y, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3>
+  // CHECK-SAME:      _layout = ["sharding_specs:x,unsharded,y, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"]
   // CHECK-SAME:      (tensor<4x4x32xf32>) -> tensor<4x4x16xf32>
   // CHECK-NEXT:    %[[ARG1_RELAYOUT:.*]] = "tf.DTensorAllScatter"(%[[ARG1]])
-  // CHECK-SAME:      _layout = ["sharding_specs:x,unsharded,y, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"]
   // CHECK-SAME:      input_layout = #dtensor.layout<sharding_specs:unsharded,unsharded,y, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3>
   // CHECK-SAME:      output_layout = #dtensor.layout<sharding_specs:x,unsharded,y, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3>
+  // CHECK-SAME:      _layout = ["sharding_specs:x,unsharded,y, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"]
   // CHECK-SAME:      (tensor<8x8x16xf32>) -> tensor<4x8x16xf32>
   // CHECK-NEXT:    %[[CONCAT_OUT:.*]] = "tf.ConcatV2"(%[[ARG0_RELAYOUT]], %[[ARG1_RELAYOUT]], %[[ARG2]], %[[AXIS]])
   // CHECK-SAME:      _layout = ["sharding_specs:x,unsharded,y, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"]
@@ -118,14 +118,14 @@
   // CHECK:       "tf_device.cluster"
   // CHECK-NEXT:    %[[AXIS:.*]] = "tf.Const"()
   // CHECK-NEXT:    %[[ARG0_RELAYOUT:.*]] = "tf.DTensorAllGather"(%[[ARG0]])
-  // CHECK-SAME:      _layout = ["sharding_specs:unsharded,unsharded,unsharded, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"]
   // CHECK-SAME:      input_layout = #dtensor.layout<sharding_specs:unsharded,x,unsharded, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3>
   // CHECK-SAME:      output_layout = #dtensor.layout<sharding_specs:unsharded,unsharded,unsharded, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3>
+  // CHECK-SAME:      _layout = ["sharding_specs:unsharded,unsharded,unsharded, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"]
   // CHECK-SAME:      (tensor<8x4x32xf32>) -> tensor<8x8x32xf32>
   // CHECK-NEXT:    %[[ARG1_RELAYOUT:.*]] = "tf.DTensorAllGather"(%[[ARG1]])
-  // CHECK-SAME:      _layout = ["sharding_specs:unsharded,unsharded,unsharded, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"]
   // CHECK-SAME:      input_layout = #dtensor.layout<sharding_specs:unsharded,unsharded,x, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3>
   // CHECK-SAME:      output_layout = #dtensor.layout<sharding_specs:unsharded,unsharded,unsharded, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3>
+  // CHECK-SAME:      _layout = ["sharding_specs:unsharded,unsharded,unsharded, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"]
   // CHECK-SAME:      (tensor<16x8x16xf32>) -> tensor<16x8x32xf32>
   // CHECK-NEXT:    %[[CONCAT_OUT:.*]] = "tf.ConcatV2"(%[[ARG0_RELAYOUT]], %[[ARG1_RELAYOUT]], %[[AXIS]])
   // CHECK-SAME:      _layout = ["sharding_specs:unsharded,unsharded,unsharded, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"]
@@ -158,19 +158,19 @@
   // CHECK:       "tf_device.cluster"
   // CHECK-NEXT:    %[[AXIS:.*]] = "tf.Const"()
   // CHECK-NEXT:    %[[ARG1_SCATTER:.*]] = "tf.DTensorAllScatter"(%[[ARG1]])
-  // CHECK-SAME:      _layout = ["sharding_specs:x,y,unsharded, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"]
   // CHECK-SAME:      input_layout = #dtensor.layout<sharding_specs:unsharded,y,unsharded, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3>
   // CHECK-SAME:      output_layout = #dtensor.layout<sharding_specs:x,y,unsharded, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3>
+  // CHECK-SAME:      _layout = ["sharding_specs:x,y,unsharded, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"]
   // CHECK-SAME:      (tensor<8x4x32xf32>) -> tensor<4x4x32xf32>
   // CHECK-NEXT:    %[[ARG1_RELAYOUT:.*]] = "tf.DTensorAllGather"(%[[ARG1_SCATTER]])
-  // CHECK-SAME:      _layout = ["sharding_specs:x,unsharded,unsharded, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"]
   // CHECK-SAME:      input_layout = #dtensor.layout<sharding_specs:x,y,unsharded, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3>
   // CHECK-SAME:      output_layout = #dtensor.layout<sharding_specs:x,unsharded,unsharded, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3>
+  // CHECK-SAME:      _layout = ["sharding_specs:x,unsharded,unsharded, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"]
   // CHECK-SAME:      (tensor<4x4x32xf32>) -> tensor<4x8x32xf32>
   // CHECK-NEXT:    %[[ARG2_RELAYOUT:.*]] = "tf.DTensorAllGather"(%[[ARG2]])
-  // CHECK-SAME:      _layout = ["sharding_specs:x,unsharded,unsharded, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"]
   // CHECK-SAME:      input_layout = #dtensor.layout<sharding_specs:x,y,unsharded, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3>
   // CHECK-SAME:      output_layout = #dtensor.layout<sharding_specs:x,unsharded,unsharded, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3>
+  // CHECK-SAME:      _layout = ["sharding_specs:x,unsharded,unsharded, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"]
   // CHECK-SAME:      (tensor<4x8x32xf32>) -> tensor<4x16x32xf32>
   // CHECK-NEXT:    %[[CONCAT_OUT:.*]] = "tf.ConcatV2"(%[[ARG0]], %[[ARG1_RELAYOUT]], %[[ARG2_RELAYOUT]], %[[AXIS]])
   // CHECK-SAME:      _layout = ["sharding_specs:x,unsharded,unsharded, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"]
diff --git a/tensorflow/dtensor/mlir/tests/spmd_conv.mlir b/tensorflow/dtensor/mlir/tests/spmd_conv.mlir
index 64da631..9c776a9 100644
--- a/tensorflow/dtensor/mlir/tests/spmd_conv.mlir
+++ b/tensorflow/dtensor/mlir/tests/spmd_conv.mlir
@@ -221,8 +221,8 @@
   // CHECK:         "tf_device.cluster"
 
   // Build left halo on height dim.
-  // CHECK:           %[[SLICE_H_LEFT_BEGIN:.*]] = "tf.Const"() {value = dense<[0, 3, 0, 0]> : tensor<4xi32>} : () -> tensor<4xi32>
-  // CHECK-NEXT:      %[[SLICE_H_LEFT_SIZE:.*]] = "tf.Const"() {value = dense<[8, 1, 4, 3]> : tensor<4xi32>} : () -> tensor<4xi32>
+  // CHECK:           %[[SLICE_H_LEFT_BEGIN:.*]] = "tf.Const"() <{value = dense<[0, 3, 0, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
+  // CHECK-NEXT:      %[[SLICE_H_LEFT_SIZE:.*]] = "tf.Const"() <{value = dense<[8, 1, 4, 3]> : tensor<4xi32>}> : () -> tensor<4xi32>
   // CHECK-NEXT:      %[[SLICE_H_LEFT:.*]] = "tf.Slice"(%arg1, %[[SLICE_H_LEFT_BEGIN]], %[[SLICE_H_LEFT_SIZE]])
   // CHECK-SAME:          (tensor<8x4x4x3xf32>, tensor<4xi32>, tensor<4xi32>) -> tensor<8x1x4x3xf32>
   // CHECK-NEXT:      %[[HALO_H_LEFT:.*]] = "tf.SelectV2"
@@ -231,8 +231,8 @@
   // CHECK-NEXT:      %[[EXCHANGED_HALO_H_LEFT:.*]] = "tf.CollectivePermute"(%[[HALO_H_LEFT]], %[[PAIRS_H_LEFT]])
   // CHECK-SAME:          (tensor<8x1x4x3xf32>, tensor<4x2xi32>) -> tensor<8x1x4x3xf32>
   // Build right halo on height dim.
-  // CHECK:           %[[SLICE_H_RIGHT_BEGIN:.*]] = "tf.Const"() {value = dense<0> : tensor<4xi32>} : () -> tensor<4xi32>
-  // CHECK-NEXT:      %[[SLICE_H_RIGHT_SIZE:.*]] = "tf.Const"() {value = dense<[8, 1, 4, 3]> : tensor<4xi32>} : () -> tensor<4xi32>
+  // CHECK:           %[[SLICE_H_RIGHT_BEGIN:.*]] = "tf.Const"() <{value = dense<0> : tensor<4xi32>}> : () -> tensor<4xi32>
+  // CHECK-NEXT:      %[[SLICE_H_RIGHT_SIZE:.*]] = "tf.Const"() <{value = dense<[8, 1, 4, 3]> : tensor<4xi32>}> : () -> tensor<4xi32>
   // CHECK-NEXT:      %[[SLICE_H_RIGHT:.*]] = "tf.Slice"(%arg1, %[[SLICE_H_RIGHT_BEGIN]], %[[SLICE_H_RIGHT_SIZE]])
   // CHECK-SAME:          (tensor<8x4x4x3xf32>, tensor<4xi32>, tensor<4xi32>) -> tensor<8x1x4x3xf32>
   // CHECK-NEXT:      %[[HALO_H_RIGHT:.*]] = "tf.SelectV2"
@@ -241,13 +241,13 @@
   // CHECK-NEXT:      %[[EXCHANGED_HALO_H_RIGHT:.*]] = "tf.CollectivePermute"(%[[HALO_H_RIGHT]], %[[PAIRS_H_RIGHT]])
   // CHECK-SAME:          (tensor<8x1x4x3xf32>, tensor<4x2xi32>) -> tensor<8x1x4x3xf32>
   // Concat the halos with the shard on the height dim.
-  // CHECK-NEXT:      %[[CONCAT_H_AXIS:.*]] = "tf.Const"() {value = dense<1> : tensor<i64>} : () -> tensor<i64>
+  // CHECK-NEXT:      %[[CONCAT_H_AXIS:.*]] = "tf.Const"() <{value = dense<1> : tensor<i64>}> : () -> tensor<i64>
   // CHECK-NEXT:      %[[CONCAT_H_TENSOR:.*]] = "tf.ConcatV2"(%[[EXCHANGED_HALO_H_LEFT]], %arg1, %[[EXCHANGED_HALO_H_RIGHT]], %[[CONCAT_H_AXIS]])
   // CHECK-SAME:          (tensor<8x1x4x3xf32>, tensor<8x4x4x3xf32>, tensor<8x1x4x3xf32>, tensor<i64>) -> tensor<8x6x4x3xf32>
 
   // Build left halo on width dim.
-  // CHECK:           %[[SLICE_W_LEFT_BEGIN:.*]] = "tf.Const"() {value = dense<[0, 0, 3, 0]> : tensor<4xi32>} : () -> tensor<4xi32>
-  // CHECK-NEXT:      %[[SLICE_W_LEFT_SIZE:.*]] = "tf.Const"() {value = dense<[8, 6, 1, 3]> : tensor<4xi32>} : () -> tensor<4xi32>
+  // CHECK:           %[[SLICE_W_LEFT_BEGIN:.*]] = "tf.Const"() <{value = dense<[0, 0, 3, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
+  // CHECK-NEXT:      %[[SLICE_W_LEFT_SIZE:.*]] = "tf.Const"() <{value = dense<[8, 6, 1, 3]> : tensor<4xi32>}> : () -> tensor<4xi32>
   // CHECK-NEXT:      %[[SLICE_W_LEFT:.*]] = "tf.Slice"(%[[CONCAT_H_TENSOR]], %[[SLICE_W_LEFT_BEGIN]], %[[SLICE_W_LEFT_SIZE]])
   // CHECK-SAME:          (tensor<8x6x4x3xf32>, tensor<4xi32>, tensor<4xi32>) -> tensor<8x6x1x3xf32>
   // CHECK-NEXT:      %[[HALO_W_LEFT:.*]] = "tf.SelectV2"
@@ -256,8 +256,8 @@
   // CHECK-NEXT:      %[[EXCHANGED_HALO_W_LEFT:.*]] = "tf.CollectivePermute"(%[[HALO_W_LEFT]], %[[PAIRS_W_LEFT]])
   // CHECK-SAME:          (tensor<8x6x1x3xf32>, tensor<4x2xi32>) -> tensor<8x6x1x3xf32>
   // Build right halo on width dim.
-  // CHECK:           %[[SLICE_W_RIGHT_BEGIN:.*]] = "tf.Const"() {value = dense<0> : tensor<4xi32>} : () -> tensor<4xi32>
-  // CHECK-NEXT:      %[[SLICE_W_RIGHT_SIZE:.*]] = "tf.Const"() {value = dense<[8, 6, 1, 3]> : tensor<4xi32>} : () -> tensor<4xi32>
+  // CHECK:           %[[SLICE_W_RIGHT_BEGIN:.*]] = "tf.Const"() <{value = dense<0> : tensor<4xi32>}> : () -> tensor<4xi32>
+  // CHECK-NEXT:      %[[SLICE_W_RIGHT_SIZE:.*]] = "tf.Const"() <{value = dense<[8, 6, 1, 3]> : tensor<4xi32>}> : () -> tensor<4xi32>
   // CHECK-NEXT:      %[[SLICE_W_RIGHT:.*]] = "tf.Slice"(%[[CONCAT_H_TENSOR]], %[[SLICE_W_RIGHT_BEGIN]], %[[SLICE_W_RIGHT_SIZE]])
   // CHECK-SAME:          (tensor<8x6x4x3xf32>, tensor<4xi32>, tensor<4xi32>) -> tensor<8x6x1x3xf32>
   // CHECK-NEXT:      %[[HALO_W_RIGHT:.*]] = "tf.SelectV2"
@@ -266,7 +266,7 @@
   // CHECK-NEXT:      %[[EXCHANGED_HALO_W_RIGHT:.*]] = "tf.CollectivePermute"(%[[HALO_W_RIGHT]], %[[PAIRS_W_RIGHT]])
   // CHECK-SAME:          (tensor<8x6x1x3xf32>, tensor<4x2xi32>) -> tensor<8x6x1x3xf32>
   // Concat the halos with the shard on the width dim.
-  // CHECK-NEXT:      %[[CONCAT_W_AXIS:.*]] = "tf.Const"() {value = dense<2> : tensor<i64>} : () -> tensor<i64>
+  // CHECK-NEXT:      %[[CONCAT_W_AXIS:.*]] = "tf.Const"() <{value = dense<2> : tensor<i64>}> : () -> tensor<i64>
   // CHECK-NEXT:      %[[CONCAT_HW_TENSOR:.*]] = "tf.ConcatV2"(%[[EXCHANGED_HALO_W_LEFT]], %[[CONCAT_H_TENSOR]], %[[EXCHANGED_HALO_W_RIGHT]], %[[CONCAT_W_AXIS]])
   // CHECK-SAME:          (tensor<8x6x1x3xf32>, tensor<8x6x4x3xf32>, tensor<8x6x1x3xf32>, tensor<i64>) -> tensor<8x6x6x3xf32>
 
@@ -293,8 +293,8 @@
   // CHECK:         "tf_device.cluster"
 
   // Build left halo on depth dim.
-  // CHECK:           %[[SLICE_D_LEFT_BEGIN:.*]] = "tf.Const"() {value = dense<[0, 3, 0, 0, 0]> : tensor<5xi32>} : () -> tensor<5xi32>
-  // CHECK-NEXT:      %[[SLICE_D_LEFT_SIZE:.*]] = "tf.Const"() {value = dense<[8, 1, 4, 4, 3]> : tensor<5xi32>} : () -> tensor<5xi32>
+  // CHECK:           %[[SLICE_D_LEFT_BEGIN:.*]] = "tf.Const"() <{value = dense<[0, 3, 0, 0, 0]> : tensor<5xi32>}> : () -> tensor<5xi32>
+  // CHECK-NEXT:      %[[SLICE_D_LEFT_SIZE:.*]] = "tf.Const"() <{value = dense<[8, 1, 4, 4, 3]> : tensor<5xi32>}> : () -> tensor<5xi32>
   // CHECK-NEXT:      %[[SLICE_D_LEFT:.*]] = "tf.Slice"(%arg1, %[[SLICE_D_LEFT_BEGIN]], %[[SLICE_D_LEFT_SIZE]])
   // CHECK-SAME:          (tensor<8x4x4x4x3xf32>, tensor<5xi32>, tensor<5xi32>) -> tensor<8x1x4x4x3xf32>
   // CHECK-NEXT:      %[[HALO_D_LEFT:.*]] = "tf.SelectV2"
@@ -303,8 +303,8 @@
   // CHECK-NEXT:      %[[EXCHANGED_HALO_D_LEFT:.*]] = "tf.CollectivePermute"(%[[HALO_D_LEFT]], %[[PAIRS_D_LEFT]])
   // CHECK-SAME:          (tensor<8x1x4x4x3xf32>, tensor<8x2xi32>) -> tensor<8x1x4x4x3xf32>
   // Build right halo on depth dim.
-  // CHECK:           %[[SLICE_D_RIGHT_BEGIN:.*]] = "tf.Const"() {value = dense<0> : tensor<5xi32>} : () -> tensor<5xi32>
-  // CHECK-NEXT:      %[[SLICE_D_RIGHT_SIZE:.*]] = "tf.Const"() {value = dense<[8, 1, 4, 4, 3]> : tensor<5xi32>} : () -> tensor<5xi32>
+  // CHECK:           %[[SLICE_D_RIGHT_BEGIN:.*]] = "tf.Const"() <{value = dense<0> : tensor<5xi32>}> : () -> tensor<5xi32>
+  // CHECK-NEXT:      %[[SLICE_D_RIGHT_SIZE:.*]] = "tf.Const"() <{value = dense<[8, 1, 4, 4, 3]> : tensor<5xi32>}> : () -> tensor<5xi32>
   // CHECK-NEXT:      %[[SLICE_D_RIGHT:.*]] = "tf.Slice"(%arg1, %[[SLICE_D_RIGHT_BEGIN]], %[[SLICE_D_RIGHT_SIZE]])
   // CHECK-SAME:          (tensor<8x4x4x4x3xf32>, tensor<5xi32>, tensor<5xi32>) -> tensor<8x1x4x4x3xf32>
   // CHECK-NEXT:      %[[HALO_D_RIGHT:.*]] = "tf.SelectV2"
@@ -313,13 +313,13 @@
   // CHECK-NEXT:      %[[EXCHANGED_HALO_D_RIGHT:.*]] = "tf.CollectivePermute"(%[[HALO_D_RIGHT]], %[[PAIRS_D_RIGHT]])
   // CHECK-SAME:          (tensor<8x1x4x4x3xf32>, tensor<8x2xi32>) -> tensor<8x1x4x4x3xf32>
   // Concat the halos with the shard on the depth dim.
-  // CHECK-NEXT:      %[[CONCAT_D_AXIS:.*]] = "tf.Const"() {value = dense<1> : tensor<i64>} : () -> tensor<i64>
+  // CHECK-NEXT:      %[[CONCAT_D_AXIS:.*]] = "tf.Const"() <{value = dense<1> : tensor<i64>}> : () -> tensor<i64>
   // CHECK-NEXT:      %[[CONCAT_D_TENSOR:.*]] = "tf.ConcatV2"(%[[EXCHANGED_HALO_D_LEFT]], %arg1, %[[EXCHANGED_HALO_D_RIGHT]], %[[CONCAT_D_AXIS]])
   // CHECK-SAME:          (tensor<8x1x4x4x3xf32>, tensor<8x4x4x4x3xf32>, tensor<8x1x4x4x3xf32>, tensor<i64>) -> tensor<8x6x4x4x3xf32>
 
   // Build left halo on height dim.
-  // CHECK:           %[[SLICE_H_LEFT_BEGIN:.*]] = "tf.Const"() {value = dense<[0, 0, 3, 0, 0]> : tensor<5xi32>} : () -> tensor<5xi32>
-  // CHECK-NEXT:      %[[SLICE_H_LEFT_SIZE:.*]] = "tf.Const"() {value = dense<[8, 6, 1, 4, 3]> : tensor<5xi32>} : () -> tensor<5xi32>
+  // CHECK:           %[[SLICE_H_LEFT_BEGIN:.*]] = "tf.Const"() <{value = dense<[0, 0, 3, 0, 0]> : tensor<5xi32>}> : () -> tensor<5xi32>
+  // CHECK-NEXT:      %[[SLICE_H_LEFT_SIZE:.*]] = "tf.Const"() <{value = dense<[8, 6, 1, 4, 3]> : tensor<5xi32>}> : () -> tensor<5xi32>
   // CHECK-NEXT:      %[[SLICE_H_LEFT:.*]] = "tf.Slice"(%[[CONCAT_D_TENSOR]], %[[SLICE_H_LEFT_BEGIN]], %[[SLICE_H_LEFT_SIZE]])
   // CHECK-SAME:          (tensor<8x6x4x4x3xf32>, tensor<5xi32>, tensor<5xi32>) -> tensor<8x6x1x4x3xf32>
   // CHECK-NEXT:      %[[HALO_H_LEFT:.*]] = "tf.SelectV2"
@@ -328,8 +328,8 @@
   // CHECK-NEXT:      %[[EXCHANGED_HALO_H_LEFT:.*]] = "tf.CollectivePermute"(%[[HALO_H_LEFT]], %[[PAIRS_H_LEFT]])
   // CHECK-SAME:          (tensor<8x6x1x4x3xf32>, tensor<8x2xi32>) -> tensor<8x6x1x4x3xf32>
   // Build right halo on height dim.
-  // CHECK:           %[[SLICE_H_RIGHT_BEGIN:.*]] = "tf.Const"() {value = dense<0> : tensor<5xi32>} : () -> tensor<5xi32>
-  // CHECK-NEXT:      %[[SLICE_H_RIGHT_SIZE:.*]] = "tf.Const"() {value = dense<[8, 6, 1, 4, 3]> : tensor<5xi32>} : () -> tensor<5xi32>
+  // CHECK:           %[[SLICE_H_RIGHT_BEGIN:.*]] = "tf.Const"() <{value = dense<0> : tensor<5xi32>}> : () -> tensor<5xi32>
+  // CHECK-NEXT:      %[[SLICE_H_RIGHT_SIZE:.*]] = "tf.Const"() <{value = dense<[8, 6, 1, 4, 3]> : tensor<5xi32>}> : () -> tensor<5xi32>
   // CHECK-NEXT:      %[[SLICE_H_RIGHT:.*]] = "tf.Slice"(%[[CONCAT_D_TENSOR]], %[[SLICE_H_RIGHT_BEGIN]], %[[SLICE_H_RIGHT_SIZE]])
   // CHECK-SAME:          (tensor<8x6x4x4x3xf32>, tensor<5xi32>, tensor<5xi32>) -> tensor<8x6x1x4x3xf32>
   // CHECK-NEXT:      %[[HALO_H_RIGHT:.*]] = "tf.SelectV2"
@@ -338,13 +338,13 @@
   // CHECK-NEXT:      %[[EXCHANGED_HALO_H_RIGHT:.*]] = "tf.CollectivePermute"(%[[HALO_H_RIGHT]], %[[PAIRS_H_RIGHT]])
   // CHECK-SAME:          (tensor<8x6x1x4x3xf32>, tensor<8x2xi32>) -> tensor<8x6x1x4x3xf32>
   // Concat the halos with the shard on the height dim.
-  // CHECK-NEXT:      %[[CONCAT_H_AXIS:.*]] = "tf.Const"() {value = dense<2> : tensor<i64>} : () -> tensor<i64>
+  // CHECK-NEXT:      %[[CONCAT_H_AXIS:.*]] = "tf.Const"() <{value = dense<2> : tensor<i64>}> : () -> tensor<i64>
   // CHECK-NEXT:      %[[CONCAT_DH_TENSOR:.*]] = "tf.ConcatV2"(%[[EXCHANGED_HALO_H_LEFT]], %[[CONCAT_D_TENSOR]], %[[EXCHANGED_HALO_H_RIGHT]], %[[CONCAT_H_AXIS]])
   // CHECK-SAME:          (tensor<8x6x1x4x3xf32>, tensor<8x6x4x4x3xf32>, tensor<8x6x1x4x3xf32>, tensor<i64>) -> tensor<8x6x6x4x3xf32>
 
   // Build left halo on width dim.
-  // CHECK:           %[[SLICE_W_LEFT_BEGIN:.*]] = "tf.Const"() {value = dense<[0, 0, 0, 3, 0]> : tensor<5xi32>} : () -> tensor<5xi32>
-  // CHECK-NEXT:      %[[SLICE_W_LEFT_SIZE:.*]] = "tf.Const"() {value = dense<[8, 6, 6, 1, 3]> : tensor<5xi32>} : () -> tensor<5xi32>
+  // CHECK:           %[[SLICE_W_LEFT_BEGIN:.*]] = "tf.Const"() <{value = dense<[0, 0, 0, 3, 0]> : tensor<5xi32>}> : () -> tensor<5xi32>
+  // CHECK-NEXT:      %[[SLICE_W_LEFT_SIZE:.*]] = "tf.Const"() <{value = dense<[8, 6, 6, 1, 3]> : tensor<5xi32>}> : () -> tensor<5xi32>
   // CHECK-NEXT:      %[[SLICE_W_LEFT:.*]] = "tf.Slice"(%[[CONCAT_DH_TENSOR]], %[[SLICE_W_LEFT_BEGIN]], %[[SLICE_W_LEFT_SIZE]])
   // CHECK-SAME:          (tensor<8x6x6x4x3xf32>, tensor<5xi32>, tensor<5xi32>) -> tensor<8x6x6x1x3xf32>
   // CHECK-NEXT:      %[[HALO_W_LEFT:.*]] = "tf.SelectV2"
@@ -353,8 +353,8 @@
   // CHECK-NEXT:      %[[EXCHANGED_HALO_W_LEFT:.*]] = "tf.CollectivePermute"(%[[HALO_W_LEFT]], %[[PAIRS_W_LEFT]])
   // CHECK-SAME:          (tensor<8x6x6x1x3xf32>, tensor<8x2xi32>) -> tensor<8x6x6x1x3xf32>
   // Build right halo on width dim.
-  // CHECK:           %[[SLICE_W_RIGHT_BEGIN:.*]] = "tf.Const"() {value = dense<0> : tensor<5xi32>} : () -> tensor<5xi32>
-  // CHECK-NEXT:      %[[SLICE_W_RIGHT_SIZE:.*]] = "tf.Const"() {value = dense<[8, 6, 6, 1, 3]> : tensor<5xi32>} : () -> tensor<5xi32>
+  // CHECK:           %[[SLICE_W_RIGHT_BEGIN:.*]] = "tf.Const"() <{value = dense<0> : tensor<5xi32>}> : () -> tensor<5xi32>
+  // CHECK-NEXT:      %[[SLICE_W_RIGHT_SIZE:.*]] = "tf.Const"() <{value = dense<[8, 6, 6, 1, 3]> : tensor<5xi32>}> : () -> tensor<5xi32>
   // CHECK-NEXT:      %[[SLICE_W_RIGHT:.*]] = "tf.Slice"(%[[CONCAT_DH_TENSOR]], %[[SLICE_W_RIGHT_BEGIN]], %[[SLICE_W_RIGHT_SIZE]])
   // CHECK-SAME:          (tensor<8x6x6x4x3xf32>, tensor<5xi32>, tensor<5xi32>) -> tensor<8x6x6x1x3xf32>
   // CHECK-NEXT:      %[[HALO_W_RIGHT:.*]] = "tf.SelectV2"
@@ -363,7 +363,7 @@
   // CHECK-NEXT:      %[[EXCHANGED_HALO_W_RIGHT:.*]] = "tf.CollectivePermute"(%[[HALO_W_RIGHT]], %[[PAIRS_W_RIGHT]])
   // CHECK-SAME:          (tensor<8x6x6x1x3xf32>, tensor<8x2xi32>) -> tensor<8x6x6x1x3xf32>
   // Concat the halos with the shard on the width dim.
-  // CHECK-NEXT:      %[[CONCAT_W_AXIS:.*]] = "tf.Const"() {value = dense<3> : tensor<i64>} : () -> tensor<i64>
+  // CHECK-NEXT:      %[[CONCAT_W_AXIS:.*]] = "tf.Const"() <{value = dense<3> : tensor<i64>}> : () -> tensor<i64>
   // CHECK-NEXT:      %[[CONCAT_DHW_TENSOR:.*]] = "tf.ConcatV2"(%[[EXCHANGED_HALO_W_LEFT]], %[[CONCAT_DH_TENSOR]], %[[EXCHANGED_HALO_W_RIGHT]], %[[CONCAT_W_AXIS]])
   // CHECK-SAME:          (tensor<8x6x6x1x3xf32>, tensor<8x6x6x4x3xf32>, tensor<8x6x6x1x3xf32>, tensor<i64>) -> tensor<8x6x6x6x3xf32>
 
@@ -390,8 +390,8 @@
   // CHECK:         "tf_device.cluster"
 
   // Build left halo on height dim.
-  // CHECK:           %[[SLICE_H_LEFT_BEGIN:.*]] = "tf.Const"() {value = dense<[0, 3, 0, 0]> : tensor<4xi32>} : () -> tensor<4xi32>
-  // CHECK-NEXT:      %[[SLICE_H_LEFT_SIZE:.*]] = "tf.Const"() {value = dense<[8, 1, 4, 3]> : tensor<4xi32>} : () -> tensor<4xi32>
+  // CHECK:           %[[SLICE_H_LEFT_BEGIN:.*]] = "tf.Const"() <{value = dense<[0, 3, 0, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
+  // CHECK-NEXT:      %[[SLICE_H_LEFT_SIZE:.*]] = "tf.Const"() <{value = dense<[8, 1, 4, 3]> : tensor<4xi32>}> : () -> tensor<4xi32>
   // CHECK-NEXT:      %[[SLICE_H_LEFT:.*]] = "tf.Slice"(%arg1, %[[SLICE_H_LEFT_BEGIN]], %[[SLICE_H_LEFT_SIZE]])
   // CHECK-SAME:          (tensor<8x4x4x3xf32>, tensor<4xi32>, tensor<4xi32>) -> tensor<8x1x4x3xf32>
   // CHECK-NEXT:      %[[HALO_H_LEFT:.*]] = "tf.SelectV2"
@@ -400,8 +400,8 @@
   // CHECK-NEXT:      %[[EXCHANGED_HALO_H_LEFT:.*]] = "tf.CollectivePermute"(%[[HALO_H_LEFT]], %[[PAIRS_H_LEFT]])
   // CHECK-SAME:          (tensor<8x1x4x3xf32>, tensor<4x2xi32>) -> tensor<8x1x4x3xf32>
   // Build right halo on height dim.
-  // CHECK:           %[[SLICE_H_RIGHT_BEGIN:.*]] = "tf.Const"() {value = dense<0> : tensor<4xi32>} : () -> tensor<4xi32>
-  // CHECK-NEXT:      %[[SLICE_H_RIGHT_SIZE:.*]] = "tf.Const"() {value = dense<[8, 1, 4, 3]> : tensor<4xi32>} : () -> tensor<4xi32>
+  // CHECK:           %[[SLICE_H_RIGHT_BEGIN:.*]] = "tf.Const"() <{value = dense<0> : tensor<4xi32>}> : () -> tensor<4xi32>
+  // CHECK-NEXT:      %[[SLICE_H_RIGHT_SIZE:.*]] = "tf.Const"() <{value = dense<[8, 1, 4, 3]> : tensor<4xi32>}> : () -> tensor<4xi32>
   // CHECK-NEXT:      %[[SLICE_H_RIGHT:.*]] = "tf.Slice"(%arg1, %[[SLICE_H_RIGHT_BEGIN]], %[[SLICE_H_RIGHT_SIZE]])
   // CHECK-SAME:          (tensor<8x4x4x3xf32>, tensor<4xi32>, tensor<4xi32>) -> tensor<8x1x4x3xf32>
   // CHECK-NEXT:      %[[HALO_H_RIGHT:.*]] = "tf.SelectV2"
@@ -410,21 +410,21 @@
   // CHECK-NEXT:      %[[EXCHANGED_HALO_H_RIGHT:.*]] = "tf.CollectivePermute"(%[[HALO_H_RIGHT]], %[[PAIRS_H_RIGHT]])
   // CHECK-SAME:          (tensor<8x1x4x3xf32>, tensor<4x2xi32>) -> tensor<8x1x4x3xf32>
   // Concat the halos with the shard on the height dim.
-  // CHECK-NEXT:      %[[CONCAT_H_AXIS:.*]] = "tf.Const"() {value = dense<1> : tensor<i64>} : () -> tensor<i64>
+  // CHECK-NEXT:      %[[CONCAT_H_AXIS:.*]] = "tf.Const"() <{value = dense<1> : tensor<i64>}> : () -> tensor<i64>
   // CHECK-NEXT:      %[[CONCAT_H_TENSOR:.*]] = "tf.ConcatV2"(%[[EXCHANGED_HALO_H_LEFT]], %arg1, %[[EXCHANGED_HALO_H_RIGHT]], %[[CONCAT_H_AXIS]])
   // CHECK-SAME:          (tensor<8x1x4x3xf32>, tensor<8x4x4x3xf32>, tensor<8x1x4x3xf32>, tensor<i64>) -> tensor<8x6x4x3xf32>
   // Dynamically slice the concatenated tensor to get correct size for VALID padding.
-  // CHECK-NEXT:      %[[HALO_SIZES_H:.*]] = "tf.Const"() {value = dense<[0, 1, 0, 0]> : tensor<4xi32>} : () -> tensor<4xi32>
-  // CHECK-NEXT:      %[[HALO_INCREMENTS_H:.*]] = "tf.Const"() {value = dense<[0, 1, 0, 0]> : tensor<4xi32>} : () -> tensor<4xi32>
+  // CHECK-NEXT:      %[[HALO_SIZES_H:.*]] = "tf.Const"() <{value = dense<[0, 1, 0, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
+  // CHECK-NEXT:      %[[HALO_INCREMENTS_H:.*]] = "tf.Const"() <{value = dense<[0, 1, 0, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
   // CHECK-NEXT:      %[[VALID_OFFSET_H:.*]] = "tf.Mul"
   // CHECK-NEXT:      %[[VALID_SLICE_BEGIN_H:.*]] = "tf.Sub"(%[[HALO_SIZES_H]], %[[VALID_OFFSET_H]])
-  // CHECK-NEXT:      %[[VALID_SLICE_SIZE_H:.*]] = "tf.Const"() {value = dense<[8, 5, 4, 3]> : tensor<4xi64>} : () -> tensor<4xi64>
-  // CHECK-NEXT:      %[[VALID_SLICE_BEGIN_CAST_I64_H:.*]] = "tf.Cast"(%[[VALID_SLICE_BEGIN_H]]) {Truncate = false} : (tensor<4xi32>) -> tensor<4xi64>
+  // CHECK-NEXT:      %[[VALID_SLICE_SIZE_H:.*]] = "tf.Const"() <{value = dense<[8, 5, 4, 3]> : tensor<4xi64>}> : () -> tensor<4xi64>
+  // CHECK-NEXT:      %[[VALID_SLICE_BEGIN_CAST_I64_H:.*]] = "tf.Cast"(%[[VALID_SLICE_BEGIN_H]]) <{Truncate = false}> : (tensor<4xi32>) -> tensor<4xi64>
   // CHECK-NEXT:      %[[VALID_SLICE_H_TENSOR:.*]] = "tf.Slice"(%[[CONCAT_H_TENSOR]], %[[VALID_SLICE_BEGIN_CAST_I64_H]], %[[VALID_SLICE_SIZE_H]])
 
   // Build left halo on width dim.
-  // CHECK:           %[[SLICE_W_LEFT_BEGIN:.*]] = "tf.Const"() {value = dense<[0, 0, 3, 0]> : tensor<4xi32>} : () -> tensor<4xi32>
-  // CHECK-NEXT:      %[[SLICE_W_LEFT_SIZE:.*]] = "tf.Const"() {value = dense<[8, 5, 1, 3]> : tensor<4xi32>} : () -> tensor<4xi32>
+  // CHECK:           %[[SLICE_W_LEFT_BEGIN:.*]] = "tf.Const"() <{value = dense<[0, 0, 3, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
+  // CHECK-NEXT:      %[[SLICE_W_LEFT_SIZE:.*]] = "tf.Const"() <{value = dense<[8, 5, 1, 3]> : tensor<4xi32>}> : () -> tensor<4xi32>
   // CHECK-NEXT:      %[[SLICE_W_LEFT:.*]] = "tf.Slice"(%[[VALID_SLICE_H_TENSOR]], %[[SLICE_W_LEFT_BEGIN]], %[[SLICE_W_LEFT_SIZE]])
   // CHECK-SAME:          (tensor<8x5x4x3xf32>, tensor<4xi32>, tensor<4xi32>) -> tensor<8x5x1x3xf32>
   // CHECK-NEXT:      %[[HALO_W_LEFT:.*]] = "tf.SelectV2"
@@ -433,8 +433,8 @@
   // CHECK-NEXT:      %[[EXCHANGED_HALO_W_LEFT:.*]] = "tf.CollectivePermute"(%[[HALO_W_LEFT]], %[[PAIRS_W_LEFT]])
   // CHECK-SAME:          (tensor<8x5x1x3xf32>, tensor<4x2xi32>) -> tensor<8x5x1x3xf32>
   // Build right halo on width dim.
-  // CHECK:           %[[SLICE_W_RIGHT_BEGIN:.*]] = "tf.Const"() {value = dense<0> : tensor<4xi32>} : () -> tensor<4xi32>
-  // CHECK-NEXT:      %[[SLICE_W_RIGHT_SIZE:.*]] = "tf.Const"() {value = dense<[8, 5, 1, 3]> : tensor<4xi32>} : () -> tensor<4xi32>
+  // CHECK:           %[[SLICE_W_RIGHT_BEGIN:.*]] = "tf.Const"() <{value = dense<0> : tensor<4xi32>}> : () -> tensor<4xi32>
+  // CHECK-NEXT:      %[[SLICE_W_RIGHT_SIZE:.*]] = "tf.Const"() <{value = dense<[8, 5, 1, 3]> : tensor<4xi32>}> : () -> tensor<4xi32>
   // CHECK-NEXT:      %[[SLICE_W_RIGHT:.*]] = "tf.Slice"(%[[VALID_SLICE_H_TENSOR]], %[[SLICE_W_RIGHT_BEGIN]], %[[SLICE_W_RIGHT_SIZE]])
   // CHECK-SAME:          (tensor<8x5x4x3xf32>, tensor<4xi32>, tensor<4xi32>) -> tensor<8x5x1x3xf32>
   // CHECK-NEXT:      %[[HALO_W_RIGHT:.*]] = "tf.SelectV2"
@@ -443,16 +443,16 @@
   // CHECK-NEXT:      %[[EXCHANGED_HALO_W_RIGHT:.*]] = "tf.CollectivePermute"(%[[HALO_W_RIGHT]], %[[PAIRS_W_RIGHT]])
   // CHECK-SAME:          (tensor<8x5x1x3xf32>, tensor<4x2xi32>) -> tensor<8x5x1x3xf32>
   // Concat the halos with the shard on the width dim.
-  // CHECK-NEXT:      %[[CONCAT_W_AXIS:.*]] = "tf.Const"() {value = dense<2> : tensor<i64>} : () -> tensor<i64>
+  // CHECK-NEXT:      %[[CONCAT_W_AXIS:.*]] = "tf.Const"() <{value = dense<2> : tensor<i64>}> : () -> tensor<i64>
   // CHECK-NEXT:      %[[CONCAT_HW_TENSOR:.*]] = "tf.ConcatV2"(%[[EXCHANGED_HALO_W_LEFT]], %[[VALID_SLICE_H_TENSOR]], %[[EXCHANGED_HALO_W_RIGHT]], %[[CONCAT_W_AXIS]])
   // CHECK-SAME:          (tensor<8x5x1x3xf32>, tensor<8x5x4x3xf32>, tensor<8x5x1x3xf32>, tensor<i64>) -> tensor<8x5x6x3xf32>
   // Dynamically slice the concatenated tensor to get correct size for VALID padding.
-  // CHECK-NEXT:      %[[HALO_SIZES_W:.*]] = "tf.Const"() {value = dense<[0, 0, 1, 0]> : tensor<4xi32>} : () -> tensor<4xi32>
-  // CHECK-NEXT:      %[[HALO_INCREMENTS_W:.*]] = "tf.Const"() {value = dense<[0, 0, 1, 0]> : tensor<4xi32>} : () -> tensor<4xi32>
+  // CHECK-NEXT:      %[[HALO_SIZES_W:.*]] = "tf.Const"() <{value = dense<[0, 0, 1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
+  // CHECK-NEXT:      %[[HALO_INCREMENTS_W:.*]] = "tf.Const"() <{value = dense<[0, 0, 1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
   // CHECK-NEXT:      %[[VALID_OFFSET_W:.*]] = "tf.Mul"
   // CHECK-NEXT:      %[[VALID_SLICE_BEGIN_W:.*]] = "tf.Sub"(%[[HALO_SIZES_W]], %[[VALID_OFFSET_W]])
-  // CHECK-NEXT:      %[[VALID_SLICE_SIZE_W:.*]] = "tf.Const"() {value = dense<[8, 5, 5, 3]> : tensor<4xi64>} : () -> tensor<4xi64>
-  // CHECK-NEXT:      %[[VALID_SLICE_BEGIN_CAST_I64_W:.*]] = "tf.Cast"(%[[VALID_SLICE_BEGIN_W]]) {Truncate = false} : (tensor<4xi32>) -> tensor<4xi64>
+  // CHECK-NEXT:      %[[VALID_SLICE_SIZE_W:.*]] = "tf.Const"() <{value = dense<[8, 5, 5, 3]> : tensor<4xi64>}> : () -> tensor<4xi64>
+  // CHECK-NEXT:      %[[VALID_SLICE_BEGIN_CAST_I64_W:.*]] = "tf.Cast"(%[[VALID_SLICE_BEGIN_W]]) <{Truncate = false}> : (tensor<4xi32>) -> tensor<4xi64>
   // CHECK-NEXT:      %[[VALID_SLICE_HW_TENSOR:.*]] = "tf.Slice"(%[[CONCAT_HW_TENSOR]], %[[VALID_SLICE_BEGIN_CAST_I64_W]], %[[VALID_SLICE_SIZE_W]])
 
   // CHECK-NEXT:      "tf.Conv2D"(%[[VALID_SLICE_HW_TENSOR]], %arg2)
@@ -478,8 +478,8 @@
   // CHECK:         "tf_device.cluster"
 
   // Build left halo on depth dim.
-  // CHECK:           %[[SLICE_D_LEFT_BEGIN:.*]] = "tf.Const"() {value = dense<[0, 3, 0, 0, 0]> : tensor<5xi32>} : () -> tensor<5xi32>
-  // CHECK-NEXT:      %[[SLICE_D_LEFT_SIZE:.*]] = "tf.Const"() {value = dense<[8, 1, 4, 4, 3]> : tensor<5xi32>} : () -> tensor<5xi32>
+  // CHECK:           %[[SLICE_D_LEFT_BEGIN:.*]] = "tf.Const"() <{value = dense<[0, 3, 0, 0, 0]> : tensor<5xi32>}> : () -> tensor<5xi32>
+  // CHECK-NEXT:      %[[SLICE_D_LEFT_SIZE:.*]] = "tf.Const"() <{value = dense<[8, 1, 4, 4, 3]> : tensor<5xi32>}> : () -> tensor<5xi32>
   // CHECK-NEXT:      %[[SLICE_D_LEFT:.*]] = "tf.Slice"(%arg1, %[[SLICE_D_LEFT_BEGIN]], %[[SLICE_D_LEFT_SIZE]])
   // CHECK-SAME:          (tensor<8x4x4x4x3xf32>, tensor<5xi32>, tensor<5xi32>) -> tensor<8x1x4x4x3xf32>
   // CHECK-NEXT:      %[[HALO_D_LEFT:.*]] = "tf.SelectV2"
@@ -488,8 +488,8 @@
   // CHECK-NEXT:      %[[EXCHANGED_HALO_D_LEFT:.*]] = "tf.CollectivePermute"(%[[HALO_D_LEFT]], %[[PAIRS_D_LEFT]])
   // CHECK-SAME:          (tensor<8x1x4x4x3xf32>, tensor<8x2xi32>) -> tensor<8x1x4x4x3xf32>
   // Build right halo on depth dim.
-  // CHECK:           %[[SLICE_D_RIGHT_BEGIN:.*]] = "tf.Const"() {value = dense<0> : tensor<5xi32>} : () -> tensor<5xi32>
-  // CHECK-NEXT:      %[[SLICE_D_RIGHT_SIZE:.*]] = "tf.Const"() {value = dense<[8, 1, 4, 4, 3]> : tensor<5xi32>} : () -> tensor<5xi32>
+  // CHECK:           %[[SLICE_D_RIGHT_BEGIN:.*]] = "tf.Const"() <{value = dense<0> : tensor<5xi32>}> : () -> tensor<5xi32>
+  // CHECK-NEXT:      %[[SLICE_D_RIGHT_SIZE:.*]] = "tf.Const"() <{value = dense<[8, 1, 4, 4, 3]> : tensor<5xi32>}> : () -> tensor<5xi32>
   // CHECK-NEXT:      %[[SLICE_D_RIGHT:.*]] = "tf.Slice"(%arg1, %[[SLICE_D_RIGHT_BEGIN]], %[[SLICE_D_RIGHT_SIZE]])
   // CHECK-SAME:          (tensor<8x4x4x4x3xf32>, tensor<5xi32>, tensor<5xi32>) -> tensor<8x1x4x4x3xf32>
   // CHECK-NEXT:      %[[HALO_D_RIGHT:.*]] = "tf.SelectV2"
@@ -498,21 +498,21 @@
   // CHECK-NEXT:      %[[EXCHANGED_HALO_D_RIGHT:.*]] = "tf.CollectivePermute"(%[[HALO_D_RIGHT]], %[[PAIRS_D_RIGHT]])
   // CHECK-SAME:          (tensor<8x1x4x4x3xf32>, tensor<8x2xi32>) -> tensor<8x1x4x4x3xf32>
   // Concat the halos with the shard on the depth dim.
-  // CHECK-NEXT:      %[[CONCAT_D_AXIS:.*]] = "tf.Const"() {value = dense<1> : tensor<i64>} : () -> tensor<i64>
+  // CHECK-NEXT:      %[[CONCAT_D_AXIS:.*]] = "tf.Const"() <{value = dense<1> : tensor<i64>}> : () -> tensor<i64>
   // CHECK-NEXT:      %[[CONCAT_D_TENSOR:.*]] = "tf.ConcatV2"(%[[EXCHANGED_HALO_D_LEFT]], %arg1, %[[EXCHANGED_HALO_D_RIGHT]], %[[CONCAT_D_AXIS]])
   // CHECK-SAME:          (tensor<8x1x4x4x3xf32>, tensor<8x4x4x4x3xf32>, tensor<8x1x4x4x3xf32>, tensor<i64>) -> tensor<8x6x4x4x3xf32>
   // Dynamically slice the concatenated tensor to get correct size for VALID padding.
-  // CHECK-NEXT:      %[[HALO_SIZES_D:.*]] = "tf.Const"() {value = dense<[0, 1, 0, 0, 0]> : tensor<5xi32>} : () -> tensor<5xi32>
-  // CHECK-NEXT:      %[[HALO_INCREMENTS_D:.*]] = "tf.Const"() {value = dense<[0, 1, 0, 0, 0]> : tensor<5xi32>} : () -> tensor<5xi32>
+  // CHECK-NEXT:      %[[HALO_SIZES_D:.*]] = "tf.Const"() <{value = dense<[0, 1, 0, 0, 0]> : tensor<5xi32>}> : () -> tensor<5xi32>
+  // CHECK-NEXT:      %[[HALO_INCREMENTS_D:.*]] = "tf.Const"() <{value = dense<[0, 1, 0, 0, 0]> : tensor<5xi32>}> : () -> tensor<5xi32>
   // CHECK-NEXT:      %[[VALID_OFFSET_D:.*]] = "tf.Mul"
   // CHECK-NEXT:      %[[VALID_SLICE_BEGIN_D:.*]] = "tf.Sub"(%[[HALO_SIZES_D]], %[[VALID_OFFSET_D]])
-  // CHECK-NEXT:      %[[VALID_SLICE_SIZE_D:.*]] = "tf.Const"() {value = dense<[8, 5, 4, 4, 3]> : tensor<5xi64>} : () -> tensor<5xi64>
-  // CHECK-NEXT:      %[[VALID_SLICE_BEGIN_CAST_I64_D:.*]] = "tf.Cast"(%[[VALID_SLICE_BEGIN_D]]) {Truncate = false} : (tensor<5xi32>) -> tensor<5xi64>
+  // CHECK-NEXT:      %[[VALID_SLICE_SIZE_D:.*]] = "tf.Const"() <{value = dense<[8, 5, 4, 4, 3]> : tensor<5xi64>}> : () -> tensor<5xi64>
+  // CHECK-NEXT:      %[[VALID_SLICE_BEGIN_CAST_I64_D:.*]] = "tf.Cast"(%[[VALID_SLICE_BEGIN_D]]) <{Truncate = false}> : (tensor<5xi32>) -> tensor<5xi64>
   // CHECK-NEXT:      %[[VALID_SLICE_D_TENSOR:.*]] = "tf.Slice"(%[[CONCAT_D_TENSOR]], %[[VALID_SLICE_BEGIN_CAST_I64_D]], %[[VALID_SLICE_SIZE_D]])
 
   // Build left halo on height dim.
-  // CHECK:           %[[SLICE_H_LEFT_BEGIN:.*]] = "tf.Const"() {value = dense<[0, 0, 3, 0, 0]> : tensor<5xi32>} : () -> tensor<5xi32>
-  // CHECK-NEXT:      %[[SLICE_H_LEFT_SIZE:.*]] = "tf.Const"() {value = dense<[8, 5, 1, 4, 3]> : tensor<5xi32>} : () -> tensor<5xi32>
+  // CHECK:           %[[SLICE_H_LEFT_BEGIN:.*]] = "tf.Const"() <{value = dense<[0, 0, 3, 0, 0]> : tensor<5xi32>}> : () -> tensor<5xi32>
+  // CHECK-NEXT:      %[[SLICE_H_LEFT_SIZE:.*]] = "tf.Const"() <{value = dense<[8, 5, 1, 4, 3]> : tensor<5xi32>}> : () -> tensor<5xi32>
   // CHECK-NEXT:      %[[SLICE_H_LEFT:.*]] = "tf.Slice"(%[[VALID_SLICE_D_TENSOR]], %[[SLICE_H_LEFT_BEGIN]], %[[SLICE_H_LEFT_SIZE]])
   // CHECK-SAME:          (tensor<8x5x4x4x3xf32>, tensor<5xi32>, tensor<5xi32>) -> tensor<8x5x1x4x3xf32>
   // CHECK-NEXT:      %[[HALO_H_LEFT:.*]] = "tf.SelectV2"
@@ -521,8 +521,8 @@
   // CHECK-NEXT:      %[[EXCHANGED_HALO_H_LEFT:.*]] = "tf.CollectivePermute"(%[[HALO_H_LEFT]], %[[PAIRS_H_LEFT]])
   // CHECK-SAME:          (tensor<8x5x1x4x3xf32>, tensor<8x2xi32>) -> tensor<8x5x1x4x3xf32>
   // Build right halo on height dim.
-  // CHECK:           %[[SLICE_H_RIGHT_BEGIN:.*]] = "tf.Const"() {value = dense<0> : tensor<5xi32>} : () -> tensor<5xi32>
-  // CHECK-NEXT:      %[[SLICE_H_RIGHT_SIZE:.*]] = "tf.Const"() {value = dense<[8, 5, 1, 4, 3]> : tensor<5xi32>} : () -> tensor<5xi32>
+  // CHECK:           %[[SLICE_H_RIGHT_BEGIN:.*]] = "tf.Const"() <{value = dense<0> : tensor<5xi32>}> : () -> tensor<5xi32>
+  // CHECK-NEXT:      %[[SLICE_H_RIGHT_SIZE:.*]] = "tf.Const"() <{value = dense<[8, 5, 1, 4, 3]> : tensor<5xi32>}> : () -> tensor<5xi32>
   // CHECK-NEXT:      %[[SLICE_H_RIGHT:.*]] = "tf.Slice"(%[[VALID_SLICE_D_TENSOR]], %[[SLICE_H_RIGHT_BEGIN]], %[[SLICE_H_RIGHT_SIZE]])
   // CHECK-SAME:          (tensor<8x5x4x4x3xf32>, tensor<5xi32>, tensor<5xi32>) -> tensor<8x5x1x4x3xf32>
   // CHECK-NEXT:      %[[HALO_H_RIGHT:.*]] = "tf.SelectV2"
@@ -531,21 +531,21 @@
   // CHECK-NEXT:      %[[EXCHANGED_HALO_H_RIGHT:.*]] = "tf.CollectivePermute"(%[[HALO_H_RIGHT]], %[[PAIRS_H_RIGHT]])
   // CHECK-SAME:          (tensor<8x5x1x4x3xf32>, tensor<8x2xi32>) -> tensor<8x5x1x4x3xf32>
   // Concat the halos with the shard on the height dim.
-  // CHECK-NEXT:      %[[CONCAT_H_AXIS:.*]] = "tf.Const"() {value = dense<2> : tensor<i64>} : () -> tensor<i64>
+  // CHECK-NEXT:      %[[CONCAT_H_AXIS:.*]] = "tf.Const"() <{value = dense<2> : tensor<i64>}> : () -> tensor<i64>
   // CHECK-NEXT:      %[[CONCAT_DH_TENSOR:.*]] = "tf.ConcatV2"(%[[EXCHANGED_HALO_H_LEFT]], %[[VALID_SLICE_D_TENSOR]], %[[EXCHANGED_HALO_H_RIGHT]], %[[CONCAT_H_AXIS]])
   // CHECK-SAME:          (tensor<8x5x1x4x3xf32>, tensor<8x5x4x4x3xf32>, tensor<8x5x1x4x3xf32>, tensor<i64>) -> tensor<8x5x6x4x3xf32>
   // Dynamically slice the concatenated tensor to get correct size for VALID padding.
-  // CHECK-NEXT:      %[[HALO_SIZES_H:.*]] = "tf.Const"() {value = dense<[0, 0, 1, 0, 0]> : tensor<5xi32>} : () -> tensor<5xi32>
-  // CHECK-NEXT:      %[[HALO_INCREMENTS_H:.*]] = "tf.Const"() {value = dense<[0, 0, 1, 0, 0]> : tensor<5xi32>} : () -> tensor<5xi32>
+  // CHECK-NEXT:      %[[HALO_SIZES_H:.*]] = "tf.Const"() <{value = dense<[0, 0, 1, 0, 0]> : tensor<5xi32>}> : () -> tensor<5xi32>
+  // CHECK-NEXT:      %[[HALO_INCREMENTS_H:.*]] = "tf.Const"() <{value = dense<[0, 0, 1, 0, 0]> : tensor<5xi32>}> : () -> tensor<5xi32>
   // CHECK-NEXT:      %[[VALID_OFFSET_H:.*]] = "tf.Mul"
   // CHECK-NEXT:      %[[VALID_SLICE_BEGIN_H:.*]] = "tf.Sub"(%[[HALO_SIZES_H]], %[[VALID_OFFSET_H]])
-  // CHECK-NEXT:      %[[VALID_SLICE_SIZE_H:.*]] = "tf.Const"() {value = dense<[8, 5, 5, 4, 3]> : tensor<5xi64>} : () -> tensor<5xi64>
-  // CHECK-NEXT:      %[[VALID_SLICE_BEGIN_CAST_I64_H:.*]] = "tf.Cast"(%[[VALID_SLICE_BEGIN_H]]) {Truncate = false} : (tensor<5xi32>) -> tensor<5xi64>
+  // CHECK-NEXT:      %[[VALID_SLICE_SIZE_H:.*]] = "tf.Const"() <{value = dense<[8, 5, 5, 4, 3]> : tensor<5xi64>}> : () -> tensor<5xi64>
+  // CHECK-NEXT:      %[[VALID_SLICE_BEGIN_CAST_I64_H:.*]] = "tf.Cast"(%[[VALID_SLICE_BEGIN_H]]) <{Truncate = false}> : (tensor<5xi32>) -> tensor<5xi64>
   // CHECK-NEXT:      %[[VALID_SLICE_DH_TENSOR:.*]] = "tf.Slice"(%[[CONCAT_DH_TENSOR]], %[[VALID_SLICE_BEGIN_CAST_I64_H]], %[[VALID_SLICE_SIZE_H]])
 
   // Build left halo on width dim.
-  // CHECK:           %[[SLICE_W_LEFT_BEGIN:.*]] = "tf.Const"() {value = dense<[0, 0, 0, 3, 0]> : tensor<5xi32>} : () -> tensor<5xi32>
-  // CHECK-NEXT:      %[[SLICE_W_LEFT_SIZE:.*]] = "tf.Const"() {value = dense<[8, 5, 5, 1, 3]> : tensor<5xi32>} : () -> tensor<5xi32>
+  // CHECK:           %[[SLICE_W_LEFT_BEGIN:.*]] = "tf.Const"() <{value = dense<[0, 0, 0, 3, 0]> : tensor<5xi32>}> : () -> tensor<5xi32>
+  // CHECK-NEXT:      %[[SLICE_W_LEFT_SIZE:.*]] = "tf.Const"() <{value = dense<[8, 5, 5, 1, 3]> : tensor<5xi32>}> : () -> tensor<5xi32>
   // CHECK-NEXT:      %[[SLICE_W_LEFT:.*]] = "tf.Slice"(%[[VALID_SLICE_DH_TENSOR]], %[[SLICE_W_LEFT_BEGIN]], %[[SLICE_W_LEFT_SIZE]])
   // CHECK-SAME:          (tensor<8x5x5x4x3xf32>, tensor<5xi32>, tensor<5xi32>) -> tensor<8x5x5x1x3xf32>
   // CHECK-NEXT:      %[[HALO_W_LEFT:.*]] = "tf.SelectV2"
@@ -554,8 +554,8 @@
   // CHECK-NEXT:      %[[EXCHANGED_HALO_W_LEFT:.*]] = "tf.CollectivePermute"(%[[HALO_W_LEFT]], %[[PAIRS_W_LEFT]])
   // CHECK-SAME:          (tensor<8x5x5x1x3xf32>, tensor<8x2xi32>) -> tensor<8x5x5x1x3xf32>
   // Build right halo on width dim.
-  // CHECK:           %[[SLICE_W_RIGHT_BEGIN:.*]] = "tf.Const"() {value = dense<0> : tensor<5xi32>} : () -> tensor<5xi32>
-  // CHECK-NEXT:      %[[SLICE_W_RIGHT_SIZE:.*]] = "tf.Const"() {value = dense<[8, 5, 5, 1, 3]> : tensor<5xi32>} : () -> tensor<5xi32>
+  // CHECK:           %[[SLICE_W_RIGHT_BEGIN:.*]] = "tf.Const"() <{value = dense<0> : tensor<5xi32>}> : () -> tensor<5xi32>
+  // CHECK-NEXT:      %[[SLICE_W_RIGHT_SIZE:.*]] = "tf.Const"() <{value = dense<[8, 5, 5, 1, 3]> : tensor<5xi32>}> : () -> tensor<5xi32>
   // CHECK-NEXT:      %[[SLICE_W_RIGHT:.*]] = "tf.Slice"(%[[VALID_SLICE_DH_TENSOR]], %[[SLICE_W_RIGHT_BEGIN]], %[[SLICE_W_RIGHT_SIZE]])
   // CHECK-SAME:          (tensor<8x5x5x4x3xf32>, tensor<5xi32>, tensor<5xi32>) -> tensor<8x5x5x1x3xf32>
   // CHECK-NEXT:      %[[HALO_W_RIGHT:.*]] = "tf.SelectV2"
@@ -564,16 +564,16 @@
   // CHECK-NEXT:      %[[EXCHANGED_HALO_W_RIGHT:.*]] = "tf.CollectivePermute"(%[[HALO_W_RIGHT]], %[[PAIRS_W_RIGHT]])
   // CHECK-SAME:          (tensor<8x5x5x1x3xf32>, tensor<8x2xi32>) -> tensor<8x5x5x1x3xf32>
   // Concat the halos with the shard on the width dim.
-  // CHECK-NEXT:      %[[CONCAT_W_AXIS:.*]] = "tf.Const"() {value = dense<3> : tensor<i64>} : () -> tensor<i64>
+  // CHECK-NEXT:      %[[CONCAT_W_AXIS:.*]] = "tf.Const"() <{value = dense<3> : tensor<i64>}> : () -> tensor<i64>
   // CHECK-NEXT:      %[[CONCAT_DHW_TENSOR:.*]] = "tf.ConcatV2"(%[[EXCHANGED_HALO_W_LEFT]], %[[VALID_SLICE_DH_TENSOR]], %[[EXCHANGED_HALO_W_RIGHT]], %[[CONCAT_W_AXIS]])
   // CHECK-SAME:          (tensor<8x5x5x1x3xf32>, tensor<8x5x5x4x3xf32>, tensor<8x5x5x1x3xf32>, tensor<i64>) -> tensor<8x5x5x6x3xf32>
   // Dynamically slice the concatenated tensor to get correct size for VALID padding.
-  // CHECK-NEXT:      %[[HALO_SIZES_W:.*]] = "tf.Const"() {value = dense<[0, 0, 0, 1, 0]> : tensor<5xi32>} : () -> tensor<5xi32>
-  // CHECK-NEXT:      %[[HALO_INCREMENTS_W:.*]] = "tf.Const"() {value = dense<[0, 0, 0, 1, 0]> : tensor<5xi32>} : () -> tensor<5xi32>
+  // CHECK-NEXT:      %[[HALO_SIZES_W:.*]] = "tf.Const"() <{value = dense<[0, 0, 0, 1, 0]> : tensor<5xi32>}> : () -> tensor<5xi32>
+  // CHECK-NEXT:      %[[HALO_INCREMENTS_W:.*]] = "tf.Const"() <{value = dense<[0, 0, 0, 1, 0]> : tensor<5xi32>}> : () -> tensor<5xi32>
   // CHECK-NEXT:      %[[VALID_OFFSET_W:.*]] = "tf.Mul"
   // CHECK-NEXT:      %[[VALID_SLICE_BEGIN_W:.*]] = "tf.Sub"(%[[HALO_SIZES_W]], %[[VALID_OFFSET_W]])
-  // CHECK-NEXT:      %[[VALID_SLICE_SIZE_W:.*]] = "tf.Const"() {value = dense<[8, 5, 5, 5, 3]> : tensor<5xi64>} : () -> tensor<5xi64>
-  // CHECK-NEXT:      %[[VALID_SLICE_BEGIN_CAST_I64_W:.*]] = "tf.Cast"(%[[VALID_SLICE_BEGIN_W]]) {Truncate = false} : (tensor<5xi32>) -> tensor<5xi64>
+  // CHECK-NEXT:      %[[VALID_SLICE_SIZE_W:.*]] = "tf.Const"() <{value = dense<[8, 5, 5, 5, 3]> : tensor<5xi64>}> : () -> tensor<5xi64>
+  // CHECK-NEXT:      %[[VALID_SLICE_BEGIN_CAST_I64_W:.*]] = "tf.Cast"(%[[VALID_SLICE_BEGIN_W]]) <{Truncate = false}> : (tensor<5xi32>) -> tensor<5xi64>
   // CHECK-NEXT:      %[[VALID_SLICE_DHW_TENSOR:.*]] = "tf.Slice"(%[[CONCAT_DHW_TENSOR]], %[[VALID_SLICE_BEGIN_CAST_I64_W]], %[[VALID_SLICE_SIZE_W]])
 
   // CHECK-NEXT:      "tf.Conv3D"(%[[VALID_SLICE_DHW_TENSOR]], %arg2)
diff --git a/tensorflow/dtensor/mlir/tests/spmd_expansion.mlir b/tensorflow/dtensor/mlir/tests/spmd_expansion.mlir
index c8ab88a..4e35ca8 100644
--- a/tensorflow/dtensor/mlir/tests/spmd_expansion.mlir
+++ b/tensorflow/dtensor/mlir/tests/spmd_expansion.mlir
@@ -171,7 +171,7 @@
 module @test_spmd_const_op_sharded_with_splat {
 func.func @main(%arg0: tensor<i32>) {
   // CHECK:        "tf_device.cluster"
-  // CHECK-NEXT:      %[[CONST_OUT:.*]] = "tf.Const"() {[[BEFORE_ATTR:.*]]value = dense<1> : tensor<1xi32>[[AFTER_ATTR:.*]]} : () -> tensor<1xi32>
+  // CHECK-NEXT:      %[[CONST_OUT:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}> {{{.*}}} : () -> tensor<1xi32>
   // CHECK-NEXT:      tf_device.return
   %0 = "tf_device.cluster"() ({
    %1 = "tf.Const"() {_layout = ["sharding_specs:x, mesh:TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"], value=dense<1>: tensor<2xi32>} : () -> tensor<2xi32>
@@ -340,8 +340,8 @@
 module @test_spmd_random_op_with_incomplete_shape_disallowed {
 func.func @main(%arg0: tensor<i32>) {
   %0 = "tf_device.cluster"() ({
-    // %1 = "tf.Const"() {_layout = ["sharding_specs:x,unsharded, mesh:TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"], device = "", value = dense<16> : tensor<2xi32>} : () -> tensor<2xi32>
-    // %2 = "tf.Const"() {_layout = ["sharding_specs:x,unsharded, mesh:TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"], device = "", value = dense<[123, 321]> : tensor<2xi32>} : () -> tensor<2xi32>
+    // %1 = "tf.Const"() <{value = dense<16> : tensor<2xi32>}> {_layout = ["sharding_specs:x,unsharded, mesh:TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"], device = ""} : () -> tensor<2xi32>
+    // %2 = "tf.Const"() <{value = dense<[123, 321]> : tensor<2xi32>}> {_layout = ["sharding_specs:x,unsharded, mesh:TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"], device = ""} : () -> tensor<2xi32>
     %1 = arith.constant dense<[16]> : tensor<1xi32>
     %2 = arith.constant dense<[2, 1]> : tensor<2xi32>
     // expected-error @+1 {{Sharding dimension of random op does not match rank of the random op}}
@@ -764,7 +764,7 @@
   // CHECK:      "tf_device.cluster"
   // CHECK-NEXT:   "tf.Const"()
   // CHECK-NEXT:   %[[INDICES:.*]] = "tf.Const"()
-  // CHECK-NEXT:   %[[NEW_SHAPE:.*]] = "tf.Const"() {value = dense<[16, 2, 4]> : tensor<3xi32>} : () -> tensor<3xi32>
+  // CHECK-NEXT:   %[[NEW_SHAPE:.*]] = "tf.Const"() <{value = dense<[16, 2, 4]> : tensor<3xi32>}> : () -> tensor<3xi32>
   // CHECK-NEXT:   "tf.ScatterNd"(%[[INDICES]], %arg0, %[[NEW_SHAPE]])
   %0 = "tf_device.cluster"() ({
     %shape = "tf.Const"() {_global_shape = [#tf_type.shape<3>], value = dense<[16, 4, 4]> : tensor<3xi32>} : () -> tensor<3xi32>
diff --git a/tensorflow/dtensor/mlir/tests/spmd_iterator.mlir b/tensorflow/dtensor/mlir/tests/spmd_iterator.mlir
index ca2cc5f..aba9c8c 100644
--- a/tensorflow/dtensor/mlir/tests/spmd_iterator.mlir
+++ b/tensorflow/dtensor/mlir/tests/spmd_iterator.mlir
@@ -64,8 +64,8 @@
   // CHECK-SAME:     (tensor<*x!tf_type.resource>) -> tensor<8x16xf32>
   // CHECK:        "tf.WhileRegion"
   // CHECK:        %[[ITER_OPTIONAL_OUT:.*]] = "tf.IteratorGetNextAsOptional"(%arg1)
-  // CHECK-SAME:     _layout = ["sharding_specs: mesh:|x=4,y=2|0,1,2,3,4,5,6,7|0,1,2,3,4,5,6,7|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3,/job:localhost/task:0/device:TPU:4,/job:localhost/task:0/device:TPU:5,/job:localhost/task:0/device:TPU:6,/job:localhost/task:0/device:TPU:7"]
   // CHECK-SAME:     output_shapes = [#tf_type.shape<8x16>]
+  // CHECK-SAME:     _layout = ["sharding_specs: mesh:|x=4,y=2|0,1,2,3,4,5,6,7|0,1,2,3,4,5,6,7|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3,/job:localhost/task:0/device:TPU:4,/job:localhost/task:0/device:TPU:5,/job:localhost/task:0/device:TPU:6,/job:localhost/task:0/device:TPU:7"]
   // CHECK-SAME:     (tensor<*x!tf_type.resource>) -> tensor<!tf_type.variant>
   // CHECK-NEXT:   %[[HAS_VALUE:.*]] = "tf.OptionalHasValue"(%[[ITER_OPTIONAL_OUT]])
   // CHECK-SAME:     _layout = ["sharding_specs: mesh:|x=4,y=2|0,1,2,3,4,5,6,7|0,1,2,3,4,5,6,7|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3,/job:localhost/task:0/device:TPU:4,/job:localhost/task:0/device:TPU:5,/job:localhost/task:0/device:TPU:6,/job:localhost/task:0/device:TPU:7"]
diff --git a/tensorflow/dtensor/mlir/tests/spmd_matmul.mlir b/tensorflow/dtensor/mlir/tests/spmd_matmul.mlir
index da16613..2d85bf1 100644
--- a/tensorflow/dtensor/mlir/tests/spmd_matmul.mlir
+++ b/tensorflow/dtensor/mlir/tests/spmd_matmul.mlir
@@ -8,10 +8,10 @@
   // CHECK:    "tf_device.cluster"
   // CHECK:      %[[MATMUL_OUT:.*]] = "tf.BatchMatMulV2"(%arg1, %arg2)
   // CHECK-SAME: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
-  // CHECK:      %[[GROUP_ID:.*]] = "tf.Const"() {value = dense<{{.*}}> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
+  // CHECK:      %[[GROUP_ID:.*]] = "tf.Const"() <{value = dense<{{.*}}> : tensor<2x2xi32>}> : () -> tensor<2x2xi32>
   // CHECK:      %[[SUM_OUT:.*]] = "tf.DTensorAllReduce"(%[[MATMUL_OUT]], %[[GROUP_ID]])
-  // CHECK-SAME: _layout = ["sharding_specs:unsharded,unsharded, mesh:TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"]
   // CHECK-SAME: reduce_op = "Add"
+  // CHECK-SAME: _layout = ["sharding_specs:unsharded,unsharded, mesh:TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"]
   // CHECK-SAME: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
   // CHECK-NEXT: tf_device.return
   // CHECK-SAME: %[[SUM_OUT]]
@@ -35,10 +35,10 @@
   // CHECK:    "tf_device.cluster"
   // CHECK:      %[[MATMUL_OUT:.*]] = "tf.BatchMatMulV2"
   // CHECK-SAME: (tensor<4x2x2xi32>, tensor<4x2x2xi32>) -> tensor<4x2x2xi32>
-  // CHECK:      %[[GROUP_ID:.*]] = "tf.Const"() {value = dense<{{.*}}> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
+  // CHECK:      %[[GROUP_ID:.*]] = "tf.Const"() <{value = dense<{{.*}}> : tensor<2x2xi32>}> : () -> tensor<2x2xi32>
   // CHECK:      %[[SUM_OUT:.*]] = "tf.DTensorAllReduce"(%[[MATMUL_OUT]], %[[GROUP_ID]])
-  // CHECK-SAME: _layout = ["sharding_specs:x,unsharded,unsharded, mesh:TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"]
   // CHECK-SAME: reduce_op = "Add"
+  // CHECK-SAME: _layout = ["sharding_specs:x,unsharded,unsharded, mesh:TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"]
   // CHECK-SAME: (tensor<4x2x2xi32>, tensor<2x2xi32>) -> tensor<4x2x2xi32>
   // CHECK-NEXT: tf_device.return
   // CHECK-SAME: %[[SUM_OUT]]
@@ -80,8 +80,8 @@
   // CHECK-SAME:   (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
   // CHECK-NEXT:   %[[GROUP_ID:.*]] = "tf.Const"()
   // CHECK-NEXT:   %[[SUM_OUT:.*]] = "tf.DTensorAllReduce"(%[[MATMUL_OUT]], %[[GROUP_ID]])
-  // CHECK-SAME:   _layout = ["sharding_specs:unsharded,unsharded, mesh:TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"]
   // CHECK-SAME: reduce_op = "Add"
+  // CHECK-SAME:   _layout = ["sharding_specs:unsharded,unsharded, mesh:TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"]
   // CHECK-SAME:   (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
   // CHECK-NEXT:   tf_device.return
   // CHECK-SAME:   %[[SUM_OUT]]
diff --git a/tensorflow/dtensor/mlir/tests/spmd_random.mlir b/tensorflow/dtensor/mlir/tests/spmd_random.mlir
index 257d7c6..183908d 100644
--- a/tensorflow/dtensor/mlir/tests/spmd_random.mlir
+++ b/tensorflow/dtensor/mlir/tests/spmd_random.mlir
@@ -30,7 +30,7 @@
   // CHECK:      "tf_device.cluster"
   // CHECK-NEXT: %[[MESH_SIZES:.*]] = "tf.Const"()
   // CHECK-SAME: 4, 2, 2
-  // CHECK-NEXT: %[[MESH_SIZES_RUNNING_PRODUCT:.*]] = "tf.Const"() {value =
+  // CHECK-NEXT: %[[MESH_SIZES_RUNNING_PRODUCT:.*]] = "tf.Const"() <{value =
   // CHECK-SAME: 4, 2, 1
   // CHECK-NEXT: %[[MESH_COORDS_PRE_MOD:.*]] = "tf.Div"(%arg0, %[[MESH_SIZES_RUNNING_PRODUCT]])
   // CHECK-NEXT: %[[MESH_COORDS:.*]] = "tf.FloorMod"(%[[MESH_COORDS_PRE_MOD]], %[[MESH_SIZES]])
@@ -38,15 +38,15 @@
   // CHECK-NEXT: %[[MESH_MULTIPLER:.*]] = "tf.Const"()
   // CHECK-SAME [65536], [0], [262144]
   // CHECK-NEXT: %[[DEVICE_SEED:.*]] = "tf.MatMul"(%[[MESH_COORDS]], %[[MESH_MULTIPLER]])
-  // CHECK-NEXT: %[[PRIME:.*]] = "tf.Const"() {value = dense<65521>
+  // CHECK-NEXT: %[[PRIME:.*]] = "tf.Const"() <{value = dense<65521>
   // CHECK-NEXT: %[[DEVICE_SEED_PRIME:.*]] = "tf.AddV2"(%[[DEVICE_SEED]], %[[PRIME]])
-  // CHECK-NEXT: %[[DEVICE_SEED_SQUEEZE:.*]] = "tf.Squeeze"(%[[DEVICE_SEED_PRIME]]) {
+  // CHECK-NEXT: %[[DEVICE_SEED_SQUEEZE:.*]] = "tf.Squeeze"(%[[DEVICE_SEED_PRIME]]) <{
   // CHECK-NOT: dtensor.device_seed_for_mesh_dims
   // CHECK-SAME: }
   // CHECK-NEXT: %[[OLD_SHAPE:.*]] = "tf.Const"(
   // CHECK-NEXT: %[[DEVICE_SEED_CAST:.*]] = "tf.Cast"(%[[DEVICE_SEED_SQUEEZE]])
   // CHECK-NEXT: %[[NEW_SEED:.*]] = "tf.BitwiseXor"(%arg1, %[[DEVICE_SEED_CAST]])
-  // CHECK-NEXT: %[[NEW_SHAPE:.*]] = "tf.Const"() {value = dense<[8, 32, 32]>
+  // CHECK-NEXT: %[[NEW_SHAPE:.*]] = "tf.Const"() <{value = dense<[8, 32, 32]>
   // CHECK-NEXT: %[[RANDOM:.*]] = "tf.StatelessRandomUniform"(%[[NEW_SHAPE]], %[[NEW_SEED]])
   // CHECK-NEXT: tf_device.return
   // CHECK-SAME: %[[RANDOM]]
diff --git a/tensorflow/dtensor/mlir/tests/spmd_save_restore.mlir b/tensorflow/dtensor/mlir/tests/spmd_save_restore.mlir
index c315360..43e6714 100644
--- a/tensorflow/dtensor/mlir/tests/spmd_save_restore.mlir
+++ b/tensorflow/dtensor/mlir/tests/spmd_save_restore.mlir
@@ -5,9 +5,9 @@
 func.func @main(%arg0: tensor<i32>) {
   "tf_device.cluster"() ({
     // CHECK:      "tf.Case"
-    // CHECK-SAME: branches = [@tf.[[D0:.*]], @tf.[[D1:.*]]]
+    // CHECK-SAME: branches = [@tf.[[D0:.*]], @tf.[[D1:.*]]], is_stateless = false
     // CHECK:      func private @tf.[[D0]]
-    // CHECK:      %[[CST:.*]] = "tf.Const"() {value = dense<"_dev-0-of-2">
+    // CHECK:      %[[CST:.*]] = "tf.Const"() <{value = dense<"_dev-0-of-2">
     // CHECK:      "tf.Add"(%arg0, %[[CST]])
     // CHECK:      ""
     // CHECK:      func private @tf.[[D1]]
@@ -31,13 +31,13 @@
 func.func @main(%arg0: tensor<i32>) {
   "tf_device.cluster"() ({
     // CHECK:      tf.Case
-    // CHECK-SAME: branches = [@tf.[[D0:.*]], @tf.[[D1:.*]]]
+    // CHECK-SAME: branches = [@tf.[[D0:.*]], @tf.[[D1:.*]]], is_stateless = false
     // CHECK:      func private @tf.[[D0]]
-    // CHECK:      %[[CST:.*]] = "tf.Const"() {value = dense<"_dev-0-of-2">
+    // CHECK:      %[[CST:.*]] = "tf.Const"() <{value = dense<"_dev-0-of-2">
     // CHECK:      "tf.Add"(%arg0, %[[CST]])
     // CHECK:      "2 0,1"
     // CHECK:      func private @tf.[[D1]]
-    // CHECK:      %[[CST:.*]] = "tf.Const"() {value = dense<"_dev-1-of-2">
+    // CHECK:      %[[CST:.*]] = "tf.Const"() <{value = dense<"_dev-1-of-2">
     // CHECK:      "tf.Add"(%arg0, %[[CST]])
     // CHECK:      "2 1,1"
     %0 = "tf.Const"() {value = dense<"/dev/null"> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string>
@@ -66,9 +66,9 @@
     // CHECK:      func private @tf.[[THEN]]
     // CHECK:      tf.NoOp
     // CHECK:      func private @tf.[[ELSE]]
-    // CHECK:      "tf.Const"() {value = dense<"_dev-0-of-2">
+    // CHECK:      "tf.Const"() <{value = dense<"_dev-0-of-2">
     // CHECK:      "tf.Add"
-    // CHECK:      "tf.Const"() {value = dense<"_dev-1-of-2">
+    // CHECK:      "tf.Const"() <{value = dense<"_dev-1-of-2">
     // CHECK:      "tf.Add"
     // CHECK:      "tf.Concat"
     // CHECK:      "tf.MergeV2Checkpoints"
@@ -109,11 +109,11 @@
       // CHECK:      "tf.Reshape"(%[[BRANCH_IDX]]
       // CHECK-SAME: (tensor<1x1xi32>, tensor<0xi32>) -> tensor<i32>
       // CHECK:      tf.Case
-      // CHECK-SAME: branches = [@tf.[[D0:.*]], @tf.[[D1:.*]]]
+      // CHECK-SAME: branches = [@tf.[[D0:.*]], @tf.[[D1:.*]]], is_stateless = false
       // CHECK:      func private @tf.[[D0]]
-      // CHECK:      "tf.Const"() {value = dense<["", "2 4 0,1:-"]>
+      // CHECK:      "tf.Const"() <{value = dense<["", "2 4 0,1:-"]>
       // CHECK:      func private @tf.[[D1]]
-      // CHECK:      "tf.Const"() {value = dense<["", "2 4 1,1:-"]>
+      // CHECK:      "tf.Const"() <{value = dense<["", "2 4 1,1:-"]>
       %1 = "tf.Const"() {_global_shape = [#tf_type.shape<2>], value = dense<""> : tensor<2x!tf_type.string>} : () -> tensor<2x!tf_type.string>
       %2 = "tf.Const"() {_global_shape = [#tf_type.shape<2>], value = dense<["model/r/.ATTRIBUTES/VARIABLE_VALUE", "model/s/.ATTRIBUTES/VARIABLE_VALUE"]> : tensor<2x!tf_type.string>} : () -> tensor<2x!tf_type.string>
       %3 = "tf.Const"() {_global_shape = [#tf_type.shape<>], value = dense<"/dev/null/ckpt-0"> : tensor<!tf_type.string>} : () -> tensor<!tf_type.string>
@@ -167,11 +167,11 @@
       // CHECK:      "tf.Reshape"(%[[BRANCH_IDX]]
       // CHECK-SAME: (tensor<1x1xi32>, tensor<0xi32>) -> tensor<i32>
       // CHECK:      tf.Case
-      // CHECK-SAME: branches = [@tf.[[D0:.*]], @tf.[[D1:.*]]]
+      // CHECK-SAME: branches = [@tf.[[D0:.*]], @tf.[[D1:.*]]], is_stateless = false
       // CHECK:      func private @tf.[[D0]]
-      // CHECK:      "tf.Const"() {value = dense<["", "2 4 0,1:-"]>
+      // CHECK:      "tf.Const"() <{value = dense<["", "2 4 0,1:-"]>
       // CHECK:      func private @tf.[[D1]]
-      // CHECK:      "tf.Const"() {value = dense<["", "2 4 1,1:-"]>
+      // CHECK:      "tf.Const"() <{value = dense<["", "2 4 1,1:-"]>
       %1 = "tf.Const"() {_global_shape = [#tf_type.shape<2>], value = dense<""> : tensor<2x!tf_type.string>} : () -> tensor<2x!tf_type.string>
       %2 = "tf.Const"() {_global_shape = [#tf_type.shape<2>], value = dense<["model/r/.ATTRIBUTES/VARIABLE_VALUE", "model/s/.ATTRIBUTES/VARIABLE_VALUE"]> : tensor<2x!tf_type.string>} : () -> tensor<2x!tf_type.string>
       %3 = "tf.Const"() {_global_shape = [#tf_type.shape<>], value = dense<"/dev/null/ckpt-0"> : tensor<!tf_type.string>} : () -> tensor<!tf_type.string>
diff --git a/tensorflow/dtensor/mlir/tests/spmd_segment_sum.mlir b/tensorflow/dtensor/mlir/tests/spmd_segment_sum.mlir
index 2122b72..6a96c60 100644
--- a/tensorflow/dtensor/mlir/tests/spmd_segment_sum.mlir
+++ b/tensorflow/dtensor/mlir/tests/spmd_segment_sum.mlir
@@ -11,8 +11,8 @@
   // CHECK:      %[[LOCAL_RESULT:.*]] = "tf.UnsortedSegmentSum"(%arg1, %arg2, %[[NUM_SEGMENTS]])
   // CHECK-SAME: (tensor<4x2xf32>, tensor<4xi32>, tensor<i32>) -> tensor<8x2xf32>
   // CHECK:      %[[RESULT:.*]] = "tf.DTensorAllReduce"(%[[LOCAL_RESULT]]
-  // CHECK-SAME: _layout = ["sharding_specs:unsharded,unsharded, mesh:TPU|x=4|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"]
   // CHECK-SAME: reduce_op = "Add"
+  // CHECK-SAME: _layout = ["sharding_specs:unsharded,unsharded, mesh:TPU|x=4|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"]
   // CHECK:      %[[FINAL_RESULT:.*]] = "tf.DTensorAllScatter"(%[[RESULT]]
   // CHECK-NEXT: tf_device.return
   // CHECK-SAME: %[[FINAL_RESULT]]
diff --git a/tensorflow/dtensor/mlir/tests/spmd_slice.mlir b/tensorflow/dtensor/mlir/tests/spmd_slice.mlir
index 6741b0c..fa1b556 100644
--- a/tensorflow/dtensor/mlir/tests/spmd_slice.mlir
+++ b/tensorflow/dtensor/mlir/tests/spmd_slice.mlir
@@ -93,7 +93,7 @@
            %arg1: tensor<2x4xf32> {tf._layout = "sharding_specs:unsharded,x, mesh:|x=2,y=2|*CPU"},
            %arg2: tensor<2xi64> {tf._layout = "sharding_specs:unsharded, mesh:|x=2,y=2|*CPU"}) -> tensor<1x4xf32> {
   // CHECK:      "tf_device.cluster"
-  // CHECK:        %[[SLICE_SIZE:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi64>} : () -> tensor<2xi64>
+  // CHECK:        %[[SLICE_SIZE:.*]] = "tf.Const"() <{value = dense<[1, 2]> : tensor<2xi64>}> : () -> tensor<2xi64>
   // CHECK-NEXT:   %[[SLICE:.*]] = "tf.Slice"(%arg1, %arg2, %[[SLICE_SIZE]])
   // CHECK-SAME:     _layout = ["sharding_specs:unsharded,x, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1,/job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3"]
   // CHECK-SAME:     (tensor<2x2xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x2xf32>
@@ -114,7 +114,7 @@
 // Check SPMD expansion of strided slice op with replicated input.
 func.func @main(%arg0: tensor<i32>, %arg1: tensor<2x4xf32> {tf._layout = "sharding_specs:unsharded,unsharded, mesh:|x=2,y=2|*CPU"}) -> tensor<2x2xf32> {
   // CHECK:      "tf_device.cluster"
-  // CHECK:      %cst_2 = "tf.Const"() {value = dense<2> : tensor<2xi32>}
+  // CHECK:      %cst_2 = "tf.Const"() <{value = dense<2> : tensor<2xi32>}>
   // CHECK:        "tf.StridedSlice"(%arg1, %cst, %cst_2, %cst_1)
   // CHECK-SAME:     _layout = ["sharding_specs:unsharded,unsharded, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1,/job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3"]
   // CHECK-SAME:     (tensor<2x4xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xf32>
diff --git a/tensorflow/dtensor/mlir/tests/spmd_squeeze.mlir b/tensorflow/dtensor/mlir/tests/spmd_squeeze.mlir
index 5bfda19..401940d 100644
--- a/tensorflow/dtensor/mlir/tests/spmd_squeeze.mlir
+++ b/tensorflow/dtensor/mlir/tests/spmd_squeeze.mlir
@@ -41,8 +41,8 @@
 // CHECK-LABEL: func @main
 func.func @main(%arg0: tensor<i32> , %arg1: tensor<2x1xf32> { tf._layout = "sharding_specs:x,unsharded, mesh:|x=2,y=1|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"}) -> tensor<2xf32> {
   // CHECK:      "tf.Squeeze"(%arg1)
-  // CHECK-SAME: _layout = ["sharding_specs:x, mesh:|x=2,y=1|0,1|0,1|
   // CHECK-SAME: squeeze_dims = [1]
+  // CHECK-SAME: _layout = ["sharding_specs:x, mesh:|x=2,y=1|0,1|0,1|
   // CHECK-SAME: (tensor<1x1xf32>) -> tensor<1xf32>
   %0 = "tf_device.cluster"() ({
     %1 = "tf.DTensorLayout"(%arg1) {_global_shape = [#tf_type.shape<2x1>], global_shape = #tf_type.shape<2x1>,
diff --git a/tensorflow/dtensor/mlir/tests/undo_merge_const_across_mesh.mlir b/tensorflow/dtensor/mlir/tests/undo_merge_const_across_mesh.mlir
index b152d29..956a350 100644
--- a/tensorflow/dtensor/mlir/tests/undo_merge_const_across_mesh.mlir
+++ b/tensorflow/dtensor/mlir/tests/undo_merge_const_across_mesh.mlir
@@ -3,9 +3,9 @@
 // Check that constants with different meshes are duplicated.
 // CHECK-LABEL: func @check_undo_sccp
 func.func @check_undo_sccp() -> (tensor<4xi32>, tensor<4xi32>) {
-    // CHECK-DAG: "tf.DTensorLayout"(%[[CONST_A:.*]]) {global_shape = #tf_type.shape<4>, layout = #dtensor.layout<sharding_specs:unsharded, mesh:|x=2|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1>} : (tensor<4xi32>) -> tensor<4xi32>
+    // CHECK-DAG: "tf.DTensorLayout"(%[[CONST_A:.*]]) <{global_shape = #tf_type.shape<4>, layout = #dtensor.layout<sharding_specs:unsharded, mesh:|x=2|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1>}> : (tensor<4xi32>) -> tensor<4xi32>
     // CHECK-DAG: %[[CONST_A]] = "tf.Const"()
-    // CHECK-DAG: "tf.DTensorLayout"(%[[CONST_B:.*]]) {global_shape = #tf_type.shape<4>, layout = #dtensor.layout<sharding_specs:unsharded, mesh:|y=2|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1>} : (tensor<4xi32>) -> tensor<4xi32>
+    // CHECK-DAG: "tf.DTensorLayout"(%[[CONST_B:.*]]) <{global_shape = #tf_type.shape<4>, layout = #dtensor.layout<sharding_specs:unsharded, mesh:|y=2|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1>}> : (tensor<4xi32>) -> tensor<4xi32>
     // CHECK-DAG: %[[CONST_B]] = "tf.Const"()
 
     %cst = "tf.Const"() {value = dense<[1, 2, 3, 4]> : tensor<4xi32>} : () -> tensor<4xi32>
diff --git a/tensorflow/dtensor/mlir/tests/update_tpu_metadata.mlir b/tensorflow/dtensor/mlir/tests/update_tpu_metadata.mlir
index c7632ca..9832224 100644
--- a/tensorflow/dtensor/mlir/tests/update_tpu_metadata.mlir
+++ b/tensorflow/dtensor/mlir/tests/update_tpu_metadata.mlir
@@ -9,9 +9,9 @@
 
 func.func @f_callee() {
   // CHECK:    tf_device.launch
+  // CHECK:    device = ""
   // CHECK:      "tf._TPUCompileMlir"
   // CHECK-SAME:  metadata = "\0A\09\08\01\12\05\12\03\08\80\01\18\04 \01"
-  // CHECK:    device = ""
   %0:2 = "tf_device.launch"() ({
     %1, %2 = "tf._TPUCompileMlir"() {
       NumDynamicShapes = 0 : i64,
@@ -20,7 +20,7 @@
     tf_device.return %1, %2 : tensor<!tf_type.string>, tensor<2x!tf_type.string>
   })  {device = "tpu_host:0"} : () -> (tensor<!tf_type.string>, tensor<2x!tf_type.string>)
 
-  // CHECK-NEXT: "tf.TPUExecute"
+  // CHECK: "tf.TPUExecute"
   "tf.TPUExecute"(%0#1) : (tensor<2x!tf_type.string>) -> ()
   func.return
 }
@@ -36,8 +36,8 @@
 
 func.func @f_callee() {
   // CHECK:    tf_device.launch
-  // CHECK:      "tf._TPUCompileMlir"
   // CHECK:    device = ""
+  // CHECK:      "tf._TPUCompileMlir"
   %0:2 = "tf_device.launch"() ({
     %1, %2 = "tf._TPUCompileMlir"() {
       NumDynamicShapes = 0 : i64,
@@ -47,8 +47,8 @@
   })  {device = "tpu_host:0"} : () -> (tensor<!tf_type.string>, tensor<2x!tf_type.string>)
 
   // CHECK:    tf_device.launch
-  // CHECK:      "tf.TPUExecute"
   // CHECK:    device = ""
+  // CHECK:      "tf.TPUExecute"
   "tf_device.launch"() ({
     "tf.TPUExecute"(%0#1) : (tensor<2x!tf_type.string>) -> ()
     tf_device.return
@@ -83,24 +83,24 @@
 
 // -----
 
-// Check for Xla Spmd mesh that TPUCompileOp has correct metadata proto and 
+// Check for Xla Spmd mesh that TPUCompileOp has correct metadata proto and
 // number of program outputs is equal to number of devices on mesh.
 
 // CHECK-LABEL: func @main
 func.func @main(%arg0: tensor<i32>, %arg1: tensor<12x24xf32>) -> (tensor<12x24xf32>) {
     %0 = "tf.StatefulPartitionedCall"(%arg1) {
-      config = "|x=2,y=4|0,1,2,3,4,5,6,7|0,1,2,3,4,5,6,7|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1,/job:localhost/replica:0/task:0/device:TPU:2,/job:localhost/replica:0/task:0/device:TPU:3,/job:localhost/replica:0/task:0/device:TPU:4,/job:localhost/replica:0/task:0/device:TPU:5,/job:localhost/replica:0/task:0/device:TPU:6,/job:localhost/replica:0/task:0/device:TPU:7|use_xla_spmd", 
-      config_proto = "", 
-      executor_type = "", 
+      config = "|x=2,y=4|0,1,2,3,4,5,6,7|0,1,2,3,4,5,6,7|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1,/job:localhost/replica:0/task:0/device:TPU:2,/job:localhost/replica:0/task:0/device:TPU:3,/job:localhost/replica:0/task:0/device:TPU:4,/job:localhost/replica:0/task:0/device:TPU:5,/job:localhost/replica:0/task:0/device:TPU:6,/job:localhost/replica:0/task:0/device:TPU:7|use_xla_spmd",
+      config_proto = "",
+      executor_type = "",
       f = @_xla_spmd_func} : (tensor<12x24xf32>) -> tensor<12x24xf32>
     return %0 : tensor<12x24xf32>
   }
 
 func.func private @_xla_spmd_func(%arg0: tensor<12x24xf32>) -> tensor<12x24xf32> {
   // CHECK:    tf_device.launch
+  // CHECK:    device = ""
   // CHECK:      %compilation_status, %program:8 = "tf._TPUCompileMlir"
   // CHECK-SAME:  metadata = "\0A\10\08\01\12\08\12\02\08\0C\12\02\08\18\18\01\22\00\12\02\0A\00\18\01 \08x\01\88\01\ED\91\DC\F5\C3\8C\95\B5\90\01"
-  // CHECK:    device = ""
   %0:2 = "tf_device.launch"() ({
     %compilation_status, %program = "tf._TPUCompileMlir"() {metadata = "\0A\18\08\01\12\08\12\02\08\0C\12\02\08\18\18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\01 \01\88\01\ED\91\DC\F5\C3\8C\95\B5\90\01", mlir_module = "#loc = loc(unknown)\0Amodule attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 1345 : i32}} {\0A  func.func @main(%arg0: tensor<12x24xf32> {mhlo.sharding = \22\22} loc(unknown)) -> (tensor<12x24xf32> {mhlo.sharding = \22\22}) {\0A    %0 = \22tf.Identity\22(%arg0) : (tensor<12x24xf32>) -> tensor<12x24xf32> loc(#loc)\0A    return %0 : tensor<12x24xf32> loc(#loc)\0A  } loc(#loc)\0A} loc(#loc)\0A"} : () -> (tensor<!tf_type.string>, tensor<3x!tf_type.string>)
     tf_device.return %compilation_status, %program : tensor<!tf_type.string>, tensor<3x!tf_type.string>
diff --git a/tensorflow/dtensor/python/tests/BUILD b/tensorflow/dtensor/python/tests/BUILD
index 498642c..af1cd7b 100644
--- a/tensorflow/dtensor/python/tests/BUILD
+++ b/tensorflow/dtensor/python/tests/BUILD
@@ -242,6 +242,7 @@
         "//tensorflow/python/framework:dtypes",
         "//tensorflow/python/framework:errors",
         "//tensorflow/python/framework:ops",
+        "//tensorflow/python/framework:tensor_shape",
         "//tensorflow/python/ops:array_ops",
         "//tensorflow/python/ops:math_ops",
         "//tensorflow/python/ops:stateless_random_ops",
diff --git a/tensorflow/dtensor/python/tests/layout_test.py b/tensorflow/dtensor/python/tests/layout_test.py
index 5fdb2d1..2dabf15 100644
--- a/tensorflow/dtensor/python/tests/layout_test.py
+++ b/tensorflow/dtensor/python/tests/layout_test.py
@@ -30,6 +30,7 @@
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors_impl
 from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import stateless_random_ops
@@ -388,6 +389,30 @@
         tensor_layout.num_shards(0), _2D_MESH.dim_size(_MESH_DIM_BATCH))
     self.assertEqual(tensor_layout.num_shards(1), 1)
 
+  def test_global_shape_from_local_shape(self):
+    tensor_layout = layout.Layout(
+        [_MESH_DIM_BATCH, _MESH_DIM_X, layout.UNSHARDED],
+        mesh=_2D_MESH,
+    )
+    self.assertEqual(
+        tensor_layout.global_shape_from_local_shape(
+            tensor_shape.TensorShape((1, 3, 5))
+        ),
+        (2, 6, 5),
+    )
+
+  def test_local_shape_from_global_shape(self):
+    tensor_layout = layout.Layout(
+        [_MESH_DIM_BATCH, _MESH_DIM_X, layout.UNSHARDED],
+        mesh=_2D_MESH,
+    )
+    self.assertEqual(
+        tensor_layout.local_shape_from_global_shape(
+            tensor_shape.TensorShape((2, 6, 5))
+        ),
+        (1, 3, 5),
+    )
+
   def test_single_device_layout(self):
     tensor_layout = layout.Layout.from_single_device_mesh(_SINGLE_DEVICE_MESH)
     tensor_layout2 = layout.Layout.from_device(
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 050e4dc..dffdc7d 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -51706,6 +51706,31 @@
 	return scope.AddOperation(opspec)
 }
 
+// Op that copies host tensor to device with dynamic shape support.
+// For internal use only.
+func TPUCopyWithDynamicShape(scope *Scope, tensors []tf.Output, unpadded_sizes []tf.Output) (tpu_tensors []tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "TPUCopyWithDynamicShape",
+		Input: []tf.Input{
+			tf.OutputList(tensors), tf.OutputList(unpadded_sizes),
+		},
+	}
+	op := scope.AddOperation(opspec)
+	if scope.Err() != nil {
+		return
+	}
+	var idx int
+	var err error
+	if tpu_tensors, idx, err = makeOutputList(op, idx, "tpu_tensors"); err != nil {
+		scope.UpdateErr("TPUCopyWithDynamicShape", err)
+		return
+	}
+	return tpu_tensors
+}
+
 // An op enabling differentiation of TPU Embeddings.
 //
 // This op simply returns its first input, which is assumed to have been sliced
diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD
index 7ad25cc..4c42ccc 100644
--- a/tensorflow/lite/BUILD
+++ b/tensorflow/lite/BUILD
@@ -3,7 +3,7 @@
 load("//tensorflow:tensorflow.bzl", "if_google", "if_not_windows", "if_oss", "tf_cc_test")
 load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable")
 load("//tensorflow/lite:build_def.bzl", "tflite_cc_shared_object", "tflite_copts", "tflite_copts_warnings", "tflite_linkopts_no_undefined", "tflite_self_contained_libs_test_suite")
-load("//tensorflow/lite:special_rules.bzl", "SPECIAL_RULES_DEPS", "internal_visibility_allowlist", "tflite_portable_test_suite")
+load("//tensorflow/lite:special_rules.bzl", "SPECIAL_RULES_DEPS", "internal_visibility_allowlist", "tflite_internal_cc_3p_api_deps_src_all_visibility_allowlist", "tflite_portable_test_suite")
 load("//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "alias_with_tflite", "cc_library_with_tflite")
 load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
 
@@ -117,6 +117,51 @@
     ],
 )
 
+filegroup(
+    name = "tflite_internal_cc_3p_api_deps_src_all",
+    srcs = [
+        ":tflite_internal_cc_3p_api_deps_src",
+        "//tensorflow/lite/core:macros.h",
+        "//tensorflow/lite/core/acceleration/configuration/c:tflite_internal_cc_3p_api_deps_src",
+        "//tensorflow/lite/core/api:tflite_internal_cc_3p_api_deps_src",
+        "//tensorflow/lite/core/c:tflite_internal_cc_3p_api_deps_src",
+        "//tensorflow/lite/kernels:tflite_internal_cc_3p_api_deps_src",
+        "//tensorflow/lite/kernels/internal:tflite_internal_cc_3p_api_deps_src",
+        "//tensorflow/lite/schema:tflite_internal_cc_3p_api_deps_src",
+    ],
+    visibility = tflite_internal_cc_3p_api_deps_src_all_visibility_allowlist(),
+)
+
+filegroup(
+    name = "tflite_internal_cc_3p_api_deps_src",
+    srcs = [
+        ":allocation.cc",
+        ":allocation.h",
+        ":array.cc",
+        ":array.h",
+        ":builtin_ops.h",
+        ":logger.h",
+        ":minimal_logging.cc",
+        ":minimal_logging.h",
+        ":minimal_logging_android.cc",
+        ":mmap_allocation.cc",
+        ":mutable_op_resolver.cc",
+        ":mutable_op_resolver.h",
+        ":op_resolver.h",
+        ":portable_type_to_tflitetype.h",
+        ":stderr_reporter.cc",
+        ":stderr_reporter.h",
+        ":tensorflow_profiler_logger.h",
+        ":tensorflow_profiler_logger_shim.cc",
+        ":type_to_tflitetype.h",
+        ":util.cc",
+        ":util.h",
+    ],
+    visibility = [
+        "//visibility:private",
+    ],
+)
+
 STABLE_FRAMEWORK_LIB_HDRS = [
     "allocation.h",
     "context_util.h",
@@ -446,9 +491,7 @@
     compatible_with = get_compatible_with_portable(),
     copts = tflite_copts_warnings(),
     deps = [
-        ":string",
         "//tensorflow/lite/core/api:error_reporter",
-        "//tensorflow/lite/core/c:common",
     ],
 )
 
diff --git a/tensorflow/lite/CMakeLists.txt b/tensorflow/lite/CMakeLists.txt
index 83bc762..1b3d6bc 100644
--- a/tensorflow/lite/CMakeLists.txt
+++ b/tensorflow/lite/CMakeLists.txt
@@ -268,7 +268,11 @@
 endif()
 
 populate_tflite_source_vars("core" TFLITE_CORE_SRCS)
-populate_tflite_source_vars("core/acceleration/configuration" TFLITE_CORE_ACCELERATION_SRCS)
+populate_tflite_source_vars(
+  "core/acceleration/configuration" TFLITE_CORE_ACCELERATION_SRCS
+  FILTER "xnnpack_plugin.*"
+  FILTER "(_test)\\.(cc|h)$"
+)
 populate_tflite_source_vars("core/api" TFLITE_CORE_API_SRCS)
 populate_tflite_source_vars("core/async" TFLITE_CORE_ASYNC_SRCS)
 populate_tflite_source_vars("core/async/c" TFLITE_CORE_ASYNC_C_SRCS)
@@ -491,6 +495,8 @@
     XNNPACK
   )
   list(APPEND TFLITE_TARGET_PUBLIC_OPTIONS "-DTFLITE_BUILD_WITH_XNNPACK_DELEGATE")
+  list(APPEND TFLITE_TARGET_PUBLIC_OPTIONS "-DXNNPACK_DELEGATE_ENABLE_QS8")
+  list(APPEND TFLITE_TARGET_PUBLIC_OPTIONS "-DXNNPACK_DELEGATE_ENABLE_QU8")
 endif()
 if(TFLITE_ENABLE_EXTERNAL_DELEGATE)
   populate_tflite_source_vars("delegates/external"
diff --git a/tensorflow/lite/acceleration/configuration/BUILD b/tensorflow/lite/acceleration/configuration/BUILD
index 9dca4b0..4f1b2fe 100644
--- a/tensorflow/lite/acceleration/configuration/BUILD
+++ b/tensorflow/lite/acceleration/configuration/BUILD
@@ -17,7 +17,9 @@
 load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable")
 load("//tensorflow/lite:build_def.bzl", "tflite_copts", "tflite_copts_warnings")
 load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite")
-load("//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite")
+load("//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite", "cc_test_with_tflite")
+
+# copybara:uncomment load("//tools/build_defs/proto/cpp:cc_proto_library.bzl", "cc_proto_library")
 load(":build_defs.bzl", "flatbuffer_schema_compat_test")
 
 # copybara:comment_begin(oss-only)
@@ -322,29 +324,53 @@
     ],
 )
 
-cc_library(
+cc_library_with_tflite(
     name = "xnnpack_plugin",
-    srcs = ["xnnpack_plugin.cc"],
     compatible_with = get_compatible_with_portable(),
-    deps = [
-        ":configuration_fbs",
-        "//tensorflow/lite:minimal_logging",
-        "//tensorflow/lite/core/acceleration/configuration:delegate_registry",
-        "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
-        "@com_google_absl//absl/base:log_severity",
-        "@com_google_absl//absl/memory",
-    ],
-    alwayslink = 1,  # For registration to always run.
+    visibility = ["//visibility:public"],
+    deps = ["//tensorflow/lite/core/acceleration/configuration:xnnpack_plugin"],
 )
 
+cc_test_with_tflite(
+    name = "xnnpack_plugin_with_tflite_test",
+    srcs = ["xnnpack_plugin_test.cc"],
+    # The variant of this test that links against TF Lite in Play services
+    # isn't portable to iOS / Mac or Android, because it relies on a separate
+    # shared library that isn't included in the executable, and the testing
+    # infrastructure for iOS and Android doesn't propagate data dependencies
+    # to the test device.  So we disable this test on those devices.
+    # TODO(b/306161304): ideally we ought to apply these tags only to the
+    # variant for TF Lite in Play services.  In the mean time, we apply those
+    # tags to the whole test, but also duplicate the test below using cc_test
+    # without the tags.
+    tags = [
+        "no_mac",
+        "tflite_not_portable_android",
+        "tflite_not_portable_ios",
+    ],
+    tflite_deps = [
+        ":xnnpack_plugin",
+        "//tensorflow/lite:test_util",
+        "//tensorflow/lite/acceleration/configuration:delegate_registry",
+    ],
+    deps = [
+        ":configuration_fbs",
+        "@com_google_googletest//:gtest_main",
+        "@flatbuffers//:runtime_cc",
+        "@pthreadpool",
+    ],
+)
+
+# This duplicates xnnnpack_plugin_with_tflite_test above, but without the tags,
+# to ensure that this test does get run on iOS and Android.
 cc_test(
     name = "xnnpack_plugin_test",
     srcs = ["xnnpack_plugin_test.cc"],
     deps = [
         ":configuration_fbs",
         ":xnnpack_plugin",
-        "//tensorflow/lite/core/acceleration/configuration:delegate_registry",
-        "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
+        "//tensorflow/lite:test_util",
+        "//tensorflow/lite/acceleration/configuration:delegate_registry",
         "@com_google_googletest//:gtest_main",
         "@flatbuffers//:runtime_cc",
         "@pthreadpool",
diff --git a/tensorflow/lite/acceleration/configuration/xnnpack_plugin_test.cc b/tensorflow/lite/acceleration/configuration/xnnpack_plugin_test.cc
index 2aa1d95..3138e7f 100644
--- a/tensorflow/lite/acceleration/configuration/xnnpack_plugin_test.cc
+++ b/tensorflow/lite/acceleration/configuration/xnnpack_plugin_test.cc
@@ -20,12 +20,12 @@
 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
 #include "pthreadpool.h"  // from @pthreadpool
 #include "tensorflow/lite/acceleration/configuration/configuration_generated.h"
-#include "tensorflow/lite/core/acceleration/configuration/delegate_registry.h"
-#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
+#include "tensorflow/lite/acceleration/configuration/delegate_registry.h"
+#include "tensorflow/lite/test_util.h"
 
 namespace tflite {
 
-class XnnpackPluginTest : public testing::Test {
+class XnnpackPluginTest : public tflite::testing::Test {
  public:
   static constexpr int kNumThreadsForTest = 7;
   static constexpr tflite::XNNPackFlags kFlagsForTest =
@@ -70,22 +70,14 @@
 constexpr int XnnpackPluginTest::kNumThreadsForTest;
 
 TEST_F(XnnpackPluginTest, CanCreateAndDestroyDelegate) {
-  delegates::TfLiteDelegatePtr delegate = delegate_plugin_->Create();
+  delegates::TfLiteOpaqueDelegatePtr delegate = delegate_plugin_->Create();
   EXPECT_NE(delegate, nullptr);
 }
 
 TEST_F(XnnpackPluginTest, CanGetDelegateErrno) {
-  delegates::TfLiteDelegatePtr delegate = delegate_plugin_->Create();
+  delegates::TfLiteOpaqueDelegatePtr delegate = delegate_plugin_->Create();
   int error_number = delegate_plugin_->GetDelegateErrno(delegate.get());
   EXPECT_EQ(error_number, 0);
 }
 
-TEST_F(XnnpackPluginTest, SetsCorrectThreadCount) {
-  delegates::TfLiteDelegatePtr delegate = delegate_plugin_->Create();
-  pthreadpool_t threadpool = static_cast<pthreadpool_t>(
-      TfLiteXNNPackDelegateGetThreadPool(delegate.get()));
-  int thread_count = pthreadpool_get_threads_count(threadpool);
-  EXPECT_EQ(thread_count, kNumThreadsForTest);
-}
-
 }  // namespace tflite
diff --git a/tensorflow/lite/c/BUILD b/tensorflow/lite/c/BUILD
index e10e26e..baf2102 100644
--- a/tensorflow/lite/c/BUILD
+++ b/tensorflow/lite/c/BUILD
@@ -463,6 +463,7 @@
     testonly = True,
     srcs = ["test_util.cc"],
     hdrs = ["test_util.h"],
+    generate_opaque_delegate_target = True,
 )
 
 tflite_self_contained_libs_test_suite(name = "self_contained_libs_test_suite")
diff --git a/tensorflow/lite/core/BUILD b/tensorflow/lite/core/BUILD
index d6c9f25..f405e37 100644
--- a/tensorflow/lite/core/BUILD
+++ b/tensorflow/lite/core/BUILD
@@ -513,8 +513,8 @@
         ":framework_stable",
         "//tensorflow/lite:framework",
         "//tensorflow/lite:util",
+        "//tensorflow/lite/c:c_api_types",
         "//tensorflow/lite/kernels:builtin_ops",  # build_cleaner: keep
-        "//tensorflow/lite/testing:util",
         "@com_google_googletest//:gtest_main",
     ],
 )
diff --git a/tensorflow/lite/core/acceleration/configuration/BUILD b/tensorflow/lite/core/acceleration/configuration/BUILD
index a1c1b2d..cd2c147 100644
--- a/tensorflow/lite/core/acceleration/configuration/BUILD
+++ b/tensorflow/lite/core/acceleration/configuration/BUILD
@@ -1,7 +1,7 @@
+load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable")
+load("//tensorflow/lite:special_rules.bzl", "nnapi_plugin_impl_visibility_allowlist", "xnnpack_plugin_impl_visibility_allowlist")
 load("//tensorflow/lite/core:special_rules.bzl", "delegate_registry_visibility_allowlist")
 load("//tensorflow/lite/core/c:special_rules.bzl", "experimental_acceleration_api_allowlist")
-load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable")
-load("//tensorflow/lite:special_rules.bzl", "nnapi_plugin_impl_visibility_allowlist")
 
 package(
     # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
@@ -88,3 +88,35 @@
         "@com_google_googletest//:gtest_main",
     ],
 )
+
+cc_library(
+    name = "xnnpack_plugin",
+    srcs = ["xnnpack_plugin.cc"],
+    compatible_with = get_compatible_with_portable(),
+    visibility = xnnpack_plugin_impl_visibility_allowlist() + [
+        "//tensorflow/lite:__subpackages__",
+    ],
+    deps = [
+        "//tensorflow/lite:minimal_logging",
+        "//tensorflow/lite/acceleration/configuration:configuration_fbs",
+        "//tensorflow/lite/core/acceleration/configuration:delegate_registry",
+        "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
+        "@com_google_absl//absl/base:log_severity",
+        "@com_google_absl//absl/memory",
+    ],
+    alwayslink = 1,  # For registration to always run.
+)
+
+cc_test(
+    name = "xnnpack_plugin_test",
+    srcs = ["xnnpack_plugin_test.cc"],
+    deps = [
+        ":xnnpack_plugin",
+        "//tensorflow/lite/acceleration/configuration:configuration_fbs",
+        "//tensorflow/lite/core/acceleration/configuration:delegate_registry",
+        "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
+        "@com_google_googletest//:gtest_main",
+        "@flatbuffers//:runtime_cc",
+        "@pthreadpool",
+    ],
+)
diff --git a/tensorflow/lite/core/acceleration/configuration/c/BUILD b/tensorflow/lite/core/acceleration/configuration/c/BUILD
index a0e1418..2e1632d 100644
--- a/tensorflow/lite/core/acceleration/configuration/c/BUILD
+++ b/tensorflow/lite/core/acceleration/configuration/c/BUILD
@@ -39,6 +39,16 @@
     licenses = ["notice"],
 )
 
+filegroup(
+    name = "tflite_internal_cc_3p_api_deps_src",
+    srcs = [
+        "stable_delegate.h",
+    ],
+    visibility = [
+        "//tensorflow/lite:__pkg__",
+    ],
+)
+
 # LINT.IfChange(tflite_acceleration_exported_headers)
 exports_files([
     "delegate_plugin.h",
diff --git a/tensorflow/lite/core/acceleration/configuration/c/delegate_plugin.h b/tensorflow/lite/core/acceleration/configuration/c/delegate_plugin.h
index 157fa9a..900a266 100644
--- a/tensorflow/lite/core/acceleration/configuration/c/delegate_plugin.h
+++ b/tensorflow/lite/core/acceleration/configuration/c/delegate_plugin.h
@@ -13,16 +13,16 @@
 limitations under the License.
 ==============================================================================*/
 // NOLINTBEGIN(whitespace/line_length)
-/// WARNING: Users of TensorFlow Lite should not include this file directly,
-/// but should instead include
-/// "third_party/tensorflow/lite/acceleration/configuration/c/delegate_plugin.h".
-/// Only the TensorFlow Lite implementation itself should include this
-/// file directly.
+// WARNING: Users of TensorFlow Lite should not include this file directly,
+// but should instead include
+// "third_party/tensorflow/lite/acceleration/configuration/c/delegate_plugin.h".
+// Only the TensorFlow Lite implementation itself should include this
+// file directly.
 // NOLINTEND(whitespace/line_length)
 #ifndef TENSORFLOW_LITE_CORE_ACCELERATION_CONFIGURATION_C_DELEGATE_PLUGIN_H_
 #define TENSORFLOW_LITE_CORE_ACCELERATION_CONFIGURATION_C_DELEGATE_PLUGIN_H_
 
-// C API types for TF Lite delegate plugins.
+/// C API types for TF Lite delegate plugins.
 
 #include "tensorflow/lite/core/c/common.h"
 
@@ -30,50 +30,60 @@
 extern "C" {
 #endif
 
-// Type of delegate creation function used to allocate and construct a delegate.
-//
-// The tflite_settings parameter passed to the delegate creation function should
-// be a pointer to a FlatBuffer table object of type tflite::TFLiteSettings.
-// We use 'const void *' here rather than 'const tflite::TFLiteSettings*' since
-// this is a C API so we don't want to directly reference C++ types such as
-// tflite::TFLiteSettings.  But note that this address should point to the
-// 'parsed' FlatBuffer object, not the raw byte buffer.
-// (Note that 'parsing' FlatBuffers is very cheap, it's just an offset load.)
-//
-// If you are using the FlatBuffers C API, then you can alternatively pass
-// in a value of type 'tflite_TFLiteSettings_table_t', which is a typedef for
-// 'const struct tflite_TFLiteSettings_table*' -- that is the corresponding
-// type for the 'parsed' FlatBuffer object in the FlatBuffers C API.
-//
-// Ownership of the tflite_settings flatbuffer remains with the caller.
-// The caller of a delegate creation function may end the lifetime of the
-// tflite_settings FlatBuffer immediately after the call to the function.
-// So the delegate creation function should ensure that any settings that the
-// delegate may need to reference later, after the delegate has been
-// constructed, are copied from the FlatBuffer into storage owned by the
-// delegate.
+// clang-format off
+// NOLINTBEGIN(whitespace/line_length)
+/** \defgroup delegate_plugin tensorflow/lite/acceleration/configuration/c/delegate_plugin.h
+ *  @{
+ */
+// NOLINTEND(whitespace/line_length)
+// clang-format on
+
+/// Type of delegate creation function used to allocate and construct a
+/// delegate.
+///
+/// The `tflite_settings` parameter passed to the delegate creation function
+/// should be a pointer to a FlatBuffer table object of type
+/// `tflite::TFLiteSettings`. We use `const void *` here rather than `const
+/// tflite::TFLiteSettings*` since this is a C API so we don't want to directly
+/// reference C++ types such as `tflite::TFLiteSettings`.  But note that this
+/// address should point to the 'parsed' FlatBuffer object, not the raw byte
+/// buffer. (Note that 'parsing' FlatBuffers is very cheap, it's just an offset
+/// load.)
+///
+/// If you are using the FlatBuffers C API, then you can alternatively pass
+/// in a value of type `tflite_TFLiteSettings_table_t`, which is a typedef for
+/// `const struct tflite_TFLiteSettings_table*` -- that is the corresponding
+/// type for the 'parsed' FlatBuffer object in the FlatBuffers C API.
+///
+/// Ownership of the `tflite_settings` flatbuffer remains with the caller.
+/// The caller of a delegate creation function may end the lifetime of the
+/// `tflite_settings` FlatBuffer immediately after the call to the function.
+/// So the delegate creation function should ensure that any settings that the
+/// delegate may need to reference later, after the delegate has been
+/// constructed, are copied from the FlatBuffer into storage owned by the
+/// delegate.
 typedef TfLiteDelegate *TfLiteDelegatePluginCreateFunc(
     const void *tflite_settings);
 
-// Type of function to destroy and deallocate a delegate.
-// The delegate argument must have been created with the corresponding
-// create function from the same delegate plugin.
+/// Type of function to destroy and deallocate a delegate.
+/// The delegate argument must have been created with the corresponding
+/// create function from the same delegate plugin.
 typedef void TfLiteDelegatePluginDestroyFunc(TfLiteDelegate *);
 
-// Type of function to return an error code for the last delegate operation.
-// The delegate argument must have been created with the corresponding
-// create function from the same delegate plugin.
+/// Type of function to return an error code for the last delegate operation.
+/// The delegate argument must have been created with the corresponding
+/// create function from the same delegate plugin.
 typedef int TfLiteDelegatePluginGetDelegateErrnoFunc(TfLiteDelegate *);
 
-// Struct to hold all the methods for a delegate plugin.
+/// Struct to hold all the methods for a delegate plugin.
 typedef struct TfLiteDelegatePlugin {
-  // Function to allocate and construct a delegate.
+  /// Function to allocate and construct a delegate.
   TfLiteDelegatePluginCreateFunc *create;
 
-  // Function to deallocate a delegate.
+  /// Function to deallocate a delegate.
   TfLiteDelegatePluginDestroyFunc *destroy;
 
-  // Function to return an error code for the last delegate operation.
+  /// Function to return an error code for the last delegate operation.
   TfLiteDelegatePluginGetDelegateErrnoFunc *get_delegate_errno;
 } TfLiteDelegatePlugin;
 
@@ -84,19 +94,20 @@
 // target. e.g. TFLite-in-Play Services initialization context.
 #if TFLITE_USE_OPAQUE_DELEGATE
 
-// Same as TfLiteDelegatePluginCreateFunc but uses truly opaque types.
+/// Same as TfLiteDelegatePluginCreateFunc but uses truly opaque types.
 typedef TfLiteOpaqueDelegateStruct *TfLiteOpaqueDelegatePluginCreateFunc(
     const void *tflite_settings);
 
-// Same as TfLiteDelegatePluginDestroyFunc but uses truly opaque types.
+/// Same as TfLiteDelegatePluginDestroyFunc but uses truly opaque types.
 typedef void TfLiteOpaqueDelegatePluginDestroyFunc(
     TfLiteOpaqueDelegateStruct *delegate);
 
-// Same as TfLiteDelegatePluginGetDelegateErrnoFunc but uses truly opaque types.
+/// Same as TfLiteDelegatePluginGetDelegateErrnoFunc but uses truly opaque
+/// types.
 typedef int TfLiteOpaqueDelegatePluginGetDelegateErrnoFunc(
     TfLiteOpaqueDelegateStruct *delegate);
 
-// Same as TfLiteDelegatePlugin but uses truly opaque types.
+/// Same as TfLiteDelegatePlugin but uses truly opaque types.
 typedef struct TfLiteOpaqueDelegatePlugin {
   TfLiteOpaqueDelegatePluginCreateFunc *create;
 
@@ -115,6 +126,8 @@
 
 #endif  // TFLITE_USE_OPAQUE_DELEGATE
 
+/** @} */
+
 #ifdef __cplusplus
 };  // extern "C"
 #endif
diff --git a/tensorflow/lite/core/acceleration/configuration/c/gpu_plugin.h b/tensorflow/lite/core/acceleration/configuration/c/gpu_plugin.h
index cbee7f1..c30ce4d 100644
--- a/tensorflow/lite/core/acceleration/configuration/c/gpu_plugin.h
+++ b/tensorflow/lite/core/acceleration/configuration/c/gpu_plugin.h
@@ -13,25 +13,25 @@
 limitations under the License.
 ==============================================================================*/
 // NOLINTBEGIN(whitespace/line_length)
-/// WARNING: Users of TensorFlow Lite should not include this file directly,
-/// but should instead include
-/// "third_party/tensorflow/lite/acceleration/configuration/c/gpu_plugin.h".
-/// Only the TensorFlow Lite implementation itself should include this
-/// file directly.
+// WARNING: Users of TensorFlow Lite should not include this file directly,
+// but should instead include
+// "third_party/tensorflow/lite/acceleration/configuration/c/gpu_plugin.h".
+// Only the TensorFlow Lite implementation itself should include this
+// file directly.
 // NOLINTEND(whitespace/line_length)
 #ifndef TENSORFLOW_LITE_CORE_ACCELERATION_CONFIGURATION_C_GPU_PLUGIN_H_
 #define TENSORFLOW_LITE_CORE_ACCELERATION_CONFIGURATION_C_GPU_PLUGIN_H_
 
-// This header file is for the delegate plugin for GPU.
-//
-// For the C++ delegate plugin interface, the GPU delegate plugin is added to
-// the DelegatePluginRegistry by the side effect of a constructor for a static
-// object, so there's no public API needed for this plugin, other than the API
-// of tflite::delegates::DelegatePluginRegistry, which is declared in
-// delegate_registry.h.
-//
-// But to provide a C API to access the GPU delegate plugin, we do expose
-// some functions, which are declared below.
+/// This header file is for the delegate plugin for GPU.
+///
+/// For the C++ delegate plugin interface, the GPU delegate plugin is added to
+/// the `DelegatePluginRegistry` by the side effect of a constructor for a
+/// static object, so there's no public API needed for this plugin, other than
+/// the API of `tflite::delegates::DelegatePluginRegistry`s, which is declared
+/// in delegate_registry.h.
+///
+/// But to provide a C API to access the GPU delegate plugin, we do expose
+/// some functions, which are declared below.
 
 #include "tensorflow/lite/core/acceleration/configuration/c/delegate_plugin.h"
 
@@ -39,10 +39,20 @@
 extern "C" {
 #endif
 
-// C API for the GPU delegate plugin.
-// Returns a pointer to a statically allocated table of function pointers.
+// clang-format off
+// NOLINTBEGIN(whitespace/line_length)
+/** \defgroup gpu_plugin tensorflow/lite/acceleration/configuration/c/gpu_plugin.h
+ *  @{
+ */
+// NOLINTEND(whitespace/line_length)
+// clang-format on
+
+/// C API for the GPU delegate plugin.
+/// Returns a pointer to a statically allocated table of function pointers.
 const TfLiteDelegatePlugin* TfLiteGpuDelegatePluginCApi();
 
+/** @} */
+
 #ifdef __cplusplus
 }  // extern "C"
 #endif
diff --git a/tensorflow/lite/core/acceleration/configuration/c/xnnpack_plugin.h b/tensorflow/lite/core/acceleration/configuration/c/xnnpack_plugin.h
index 224e786..fce48ff 100644
--- a/tensorflow/lite/core/acceleration/configuration/c/xnnpack_plugin.h
+++ b/tensorflow/lite/core/acceleration/configuration/c/xnnpack_plugin.h
@@ -13,25 +13,25 @@
 limitations under the License.
 ==============================================================================*/
 // NOLINTBEGIN(whitespace/line_length)
-/// WARNING: Users of TensorFlow Lite should not include this file directly,
-/// but should instead include
-/// "third_party/tensorflow/lite/acceleration/configuration/c/xnnpack_plugin.h".
-/// Only the TensorFlow Lite implementation itself should include this
-/// file directly.
+// WARNING: Users of TensorFlow Lite should not include this file directly,
+// but should instead include
+// "third_party/tensorflow/lite/acceleration/configuration/c/xnnpack_plugin.h".
+// Only the TensorFlow Lite implementation itself should include this
+// file directly.
 // NOLINTEND(whitespace/line_length)
 #ifndef TENSORFLOW_LITE_CORE_ACCELERATION_CONFIGURATION_C_XNNPACK_PLUGIN_H_
 #define TENSORFLOW_LITE_CORE_ACCELERATION_CONFIGURATION_C_XNNPACK_PLUGIN_H_
 
-// This header file is for the delegate plugin for XNNPACK.
-//
-// For the C++ delegate plugin interface, the XNNPACK delegate plugin is added
-// to the DelegatePluginRegistry by the side effect of a constructor for a
-// static object, so there's no public API needed for this plugin, other than
-// the API of tflite::delegates::DelegatePluginRegistry, which is declared in
-// delegate_registry.h.
-//
-// But to provide a C API to access the XNNPACK delegate plugin, we do expose
-// some functions, which are declared below.
+/// This header file is for the delegate plugin for XNNPACK.
+///
+/// For the C++ delegate plugin interface, the XNNPACK delegate plugin is added
+/// to the DelegatePluginRegistry by the side effect of a constructor for a
+/// static object, so there's no public API needed for this plugin, other than
+/// the API of `tflite::delegates::DelegatePluginRegistry`, which is declared in
+/// delegate_registry.h.
+///
+/// But to provide a C API to access the XNNPACK delegate plugin, we do expose
+/// some functions, which are declared below.
 
 #include "tensorflow/lite/core/acceleration/configuration/c/delegate_plugin.h"
 
@@ -39,10 +39,20 @@
 extern "C" {
 #endif
 
-// C API for the XNNPACK delegate plugin.
-// Returns a pointer to a statically allocated table of function pointers.
+// clang-format off
+// NOLINTBEGIN(whitespace/line_length)
+/** \defgroup xnnpack_plugin tensorflow/lite/acceleration/configuration/c/xnnpack_plugin.h
+ *  @{
+ */
+// NOLINTEND(whitespace/line_length)
+// clang-format on
+
+/// C API for the XNNPACK delegate plugin.
+/// Returns a pointer to a statically allocated table of function pointers.
 const TfLiteDelegatePlugin* TfLiteXnnpackDelegatePluginCApi();
 
+/** @} */
+
 #ifdef __cplusplus
 }  // extern "C"
 #endif
diff --git a/tensorflow/lite/acceleration/configuration/xnnpack_plugin.cc b/tensorflow/lite/core/acceleration/configuration/xnnpack_plugin.cc
similarity index 100%
rename from tensorflow/lite/acceleration/configuration/xnnpack_plugin.cc
rename to tensorflow/lite/core/acceleration/configuration/xnnpack_plugin.cc
diff --git a/tensorflow/lite/core/acceleration/configuration/xnnpack_plugin_test.cc b/tensorflow/lite/core/acceleration/configuration/xnnpack_plugin_test.cc
new file mode 100644
index 0000000..2aa1d95
--- /dev/null
+++ b/tensorflow/lite/core/acceleration/configuration/xnnpack_plugin_test.cc
@@ -0,0 +1,91 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Some very simple unit tests of the (C++) XNNPack Delegate Plugin.
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "flatbuffers/flatbuffers.h"  // from @flatbuffers
+#include "pthreadpool.h"  // from @pthreadpool
+#include "tensorflow/lite/acceleration/configuration/configuration_generated.h"
+#include "tensorflow/lite/core/acceleration/configuration/delegate_registry.h"
+#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
+
+namespace tflite {
+
+class XnnpackPluginTest : public testing::Test {
+ public:
+  static constexpr int kNumThreadsForTest = 7;
+  static constexpr tflite::XNNPackFlags kFlagsForTest =
+      tflite::XNNPackFlags::XNNPackFlags_TFLITE_XNNPACK_DELEGATE_FLAG_QS8_QU8;
+  void SetUp() override {
+    // Construct a FlatBuffer that contains
+    //   TFLiteSettings {
+    //     delegate: Delegate.XNNPACK,
+    //     XNNPackSettings { num_threads: kNumThreadsForTest
+    //                       flags: TFLITE_XNNPACK_DELEGATE_FLAG_QS8 |
+    //                           TFLITE_XNNPACK_DELEGATE_FLAG_QU8
+    //     }
+    //   }.
+    XNNPackSettingsBuilder xnnpack_settings_builder(flatbuffer_builder_);
+    xnnpack_settings_builder.add_num_threads(kNumThreadsForTest);
+    xnnpack_settings_builder.add_flags(kFlagsForTest);
+    flatbuffers::Offset<XNNPackSettings> xnnpack_settings =
+        xnnpack_settings_builder.Finish();
+    TFLiteSettingsBuilder tflite_settings_builder(flatbuffer_builder_);
+    tflite_settings_builder.add_xnnpack_settings(xnnpack_settings);
+    tflite_settings_builder.add_delegate(Delegate_XNNPACK);
+    flatbuffers::Offset<TFLiteSettings> tflite_settings =
+        tflite_settings_builder.Finish();
+    flatbuffer_builder_.Finish(tflite_settings);
+    tflite_settings_ = flatbuffers::GetRoot<TFLiteSettings>(
+        flatbuffer_builder_.GetBufferPointer());
+    // Create an XNNPack delegate plugin using the settings from the flatbuffer.
+    delegate_plugin_ = delegates::DelegatePluginRegistry::CreateByName(
+        "XNNPackPlugin", *tflite_settings_);
+    ASSERT_NE(delegate_plugin_, nullptr);
+  }
+  void TearDown() override { delegate_plugin_.reset(); }
+  ~XnnpackPluginTest() override {}
+
+ protected:
+  // settings_ points into storage owned by flatbuffer_builder_.
+  flatbuffers::FlatBufferBuilder flatbuffer_builder_;
+  const TFLiteSettings *tflite_settings_;
+  std::unique_ptr<delegates::DelegatePluginInterface> delegate_plugin_;
+};
+
+constexpr int XnnpackPluginTest::kNumThreadsForTest;
+
+TEST_F(XnnpackPluginTest, CanCreateAndDestroyDelegate) {
+  delegates::TfLiteDelegatePtr delegate = delegate_plugin_->Create();
+  EXPECT_NE(delegate, nullptr);
+}
+
+TEST_F(XnnpackPluginTest, CanGetDelegateErrno) {
+  delegates::TfLiteDelegatePtr delegate = delegate_plugin_->Create();
+  int error_number = delegate_plugin_->GetDelegateErrno(delegate.get());
+  EXPECT_EQ(error_number, 0);
+}
+
+TEST_F(XnnpackPluginTest, SetsCorrectThreadCount) {
+  delegates::TfLiteDelegatePtr delegate = delegate_plugin_->Create();
+  pthreadpool_t threadpool = static_cast<pthreadpool_t>(
+      TfLiteXNNPackDelegateGetThreadPool(delegate.get()));
+  int thread_count = pthreadpool_get_threads_count(threadpool);
+  EXPECT_EQ(thread_count, kNumThreadsForTest);
+}
+
+}  // namespace tflite
diff --git a/tensorflow/lite/core/api/BUILD b/tensorflow/lite/core/api/BUILD
index 916dbaf..a0e28f1 100644
--- a/tensorflow/lite/core/api/BUILD
+++ b/tensorflow/lite/core/api/BUILD
@@ -8,6 +8,21 @@
     licenses = ["notice"],
 )
 
+filegroup(
+    name = "tflite_internal_cc_3p_api_deps_src",
+    srcs = [
+        ":error_reporter.cc",
+        ":error_reporter.h",
+        ":op_resolver.cc",
+        ":op_resolver.h",
+        ":op_resolver_internal.h",
+        ":verifier.h",
+    ],
+    visibility = [
+        "//tensorflow/lite:__pkg__",
+    ],
+)
+
 cc_library(
     name = "api",
     srcs = [
@@ -138,7 +153,10 @@
     deps = [
         ":api",
         "//tensorflow/lite:string",
+        "//tensorflow/lite/c:c_api_types",
         "//tensorflow/lite/core/c:common",
+        "//tensorflow/lite/schema:schema_fbs",
         "@com_google_googletest//:gtest_main",
+        "@flatbuffers//:runtime_cc",
     ],
 )
diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc
index a708918..f37e38a 100644
--- a/tensorflow/lite/core/api/flatbuffer_conversions.cc
+++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc
@@ -15,6 +15,7 @@
 
 #include "tensorflow/lite/core/api/flatbuffer_conversions.h"
 
+#include <algorithm>
 #include <cstddef>
 #include <cstdint>
 #include <memory>
@@ -881,6 +882,10 @@
     case BuiltinOperator_STABLEHLO_GATHER: {
       return ParseStablehloGather(op, error_reporter, allocator, builtin_data);
     }
+    case BuiltinOperator_STABLEHLO_REDUCE_WINDOW: {
+      return ParseStablehloReduceWindow(op, error_reporter, allocator,
+                                        builtin_data);
+    }
     case BuiltinOperator_REDUCE_WINDOW: {
       auto params = safe_allocator.Allocate<TfLiteReduceWindowParams>();
       TF_LITE_ENSURE(error_reporter, params != nullptr);
@@ -949,7 +954,6 @@
     case BuiltinOperator_STABLEHLO_CONVERT:
     case BuiltinOperator_STABLEHLO_PAD:
     case BuiltinOperator_STABLEHLO_DOT_GENERAL:
-    case BuiltinOperator_STABLEHLO_REDUCE_WINDOW:
     case BuiltinOperator_STABLEHLO_SORT:
     case BuiltinOperator_STABLEHLO_WHILE:
     case BuiltinOperator_STABLEHLO_TRANSPOSE:
@@ -2096,6 +2100,98 @@
   return kTfLiteOk;
 }
 
+TfLiteStatus ParseStablehloReduceWindow(const Operator* op,
+                                        ErrorReporter* error_reporter,
+                                        BuiltinDataAllocator* allocator,
+                                        void** builtin_data) {
+  CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
+
+  SafeBuiltinDataAllocator safe_allocator(allocator);
+  auto params = safe_allocator.Allocate<TfLiteStablehloReduceWindowParams>();
+
+  const StablehloReduceWindowOptions* schema_params =
+      op->builtin_options_2_as_StablehloReduceWindowOptions();
+  if (schema_params) {
+    if (!schema_params->window_dimensions() ||
+        schema_params->window_dimensions()->size() == 0) {
+      TF_LITE_REPORT_ERROR(error_reporter,
+                           "'window_dimensions' attribute is not optional for "
+                           "'stablehlo.reduce_window' and cannot be empty.");
+      return kTfLiteError;
+    }
+
+    const size_t rank = schema_params->window_dimensions()->size();
+
+    auto LoadAttr = [&error_reporter](
+                        auto& params_array, auto* const flatbuffer_vector,
+                        const char* attr_name, const size_t expected_size,
+                        const int64_t fill_value) -> TfLiteStatus {
+      if (flatbuffer_vector && flatbuffer_vector->size()) {
+        if (expected_size != 0 && flatbuffer_vector->size() != expected_size) {
+          TF_LITE_REPORT_ERROR(
+              error_reporter,
+              "'%s' attribute of 'stablehlo.reduce_window' does not have the "
+              "expected size (%llu != %llu).",
+              attr_name, flatbuffer_vector->size(), expected_size);
+          return kTfLiteError;
+        }
+        TfLiteStatus status = FlatBufferIntVectorToArray(
+            sizeof(params_array), flatbuffer_vector, params_array,
+            error_reporter, "stablehlo.reduce_window");
+        if (status != kTfLiteOk) {
+          TF_LITE_REPORT_ERROR(error_reporter, "Check the '%s' attribute.",
+                               attr_name);
+          return status;
+        }
+      } else {
+        std::fill_n(params_array,
+                    TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT,
+                    fill_value);
+      }
+      return kTfLiteOk;
+    };
+
+    if (TfLiteStatus status = LoadAttr(
+            params->window_dimensions, schema_params->window_dimensions(),
+            "window_dimensions", /*expected_size=*/rank, /*fill_value=*/1);
+        status != kTfLiteOk) {
+      return status;
+    }
+    if (TfLiteStatus status = LoadAttr(
+            params->window_strides, schema_params->window_strides(),
+            "window_strides", /*expected_size=*/rank, /*fill_value=*/1);
+        status != kTfLiteOk) {
+      return status;
+    }
+    if (TfLiteStatus status = LoadAttr(
+            params->base_dilations, schema_params->base_dilations(),
+            "base_dilations", /*expected_size=*/rank, /*fill_value=*/1);
+        status != kTfLiteOk) {
+      return status;
+    }
+    if (TfLiteStatus status = LoadAttr(
+            params->window_dilations, schema_params->window_dilations(),
+            "window_dilations", /*expected_size=*/rank, /*fill_value=*/1);
+        status != kTfLiteOk) {
+      return status;
+    }
+    if (TfLiteStatus status =
+            LoadAttr(params->padding, schema_params->padding(), "padding",
+                     /*expected_size=*/2 * rank, /*fill_value=*/0);
+        status != kTfLiteOk) {
+      return status;
+    }
+
+    params->body_subgraph_index = schema_params->body_subgraph_index();
+    *builtin_data = params.release();
+    return kTfLiteOk;
+  }
+  TF_LITE_REPORT_ERROR(
+      error_reporter,
+      "Could not get 'stablehlo.reduce_window' operation parameters.");
+  return kTfLiteError;
+}
+
 TfLiteStatus ParseStablehloScatter(const Operator* op,
                                    ErrorReporter* error_reporter,
                                    BuiltinDataAllocator* allocator,
diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.h b/tensorflow/lite/core/api/flatbuffer_conversions.h
index 9c895b2..11e70a6 100644
--- a/tensorflow/lite/core/api/flatbuffer_conversions.h
+++ b/tensorflow/lite/core/api/flatbuffer_conversions.h
@@ -435,6 +435,11 @@
                                   BuiltinDataAllocator* allocator,
                                   void** builtin_data);
 
+TfLiteStatus ParseStablehloReduceWindow(const Operator* op,
+                                        ErrorReporter* error_reporter,
+                                        BuiltinDataAllocator* allocator,
+                                        void** builtin_data);
+
 }  // namespace tflite
 
 #endif  // TENSORFLOW_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_
diff --git a/tensorflow/lite/core/api/flatbuffer_conversions_test.cc b/tensorflow/lite/core/api/flatbuffer_conversions_test.cc
index 314a1f8..1fbe440 100644
--- a/tensorflow/lite/core/api/flatbuffer_conversions_test.cc
+++ b/tensorflow/lite/core/api/flatbuffer_conversions_test.cc
@@ -15,13 +15,30 @@
 
 #include "tensorflow/lite/core/api/flatbuffer_conversions.h"
 
+#include <cstdarg>
+#include <cstdint>
+#include <cstdio>
 #include <cstring>
+#include <tuple>
+#include <vector>
 
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
+#include "flatbuffers/buffer.h"  // from @flatbuffers
+#include "flatbuffers/flatbuffer_builder.h"  // from @flatbuffers
+#include "tensorflow/lite/c/c_api_types.h"
+#include "tensorflow/lite/core/api/error_reporter.h"
 #include "tensorflow/lite/core/c/builtin_op_data.h"
+#include "tensorflow/lite/schema/schema_generated.h"
 #include "tensorflow/lite/string_type.h"
 
+using testing::AllOf;
+using testing::Each;
+using testing::ElementsAre;
+using testing::Eq;
+using testing::HasSubstr;
+using testing::StrEq;
+
 namespace tflite {
 namespace {
 
@@ -29,13 +46,15 @@
  public:
   MockErrorReporter() : buffer_size_(0) {}
   int Report(const char* format, va_list args) override {
-    buffer_size_ = vsnprintf(buffer_, kBufferSize, format, args);
+    buffer_size_ += vsnprintf(buffer_ + buffer_size_,
+                              kBufferSize - buffer_size_, format, args);
     return buffer_size_;
   }
-  char* GetBuffer() { return buffer_; }
-  int GetBufferSize() { return buffer_size_; }
+  const char* GetBuffer() const { return buffer_; }
+  int GetBufferSize() const { return buffer_size_; }
+  bool IsEmpty() const { return !buffer_size_; }
 
-  string GetAsString() const { return string(buffer_, buffer_size_); }
+  string GetString() const { return string(buffer_, buffer_size_); }
 
  private:
   static constexpr int kBufferSize = 256;
@@ -76,6 +95,22 @@
     return flatbuffers::GetRoot<Operator>(pointer);
   }
 
+  const Operator* BuildTestOperator(BuiltinOptions2 op_type,
+                                    flatbuffers::Offset<void> options) {
+    flatbuffers::Offset<Operator> offset = CreateOperatorDirect(
+        builder_, /*opcode_index=*/0, /*inputs=*/nullptr, /*outputs=*/nullptr,
+        /*builtin_options_type=*/tflite::BuiltinOptions_NONE,
+        /*builtin_options=*/0, /*custom_options=*/nullptr,
+        /*custom_options_format=*/tflite::CustomOptionsFormat_FLEXBUFFERS,
+        /*mutating_variable_inputs=*/nullptr, /*intermediates=*/nullptr,
+        /*large_custom_options_offset=*/0, /*large_custom_options_size=*/0,
+        /*builtin_options_2_type=*/op_type,
+        /*builtin_options_2=*/options);
+    builder_.Finish(offset);
+    void* pointer = builder_.GetBufferPointer();
+    return flatbuffers::GetRoot<Operator>(pointer);
+  }
+
  protected:
   MockErrorReporter mock_reporter_;
   MockDataAllocator mock_allocator_;
@@ -162,4 +197,536 @@
   EXPECT_EQ(kTfLiteInt4, type);
 }
 
+class StablehloReduceWindowFlatbufferConversionsTest
+    : public FlatbufferConversionsTest {
+ public:
+  static constexpr int kMaxDims =
+      TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT;
+  static constexpr int64_t kValidValue = 5;
+
+  auto ValidAttr() {
+    return builder_.CreateVector(std::vector<int64_t>(kMaxDims, kValidValue));
+  }
+
+  auto InvalidAttr() {
+    return builder_.CreateVector(
+        std::vector<int64_t>(kMaxDims + 1, kValidValue));
+  }
+
+  auto ValidPaddingAttr() {
+    return builder_.CreateVector(
+        std::vector<int64_t>(2 * kMaxDims, kValidValue));
+  }
+
+  auto InvalidPaddingAttr() {
+    return builder_.CreateVector(
+        std::vector<int64_t>(2 * kMaxDims + 1, kValidValue));
+  }
+
+  auto EmptyAttr() { return builder_.CreateVector<int64_t>({}); }
+};
+
+TEST_F(StablehloReduceWindowFlatbufferConversionsTest,
+       ParseStablehloReduceWindow) {
+  const Operator* stablehlo_reduce_window_op = BuildTestOperator(
+      BuiltinOptions2_StablehloReduceWindowOptions,
+      CreateStablehloReduceWindowOptions(
+          builder_,
+          /*window_dimensions=*/builder_.CreateVector<int64_t>({1, 2}),
+          /*window_strides=*/builder_.CreateVector<int64_t>({3, 4}),
+          /*base_dilations=*/builder_.CreateVector<int64_t>({5, 6}),
+          /*window_dilations=*/builder_.CreateVector<int64_t>({7, 8}),
+          /*padding=*/builder_.CreateVector<int64_t>({9, 10, 11, 12}),
+          /*body_subgraph_index=*/13)
+          .Union());
+  TfLiteStablehloReduceWindowParams* output_data = nullptr;
+  EXPECT_EQ(
+      ParseOpData(stablehlo_reduce_window_op,
+                  BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_reporter_,
+                  &mock_allocator_, (void**)&output_data),
+      kTfLiteOk);
+
+  EXPECT_THAT(std::make_tuple(output_data->window_dimensions, 2),
+              ElementsAre(1, 2));
+  EXPECT_THAT(std::make_tuple(output_data->window_strides, 2),
+              ElementsAre(3, 4));
+  EXPECT_THAT(std::make_tuple(output_data->base_dilations, 2),
+              ElementsAre(5, 6));
+  EXPECT_THAT(std::make_tuple(output_data->window_dilations, 2),
+              ElementsAre(7, 8));
+  EXPECT_THAT(std::make_tuple(output_data->padding, 4),
+              ElementsAre(9, 10, 11, 12));
+  EXPECT_THAT(output_data->body_subgraph_index, Eq(13));
+}
+
+TEST_F(StablehloReduceWindowFlatbufferConversionsTest,
+       ParseStablehloReduceWindowDeathTests) {
+  const Operator* stablehlo_reduce_window_op = BuildTestOperator(
+      BuiltinOptions2_StablehloReduceWindowOptions,
+      CreateStablehloReduceWindowOptions(
+          builder_, /*window_dimensions=*/ValidAttr(),
+          /*window_strides=*/ValidAttr(),
+          /*base_dilations=*/ValidAttr(),
+          /*window_dilations=*/ValidAttr(),
+          /*padding=*/ValidPaddingAttr(), /*body_subgraph_index=*/13)
+          .Union());
+  TfLiteStablehloReduceWindowParams* output_data = nullptr;
+#ifdef NDEBUG
+  GTEST_SKIP();
+#endif
+  EXPECT_DEATH(
+      ParseOpData(nullptr, BuiltinOperator_STABLEHLO_REDUCE_WINDOW,
+                  &mock_reporter_, &mock_allocator_, (void**)&output_data),
+      "");
+  EXPECT_DEATH(ParseOpData(stablehlo_reduce_window_op,
+                           BuiltinOperator_STABLEHLO_REDUCE_WINDOW, nullptr,
+                           &mock_allocator_, (void**)&output_data),
+               "");
+  EXPECT_DEATH(ParseOpData(stablehlo_reduce_window_op,
+                           BuiltinOperator_STABLEHLO_REDUCE_WINDOW,
+                           &mock_reporter_, nullptr, (void**)&output_data),
+               "");
+  EXPECT_DEATH(ParseOpData(stablehlo_reduce_window_op,
+                           BuiltinOperator_STABLEHLO_REDUCE_WINDOW,
+                           &mock_reporter_, &mock_allocator_, nullptr),
+               "");
+}
+
+TEST_F(StablehloReduceWindowFlatbufferConversionsTest,
+       ParseStablehloReduceWindowFailsWithNoWindowDimensions) {
+  TfLiteStablehloReduceWindowParams* output_data = nullptr;
+  EXPECT_EQ(ParseOpData(
+                BuildTestOperator(BuiltinOptions2_StablehloReduceWindowOptions,
+                                  CreateStablehloReduceWindowOptions(
+                                      builder_,
+                                      /*window_dimensions=*/0,
+                                      /*window_strides=*/ValidAttr(),
+                                      /*base_dilations=*/ValidAttr(),
+                                      /*window_dilations=*/ValidAttr(),
+                                      /*padding=*/ValidPaddingAttr(),
+                                      /*body_subgraph_index=*/13)
+                                      .Union()),
+                BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_reporter_,
+                &mock_allocator_, (void**)&output_data),
+            kTfLiteError);
+  EXPECT_THAT(mock_reporter_.GetString(),
+              HasSubstr("'window_dimensions' attribute is not optional for "
+                        "'stablehlo.reduce_window' and cannot be empty."));
+}
+
+TEST_F(StablehloReduceWindowFlatbufferConversionsTest,
+       ParseStablehloReduceWindowSucceedsWithNoWindowStrides) {
+  TfLiteStablehloReduceWindowParams* output_data = nullptr;
+  EXPECT_EQ(ParseOpData(
+                BuildTestOperator(BuiltinOptions2_StablehloReduceWindowOptions,
+                                  CreateStablehloReduceWindowOptions(
+                                      builder_,
+                                      /*window_dimensions=*/ValidAttr(),
+                                      /*window_strides=*/0,
+                                      /*base_dilations=*/ValidAttr(),
+                                      /*window_dilations=*/ValidAttr(),
+                                      /*padding=*/ValidPaddingAttr(),
+                                      /*body_subgraph_index=*/13)
+                                      .Union()),
+                BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_reporter_,
+                &mock_allocator_, (void**)&output_data),
+            kTfLiteOk);
+  EXPECT_THAT(mock_reporter_.GetString(), StrEq(""));
+  EXPECT_THAT(std::make_tuple(output_data->window_dimensions, kMaxDims),
+              Each(kValidValue));
+  EXPECT_THAT(std::make_tuple(output_data->window_strides, kMaxDims), Each(1));
+  EXPECT_THAT(std::make_tuple(output_data->base_dilations, kMaxDims),
+              Each(kValidValue));
+  EXPECT_THAT(std::make_tuple(output_data->window_dilations, kMaxDims),
+              Each(kValidValue));
+  EXPECT_THAT(std::make_tuple(output_data->padding, 2 * kMaxDims),
+              Each(kValidValue));
+  EXPECT_THAT(output_data->body_subgraph_index, Eq(13));
+}
+
+TEST_F(StablehloReduceWindowFlatbufferConversionsTest,
+       ParseStablehloReduceWindowSucceedsWithNoBaseDilations) {
+  TfLiteStablehloReduceWindowParams* output_data = nullptr;
+  EXPECT_EQ(ParseOpData(
+                BuildTestOperator(BuiltinOptions2_StablehloReduceWindowOptions,
+                                  CreateStablehloReduceWindowOptions(
+                                      builder_,
+                                      /*window_dimensions=*/ValidAttr(),
+                                      /*window_strides=*/ValidAttr(),
+                                      /*base_dilations=*/0,
+                                      /*window_dilations=*/ValidAttr(),
+                                      /*padding=*/ValidPaddingAttr(),
+                                      /*body_subgraph_index=*/13)
+                                      .Union()),
+                BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_reporter_,
+                &mock_allocator_, (void**)&output_data),
+            kTfLiteOk);
+  EXPECT_THAT(mock_reporter_.GetString(), StrEq(""));
+  EXPECT_THAT(std::make_tuple(output_data->window_dimensions, kMaxDims),
+              Each(kValidValue));
+  EXPECT_THAT(std::make_tuple(output_data->window_strides, kMaxDims),
+              Each(kValidValue));
+  EXPECT_THAT(std::make_tuple(output_data->base_dilations, kMaxDims), Each(1));
+  EXPECT_THAT(std::make_tuple(output_data->window_dilations, kMaxDims),
+              Each(kValidValue));
+  EXPECT_THAT(std::make_tuple(output_data->padding, 2 * kMaxDims),
+              Each(kValidValue));
+  EXPECT_THAT(output_data->body_subgraph_index, Eq(13));
+}
+
+TEST_F(StablehloReduceWindowFlatbufferConversionsTest,
+       ParseStablehloReduceWindowSucceedsWithNoWindowDilations) {
+  TfLiteStablehloReduceWindowParams* output_data = nullptr;
+  EXPECT_EQ(ParseOpData(
+                BuildTestOperator(BuiltinOptions2_StablehloReduceWindowOptions,
+                                  CreateStablehloReduceWindowOptions(
+                                      builder_,
+                                      /*window_dimensions=*/ValidAttr(),
+                                      /*window_strides=*/ValidAttr(),
+                                      /*base_dilations=*/ValidAttr(),
+                                      /*window_dilations=*/0,
+                                      /*padding=*/ValidPaddingAttr(),
+                                      /*body_subgraph_index=*/13)
+                                      .Union()),
+                BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_reporter_,
+                &mock_allocator_, (void**)&output_data),
+            kTfLiteOk);
+  EXPECT_THAT(mock_reporter_.GetString(), StrEq(""));
+  EXPECT_THAT(std::make_tuple(output_data->window_dimensions, kMaxDims),
+              Each(kValidValue));
+  EXPECT_THAT(std::make_tuple(output_data->window_strides, kMaxDims),
+              Each(kValidValue));
+  EXPECT_THAT(std::make_tuple(output_data->base_dilations, kMaxDims),
+              Each(kValidValue));
+  EXPECT_THAT(std::make_tuple(output_data->window_dilations, kMaxDims),
+              Each(1));
+  EXPECT_THAT(std::make_tuple(output_data->padding, 2 * kMaxDims),
+              Each(kValidValue));
+  EXPECT_THAT(output_data->body_subgraph_index, Eq(13));
+}
+
+TEST_F(StablehloReduceWindowFlatbufferConversionsTest,
+       ParseStablehloReduceWindowSucceedsWithNoPadding) {
+  TfLiteStablehloReduceWindowParams* output_data = nullptr;
+  EXPECT_EQ(ParseOpData(
+                BuildTestOperator(BuiltinOptions2_StablehloReduceWindowOptions,
+                                  CreateStablehloReduceWindowOptions(
+                                      builder_,
+                                      /*window_dimensions=*/ValidAttr(),
+                                      /*window_strides=*/ValidAttr(),
+                                      /*base_dilations=*/ValidAttr(),
+                                      /*window_dilations=*/ValidAttr(),
+                                      /*padding=*/0,
+                                      /*body_subgraph_index=*/13)
+                                      .Union()),
+                BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_reporter_,
+                &mock_allocator_, (void**)&output_data),
+            kTfLiteOk);
+  EXPECT_THAT(mock_reporter_.GetString(), StrEq(""));
+  EXPECT_THAT(std::make_tuple(output_data->window_dimensions, kMaxDims),
+              Each(kValidValue));
+  EXPECT_THAT(std::make_tuple(output_data->window_strides, kMaxDims),
+              Each(kValidValue));
+  EXPECT_THAT(std::make_tuple(output_data->base_dilations, kMaxDims),
+              Each(kValidValue));
+  EXPECT_THAT(std::make_tuple(output_data->window_dilations, kMaxDims),
+              Each(kValidValue));
+  EXPECT_THAT(std::make_tuple(output_data->padding, 2 * kMaxDims), Each(0));
+  EXPECT_THAT(output_data->body_subgraph_index, Eq(13));
+}
+
+TEST_F(StablehloReduceWindowFlatbufferConversionsTest,
+       ParseStablehloReduceWindowFailsWithEmptyWindowDimensions) {
+  TfLiteStablehloReduceWindowParams* output_data = nullptr;
+  EXPECT_EQ(ParseOpData(
+                BuildTestOperator(BuiltinOptions2_StablehloReduceWindowOptions,
+                                  CreateStablehloReduceWindowOptions(
+                                      builder_,
+                                      /*window_dimensions=*/EmptyAttr(),
+                                      /*window_strides=*/ValidAttr(),
+                                      /*base_dilations=*/ValidAttr(),
+                                      /*window_dilations=*/ValidAttr(),
+                                      /*padding=*/ValidPaddingAttr(),
+                                      /*body_subgraph_index=*/13)
+                                      .Union()),
+                BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_reporter_,
+                &mock_allocator_, (void**)&output_data),
+            kTfLiteError);
+  EXPECT_THAT(mock_reporter_.GetString(),
+              HasSubstr("'window_dimensions' attribute is not optional for "
+                        "'stablehlo.reduce_window' and cannot be empty."));
+}
+
+TEST_F(StablehloReduceWindowFlatbufferConversionsTest,
+       ParseStablehloReduceWindowSucceedsWithEmptyWindowStrides) {
+  TfLiteStablehloReduceWindowParams* output_data = nullptr;
+  EXPECT_EQ(ParseOpData(
+                BuildTestOperator(BuiltinOptions2_StablehloReduceWindowOptions,
+                                  CreateStablehloReduceWindowOptions(
+                                      builder_,
+                                      /*window_dimensions=*/ValidAttr(),
+                                      /*window_strides=*/EmptyAttr(),
+                                      /*base_dilations=*/ValidAttr(),
+                                      /*window_dilations=*/ValidAttr(),
+                                      /*padding=*/ValidPaddingAttr(),
+                                      /*body_subgraph_index=*/13)
+                                      .Union()),
+                BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_reporter_,
+                &mock_allocator_, (void**)&output_data),
+            kTfLiteOk);
+  EXPECT_THAT(mock_reporter_.GetString(), StrEq(""));
+  EXPECT_THAT(std::make_tuple(output_data->window_dimensions, kMaxDims),
+              Each(kValidValue));
+  EXPECT_THAT(std::make_tuple(output_data->window_strides, kMaxDims), Each(1));
+  EXPECT_THAT(std::make_tuple(output_data->base_dilations, kMaxDims),
+              Each(kValidValue));
+  EXPECT_THAT(std::make_tuple(output_data->window_dilations, kMaxDims),
+              Each(kValidValue));
+  EXPECT_THAT(std::make_tuple(output_data->padding, 2 * kMaxDims),
+              Each(kValidValue));
+  EXPECT_THAT(output_data->body_subgraph_index, Eq(13));
+}
+
+TEST_F(StablehloReduceWindowFlatbufferConversionsTest,
+       ParseStablehloReduceWindowSucceedsWithEmptyBaseDilations) {
+  TfLiteStablehloReduceWindowParams* output_data = nullptr;
+  EXPECT_EQ(ParseOpData(
+                BuildTestOperator(BuiltinOptions2_StablehloReduceWindowOptions,
+                                  CreateStablehloReduceWindowOptions(
+                                      builder_,
+                                      /*window_dimensions=*/ValidAttr(),
+                                      /*window_strides=*/ValidAttr(),
+                                      /*base_dilations=*/EmptyAttr(),
+                                      /*window_dilations=*/ValidAttr(),
+                                      /*padding=*/ValidPaddingAttr(),
+                                      /*body_subgraph_index=*/13)
+                                      .Union()),
+                BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_reporter_,
+                &mock_allocator_, (void**)&output_data),
+            kTfLiteOk);
+  EXPECT_THAT(mock_reporter_.GetString(), StrEq(""));
+  EXPECT_THAT(std::make_tuple(output_data->window_dimensions, kMaxDims),
+              Each(kValidValue));
+  EXPECT_THAT(std::make_tuple(output_data->window_strides, kMaxDims),
+              Each(kValidValue));
+  EXPECT_THAT(std::make_tuple(output_data->base_dilations, kMaxDims), Each(1));
+  EXPECT_THAT(std::make_tuple(output_data->window_dilations, kMaxDims),
+              Each(kValidValue));
+  EXPECT_THAT(std::make_tuple(output_data->padding, 2 * kMaxDims),
+              Each(kValidValue));
+  EXPECT_THAT(output_data->body_subgraph_index, Eq(13));
+}
+
+TEST_F(StablehloReduceWindowFlatbufferConversionsTest,
+       ParseStablehloReduceWindowSucceedsWithEmptyWindowDilations) {
+  TfLiteStablehloReduceWindowParams* output_data = nullptr;
+  EXPECT_EQ(ParseOpData(
+                BuildTestOperator(BuiltinOptions2_StablehloReduceWindowOptions,
+                                  CreateStablehloReduceWindowOptions(
+                                      builder_,
+                                      /*window_dimensions=*/ValidAttr(),
+                                      /*window_strides=*/ValidAttr(),
+                                      /*base_dilations=*/ValidAttr(),
+                                      /*window_dilations=*/EmptyAttr(),
+                                      /*padding=*/ValidPaddingAttr(),
+                                      /*body_subgraph_index=*/13)
+                                      .Union()),
+                BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_reporter_,
+                &mock_allocator_, (void**)&output_data),
+            kTfLiteOk);
+  EXPECT_THAT(mock_reporter_.GetString(), StrEq(""));
+  EXPECT_THAT(std::make_tuple(output_data->window_dimensions, kMaxDims),
+              Each(kValidValue));
+  EXPECT_THAT(std::make_tuple(output_data->window_strides, kMaxDims),
+              Each(kValidValue));
+  EXPECT_THAT(std::make_tuple(output_data->base_dilations, kMaxDims),
+              Each(kValidValue));
+  EXPECT_THAT(std::make_tuple(output_data->window_dilations, kMaxDims),
+              Each(1));
+  EXPECT_THAT(std::make_tuple(output_data->padding, 2 * kMaxDims),
+              Each(kValidValue));
+  EXPECT_THAT(output_data->body_subgraph_index, Eq(13));
+}
+
+TEST_F(StablehloReduceWindowFlatbufferConversionsTest,
+       ParseStablehloReduceWindowSucceedsWithEmptyPadding) {
+  TfLiteStablehloReduceWindowParams* output_data = nullptr;
+  EXPECT_EQ(ParseOpData(
+                BuildTestOperator(BuiltinOptions2_StablehloReduceWindowOptions,
+                                  CreateStablehloReduceWindowOptions(
+                                      builder_,
+                                      /*window_dimensions=*/ValidAttr(),
+                                      /*window_strides=*/ValidAttr(),
+                                      /*base_dilations=*/ValidAttr(),
+                                      /*window_dilations=*/ValidAttr(),
+                                      /*padding=*/EmptyAttr(),
+                                      /*body_subgraph_index=*/13)
+                                      .Union()),
+                BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_reporter_,
+                &mock_allocator_, (void**)&output_data),
+            kTfLiteOk);
+  EXPECT_THAT(mock_reporter_.GetString(), StrEq(""));
+  EXPECT_THAT(std::make_tuple(output_data->window_dimensions, kMaxDims),
+              Each(kValidValue));
+  EXPECT_THAT(std::make_tuple(output_data->window_strides, kMaxDims),
+              Each(kValidValue));
+  EXPECT_THAT(std::make_tuple(output_data->base_dilations, kMaxDims),
+              Each(kValidValue));
+  EXPECT_THAT(std::make_tuple(output_data->window_dilations, kMaxDims),
+              Each(kValidValue));
+  EXPECT_THAT(std::make_tuple(output_data->padding, 2 * kMaxDims), Each(0));
+  EXPECT_THAT(output_data->body_subgraph_index, Eq(13));
+}
+
+TEST_F(StablehloReduceWindowFlatbufferConversionsTest,
+       ParseStablehloReduceWindowSucceedsWithParamsAtMaxDims) {
+  TfLiteStablehloReduceWindowParams* output_data = nullptr;
+  EXPECT_EQ(ParseOpData(
+                BuildTestOperator(BuiltinOptions2_StablehloReduceWindowOptions,
+                                  CreateStablehloReduceWindowOptions(
+                                      builder_,
+                                      /*window_dimensions=*/ValidAttr(),
+                                      /*window_strides=*/ValidAttr(),
+                                      /*base_dilations=*/ValidAttr(),
+                                      /*window_dilations=*/ValidAttr(),
+                                      /*padding=*/ValidPaddingAttr(),
+                                      /*body_subgraph_index=*/13)
+                                      .Union()),
+                BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_reporter_,
+                &mock_allocator_, (void**)&output_data),
+            kTfLiteOk);
+  EXPECT_THAT(mock_reporter_.GetString(), StrEq(""));
+}
+
+TEST_F(StablehloReduceWindowFlatbufferConversionsTest,
+       ParseStablehloReduceWindowFailsWhenWindowDimensionsHasMoreThanMaxDims) {
+  TfLiteStablehloReduceWindowParams* output_data = nullptr;
+  EXPECT_EQ(ParseOpData(
+                BuildTestOperator(BuiltinOptions2_StablehloReduceWindowOptions,
+                                  CreateStablehloReduceWindowOptions(
+                                      builder_,
+                                      /*window_dimensions=*/InvalidAttr(),
+                                      /*window_strides=*/ValidAttr(),
+                                      /*base_dilations=*/ValidAttr(),
+                                      /*window_dilations=*/ValidAttr(),
+                                      /*padding=*/ValidPaddingAttr(),
+                                      /*body_subgraph_index=*/13)
+                                      .Union()),
+                BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_reporter_,
+                &mock_allocator_, (void**)&output_data),
+            kTfLiteError);
+  EXPECT_THAT(mock_reporter_.GetString(),
+              AllOf(HasSubstr("Found too many dimensions in the input array of "
+                              "operation 'stablehlo.reduce_window'."),
+                    HasSubstr("Check the 'window_dimensions' attribute.")));
+}
+
+TEST_F(StablehloReduceWindowFlatbufferConversionsTest,
+       ParseStablehloReduceWindowFailsWhenWindowStridesHasWrongDimCount) {
+  TfLiteStablehloReduceWindowParams* output_data = nullptr;
+  EXPECT_EQ(ParseOpData(
+                BuildTestOperator(BuiltinOptions2_StablehloReduceWindowOptions,
+                                  CreateStablehloReduceWindowOptions(
+                                      builder_,
+                                      /*window_dimensions=*/ValidAttr(),
+                                      /*window_strides=*/InvalidAttr(),
+                                      /*base_dilations=*/ValidAttr(),
+                                      /*window_dilations=*/ValidAttr(),
+                                      /*padding=*/ValidPaddingAttr(),
+                                      /*body_subgraph_index=*/13)
+                                      .Union()),
+                BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_reporter_,
+                &mock_allocator_, (void**)&output_data),
+            kTfLiteError);
+  EXPECT_THAT(
+      mock_reporter_.GetString(),
+      HasSubstr("'window_strides' attribute of 'stablehlo.reduce_window' does "
+                "not have the expected size"));
+}
+
+TEST_F(StablehloReduceWindowFlatbufferConversionsTest,
+       ParseStablehloReduceWindowFailsWhenBaseDilationsHasWrongDimCount) {
+  TfLiteStablehloReduceWindowParams* output_data = nullptr;
+  EXPECT_EQ(ParseOpData(
+                BuildTestOperator(BuiltinOptions2_StablehloReduceWindowOptions,
+                                  CreateStablehloReduceWindowOptions(
+                                      builder_,
+                                      /*window_dimensions=*/ValidAttr(),
+                                      /*window_strides=*/ValidAttr(),
+                                      /*base_dilations=*/InvalidAttr(),
+                                      /*window_dilations=*/ValidAttr(),
+                                      /*padding=*/ValidPaddingAttr(),
+                                      /*body_subgraph_index=*/13)
+                                      .Union()),
+                BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_reporter_,
+                &mock_allocator_, (void**)&output_data),
+            kTfLiteError);
+  EXPECT_THAT(
+      mock_reporter_.GetString(),
+      HasSubstr("'base_dilations' attribute of 'stablehlo.reduce_window' does "
+                "not have the expected size"));
+}
+
+TEST_F(StablehloReduceWindowFlatbufferConversionsTest,
+       ParseStablehloReduceWindowFailsWhenWindowDilationsHasWrongDimCount) {
+  TfLiteStablehloReduceWindowParams* output_data = nullptr;
+  EXPECT_EQ(ParseOpData(
+                BuildTestOperator(BuiltinOptions2_StablehloReduceWindowOptions,
+                                  CreateStablehloReduceWindowOptions(
+                                      builder_,
+                                      /*window_dimensions=*/ValidAttr(),
+                                      /*window_strides=*/ValidAttr(),
+                                      /*base_dilations=*/ValidAttr(),
+                                      /*window_dilations=*/InvalidAttr(),
+                                      /*padding=*/ValidPaddingAttr(),
+                                      /*body_subgraph_index=*/13)
+                                      .Union()),
+                BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_reporter_,
+                &mock_allocator_, (void**)&output_data),
+            kTfLiteError);
+  EXPECT_THAT(
+      mock_reporter_.GetString(),
+      HasSubstr(
+          "'window_dilations' attribute of 'stablehlo.reduce_window' does "
+          "not have the expected size"));
+}
+
+TEST_F(StablehloReduceWindowFlatbufferConversionsTest,
+       ParseStablehloReduceWindowFailsWhenPaddingHasWrongDimCount) {
+  TfLiteStablehloReduceWindowParams* output_data = nullptr;
+  EXPECT_EQ(ParseOpData(
+                BuildTestOperator(BuiltinOptions2_StablehloReduceWindowOptions,
+                                  CreateStablehloReduceWindowOptions(
+                                      builder_,
+                                      /*window_dimensions=*/ValidAttr(),
+                                      /*window_strides=*/ValidAttr(),
+                                      /*base_dilations=*/ValidAttr(),
+                                      /*window_dilations=*/ValidAttr(),
+                                      /*padding=*/InvalidPaddingAttr(),
+                                      /*body_subgraph_index=*/13)
+                                      .Union()),
+                BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_reporter_,
+                &mock_allocator_, (void**)&output_data),
+            kTfLiteError);
+  EXPECT_THAT(mock_reporter_.GetString(),
+              HasSubstr("'padding' attribute of 'stablehlo.reduce_window' does "
+                        "not have the expected size"));
+}
+
+TEST_F(StablehloReduceWindowFlatbufferConversionsTest,
+       ParseStablehloReduceWindowFailsWithWrongOptions) {
+  const Operator* stablehlo_reduce_window_op =
+      BuildTestOperator(BuiltinOptions2_StablehloReduceWindowOptions, 0);
+  TfLiteStablehloReduceWindowParams* output_data = nullptr;
+  EXPECT_EQ(
+      ParseOpData(stablehlo_reduce_window_op,
+                  BuiltinOperator_STABLEHLO_REDUCE_WINDOW, &mock_reporter_,
+                  &mock_allocator_, (void**)&output_data),
+      kTfLiteError);
+  EXPECT_THAT(
+      mock_reporter_.GetString(),
+      HasSubstr(
+          "Could not get 'stablehlo.reduce_window' operation parameters."));
+}
+
 }  // namespace tflite
diff --git a/tensorflow/lite/core/c/BUILD b/tensorflow/lite/core/c/BUILD
index 746de5d..e9e7fb1 100644
--- a/tensorflow/lite/core/c/BUILD
+++ b/tensorflow/lite/core/c/BUILD
@@ -48,6 +48,18 @@
     ],
 )
 
+filegroup(
+    name = "tflite_internal_cc_3p_api_deps_src",
+    srcs = [
+        "builtin_op_data.h",
+        "common.cc",
+        "common.h",
+    ],
+    visibility = [
+        "//tensorflow/lite:__pkg__",
+    ],
+)
+
 tflite_cc_library_with_c_headers_test(
     name = "c_api",
     hdrs = [
diff --git a/tensorflow/lite/core/c/builtin_op_data.h b/tensorflow/lite/core/c/builtin_op_data.h
index 8464a26..b96350f 100644
--- a/tensorflow/lite/core/c/builtin_op_data.h
+++ b/tensorflow/lite/core/c/builtin_op_data.h
@@ -34,6 +34,7 @@
 #define TFLITE_RESHAPE_PARAMS_MAX_DIMENSION_COUNT 8
 #define TFLITE_STABLEHLO_SCATTER_PARAMS_MAX_DIMENSION_COUNT 8
 #define TFLITE_STABLEHLO_GATHER_PARAMS_MAX_DIMENSION_COUNT 8
+#define TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT 8
 
 // TODO(aselle): Consider using "if this then that" for testing.
 
@@ -605,6 +606,22 @@
   bool indices_are_sorted;
 } TfLiteStablehloGatherParams;
 
+typedef struct {
+  // See the stablehlo spec for the explanation of the attributes:
+  // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reduce_window
+  int64_t window_dimensions
+      [TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT];
+  int64_t
+      window_strides[TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT];
+  int64_t
+      base_dilations[TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT];
+  int64_t window_dilations
+      [TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT];
+  int64_t
+      padding[2 * TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT];
+  int body_subgraph_index;
+} TfLiteStablehloReduceWindowParams;
+
 enum TfLiteReduceWindowFunction {
   TfLiteReduceWindowFunctionUnsupported,
   TfLiteReduceWindowFunctionAdd,
diff --git a/tensorflow/lite/core/c/c_api.h b/tensorflow/lite/core/c/c_api.h
index e299970..b98fddf 100644
--- a/tensorflow/lite/core/c/c_api.h
+++ b/tensorflow/lite/core/c/c_api.h
@@ -31,7 +31,7 @@
 #include "tensorflow/lite/core/c/c_api_types.h"  // IWYU pragma: export
 #include "tensorflow/lite/core/c/registration_external.h"  // IWYU pragma: export
 
-/// C API for TensorFlow Lite:
+/// C API for TensorFlow Lite.
 ///
 /// The API leans towards simplicity and uniformity instead of convenience, as
 /// most usage will be by language-specific wrappers. It provides largely the
@@ -81,9 +81,13 @@
 extern "C" {
 #endif  // __cplusplus
 
-/** \addtogroup c_api tensorflow/lite/c/c_api.h
+// clang-format off
+// NOLINTBEGIN(whitespace/line_length)
+/** \defgroup c_api tensorflow/lite/c/c_api.h
  *  @{
  */
+// NOLINTEND(whitespace/line_length)
+// clang-format on
 
 // This header should be valid in both C (e.g. C99) and C++,
 // so 'void' in parameters is not redundant.
diff --git a/tensorflow/lite/core/c/c_api_opaque.h b/tensorflow/lite/core/c/c_api_opaque.h
index b3999ee..06bdc19 100644
--- a/tensorflow/lite/core/c/c_api_opaque.h
+++ b/tensorflow/lite/core/c/c_api_opaque.h
@@ -37,9 +37,13 @@
 /// schedule than for the other TensorFlow Lite APIs. See
 /// https://www.tensorflow.org/guide/versions#separate_version_number_for_tensorflow_lite_extension_apis.
 
-/** \addtogroup c_api_opaque tensorflow/lite/c/c_api_opaque.h
+// clang-format off
+// NOLINTBEGIN(whitespace/line_length)
+/** \defgroup c_api_opaque tensorflow/lite/c/c_api_opaque.h
  *  @{
  */
+// NOLINTEND(whitespace/line_length)
+// clang-format on
 
 // --------------------------------------------------------------------------
 // Accessors for TfLiteOpaqueTensor.
@@ -139,103 +143,103 @@
     const TfLiteOpaqueTensor* opaque_tensor, void* output_data,
     size_t output_data_size);
 
-// Returns the number of strings stored in the provided 'tensor'.  Returns -1 in
-// case of failure.
+/// Returns the number of strings stored in the provided 'tensor'.
+/// Returns -1 in case of failure.
 int TfLiteOpaqueTensorGetStringCount(const TfLiteOpaqueTensor* tensor);
 
-// Stores the address of the n-th (denoted by the provided 'index') string
-// contained in the provided 'tensor' in the provided '*str' pointer.  Stores
-// the length of the string in the provided '*len' argument.
-//
-// Returns 'kTfLiteOk' if '*str' and '*len' have been set successfully.  Any
-// other return value indicates a failure, which leaves '*str' and '*len' in an
-// unspecified state.
-//
-// The range of valid indices is defined by the half open interval [0, N),
-// where N == TfLiteOpaqueTensorGetStringCount(tensor).
-//
-// Note that 'str' is not guaranteed to be null-terminated. Also note that this
-// function will not create a copy of the underlying string data.  The data is
-// owned by the 'tensor'.
+/// Stores the address of the n-th (denoted by the provided 'index') string
+/// contained in the provided 'tensor' in the provided '*str' pointer.  Stores
+/// the length of the string in the provided '*len' argument.
+///
+/// Returns 'kTfLiteOk' if '*str' and '*len' have been set successfully.  Any
+/// other return value indicates a failure, which leaves '*str' and '*len' in an
+/// unspecified state.
+///
+/// The range of valid indices is defined by the half open interval [0, N),
+/// where N == TfLiteOpaqueTensorGetStringCount(tensor).
+///
+/// Note that 'str' is not guaranteed to be null-terminated. Also note that this
+/// function will not create a copy of the underlying string data.  The data is
+/// owned by the 'tensor'.
 TfLiteStatus TfLiteOpaqueTensorGetString(const TfLiteOpaqueTensor* tensor,
                                          int index, const char** str, int* len);
 
-// Writes the array of strings specified by 'str_array' into
-// the specified 'tensor'.  The strings provided via the 'str_array' are being
-// copied into the 'tensor'. Returns 'kTfLiteOk' in case of success.  Any other
-// return value indicates a failure.
-//
-// The provided 'str_array_len' must denote the length of 'str_array'
-// and 'str_n_len[i]' must denote the length of the i-th string.
-//
-// The provided strings don't need to be null terminated and may contain
-// embedded null characters.  The amount of bytes copied into the 'tensor' is
-// entirely determined by 'str_n_len[i]' and it is the caller's responsibility
-// to set this value correctly to avoid undefined behavior.
-//
-// Also note that calling 'TfLiteOpaqueTensorWriteStrings' deallocates any
-// previously stored data in the 'tensor'.
+/// Writes the array of strings specified by 'str_array' into
+/// the specified 'tensor'.  The strings provided via the 'str_array' are being
+/// copied into the 'tensor'. Returns 'kTfLiteOk' in case of success.  Any other
+/// return value indicates a failure.
+///
+/// The provided 'str_array_len' must denote the length of 'str_array'
+/// and 'str_n_len[i]' must denote the length of the i-th string.
+///
+/// The provided strings don't need to be null terminated and may contain
+/// embedded null characters.  The amount of bytes copied into the 'tensor' is
+/// entirely determined by 'str_n_len[i]' and it is the caller's responsibility
+/// to set this value correctly to avoid undefined behavior.
+///
+/// Also note that calling 'TfLiteOpaqueTensorWriteStrings' deallocates any
+/// previously stored data in the 'tensor'.
 TfLiteStatus TfLiteOpaqueTensorWriteStrings(TfLiteOpaqueTensor* tensor,
                                             const char* const* str_array,
                                             int str_array_len,
                                             const int* str_n_len);
 
-// Writes the string pointed to by the provided 'str' pointer of length 'len'
-// into the provided 'tensor'.  The string provided via 'str' is
-// copied into the 'tensor'.  Returns 'kTfLiteOk' in case of success.  Any
-// other return value indicates a failure.
-//
-// Note that calling 'TfLiteOpaqueTensorWriteString' deallocates any
-// previously stored data in the 'tensor'.  E.g. suppose 't' denotes a
-// 'TfLiteOpaqueTensor*', then calling 'TfLiteOpaqueTensorWriteString(t, "AB",
-// 2)' followed by a call to 'TfLiteOpaqueTensorWriteString(t, "CD", 2)' will
-// lead to 't' containing 'CD', not 'ABCD'.
-//
-// 'TfLiteOpaqueTensorWriteString' is a convenience function for the use case
-// of writing a single string to a tensor and its effects are identical to
-// calling 'TfLiteOpaqueTensorWriteStrings' with an array of a single string.
+/// Writes the string pointed to by the provided 'str' pointer of length 'len'
+/// into the provided 'tensor'.  The string provided via 'str' is
+/// copied into the 'tensor'.  Returns 'kTfLiteOk' in case of success. Any
+/// other return value indicates a failure.
+///
+/// Note that calling 'TfLiteOpaqueTensorWriteString' deallocates any
+/// previously stored data in the 'tensor'.  E.g. suppose 't' denotes a
+/// 'TfLiteOpaqueTensor*', then calling 'TfLiteOpaqueTensorWriteString(t, "AB",
+/// 2)' followed by a call to 'TfLiteOpaqueTensorWriteString(t, "CD", 2)' will
+/// lead to 't' containing 'CD', not 'ABCD'.
+///
+/// 'TfLiteOpaqueTensorWriteString' is a convenience function for the use case
+/// of writing a single string to a tensor and its effects are identical to
+/// calling 'TfLiteOpaqueTensorWriteStrings' with an array of a single string.
 TfLiteStatus TfLiteOpaqueTensorWriteString(TfLiteOpaqueTensor* tensor,
                                            const char* str, int len);
 
-// An opaque type to create a tensor.
+/// An opaque type to create a tensor.
 typedef struct TfLiteOpaqueTensorBuilder TfLiteOpaqueTensorBuilder;
 
-// Creates an opaque tensor builder object.
+/// Creates an opaque tensor builder object.
 TfLiteOpaqueTensorBuilder* TfLiteOpaqueTensorBuilderCreate();
 
-// Deletes an opaque tensor builder object.
+/// Deletes an opaque tensor builder object.
 void TfLiteOpaqueTensorBuilderDelete(TfLiteOpaqueTensorBuilder* builder);
 
-// Sets the 'TfLiteType' of the provided 'builder' to the provided 'type'.
-// Returns the address of the provided 'builder', so that builder calls can be
-// chained together.
+/// Sets the 'TfLiteType' of the provided 'builder' to the provided 'type'.
+/// Returns the address of the provided 'builder', so that builder calls can be
+/// chained together.
 TfLiteOpaqueTensorBuilder* TfLiteOpaqueTensorBuilderSetType(
     TfLiteOpaqueTensorBuilder* builder, TfLiteType type);
 
-// Sets the raw data of the provided 'builder' to the provided 'data'. Returns
-// the address of the provided 'builder', so that builder calls can be chained
-// together.
+/// Sets the raw data of the provided 'builder' to the provided 'data'. Returns
+/// the address of the provided 'builder', so that builder calls can be chained
+/// together.
 TfLiteOpaqueTensorBuilder* TfLiteOpaqueTensorBuilderSetData(
     TfLiteOpaqueTensorBuilder* builder, void* data);
 
-// Sets the allocation type of the provided 'builder' to the provided
-// 'allocation_type'.  The 'allocation_type' must be one of the following:
-// 'kTfLiteDynamic', 'kTfLiteArenaRw' or 'kTfLiteArenaRwPersistent'.  If the
-// provided 'allocation_type' is not one of those values then
-// 'TfLiteOpaqueContextAddTensor' will return an error. Returns the address of
-// the provided 'builder', so that builder calls can be chained together.
+/// Sets the allocation type of the provided 'builder' to the provided
+/// 'allocation_type'.  The 'allocation_type' must be one of the following:
+/// 'kTfLiteDynamic', 'kTfLiteArenaRw' or 'kTfLiteArenaRwPersistent'.  If the
+/// provided 'allocation_type' is not one of those values then
+/// 'TfLiteOpaqueContextAddTensor' will return an error. Returns the address of
+/// the provided 'builder', so that builder calls can be chained together.
 TfLiteOpaqueTensorBuilder* TfLiteOpaqueTensorBuilderSetAllocationType(
     TfLiteOpaqueTensorBuilder* builder, TfLiteAllocationType allocation_type);
 
-// Sets the quantization params of the provided 'builder' to the provided
-// 'params'. Returns the address of the provided 'builder', so that builder
-// calls can be chained together.
+/// Sets the quantization params of the provided 'builder' to the provided
+/// 'params'. Returns the address of the provided 'builder', so that builder
+/// calls can be chained together.
 TfLiteOpaqueTensorBuilder* TfLiteOpaqueTensorBuilderSetQuantizationParams(
     TfLiteOpaqueTensorBuilder* builder, TfLiteQuantizationParams params);
 
-// Sets the quantization of the provided 'builder' to the provided
-// 'quantization'. Returns the address of the provided 'builder', so that
-// builder calls can be chained together.
+/// Sets the quantization of the provided 'builder' to the provided
+/// 'quantization'. Returns the address of the provided 'builder', so that
+/// builder calls can be chained together.
 TfLiteOpaqueTensorBuilder* TfLiteOpaqueTensorBuilderSetQuantization(
     TfLiteOpaqueTensorBuilder* builder, TfLiteQuantization quantization);
 
@@ -337,18 +341,18 @@
                                          const int** temporaries,
                                          int* num_temporaries);
 
-// Given an 'index_of_input', which must be in the range of [0, N), where N is
-// the number of input tensors of the provided 'opaque_node', returns the
-// (global) index of the tensor that holds the input.  Returns -1 if
-// 'index_of_input' is not within the [0, N) range.
+/// Given an 'index_of_input', which must be in the range of [0, N), where N is
+/// the number of input tensors of the provided 'opaque_node', returns the
+/// (global) index of the tensor that holds the input.  Returns -1 if
+/// 'index_of_input' is not within the [0, N) range.
 TFL_CAPI_EXPORT
 int TfLiteOpaqueNodeGetInputTensorIndex(const TfLiteOpaqueNode* opaque_node,
                                         int index_of_input);
 
-// Given an 'index_of_output', which must be in the range of [0, N), where N is
-// the number of output tensors of the provided 'opaque_node', returns the
-// (global) index of the tensor that holds the output.  Returns -1 if
-// 'index_of_output' is not within the [0, N) range.
+/// Given an 'index_of_output', which must be in the range of [0, N), where N is
+/// the number of output tensors of the provided 'opaque_node', returns the
+/// (global) index of the tensor that holds the output. Returns -1 if
+/// 'index_of_output' is not within the [0, N) range.
 TFL_CAPI_EXPORT
 int TfLiteOpaqueNodeGetOutputTensorIndex(const TfLiteOpaqueNode* opaque_node,
                                          int index_of_output);
@@ -506,6 +510,7 @@
 /// Retrieves the corresponding TfLiteOpaqueContext of a subgraph given a
 /// subgraph index and switches to the delegate context for this subgraph. If an
 /// invalid subgraph index is given, then returns kTfLiteError.
+///
 /// NOTE: This function is expected to be paired with
 /// TfLiteOpaqueContextReleaseSubgraphContext() once the delegate preparation is
 /// done and/or the delegate context functions are no longer needed.
@@ -518,6 +523,7 @@
 ///
 /// Releases the corresponding TfLiteOpaqueContext by switching back to the
 /// TFLite kernel context for this specified subgraph.
+///
 /// NOTE: This function is expected to be used after
 /// TfLiteOpaqueContextAcquireSubgraphContext() once the delegate preparation is
 /// done and/or the delegate context functions are no longer needed.
@@ -535,6 +541,7 @@
 /// a specific TfLiteOpaqueDelegate that is already supposed to be
 /// aware of this condition, and therefore, TfLiteInterpreter can skip invoking
 /// `ModifyGraphWithDelegate` on this subgraph.
+///
 /// NOTE: This function is expected to be called only when the subgraph that
 /// `subgraph_index` is pointing to should be skipped by
 /// interpreter::ModifyGraphWithDelegate (e.g. the subgraph is part of the list
@@ -548,53 +555,55 @@
 ///   1. The delegate supports while op
 ///   2. Both condition subgraph `i` and body subgraph `j` can be fully
 ///   delegated to the delegate.
+///
 /// Then if the delegate decides to support the while node along with both body
 /// and condition subgraphs, it should mark subgraphs `i` and `j` skippable so
 /// that those two subgraphs won't be delegated to another delegate.
+///
 /// WARNING: It is the delegate's responsibility to define when to skip
-/// Subgraph::ModifyGraphWithDelegate, to check for any edge cases (i.e.
+/// `Subgraph::ModifyGraphWithDelegate`, to check for any edge cases (i.e.
 /// multiple references to the subgraph that `subgraph_index` is pointing to),
 /// and to mark a subgraph as skippable by using this function.
 TFL_CAPI_EXPORT
 TfLiteStatus TfLiteOpaqueContextMarkSubgraphAsDelegationSkippable(
     TfLiteOpaqueContext* opaque_context, int subgraph_index);
 
-// Loads metadata of a TF Lite node's custom initialization data.  Specifically:
-// * Loads into the supplied 'fd' the file descriptor of the file that stores
-//   the 'node's custom  initialization data.  This output parameter will be
-//   loaded if the TF Lite runtime has access to the file descriptor, though
-//   this is not always the case, e.g. if a client provides a tflite::Model
-//   directly to the TF Lite runtime.  If 'fd' can be loaded then 'kTfLiteOk'
-//   will be returned, otherwise 'kTfLiteError' is returned.
-// * Loads into the supplied 'custom_initial_data_offset_in_file' pointer the
-//   offset of the 'node's custom init data in the file associated with 'fd'.
-//   This output parameter will be set to -1 if the 'node' does not have custom
-//   init data set.
-// * Loads into the supplied 'custom_initial_data_size' the size of the
-//   custom initialization data.  This output parameter will be set to -1 if the
-//   'node' does not have custom init data set.
-//
-// Returns 'kTfLiteOk' when 'fd' has been loaded successfully and 'kTfLiteError'
-// otherwise.  Note that this means that 'kTfLiteOk' can be returned, even if
-// the 'node' does not have custom init data set.
+/// Loads metadata of a TF Lite node's custom initialization data. Specifically:
+/// * Loads into the supplied 'fd' the file descriptor of the file that stores
+///   the 'node's custom  initialization data.  This output parameter will be
+///   loaded if the TF Lite runtime has access to the file descriptor, though
+///   this is not always the case, e.g. if a client provides a tflite::Model
+///   directly to the TF Lite runtime.  If 'fd' can be loaded then 'kTfLiteOk'
+///   will be returned, otherwise 'kTfLiteError' is returned.
+/// * Loads into the supplied 'custom_initial_data_offset_in_file' pointer the
+///   offset of the 'node's custom init data in the file associated with 'fd'.
+///   This output parameter will be set to -1 if the 'node' does not have custom
+///   init data set.
+/// * Loads into the supplied 'custom_initial_data_size' the size of the
+///   custom initialization data.  This output parameter will be set to -1 if
+///   the 'node' does not have custom init data set.
+///
+/// Returns 'kTfLiteOk' when 'fd' has been loaded successfully and
+/// 'kTfLiteError' otherwise.  Note that this means that 'kTfLiteOk' can be
+/// returned, even if the 'node' does not have custom init data set.
 TFL_CAPI_EXPORT
 TfLiteStatus TfLiteOpaqueContextGetNodeInitDataMmapInfo(
     const TfLiteOpaqueContext* context, const TfLiteOpaqueNode* node, int* fd,
     int64_t* custom_initial_data_offset_in_file,
     int64_t* custom_initial_data_size);
 
-// Adds an additional tensor and configures its properties based on the provided
-// 'builder', preserving pre-existing Tensor entries.  If non-null, the value
-// pointed to by 'new_tensor_index' will be set to the index of the
-// new tensor.  Returns 'kTfLiteOk' when the tensor has been added
-// successfully.  Returns 'kTfLiteError' in case of failure.
+/// Adds an additional tensor and configures its properties based on the
+/// provided 'builder', preserving pre-existing Tensor entries. If non-null,
+/// the value pointed to by 'new_tensor_index' will be set to the index of the
+/// new tensor.  Returns 'kTfLiteOk' when the tensor has been added
+/// successfully.  Returns 'kTfLiteError' in case of failure.
 TFL_CAPI_EXPORT
 TfLiteStatus TfLiteOpaqueContextAddTensor(TfLiteOpaqueContext* context,
                                           TfLiteOpaqueTensorBuilder* builder,
                                           int* new_tensor_index);
 
-// Populates the size in bytes of a provide 'type' into 'bytes'.  Returns
-// 'kTfLiteOk' for valid types, and 'kTfLiteError' otherwise.
+/// Populates the size in bytes of a provide 'type' into 'bytes'.  Returns
+/// 'kTfLiteOk' for valid types, and 'kTfLiteError' otherwise.
 TFL_CAPI_EXPORT
 TfLiteStatus TfLiteOpaqueContextGetSizeOfType(TfLiteOpaqueContext* context,
                                               TfLiteType type, size_t* bytes);
@@ -617,20 +626,21 @@
 void TfLiteOpaqueContextReportError(struct TfLiteOpaqueContext* opaque_context,
                                     const char* format, ...);
 
-/// Same as 'TfLiteOpaqueContextReportError', but with the variable arguments
-/// passed via a 'va_list' instead of directly.
+/// Same as `TfLiteOpaqueContextReportError`, but with the variable arguments
+/// passed via a `va_list` instead of directly.
 ///
 /// Callers that receive an ellipsis and want to forward it to
 /// to the opaque context error reporting API can add the ellipsis content to a
-/// 'va_list' and then call 'TfLiteOpaqueContextReportErrorVa'. E.g.:
+/// `va_list` and then call `TfLiteOpaqueContextReportErrorVa`. E.g.:
 ///
-/// void MyErrorReporter(struct TfLiteOpaqueContext* opaque_context,
-///                                     const char* format, ...) {
-///   va_list vlist;
-///   va_start(vlist, format);
-///   TfLiteOpaqueContextReportErrorVa(opaque_context, format, vlist);
-///   va_end(vlist);
-/// }
+///
+///     void MyErrorReporter(struct TfLiteOpaqueContext* opaque_context,
+///                                      const char* format, ...) {
+///       va_list vlist;
+///       va_start(vlist, format);
+///       TfLiteOpaqueContextReportErrorVa(opaque_context, format, vlist);
+///       va_end(vlist);
+///     }
 TFL_CAPI_EXPORT
 void TfLiteOpaqueContextReportErrorVa(
     struct TfLiteOpaqueContext* opaque_context, const char* format,
@@ -676,8 +686,8 @@
 
 #endif  // TF_LITE_STRIP_ERROR_STRINGS
 
-// Check whether value is true, and if not return kTfLiteError from
-// the current function (and report the error string msg).
+/// Check whether value is true, and if not return kTfLiteError from
+/// the current function (and report the error string msg).
 #if !defined(TF_LITE_OPAQUE_ENSURE_MSG)
 #define TF_LITE_OPAQUE_ENSURE_MSG(opaque_context, value, msg)        \
   do {                                                               \
@@ -688,8 +698,8 @@
   } while (0)
 #endif
 
-// Check whether the value `a` is true, and if not return kTfLiteError from
-// the current function, while also reporting the location of the error.
+/// Check whether the value `a` is true, and if not return kTfLiteError from
+/// the current function, while also reporting the location of the error.
 #if !defined(TF_LITE_OPAQUE_ENSURE)
 #define TF_LITE_OPAQUE_ENSURE(opaque_context, a)                           \
   do {                                                                     \
@@ -701,11 +711,12 @@
   } while (0)
 #endif
 
-// Check whether the value `a == b` is true, and if not return kTfLiteError from
-// the current function, while also reporting the location of the error.
-// `a` and `b` may be evaluated more than once, so no side effects or
-// extremely expensive computations should be done.
-// NOTE: Use TF_LITE_ENSURE_TYPES_EQ if comparing TfLiteTypes.
+/// Check whether the value `a == b` is true, and if not return kTfLiteError
+/// from the current function, while also reporting the location of the error.
+/// `a` and `b` may be evaluated more than once, so no side effects or
+/// extremely expensive computations should be done.
+///
+/// NOTE: Use TF_LITE_ENSURE_TYPES_EQ if comparing TfLiteTypes.
 #if !defined(TF_LITE_OPAQUE_ENSURE_EQ)
 #define TF_LITE_OPAQUE_ENSURE_EQ(opaque_context, a, b)                  \
   do {                                                                  \
diff --git a/tensorflow/lite/core/c/c_api_types.h b/tensorflow/lite/core/c/c_api_types.h
index 3a6594d..c1f0c56 100644
--- a/tensorflow/lite/core/c/c_api_types.h
+++ b/tensorflow/lite/core/c/c_api_types.h
@@ -34,9 +34,13 @@
 extern "C" {
 #endif
 
-/** \addtogroup c_api_types tensorflow/lite/c/c_api_types.h
+// clang-format off
+// NOLINTBEGIN(whitespace/line_length)
+/** \defgroup c_api_types tensorflow/lite/c/c_api_types.h
  *  @{
  */
+// NOLINTEND(whitespace/line_length)
+// clang-format on
 
 // Define TFL_CAPI_EXPORT macro to export a function properly with a shared
 // library.
@@ -123,12 +127,11 @@
   kTfLiteInt4 = 18,
 } TfLiteType;
 
-/// Legacy. Will be deprecated in favor of TfLiteAffineQuantization.
+/// Legacy. Will be deprecated in favor of `TfLiteAffineQuantization`.
 /// If per-layer quantization is specified this field will still be populated in
-/// addition to TfLiteAffineQuantization.
+/// addition to `TfLiteAffineQuantization`.
 /// Parameters for asymmetric quantization. Quantized values can be converted
-/// back to float using:
-///     real_value = scale * (quantized_value - zero_point)
+/// back to float using: `real_value = scale * (quantized_value - zero_point)`
 typedef struct TfLiteQuantizationParams {
   float scale;
   int32_t zero_point;
@@ -156,6 +159,7 @@
 /// This is an abstract type that is intended to have the same
 /// role as TfLiteDelegate, but without exposing the implementation
 /// details of how delegates are implemented.
+///
 /// WARNING: This is an experimental type and subject to change.
 typedef struct TfLiteOpaqueDelegateStruct TfLiteOpaqueDelegateStruct;
 
@@ -163,6 +167,7 @@
 /// TfLiteDelegate; allows delegation of nodes to alternative backends.
 /// For TF Lite in Play Services, this is an opaque type,
 /// but for regular TF Lite, this is just a typedef for TfLiteDelegate.
+///
 /// WARNING: This is an experimental type and subject to change.
 #if TFLITE_WITH_STABLE_ABI || TFLITE_USE_OPAQUE_DELEGATE
 typedef TfLiteOpaqueDelegateStruct TfLiteOpaqueDelegate;
diff --git a/tensorflow/lite/core/c/common.h b/tensorflow/lite/core/c/common.h
index d9abf9e..0ebba76 100644
--- a/tensorflow/lite/core/c/common.h
+++ b/tensorflow/lite/core/c/common.h
@@ -13,31 +13,31 @@
 limitations under the License.
 ==============================================================================*/
 
-// This file defines common C types and APIs for implementing operations,
-// delegates and other constructs in TensorFlow Lite. The actual operations and
-// delegates can be defined using C++, but the interface between the interpreter
-// and the operations are C.
-//
-// Summary of abstractions
-// TF_LITE_ENSURE - Self-sufficient error checking
-// TfLiteStatus - Status reporting
-// TfLiteIntArray - stores tensor shapes (dims),
-// TfLiteContext - allows an op to access the tensors
-// TfLiteTensor - tensor (a multidimensional array)
-// TfLiteNode - a single node or operation
-// TfLiteRegistration - the implementation of a conceptual operation.
-// TfLiteDelegate - allows delegation of nodes to alternative backends.
-//
-// Some abstractions in this file are created and managed by Interpreter.
-//
-// NOTE: The order of values in these structs are "semi-ABI stable". New values
-// should be added only to the end of structs and never reordered.
+/// This file defines common C types and APIs for implementing operations,
+/// delegates and other constructs in TensorFlow Lite. The actual operations and
+/// delegates can be defined using C++, but the interface between the
+/// interpreter and the operations are C.
+///
+/// Summary of abstractions:
+/// * `TF_LITE_ENSURE` - self-sufficient error checking
+/// * `TfLiteStatus` - status reporting
+/// * `TfLiteIntArray` - stores tensor shapes (dims),
+/// * `TfLiteContext` - allows an op to access the tensors
+/// * `TfLiteTensor` - tensor (a multidimensional array)
+/// * `TfLiteNode` - a single node or operation
+/// * `TfLiteRegistration` - the implementation of a conceptual operation.
+/// * `TfLiteDelegate` - allows delegation of nodes to alternative backends.
+///
+/// Some abstractions in this file are created and managed by Interpreter.
+///
+/// NOTE: The order of values in these structs are "semi-ABI stable". New values
+/// should be added only to the end of structs and never reordered.
 
-/// WARNING: Users of TensorFlow Lite should not include this file directly,
-/// but should instead include
-/// "third_party/tensorflow/lite/c/common.h".
-/// Only the TensorFlow Lite implementation itself should include this
-/// file directly.
+// WARNING: Users of TensorFlow Lite should not include this file directly,
+// but should instead include
+// "third_party/tensorflow/lite/c/common.h".
+// Only the TensorFlow Lite implementation itself should include this
+// file directly.
 // IWYU pragma: private, include "third_party/tensorflow/lite/c/common.h"
 
 #ifndef TENSORFLOW_LITE_CORE_C_COMMON_H_
@@ -54,15 +54,23 @@
 extern "C" {
 #endif  // __cplusplus
 
-// The list of external context types known to TF Lite. This list exists solely
-// to avoid conflicts and to ensure ops can share the external contexts they
-// need. Access to the external contexts is controlled by one of the
-// corresponding support files.
+// clang-format off
+// NOLINTBEGIN(whitespace/line_length)
+/** \defgroup common tensorflow/lite/c/common.h
+ *  @{
+ */
+// NOLINTEND(whitespace/line_length)
+// clang-format on
+
+/// The list of external context types known to TF Lite. This list exists solely
+/// to avoid conflicts and to ensure ops can share the external contexts they
+/// need. Access to the external contexts is controlled by one of the
+/// corresponding support files.
 typedef enum TfLiteExternalContextType {
-  kTfLiteEigenContext = 0,       // include eigen_support.h to use.
-  kTfLiteGemmLowpContext = 1,    // include gemm_support.h to use.
-  kTfLiteEdgeTpuContext = 2,     // Placeholder for Edge TPU support.
-  kTfLiteCpuBackendContext = 3,  // include cpu_backend_context.h to use.
+  kTfLiteEigenContext = 0,       /// include eigen_support.h to use.
+  kTfLiteGemmLowpContext = 1,    /// include gemm_support.h to use.
+  kTfLiteEdgeTpuContext = 2,     /// Placeholder for Edge TPU support.
+  kTfLiteCpuBackendContext = 3,  /// include cpu_backend_context.h to use.
   kTfLiteMaxExternalContexts = 4
 } TfLiteExternalContextType;
 
@@ -73,11 +81,11 @@
 struct TfLiteRegistration;
 struct TfLiteOpaqueDelegateBuilder;
 
-// An external context is a collection of information unrelated to the TF Lite
-// framework, but useful to a subset of the ops. TF Lite knows very little
-// about the actual contexts, but it keeps a list of them, and is able to
-// refresh them if configurations like the number of recommended threads
-// change.
+/// An external context is a collection of information unrelated to the TF Lite
+/// framework, but useful to a subset of the ops. TF Lite knows very little
+/// about the actual contexts, but it keeps a list of them, and is able to
+/// refresh them if configurations like the number of recommended threads
+/// change.
 typedef struct TfLiteExternalContext {
   TfLiteExternalContextType type;
   TfLiteStatus (*Refresh)(struct TfLiteContext* context);
@@ -85,8 +93,8 @@
 
 #define kTfLiteOptionalTensor (-1)
 
-// Fixed size list of integers. Used for dimensions and inputs/outputs tensor
-// indices
+/// Fixed size list of integers. Used for dimensions and inputs/outputs tensor
+/// indices
 typedef struct TfLiteIntArray {
   int size;
 
@@ -105,33 +113,33 @@
 #endif
 } TfLiteIntArray;
 
-// Given the size (number of elements) in a TfLiteIntArray, calculate its size
-// in bytes.
+/// Given the size (number of elements) in a TfLiteIntArray, calculate its size
+/// in bytes.
 size_t TfLiteIntArrayGetSizeInBytes(int size);
 
 #ifndef TF_LITE_STATIC_MEMORY
-// Create a array of a given `size` (uninitialized entries).
-// This returns a pointer, that you must free using TfLiteIntArrayFree().
+/// Create a array of a given `size` (uninitialized entries).
+/// This returns a pointer, that you must free using TfLiteIntArrayFree().
 TfLiteIntArray* TfLiteIntArrayCreate(int size);
 #endif
 
-// Check if two intarrays are equal. Returns 1 if they are equal, 0 otherwise.
+/// Check if two intarrays are equal. Returns 1 if they are equal, 0 otherwise.
 int TfLiteIntArrayEqual(const TfLiteIntArray* a, const TfLiteIntArray* b);
 
-// Check if an intarray equals an array. Returns 1 if equals, 0 otherwise.
+/// Check if an intarray equals an array. Returns 1 if equals, 0 otherwise.
 int TfLiteIntArrayEqualsArray(const TfLiteIntArray* a, int b_size,
                               const int b_data[]);
 
 #ifndef TF_LITE_STATIC_MEMORY
-// Create a copy of an array passed as `src`.
-// You are expected to free memory with TfLiteIntArrayFree
+/// Create a copy of an array passed as `src`.
+/// You are expected to free memory with TfLiteIntArrayFree
 TfLiteIntArray* TfLiteIntArrayCopy(const TfLiteIntArray* src);
 
-// Free memory of array `a`.
+/// Free memory of array `a`.
 void TfLiteIntArrayFree(TfLiteIntArray* a);
 #endif  // TF_LITE_STATIC_MEMORY
 
-// Fixed size list of floats. Used for per-channel quantization.
+/// Fixed size list of floats. Used for per-channel quantization.
 typedef struct TfLiteFloatArray {
   int size;
 #if defined(_MSC_VER)
@@ -149,20 +157,20 @@
 #endif
 } TfLiteFloatArray;
 
-// Given the size (number of elements) in a TfLiteFloatArray, calculate its size
-// in bytes.
+/// Given the size (number of elements) in a TfLiteFloatArray, calculate its
+/// size in bytes.
 int TfLiteFloatArrayGetSizeInBytes(int size);
 
 #ifndef TF_LITE_STATIC_MEMORY
-// Create a array of a given `size` (uninitialized entries).
-// This returns a pointer, that you must free using TfLiteFloatArrayFree().
+/// Create a array of a given `size` (uninitialized entries).
+/// This returns a pointer, that you must free using TfLiteFloatArrayFree().
 TfLiteFloatArray* TfLiteFloatArrayCreate(int size);
 
-// Create a copy of an array passed as `src`.
-// You are expected to free memory with TfLiteFloatArrayFree.
+/// Create a copy of an array passed as `src`.
+/// You are expected to free memory with TfLiteFloatArrayFree.
 TfLiteFloatArray* TfLiteFloatArrayCopy(const TfLiteFloatArray* src);
 
-// Free memory of array `a`.
+/// Free memory of array `a`.
 void TfLiteFloatArrayFree(TfLiteFloatArray* a);
 #endif  // TF_LITE_STATIC_MEMORY
 
@@ -191,18 +199,18 @@
 #define TF_LITE_MAYBE_KERNEL_LOG(context, ...) ARGS_UNUSED(__VA_ARGS__)
 #endif  // TF_LITE_STRIP_ERROR_STRINGS
 
-// Check whether value is true, and if not return kTfLiteError from
-// the current function (and report the error string msg).
-#define TF_LITE_ENSURE_MSG(context, value, msg)        \
-  do {                                                 \
-    if (!(value)) {                                    \
-      TF_LITE_KERNEL_LOG((context), __FILE__ " " msg); \
-      return kTfLiteError;                             \
-    }                                                  \
+/// Check whether value is true, and if not return kTfLiteError from
+/// the current function (and report the error string msg).
+#define TF_LITE_ENSURE_MSG(context, value, ...)                \
+  do {                                                         \
+    if (!(value)) {                                            \
+      TF_LITE_KERNEL_LOG((context), __FILE__ " " __VA_ARGS__); \
+      return kTfLiteError;                                     \
+    }                                                          \
   } while (0)
 
-// Check whether the value `a` is true, and if not return kTfLiteError from
-// the current function, while also reporting the location of the error.
+/// Check whether the value `a` is true, and if not return kTfLiteError from
+/// the current function, while also reporting the location of the error.
 #define TF_LITE_ENSURE(context, a)                                      \
   do {                                                                  \
     if (!(a)) {                                                         \
@@ -220,11 +228,12 @@
     }                            \
   } while (0)
 
-// Check whether the value `a == b` is true, and if not return kTfLiteError from
-// the current function, while also reporting the location of the error.
-// `a` and `b` may be evaluated more than once, so no side effects or
-// extremely expensive computations should be done.
-// NOTE: Use TF_LITE_ENSURE_TYPES_EQ if comparing TfLiteTypes.
+/// Check whether the value `a == b` is true, and if not return kTfLiteError
+/// from the current function, while also reporting the location of the error.
+/// `a` and `b` may be evaluated more than once, so no side effects or
+/// extremely expensive computations should be done.
+///
+/// NOTE: Use TF_LITE_ENSURE_TYPES_EQ if comparing TfLiteTypes.
 #define TF_LITE_ENSURE_EQ(context, a, b)                                   \
   do {                                                                     \
     if ((a) != (b)) {                                                      \
@@ -263,61 +272,62 @@
     }                                      \
   } while (0)
 
-// Single-precision complex data type compatible with the C99 definition.
+/// Single-precision complex data type compatible with the C99 definition.
 typedef struct TfLiteComplex64 {
-  float re, im;  // real and imaginary parts, respectively.
+  float re, im;  /// real and imaginary parts, respectively.
 } TfLiteComplex64;
 
-// Double-precision complex data type compatible with the C99 definition.
+/// Double-precision complex data type compatible with the C99 definition.
 typedef struct TfLiteComplex128 {
-  double re, im;  // real and imaginary parts, respectively.
+  double re, im;  /// real and imaginary parts, respectively.
 } TfLiteComplex128;
 
-// Half precision data type compatible with the C99 definition.
+/// Half precision data type compatible with the C99 definition.
 typedef struct TfLiteFloat16 {
   uint16_t data;
 } TfLiteFloat16;
 
-// Return the name of a given type, for error reporting purposes.
+/// Return the name of a given type, for error reporting purposes.
 const char* TfLiteTypeGetName(TfLiteType type);
 
-// SupportedQuantizationTypes.
+/// SupportedQuantizationTypes.
 typedef enum TfLiteQuantizationType {
-  // No quantization.
+  /// No quantization.
   kTfLiteNoQuantization = 0,
-  // Affine quantization (with support for per-channel quantization).
-  // Corresponds to TfLiteAffineQuantization.
+  /// Affine quantization (with support for per-channel quantization).
+  /// Corresponds to TfLiteAffineQuantization.
   kTfLiteAffineQuantization = 1,
 } TfLiteQuantizationType;
 
-// Structure specifying the quantization used by the tensor, if-any.
+/// Structure specifying the quantization used by the tensor, if-any.
 typedef struct TfLiteQuantization {
-  // The type of quantization held by params.
+  /// The type of quantization held by params.
   TfLiteQuantizationType type;
-  // Holds an optional reference to a quantization param structure. The actual
-  // type depends on the value of the `type` field (see the comment there for
-  // the values and corresponding types).
+  /// Holds an optional reference to a quantization param structure. The actual
+  /// type depends on the value of the `type` field (see the comment there for
+  /// the values and corresponding types).
   void* params;
 } TfLiteQuantization;
 
-// Parameters for asymmetric quantization across a dimension (i.e per output
-// channel quantization).
-// quantized_dimension specifies which dimension the scales and zero_points
-// correspond to.
-// For a particular value in quantized_dimension, quantized values can be
-// converted back to float using:
-//     real_value = scale * (quantized_value - zero_point)
+/// Parameters for asymmetric quantization across a dimension (i.e per output
+/// channel quantization).
+/// quantized_dimension specifies which dimension the scales and zero_points
+/// correspond to.
+/// For a particular value in quantized_dimension, quantized values can be
+/// converted back to float using:
+///     `real_value = scale * (quantized_value - zero_point)`
 typedef struct TfLiteAffineQuantization {
   TfLiteFloatArray* scale;
   TfLiteIntArray* zero_point;
   int32_t quantized_dimension;
 } TfLiteAffineQuantization;
 
-/* A union of pointers that points to memory for a given tensor. */
+/// A union of pointers that points to memory for a given tensor.
+///
+/// Do not access these members directly, if possible, use
+/// `GetTensorData<TYPE>(tensor)` instead, otherwise only access `.data`, as
+/// other members are deprecated.
 typedef union TfLitePtrUnion {
-  /* Do not access these members directly, if possible, use
-   * GetTensorData<TYPE>(tensor) instead, otherwise only access .data, as other
-   * members are deprecated. */
   int32_t* i32;
   uint32_t* u32;
   int64_t* i64;
@@ -334,24 +344,26 @@
   TfLiteComplex64* c64;
   TfLiteComplex128* c128;
   int8_t* int8;
-  /* Only use this member. */
+  /// Only use this member.
   void* data;
 } TfLitePtrUnion;
 
-// Memory allocation strategies.
-//  * kTfLiteMmapRo: Read-only memory-mapped data, or data externally allocated.
-//  * kTfLiteArenaRw: Arena allocated with no guarantees about persistence,
-//        and available during eval.
-//  * kTfLiteArenaRwPersistent: Arena allocated but persistent across eval, and
-//        only available during eval.
-//  * kTfLiteDynamic: Allocated during eval, or for string tensors.
-//  * kTfLitePersistentRo: Allocated and populated during prepare. This is
-//        useful for tensors that can be computed during prepare and treated
-//        as constant inputs for downstream ops (also in prepare).
-//  * kTfLiteCustom: Custom memory allocation provided by the user. See
-//        TfLiteCustomAllocation below.
-// * kTfLiteVariantObject: Allocation is an arbitrary type-erased C++ object.
-//        Allocation and deallocation are done through `new` and `delete`.
+/// Memory allocation strategies.
+///  * `kTfLiteMmapRo`: Read-only memory-mapped data, or data externally
+///        allocated.
+///  * `kTfLiteArenaRw`: Arena allocated with no guarantees about persistence,
+///        and available during eval.
+///  * `kTfLiteArenaRwPersistent`: Arena allocated but persistent across eval,
+///  and only available during eval.
+///  * `kTfLiteDynamic`: Allocated during eval, or for string tensors.
+///  * `kTfLitePersistentRo`: Allocated and populated during prepare. This is
+///        useful for tensors that can be computed during prepare and treated
+///        as constant inputs for downstream ops (also in prepare).
+///  * `kTfLiteCustom`: Custom memory allocation provided by the user. See
+///        TfLiteCustomAllocation below.
+///  * `kTfLiteVariantObject`: Allocation is an arbitrary type-erased C++
+///  object.
+///        Allocation and deallocation are done through `new` and `delete`.
 typedef enum TfLiteAllocationType {
   kTfLiteMemNone = 0,
   kTfLiteMmapRo,
@@ -363,30 +375,30 @@
   kTfLiteVariantObject,
 } TfLiteAllocationType;
 
-// Memory allocation strategies.
-//
-// TfLiteAllocationType values have been overloaded to mean more than their
-// original intent. This enum should only be used to document the allocation
-// strategy used by a tensor for it data.
+/// Memory allocation strategies.
+///
+/// TfLiteAllocationType values have been overloaded to mean more than their
+/// original intent. This enum should only be used to document the allocation
+/// strategy used by a tensor for it data.
 typedef enum TfLiteAllocationStrategy {
   kTfLiteAllocationStrategyUnknown,
-  kTfLiteAllocationStrategyNone,    // No data is allocated.
-  kTfLiteAllocationStrategyMMap,    // Data is mmaped.
-  kTfLiteAllocationStrategyArena,   // Handled by the arena.
-  kTfLiteAllocationStrategyMalloc,  // Uses `malloc`/`free`.
-  kTfLiteAllocationStrategyNew      // Uses `new[]`/`delete[]`.
+  kTfLiteAllocationStrategyNone,    /// No data is allocated.
+  kTfLiteAllocationStrategyMMap,    /// Data is mmaped.
+  kTfLiteAllocationStrategyArena,   /// Handled by the arena.
+  kTfLiteAllocationStrategyMalloc,  /// Uses `malloc`/`free`.
+  kTfLiteAllocationStrategyNew      /// Uses `new[]`/`delete[]`.
 } TfLiteAllocationStrategy;
 
-// Describes how stable a tensor attribute is with regards to an interpreter
-// runs.
+/// Describes how stable a tensor attribute is with regards to an interpreter
+/// runs.
 typedef enum TfLiteRunStability {
   kTfLiteRunStabilityUnknown,
-  kTfLiteRunStabilityUnstable,   // May change at any time.
-  kTfLiteRunStabilitySingleRun,  // Will stay the same for one run.
-  kTfLiteRunStabilityAcrossRuns  // Will stay the same across all runs.
+  kTfLiteRunStabilityUnstable,   /// May change at any time.
+  kTfLiteRunStabilitySingleRun,  /// Will stay the same for one run.
+  kTfLiteRunStabilityAcrossRuns  /// Will stay the same across all runs.
 } TfLiteRunStability;
 
-// Describes the steps of a TFLite operation life cycle.
+/// Describes the steps of a TFLite operation life cycle.
 typedef enum TfLiteRunStep {
   kTfLiteRunStepUnknown,
   kTfLiteRunStepInit,
@@ -394,20 +406,20 @@
   kTfLiteRunStepEval
 } TfLiteRunStep;
 
-// The delegates should use zero or positive integers to represent handles.
-// -1 is reserved from unallocated status.
+/// The delegates should use zero or positive integers to represent handles.
+/// -1 is reserved from unallocated status.
 typedef int TfLiteBufferHandle;
 enum {
   kTfLiteNullBufferHandle = -1,
 };
 
-// Storage format of each dimension in a sparse tensor.
+/// Storage format of each dimension in a sparse tensor.
 typedef enum TfLiteDimensionType {
   kTfLiteDimDense = 0,
   kTfLiteDimSparseCSR,
 } TfLiteDimensionType;
 
-// Metadata to encode each dimension in a sparse tensor.
+/// Metadata to encode each dimension in a sparse tensor.
 typedef struct TfLiteDimensionMetadata {
   TfLiteDimensionType format;
   int dense_size;
@@ -415,8 +427,8 @@
   TfLiteIntArray* array_indices;
 } TfLiteDimensionMetadata;
 
-// Parameters used to encode a sparse tensor. For detailed explanation of each
-// field please refer to lite/schema/schema.fbs.
+/// Parameters used to encode a sparse tensor. For detailed explanation of each
+/// field please refer to lite/schema/schema.fbs.
 typedef struct TfLiteSparsity {
   TfLiteIntArray* traversal_order;
   TfLiteIntArray* block_map;
@@ -424,133 +436,139 @@
   int dim_metadata_size;
 } TfLiteSparsity;
 
-// Defines a custom memory allocation not owned by the runtime.
-// `data` should be aligned to kDefaultTensorAlignment defined in
-// lite/util.h. (Currently 64 bytes)
-// NOTE: See Interpreter.SetCustomAllocationForTensor for details on usage.
+/// Defines a custom memory allocation not owned by the runtime.
+/// `data` should be aligned to kDefaultTensorAlignment defined in
+/// lite/util.h. (Currently 64 bytes)
+/// NOTE: See `Interpreter::SetCustomAllocationForTensor` for details on usage.
 typedef struct TfLiteCustomAllocation {
   void* data;
   size_t bytes;
 } TfLiteCustomAllocation;
 
-// The flags used in `Interpreter::SetCustomAllocationForTensor`.
-// Note that this is a bitmask, so the values should be 1, 2, 4, 8, ...etc.
+/// The flags used in `Interpreter::SetCustomAllocationForTensor`.
+/// Note that this is a bitmask, so the values should be 1, 2, 4, 8, ...etc.
 typedef enum TfLiteCustomAllocationFlags {
   kTfLiteCustomAllocationFlagsNone = 0,
-  // Skips checking whether allocation.data points to an aligned buffer as
-  // expected by the TFLite runtime.
-  // NOTE: Setting this flag can cause crashes when calling Invoke().
-  // Use with caution.
+  /// Skips checking whether allocation.data points to an aligned buffer as
+  /// expected by the TFLite runtime.
+  /// NOTE: Setting this flag can cause crashes when calling Invoke().
+  /// Use with caution.
   kTfLiteCustomAllocationFlagsSkipAlignCheck = 1,
 } TfLiteCustomAllocationFlags;
 
-// A tensor in the interpreter system which is a wrapper around a buffer of
-// data including a dimensionality (or NULL if not currently defined).
+/// A tensor in the interpreter system which is a wrapper around a buffer of
+/// data including a dimensionality (or NULL if not currently defined).
 #ifndef TF_LITE_STATIC_MEMORY
 typedef struct TfLiteTensor {
-  // The data type specification for data stored in `data`. This affects
-  // what member of `data` union should be used.
+  /// The data type specification for data stored in `data`. This affects
+  /// what member of `data` union should be used.
   TfLiteType type;
-  // A union of data pointers. The appropriate type should be used for a typed
-  // tensor based on `type`.
+  /// A union of data pointers. The appropriate type should be used for a typed
+  /// tensor based on `type`.
   TfLitePtrUnion data;
-  // A pointer to a structure representing the dimensionality interpretation
-  // that the buffer should have. NOTE: the product of elements of `dims`
-  // and the element datatype size should be equal to `bytes` below.
+  /// A pointer to a structure representing the dimensionality interpretation
+  /// that the buffer should have. NOTE: the product of elements of `dims`
+  /// and the element datatype size should be equal to `bytes` below.
   TfLiteIntArray* dims;
-  // Quantization information.
+  /// Quantization information.
   TfLiteQuantizationParams params;
-  // How memory is mapped
-  //  kTfLiteMmapRo: Memory mapped read only.
-  //  i.e. weights
-  //  kTfLiteArenaRw: Arena allocated read write memory
-  //  (i.e. temporaries, outputs).
+  /// How memory is mapped
+  ///  kTfLiteMmapRo: Memory mapped read only.
+  ///  i.e. weights
+  ///  kTfLiteArenaRw: Arena allocated read write memory
+  ///  (i.e. temporaries, outputs).
   TfLiteAllocationType allocation_type;
-  // The number of bytes required to store the data of this Tensor. I.e.
-  // (bytes of each element) * dims[0] * ... * dims[n-1].  For example, if
-  // type is kTfLiteFloat32 and dims = {3, 2} then
-  // bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24.
+  /// The number of bytes required to store the data of this Tensor. I.e.
+  /// (bytes of each element) * dims[0] * ... * dims[n-1].  For example, if
+  /// type is kTfLiteFloat32 and dims = {3, 2} then
+  /// bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24.
   size_t bytes;
 
-  // An opaque pointer to a tflite::MMapAllocation
+  /// An opaque pointer to a tflite::MMapAllocation
   const void* allocation;
 
-  // Null-terminated name of this tensor.
+  /// Null-terminated name of this tensor.
   const char* name;
 
-  // The delegate which knows how to handle `buffer_handle`.
-  // WARNING: This is an experimental interface that is subject to change.
+  /// The delegate which knows how to handle `buffer_handle`.
+  ///
+  /// WARNING: This is an experimental interface that is subject to change.
   struct TfLiteDelegate* delegate;
 
-  // An integer buffer handle that can be handled by `delegate`.
-  // The value is valid only when delegate is not null.
-  // WARNING: This is an experimental interface that is subject to change.
+  /// An integer buffer handle that can be handled by `delegate`.
+  /// The value is valid only when delegate is not null.
+  ///
+  /// WARNING: This is an experimental interface that is subject to change.
   TfLiteBufferHandle buffer_handle;
 
-  // If the delegate uses its own buffer (e.g. GPU memory), the delegate is
-  // responsible to set data_is_stale to true.
-  // `delegate->CopyFromBufferHandle` can be called to copy the data from
-  // delegate buffer.
-  // WARNING: This is an // experimental interface that is subject to change.
+  /// If the delegate uses its own buffer (e.g. GPU memory), the delegate is
+  /// responsible to set data_is_stale to true.
+  /// `delegate->CopyFromBufferHandle` can be called to copy the data from
+  /// delegate buffer.
+  ///
+  /// WARNING: This is an experimental interface that is subject to change.
   bool data_is_stale;
 
-  // True if the tensor is a variable.
+  /// True if the tensor is a variable.
   bool is_variable;
 
-  // Quantization information. Replaces params field above.
+  /// Quantization information. Replaces params field above.
   TfLiteQuantization quantization;
 
-  // Parameters used to encode a sparse tensor.
-  // This is optional. The field is NULL if a tensor is dense.
-  // WARNING: This is an experimental interface that is subject to change.
+  /// Parameters used to encode a sparse tensor.
+  /// This is optional. The field is NULL if a tensor is dense.
+  ///
+  /// WARNING: This is an experimental interface that is subject to change.
   TfLiteSparsity* sparsity;
 
-  // Optional. Encodes shapes with unknown dimensions with -1. This field is
-  // only populated when unknown dimensions exist in a read-write tensor (i.e.
-  // an input or output tensor). (e.g.  `dims` contains [1, 1, 1, 3] and
-  // `dims_signature` contains [1, -1, -1, 3]).  If no unknown dimensions exist
-  // then `dims_signature` is either null, or set to an empty array.  Note that
-  // this field only exists when TF_LITE_STATIC_MEMORY is not defined.
+  /// Optional. Encodes shapes with unknown dimensions with -1. This field is
+  /// only populated when unknown dimensions exist in a read-write tensor (i.e.
+  /// an input or output tensor). (e.g.  `dims` contains [1, 1, 1, 3] and
+  /// `dims_signature` contains [1, -1, -1, 3]).  If no unknown dimensions exist
+  /// then `dims_signature` is either null, or set to an empty array.  Note that
+  /// this field only exists when TF_LITE_STATIC_MEMORY is not defined.
   const TfLiteIntArray* dims_signature;
 } TfLiteTensor;
 
-// A structure representing an instance of a node.
-// This structure only exhibits the inputs, outputs, user defined data and some
-// node properties (like statefulness), not other features like the type.
+/// A structure representing an instance of a node.
+/// This structure only exhibits the inputs, outputs, user defined data and some
+/// node properties (like statefulness), not other features like the type.
 typedef struct TfLiteNode {
-  // Inputs to this node expressed as indices into the simulator's tensors.
+  /// Inputs to this node expressed as indices into the simulator's tensors.
   TfLiteIntArray* inputs;
 
-  // Outputs to this node expressed as indices into the simulator's tensors.
+  /// Outputs to this node expressed as indices into the simulator's tensors.
   TfLiteIntArray* outputs;
 
-  // intermediate tensors to this node expressed as indices into the simulator's
-  // tensors.
+  /// intermediate tensors to this node expressed as indices into the
+  /// simulator's tensors.
   TfLiteIntArray* intermediates;
 
-  // Temporary tensors uses during the computations. This usually contains no
-  // tensors, but ops are allowed to change that if they need scratch space of
-  // any sort.
+  /// Temporary tensors uses during the computations. This usually contains no
+  /// tensors, but ops are allowed to change that if they need scratch space of
+  /// any sort.
   TfLiteIntArray* temporaries;
 
-  // Opaque data provided by the node implementer through `Registration.init`.
+  /// Opaque data provided by the node implementer through `Registration.init`.
   void* user_data;
 
-  // Opaque data provided to the node if the node is a builtin. This is usually
-  // a structure defined in builtin_op_data.h
+  /// Opaque data provided to the node if the node is a builtin. This is usually
+  /// a structure defined in builtin_op_data.h
   void* builtin_data;
 
-  // Custom initial data. This is the opaque data provided in the flatbuffer.
-  // WARNING: This is an experimental interface that is subject to change.
+  /// Custom initial data. This is the opaque data provided in the flatbuffer.
+  ///
+  /// WARNING: This is an experimental interface that is subject to change.
   const void* custom_initial_data;
   int custom_initial_data_size;
 
-  // The pointer to the delegate. This is non-null only when the node is
-  // created by calling `interpreter.ModifyGraphWithDelegate`.
-  // WARNING: This is an experimental interface that is subject to change.
+  /// The pointer to the delegate. This is non-null only when the node is
+  /// created by calling `interpreter.ModifyGraphWithDelegate`.
+  ///
+  /// WARNING: This is an experimental interface that is subject to change.
   struct TfLiteDelegate* delegate;
 
-  // Whether this op might have side effect (e.g. stateful op).
+  /// Whether this op might have side effect (e.g. stateful op).
   bool might_have_side_effect;
 } TfLiteNode;
 #else   // defined(TF_LITE_STATIC_MEMORY)?
@@ -633,90 +651,89 @@
   void* builtin_data;
 
   // Custom initial data. This is the opaque data provided in the flatbuffer.
+  //
   // WARNING: This is an experimental interface that is subject to change.
   const void* custom_initial_data;
   int custom_initial_data_size;
 } TfLiteNode;
 #endif  // TF_LITE_STATIC_MEMORY
 
-// Light-weight tensor struct for TF Micro runtime. Provides the minimal amount
-// of information required for a kernel to run during TfLiteRegistration::Eval.
+/// Light-weight tensor struct for TF Micro runtime. Provides the minimal amount
+/// of information required for a kernel to run during TfLiteRegistration::Eval.
 // TODO(b/160955687): Move this field into TF_LITE_STATIC_MEMORY when TFLM
 // builds with this flag by default internally.
 typedef struct TfLiteEvalTensor {
-  // A union of data pointers. The appropriate type should be used for a typed
-  // tensor based on `type`.
+  /// A union of data pointers. The appropriate type should be used for a typed
+  /// tensor based on `type`.
   TfLitePtrUnion data;
 
-  // A pointer to a structure representing the dimensionality interpretation
-  // that the buffer should have.
+  /// A pointer to a structure representing the dimensionality interpretation
+  /// that the buffer should have.
   TfLiteIntArray* dims;
 
-  // The data type specification for data stored in `data`. This affects
-  // what member of `data` union should be used.
+  /// The data type specification for data stored in `data`. This affects
+  /// what member of `data` union should be used.
   TfLiteType type;
 } TfLiteEvalTensor;
 
 #ifndef TF_LITE_STATIC_MEMORY
-// Free data memory of tensor `t`.
+/// Free data memory of tensor `t`.
 void TfLiteTensorDataFree(TfLiteTensor* t);
 
-// Free quantization data.
+/// Free quantization data.
 void TfLiteQuantizationFree(TfLiteQuantization* quantization);
 
-// Free sparsity parameters.
+/// Free sparsity parameters.
 void TfLiteSparsityFree(TfLiteSparsity* sparsity);
 
-// Free memory of tensor `t`.
+/// Free memory of tensor `t`.
 void TfLiteTensorFree(TfLiteTensor* t);
 
-// Set all of a tensor's fields (and free any previously allocated data).
+/// Set all of a tensor's fields (and free any previously allocated data).
 void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
                        TfLiteQuantizationParams quantization, char* buffer,
                        size_t size, TfLiteAllocationType allocation_type,
                        const void* allocation, bool is_variable,
                        TfLiteTensor* tensor);
 
-// Copies the contents of 'src' in 'dst'.
-// Function does nothing if either 'src' or 'dst' is passed as nullptr and
-// return kTfLiteOk.
-// Returns kTfLiteError if 'src' and 'dst' doesn't have matching data size.
-// Note function copies contents, so it won't create new data pointer
-// or change allocation type.
-// All Tensor related properties will be copied from 'src' to 'dst' like
-// quantization, sparsity, ...
+/// Copies the contents of `src` in `dst`.
+/// Function does nothing if either `src` or `dst` is passed as nullptr and
+/// return `kTfLiteOk`.
+/// Returns `kTfLiteError` if `src` and `dst` doesn't have matching data size.
+/// Note function copies contents, so it won't create new data pointer
+/// or change allocation type.
+/// All Tensor related properties will be copied from `src` to `dst` like
+/// quantization, sparsity, ...
 TfLiteStatus TfLiteTensorCopy(const TfLiteTensor* src, TfLiteTensor* dst);
 
-// Change the size of the memory block owned by `tensor` to `num_bytes`.
-// Tensors with allocation types other than `kTfLiteDynamic` will be ignored and
-// a kTfLiteOk will be returned.
-// `tensor`'s internal data buffer will be assigned a pointer
-// which can safely be passed to free or realloc if `num_bytes` is zero.
-// If `preserve_data` is true, tensor data will be unchanged in the range from
-// the start of the region up to the minimum of the old and new sizes. In the
-// case of NULL tensor, or an error allocating new memory, returns
-// `kTfLiteError`.
+/// Change the size of the memory block owned by `tensor` to `num_bytes`.
+/// Tensors with allocation types other than `kTfLiteDynamic` will be ignored
+/// and a `kTfLiteOk` will be returned. `tensor`'s internal data buffer will be
+/// assigned a pointer which can safely be passed to free or realloc if
+/// `num_bytes` is zero. If `preserve_data` is true, tensor data will be
+/// unchanged in the range from the start of the region up to the minimum of the
+/// old and new sizes. In the case of NULL tensor, or an error allocating new
+/// memory, returns `kTfLiteError`.
 TfLiteStatus TfLiteTensorResizeMaybeCopy(size_t num_bytes, TfLiteTensor* tensor,
                                          bool preserve_data);
 
-// Change the size of the memory block owned by `tensor` to `num_bytes`.
-// Tensors with allocation types other than kTfLiteDynamic will be ignored and
-// a kTfLiteOk will be returned.
-// `tensor`'s internal data buffer will be assigned a pointer
-// which can safely be passed to free or realloc if `num_bytes` is zero.
-// Tensor data will be unchanged in the range from the start of the region up to
-// the minimum of the old and new sizes. In the case
-// of NULL tensor, or an error allocating new memory, returns `kTfLiteError`.
+/// Change the size of the memory block owned by `tensor` to `num_bytes`.
+/// Tensors with allocation types other than `kTfLiteDynamic` will be ignored
+/// and a `kTfLiteOk` will be returned. `tensor`'s internal data buffer will be
+/// assigned a pointer which can safely be passed to free or realloc if
+/// `num_bytes` is zero. Tensor data will be unchanged in the range from the
+/// start of the region up to the minimum of the old and new sizes. In the case
+/// of NULL tensor, or an error allocating new memory, returns `kTfLiteError`.
 TfLiteStatus TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor);
 #endif  // TF_LITE_STATIC_MEMORY
 
-// WARNING: This is an experimental interface that is subject to change.
-//
-// Currently, TfLiteDelegateParams has to be allocated in a way that it's
-// trivially destructable. It will be stored as `builtin_data` field in
-// `TfLiteNode` of the delegate node.
-//
-// See also the `CreateDelegateParams` function in `interpreter.cc` details.
+/// WARNING: This is an experimental interface that is subject to change.
+///
+/// Currently, TfLiteDelegateParams has to be allocated in a way that it's
+/// trivially destructable. It will be stored as `builtin_data` field in
+/// `TfLiteNode` of the delegate node.
+///
+/// See also the `CreateDelegateParams` function in `interpreter.cc` details.
 typedef struct TfLiteDelegateParams {
   struct TfLiteDelegate* delegate;
   TfLiteIntArray* nodes_to_replace;
@@ -724,14 +741,14 @@
   TfLiteIntArray* output_tensors;
 } TfLiteDelegateParams;
 
-// WARNING: This is an experimental interface that is subject to change.
-//
-// Currently, TfLiteOpaqueDelegateParams has to be allocated in a way that it's
-// trivially destructable. It will be stored as `builtin_data` field in
-// `TfLiteNode` of the delegate node.
-//
-// See also the `CreateOpaqueDelegateParams` function in `subgraph.cc`
-// details.
+/// WARNING: This is an experimental interface that is subject to change.
+///
+/// Currently, TfLiteOpaqueDelegateParams has to be allocated in a way that it's
+/// trivially destructable. It will be stored as `builtin_data` field in
+/// `TfLiteNode` of the delegate node.
+///
+/// See also the `CreateOpaqueDelegateParams` function in `subgraph.cc`
+/// details.
 typedef struct TfLiteOpaqueDelegateParams {
   TfLiteOpaqueDelegate* delegate;
   void* delegate_data;
@@ -740,371 +757,424 @@
   TfLiteIntArray* output_tensors;
 } TfLiteOpaqueDelegateParams;
 
+/// `TfLiteContext` allows an op to access the tensors.
+///
+/// `TfLiteContext` is a struct that is created by the TF Lite runtime
+/// and passed to the "methods" (C function pointers) in the
+/// `TfLiteRegistration` struct that are used to define custom ops and custom
+/// delegate kernels. It contains information and methods (C function pointers)
+/// that can be called by the code implementing a custom op or a custom delegate
+/// kernel. These methods provide access to the context in which that custom op
+/// or custom delegate kernel occurs, such as access to the input and output
+/// tensors for that op, as well as methods for allocating memory buffers
+/// and intermediate tensors, etc.
+///
+/// See also `TfLiteOpaqueContext`, which is an more ABI-stable equivalent.
 typedef struct TfLiteContext {
-  // Number of tensors in the context.
+  /// Number of tensors in the context.
   size_t tensors_size;
 
-  // The execution plan contains a list of the node indices in execution
-  // order. execution_plan->size is the current number of nodes. And,
-  // execution_plan->data[0] is the first node that needs to be run.
-  // TfLiteDelegates can traverse the current execution plan by iterating
-  // through each member of this array and using GetNodeAndRegistration() to
-  // access details about a node. i.e.
-  //
-  // TfLiteIntArray* execution_plan;
-  // TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &execution_plan));
-  // for (int exec_index = 0; exec_index < execution_plan->size; exec_index++) {
-  //    int node_index = execution_plan->data[exec_index];
-  //    TfLiteNode* node;
-  //    TfLiteRegistration* reg;
-  //    context->GetNodeAndRegistration(context, node_index, &node, &reg);
-  // }
-  // Note: the memory pointed by '`*execution_plan` is OWNED by TfLite runtime.
-  // Future calls to GetExecutionPlan invalidates earlier outputs. The following
-  // code snippet shows the issue of such an invocation pattern. After calling
-  // CheckNode, subsequent access to `plan_1st` is undefined.
-  //
-  // void CheckNode(const TfLiteNode* node) {
-  //   ...
-  //   TfLiteIntArray* plan_2nd;
-  //   TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan_2nd));
-  //   ...
-  // }
-  //
-  // TfLiteIntArray* plan_1st;
-  // TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan_1st));
-  // for (int exec_index = 0; exec_index < plan_1st->size; exec_index++) {
-  //    int node_index = plan_1st->data[exec_index];
-  //    TfLiteNode* node;
-  //    TfLiteRegistration* reg;
-  //    context->GetNodeAndRegistration(context, node_index, &node, &reg);
-  //    CheckNode(node);
-  // }
-  //
-  // WARNING: This is an experimental interface that is subject to change.
+  /// The execution plan contains a list of the node indices in execution
+  /// order. execution_plan->size is the current number of nodes. And,
+  /// execution_plan->data[0] is the first node that needs to be run.
+  /// TfLiteDelegates can traverse the current execution plan by iterating
+  /// through each member of this array and using GetNodeAndRegistration() to
+  /// access details about a node. i.e.
+  ///
+  ///
+  ///     TfLiteIntArray* execution_plan;
+  ///     TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context,
+  ///                                                     &execution_plan));
+  ///     for (int exec_index = 0; exec_index < execution_plan->size;
+  ///           exec_index++) {
+  ///        int node_index = execution_plan->data[exec_index];
+  ///        TfLiteNode* node;
+  ///        TfLiteRegistration* reg;
+  ///        context->GetNodeAndRegistration(context, node_index, &node, &reg);
+  ///     }
+  ///
+  /// Note: the memory pointed by '`*execution_plan` is OWNED by TfLite runtime.
+  /// Future calls to GetExecutionPlan invalidates earlier outputs. The
+  /// following code snippet shows the issue of such an invocation pattern.
+  /// After calling CheckNode, subsequent access to `plan_1st` is undefined.
+  ///
+  ///     void CheckNode(const TfLiteNode* node) {
+  ///       ...
+  ///       TfLiteIntArray* plan_2nd;
+  ///       TF_LITE_ENSURE_STATUS(
+  ///           context->GetExecutionPlan(context, &plan_2nd)
+  ///       );
+  ///       ...
+  ///     }
+  ///
+  ///     TfLiteIntArray* plan_1st;
+  ///     TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan_1st));
+  ///     for (int exec_index = 0; exec_index < plan_1st->size; exec_index++) {
+  ///        int node_index = plan_1st->data[exec_index];
+  ///        TfLiteNode* node;
+  ///        TfLiteRegistration* reg;
+  ///        context->GetNodeAndRegistration(context, node_index, &node, &reg);
+  ///        CheckNode(node);
+  ///     }
+  ///
+  /// WARNING: This is an experimental interface that is subject to change.
   TfLiteStatus (*GetExecutionPlan)(struct TfLiteContext* context,
                                    TfLiteIntArray** execution_plan);
 
-  // An array of tensors in the interpreter context (of length `tensors_size`)
+  /// An array of tensors in the interpreter context (of length `tensors_size`)
   TfLiteTensor* tensors;
 
-  // opaque full context ptr (an opaque c++ data structure)
+  /// opaque full context ptr (an opaque c++ data structure)
   void* impl_;
 
-  // Request memory pointer be resized. Updates dimensions on the tensor.
-  // NOTE: ResizeTensor takes ownership of newSize.
+  /// Request memory pointer be resized. Updates dimensions on the tensor.
+  /// NOTE: ResizeTensor takes ownership of newSize.
   TfLiteStatus (*ResizeTensor)(struct TfLiteContext*, TfLiteTensor* tensor,
                                TfLiteIntArray* new_size);
-  // Request that an error be reported with format string msg.
+  /// Request that an error be reported with format string msg.
   void (*ReportError)(struct TfLiteContext*, const char* msg, ...);
 
-  // Add `tensors_to_add` tensors, preserving pre-existing Tensor entries.  If
-  // non-null, the value pointed to by `first_new_tensor_index` will be set to
-  // the index of the first new tensor.
+  /// Add `tensors_to_add` tensors, preserving pre-existing Tensor entries.  If
+  /// non-null, the value pointed to by `first_new_tensor_index` will be set to
+  /// the index of the first new tensor.
   TfLiteStatus (*AddTensors)(struct TfLiteContext*, int tensors_to_add,
                              int* first_new_tensor_index);
 
-  // Get a Tensor node by node_index.
-  // WARNING: This is an experimental interface that is subject to change.
+  /// Get a Tensor node by node_index.
+  ///
+  /// WARNING: This is an experimental interface that is subject to change.
   TfLiteStatus (*GetNodeAndRegistration)(
       struct TfLiteContext*, int node_index, TfLiteNode** node,
       struct TfLiteRegistration** registration);
 
-  // Replace ops with one or more stub delegate operations. This function
-  // does not take ownership of `nodes_to_replace`.
+  /// Replace ops with one or more stub delegate operations. This function
+  /// does not take ownership of `nodes_to_replace`.
   TfLiteStatus (*ReplaceNodeSubsetsWithDelegateKernels)(
       struct TfLiteContext*, struct TfLiteRegistration registration,
       const TfLiteIntArray* nodes_to_replace, struct TfLiteDelegate* delegate);
 
-  // Number of threads that are recommended to subsystems like gemmlowp and
-  // eigen.
+  /// Number of threads that are recommended to subsystems like gemmlowp and
+  /// eigen.
   int recommended_num_threads;
 
-  // Access external contexts by type.
-  // WARNING: This is an experimental interface that is subject to change.
+  /// Access external contexts by type.
+  ///
+  /// WARNING: This is an experimental interface that is subject to change.
   TfLiteExternalContext* (*GetExternalContext)(struct TfLiteContext*,
                                                TfLiteExternalContextType);
-  // Set the value of a external context. Does not take ownership of the
-  // pointer.
-  // WARNING: This is an experimental interface that is subject to change.
+  /// Set the value of a external context. Does not take ownership of the
+  /// pointer.
+  ///
+  /// WARNING: This is an experimental interface that is subject to change.
   void (*SetExternalContext)(struct TfLiteContext*, TfLiteExternalContextType,
                              TfLiteExternalContext*);
 
-  // Flag for allowing float16 precision for FP32 calculation.
-  // default: false.
-  // WARNING: This is an experimental API and subject to change.
+  /// Flag for allowing float16 precision for FP32 calculation.
+  /// default: false.
+  ///
+  /// WARNING: This is an experimental API and subject to change.
   bool allow_fp32_relax_to_fp16;
 
-  // Pointer to the op-level profiler, if set; nullptr otherwise.
+  /// Pointer to the op-level profiler, if set; nullptr otherwise.
   void* profiler;
 
-  // Allocate persistent buffer which has the same life time as the interpreter.
-  // Returns nullptr on failure.
-  // The memory is allocated from heap for TFL, and from tail in TFLM.
-  // This method is only available in Init or Prepare stage.
-  // WARNING: This is an experimental interface that is subject to change.
+  /// Allocate persistent buffer which has the same life time as the
+  /// interpreter. Returns `nullptr` on failure. The memory is allocated from
+  /// heap for TFL, and from tail in TFLM. This method is only available in
+  /// `Init` or `Prepare` stage.
+  ///
+  /// WARNING: This is an experimental interface that is subject
+  /// to change.
   void* (*AllocatePersistentBuffer)(struct TfLiteContext* ctx, size_t bytes);
 
-  // Allocate a buffer which will be deallocated right after invoke phase.
-  // The memory is allocated from heap in TFL, and from volatile arena in TFLM.
-  // This method is only available in invoke stage.
-  // NOTE: If possible use RequestScratchBufferInArena method to avoid memory
-  // allocation during inference time.
-  // WARNING: This is an experimental interface that is subject to change.
+  /// Allocate a buffer which will be deallocated right after invoke phase.
+  /// The memory is allocated from heap in TFL, and from volatile arena in TFLM.
+  /// This method is only available in invoke stage.
+  ///
+  /// NOTE: If possible use `RequestScratchBufferInArena` method to avoid memory
+  /// allocation during inference time.
+  ///
+  /// WARNING: This is an experimental interface that is subject to change.
   TfLiteStatus (*AllocateBufferForEval)(struct TfLiteContext* ctx, size_t bytes,
                                         void** ptr);
 
-  // Request a scratch buffer in the arena through static memory planning.
-  // This method is only available in Prepare stage and the buffer is allocated
-  // by the interpreter between Prepare and Eval stage. In Eval stage,
-  // GetScratchBuffer API can be used to fetch the address.
-  // WARNING: This is an experimental interface that is subject to change.
+  /// Request a scratch buffer in the arena through static memory planning.
+  /// This method is only available in `Prepare` stage and the buffer is
+  /// allocated by the interpreter between Prepare and Eval stage. In `Eval`
+  /// stage, `GetScratchBuffer` API can be used to fetch the address.
+  ///
+  /// WARNING: This is an experimental interface that is subject to change.
   TfLiteStatus (*RequestScratchBufferInArena)(struct TfLiteContext* ctx,
                                               size_t bytes, int* buffer_idx);
 
-  // Get the scratch buffer pointer.
-  // This method is only available in Eval stage.
-  // WARNING: This is an experimental interface that is subject to change.
+  /// Get the scratch buffer pointer.
+  /// This method is only available in Eval stage.
+  ///
+  /// WARNING: This is an experimental interface that is subject to change.
   void* (*GetScratchBuffer)(struct TfLiteContext* ctx, int buffer_idx);
 
-  // Resize the memory pointer of the `tensor`. This method behaves the same as
-  // `ResizeTensor`, except that it makes a copy of the shape array internally
-  // so the shape array could be deallocated right afterwards.
-  // WARNING: This is an experimental interface that is subject to change.
+  /// Resize the memory pointer of the `tensor`. This method behaves the same as
+  /// `ResizeTensor`, except that it makes a copy of the shape array internally
+  /// so the shape array could be deallocated right afterwards.
+  ///
+  /// WARNING: This is an experimental interface that is subject to change.
   TfLiteStatus (*ResizeTensorExplicit)(struct TfLiteContext* ctx,
                                        TfLiteTensor* tensor, int dims,
                                        const int* shape);
 
-  // This method provides a preview of post-delegation partitioning. Each
-  // TfLiteDelegateParams in the referenced array corresponds to one instance of
-  // the delegate kernel.
-  // Example usage:
-  //
-  // TfLiteIntArray* nodes_to_replace = ...;
-  // TfLiteDelegateParams* params_array;
-  // int num_partitions = 0;
-  // TF_LITE_ENSURE_STATUS(context->PreviewDelegatePartitioning(
-  //    context, delegate, nodes_to_replace, &params_array, &num_partitions));
-  // for (int idx = 0; idx < num_partitions; idx++) {
-  //    const auto& partition_params = params_array[idx];
-  //    ...
-  // }
-  //
-  // NOTE: The context owns the memory referenced by partition_params_array. It
-  // will be cleared with another call to PreviewDelegatePartitioning, or after
-  // TfLiteDelegateParams::Prepare returns.
-  //
-  // WARNING: This is an experimental interface that is subject to change.
+  /// This method provides a preview of post-delegation partitioning. Each
+  /// TfLiteDelegateParams in the referenced array corresponds to one instance
+  /// of the delegate kernel. Example usage:
+  ///
+  ///     TfLiteIntArray* nodes_to_replace = ...;
+  ///     TfLiteDelegateParams* params_array;
+  ///     int num_partitions = 0;
+  ///     TF_LITE_ENSURE_STATUS(context->PreviewDelegatePartitioning(
+  ///        context, delegate, nodes_to_replace, &params_array,
+  ///        &num_partitions));
+  ///     for (int idx = 0; idx < num_partitions; idx++) {
+  ///        const auto& partition_params = params_array[idx];
+  ///        ...
+  ///     }
+  ///
+  /// NOTE: The context owns the memory referenced by partition_params_array. It
+  /// will be cleared with another call to PreviewDelegatePartitioning, or after
+  /// TfLiteDelegateParams::Prepare returns.
+  ///
+  /// WARNING: This is an experimental interface that is subject to change.
   TfLiteStatus (*PreviewDelegatePartitioning)(
       struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
       TfLiteDelegateParams** partition_params_array, int* num_partitions);
 
-  // Returns a TfLiteTensor struct for a given index.
-  // WARNING: This is an experimental interface that is subject to change.
-  // WARNING: This method may not be available on all platforms.
+  /// Returns a TfLiteTensor struct for a given index.
+  ///
+  /// WARNING: This is an experimental interface that is subject to change.
+  ///
+  /// WARNING: This method may not be available on all platforms.
   TfLiteTensor* (*GetTensor)(const struct TfLiteContext* context,
                              int tensor_idx);
 
-  // Returns a TfLiteEvalTensor struct for a given index.
-  // WARNING: This is an experimental interface that is subject to change.
-  // WARNING: This method may not be available on all platforms.
+  /// Returns a TfLiteEvalTensor struct for a given index.
+  ///
+  /// WARNING: This is an experimental interface that is subject to change.
+  ///
+  /// WARNING: This method may not be available on all platforms.
   TfLiteEvalTensor* (*GetEvalTensor)(const struct TfLiteContext* context,
                                      int tensor_idx);
 
-  // Retrieves named metadata buffer from the TFLite model.
-  // Returns kTfLiteOk if metadata is successfully obtained from the flatbuffer
-  // Model: that is, there exists a `metadata` entry with given `name` string.
-  // (see TFLite's schema.fbs).
-  // The corresponding `buffer` information is populated in `ptr` & `bytes`.
-  // The data from `ptr` is valid for the lifetime of the Interpreter.
-  //
-  // WARNING: This is an experimental interface that is subject to change.
+  /// Retrieves named metadata buffer from the TFLite model.
+  /// Returns kTfLiteOk if metadata is successfully obtained from the flatbuffer
+  /// Model: that is, there exists a `metadata` entry with given `name` string.
+  /// (see TFLite's schema.fbs).
+  /// The corresponding `buffer` information is populated in `ptr` & `bytes`.
+  /// The data from `ptr` is valid for the lifetime of the Interpreter.
+  ///
+  /// WARNING: This is an experimental interface that is subject to change.
   TfLiteStatus (*GetModelMetadata)(const struct TfLiteContext* context,
                                    const char* name, const char** ptr,
                                    size_t* bytes);
 
-  // Retrieves the corresponding TfLiteContext of a subgraph that the given
-  // subgraph_index points to and switches to the delegate context for that
-  // subgraph. If an invalid subgraph index is given, returns kTfLiteError.
-  // NOTE: This function is expected to be paired with ReleaseSubgraphContext()
-  // once the delegate preparation is done and/or the delegate context functions
-  // are no longer needed.
-  //
-  // WARNING: This is an experimental interface that is subject to change.
+  /// Retrieves the corresponding TfLiteContext of a subgraph that the given
+  /// subgraph_index points to and switches to the delegate context for that
+  /// subgraph. If an invalid subgraph index is given, returns kTfLiteError.
+  ///
+  /// NOTE: This function is expected to be paired with ReleaseSubgraphContext()
+  /// once the delegate preparation is done and/or the delegate context
+  /// functions are no longer needed.
+  ///
+  /// WARNING: This is an experimental interface that is subject to change.
   TfLiteStatus (*AcquireSubgraphContext)(
       struct TfLiteContext* context, int subgraph_index,
       struct TfLiteContext** acquired_context);
-  // Releases the subgraph context by switching back to the TFLite kernel
-  // context for the subgraph that the given subgraph_index points to.
-  // NOTE: This function is expected to be used after AcquireSubgraphContext()
-  // once the delegate preparation is done and/or the delegate context functions
-  // are no longer needed.
-  //
-  // WARNING: This is an experimental interface that is subject to change.
+  /// Releases the subgraph context by switching back to the TFLite kernel
+  /// context for the subgraph that the given subgraph_index points to.
+  ///
+  /// NOTE: This function is expected to be used after AcquireSubgraphContext()
+  /// once the delegate preparation is done and/or the delegate context
+  /// functions are no longer needed.
+  ///
+  /// WARNING: This is an experimental interface that is subject to change.
   TfLiteStatus (*ReleaseSubgraphContext)(struct TfLiteContext* context,
                                          int subgraph_index);
 } TfLiteContext;
 
-// `TfLiteRegistrationExternal` is an external version of `TfLiteRegistration`
-// for C API which doesn't use internal types (such as `TfLiteContext`) but only
-// uses stable API types (such as `TfLiteOpaqueContext`). The purpose of each
-// field is the exactly the same as with `TfLiteRegistration`.
+/// `TfLiteRegistrationExternal` is an external version of `TfLiteRegistration`
+/// for C API which doesn't use internal types (such as `TfLiteContext`) but
+/// only uses stable API types (such as `TfLiteOpaqueContext`). The purpose of
+/// each field is the exactly the same as with `TfLiteRegistration`.
 typedef struct TfLiteRegistrationExternal TfLiteRegistrationExternal;
 
-// The valid values of the `inplace_operator` field in `TfLiteRegistration`.
-// This allow an op to signal to the runtime that the same data pointer
-// may be passed as an input and output without impacting the result.
-// This does not mean that the memory can safely be reused, it is up to the
-// runtime to determine this, e.g. if another op consumes the same input or not
-// or if an input tensor has sufficient memory allocated to store the output
-// data.
-//
-// Setting these flags authorizes the runtime to set the data pointers of an
-// input and output tensor to the same value. In such cases, the memory required
-// by the output must be less than or equal to that required by the shared
-// input, never greater. If kTfLiteInplaceOpDataUnmodified is set, then the
-// runtime can share the same input tensor with multiple operator's outputs,
-// provided that kTfLiteInplaceOpDataUnmodified is set for all of them.
-// Otherwise, if an input tensor is consumed by multiple operators, it may only
-// be shared with the operator which is the last to consume it.
-//
-// Note that this is a bitmask, so the values should be 1, 2, 4, 8, ...etc.
+/// The valid values of the `inplace_operator` field in `TfLiteRegistration`.
+/// This allow an op to signal to the runtime that the same data pointer
+/// may be passed as an input and output without impacting the result.
+/// This does not mean that the memory can safely be reused, it is up to the
+/// runtime to determine this, e.g. if another op consumes the same input or not
+/// or if an input tensor has sufficient memory allocated to store the output
+/// data.
+///
+/// Setting these flags authorizes the runtime to set the data pointers of an
+/// input and output tensor to the same value. In such cases, the memory
+/// required by the output must be less than or equal to that required by the
+/// shared input, never greater. If kTfLiteInplaceOpDataUnmodified is set, then
+/// the runtime can share the same input tensor with multiple operator's
+/// outputs, provided that kTfLiteInplaceOpDataUnmodified is set for all of
+/// them. Otherwise, if an input tensor is consumed by multiple operators, it
+/// may only be shared with the operator which is the last to consume it.
+///
+/// Note that this is a bitmask, so the values should be 1, 2, 4, 8, ...etc.
 typedef enum {
-  // The default value. This indicates that the same data pointer cannot safely
-  // be passed as an op's input and output.
+  /// The default value. This indicates that the same data pointer cannot safely
+  /// be passed as an op's input and output.
   kTfLiteInplaceOpNone = 0,
-  // This indicates that an op's first output's data is identical to its first
-  // input's data, for example Reshape.
+  /// This indicates that an op's first output's data is identical to its first
+  /// input's data, for example Reshape.
   kTfLiteInplaceOpDataUnmodified = 1,
-  // Setting kTfLiteInplaceInputCanBeSharedWithCorrespondingOutput means
-  // that InputN may be shared with OutputN instead of with the first output.
-  // This flag requires one or more of kTfLiteInplaceOpInputNShared to be set.
+  /// Setting kTfLiteInplaceInputCanBeSharedWithCorrespondingOutput means
+  /// that InputN may be shared with OutputN instead of with the first output.
+  /// This flag requires one or more of kTfLiteInplaceOpInputNShared to be set.
   kTfLiteInplaceInputCanBeSharedWithCorrespondingOutput = 2,
-  // kTfLiteInplaceOpInputNShared indicates that it is safe for an op to share
-  // InputN's data pointer with an output tensor. If
-  // kTfLiteInplaceInputCanBeSharedWithCorrespondingOutput is set then
-  // kTfLiteInplaceOpInputNShared indicates that InputN may be shared
-  // with OutputN, otherwise kTfLiteInplaceOpInputNShared indicates that InputN
-  // may be shared with the first output.
-  //
-  // Indicates that an op's first input may be shared with the first output
-  // tensor. kTfLiteInplaceInputCanBeSharedWithCorrespondingOutput has
-  // no impact on the behavior allowed by this flag.
+  /// kTfLiteInplaceOpInputNShared indicates that it is safe for an op to share
+  /// InputN's data pointer with an output tensor. If
+  /// kTfLiteInplaceInputCanBeSharedWithCorrespondingOutput is set then
+  /// kTfLiteInplaceOpInputNShared indicates that InputN may be shared
+  /// with OutputN, otherwise kTfLiteInplaceOpInputNShared indicates that InputN
+  /// may be shared with the first output.
+  ///
+  /// Indicates that an op's first input may be shared with the first output
+  /// tensor. kTfLiteInplaceInputCanBeSharedWithCorrespondingOutput has
+  /// no impact on the behavior allowed by this flag.
   kTfLiteInplaceOpInput0Shared = 4,
-  // Indicates that an op's second input may be shared with the first output
-  // if kTfLiteInplaceInputCanBeSharedWithCorrespondingOutput is not set
-  // or second output if kTfLiteInplaceInputCanBeSharedWithCorrespondingOutput
-  // is set.
+  /// Indicates that an op's second input may be shared with the first output
+  /// if kTfLiteInplaceInputCanBeSharedWithCorrespondingOutput is not set
+  /// or second output if kTfLiteInplaceInputCanBeSharedWithCorrespondingOutput
+  /// is set.
   kTfLiteInplaceOpInput1Shared = 8,
-  // Indicates that an op's third input may be shared with the first output
-  // if kTfLiteInplaceInputCanBeSharedWithCorrespondingOutput is not set
-  // or third output if kTfLiteInplaceInputCanBeSharedWithCorrespondingOutput is
-  // set.
+  /// Indicates that an op's third input may be shared with the first output
+  /// if kTfLiteInplaceInputCanBeSharedWithCorrespondingOutput is not set
+  /// or third output if kTfLiteInplaceInputCanBeSharedWithCorrespondingOutput
+  /// is
+  /// set.
   kTfLiteInplaceOpInput2Shared = 16,
-  // Placeholder to ensure that enum can hold 64 bit values to accommodate
-  // future fields.
+  /// Placeholder to ensure that enum can hold 64 bit values to accommodate
+  /// future fields.
   kTfLiteInplaceOpMaxValue = UINT64_MAX,
 } TfLiteInPlaceOp;
 
-// The number of shareable inputs supported.
+/// The number of shareable inputs supported.
 static const int kTfLiteMaxSharableOpInputs = 3;
 
+/// `TfLiteRegistration` defines the implementation of an operation
+/// (a built-in op, custom op, or custom delegate kernel).
+///
+/// It is a struct containing "methods" (C function pointers) that will be
+/// invoked by the TF Lite runtime to evaluate instances of the operation.
+///
+/// See also `TfLiteRegistrationExternal` which is a more ABI-stable equivalent.
 typedef struct TfLiteRegistration {
-  // Initializes the op from serialized data.
-  // Called only *once* for the lifetime of the op, so any one-time allocations
-  // should be made here (unless they depend on tensor sizes).
-  //
-  // If a built-in op:
-  //   `buffer` is the op's params data (TfLiteLSTMParams*).
-  //   `length` is zero.
-  // If custom op:
-  //   `buffer` is the op's `custom_options`.
-  //   `length` is the size of the buffer.
-  //
-  // Returns a type-punned (i.e. void*) opaque data (e.g. a primitive pointer
-  // or an instance of a struct).
-  //
-  // The returned pointer will be stored with the node in the `user_data` field,
-  // accessible within prepare and invoke functions below.
-  // NOTE: if the data is already in the desired format, simply implement this
-  // function to return `nullptr` and implement the free function to be a no-op.
+  /// Initializes the op from serialized data.
+  /// Called only *once* for the lifetime of the op, so any one-time allocations
+  /// should be made here (unless they depend on tensor sizes).
+  ///
+  /// * If a built-in op:
+  ///       * `buffer` is the op's params data (TfLiteLSTMParams*).
+  ///       * `length` is zero.
+  /// * If custom op:
+  ///       * `buffer` is the op's `custom_options`.
+  ///       * `length` is the size of the buffer.
+  ///
+  /// Returns a type-punned (i.e. void*) opaque data (e.g. a primitive pointer
+  /// or an instance of a struct).
+  ///
+  /// The returned pointer will be stored with the node in the `user_data`
+  /// field, accessible within prepare and invoke functions below.
+  ///
+  /// NOTE: if the data is already in the desired format, simply implement this
+  /// function to return `nullptr` and implement the free function to be a
+  /// no-op.
   void* (*init)(TfLiteContext* context, const char* buffer, size_t length);
 
-  // The pointer `buffer` is the data previously returned by an init invocation.
+  /// The pointer `buffer` is the data previously returned by an init
+  /// invocation.
   void (*free)(TfLiteContext* context, void* buffer);
 
-  // prepare is called when the inputs this node depends on have been resized.
-  // context->ResizeTensor() can be called to request output tensors to be
-  // resized.
-  // Can be called multiple times for the lifetime of the op.
-  //
-  // Returns kTfLiteOk on success.
+  /// prepare is called when the inputs this node depends on have been resized.
+  /// `context->ResizeTensor()` can be called to request output tensors to be
+  /// resized.
+  /// Can be called multiple times for the lifetime of the op.
+  ///
+  /// Returns `kTfLiteOk` on success.
   TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node);
 
-  // Execute the node (should read node->inputs and output to node->outputs).
-  // Returns kTfLiteOk on success.
+  /// Execute the node (should read `node->inputs` and output to
+  /// `node->outputs`).
+  ///
+  /// Returns `kTfLiteOk` on success.
   TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node);
 
-  // profiling_string is called during summarization of profiling information
-  // in order to group executions together. Providing a value here will cause a
-  // given op to appear multiple times is the profiling report. This is
-  // particularly useful for custom ops that can perform significantly
-  // different calculations depending on their `user-data`.
+  /// `profiling_string` is called during summarization of profiling information
+  /// in order to group executions together. Providing a value here will cause a
+  /// given op to appear multiple times is the profiling report. This is
+  /// particularly useful for custom ops that can perform significantly
+  /// different calculations depending on their `user-data`.
   const char* (*profiling_string)(const TfLiteContext* context,
                                   const TfLiteNode* node);
 
-  // Builtin codes. If this kernel refers to a builtin this is the code
-  // of the builtin. This is so we can do marshaling to other frameworks like
-  // NN API.
-  // Note: It is the responsibility of the registration binder to set this
-  // properly.
+  /// Builtin codes. If this kernel refers to a builtin this is the code
+  /// of the builtin. This is so we can do marshaling to other frameworks like
+  /// NN API.
+  ///
+  /// Note: It is the responsibility of the registration binder to set this
+  /// properly.
   int32_t builtin_code;
 
-  // Custom op name. If the op is a builtin, this will be null.
-  // Note: It is the responsibility of the registration binder to set this
-  // properly.
-  // WARNING: This is an experimental interface that is subject to change.
+  /// Custom op name. If the op is a builtin, this will be `null`.
+  ///
+  /// Note: It is the responsibility of the registration binder to set this
+  /// properly.
+  ///
+  /// WARNING: This is an experimental interface that is subject to change.
   const char* custom_name;
 
-  // The version of the op.
-  // Note: It is the responsibility of the registration binder to set this
-  // properly.
+  /// The version of the op.
+  /// Note: It is the responsibility of the registration binder to set this
+  /// properly.
   int version;
 
-  // The external version of `TfLiteRegistration`. Since we can't use internal
-  // types (such as `TfLiteContext`) for C API to maintain ABI stability.
-  // C API user will provide `TfLiteRegistrationExternal` to implement custom
-  // ops. We keep it inside of `TfLiteRegistration` and use it to route
-  // callbacks properly.
+  /// The external version of `TfLiteRegistration`. Since we can't use internal
+  /// types (such as `TfLiteContext`) for C API to maintain ABI stability.
+  /// C API user will provide `TfLiteRegistrationExternal` to implement custom
+  /// ops. We keep it inside of `TfLiteRegistration` and use it to route
+  /// callbacks properly.
   TfLiteRegistrationExternal* registration_external;
 
-  // Retrieves asynchronous kernel.
-  //
-  // If the `async_kernel` field is nullptr, it means the operation described by
-  // this TfLiteRegistration object does not support asynchronous execution.
-  // Otherwise, the function that the field points to should only be called for
-  // delegate kernel nodes, i.e. `node` should be a delegate kernel node created
-  // by applying a delegate.
-  // If the function returns nullptr, that means that the underlying delegate
-  // does not support asynchronous execution for this `node`.
+  /// Retrieves asynchronous kernel.
+  ///
+  /// If the `async_kernel` field is nullptr, it means the operation described
+  /// by this TfLiteRegistration object does not support asynchronous execution.
+  /// Otherwise, the function that the field points to should only be called for
+  /// delegate kernel nodes, i.e. `node` should be a delegate kernel node
+  /// created by applying a delegate. If the function returns nullptr, that
+  /// means that the underlying delegate does not support asynchronous execution
+  /// for this `node`.
   struct TfLiteAsyncKernel* (*async_kernel)(TfLiteContext* context,
                                             TfLiteNode* node);
 
-  // Indicates if an operator's output may safely overwrite its inputs.
-  // See the comments in `TfLiteInPlaceOp`.
+  /// Indicates if an operator's output may safely overwrite its inputs.
+  /// See the comments in `TfLiteInPlaceOp`.
   uint64_t inplace_operator;
 } TfLiteRegistration;
 
 /// \private
-// Old version of `TfLiteRegistration` to maintain binary backward
-// compatibility.
-// The legacy registration type must be a POD struct type whose field types must
-// be a prefix of the field types in TfLiteRegistration, and offset of the first
-// field in TfLiteRegistration that is not present in the legacy registration
-// type must be greater than or equal to the size of the legacy registration
-// type.
-// WARNING: This structure is deprecated / not an official part of the
-// API. It should be only used for binary backward compatibility.
+/// Old version of `TfLiteRegistration` to maintain binary backward
+/// compatibility.
+/// The legacy registration type must be a POD struct type whose field types
+/// must be a prefix of the field types in TfLiteRegistration, and offset of the
+/// first field in TfLiteRegistration that is not present in the legacy
+/// registration type must be greater than or equal to the size of the legacy
+/// registration type.
+///
+/// WARNING: This structure is deprecated / not an official part of the
+/// API. It should be only used for binary backward compatibility.
 typedef struct TfLiteRegistration_V3 {
   void* (*init)(TfLiteContext* context, const char* buffer, size_t length);
   void (*free)(TfLiteContext* context, void* buffer);
@@ -1121,15 +1191,16 @@
 } TfLiteRegistration_V3;
 
 /// \private
-// Old version of `TfLiteRegistration` to maintain binary backward
-// compatibility.
-// The legacy registration type must be a POD struct type whose field types must
-// be a prefix of the field types in TfLiteRegistration, and offset of the first
-// field in TfLiteRegistration that is not present in the legacy registration
-// type must be greater than or equal to the size of the legacy registration
-// type.
-// WARNING: This structure is deprecated / not an official part of the
-// API. It should be only used for binary backward compatibility.
+/// Old version of `TfLiteRegistration` to maintain binary backward
+/// compatibility.
+/// The legacy registration type must be a POD struct type whose field types
+/// must be a prefix of the field types in TfLiteRegistration, and offset of the
+/// first field in TfLiteRegistration that is not present in the legacy
+/// registration type must be greater than or equal to the size of the legacy
+/// registration type.
+///
+/// WARNING: This structure is deprecated / not an official part of the
+/// API. It should be only used for binary backward compatibility.
 typedef struct TfLiteRegistration_V2 {
   void* (*init)(TfLiteContext* context, const char* buffer, size_t length);
   void (*free)(TfLiteContext* context, void* buffer);
@@ -1144,15 +1215,16 @@
 } TfLiteRegistration_V2;
 
 /// \private
-// Old version of `TfLiteRegistration` to maintain binary backward
-// compatibility.
-// The legacy registration type must be a POD struct type whose field types must
-// be a prefix of the field types in TfLiteRegistration, and offset of the first
-// field in TfLiteRegistration that is not present in the legacy registration
-// type must be greater than or equal to the size of the legacy registration
-// type.
-// WARNING: This structure is deprecated / not an official part of the
-// API. It should be only used for binary backward compatibility.
+/// Old version of `TfLiteRegistration` to maintain binary backward
+/// compatibility.
+/// The legacy registration type must be a POD struct type whose field types
+/// must be a prefix of the field types in TfLiteRegistration, and offset of the
+/// first field in TfLiteRegistration that is not present in the legacy
+/// registration type must be greater than or equal to the size of the legacy
+/// registration type.
+///
+/// WARNING: This structure is deprecated / not an official part of the
+/// API. It should be only used for binary backward compatibility.
 typedef struct TfLiteRegistration_V1 {
   void* (*init)(TfLiteContext* context, const char* buffer, size_t length);
   void (*free)(TfLiteContext* context, void* buffer);
@@ -1165,207 +1237,214 @@
   int version;
 } TfLiteRegistration_V1;
 
-// The flags used in `TfLiteDelegate`. Note that this is a bitmask, so the
-// values should be 1, 2, 4, 8, ...etc.
+/// The flags used in `TfLiteDelegate`. Note that this is a bitmask, so the
+/// values should be 1, 2, 4, 8, ...etc.
 typedef enum TfLiteDelegateFlags {
   kTfLiteDelegateFlagsNone = 0,
-  // The flag is set if the delegate can handle dynamic sized tensors.
-  // For example, the output shape of a `Resize` op with non-constant shape
-  // can only be inferred when the op is invoked.
-  // In this case, the Delegate is responsible for calling
-  // `SetTensorToDynamic` to mark the tensor as a dynamic tensor, and calling
-  // `ResizeTensor` when invoking the op.
-  //
-  // If the delegate isn't capable to handle dynamic tensors, this flag need
-  // to be set to false.
+  /// The flag is set if the delegate can handle dynamic sized tensors.
+  /// For example, the output shape of a `Resize` op with non-constant shape
+  /// can only be inferred when the op is invoked.
+  /// In this case, the Delegate is responsible for calling
+  /// `SetTensorToDynamic` to mark the tensor as a dynamic tensor, and calling
+  /// `ResizeTensor` when invoking the op.
+  ///
+  /// If the delegate isn't capable to handle dynamic tensors, this flag need
+  /// to be set to false.
   kTfLiteDelegateFlagsAllowDynamicTensors = 1,
 
-  // This flag can be used by delegates (that allow dynamic tensors) to ensure
-  // applicable tensor shapes are automatically propagated in the case of tensor
-  // resizing.
-  // This means that non-dynamic (allocation_type != kTfLiteDynamic) I/O tensors
-  // of a delegate kernel will have correct shapes before its Prepare() method
-  // is called. The runtime leverages TFLite builtin ops in the original
-  // execution plan to propagate shapes.
-  //
-  // A few points to note:
-  // 1. This requires kTfLiteDelegateFlagsAllowDynamicTensors. If that flag is
-  // false, this one is redundant since the delegate kernels are re-initialized
-  // every time tensors are resized.
-  // 2. Enabling this flag adds some overhead to AllocateTensors(), since extra
-  // work is required to prepare the original execution plan.
-  // 3. This flag requires that the original execution plan only have ops with
-  // valid registrations (and not 'dummy' custom ops like with Flex).
-  // WARNING: This feature is experimental and subject to change.
+  /// This flag can be used by delegates (that allow dynamic tensors) to ensure
+  /// applicable tensor shapes are automatically propagated in the case of
+  /// tensor resizing. This means that non-dynamic (allocation_type !=
+  /// kTfLiteDynamic) I/O tensors of a delegate kernel will have correct shapes
+  /// before its Prepare() method is called. The runtime leverages TFLite
+  /// builtin ops in the original execution plan to propagate shapes.
+  ///
+  /// A few points to note:
+  /// 1. This requires kTfLiteDelegateFlagsAllowDynamicTensors. If that flag is
+  /// false, this one is redundant since the delegate kernels are re-initialized
+  /// every time tensors are resized.
+  /// 2. Enabling this flag adds some overhead to AllocateTensors(), since extra
+  /// work is required to prepare the original execution plan.
+  /// 3. This flag requires that the original execution plan only have ops with
+  /// valid registrations (and not 'dummy' custom ops like with Flex).
+  ///
+  /// WARNING: This feature is experimental and subject to change.
   kTfLiteDelegateFlagsRequirePropagatedShapes = 2,
 
-  // This flag can be used by delegates to request per-operator profiling. If a
-  // node is a delegate node, this flag will be checked before profiling. If
-  // set, then the node will not be profiled. The delegate will then add per
-  // operator information using Profiler::EventType::OPERATOR_INVOKE_EVENT and
-  // the results will appear in the operator-wise Profiling section and not in
-  // the Delegate internal section.
+  /// This flag can be used by delegates to request per-operator profiling. If a
+  /// node is a delegate node, this flag will be checked before profiling. If
+  /// set, then the node will not be profiled. The delegate will then add per
+  /// operator information using `Profiler::EventType::OPERATOR_INVOKE_EVENT`
+  /// and the results will appear in the operator-wise Profiling section and not
+  /// in the Delegate internal section.
   kTfLiteDelegateFlagsPerOperatorProfiling = 4
 } TfLiteDelegateFlags;
 
-// WARNING: This is an experimental interface that is subject to change.
+/// WARNING: This is an experimental interface that is subject to change.
 typedef struct TfLiteDelegate {
-  // Data that delegate needs to identify itself. This data is owned by the
-  // delegate. The delegate is owned in the user code, so the delegate is
-  // responsible for deallocating this when it is destroyed.
+  /// Data that delegate needs to identify itself. This data is owned by the
+  /// delegate. The delegate is owned in the user code, so the delegate is
+  /// responsible for deallocating this when it is destroyed.
   void* data_;
 
-  // Invoked by ModifyGraphWithDelegate. This prepare is called, giving the
-  // delegate a view of the current graph through TfLiteContext*. It typically
-  // will look at the nodes and call ReplaceNodeSubsetsWithDelegateKernels()
-  // to ask the TensorFlow lite runtime to create macro-nodes to represent
-  // delegated subgraphs of the original graph.
+  /// Invoked by `ModifyGraphWithDelegate`. This prepare is called, giving the
+  /// delegate a view of the current graph through `TfLiteContext*`. It
+  /// typically will look at the nodes and call
+  /// `ReplaceNodeSubsetsWithDelegateKernels()` to ask the TensorFlow lite
+  /// runtime to create macro-nodes to represent delegated subgraphs of the
+  /// original graph.
   TfLiteStatus (*Prepare)(TfLiteContext* context,
                           struct TfLiteDelegate* delegate);
 
-  // Copy the data from delegate buffer handle into raw memory of the given
-  // 'tensor'. Note that the delegate is allowed to allocate the raw bytes as
-  // long as it follows the rules for kTfLiteDynamic tensors, in which case this
-  // cannot be null.
+  /// Copy the data from delegate buffer handle into raw memory of the given
+  /// `tensor`. Note that the delegate is allowed to allocate the raw bytes as
+  /// long as it follows the rules for `kTfLiteDynamic` tensors, in which case
+  /// this cannot be null.
   TfLiteStatus (*CopyFromBufferHandle)(TfLiteContext* context,
                                        struct TfLiteDelegate* delegate,
                                        TfLiteBufferHandle buffer_handle,
                                        TfLiteTensor* tensor);
 
-  // Copy the data from raw memory of the given 'tensor' to delegate buffer
-  // handle. This can be null if the delegate doesn't use its own buffer.
+  /// Copy the data from raw memory of the given `tensor` to delegate buffer
+  /// handle. This can be null if the delegate doesn't use its own buffer.
   TfLiteStatus (*CopyToBufferHandle)(TfLiteContext* context,
                                      struct TfLiteDelegate* delegate,
                                      TfLiteBufferHandle buffer_handle,
                                      TfLiteTensor* tensor);
 
-  // Free the Delegate Buffer Handle. Note: This only frees the handle, but
-  // this doesn't release the underlying resource (e.g. textures). The
-  // resources are either owned by application layer or the delegate.
-  // This can be null if the delegate doesn't use its own buffer.
+  /// Free the Delegate Buffer Handle. Note: This only frees the handle, but
+  /// this doesn't release the underlying resource (e.g. textures). The
+  /// resources are either owned by application layer or the delegate.
+  /// This can be null if the delegate doesn't use its own buffer.
   void (*FreeBufferHandle)(TfLiteContext* context,
                            struct TfLiteDelegate* delegate,
                            TfLiteBufferHandle* handle);
 
-  // Bitmask flags. See the comments in `TfLiteDelegateFlags`.
+  /// Bitmask flags. See the comments in `TfLiteDelegateFlags`.
   int64_t flags;
 
-  // The opaque delegate builder associated with this object.  If set then the
-  // TF Lite runtime will give precedence to this field.  E.g. instead of
-  // invoking 'Prepare' via the function pointer inside the 'TfLiteDelegate'
-  // object, the runtime will first check if the corresponding function
-  // pointer inside 'opaque_delegate_builder' is set and if so invoke that.
-  //
-  // If this field is non-null, then the 'Prepare' field (of the
-  // 'TfLiteDelegate') should be null.
+  /// The opaque delegate builder associated with this object.  If set then the
+  /// TF Lite runtime will give precedence to this field.  E.g. instead of
+  /// invoking `Prepare` via the function pointer inside the `TfLiteDelegate`
+  /// object, the runtime will first check if the corresponding function
+  /// pointer inside `opaque_delegate_builder` is set and if so invoke that.
+  ///
+  /// If this field is non-null, then the `Prepare` field (of the
+  /// `TfLiteDelegate`) should be null.
   struct TfLiteOpaqueDelegateBuilder* opaque_delegate_builder;
 } TfLiteDelegate;
 
-// Build a 'null' delegate, with all the fields properly set to their default
-// values.
+/// Build a `null` delegate, with all the fields properly set to their default
+/// values.
 TfLiteDelegate TfLiteDelegateCreate(void);
 
-// `TfLiteOpaqueDelegateBuilder` is used for constructing
-// `TfLiteOpaqueDelegate`, see `TfLiteOpaqueDelegateCreate` below.  Note:
-// This struct is not ABI stable.
-//
-// For forward source compatibility `TfLiteOpaqueDelegateBuilder` objects should
-// be brace-initialized, so that all fields (including any that might be added
-// in the future) get zero-initialized.  The purpose of each field is exactly
-// the same as with `TfLiteDelegate`.
-//
-// WARNING: This is an experimental interface that is subject to change.
+/// `TfLiteOpaqueDelegateBuilder` is used for constructing
+/// `TfLiteOpaqueDelegate`, see `TfLiteOpaqueDelegateCreate` below.  Note:
+/// This struct is not ABI stable.
+///
+/// For forward source compatibility `TfLiteOpaqueDelegateBuilder` objects
+/// should be brace-initialized, so that all fields (including any that might be
+/// added in the future) get zero-initialized.  The purpose of each field is
+/// exactly the same as with `TfLiteDelegate`.
+///
+/// WARNING: This is an experimental interface that is subject to change.
 typedef struct TfLiteOpaqueDelegateBuilder {
-  // Data that delegate needs to identify itself. This data is owned by the
-  // delegate. The delegate is owned in the user code, so the delegate is
-  // responsible for deallocating this when it is destroyed.
+  /// Data that delegate needs to identify itself. This data is owned by the
+  /// delegate. The delegate is owned in the user code, so the delegate is
+  /// responsible for deallocating this when it is destroyed.
   void* data;
-  // Invoked by ModifyGraphWithDelegate. This prepare is called, giving the
-  // delegate a view of the current graph through TfLiteContext*. It typically
-  // will look at the nodes and call ReplaceNodeSubsetsWithDelegateKernels()
-  // to ask the TensorFlow lite runtime to create macro-nodes to represent
-  // delegated subgraphs of the original graph.
+  /// Invoked by ModifyGraphWithDelegate. This prepare is called, giving the
+  /// delegate a view of the current graph through `TfLiteContext*`. It
+  /// typically will look at the nodes and call
+  /// `ReplaceNodeSubsetsWithDelegateKernels()` to ask the TensorFlow lite
+  /// runtime to create macro-nodes to represent delegated subgraphs of the
+  /// original graph.
   TfLiteStatus (*Prepare)(TfLiteOpaqueContext* context,  // NOLINT
                           TfLiteOpaqueDelegate* delegate, void* data);
-  // Copies the data from delegate buffer handle into raw memory of the given
-  // 'tensor'. Note that the delegate is allowed to allocate the raw bytes as
-  // long as it follows the rules for kTfLiteDynamic tensors, in which case this
-  // cannot be null.
+  /// Copies the data from delegate buffer handle into raw memory of the given
+  /// `tensor`. Note that the delegate is allowed to allocate the raw bytes as
+  /// long as it follows the rules for kTfLiteDynamic tensors, in which case
+  /// this cannot be null.
   TfLiteStatus (*CopyFromBufferHandle)(  // NOLINT
       TfLiteOpaqueContext* context, TfLiteOpaqueDelegate* delegate, void* data,
       TfLiteBufferHandle buffer_handle, TfLiteOpaqueTensor* tensor);
-  // Copies the data from raw memory of the given 'tensor' to delegate buffer
-  // handle. This can be null if the delegate doesn't use its own buffer.
+  /// Copies the data from raw memory of the given `tensor` to delegate buffer
+  /// handle. This can be null if the delegate doesn't use its own buffer.
   TfLiteStatus (*CopyToBufferHandle)(  // NOLINT
       TfLiteOpaqueContext* context, TfLiteOpaqueDelegate* delegate, void* data,
       TfLiteBufferHandle buffer_handle, TfLiteOpaqueTensor* tensor);
-  // Frees the Delegate Buffer Handle. Note: This only frees the handle, but
-  // this doesn't release the underlying resource (e.g. textures). The
-  // resources are either owned by application layer or the delegate.
-  // This can be null if the delegate doesn't use its own buffer.
+  /// Frees the Delegate Buffer Handle. Note: This only frees the handle, but
+  /// this doesn't release the underlying resource (e.g. textures). The
+  /// resources are either owned by application layer or the delegate.
+  /// This can be null if the delegate doesn't use its own buffer.
   void (*FreeBufferHandle)(TfLiteOpaqueContext* context,  // NOLINT
                            TfLiteOpaqueDelegate* delegate, void* data,
                            TfLiteBufferHandle* handle);
-  // Bitmask flags. See the comments in `TfLiteDelegateFlags`.
+  /// Bitmask flags. See the comments in `TfLiteDelegateFlags`.
   int64_t flags;
 } TfLiteOpaqueDelegateBuilder;
 
 #ifndef TF_LITE_STATIC_MEMORY
-// Creates an opaque delegate and returns its address.  The opaque delegate will
-// behave according to the provided 'opaque_delegate_builder'.  The lifetime of
-// the objects pointed to by any of the fields within the
-// 'opaque_delegate_builder' must outlive the returned
-// 'TfLiteOpaqueDelegate' and any 'TfLiteInterpreter',
-// 'TfLiteInterpreterOptions', 'tflite::Interpreter', or
-// 'tflite::InterpreterBuilder' that the delegate is added to.  The returned
-// address should be passed to 'TfLiteOpaqueDelegateDelete' for deletion.  If
-// 'opaque_delegate_builder' is a null pointer, then a null pointer will be
-// returned.
+/// Creates an opaque delegate and returns its address.  The opaque delegate
+/// will behave according to the provided `opaque_delegate_builder`.  The
+/// lifetime of the objects pointed to by any of the fields within the
+/// `opaque_delegate_builder` must outlive the returned
+/// `TfLiteOpaqueDelegate` and any `TfLiteInterpreter`,
+/// `TfLiteInterpreterOptions`, `tflite::Interpreter`, or
+/// `tflite::InterpreterBuilder` that the delegate is added to.  The returned
+/// address should be passed to `TfLiteOpaqueDelegateDelete` for deletion.  If
+/// `opaque_delegate_builder` is a null pointer, then a null pointer will be
+/// returned.
 TfLiteOpaqueDelegate* TfLiteOpaqueDelegateCreate(
     const TfLiteOpaqueDelegateBuilder* opaque_delegate_builder);
 
-// Deletes the provided opaque 'delegate'.  This function has no effect if the
-// 'delegate' is a null pointer.
+/// Deletes the provided opaque `delegate`.  This function has no effect if the
+/// `delegate` is a null pointer.
 void TfLiteOpaqueDelegateDelete(TfLiteOpaqueDelegate* delegate);
 #endif  // TF_LITE_STATIC_MEMORY
 
-// Returns a pointer to the data associated with the provided opaque 'delegate'.
-//
-// A null pointer will be returned when:
-// - The 'delegate' is null.
-// - The 'data' field of the 'TfLiteOpaqueDelegateBuilder' used to construct the
-//   'delegate' was null.
-// - Or in case of any other error.
-// - The 'delegate' has been constructed via a 'TfLiteOpaqueDelegateBuilder',
-//   but the 'data' field of the 'TfLiteOpaqueDelegateBuilder' is null.
-//
-//  The data_ field of 'delegate' will be returned if the
-//  'opaque_delegate_builder' field is null.
+/// Returns a pointer to the data associated with the provided opaque
+/// `delegate`.
+///
+/// A null pointer will be returned when:
+/// - The `delegate` is null.
+/// - The `data` field of the `TfLiteOpaqueDelegateBuilder` used to construct
+///   the `delegate` was null.
+/// - Or in case of any other error.
+/// - The `delegate` has been constructed via a `TfLiteOpaqueDelegateBuilder`,
+///   but the `data` field of the `TfLiteOpaqueDelegateBuilder` is null.
+///
+///  The data_ field of `delegate` will be returned if the
+///  `opaque_delegate_builder` field is null.
 void* TfLiteOpaqueDelegateGetData(const TfLiteOpaqueDelegate* delegate);
 
-// Returns a tensor data allocation strategy.
+/// Returns a tensor data allocation strategy.
 TfLiteAllocationStrategy TfLiteTensorGetAllocationStrategy(
     const TfLiteTensor* t);
 
-// Returns how stable a tensor data buffer address is across runs.
+/// Returns how stable a tensor data buffer address is across runs.
 TfLiteRunStability TfLiteTensorGetBufferAddressStability(const TfLiteTensor* t);
 
-// Returns how stable a tensor data values are across runs.
+/// Returns how stable a tensor data values are across runs.
 TfLiteRunStability TfLiteTensorGetDataStability(const TfLiteTensor* t);
 
-// Returns the operation step when the data of a tensor is populated.
-//
-// Some operations can precompute their results before the evaluation step. This
-// makes the data available earlier for subsequent operations.
+/// Returns the operation step when the data of a tensor is populated.
+///
+/// Some operations can precompute their results before the evaluation step.
+/// This makes the data available earlier for subsequent operations.
 TfLiteRunStep TfLiteTensorGetDataKnownStep(const TfLiteTensor* t);
 
-// Returns the operation steop when the shape of a tensor is computed.
-//
-// Some operations can precompute the shape of their results before the
-// evaluation step. This makes the shape available earlier for subsequent
-// operations.
+/// Returns the operation steop when the shape of a tensor is computed.
+///
+/// Some operations can precompute the shape of their results before the
+/// evaluation step. This makes the shape available earlier for subsequent
+/// operations.
 TfLiteRunStep TfLiteTensorGetShapeKnownStep(const TfLiteTensor* t);
 
+/** @} */
+// Ends `\addtogroup`, it's important for the doc generator that this doesn't
+// include the CC code below.
+
 #ifdef __cplusplus
 }  // extern "C"
 
diff --git a/tensorflow/lite/core/experimental/acceleration/configuration/c/BUILD b/tensorflow/lite/core/experimental/acceleration/configuration/c/BUILD
index 0774c13..47cad4e 100644
--- a/tensorflow/lite/core/experimental/acceleration/configuration/c/BUILD
+++ b/tensorflow/lite/core/experimental/acceleration/configuration/c/BUILD
@@ -37,9 +37,7 @@
     name = "delegate_plugin",
     hdrs = ["delegate_plugin.h"],
     compatible_with = get_compatible_with_portable(),
-    visibility = [
-        "//tensorflow/lite:__subpackages__",
-    ] + delegate_plugin_visibility_allowlist(),
+    visibility = delegate_plugin_visibility_allowlist(),
     deps = [
         "//tensorflow/lite/core/acceleration/configuration/c:delegate_plugin",
     ],
@@ -48,9 +46,7 @@
 tflite_cc_library_with_c_headers_test(
     name = "gpu_plugin",
     hdrs = ["gpu_plugin.h"],
-    visibility = [
-        "//tensorflow/lite:__subpackages__",
-    ] + gpu_plugin_visibility_allowlist(),
+    visibility = gpu_plugin_visibility_allowlist(),
     deps = [
         "//tensorflow/lite/core/acceleration/configuration/c:gpu_plugin",
     ],
@@ -67,9 +63,7 @@
 tflite_cc_library_with_c_headers_test(
     name = "xnnpack_plugin",
     hdrs = ["xnnpack_plugin.h"],
-    visibility = [
-        "//tensorflow/lite:__subpackages__",
-    ] + xnnpack_plugin_visibility_allowlist(),
+    visibility = xnnpack_plugin_visibility_allowlist(),
     deps = [
         "//tensorflow/lite/core/acceleration/configuration/c:xnnpack_plugin",
     ],
diff --git a/tensorflow/lite/core/interpreter.h b/tensorflow/lite/core/interpreter.h
index 54e9a47..98a2fd6 100644
--- a/tensorflow/lite/core/interpreter.h
+++ b/tensorflow/lite/core/interpreter.h
@@ -994,7 +994,7 @@
   // The flag is shared across all subgraphs in the interpreter.
   // When the application calls `Cancel`, the flag will be set to false.
   // It "resets" to true at the beginning of each `Invoke`.
-  std::atomic_flag continue_invocation_{false};
+  std::atomic_flag continue_invocation_ = ATOMIC_FLAG_INIT;
   bool cancellation_enabled_ = false;
 };
 
diff --git a/tensorflow/lite/core/kernels/builtin_op_kernels.h b/tensorflow/lite/core/kernels/builtin_op_kernels.h
index 2163a679..20362ad 100644
--- a/tensorflow/lite/core/kernels/builtin_op_kernels.h
+++ b/tensorflow/lite/core/kernels/builtin_op_kernels.h
@@ -200,18 +200,19 @@
 TfLiteRegistration*
 Register_STABLEHLO_LOGISTIC();  // WARNING: not implemented, using this op will
                                 // crash the runtime
-TfLiteRegistration*
-Register_STABLEHLO_ADD();  // WARNING: not implemented, using this op will crash
-                           // the runtime
+
+TfLiteRegistration* Register_STABLEHLO_ADD();
+
 TfLiteRegistration*
 Register_STABLEHLO_DIVIDE();  // WARNING: not implemented, using this op will
                               // crash the runtime
-TfLiteRegistration*
-Register_STABLEHLO_MULTIPLY();  // WARNING: not implemented, using this op will
-                                // crash the runtime
-TfLiteRegistration*
-Register_STABLEHLO_MAXIMUM();  // WARNING: not implemented, using this op will
-                               // crash the runtime
+
+TfLiteRegistration* Register_STABLEHLO_MULTIPLY();
+
+TfLiteRegistration* Register_STABLEHLO_MAXIMUM();
+
+TfLiteRegistration* Register_STABLEHLO_MINIMUM();
+
 TfLiteRegistration*
 Register_STABLEHLO_RESHAPE();  // WARNING: not implemented, using this op will
                                // crash the runtime
@@ -254,9 +255,7 @@
 TfLiteRegistration*
 Register_STABLEHLO_LOG();  // WARNING: not implemented, using this
                            // op will crash the runtime
-TfLiteRegistration*
-Register_STABLEHLO_MINIMUM();  // WARNING: not implemented, using this
-                               // op will crash the runtime
+
 TfLiteRegistration*
 Register_STABLEHLO_NEGATE();  // WARNING: not implemented, using this
                               // op will crash the runtime
@@ -301,9 +300,9 @@
 TfLiteRegistration*
 Register_STABLEHLO_DOT_GENERAL();  // WARNING: not implemented, using this
                                    // op will crash the runtime
-TfLiteRegistration*
-Register_STABLEHLO_REDUCE_WINDOW();  // WARNING: not implemented, using this
-                                     // op will crash the runtime
+
+TfLiteRegistration* Register_STABLEHLO_REDUCE_WINDOW();
+
 TfLiteRegistration*
 Register_STABLEHLO_SORT();  // WARNING: not implemented, using this
                             // op will crash the runtime
@@ -313,9 +312,8 @@
 
 TfLiteRegistration* Register_STABLEHLO_SCATTER();
 
-TfLiteRegistration*
-Register_STABLEHLO_GATHER();  // WARNING: not implemented, using this
-                              // op will crash the runtime
+TfLiteRegistration* Register_STABLEHLO_GATHER();
+
 TfLiteRegistration*
 Register_STABLEHLO_TRANSPOSE();  // WARNING: not implemented, using this
                                  // op will crash the runtime
diff --git a/tensorflow/lite/core/kernels/register.cc b/tensorflow/lite/core/kernels/register.cc
index add966d..cb53c20 100644
--- a/tensorflow/lite/core/kernels/register.cc
+++ b/tensorflow/lite/core/kernels/register.cc
@@ -373,6 +373,13 @@
   AddBuiltin(BuiltinOperator_STABLEHLO_RNG_BIT_GENERATOR,
              Register_STABLEHLO_RNG_BIT_GENERATOR());
   AddBuiltin(BuiltinOperator_REDUCE_WINDOW, Register_REDUCE_WINDOW());
+  AddBuiltin(BuiltinOperator_STABLEHLO_REDUCE_WINDOW,
+             Register_STABLEHLO_REDUCE_WINDOW());
+  AddBuiltin(BuiltinOperator_STABLEHLO_GATHER, Register_STABLEHLO_GATHER());
+  AddBuiltin(BuiltinOperator_STABLEHLO_ADD, Register_STABLEHLO_ADD());
+  AddBuiltin(BuiltinOperator_STABLEHLO_MULTIPLY, Register_STABLEHLO_MULTIPLY());
+  AddBuiltin(BuiltinOperator_STABLEHLO_MAXIMUM, Register_STABLEHLO_MAXIMUM());
+  AddBuiltin(BuiltinOperator_STABLEHLO_MINIMUM, Register_STABLEHLO_MINIMUM());
   AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY());
   // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
   // custom ops aren't always included by default.
diff --git a/tensorflow/lite/core/model_builder.cc b/tensorflow/lite/core/model_builder.cc
index 0c4969a..0582283 100644
--- a/tensorflow/lite/core/model_builder.cc
+++ b/tensorflow/lite/core/model_builder.cc
@@ -243,15 +243,13 @@
   auto buffers = model_->buffers();
   if (buffers && buffers->size() > 0) {
     auto first_buffer = buffers->Get(0);
-    if (first_buffer && first_buffer->data()) {
-      if (first_buffer->data()->size() != 0) {
-        // Note the 0th entry of this array must be an empty buffer (sentinel).
-        // This is a convention so that tensors without a buffer can provide 0
-        // as their buffer.
-        TF_LITE_REPORT_ERROR(
-            error_reporter,
-            "The 0th entry of the model buffer must be an empty buffer.");
-      }
+    if (first_buffer && first_buffer->size() != 0) {
+      // Note the 0th entry of this array must be an empty buffer (sentinel).
+      // This is a convention so that tensors without a buffer can provide 0
+      // as their buffer.
+      TF_LITE_REPORT_ERROR(
+          error_reporter,
+          "The 0th entry of the model buffer must be an empty buffer.");
     }
   }
 }
diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc
index 3fa9fa3..489b5d4 100644
--- a/tensorflow/lite/core/subgraph.cc
+++ b/tensorflow/lite/core/subgraph.cc
@@ -1726,15 +1726,35 @@
 TfLiteStatus Subgraph::ResizeTensor(TfLiteContext* context,
                                     TfLiteTensor* tensor,
                                     TfLiteIntArray* new_size) {
-  // If the dimensions don't change, avoiding
-  // unnecessary (re)allocations.
+  // If the dimensions don't change, avoid unnecessary (re)allocations.
   //
   // Note that it's required to check `tensor->data.raw != nullptr`. Otherwise
   // the subgraph won't allocate memory for a dynamic tensor when its size
   // is equal to the original tensor size.
-  if (tensor->data.raw != nullptr &&
-      EqualArrayAndTfLiteIntArray(tensor->dims, new_size->size,
-                                  new_size->data)) {
+  //
+  // We also need to check the bytes count because some direct calls to
+  // TfLiteTensorResizeMaybeCopy may lead to inconsistent dims and bytes in a
+  // tensor.
+  const bool can_reuse_allocation = [tensor, new_size, context] {
+    if (tensor->data.raw == nullptr) {
+      return false;
+    }
+    if (!EqualArrayAndTfLiteIntArray(tensor->dims, new_size->size,
+                                     new_size->data)) {
+      return false;
+    }
+    // Those data types byte sizes are not handled by `ResizeTensorImpl`.
+    if (tensor->type == kTfLiteString || tensor->type == kTfLiteResource ||
+        tensor->type == kTfLiteVariant) {
+      return true;
+    }
+    size_t new_bytes = 0;
+    tflite::BytesRequired(tensor->type, tensor->dims->data, tensor->dims->size,
+                          &new_bytes, context);
+    return new_bytes == tensor->bytes;
+  }();
+
+  if (can_reuse_allocation) {
     // A number of clients assume |new_size| remains valid upon success, so
     // swap it in as the new (but logically identical) tensor dims.
     if (new_size != tensor->dims) {
diff --git a/tensorflow/lite/core/subgraph.h b/tensorflow/lite/core/subgraph.h
index 58bcd8c..5de4916 100644
--- a/tensorflow/lite/core/subgraph.h
+++ b/tensorflow/lite/core/subgraph.h
@@ -229,6 +229,14 @@
   // Return read-only vector of node indices in the order of execution.
   const std::vector<int>& execution_plan() const { return execution_plan_; }
 
+  // Return read-only vector of node indices in the order of execution before
+  // any delegate was applied.
+  //
+  // Note: if no delegate is applied, this vector will be empty.
+  const std::vector<int>& pre_delegation_execution_plan() const {
+    return pre_delegation_execution_plan_;
+  }
+
   const std::vector<std::pair<TfLiteNode, TfLiteRegistration>>&
   nodes_and_registration() const {
     return nodes_and_registration_;
diff --git a/tensorflow/lite/core/subgraph_test.cc b/tensorflow/lite/core/subgraph_test.cc
index 9e2f7e9..7515fcc 100644
--- a/tensorflow/lite/core/subgraph_test.cc
+++ b/tensorflow/lite/core/subgraph_test.cc
@@ -20,14 +20,13 @@
 #include <functional>
 #include <memory>
 #include <numeric>
-#include <string>
 #include <vector>
 
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
+#include "tensorflow/lite/c/c_api_types.h"
 #include "tensorflow/lite/core/interpreter.h"
 #include "tensorflow/lite/stderr_reporter.h"
-#include "tensorflow/lite/testing/util.h"
 #include "tensorflow/lite/util.h"
 
 namespace tflite {
@@ -241,7 +240,6 @@
 TEST_F(SubgraphResizeTensorTest,
        ResizeDynamicTensorWithDifferentShapeReallocatesData) {
   ASSERT_EQ(context_.ResizeTensor(&context_, &tensor_, dims_), kTfLiteOk);
-  const void* const initial_data = tensor_.data.data;
 
   TfLiteIntArray* dims2 = ConvertVectorToTfLiteIntArray({5, 4, 6});
   const int dims2_bytes = BytesFor(type_, *dims2);
@@ -251,12 +249,87 @@
 
   // Some alignment requirements may lead to more memory being allocated.
   EXPECT_GE(tensor_.bytes, dims2_bytes);
-  EXPECT_NE(tensor_.data.data, initial_data);
   EXPECT_EQ(tensor_.dims, dims2);
   // Touch memory to trigger ASAN in case of incorrect handling.
   std::fill_n(tensor_.data.raw, dims2_bytes, 0);
   std::fill_n(tensor_.dims->data, tensor_.dims->size, 1);
 }
 
+TEST_F(SubgraphResizeTensorTest,
+       ResizeDynamicTensorWithSameShapeButDifferentBytesReallocatesData) {
+  ASSERT_EQ(context_.ResizeTensor(&context_, &tensor_, dims_), kTfLiteOk);
+  // Resize the tensor manually with more bytes than it already has.
+  TfLiteTensorResizeMaybeCopy(tensor_.bytes + 15, &tensor_,
+                              /*preserve_data=*/true);
+  ASSERT_GT(tensor_.bytes, reference_dims_bytes_);
+
+  ASSERT_EQ(context_.ResizeTensor(&context_, &tensor_, tensor_.dims),
+            kTfLiteOk);
+
+  // Some alignment requirements may lead to more memory being allocated.
+  EXPECT_GE(tensor_.bytes, reference_dims_bytes_);
+  EXPECT_EQ(tensor_.dims, dims_);
+  // Touch memory to trigger ASAN in case of incorrect handling.
+  std::fill_n(tensor_.data.raw, tensor_.bytes, 0);
+  std::fill_n(tensor_.dims->data, tensor_.dims->size, 1);
+}
+
+TEST_F(SubgraphResizeTensorTest,
+       ResizeDynamicTensorWithSameShapeButStringTypeSizeReallocatesData) {
+  constexpr size_t manual_bytes = 10;
+  // Allocate the tensor manually.
+  TfLiteTensorResizeMaybeCopy(manual_bytes, &tensor_, /*preserve_data=*/true);
+  tensor_.dims = dims_;
+  tensor_.type = kTfLiteString;
+
+  // This fill fail if the memory reuse check cannot compute the new byte size.
+  ASSERT_EQ(context_.ResizeTensor(&context_, &tensor_, tensor_.dims),
+            kTfLiteOk);
+
+  // Some alignment requirements may lead to more memory being allocated.
+  EXPECT_EQ(tensor_.dims, dims_);
+  // Touch memory to trigger ASAN in case of incorrect handling.
+  std::fill_n(tensor_.data.raw, tensor_.bytes, 0);
+  std::fill_n(tensor_.dims->data, tensor_.dims->size, 1);
+}
+
+TEST_F(SubgraphResizeTensorTest,
+       ResizeDynamicTensorWithSameShapeButRessourceTypeSizeReallocatesData) {
+  constexpr size_t manual_bytes = 10;
+  // Allocate the tensor manually.
+  TfLiteTensorResizeMaybeCopy(manual_bytes, &tensor_, /*preserve_data=*/true);
+  tensor_.dims = dims_;
+  tensor_.type = kTfLiteResource;
+
+  // This fill fail if the memory reuse check cannot compute the new byte size.
+  ASSERT_EQ(context_.ResizeTensor(&context_, &tensor_, tensor_.dims),
+            kTfLiteOk);
+
+  // Some alignment requirements may lead to more memory being allocated.
+  EXPECT_EQ(tensor_.dims, dims_);
+  // Touch memory to trigger ASAN in case of incorrect handling.
+  std::fill_n(tensor_.data.raw, tensor_.bytes, 0);
+  std::fill_n(tensor_.dims->data, tensor_.dims->size, 1);
+}
+
+TEST_F(SubgraphResizeTensorTest,
+       ResizeDynamicTensorWithSameShapeButVariantTypeSizeReallocatesData) {
+  constexpr size_t manual_bytes = 10;
+  // Allocate the tensor manually.
+  TfLiteTensorResizeMaybeCopy(manual_bytes, &tensor_, /*preserve_data=*/true);
+  tensor_.dims = dims_;
+  tensor_.type = kTfLiteVariant;
+
+  // This fill fail if the memory reuse check cannot compute the new byte size.
+  ASSERT_EQ(context_.ResizeTensor(&context_, &tensor_, tensor_.dims),
+            kTfLiteOk);
+
+  // Some alignment requirements may lead to more memory being allocated.
+  EXPECT_EQ(tensor_.dims, dims_);
+  // Touch memory to trigger ASAN in case of incorrect handling.
+  std::fill_n(tensor_.data.raw, tensor_.bytes, 0);
+  std::fill_n(tensor_.dims->data, tensor_.dims->size, 1);
+}
+
 }  // namespace
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/flex/BUILD b/tensorflow/lite/delegates/flex/BUILD
index 5d27b02..857e266 100644
--- a/tensorflow/lite/delegates/flex/BUILD
+++ b/tensorflow/lite/delegates/flex/BUILD
@@ -4,8 +4,8 @@
     "if_not_mobile",
     "tf_cc_test",
     "tf_features_nolayering_check_if_ios",
-    "tf_opts_nortti_if_android",
     "tf_opts_nortti_if_lite_protos",
+    "tf_opts_nortti_if_mobile",
 )
 load("//tensorflow/lite:build_def.bzl", "tflite_copts")
 load("//tensorflow/lite:special_rules.bzl", "internal_visibility_allowlist")
@@ -157,7 +157,7 @@
         "delegate.h",
     ],
     compatible_with = get_compatible_with_portable(),
-    copts = tflite_copts() + tf_opts_nortti_if_android(),
+    copts = tflite_copts() + tf_opts_nortti_if_mobile(),
     features = tf_features_nolayering_check_if_ios(),
     visibility = ["//visibility:public"],
     deps = [
@@ -210,7 +210,7 @@
     srcs = ["delegate_data.cc"],
     hdrs = ["delegate_data.h"],
     compatible_with = get_compatible_with_portable(),
-    copts = tf_opts_nortti_if_android(),
+    copts = tf_opts_nortti_if_mobile(),
     features = tf_features_nolayering_check_if_ios(),
     visibility = ["//visibility:public"],
     deps = [
@@ -405,7 +405,7 @@
     name = "tflite_subgraph_execute",
     srcs = ["tflite_subgraph_execute.cc"],
     compatible_with = get_compatible_with_portable(),
-    copts = tf_opts_nortti_if_android(),
+    copts = tf_opts_nortti_if_mobile(),
     features = tf_features_nolayering_check_if_ios(),
     deps = [
         ":buffer_map_util",
diff --git a/tensorflow/lite/delegates/flex/build_def.bzl b/tensorflow/lite/delegates/flex/build_def.bzl
index 0106292..44588f9 100644
--- a/tensorflow/lite/delegates/flex/build_def.bzl
+++ b/tensorflow/lite/delegates/flex/build_def.bzl
@@ -21,8 +21,8 @@
     "tflite_jni_binary",
     "tflite_jni_linkopts",
 )
-load("@build_bazel_rules_android//android:rules.bzl", "android_library")
 load("//tensorflow/lite:special_rules.bzl", "flex_portable_tensorflow_deps")
+load("@build_bazel_rules_android//android:rules.bzl", "android_library")
 
 def generate_flex_kernel_header(
         name,
@@ -147,9 +147,9 @@
             ],
             visibility = visibility,
             deps = flex_portable_tensorflow_deps() + [
+                clean_dep("@ducc//:fft_wrapper"),
                 clean_dep("//tensorflow/core:protos_all_cc"),
                 clean_dep("//tensorflow/core:portable_tensorflow_lib_lite"),
-                clean_dep("//tensorflow/core/kernels:portable_fft_impl"),
                 clean_dep("//tensorflow/core/platform:strong_hash"),
                 clean_dep("//tensorflow/lite/delegates/flex:portable_images_lib"),
             ],
diff --git a/tensorflow/lite/delegates/flex/test/BUILD b/tensorflow/lite/delegates/flex/test/BUILD
index e0bb3f4..238b6ad 100644
--- a/tensorflow/lite/delegates/flex/test/BUILD
+++ b/tensorflow/lite/delegates/flex/test/BUILD
@@ -8,9 +8,6 @@
 
 package(
     # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
-    default_visibility = [
-        "//tensorflow/lite/android:__subpackages__",
-    ],
     licenses = ["notice"],
 )
 
diff --git a/tensorflow/lite/delegates/gpu/build_defs.bzl b/tensorflow/lite/delegates/gpu/build_defs.bzl
index 462ec7c..e6e6fa2 100644
--- a/tensorflow/lite/delegates/gpu/build_defs.bzl
+++ b/tensorflow/lite/delegates/gpu/build_defs.bzl
@@ -25,3 +25,18 @@
         ],
         "//conditions:default": [],
     }) + nativewindow_linkopts()
+
+def tflite_angle_heapcheck_deps():
+    # copybara:uncomment_begin(google-only)
+    # return select({
+    # "//tensorflow/lite/delegates/gpu:tflite_gpu_angle": [
+    # "@com_google_googletest//:gtest_main_no_heapcheck",
+    # ],
+    # "//conditions:default": [
+    # "@com_google_googletest//:gtest_main",
+    # ],
+    # })
+    # copybara:uncomment_end
+    # copybara:comment_begin(oss-only)
+    return ["@com_google_googletest//:gtest_main"]
+    # copybara:comment_end
diff --git a/tensorflow/lite/delegates/gpu/cl/inference_context.cc b/tensorflow/lite/delegates/gpu/cl/inference_context.cc
index 684ed5c..93982f5 100644
--- a/tensorflow/lite/delegates/gpu/cl/inference_context.cc
+++ b/tensorflow/lite/delegates/gpu/cl/inference_context.cc
@@ -238,6 +238,12 @@
     flush_periodically = true;
     flush_period = 16;
   }
+  // clvk has inside to know when to flush, do not do it at the application
+  // level.
+  if (gpu_info.IsApiOpenCl() && gpu_info.opencl_info.IsCLVK()) {
+    need_flush = false;
+    flush_periodically = false;
+  }
 }
 
 absl::Status InferenceContext::InitFromGraph(
diff --git a/tensorflow/lite/delegates/gpu/common/task/profiling_info.cc b/tensorflow/lite/delegates/gpu/common/task/profiling_info.cc
index 630290f..0c9d3ad 100644
--- a/tensorflow/lite/delegates/gpu/common/task/profiling_info.cc
+++ b/tensorflow/lite/delegates/gpu/common/task/profiling_info.cc
@@ -49,7 +49,7 @@
           dispatch.read_mem_size + dispatch.write_mem_size;
       const double giga_bytes = total_size / 1024.0 / 1024.0 / 1024.0;
       const double giga_bytes_per_sec = times_per_sec * giga_bytes;
-      result += ", " + std::to_string(giga_bytes_per_sec) + " Gb/s";
+      result += ", " + std::to_string(giga_bytes_per_sec) + " Gbyte/s";
     }
     if (dispatch.flops) {
       const double giga_flops = dispatch.flops / 1000.0 / 1000.0 / 1000.0;
diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/BUILD b/tensorflow/lite/delegates/gpu/gl/kernels/BUILD
index 169e0cd..eafba64 100644
--- a/tensorflow/lite/delegates/gpu/gl/kernels/BUILD
+++ b/tensorflow/lite/delegates/gpu/gl/kernels/BUILD
@@ -2,7 +2,12 @@
     "//tensorflow/core/platform:build_config_root.bzl",
     "tf_gpu_tests_tags",
 )
-load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite_combined")
+load(
+    "//tensorflow/lite:special_rules.bzl",
+    "tflite_extra_gles_deps",
+    "tflite_portable_test_suite_combined",
+)
+load("//tensorflow/lite/delegates/gpu:build_defs.bzl", "tflite_angle_heapcheck_deps")
 
 package(
     # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
@@ -33,14 +38,7 @@
     name = "converter_test",
     size = "small",
     srcs = ["converter_test.cc"],
-    linkopts = [
-        "-lEGL",
-        "-lGLESv3",
-    ],
     tags = tf_gpu_tests_tags() + [
-        "local",
-        "nobuilder",
-        "notap",
         "tflite_not_portable_ios",
     ],
     deps = [
@@ -53,8 +51,7 @@
         "//tensorflow/lite/delegates/gpu/gl:gl_buffer",
         "//tensorflow/lite/delegates/gpu/gl:portable",
         "@com_google_absl//absl/types:span",
-        "@com_google_googletest//:gtest_main",
-    ],
+    ] + tflite_angle_heapcheck_deps(),
 )
 
 cc_library(
@@ -78,15 +75,13 @@
     srcs = ["add_test.cc"],
     linkstatic = True,
     tags = tf_gpu_tests_tags() + [
-        "notap",
         "tflite_not_portable_ios",
     ],
     deps = [
         ":add",
         ":test_util",
         "//tensorflow/lite/delegates/gpu/common:operations",
-        "@com_google_googletest//:gtest",
-    ],
+    ] + tflite_angle_heapcheck_deps(),
 )
 
 cc_library(
@@ -108,15 +103,13 @@
     srcs = ["concat_test.cc"],
     linkstatic = True,
     tags = tf_gpu_tests_tags() + [
-        "notap",
         "tflite_not_portable_ios",
     ],
     deps = [
         ":concat",
         ":test_util",
         "//tensorflow/lite/delegates/gpu/common:operations",
-        "@com_google_googletest//:gtest",
-    ],
+    ] + tflite_angle_heapcheck_deps(),
 )
 
 cc_library(
@@ -143,15 +136,13 @@
     srcs = ["conv_test.cc"],
     linkstatic = True,
     tags = tf_gpu_tests_tags() + [
-        "notap",
         "tflite_not_portable_ios",
     ],
     deps = [
         ":conv",
         ":test_util",
         "//tensorflow/lite/delegates/gpu/common:operations",
-        "@com_google_googletest//:gtest",
-    ],
+    ] + tflite_angle_heapcheck_deps(),
 )
 
 cc_library(
@@ -187,15 +178,13 @@
     srcs = ["depthwise_conv_test.cc"],
     linkstatic = True,
     tags = tf_gpu_tests_tags() + [
-        "notap",
         "tflite_not_portable_ios",
     ],
     deps = [
         ":depthwise_conv",
         ":test_util",
         "//tensorflow/lite/delegates/gpu/common:operations",
-        "@com_google_googletest//:gtest",
-    ],
+    ] + tflite_angle_heapcheck_deps(),
 )
 
 cc_library(
@@ -217,15 +206,13 @@
     srcs = ["elementwise_test.cc"],
     linkstatic = True,
     tags = tf_gpu_tests_tags() + [
-        "notap",
         "tflite_not_portable_ios",
     ],
     deps = [
         ":elementwise",
         ":test_util",
         "//tensorflow/lite/delegates/gpu/common:operations",
-        "@com_google_googletest//:gtest",
-    ],
+    ] + tflite_angle_heapcheck_deps(),
 )
 
 cc_library(
@@ -248,15 +235,13 @@
     srcs = ["fully_connected_test.cc"],
     linkstatic = True,
     tags = tf_gpu_tests_tags() + [
-        "notap",
         "tflite_not_portable_ios",
     ],
     deps = [
         ":fully_connected",
         ":test_util",
         "//tensorflow/lite/delegates/gpu/common:operations",
-        "@com_google_googletest//:gtest",
-    ],
+    ] + tflite_angle_heapcheck_deps(),
 )
 
 cc_library(
@@ -277,15 +262,13 @@
     srcs = ["lstm_test.cc"],
     linkstatic = True,
     tags = tf_gpu_tests_tags() + [
-        "notap",
         "tflite_not_portable_ios",
     ],
     deps = [
         ":lstm",
         ":test_util",
         "//tensorflow/lite/delegates/gpu/common:operations",
-        "@com_google_googletest//:gtest",
-    ],
+    ] + tflite_angle_heapcheck_deps(),
 )
 
 cc_library(
@@ -307,15 +290,13 @@
     srcs = ["max_unpooling_test.cc"],
     linkstatic = True,
     tags = tf_gpu_tests_tags() + [
-        "notap",
         "tflite_not_portable_ios",
     ],
     deps = [
         ":max_unpooling",
         ":test_util",
         "//tensorflow/lite/delegates/gpu/common:operations",
-        "@com_google_googletest//:gtest",
-    ],
+    ] + tflite_angle_heapcheck_deps(),
 )
 
 cc_library(
@@ -338,15 +319,13 @@
     srcs = ["mean_test.cc"],
     linkstatic = True,
     tags = tf_gpu_tests_tags() + [
-        "notap",
         "tflite_not_portable_ios",
     ],
     deps = [
         ":mean",
         ":test_util",
         "//tensorflow/lite/delegates/gpu/common:operations",
-        "@com_google_googletest//:gtest",
-    ],
+    ] + tflite_angle_heapcheck_deps(),
 )
 
 cc_library(
@@ -373,15 +352,13 @@
     srcs = ["mul_test.cc"],
     linkstatic = True,
     tags = tf_gpu_tests_tags() + [
-        "notap",
         "tflite_not_portable_ios",
     ],
     deps = [
         ":mul",
         ":test_util",
         "//tensorflow/lite/delegates/gpu/common:operations",
-        "@com_google_googletest//:gtest",
-    ],
+    ] + tflite_angle_heapcheck_deps(),
 )
 
 cc_library(
@@ -403,15 +380,13 @@
     srcs = ["pad_test.cc"],
     linkstatic = True,
     tags = tf_gpu_tests_tags() + [
-        "notap",
         "tflite_not_portable_ios",
     ],
     deps = [
         ":pad",
         ":test_util",
         "//tensorflow/lite/delegates/gpu/common:operations",
-        "@com_google_googletest//:gtest",
-    ],
+    ] + tflite_angle_heapcheck_deps(),
 )
 
 cc_library(
@@ -433,15 +408,13 @@
     srcs = ["pooling_test.cc"],
     linkstatic = True,
     tags = tf_gpu_tests_tags() + [
-        "notap",
         "tflite_not_portable_ios",
     ],
     deps = [
         ":pooling",
         ":test_util",
         "//tensorflow/lite/delegates/gpu/common:operations",
-        "@com_google_googletest//:gtest",
-    ],
+    ] + tflite_angle_heapcheck_deps(),
 )
 
 cc_library(
@@ -465,15 +438,13 @@
     srcs = ["prelu_test.cc"],
     linkstatic = True,
     tags = tf_gpu_tests_tags() + [
-        "notap",
         "tflite_not_portable_ios",
     ],
     deps = [
         ":prelu",
         ":test_util",
         "//tensorflow/lite/delegates/gpu/common:operations",
-        "@com_google_googletest//:gtest",
-    ],
+    ] + tflite_angle_heapcheck_deps(),
 )
 
 cc_library(
@@ -496,7 +467,6 @@
     srcs = ["quantize_and_dequantize_test.cc"],
     linkstatic = True,
     tags = tf_gpu_tests_tags() + [
-        "notap",
         "tflite_not_portable_ios",
     ],
     deps = [
@@ -504,8 +474,7 @@
         ":test_util",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/kernels/internal:quantization_util",
-        "@com_google_googletest//:gtest",
-    ],
+    ] + tflite_angle_heapcheck_deps(),
 )
 
 cc_library(
@@ -527,15 +496,13 @@
     srcs = ["relu_test.cc"],
     linkstatic = True,
     tags = tf_gpu_tests_tags() + [
-        "notap",
         "tflite_not_portable_ios",
     ],
     deps = [
         ":relu",
         ":test_util",
         "//tensorflow/lite/delegates/gpu/common:operations",
-        "@com_google_googletest//:gtest",
-    ],
+    ] + tflite_angle_heapcheck_deps(),
 )
 
 cc_library(
@@ -556,15 +523,13 @@
     srcs = ["resampler_test.cc"],
     linkstatic = True,
     tags = tf_gpu_tests_tags() + [
-        "notap",
         "tflite_not_portable_ios",
     ],
     deps = [
         ":resampler",
         ":test_util",
         "//tensorflow/lite/delegates/gpu/common:operations",
-        "@com_google_googletest//:gtest",
-    ],
+    ] + tflite_angle_heapcheck_deps(),
 )
 
 cc_library(
@@ -585,15 +550,13 @@
     srcs = ["reshape_test.cc"],
     linkstatic = True,
     tags = tf_gpu_tests_tags() + [
-        "notap",
         "tflite_not_portable_ios",
     ],
     deps = [
         ":reshape",
         ":test_util",
         "//tensorflow/lite/delegates/gpu/common:operations",
-        "@com_google_googletest//:gtest",
-    ],
+    ] + tflite_angle_heapcheck_deps(),
 )
 
 cc_library(
@@ -614,15 +577,13 @@
     srcs = ["resize_test.cc"],
     linkstatic = True,
     tags = tf_gpu_tests_tags() + [
-        "notap",
         "tflite_not_portable_ios",
     ],
     deps = [
         ":resize",
         ":test_util",
         "//tensorflow/lite/delegates/gpu/common:operations",
-        "@com_google_googletest//:gtest",
-    ],
+    ] + tflite_angle_heapcheck_deps(),
 )
 
 cc_library(
@@ -644,15 +605,13 @@
     srcs = ["slice_test.cc"],
     linkstatic = True,
     tags = tf_gpu_tests_tags() + [
-        "notap",
         "tflite_not_portable_ios",
     ],
     deps = [
         ":slice",
         ":test_util",
         "//tensorflow/lite/delegates/gpu/common:operations",
-        "@com_google_googletest//:gtest",
-    ],
+    ] + tflite_angle_heapcheck_deps(),
 )
 
 cc_library(
@@ -676,7 +635,6 @@
     srcs = ["softmax_test.cc"],
     linkstatic = True,
     tags = tf_gpu_tests_tags() + [
-        "notap",
         "tflite_not_portable_ios",
     ],
     deps = [
@@ -684,8 +642,7 @@
         ":test_util",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:shape",
-        "@com_google_googletest//:gtest",
-    ],
+    ] + tflite_angle_heapcheck_deps(),
 )
 
 cc_library(
@@ -706,15 +663,13 @@
     srcs = ["space_to_depth_test.cc"],
     linkstatic = True,
     tags = tf_gpu_tests_tags() + [
-        "notap",
         "tflite_not_portable_ios",
     ],
     deps = [
         ":space_to_depth",
         ":test_util",
         "//tensorflow/lite/delegates/gpu/common:operations",
-        "@com_google_googletest//:gtest",
-    ],
+    ] + tflite_angle_heapcheck_deps(),
 )
 
 cc_library(
@@ -729,6 +684,10 @@
             "-ldl",
             "-lm",
         ],
+        # copybara:uncomment_begin(google-only)
+        # "//tensorflow/lite/delegates/gpu:tflite_gpu_angle": [],
+        # "//tensorflow/lite/delegates/gpu:tflite_gpu_extra_gles_deps": [],
+        # copybara:uncomment_end
         "//conditions:default": [
             "-lEGL",
             "-lGLESv3",
@@ -751,8 +710,7 @@
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/container:flat_hash_set",
         "@com_google_googletest//:gtest",
-        "@com_google_googletest//:gtest_main",
-    ],
+    ] + tflite_extra_gles_deps(),
 )
 
 cc_library(
@@ -773,15 +731,13 @@
     srcs = ["tile_test.cc"],
     linkstatic = True,
     tags = tf_gpu_tests_tags() + [
-        "notap",
         "tflite_not_portable_ios",
     ],
     deps = [
         ":test_util",
         ":tile",
         "//tensorflow/lite/delegates/gpu/common:operations",
-        "@com_google_googletest//:gtest",
-    ],
+    ] + tflite_angle_heapcheck_deps(),
 )
 
 cc_library(
@@ -806,15 +762,13 @@
     srcs = ["transpose_conv_test.cc"],
     linkstatic = True,
     tags = tf_gpu_tests_tags() + [
-        "notap",
         "tflite_not_portable_ios",
     ],
     deps = [
         ":test_util",
         ":transpose_conv",
         "//tensorflow/lite/delegates/gpu/common:operations",
-        "@com_google_googletest//:gtest",
-    ],
+    ] + tflite_angle_heapcheck_deps(),
 )
 
 TFLITE_GPU_BINARY_RELEASE_OPERATORS = [
diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
index d0f37eb..a9a5cd1 100644
--- a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
+++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
@@ -1459,15 +1459,48 @@
   TfLiteStatus TransformCosIntoSupportedOps(int lite_node_index,
                                             TfLiteNode* node,
                                             TfLiteRegistration* reg) {
-    const TfLiteTensor& theta = context_->tensors[node->inputs->data[0]];
-
-    // NNAPI only supports float sin
-    auto tensor_size = theta.bytes / sizeof(float);
+    const TfLiteTensor& input = context_->tensors[node->inputs->data[0]];
+    const TfLiteTensor& output = context_->tensors[node->outputs->data[0]];
 
     // Convert cos to sin: $cos(x) = sin(\frac{\pi}{2} - x)$
-    auto data = theta.data.f;
-    for (int i = 0; i < tensor_size; i++) {
-      data[i] = M_PI_2 - data[i];
+
+    int diff_out_ann_index;
+    // stage 1: $frac{\pi}{2} - x)$
+    {
+      // NNAPI only supports float sin
+      auto tensor_size = input.bytes / sizeof(float);
+
+      int tensor_index;
+      TF_LITE_ENSURE_OK(context_,
+                        AddNewInputConstantTensor(
+                            ANEURALNETWORKS_TENSOR_FLOAT32, kTfLiteFloat32,
+                            input.dims, std::vector<float>(tensor_size, M_PI_2),
+                            input.params, &tensor_index));
+
+      TF_LITE_ENSURE_OK(
+          context_, AddTensorInput(node->inputs->data[0], /*hybrid_op=*/false));
+
+      TF_LITE_ENSURE_OK(context_,
+                        AddScalarInt32Operand(ANEURALNETWORKS_FUSED_NONE));
+
+      TF_LITE_ENSURE_OK(
+          context_,
+          AddAdditionalOutputTensor(
+              output.dims->size, reinterpret_cast<uint32_t*>(output.dims->data),
+              ANEURALNETWORKS_TENSOR_FLOAT32, 0, 0, &diff_out_ann_index));
+
+      TF_LITE_ENSURE_OK(
+          context_, FinalizeAddOperation(ANEURALNETWORKS_SUB, lite_node_index));
+    }
+
+    // stage 2: $sin(\frac{\pi}{2} - x)$
+    {
+      augmented_inputs_.push_back(diff_out_ann_index);
+
+      TF_LITE_ENSURE_OK(context_, AddTensorOutput(node->outputs->data[0]));
+
+      TF_LITE_ENSURE_OK(
+          context_, FinalizeAddOperation(ANEURALNETWORKS_SIN, lite_node_index));
     }
 
     return kTfLiteOk;
diff --git a/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_app_using_stable_delegate.cc b/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_app_using_stable_delegate.cc
index 2ef7d18..d09f355 100644
--- a/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_app_using_stable_delegate.cc
+++ b/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_app_using_stable_delegate.cc
@@ -22,6 +22,7 @@
 #include "tensorflow/lite/acceleration/configuration/configuration_generated.h"
 #include "tensorflow/lite/c/c_api.h"  // For TfLiteTensorByteSize.
 #include "tensorflow/lite/delegates/utils/experimental/stable_delegate/delegate_loader.h"
+#include "tensorflow/lite/interpreter.h"
 #include "tensorflow/lite/interpreter_builder.h"
 #include "tensorflow/lite/kernels/register.h"
 #include "tensorflow/lite/model_builder.h"
diff --git a/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_stable_delegate_external.cc b/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_stable_delegate_external.cc
index 4076626..135d6f9 100644
--- a/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_stable_delegate_external.cc
+++ b/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_stable_delegate_external.cc
@@ -15,13 +15,13 @@
 #include <memory>
 #include <utility>
 
+#include "tensorflow/lite/acceleration/configuration/c/delegate_plugin.h"
+#include "tensorflow/lite/acceleration/configuration/c/stable_delegate.h"
 #include "tensorflow/lite/c/c_api_types.h"
 #include "tensorflow/lite/c/common.h"
-#include "tensorflow/lite/acceleration/configuration/c/delegate_plugin.h"
 #include "tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_stable_delegate.h"
 #include "tensorflow/lite/delegates/utils/experimental/stable_delegate/stable_delegate_interface.h"
 #include "tensorflow/lite/delegates/utils/simple_opaque_delegate.h"
-#include "tensorflow/lite/acceleration/configuration/c/stable_delegate.h"
 
 namespace {
 
diff --git a/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_stable_delegate_external_test.cc b/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_stable_delegate_external_test.cc
index 9c26723..e237eef 100644
--- a/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_stable_delegate_external_test.cc
+++ b/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_stable_delegate_external_test.cc
@@ -46,6 +46,23 @@
   EXPECT_STREQ(stable_delegate_handle->delegate_version,
                tflite::example::kSampleStableDelegateVersion);
   ASSERT_NE(stable_delegate_handle->delegate_plugin, nullptr);
+}
+
+TEST(SampleStableDelegate, LoadFromSharedLibraryTestFile) {
+  // Load the example stable opaque_delegate that implements the ADD operation
+  // from a shared library file.
+  const TfLiteStableDelegate* stable_delegate_handle =
+      LoadDelegateFromSharedLibrary(
+          "tensorflow/lite/delegates/utils/experimental/"
+          "sample_stable_delegate/libtensorflowlite_sample_stable_delegate_for_test.so");
+  ASSERT_NE(stable_delegate_handle, nullptr);
+  EXPECT_STREQ(stable_delegate_handle->delegate_abi_version,
+               TFL_STABLE_DELEGATE_ABI_VERSION);
+  EXPECT_STREQ(stable_delegate_handle->delegate_name,
+               tflite::example::kSampleStableDelegateName);
+  EXPECT_STREQ(stable_delegate_handle->delegate_version,
+               tflite::example::kSampleStableDelegateVersion);
+  ASSERT_NE(stable_delegate_handle->delegate_plugin, nullptr);
 
   // Build TFLiteSettings flatbuffer and pass into opaque_delegate plugin
   // create method.
@@ -60,9 +77,7 @@
       stable_delegate_handle->delegate_plugin->create(settings);
   ASSERT_NE(opaque_delegate, nullptr);
 
-  //
   // Create the model and the interpreter
-  //
   TfLiteModel* model =
       TfLiteModelCreateFromFile("tensorflow/lite/testdata/add.bin");
   ASSERT_NE(model, nullptr);
@@ -74,9 +89,7 @@
   // The options can be deleted immediately after interpreter creation.
   TfLiteInterpreterOptionsDelete(options);
 
-  //
   // Allocate the tensors and fill the input tensor.
-  //
   ASSERT_EQ(TfLiteInterpreterAllocateTensors(interpreter), kTfLiteOk);
   TfLiteTensor* input_tensor =
       TfLiteInterpreterGetInputTensor(interpreter, /*input_index=*/0);
@@ -88,9 +101,7 @@
                                        input.size() * sizeof(float)),
             kTfLiteOk);
 
-  //
   // Run the interpreter and read the output tensor.
-  //
   ASSERT_EQ(TfLiteInterpreterInvoke(interpreter), kTfLiteOk);
 
   const TfLiteTensor* output_tensor =
diff --git a/tensorflow/lite/delegates/utils/experimental/stable_delegate/BUILD b/tensorflow/lite/delegates/utils/experimental/stable_delegate/BUILD
index cadf44e..5ac1eb2 100644
--- a/tensorflow/lite/delegates/utils/experimental/stable_delegate/BUILD
+++ b/tensorflow/lite/delegates/utils/experimental/stable_delegate/BUILD
@@ -79,6 +79,7 @@
     deps = [
         ":delegate_loader",
         "//tensorflow/lite/acceleration/configuration:configuration_fbs",
+        "//tensorflow/lite/acceleration/configuration/c:stable_delegate",
         "//tensorflow/lite/delegates/utils/experimental/sample_stable_delegate",
         "@com_google_googletest//:gtest_main",
     ],
diff --git a/tensorflow/lite/delegates/utils/experimental/stable_delegate/delegate_loader.cc b/tensorflow/lite/delegates/utils/experimental/stable_delegate/delegate_loader.cc
index abf89f7..7aff358 100644
--- a/tensorflow/lite/delegates/utils/experimental/stable_delegate/delegate_loader.cc
+++ b/tensorflow/lite/delegates/utils/experimental/stable_delegate/delegate_loader.cc
@@ -21,6 +21,7 @@
 #include <cerrno>
 #include <string>
 
+#include "tensorflow/lite/acceleration/configuration/c/stable_delegate.h"
 #include "tensorflow/lite/experimental/acceleration/compatibility/android_info.h"
 #include "tensorflow/lite/tools/logging.h"
 
diff --git a/tensorflow/lite/delegates/utils/experimental/stable_delegate/delegate_loader.h b/tensorflow/lite/delegates/utils/experimental/stable_delegate/delegate_loader.h
index 940e4fd..e932fe8 100644
--- a/tensorflow/lite/delegates/utils/experimental/stable_delegate/delegate_loader.h
+++ b/tensorflow/lite/delegates/utils/experimental/stable_delegate/delegate_loader.h
@@ -23,8 +23,8 @@
 namespace delegates {
 namespace utils {
 
-const char kTfLiteStableDelegateSymbol[] = "TFL_TheStableDelegate";
-const char kTfLiteLibraryPathEnvironmentVariable[] =
+constexpr char kTfLiteStableDelegateSymbol[] = "TFL_TheStableDelegate";
+constexpr char kTfLiteLibraryPathEnvironmentVariable[] =
     "TFLITE_STABLE_DELEGATE_LIBRARY_PATH";
 
 // Loads the TFLite delegate shared library and returns the pointer to
diff --git a/tensorflow/lite/delegates/utils/experimental/stable_delegate/delegate_loader_test.cc b/tensorflow/lite/delegates/utils/experimental/stable_delegate/delegate_loader_test.cc
index 4a25c28..d6880d5 100644
--- a/tensorflow/lite/delegates/utils/experimental/stable_delegate/delegate_loader_test.cc
+++ b/tensorflow/lite/delegates/utils/experimental/stable_delegate/delegate_loader_test.cc
@@ -14,10 +14,10 @@
 ==============================================================================*/
 #include "tensorflow/lite/delegates/utils/experimental/stable_delegate/delegate_loader.h"
 
-#include <cstddef>
 #include <cstdlib>
 
 #include <gtest/gtest.h>
+#include "tensorflow/lite/acceleration/configuration/c/stable_delegate.h"
 #include "tensorflow/lite/acceleration/configuration/configuration_generated.h"
 #include "tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_stable_delegate.h"
 
@@ -32,9 +32,11 @@
   const TfLiteStableDelegate* stable_delegate_handle =
       LoadDelegateFromSharedLibrary(
           "tensorflow/lite/delegates/utils/experimental/"
-          "sample_stable_delegate/libtensorflowlite_sample_stable_delegate.so");
+          "sample_stable_delegate/"
+          "libtensorflowlite_sample_stable_delegate.so"
+          );
 
-  EXPECT_NE(stable_delegate_handle, nullptr);
+  ASSERT_NE(stable_delegate_handle, nullptr);
   EXPECT_STREQ(stable_delegate_handle->delegate_abi_version,
                TFL_STABLE_DELEGATE_ABI_VERSION);
   EXPECT_STREQ(stable_delegate_handle->delegate_name,
@@ -58,9 +60,10 @@
       flatbuffer_builder.GetBufferPointer());
   auto delegate = stable_delegate_handle->delegate_plugin->create(settings);
 
-  EXPECT_NE(delegate, nullptr);
+  ASSERT_NE(delegate, nullptr);
   EXPECT_EQ(
       stable_delegate_handle->delegate_plugin->get_delegate_errno(delegate), 0);
+
   stable_delegate_handle->delegate_plugin->destroy(delegate);
 }
 
diff --git a/tensorflow/lite/delegates/xnnpack/BUILD b/tensorflow/lite/delegates/xnnpack/BUILD
index f892521..affa060 100644
--- a/tensorflow/lite/delegates/xnnpack/BUILD
+++ b/tensorflow/lite/delegates/xnnpack/BUILD
@@ -236,6 +236,7 @@
         "//tensorflow/lite/schema:schema_fbs",
         "//tensorflow/lite/tools/optimize:reduced_precision_support",
         "@XNNPACK",
+        "@XNNPACK//:experiments_config",
     ],
 )
 
@@ -270,6 +271,7 @@
         "//tensorflow/lite/schema:schema_fbs",
         "//tensorflow/lite/tools/optimize:reduced_precision_support",
         "@XNNPACK//:XNNPACK_test_mode",
+        "@XNNPACK//:experiments_config",
     ],
 )
 
@@ -457,6 +459,26 @@
 )
 
 cc_library(
+    name = "dynamically_quantized_conv_2d_tester",
+    testonly = 1,
+    srcs = ["dynamically_quantized_conv_2d_tester.cc"],
+    hdrs = ["dynamically_quantized_conv_2d_tester.h"],
+    deps = [
+        ":xnnpack_delegate_test_mode",
+        "//tensorflow/lite:framework",
+        "//tensorflow/lite:schema_fbs_version",
+        "//tensorflow/lite/c:c_api_types",
+        "//tensorflow/lite/core:framework",
+        "//tensorflow/lite/core/c:common",
+        "//tensorflow/lite/core/kernels:builtin_ops",
+        "//tensorflow/lite/schema:schema_conversion_utils",
+        "//tensorflow/lite/schema:schema_fbs",
+        "//testing/base/public:gunit_for_library_testonly",
+        "@flatbuffers",
+    ],
+)
+
+cc_library(
     name = "fully_connected_tester",
     testonly = 1,
     srcs = ["fully_connected_tester.cc"],
@@ -1279,6 +1301,22 @@
 )
 
 cc_test(
+    name = "dynamically_quantized_conv_2d_test",
+    srcs = ["dynamically_quantized_conv_2d_test.cc"],
+    linkopts = select({
+        "//tensorflow:emscripten": EMSCRIPTEN_LINKOPTS,
+        "//conditions:default": [],
+    }),
+    deps = [
+        ":dynamically_quantized_conv_2d_tester",
+        ":test_main",
+        ":xnnpack_delegate_test_mode",
+        "//tensorflow/lite/c:c_api_types",
+        "@com_google_googletest//:gtest",
+    ],
+)
+
+cc_test(
     name = "elu_test",
     srcs = ["elu_test.cc"],
     linkopts = select({
diff --git a/tensorflow/lite/delegates/xnnpack/conv_2d_test.cc b/tensorflow/lite/delegates/xnnpack/conv_2d_test.cc
index cab06da..5654c28 100644
--- a/tensorflow/lite/delegates/xnnpack/conv_2d_test.cc
+++ b/tensorflow/lite/delegates/xnnpack/conv_2d_test.cc
@@ -816,5 +816,40 @@
       .Test(xnnpack_delegate.get());
 }
 
+TEST(Conv2D, AdaptiveAvxOptimization) {
+  TfLiteXNNPackDelegateOptions xnnpack_options =
+      TfLiteXNNPackDelegateOptionsDefault();
+  xnnpack_options.num_threads = 2;
+  xnnpack_options.experimental_adaptive_avx_optimization = true;
+  std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
+      xnnpack_delegate(TfLiteXNNPackDelegateCreate(&xnnpack_options),
+                       TfLiteXNNPackDelegateDelete);
+
+  std::random_device random_device;
+  auto rng = std::mt19937(random_device());
+  auto batch_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 4), std::ref(rng));
+  auto input_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(5, 25), std::ref(rng));
+  auto kernel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(3, 5), std::ref(rng));
+  auto stride_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
+  auto channel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 16), std::ref(rng));
+
+  Conv2DTester()
+      .BatchSize(batch_rng())
+      .InputHeight(input_rng())
+      .InputWidth(input_rng())
+      .InputChannels(channel_rng())
+      .OutputChannels(channel_rng())
+      .KernelHeight(kernel_rng())
+      .KernelWidth(kernel_rng())
+      .StrideHeight(stride_rng())
+      .StrideWidth(stride_rng())
+      .Test(xnnpack_delegate.get());
+}
+
 }  // namespace xnnpack
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/xnnpack/dynamically_quantized_conv_2d_test.cc b/tensorflow/lite/delegates/xnnpack/dynamically_quantized_conv_2d_test.cc
new file mode 100644
index 0000000..3e0629e
--- /dev/null
+++ b/tensorflow/lite/delegates/xnnpack/dynamically_quantized_conv_2d_test.cc
@@ -0,0 +1,727 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cstdint>
+#include <functional>
+#include <memory>
+#include <random>
+
+#include <gtest/gtest.h>
+#include "tensorflow/lite/c/c_api_types.h"
+#include "tensorflow/lite/delegates/xnnpack/dynamically_quantized_conv_2d_tester.h"
+#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
+
+namespace tflite {
+namespace xnnpack {
+
+TEST(DynamicallyQuantizedConv2D, 1x1) {
+  TfLiteXNNPackDelegateOptions delegate_options =
+      TfLiteXNNPackDelegateOptionsDefault();
+  delegate_options.flags |=
+      TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_LATEST_OPERATORS;
+  std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
+      xnnpack_delegate(TfLiteXNNPackDelegateCreate(&delegate_options),
+                       TfLiteXNNPackDelegateDelete);
+
+  std::random_device random_device;
+  auto rng = std::mt19937(random_device());
+  auto batch_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 4), std::ref(rng));
+  auto input_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(5, 25), std::ref(rng));
+  auto channel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 16), std::ref(rng));
+
+  DynamicallyQuantizedConv2DTester()
+      .BatchSize(batch_rng())
+      .InputHeight(input_rng())
+      .InputWidth(input_rng())
+      .InputChannels(channel_rng())
+      .OutputChannels(channel_rng())
+      .KernelHeight(1)
+      .KernelWidth(1)
+      .ValidPadding()
+      .Test(xnnpack_delegate.get());
+}
+
+TEST(DynamicallyQuantizedConv2D, 3x3) {
+  TfLiteXNNPackDelegateOptions delegate_options =
+      TfLiteXNNPackDelegateOptionsDefault();
+  delegate_options.flags |=
+      TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_LATEST_OPERATORS;
+  std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
+      xnnpack_delegate(TfLiteXNNPackDelegateCreate(&delegate_options),
+                       TfLiteXNNPackDelegateDelete);
+
+  std::random_device random_device;
+  auto rng = std::mt19937(random_device());
+  auto batch_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 4), std::ref(rng));
+  auto input_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(5, 25), std::ref(rng));
+  auto channel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 16), std::ref(rng));
+
+  DynamicallyQuantizedConv2DTester()
+      .BatchSize(batch_rng())
+      .InputHeight(input_rng())
+      .InputWidth(input_rng())
+      .InputChannels(channel_rng())
+      .OutputChannels(channel_rng())
+      .KernelHeight(3)
+      .KernelWidth(3)
+      .SamePadding()
+      .Test(xnnpack_delegate.get());
+}
+
+TEST(DynamicallyQuantizedConv2D, 3x3Stride2) {
+  TfLiteXNNPackDelegateOptions delegate_options =
+      TfLiteXNNPackDelegateOptionsDefault();
+  delegate_options.flags |=
+      TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_LATEST_OPERATORS;
+  std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
+      xnnpack_delegate(TfLiteXNNPackDelegateCreate(&delegate_options),
+                       TfLiteXNNPackDelegateDelete);
+
+  std::random_device random_device;
+  auto rng = std::mt19937(random_device());
+  auto batch_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 4), std::ref(rng));
+  auto input_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(5, 25), std::ref(rng));
+  auto channel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 16), std::ref(rng));
+
+  DynamicallyQuantizedConv2DTester()
+      .BatchSize(batch_rng())
+      .InputHeight(input_rng())
+      .InputWidth(input_rng())
+      .InputChannels(channel_rng())
+      .OutputChannels(channel_rng())
+      .KernelHeight(3)
+      .KernelWidth(3)
+      .StrideHeight(2)
+      .StrideWidth(2)
+      .SamePadding()
+      .Test(xnnpack_delegate.get());
+}
+
+TEST(DynamicallyQuantizedConv2D, Grouped) {
+  TfLiteXNNPackDelegateOptions delegate_options =
+      TfLiteXNNPackDelegateOptionsDefault();
+  delegate_options.flags |=
+      TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_LATEST_OPERATORS;
+  std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
+      xnnpack_delegate(TfLiteXNNPackDelegateCreate(&delegate_options),
+                       TfLiteXNNPackDelegateDelete);
+
+  std::random_device random_device;
+  auto rng = std::mt19937(random_device());
+  auto batch_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 4), std::ref(rng));
+  auto input_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(5, 25), std::ref(rng));
+  auto channel_per_group_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 16), std::ref(rng));
+  auto groups_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 8), std::ref(rng));
+
+  auto groups = groups_rng();
+  DynamicallyQuantizedConv2DTester()
+      .BatchSize(batch_rng())
+      .InputHeight(input_rng())
+      .InputWidth(input_rng())
+      .InputChannels(groups * channel_per_group_rng())
+      .OutputChannels(groups * channel_per_group_rng())
+      .Groups(groups)
+      .KernelHeight(3)
+      .KernelWidth(3)
+      .SamePadding()
+      .Test(xnnpack_delegate.get());
+}
+
+TEST(DynamicallyQuantizedConv2D, SmallKernelWithSamePadding) {
+  TfLiteXNNPackDelegateOptions delegate_options =
+      TfLiteXNNPackDelegateOptionsDefault();
+  delegate_options.flags |=
+      TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_LATEST_OPERATORS;
+  std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
+      xnnpack_delegate(TfLiteXNNPackDelegateCreate(&delegate_options),
+                       TfLiteXNNPackDelegateDelete);
+
+  std::random_device random_device;
+  auto rng = std::mt19937(random_device());
+  auto batch_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 4), std::ref(rng));
+  auto input_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
+  auto kernel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 7), std::ref(rng));
+  auto channel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 16), std::ref(rng));
+
+  DynamicallyQuantizedConv2DTester()
+      .BatchSize(batch_rng())
+      .InputHeight(input_rng())
+      .InputWidth(input_rng())
+      .InputChannels(channel_rng())
+      .OutputChannels(channel_rng())
+      .KernelHeight(kernel_rng())
+      .KernelWidth(kernel_rng())
+      .SamePadding()
+      .Test(xnnpack_delegate.get());
+}
+
+TEST(DynamicallyQuantizedConv2D, SmallKernelWithValidPadding) {
+  TfLiteXNNPackDelegateOptions delegate_options =
+      TfLiteXNNPackDelegateOptionsDefault();
+  delegate_options.flags |=
+      TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_LATEST_OPERATORS;
+  std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
+      xnnpack_delegate(TfLiteXNNPackDelegateCreate(&delegate_options),
+                       TfLiteXNNPackDelegateDelete);
+
+  std::random_device random_device;
+  auto rng = std::mt19937(random_device());
+  auto batch_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 4), std::ref(rng));
+  auto input_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
+  auto kernel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 7), std::ref(rng));
+  auto channel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 16), std::ref(rng));
+
+  DynamicallyQuantizedConv2DTester()
+      .BatchSize(batch_rng())
+      .InputHeight(input_rng())
+      .InputWidth(input_rng())
+      .InputChannels(channel_rng())
+      .OutputChannels(channel_rng())
+      .KernelHeight(kernel_rng())
+      .KernelWidth(kernel_rng())
+      .ValidPadding()
+      .Test(xnnpack_delegate.get());
+}
+
+TEST(DynamicallyQuantizedConv2D, StrideWithSamePadding) {
+  TfLiteXNNPackDelegateOptions delegate_options =
+      TfLiteXNNPackDelegateOptionsDefault();
+  delegate_options.flags |=
+      TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_LATEST_OPERATORS;
+  std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
+      xnnpack_delegate(TfLiteXNNPackDelegateCreate(&delegate_options),
+                       TfLiteXNNPackDelegateDelete);
+  std::random_device random_device;
+  auto rng = std::mt19937(random_device());
+  auto batch_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 4), std::ref(rng));
+  auto input_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
+  auto kernel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(3, 5), std::ref(rng));
+  auto stride_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
+  auto channel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 16), std::ref(rng));
+
+  DynamicallyQuantizedConv2DTester()
+      .BatchSize(batch_rng())
+      .InputHeight(input_rng())
+      .InputWidth(input_rng())
+      .InputChannels(channel_rng())
+      .OutputChannels(channel_rng())
+      .KernelHeight(kernel_rng())
+      .KernelWidth(kernel_rng())
+      .StrideHeight(stride_rng())
+      .StrideWidth(stride_rng())
+      .SamePadding()
+      .Test(xnnpack_delegate.get());
+}
+
+TEST(DynamicallyQuantizedConv2D, StrideWithValidPadding) {
+  TfLiteXNNPackDelegateOptions delegate_options =
+      TfLiteXNNPackDelegateOptionsDefault();
+  delegate_options.flags |=
+      TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_LATEST_OPERATORS;
+  std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
+      xnnpack_delegate(TfLiteXNNPackDelegateCreate(&delegate_options),
+                       TfLiteXNNPackDelegateDelete);
+
+  std::random_device random_device;
+  auto rng = std::mt19937(random_device());
+  auto batch_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 4), std::ref(rng));
+  auto input_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
+  auto kernel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(3, 5), std::ref(rng));
+  auto stride_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
+  auto channel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 16), std::ref(rng));
+
+  DynamicallyQuantizedConv2DTester()
+      .BatchSize(batch_rng())
+      .InputHeight(input_rng())
+      .InputWidth(input_rng())
+      .InputChannels(channel_rng())
+      .OutputChannels(channel_rng())
+      .KernelHeight(kernel_rng())
+      .KernelWidth(kernel_rng())
+      .StrideHeight(stride_rng())
+      .StrideWidth(stride_rng())
+      .ValidPadding()
+      .Test(xnnpack_delegate.get());
+}
+
+TEST(DynamicallyQuantizedConv2D, DilationWithSamePadding) {
+  TfLiteXNNPackDelegateOptions delegate_options =
+      TfLiteXNNPackDelegateOptionsDefault();
+  delegate_options.flags |=
+      TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_LATEST_OPERATORS;
+  std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
+      xnnpack_delegate(TfLiteXNNPackDelegateCreate(&delegate_options),
+                       TfLiteXNNPackDelegateDelete);
+
+  std::random_device random_device;
+  auto rng = std::mt19937(random_device());
+  auto batch_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 4), std::ref(rng));
+  auto input_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
+  auto kernel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
+  auto dilation_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
+  auto channel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 16), std::ref(rng));
+
+  DynamicallyQuantizedConv2DTester()
+      .BatchSize(batch_rng())
+      .InputHeight(input_rng())
+      .InputWidth(input_rng())
+      .InputChannels(channel_rng())
+      .OutputChannels(channel_rng())
+      .KernelHeight(kernel_rng())
+      .KernelWidth(kernel_rng())
+      .DilationHeight(dilation_rng())
+      .DilationWidth(dilation_rng())
+      .SamePadding()
+      .Test(xnnpack_delegate.get());
+}
+
+TEST(DynamicallyQuantizedConv2D, DilationWithValidPadding) {
+  TfLiteXNNPackDelegateOptions delegate_options =
+      TfLiteXNNPackDelegateOptionsDefault();
+  delegate_options.flags |=
+      TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_LATEST_OPERATORS;
+  std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
+      xnnpack_delegate(TfLiteXNNPackDelegateCreate(&delegate_options),
+                       TfLiteXNNPackDelegateDelete);
+
+  std::random_device random_device;
+  auto rng = std::mt19937(random_device());
+  auto batch_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 4), std::ref(rng));
+  auto input_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
+  auto kernel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
+  auto dilation_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
+  auto channel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 16), std::ref(rng));
+
+  DynamicallyQuantizedConv2DTester()
+      .BatchSize(batch_rng())
+      .InputHeight(input_rng())
+      .InputWidth(input_rng())
+      .InputChannels(channel_rng())
+      .OutputChannels(channel_rng())
+      .KernelHeight(kernel_rng())
+      .KernelWidth(kernel_rng())
+      .DilationHeight(dilation_rng())
+      .DilationWidth(dilation_rng())
+      .ValidPadding()
+      .Test(xnnpack_delegate.get());
+}
+
+TEST(DynamicallyQuantizedConv2D, TensorWiseQuantizedInt8Weights) {
+  TfLiteXNNPackDelegateOptions delegate_options =
+      TfLiteXNNPackDelegateOptionsDefault();
+  delegate_options.flags |=
+      TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_LATEST_OPERATORS;
+  std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
+      xnnpack_delegate(TfLiteXNNPackDelegateCreate(&delegate_options),
+                       TfLiteXNNPackDelegateDelete);
+
+  std::random_device random_device;
+  auto rng = std::mt19937(random_device());
+  auto batch_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 4), std::ref(rng));
+  auto input_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
+  auto kernel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(3, 5), std::ref(rng));
+  auto stride_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
+  auto channel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 16), std::ref(rng));
+
+  DynamicallyQuantizedConv2DTester()
+      .BatchSize(batch_rng())
+      .InputHeight(input_rng())
+      .InputWidth(input_rng())
+      .InputChannels(channel_rng())
+      .OutputChannels(channel_rng())
+      .KernelHeight(kernel_rng())
+      .KernelWidth(kernel_rng())
+      .StrideHeight(stride_rng())
+      .StrideWidth(stride_rng())
+      .Test(xnnpack_delegate.get());
+}
+
+TEST(DynamicallyQuantizedConv2D, ChannelWiseQuantizedInt8Weights) {
+  TfLiteXNNPackDelegateOptions delegate_options =
+      TfLiteXNNPackDelegateOptionsDefault();
+  delegate_options.flags |=
+      TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_LATEST_OPERATORS;
+  std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
+      xnnpack_delegate(TfLiteXNNPackDelegateCreate(&delegate_options),
+                       TfLiteXNNPackDelegateDelete);
+
+  std::random_device random_device;
+  auto rng = std::mt19937(random_device());
+  auto batch_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 4), std::ref(rng));
+  auto input_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
+  auto kernel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(3, 5), std::ref(rng));
+  auto stride_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
+  auto channel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 16), std::ref(rng));
+
+  DynamicallyQuantizedConv2DTester()
+      .BatchSize(batch_rng())
+      .InputHeight(input_rng())
+      .InputWidth(input_rng())
+      .InputChannels(channel_rng())
+      .OutputChannels(channel_rng())
+      .KernelHeight(kernel_rng())
+      .KernelWidth(kernel_rng())
+      .StrideHeight(stride_rng())
+      .StrideWidth(stride_rng())
+      .Test(xnnpack_delegate.get());
+}
+
+TEST(DynamicallyQuantizedConv2D, ReluActivation) {
+  TfLiteXNNPackDelegateOptions delegate_options =
+      TfLiteXNNPackDelegateOptionsDefault();
+  delegate_options.flags |=
+      TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_LATEST_OPERATORS;
+  std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
+      xnnpack_delegate(TfLiteXNNPackDelegateCreate(&delegate_options),
+                       TfLiteXNNPackDelegateDelete);
+
+  std::random_device random_device;
+  auto rng = std::mt19937(random_device());
+  auto batch_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 4), std::ref(rng));
+  auto input_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
+  auto kernel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(3, 5), std::ref(rng));
+  auto stride_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
+  auto channel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 16), std::ref(rng));
+
+  DynamicallyQuantizedConv2DTester()
+      .BatchSize(batch_rng())
+      .InputHeight(input_rng())
+      .InputWidth(input_rng())
+      .InputChannels(channel_rng())
+      .OutputChannels(channel_rng())
+      .KernelHeight(kernel_rng())
+      .KernelWidth(kernel_rng())
+      .StrideHeight(stride_rng())
+      .StrideWidth(stride_rng())
+      .ReluActivation()
+      .Test(xnnpack_delegate.get());
+}
+
+TEST(DynamicallyQuantizedConv2D, Relu6Activation) {
+  TfLiteXNNPackDelegateOptions delegate_options =
+      TfLiteXNNPackDelegateOptionsDefault();
+  delegate_options.flags |=
+      TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_LATEST_OPERATORS;
+  std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
+      xnnpack_delegate(TfLiteXNNPackDelegateCreate(&delegate_options),
+                       TfLiteXNNPackDelegateDelete);
+
+  std::random_device random_device;
+  auto rng = std::mt19937(random_device());
+  auto batch_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 4), std::ref(rng));
+  auto input_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
+  auto kernel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(3, 5), std::ref(rng));
+  auto stride_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
+  auto channel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 16), std::ref(rng));
+
+  DynamicallyQuantizedConv2DTester()
+      .BatchSize(batch_rng())
+      .InputHeight(input_rng())
+      .InputWidth(input_rng())
+      .InputChannels(channel_rng())
+      .OutputChannels(channel_rng())
+      .KernelHeight(kernel_rng())
+      .KernelWidth(kernel_rng())
+      .StrideHeight(stride_rng())
+      .StrideWidth(stride_rng())
+      .Relu6Activation()
+      .Test(xnnpack_delegate.get());
+}
+
+TEST(DynamicallyQuantizedConv2D, ReluMinus1To1Activation) {
+  TfLiteXNNPackDelegateOptions delegate_options =
+      TfLiteXNNPackDelegateOptionsDefault();
+  delegate_options.flags |=
+      TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_LATEST_OPERATORS;
+  std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
+      xnnpack_delegate(TfLiteXNNPackDelegateCreate(&delegate_options),
+                       TfLiteXNNPackDelegateDelete);
+
+  std::random_device random_device;
+  auto rng = std::mt19937(random_device());
+  auto batch_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 4), std::ref(rng));
+  auto input_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
+  auto kernel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(3, 5), std::ref(rng));
+  auto stride_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
+  auto channel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 16), std::ref(rng));
+
+  DynamicallyQuantizedConv2DTester()
+      .BatchSize(batch_rng())
+      .InputHeight(input_rng())
+      .InputWidth(input_rng())
+      .InputChannels(channel_rng())
+      .OutputChannels(channel_rng())
+      .KernelHeight(kernel_rng())
+      .KernelWidth(kernel_rng())
+      .StrideHeight(stride_rng())
+      .StrideWidth(stride_rng())
+      .ReluMinus1To1Activation()
+      .Test(xnnpack_delegate.get());
+}
+
+TEST(DynamicallyQuantizedConv2D, TanhActivation) {
+  TfLiteXNNPackDelegateOptions delegate_options =
+      TfLiteXNNPackDelegateOptionsDefault();
+  delegate_options.flags |=
+      TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_LATEST_OPERATORS;
+  std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
+      xnnpack_delegate(TfLiteXNNPackDelegateCreate(&delegate_options),
+                       TfLiteXNNPackDelegateDelete);
+
+  std::random_device random_device;
+  auto rng = std::mt19937(random_device());
+  auto batch_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 4), std::ref(rng));
+  auto input_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
+  auto kernel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(3, 5), std::ref(rng));
+  auto stride_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
+  auto channel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 16), std::ref(rng));
+
+  DynamicallyQuantizedConv2DTester()
+      .BatchSize(batch_rng())
+      .InputHeight(input_rng())
+      .InputWidth(input_rng())
+      .InputChannels(channel_rng())
+      .OutputChannels(channel_rng())
+      .KernelHeight(kernel_rng())
+      .KernelWidth(kernel_rng())
+      .StrideHeight(stride_rng())
+      .StrideWidth(stride_rng())
+      .TanhActivation()
+      .Test(xnnpack_delegate.get());
+}
+
+TEST(DynamicallyQuantizedConv2D, SignBitActivation) {
+  TfLiteXNNPackDelegateOptions delegate_options =
+      TfLiteXNNPackDelegateOptionsDefault();
+  delegate_options.flags |=
+      TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_LATEST_OPERATORS;
+  std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
+      xnnpack_delegate(TfLiteXNNPackDelegateCreate(&delegate_options),
+                       TfLiteXNNPackDelegateDelete);
+
+  std::random_device random_device;
+  auto rng = std::mt19937(random_device());
+  auto batch_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 4), std::ref(rng));
+  auto input_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
+  auto kernel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(3, 5), std::ref(rng));
+  auto stride_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
+  auto channel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 16), std::ref(rng));
+
+  DynamicallyQuantizedConv2DTester()
+      .BatchSize(batch_rng())
+      .InputHeight(input_rng())
+      .InputWidth(input_rng())
+      .InputChannels(channel_rng())
+      .OutputChannels(channel_rng())
+      .KernelHeight(kernel_rng())
+      .KernelWidth(kernel_rng())
+      .StrideHeight(stride_rng())
+      .StrideWidth(stride_rng())
+      .SignBitActivation()
+      .Test(xnnpack_delegate.get());
+}
+
+TEST(DynamicallyQuantizedConv2D, MultiThreading) {
+  TfLiteXNNPackDelegateOptions delegate_options =
+      TfLiteXNNPackDelegateOptionsDefault();
+  delegate_options.num_threads = 2;
+  delegate_options.flags |=
+      TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_LATEST_OPERATORS;
+  std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
+      xnnpack_delegate(TfLiteXNNPackDelegateCreate(&delegate_options),
+                       TfLiteXNNPackDelegateDelete);
+
+  std::random_device random_device;
+  auto rng = std::mt19937(random_device());
+  auto batch_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 4), std::ref(rng));
+  auto input_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
+  auto kernel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(3, 5), std::ref(rng));
+  auto stride_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
+  auto channel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 16), std::ref(rng));
+
+  DynamicallyQuantizedConv2DTester()
+      .BatchSize(batch_rng())
+      .InputHeight(input_rng())
+      .InputWidth(input_rng())
+      .InputChannels(channel_rng())
+      .OutputChannels(channel_rng())
+      .KernelHeight(kernel_rng())
+      .KernelWidth(kernel_rng())
+      .StrideHeight(stride_rng())
+      .StrideWidth(stride_rng())
+      .Test(xnnpack_delegate.get());
+}
+
+TEST(DynamicallyQuantizedConv2D, WeightsCache) {
+  TfLiteXNNPackDelegateOptions delegate_options =
+      TfLiteXNNPackDelegateOptionsDefault();
+  std::unique_ptr<TfLiteXNNPackDelegateWeightsCache,
+                  decltype(&TfLiteXNNPackDelegateWeightsCacheDelete)>
+      weights_cache(TfLiteXNNPackDelegateWeightsCacheCreate(),
+                    TfLiteXNNPackDelegateWeightsCacheDelete);
+  delegate_options.weights_cache = weights_cache.get();
+  delegate_options.flags |=
+      TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_LATEST_OPERATORS;
+  std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
+      xnnpack_delegate(TfLiteXNNPackDelegateCreate(&delegate_options),
+                       TfLiteXNNPackDelegateDelete);
+
+  std::random_device random_device;
+  auto rng = std::mt19937(random_device());
+  auto batch_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 4), std::ref(rng));
+  auto input_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(10, 25), std::ref(rng));
+  auto kernel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(3, 5), std::ref(rng));
+  auto stride_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
+  auto channel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 16), std::ref(rng));
+
+  DynamicallyQuantizedConv2DTester()
+      .BatchSize(batch_rng())
+      .InputHeight(input_rng())
+      .InputWidth(input_rng())
+      .InputChannels(channel_rng())
+      .OutputChannels(channel_rng())
+      .KernelHeight(kernel_rng())
+      .KernelWidth(kernel_rng())
+      .StrideHeight(stride_rng())
+      .StrideWidth(stride_rng())
+      .WeightsCache(weights_cache.get())
+      .Test(xnnpack_delegate.get());
+}
+
+TEST(DynamicallyQuantizedConv2D, TransientIndirectionBuffer) {
+  TfLiteXNNPackDelegateOptions xnnpack_options =
+      TfLiteXNNPackDelegateOptionsDefault();
+  xnnpack_options.num_threads = 2;
+  xnnpack_options.flags |=
+      TFLITE_XNNPACK_DELEGATE_FLAG_TRANSIENT_INDIRECTION_BUFFER;
+  xnnpack_options.flags |= TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_LATEST_OPERATORS;
+  std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
+      xnnpack_delegate(TfLiteXNNPackDelegateCreate(&xnnpack_options),
+                       TfLiteXNNPackDelegateDelete);
+
+  std::random_device random_device;
+  auto rng = std::mt19937(random_device());
+  auto batch_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 4), std::ref(rng));
+  auto input_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(5, 25), std::ref(rng));
+  auto kernel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(3, 5), std::ref(rng));
+  auto stride_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 3), std::ref(rng));
+  auto channel_rng =
+      std::bind(std::uniform_int_distribution<int32_t>(2, 16), std::ref(rng));
+
+  DynamicallyQuantizedConv2DTester()
+      .BatchSize(batch_rng())
+      .InputHeight(input_rng())
+      .InputWidth(input_rng())
+      .InputChannels(channel_rng())
+      .OutputChannels(channel_rng())
+      .KernelHeight(kernel_rng())
+      .KernelWidth(kernel_rng())
+      .StrideHeight(stride_rng())
+      .StrideWidth(stride_rng())
+      .Test(xnnpack_delegate.get());
+}
+
+}  // namespace xnnpack
+}  // namespace tflite
diff --git a/tensorflow/lite/delegates/xnnpack/dynamically_quantized_conv_2d_tester.cc b/tensorflow/lite/delegates/xnnpack/dynamically_quantized_conv_2d_tester.cc
new file mode 100644
index 0000000..de9276f
--- /dev/null
+++ b/tensorflow/lite/delegates/xnnpack/dynamically_quantized_conv_2d_tester.cc
@@ -0,0 +1,246 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/delegates/xnnpack/dynamically_quantized_conv_2d_tester.h"
+
+#include <algorithm>
+#include <array>
+#include <cstdint>
+#include <cstdlib>
+#include <functional>
+#include <limits>
+#include <memory>
+#include <random>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "flatbuffers/buffer.h"  // from @flatbuffers
+#include "flatbuffers/flatbuffer_builder.h"  // from @flatbuffers
+#include "flatbuffers/string.h"  // from @flatbuffers
+#include "tensorflow/lite/c/c_api_types.h"
+#include "tensorflow/lite/core/kernels/register.h"
+#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
+#include "tensorflow/lite/interpreter.h"
+#include "tensorflow/lite/schema/schema_conversion_utils.h"
+#include "tensorflow/lite/schema/schema_generated.h"
+#include "tensorflow/lite/version.h"
+
+namespace tflite {
+namespace xnnpack {
+
+void DynamicallyQuantizedConv2DTester::Test(TfLiteDelegate* delegate) const {
+  std::vector<char> buffer = CreateTfLiteModel();
+  const Model* model = GetModel(buffer.data());
+
+  std::unique_ptr<Interpreter> delegate_interpreter;
+  ASSERT_EQ(
+      InterpreterBuilder(
+          model,
+          ::tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates())(
+          &delegate_interpreter),
+      kTfLiteOk);
+  std::unique_ptr<Interpreter> default_interpreter;
+  ASSERT_EQ(
+      InterpreterBuilder(
+          model,
+          ::tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates())(
+          &default_interpreter),
+      kTfLiteOk);
+
+  ASSERT_TRUE(delegate_interpreter);
+  ASSERT_TRUE(default_interpreter);
+
+  ASSERT_EQ(delegate_interpreter->inputs().size(), 1);
+  ASSERT_EQ(default_interpreter->inputs().size(), 1);
+
+  ASSERT_EQ(delegate_interpreter->outputs().size(), 1);
+  ASSERT_EQ(default_interpreter->outputs().size(), 1);
+
+  ASSERT_EQ(delegate_interpreter->AllocateTensors(), kTfLiteOk);
+  ASSERT_EQ(default_interpreter->AllocateTensors(), kTfLiteOk);
+
+  ASSERT_EQ(delegate_interpreter->ModifyGraphWithDelegate(delegate), kTfLiteOk);
+
+  if (weights_cache_ != nullptr) {
+    TfLiteXNNPackDelegateWeightsCacheFinalizeHard(weights_cache_);
+  }
+
+  std::random_device random_device;
+  auto rng = std::mt19937(random_device());
+  auto input_rng =
+      std::bind(std::uniform_real_distribution<float>(-10, 10), std::ref(rng));
+  float* default_input_data = default_interpreter->typed_input_tensor<float>(0);
+  std::generate_n(default_input_data,
+                  BatchSize() * InputHeight() * InputWidth() * InputChannels(),
+                  input_rng);
+
+  float* delegate_input_data =
+      delegate_interpreter->typed_input_tensor<float>(0);
+  std::copy_n(default_input_data,
+              BatchSize() * InputHeight() * InputWidth() * InputChannels(),
+              delegate_input_data);
+
+  ASSERT_EQ(default_interpreter->Invoke(), kTfLiteOk);
+  ASSERT_EQ(delegate_interpreter->Invoke(), kTfLiteOk);
+
+  float* default_output_data =
+      default_interpreter->typed_output_tensor<float>(0);
+  float* delegate_output_data =
+      delegate_interpreter->typed_output_tensor<float>(0);
+
+  const int num_output_values =
+      BatchSize() * OutputHeight() * OutputWidth() * OutputChannels();
+  int different_output_values = 0;
+  // TFLite rounds to nearest with ties to Away. XNNPACK rounds to nearest with
+  // ties to even. IEEE 754 states: "Round to nearest, ties to even" is the
+  // default for binary floating point and the recommended default for decimal.
+  // For this reason, many output values may differ slightly.
+  for (int32_t i = 0; i < BatchSize(); i++) {
+    for (int32_t y = 0; y < OutputHeight(); y++) {
+      for (int32_t x = 0; x < OutputWidth(); x++) {
+        for (int32_t c = 0; c < OutputChannels(); c++) {
+          const int32_t index = ((i * OutputHeight() + y) * OutputWidth() + x) *
+                                    OutputChannels() +
+                                c;
+          if (std::abs(default_output_data[index] -
+                       delegate_output_data[index]) >
+              0.005 * std::abs(default_output_data[index])) {
+            ++different_output_values;
+          }
+        }
+      }
+    }
+  }
+  if (different_output_values > 0.05 * num_output_values) {
+    GTEST_FAIL() << (float)different_output_values / num_output_values * 100.f
+                 << "% of output values differ";
+  }
+}
+
+std::vector<char> DynamicallyQuantizedConv2DTester::CreateTfLiteModel() const {
+  std::random_device random_device;
+  auto rng = std::mt19937(random_device());
+  auto filter_rng = std::bind(std::uniform_int_distribution<int32_t>(
+                                  -std::numeric_limits<int8_t>::max(),
+                                  std::numeric_limits<int8_t>::max()),
+                              std::ref(rng));
+  auto bias_rng =
+      std::bind(std::uniform_real_distribution<float>(-10, 10), std::ref(rng));
+  auto kernel_scale_rng =
+      std::bind(std::uniform_real_distribution<float>(0.1, 3), std::ref(rng));
+
+  flatbuffers::FlatBufferBuilder builder;
+  const std::array<flatbuffers::Offset<OperatorCode>, 1> operator_codes{
+      {CreateOperatorCode(builder, BuiltinOperator_CONV_2D)}};
+  std::vector<flatbuffers::Offset<Operator>> operators;
+
+  std::vector<int8_t> filter_data(OutputChannels() * KernelHeight() *
+                                  KernelWidth() * KernelInputChannels());
+  std::generate(filter_data.begin(), filter_data.end(), std::ref(filter_rng));
+  std::vector<float> bias_data(OutputChannels());
+  std::generate(bias_data.begin(), bias_data.end(), std::ref(bias_rng));
+  std::vector<float> kernel_scale(OutputChannels());
+  std::generate(kernel_scale.begin(), kernel_scale.end(),
+                std::ref(kernel_scale_rng));
+
+  const std::array<flatbuffers::Offset<Buffer>, 3> buffersq{{
+      CreateBuffer(builder, builder.CreateVector({})),
+      CreateBuffer(builder,
+                   builder.CreateVector(
+                       reinterpret_cast<const uint8_t*>(filter_data.data()),
+                       sizeof(int8_t) * filter_data.size())),
+      CreateBuffer(builder,
+                   builder.CreateVector(
+                       reinterpret_cast<const uint8_t*>(bias_data.data()),
+                       sizeof(float) * bias_data.size())),
+  }};
+
+  const std::array<int32_t, 4> filter_shape{
+      {OutputChannels(), KernelHeight(), KernelWidth(), KernelInputChannels()}};
+  const std::array<int32_t, 1> bias_shape{{OutputChannels()}};
+
+  const std::array<int32_t, 4> input_shape{
+      {BatchSize(), InputHeight(), InputWidth(), InputChannels()}};
+  const std::array<int32_t, 4> output_shape{
+      {BatchSize(), OutputHeight(), OutputWidth(), OutputChannels()}};
+  std::vector<flatbuffers::Offset<Tensor>> tensors;
+  tensors.emplace_back(CreateTensor(
+      builder,
+      builder.CreateVector<int32_t>(input_shape.data(), input_shape.size()),
+      TensorType_FLOAT32, /*buffer=*/0));
+  tensors.emplace_back(CreateTensor(
+      builder,
+      builder.CreateVector<int32_t>(filter_shape.data(), filter_shape.size()),
+      TensorType_INT8, /*buffer=*/1, /*name=*/0,
+      CreateQuantizationParameters(
+          builder, /*min=*/0, /*max=*/0,
+          // builder.CreateVector<float>({1}),
+          // builder.CreateVector<int64_t>({0}))));
+          builder.CreateVector<float>(kernel_scale),
+          builder.CreateVector<int64_t>(
+              std::vector<int64_t>(OutputChannels(), 0)))));
+  tensors.emplace_back(CreateTensor(
+      builder,
+      builder.CreateVector<int32_t>(bias_shape.data(), bias_shape.size()),
+      TensorType_FLOAT32, /*buffer=*/2));
+  tensors.emplace_back(CreateTensor(
+      builder,
+      builder.CreateVector<int32_t>(output_shape.data(), output_shape.size()),
+      TensorType_FLOAT32));
+
+  const flatbuffers::Offset<Conv2DOptions> conv2d_options =
+      CreateConv2DOptions(builder, Padding(), StrideWidth(), StrideHeight(),
+                          Activation(), DilationWidth(), DilationHeight());
+
+  std::vector<int32_t> op_inputs{{static_cast<int32_t>(tensors.size()) - 3,
+                                  static_cast<int32_t>(tensors.size()) - 2}};
+  op_inputs.insert(op_inputs.begin(), static_cast<int32_t>(tensors.size()) - 4);
+  const std::array<int32_t, 1> op_outputs{
+      {static_cast<int32_t>(tensors.size()) - 1}};
+  operators.emplace_back(CreateOperator(
+      builder, /*opcode_index=*/0,
+      builder.CreateVector<int32_t>(op_inputs.data(), op_inputs.size()),
+      builder.CreateVector<int32_t>(op_outputs.data(), op_outputs.size()),
+      BuiltinOptions_Conv2DOptions, conv2d_options.Union()));
+
+  const std::array<int32_t, 1> subgraph_inputs{
+      {static_cast<int>(tensors.size()) - 3 - static_cast<int>(1)}};
+  const std::array<int32_t, 1> subgraph_outputs{
+      {static_cast<int>(tensors.size()) - 1}};
+  flatbuffers::Offset<SubGraph> subgraph = CreateSubGraph(
+      builder, builder.CreateVector(tensors.data(), tensors.size()),
+      builder.CreateVector<int32_t>(subgraph_inputs.data(),
+                                    subgraph_inputs.size()),
+      builder.CreateVector<int32_t>(subgraph_outputs.data(),
+                                    subgraph_outputs.size()),
+      builder.CreateVector(operators.data(), operators.size()));
+
+  flatbuffers::Offset<flatbuffers::String> description =
+      builder.CreateString("Dynamically Quantized Conv2D model");
+
+  flatbuffers::Offset<Model> model_buffer = CreateModel(
+      builder, TFLITE_SCHEMA_VERSION,
+      builder.CreateVector(operator_codes.data(), operator_codes.size()),
+      builder.CreateVector(&subgraph, 1), description,
+      builder.CreateVector(buffersq.data(), buffersq.size()));
+
+  builder.Finish(model_buffer);
+
+  return std::vector<char>(builder.GetBufferPointer(),
+                           builder.GetBufferPointer() + builder.GetSize());
+}
+
+}  // namespace xnnpack
+}  // namespace tflite
diff --git a/tensorflow/lite/delegates/xnnpack/dynamically_quantized_conv_2d_tester.h b/tensorflow/lite/delegates/xnnpack/dynamically_quantized_conv_2d_tester.h
new file mode 100644
index 0000000..dc7439d
--- /dev/null
+++ b/tensorflow/lite/delegates/xnnpack/dynamically_quantized_conv_2d_tester.h
@@ -0,0 +1,245 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_DELEGATES_XNNPACK_DYNAMICALLY_QUANTIZED_CONV_2D_TESTER_H_
+#define TENSORFLOW_LITE_DELEGATES_XNNPACK_DYNAMICALLY_QUANTIZED_CONV_2D_TESTER_H_
+
+#include <cstdint>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/lite/core/c/common.h"
+#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
+#include "tensorflow/lite/schema/schema_generated.h"
+
+namespace tflite {
+namespace xnnpack {
+
+class DynamicallyQuantizedConv2DTester {
+ public:
+  DynamicallyQuantizedConv2DTester() = default;
+  DynamicallyQuantizedConv2DTester(const DynamicallyQuantizedConv2DTester&) =
+      delete;
+  DynamicallyQuantizedConv2DTester& operator=(
+      const DynamicallyQuantizedConv2DTester&) = delete;
+
+  inline DynamicallyQuantizedConv2DTester& BatchSize(int32_t batch_size) {
+    EXPECT_GT(batch_size, 0);
+    batch_size_ = batch_size;
+    return *this;
+  }
+
+  inline int32_t BatchSize() const { return batch_size_; }
+
+  inline DynamicallyQuantizedConv2DTester& InputChannels(
+      int32_t input_channels) {
+    EXPECT_GT(input_channels, 0);
+    input_channels_ = input_channels;
+    return *this;
+  }
+
+  inline int32_t InputChannels() const { return input_channels_; }
+
+  inline DynamicallyQuantizedConv2DTester& OutputChannels(
+      int32_t output_channels) {
+    EXPECT_GT(output_channels, 0);
+    output_channels_ = output_channels;
+    return *this;
+  }
+
+  inline int32_t OutputChannels() const { return output_channels_; }
+
+  inline DynamicallyQuantizedConv2DTester& Groups(int32_t groups) {
+    EXPECT_EQ(InputChannels() % groups, 0);
+    EXPECT_EQ(OutputChannels() % groups, 0);
+    groups_ = groups;
+    return *this;
+  }
+
+  inline int32_t Groups() const { return groups_; }
+
+  inline int32_t KernelInputChannels() const {
+    return input_channels_ / groups_;
+  }
+
+  inline DynamicallyQuantizedConv2DTester& InputHeight(int32_t input_height) {
+    EXPECT_GT(input_height, 0);
+    input_height_ = input_height;
+    return *this;
+  }
+
+  inline int32_t InputHeight() const { return input_height_; }
+
+  inline DynamicallyQuantizedConv2DTester& InputWidth(int32_t input_width) {
+    EXPECT_GT(input_width, 0);
+    input_width_ = input_width;
+    return *this;
+  }
+
+  inline int32_t InputWidth() const { return input_width_; }
+
+  inline int32_t OutputWidth() const {
+    if (Padding() == ::tflite::Padding_SAME) {
+      EXPECT_GE(InputWidth(), 1);
+      return (InputWidth() - 1) / StrideWidth() + 1;
+    } else {
+      EXPECT_GE(InputWidth(), DilatedKernelWidth());
+      return 1 + (InputWidth() - DilatedKernelWidth()) / StrideWidth();
+    }
+  }
+
+  inline int32_t OutputHeight() const {
+    if (Padding() == ::tflite::Padding_SAME) {
+      EXPECT_GE(InputHeight(), 1);
+      return (InputHeight() - 1) / StrideHeight() + 1;
+    } else {
+      EXPECT_GE(InputHeight(), DilatedKernelHeight());
+      return 1 + (InputHeight() - DilatedKernelHeight()) / StrideHeight();
+    }
+  }
+
+  inline DynamicallyQuantizedConv2DTester& KernelHeight(int32_t kernel_height) {
+    EXPECT_GT(kernel_height, 0);
+    kernel_height_ = kernel_height;
+    return *this;
+  }
+
+  inline int32_t KernelHeight() const { return kernel_height_; }
+
+  inline DynamicallyQuantizedConv2DTester& KernelWidth(int32_t kernel_width) {
+    EXPECT_GT(kernel_width, 0);
+    kernel_width_ = kernel_width;
+    return *this;
+  }
+
+  inline int32_t KernelWidth() const { return kernel_width_; }
+
+  inline DynamicallyQuantizedConv2DTester& StrideHeight(int32_t stride_height) {
+    EXPECT_GT(stride_height, 0);
+    stride_height_ = stride_height;
+    return *this;
+  }
+
+  inline int32_t StrideHeight() const { return stride_height_; }
+
+  inline DynamicallyQuantizedConv2DTester& StrideWidth(int32_t stride_width) {
+    EXPECT_GT(stride_width, 0);
+    stride_width_ = stride_width;
+    return *this;
+  }
+
+  inline int32_t StrideWidth() const { return stride_width_; }
+
+  inline DynamicallyQuantizedConv2DTester& DilationHeight(
+      int32_t dilation_height) {
+    EXPECT_GT(dilation_height, 0);
+    dilation_height_ = dilation_height;
+    return *this;
+  }
+
+  inline int32_t DilationHeight() const { return dilation_height_; }
+
+  inline DynamicallyQuantizedConv2DTester& DilationWidth(
+      int32_t dilation_width) {
+    EXPECT_GT(dilation_width, 0);
+    dilation_width_ = dilation_width;
+    return *this;
+  }
+
+  inline int32_t DilationWidth() const { return dilation_width_; }
+
+  inline int32_t DilatedKernelHeight() const {
+    return (KernelHeight() - 1) * DilationHeight() + 1;
+  }
+
+  inline int32_t DilatedKernelWidth() const {
+    return (KernelWidth() - 1) * DilationWidth() + 1;
+  }
+
+  inline DynamicallyQuantizedConv2DTester& SamePadding() {
+    padding_ = ::tflite::Padding_SAME;
+    return *this;
+  }
+
+  inline DynamicallyQuantizedConv2DTester& ValidPadding() {
+    padding_ = ::tflite::Padding_VALID;
+    return *this;
+  }
+
+  inline DynamicallyQuantizedConv2DTester& ReluActivation() {
+    activation_ = ::tflite::ActivationFunctionType_RELU;
+    return *this;
+  }
+
+  inline DynamicallyQuantizedConv2DTester& Relu6Activation() {
+    activation_ = ::tflite::ActivationFunctionType_RELU6;
+    return *this;
+  }
+
+  inline DynamicallyQuantizedConv2DTester& ReluMinus1To1Activation() {
+    activation_ = ::tflite::ActivationFunctionType_RELU_N1_TO_1;
+    return *this;
+  }
+
+  inline DynamicallyQuantizedConv2DTester& TanhActivation() {
+    activation_ = ::tflite::ActivationFunctionType_TANH;
+    return *this;
+  }
+
+  inline DynamicallyQuantizedConv2DTester& SignBitActivation() {
+    activation_ = ::tflite::ActivationFunctionType_SIGN_BIT;
+    return *this;
+  }
+
+  inline DynamicallyQuantizedConv2DTester& WeightsCache(
+      TfLiteXNNPackDelegateWeightsCache* weights_cache) {
+    weights_cache_ = weights_cache;
+    return *this;
+  }
+
+  void Test(TfLiteDelegate* delegate) const;
+
+  std::vector<char> CreateTfLiteModel() const;
+
+ private:
+  inline ::tflite::Padding Padding() const { return padding_; }
+
+  inline ::tflite::ActivationFunctionType Activation() const {
+    return activation_;
+  }
+
+  int32_t batch_size_ = 1;
+  int32_t input_channels_ = 1;
+  int32_t output_channels_ = 1;
+  int32_t groups_ = 1;
+  int32_t input_height_ = 1;
+  int32_t input_width_ = 1;
+  int32_t kernel_height_ = 1;
+  int32_t kernel_width_ = 1;
+  int32_t stride_height_ = 1;
+  int32_t stride_width_ = 1;
+  int32_t dilation_height_ = 1;
+  int32_t dilation_width_ = 1;
+  ::tflite::Padding padding_ = ::tflite::Padding_VALID;
+
+  ::tflite::ActivationFunctionType activation_ =
+      ::tflite::ActivationFunctionType_NONE;
+  TfLiteXNNPackDelegateWeightsCache* weights_cache_ = nullptr;
+};
+
+}  // namespace xnnpack
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_DELEGATES_XNNPACK_DYNAMICALLY_QUANTIZED_CONV_2D_TESTER_H_
diff --git a/tensorflow/lite/delegates/xnnpack/dynamically_quantized_fully_connected_tester.cc b/tensorflow/lite/delegates/xnnpack/dynamically_quantized_fully_connected_tester.cc
index f52dbd5..7ea1bed 100644
--- a/tensorflow/lite/delegates/xnnpack/dynamically_quantized_fully_connected_tester.cc
+++ b/tensorflow/lite/delegates/xnnpack/dynamically_quantized_fully_connected_tester.cc
@@ -79,10 +79,21 @@
   float* delegate_output_data =
       delegate_interpreter->typed_output_tensor<float>(0);
 
-  for (size_t i = 0; i < ComputeSize(OutputShape()); i++) {
-    EXPECT_NEAR(default_output_data[i], delegate_output_data[i],
-                std::numeric_limits<float>::epsilon() *
-                    std::max(std::abs(default_output_data[i]) * 20.0f, 1.0f));
+  const int num_output_values = ComputeSize(OutputShape());
+  int different_output_values = 0;
+  // TFLite rounds to nearest with ties to Away. XNNPACK rounds to nearest with
+  // ties to even. IEEE 754 states: "Round to nearest, ties to even" is the
+  // default for binary floating point and the recommended default for decimal.
+  // For this reason, many output values may differ slightly.
+  for (size_t i = 0; i < num_output_values; i++) {
+    if (std::abs(default_output_data[i] - delegate_output_data[i]) >
+        0.005 * std::abs(default_output_data[i])) {
+      ++different_output_values;
+    }
+  }
+  if (different_output_values > 0.05 * num_output_values) {
+    GTEST_FAIL() << (float)different_output_values / num_output_values * 100.f
+                 << "% of output values differ";
   }
 }
 
diff --git a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc
index 2507a02..ca7bdfc 100644
--- a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc
+++ b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc
@@ -30,6 +30,7 @@
 #include <utility>
 #include <vector>
 
+#include "experiments-config.h"  // from @XNNPACK
 #include "xnnpack.h"  // from @XNNPACK
 #include "tensorflow/lite/builtin_ops.h"
 #include "tensorflow/lite/core/api/profiler.h"
@@ -555,6 +556,10 @@
             TFLITE_XNNPACK_DELEGATE_FLAG_TRANSIENT_INDIRECTION_BUFFER) != 0;
   }
 
+  bool experimental_adaptive_avx_optimization() const {
+    return options_.experimental_adaptive_avx_optimization;
+  }
+
   pthreadpool_t threadpool() const {
 #if defined(__EMSCRIPTEN__) && !defined(__EMSCRIPTEN_PTHREADS__)
     return nullptr;
@@ -1017,6 +1022,9 @@
     if (delegate.transient_indirection_buffer()) {
       flags |= XNN_FLAG_TRANSIENT_INDIRECTION_BUFFER;
     }
+    if (delegate.experimental_adaptive_avx_optimization()) {
+      xnn_experiment_enable_adaptive_avx_optimization();
+    }
     if (delegate.force_fp16()) {
       flags |= XNN_FLAG_FORCE_FP16_INFERENCE;
     } else {
@@ -3211,8 +3219,11 @@
     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
         logging_context, output_tensor, node->outputs->data[0], node_index));
 
+    bool dynamically_quantized = (delegate.enable_latest_operators() &&
+                                  (input_tensor.type == kTfLiteFloat32 &&
+                                   filter_tensor.type == kTfLiteInt8));
     if (input_tensor.type != output_tensor.type ||
-        input_tensor.type != filter_tensor.type) {
+        ((input_tensor.type != filter_tensor.type) && !dynamically_quantized)) {
       TF_LITE_MAYBE_KERNEL_LOG(
           logging_context, "unsupported mixed types in CONV_2D operator #%d",
           node_index);
@@ -3236,28 +3247,109 @@
         &output_max));
 
     if (subgraph != nullptr) {
-      const xnn_status status = xnn_define_convolution_2d(
-          subgraph,
-          /*input_padding_top=*/0,
-          /*input_padding_right=*/0,
-          /*input_padding_bottom=*/0,
-          /*input_padding_left=*/0, static_cast<uint32_t>(kernel_height),
-          static_cast<uint32_t>(kernel_width),
-          static_cast<uint32_t>(conv_params->stride_height),
-          static_cast<uint32_t>(conv_params->stride_width),
-          static_cast<uint32_t>(conv_params->dilation_height_factor),
-          static_cast<uint32_t>(conv_params->dilation_width_factor), groups,
-          static_cast<size_t>(input_channels),
-          static_cast<size_t>(output_channels) / groups, output_min, output_max,
-          /*input_id=*/input_output_tensors.at(node->inputs->data[0]),
-          /*filter_id=*/input_output_tensors.at(node->inputs->data[1]),
-          /*bias_id=*/input_output_tensors.at(node->inputs->data[2]),
-          /*output_id=*/input_output_tensors.at(node->outputs->data[0]), flags);
-      if (status != xnn_status_success) {
-        TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d",
-                           EnumNameBuiltinOperator(BuiltinOperator_CONV_2D),
-                           node_index);
-        return kTfLiteError;
+      if (dynamically_quantized) {
+        TfLiteAffineQuantization* filter_params =
+            reinterpret_cast<TfLiteAffineQuantization*>(
+                filter_tensor.quantization.params);
+        if (filter_params->scale->size != output_channels) {
+          TfLiteFloatArrayFree(filter_params->scale);
+          filter_params->scale = TfLiteFloatArrayCreate(output_channels);
+          for (int i = 0; i < output_channels; ++i) {
+            filter_params->scale->data[i] = filter_tensor.params.scale;
+          }
+        }
+        uint32_t dq_quantized_id = XNN_INVALID_VALUE_ID;
+        std::vector<size_t> input_dims(
+            &input_tensor.dims->data[0],
+            &input_tensor.dims->data[NumDimensions(&input_tensor)]);
+        xnn_status status = xnn_define_dynamically_quantized_tensor_value(
+            subgraph, xnn_datatype_qdint8, input_dims.size(),
+            /*num_nonbatch_dims=*/3, input_dims.data(), XNN_INVALID_VALUE_ID,
+            /*flags=*/0, &dq_quantized_id);
+        if (status != xnn_status_success) {
+          TF_LITE_KERNEL_LOG(logging_context,
+                             "failed to create XNNPACK Value for tensor %d",
+                             -1);
+          return kTfLiteError;
+        }
+        status = xnn_define_convert(
+            subgraph,
+            /*input_id=*/input_output_tensors.at(node->inputs->data[0]),
+            dq_quantized_id, /*flags=*/0);
+        if (status != xnn_status_success) {
+          TF_LITE_KERNEL_LOG(
+              logging_context, "failed to delegate %s node #%d",
+              EnumNameBuiltinOperator(BuiltinOperator_FULLY_CONNECTED),
+              node_index);
+          return kTfLiteError;
+        }
+        std::vector<size_t> filter_dims(
+            &filter_tensor.dims->data[0],
+            &filter_tensor.dims->data[NumDimensions(&filter_tensor)]);
+        uint32_t kernel_id = XNN_INVALID_VALUE_ID;
+        status = xnn_define_channelwise_quantized_tensor_value(
+            subgraph, xnn_datatype_qcint8, filter_params->scale->data,
+            filter_dims.size(), /*channel_dim=*/0, filter_dims.data(),
+            GetTensorData<int8_t>(&filter_tensor), XNN_INVALID_VALUE_ID,
+            /*flags=*/0, &kernel_id);
+        if (status != xnn_status_success) {
+          TF_LITE_KERNEL_LOG(
+              logging_context, "failed to update filter tensor %s node #%d",
+              EnumNameBuiltinOperator(BuiltinOperator_FULLY_CONNECTED),
+              node_index);
+          return kTfLiteError;
+        }
+        status = xnn_define_convolution_2d(
+            subgraph,
+            /*input_padding_top=*/0,
+            /*input_padding_right=*/0,
+            /*input_padding_bottom=*/0,
+            /*input_padding_left=*/0, static_cast<uint32_t>(kernel_height),
+            static_cast<uint32_t>(kernel_width),
+            static_cast<uint32_t>(conv_params->stride_height),
+            static_cast<uint32_t>(conv_params->stride_width),
+            static_cast<uint32_t>(conv_params->dilation_height_factor),
+            static_cast<uint32_t>(conv_params->dilation_width_factor), groups,
+            static_cast<size_t>(input_channels),
+            static_cast<size_t>(output_channels) / groups, output_min,
+            output_max,
+            /*input_id=*/dq_quantized_id,
+            /*filter_id=*/kernel_id,
+            /*bias_id=*/input_output_tensors.at(node->inputs->data[2]),
+            /*output_id=*/input_output_tensors.at(node->outputs->data[0]),
+            flags);
+        if (status != xnn_status_success) {
+          TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d",
+                             EnumNameBuiltinOperator(BuiltinOperator_CONV_2D),
+                             node_index);
+          return kTfLiteError;
+        }
+      } else {
+        const xnn_status status = xnn_define_convolution_2d(
+            subgraph,
+            /*input_padding_top=*/0,
+            /*input_padding_right=*/0,
+            /*input_padding_bottom=*/0,
+            /*input_padding_left=*/0, static_cast<uint32_t>(kernel_height),
+            static_cast<uint32_t>(kernel_width),
+            static_cast<uint32_t>(conv_params->stride_height),
+            static_cast<uint32_t>(conv_params->stride_width),
+            static_cast<uint32_t>(conv_params->dilation_height_factor),
+            static_cast<uint32_t>(conv_params->dilation_width_factor), groups,
+            static_cast<size_t>(input_channels),
+            static_cast<size_t>(output_channels) / groups, output_min,
+            output_max,
+            /*input_id=*/input_output_tensors.at(node->inputs->data[0]),
+            /*filter_id=*/input_output_tensors.at(node->inputs->data[1]),
+            /*bias_id=*/input_output_tensors.at(node->inputs->data[2]),
+            /*output_id=*/input_output_tensors.at(node->outputs->data[0]),
+            flags);
+        if (status != xnn_status_success) {
+          TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d",
+                             EnumNameBuiltinOperator(BuiltinOperator_CONV_2D),
+                             node_index);
+          return kTfLiteError;
+        }
       }
     }
 
diff --git a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h
index 1154528..f6cf417 100644
--- a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h
+++ b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h
@@ -64,6 +64,8 @@
   // Deprecated. Use the flags bitfield with the
   // TFLITE_XNNPACK_DELEGATE_FLAG_VARIABLE_OPERATORS mask.
   bool handle_variable_ops;
+  // Enable adaptive optimization for AVX CPUs.
+  bool experimental_adaptive_avx_optimization;
 } TfLiteXNNPackDelegateOptions;
 
 // Returns a structure with the default XNNPack delegate options.
diff --git a/tensorflow/lite/experimental/acceleration/configuration/BUILD b/tensorflow/lite/experimental/acceleration/configuration/BUILD
index d97b2a6..2e14980 100644
--- a/tensorflow/lite/experimental/acceleration/configuration/BUILD
+++ b/tensorflow/lite/experimental/acceleration/configuration/BUILD
@@ -18,6 +18,7 @@
 load("//tensorflow/lite:build_def.bzl", "tflite_copts", "tflite_copts_warnings")
 load("//tensorflow/lite:special_rules.bzl", "nnapi_plugin_impl_visibility_allowlist", "tflite_portable_test_suite")
 load("//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite")
+# copybara:uncomment load("//tools/build_defs/proto/cpp:cc_proto_library.bzl", "cc_proto_library")
 
 # copybara:comment_begin(oss-only)
 load("@local_tsl//tsl/platform/default:build_config.bzl", "tf_proto_library_py")
diff --git a/tensorflow/lite/java/AndroidManifestApi.xml b/tensorflow/lite/java/AndroidManifestApi.xml
new file mode 100644
index 0000000..fa022be
--- /dev/null
+++ b/tensorflow/lite/java/AndroidManifestApi.xml
@@ -0,0 +1,12 @@
+<?xml version="1.0" encoding="utf-8"?>
+<manifest xmlns:android="http://schemas.android.com/apk/res/android"
+    package="org.tensorflow.lite.api">
+
+    <!-- TFLite Java Library is built against NDK API 19. -->
+    <uses-sdk
+        android:minSdkVersion="19" />
+
+    <application />
+
+</manifest>
+
diff --git a/tensorflow/lite/java/AndroidManifestGpu.xml b/tensorflow/lite/java/AndroidManifestGpu.xml
index 085a663..4c4c484 100644
--- a/tensorflow/lite/java/AndroidManifestGpu.xml
+++ b/tensorflow/lite/java/AndroidManifestGpu.xml
@@ -1,22 +1,12 @@
 <?xml version="1.0" encoding="utf-8"?>
 <manifest xmlns:android="http://schemas.android.com/apk/res/android"
-    package="org.tensorflow.lite">
+    package="org.tensorflow.lite.gpu">
 
-    <application>
-        <!-- Applications that target Android S+ require explicit declaration of
-             any referenced vendor-provided libraries. -->
-        <uses-native-library
-            android:name="libOpenCL.so"
-            android:required="false" />
+    <!-- TFLite Java Library is built against NDK API 19. -->
+    <uses-sdk
+        android:minSdkVersion="19" />
 
-        <uses-native-library
-            android:name="libOpenCL-car.so"
-            android:required="false" />
-
-        <uses-native-library
-            android:name="libOpenCL-pixel.so"
-            android:required="false" />
-    </application>
+    <application />
 
 </manifest>
 
diff --git a/tensorflow/lite/java/AndroidManifestGpuApi.xml b/tensorflow/lite/java/AndroidManifestGpuApi.xml
new file mode 100644
index 0000000..1343f5d
--- /dev/null
+++ b/tensorflow/lite/java/AndroidManifestGpuApi.xml
@@ -0,0 +1,22 @@
+<?xml version="1.0" encoding="utf-8"?>
+<manifest xmlns:android="http://schemas.android.com/apk/res/android"
+    package="org.tensorflow.lite.gpu.api">
+
+    <application>
+        <!-- Applications that target Android S+ require explicit declaration of
+             any referenced vendor-provided libraries. -->
+        <uses-native-library
+            android:name="libOpenCL.so"
+            android:required="false" />
+
+        <uses-native-library
+            android:name="libOpenCL-car.so"
+            android:required="false" />
+
+        <uses-native-library
+            android:name="libOpenCL-pixel.so"
+            android:required="false" />
+    </application>
+
+</manifest>
+
diff --git a/tensorflow/lite/java/BUILD b/tensorflow/lite/java/BUILD
index 15106b6..3c0f026 100644
--- a/tensorflow/lite/java/BUILD
+++ b/tensorflow/lite/java/BUILD
@@ -31,7 +31,9 @@
     "src/testdata/uint8.bin",
     "src/testdata/with_custom_op.lite",
     "AndroidManifest.xml",
+    "AndroidManifestApi.xml",
     "AndroidManifestGpu.xml",
+    "AndroidManifestGpuApi.xml",
     "proguard.flags",
     "tflite_version_script.lds",
 ])
@@ -269,7 +271,7 @@
 android_library(
     name = "tensorflowlite_api",
     srcs = [":java_api_srcs"],
-    manifest = "AndroidManifest.xml",
+    manifest = "AndroidManifestApi.xml",
     proguard_specs = ["proguard.flags"],
     deps = [
         "@org_checkerframework_qual",
@@ -318,7 +320,7 @@
 # EXPERIMENTAL: Android target for GPU acceleration. Note that this library
 # contains *only* the GPU delegate and its Java wrapper; clients must also
 # include the core `tensorflowlite` runtime.
-# Note that AndroidManifestGpu.xml usage requires AGP 4.2.0+.
+# Note that AndroidManifestGpuApi.xml usage requires AGP 4.2.0+.
 alias(
     name = "tensorflowlite_gpu",
     actual = "tensorflowlite_gpu_impl",
@@ -327,7 +329,7 @@
 # EXPERIMENTAL: Android target for the implementation of the GPU acceleration API, including the
 # native library. Note that this library contains *only* the GPU delegate and its Java wrapper;
 # clients must also include the core `tensorflowlite` runtime.
-# Note that AndroidManifestGpu.xml usage requires AGP 4.2.0+.
+# Note that AndroidManifestGpuApi.xml usage requires AGP 4.2.0+.
 android_library(
     name = "tensorflowlite_gpu_impl",
     # Note that we need to directly includes all the Java source files we intend to ship directly in
@@ -339,7 +341,7 @@
     # Note that this uses the standard manifest and doesn't export it: the declaration is required
     # because android_library targets require a non-empty Android package in Bazel. The API target
     # exports the GPU manifest.
-    manifest = "AndroidManifest.xml",
+    manifest = "AndroidManifestGpu.xml",
     exports = [
         ":tensorflowlite_gpu_api",
         ":tensorflowlite_gpu_native",
@@ -388,12 +390,12 @@
 # EXPERIMENTAL: Android target for GPU acceleration API, EXCLUDING implementation.
 # Note that this library contains *only* the GPU delegate API; clients must also include
 # an implementation, as well as the core `tensorflowlite` runtime.
-# Note that AndroidManifestGpu.xml usage requires AGP 4.2.0+.
+# Note that AndroidManifestGpuApi.xml usage requires AGP 4.2.0+.
 android_library(
     name = "tensorflowlite_gpu_api",
     srcs = ["//tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu:gpu_delegate"],
     exports_manifest = 1,
-    manifest = "AndroidManifestGpu.xml",
+    manifest = "AndroidManifestGpuApi.xml",
     proguard_specs = ["proguard.flags"],
     deps = [":tensorflowlite_api"],
 )
diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD
index 99f7df0..bf2d189 100644
--- a/tensorflow/lite/kernels/BUILD
+++ b/tensorflow/lite/kernels/BUILD
@@ -12,6 +12,14 @@
     licenses = ["notice"],
 )
 
+filegroup(
+    name = "tflite_internal_cc_3p_api_deps_src",
+    srcs = ["op_macros.h"],
+    visibility = [
+        "//tensorflow/lite:__pkg__",
+    ],
+)
+
 # Enables usage of ruy exclusively as the GEMM backend in TFLite kernels.
 # This will cause TFLite to build with ruy only, providing a smaller binary.
 # WARNING: This build flag is experimental and subject to change.
@@ -495,6 +503,24 @@
 )
 
 cc_library(
+    name = "stablehlo_elementwise",
+    srcs = ["stablehlo_elementwise.cc"],
+    hdrs = [
+        "stablehlo_elementwise.h",
+    ],
+    compatible_with = get_compatible_with_portable(),
+    copts = tflite_copts(),
+    deps = [
+        ":kernel_util",
+        "//tensorflow/lite/core/c:common",
+        "//tensorflow/lite/kernels/internal:runtime_shape",
+        "//tensorflow/lite/kernels/internal:tensor_ctypes",
+        "//tensorflow/lite/kernels/internal:types",
+        "@eigen_archive//:eigen3",
+    ],
+)
+
+cc_library(
     name = "rng_util",
     srcs = [
         "rng_util.cc",
@@ -727,7 +753,6 @@
     "range.cc",
     "rank.cc",
     "reduce.cc",
-    "reduce_window.cc",
     "reshape.cc",
     "resize_bilinear.cc",
     "resize_nearest_neighbor.cc",
@@ -739,6 +764,11 @@
     "segment_sum.cc",
     "select.cc",
     "shape.cc",
+    "stablehlo_gather.cc",
+    "stablehlo_add.cc",
+    "stablehlo_multiply.cc",
+    "stablehlo_reduce_window.cc",
+    "stablehlo_min_max.cc",
     "stablehlo_scatter.cc",
     "sign.cc",
     "skip_gram.cc",
@@ -786,6 +816,7 @@
     ":lstm_shared",
     ":op_macros",
     ":padding",
+    ":stablehlo_elementwise",
     ":control_flow_common",
     "@eigen_archive//:eigen3",
     "@flatbuffers",
@@ -818,6 +849,12 @@
 }) + select({
     "//tensorflow/lite:tflite_with_xnnpack_explicit_false": [],
     "//conditions:default": [
+        "@pthreadpool",
+    ],
+}) + select({
+    "//tensorflow/lite:tflite_with_xnnpack_explicit_false": [],
+    "//tensorflow/lite:tflite_kernel_use_xnnpack_false": [],
+    "//conditions:default": [
         "@XNNPACK",
     ],
 }) + select({
@@ -1301,6 +1338,47 @@
     ],
 )
 
+cc_library(
+    name = "stablehlo_reduce_window_test_util",
+    hdrs = ["stablehlo_reduce_window_test_util.h"],
+    deps = [
+        "@com_google_absl//absl/algorithm:container",
+    ],
+)
+
+cc_test(
+    name = "stablehlo_reduce_window_test_util_test",
+    size = "small",
+    srcs = ["stablehlo_reduce_window_test_util_test.cc"],
+    deps = [
+        ":stablehlo_reduce_window_test_util",
+        "@com_google_googletest//:gtest_main",
+    ],
+)
+
+cc_test(
+    name = "stablehlo_reduce_window_test",
+    size = "small",
+    srcs = ["stablehlo_reduce_window_test.cc"],
+    tags = ["tflite_nnapi"],
+    deps = [
+        ":stablehlo_reduce_window_test_util",
+        ":subgraph_test_util",
+        ":test_main",
+        ":test_util",
+        "//tensorflow/lite/c:c_api_types",
+        "//tensorflow/lite/core/c:common",
+        "//tensorflow/lite/schema:schema_fbs",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/log:absl_log",
+        "@com_google_absl//absl/random",
+        "@com_google_absl//absl/random:bit_gen_ref",
+        "@com_google_absl//absl/random:distributions",
+        "@com_google_absl//absl/types:span",
+        "@com_google_googletest//:gtest",
+    ],
+)
+
 cc_test(
     name = "space_to_batch_nd_test",
     size = "small",
@@ -3082,6 +3160,68 @@
 )
 
 cc_test(
+    name = "stablehlo_gather_test",
+    size = "small",
+    srcs = ["stablehlo_gather_test.cc"],
+    deps = [
+        ":test_main",
+        ":test_util",
+        "//tensorflow/lite/c:c_api_types",
+        "//tensorflow/lite/core/c:common",
+        "//tensorflow/lite/schema:schema_fbs",
+        "@com_google_googletest//:gtest",
+    ],
+)
+
+cc_test(
+    name = "stablehlo_add_test",
+    size = "small",
+    srcs = ["stablehlo_add_test.cc"],
+    deps = [
+        ":subgraph_test_util",
+        ":test_util",
+        "//tensorflow/lite/c:c_api_types",
+        "//tensorflow/lite/c:common",
+        "//tensorflow/lite/core:subgraph",
+        "//tensorflow/lite/core/c:common",
+        "//tensorflow/lite/schema:schema_fbs",
+        "@com_google_googletest//:gtest_main",
+    ],
+)
+
+cc_test(
+    name = "stablehlo_multiply_test",
+    size = "small",
+    srcs = ["stablehlo_multiply_test.cc"],
+    deps = [
+        ":subgraph_test_util",
+        ":test_util",
+        "//tensorflow/lite/c:c_api_types",
+        "//tensorflow/lite/c:common",
+        "//tensorflow/lite/core:subgraph",
+        "//tensorflow/lite/core/c:common",
+        "//tensorflow/lite/schema:schema_fbs",
+        "@com_google_googletest//:gtest_main",
+    ],
+)
+
+cc_test(
+    name = "stablehlo_min_max_test",
+    size = "small",
+    srcs = ["stablehlo_min_max_test.cc"],
+    deps = [
+        ":test_util",
+        "//tensorflow/lite/c:c_api_types",
+        "//tensorflow/lite/c:common",
+        "//tensorflow/lite/core:subgraph",
+        "//tensorflow/lite/core/c:common",
+        "//tensorflow/lite/schema:schema_fbs",
+        "@com_google_absl//absl/log:absl_log",
+        "@com_google_googletest//:gtest_main",
+    ],
+)
+
+cc_test(
     name = "stablehlo_scatter_test",
     size = "small",
     srcs = ["stablehlo_scatter_test.cc"],
diff --git a/tensorflow/lite/kernels/activations.cc b/tensorflow/lite/kernels/activations.cc
index a28b819..4ed6784 100644
--- a/tensorflow/lite/kernels/activations.cc
+++ b/tensorflow/lite/kernels/activations.cc
@@ -46,6 +46,13 @@
 #include "tensorflow/lite/kernels/internal/types.h"
 #include "tensorflow/lite/kernels/kernel_util.h"
 
+#ifdef TFLITE_KERNEL_USE_XNNPACK
+#include "xnnpack.h"  // from @XNNPACK
+#include "tensorflow/lite/logger.h"
+#include "tensorflow/lite/minimal_logging.h"
+#include "pthreadpool.h"  // from @pthreadpool
+#endif  // TFLITE_KERNEL_USE_XNNPACK
+
 namespace tflite {
 namespace ops {
 namespace builtin {
@@ -737,6 +744,24 @@
   const ReluOpData* data = reinterpret_cast<ReluOpData*>(node->user_data);
   switch (input->type) {
     case kTfLiteFloat32: {
+#ifdef TFLITE_KERNEL_USE_XNNPACK
+      const size_t channel_dim = 1;
+      const size_t batch_size = NumElements(input->dims);
+      CpuBackendContext* cpu_backend_context =
+          CpuBackendContext::GetFromContext(context);
+      pthreadpool_t threadpool = cpu_backend_context->get_xnnpack_threadpool();
+      xnn_status status = xnn_run_clamp_nc_f32(
+          channel_dim, channel_dim, channel_dim, batch_size,
+          GetTensorData<float>(input), GetTensorData<float>(output),
+          /*min=*/0.0f, /*max=*/std::numeric_limits<float>::infinity(),
+          /*flags=*/XNN_FLAG_YIELD_WORKERS, threadpool);
+      if (status == xnn_status_success) {
+        return kTfLiteOk;
+      }
+      TFLITE_LOG(TFLITE_LOG_INFO,
+                 "Failed to run xnnpack xnn_run_clamp_nc_f32. Error code: %d",
+                 status);
+#endif
       optimized_ops::Relu(GetTensorShape(input), GetTensorData<float>(input),
                           GetTensorShape(output), GetTensorData<float>(output));
     } break;
@@ -772,6 +797,24 @@
   const ReluOpData* data = reinterpret_cast<ReluOpData*>(node->user_data);
   switch (input->type) {
     case kTfLiteFloat32: {
+#ifdef TFLITE_KERNEL_USE_XNNPACK
+      const size_t channel_dim = 1;
+      const size_t batch_size = NumElements(input->dims);
+      CpuBackendContext* cpu_backend_context =
+          CpuBackendContext::GetFromContext(context);
+      pthreadpool_t threadpool = cpu_backend_context->get_xnnpack_threadpool();
+      xnn_status status = xnn_run_clamp_nc_f32(
+          channel_dim, channel_dim, channel_dim, batch_size,
+          GetTensorData<float>(input), GetTensorData<float>(output),
+          /*min=*/-1.0f, /*max=*/1.0f, /*flags=*/XNN_FLAG_YIELD_WORKERS,
+          threadpool);
+      if (status == xnn_status_success) {
+        return kTfLiteOk;
+      }
+      TFLITE_LOG(TFLITE_LOG_INFO,
+                 "Failed to run xnnpack xnn_run_clamp_nc_f32. Error code: %d",
+                 status);
+#endif
       optimized_ops::Relu1(GetTensorShape(input), GetTensorData<float>(input),
                            GetTensorShape(output),
                            GetTensorData<float>(output));
@@ -809,6 +852,25 @@
             GetTensorShape(input), GetTensorData<float>(input),
             GetTensorShape(output), GetTensorData<float>(output));
       } else {
+#ifdef TFLITE_KERNEL_USE_XNNPACK
+        const size_t channel_dim = 1;
+        const size_t batch_size = NumElements(input->dims);
+        CpuBackendContext* cpu_backend_context =
+            CpuBackendContext::GetFromContext(context);
+        pthreadpool_t threadpool =
+            cpu_backend_context->get_xnnpack_threadpool();
+        xnn_status status = xnn_run_hardswish_nc_f32(
+            channel_dim, channel_dim, channel_dim, batch_size,
+            GetTensorData<float>(input), GetTensorData<float>(output),
+            /*flags=*/XNN_FLAG_YIELD_WORKERS, threadpool);
+        if (status == xnn_status_success) {
+          return kTfLiteOk;
+        }
+        TFLITE_LOG(
+            TFLITE_LOG_INFO,
+            "Failed to run xnnpack xnn_run_hardswish_nc_f32. Error code: %d",
+            status);
+#endif
         optimized_ops::HardSwish(
             GetTensorShape(input), GetTensorData<float>(input),
             GetTensorShape(output), GetTensorData<float>(output));
@@ -858,6 +920,24 @@
   const ReluOpData* data = reinterpret_cast<ReluOpData*>(node->user_data);
   switch (input->type) {
     case kTfLiteFloat32: {
+#ifdef TFLITE_KERNEL_USE_XNNPACK
+      const size_t channel_dim = 1;
+      const size_t batch_size = NumElements(input->dims);
+      CpuBackendContext* cpu_backend_context =
+          CpuBackendContext::GetFromContext(context);
+      pthreadpool_t threadpool = cpu_backend_context->get_xnnpack_threadpool();
+      xnn_status status = xnn_run_clamp_nc_f32(
+          channel_dim, channel_dim, channel_dim, batch_size,
+          GetTensorData<float>(input), GetTensorData<float>(output),
+          /*min=*/0.0f, /*max=*/1.0f, /*flags=*/XNN_FLAG_YIELD_WORKERS,
+          threadpool);
+      if (status == xnn_status_success) {
+        return kTfLiteOk;
+      }
+      TFLITE_LOG(TFLITE_LOG_INFO,
+                 "Failed to run xnnpack xnn_run_clamp_nc_f32. Error code: %d",
+                 status);
+#endif
       optimized_ops::Relu0To1(
           GetTensorShape(input), GetTensorData<float>(input),
           GetTensorShape(output), GetTensorData<float>(output));
@@ -888,6 +968,24 @@
   ReluOpData* data = reinterpret_cast<ReluOpData*>(node->user_data);
   switch (input->type) {
     case kTfLiteFloat32: {
+#ifdef TFLITE_KERNEL_USE_XNNPACK
+      const size_t channel_dim = 1;
+      const size_t batch_size = NumElements(input->dims);
+      CpuBackendContext* cpu_backend_context =
+          CpuBackendContext::GetFromContext(context);
+      pthreadpool_t threadpool = cpu_backend_context->get_xnnpack_threadpool();
+      xnn_status status = xnn_run_clamp_nc_f32(
+          channel_dim, channel_dim, channel_dim, batch_size,
+          GetTensorData<float>(input), GetTensorData<float>(output),
+          /*min=*/0.0f, /*max=*/6.0f, /*flags=*/XNN_FLAG_YIELD_WORKERS,
+          threadpool);
+      if (status == xnn_status_success) {
+        return kTfLiteOk;
+      }
+      TFLITE_LOG(TFLITE_LOG_INFO,
+                 "Failed to run xnnpack xnn_run_clamp_nc_f32. Error code: %d",
+                 status);
+#endif
       size_t elements = input->bytes / sizeof(float);
       const float* in = GetTensorData<float>(input);
       const float* in_end = in + elements;
@@ -929,6 +1027,24 @@
                             GetTensorShape(output),
                             GetTensorData<float>(output));
       } else {
+#ifdef TFLITE_KERNEL_USE_XNNPACK
+        const size_t channel_dim = 1;
+        const size_t batch_size = NumElements(input->dims);
+        CpuBackendContext* cpu_backend_context =
+            CpuBackendContext::GetFromContext(context);
+        pthreadpool_t threadpool =
+            cpu_backend_context->get_xnnpack_threadpool();
+        xnn_status status = xnn_run_tanh_nc_f32(
+            channel_dim, channel_dim, channel_dim, batch_size,
+            GetTensorData<float>(input), GetTensorData<float>(output),
+            /*flags=*/XNN_FLAG_YIELD_WORKERS, threadpool);
+        if (status == xnn_status_success) {
+          return kTfLiteOk;
+        }
+        TFLITE_LOG(TFLITE_LOG_INFO,
+                   "Failed to run xnnpack xnn_run_tanh_nc_f32. Error code: %d",
+                   status);
+#endif
         optimized_ops::Tanh(GetTensorShape(input), GetTensorData<float>(input),
                             GetTensorShape(output),
                             GetTensorData<float>(output));
@@ -1011,6 +1127,25 @@
             GetTensorShape(input), GetTensorData<float>(input),
             GetTensorShape(output), GetTensorData<float>(output));
       } else {
+#ifdef TFLITE_KERNEL_USE_XNNPACK
+        const size_t channel_dim = 1;
+        const size_t batch_size = NumElements(input->dims);
+        CpuBackendContext* cpu_backend_context =
+            CpuBackendContext::GetFromContext(context);
+        pthreadpool_t threadpool =
+            cpu_backend_context->get_xnnpack_threadpool();
+        xnn_status status = xnn_run_sigmoid_nc_f32(
+            channel_dim, channel_dim, channel_dim, batch_size,
+            GetTensorData<float>(input), GetTensorData<float>(output),
+            /*flags=*/XNN_FLAG_YIELD_WORKERS, threadpool);
+        if (status == xnn_status_success) {
+          return kTfLiteOk;
+        }
+        TFLITE_LOG(
+            TFLITE_LOG_INFO,
+            "Failed to run xnnpack xnn_run_sigmoid_nc_f32. Error code: %d",
+            status);
+#endif
         optimized_ops::Logistic(
             GetTensorShape(input), GetTensorData<float>(input),
             GetTensorShape(output), GetTensorData<float>(output));
@@ -1450,6 +1585,24 @@
   LeakyReluParams op_params;
   switch (input->type) {
     case kTfLiteFloat32: {
+#ifdef TFLITE_KERNEL_USE_XNNPACK
+      const size_t channel_dim = 1;
+      const size_t batch_size = NumElements(input->dims);
+      CpuBackendContext* cpu_backend_context =
+          CpuBackendContext::GetFromContext(context);
+      pthreadpool_t threadpool = cpu_backend_context->get_xnnpack_threadpool();
+      xnn_status status = xnn_run_leaky_relu_nc_f32(
+          channel_dim, channel_dim, channel_dim, batch_size,
+          GetTensorData<float>(input), GetTensorData<float>(output),
+          params->alpha, /*flags=*/XNN_FLAG_YIELD_WORKERS, threadpool);
+      if (status == xnn_status_success) {
+        return kTfLiteOk;
+      }
+      TFLITE_LOG(
+          TFLITE_LOG_INFO,
+          "Failed to run xnnpack xnn_run_leaky_relu_nc_f32. Error code: %d",
+          status);
+#endif
       op_params.alpha = params->alpha;
       optimized_ops::LeakyRelu(
           op_params, GetTensorShape(input), GetTensorData<float>(input),
@@ -1502,6 +1655,23 @@
   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
   switch (input->type) {
     case kTfLiteFloat32: {
+#ifdef TFLITE_KERNEL_USE_XNNPACK
+      const size_t channel_dim = 1;
+      const size_t batch_size = NumElements(input->dims);
+      CpuBackendContext* cpu_backend_context =
+          CpuBackendContext::GetFromContext(context);
+      pthreadpool_t threadpool = cpu_backend_context->get_xnnpack_threadpool();
+      xnn_status status = xnn_run_elu_nc_f32(
+          channel_dim, channel_dim, channel_dim, batch_size,
+          GetTensorData<float>(input), GetTensorData<float>(output),
+          /*alpha=*/1.0f, /*flags=*/XNN_FLAG_YIELD_WORKERS, threadpool);
+      if (status == xnn_status_success) {
+        return kTfLiteOk;
+      }
+      TFLITE_LOG(TFLITE_LOG_INFO,
+                 "Failed to run xnnpack xnn_run_elu_nc_f32. Error code: %d",
+                 status);
+#endif
       optimized_ops::Elu(GetTensorShape(input), GetTensorData<float>(input),
                          GetTensorShape(output), GetTensorData<float>(output));
       return kTfLiteOk;
diff --git a/tensorflow/lite/kernels/concatenation.cc b/tensorflow/lite/kernels/concatenation.cc
index 99a6212..4f11644 100644
--- a/tensorflow/lite/kernels/concatenation.cc
+++ b/tensorflow/lite/kernels/concatenation.cc
@@ -16,6 +16,8 @@
 
 #include <stdint.h>
 
+#include <cstddef>
+#include <cstring>
 #include <limits>
 
 #include "tensorflow/lite/core/c/builtin_op_data.h"
@@ -27,6 +29,7 @@
 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
 #include "tensorflow/lite/kernels/internal/types.h"
 #include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/util.h"
 
 namespace tflite {
 namespace ops {
@@ -134,7 +137,8 @@
   TfLiteType input_type = t0->type;
   if (axis < 0) axis += t0->dims->size;
   TF_LITE_ENSURE(context, axis >= 0);
-  TF_LITE_ENSURE(context, axis < t0->dims->size);
+  TF_LITE_ENSURE(context,
+                 axis < t0->dims->size || (t0->dims->size == 0 && axis == 0));
 
   TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone);
   TF_LITE_ENSURE(context,
@@ -143,23 +147,60 @@
                      input_type == kTfLiteInt32 || input_type == kTfLiteInt64 ||
                      input_type == kTfLiteBool || input_type == kTfLiteUInt32);
 
-  // Output dimensions will match input dimensions, except 'axis', which
-  // will be the sum of inputs
-  int sum_axis = t0->dims->data[axis];
-  for (int i = 1; i < num_inputs; ++i) {
+  // Check to see if we can calculate the output now.
+  bool all_inputs_at_prepare = true;
+  for (int i = 0; i < num_inputs; ++i) {
     const TfLiteTensor* t;
     TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &t));
-    TF_LITE_ENSURE_EQ(context, t->dims->size, t0->dims->size);
-    TF_LITE_ENSURE_EQ(context, t->type, input_type);
-    for (int d = 0; d < t0->dims->size; ++d) {
-      if (d == axis) {
-        // Avoid integer overflow in sum_axis below
-        TF_LITE_ENSURE(context, t->dims->data[axis] >= 0);
-        TF_LITE_ENSURE(context, t->dims->data[axis] <=
-                                    std::numeric_limits<int>::max() - sum_axis);
-        sum_axis += t->dims->data[axis];
-      } else {
-        TF_LITE_ENSURE_EQ(context, t->dims->data[d], t0->dims->data[d]);
+    if (!IsConstantOrPersistentTensor(t)) {
+      all_inputs_at_prepare = false;
+      break;
+    }
+  }
+  // Output dimensions will match input dimensions, except 'axis', which
+  // will be the sum of inputs
+  int sum_axis = t0->dims->size > 0 ? t0->dims->data[axis] : 1;
+  // Check if we are concatenating constant scalars.
+  if (all_inputs_at_prepare && t0->dims->size == 0 && axis == 0) {
+    for (int i = 1; i < num_inputs; ++i) {
+      const TfLiteTensor* t;
+      TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &t));
+      TF_LITE_ENSURE_EQ(context, t->dims->size, t0->dims->size);
+    }
+    TfLiteTensor* output;
+    TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
+    TfLiteIntArray* output_size = TfLiteIntArrayCreate(1);
+    output_size->data[0] = num_inputs;
+    SetTensorToPersistentRo(output);
+    context->ResizeTensor(context, output, output_size);
+    size_t input_type_size;
+    TF_LITE_ENSURE_STATUS(GetSizeOfType(context, t0->type, &input_type_size));
+    void* o_data = output->data.data;
+    for (int i = 0; i < num_inputs; ++i) {
+      const TfLiteTensor* t;
+      TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &t));
+      const void* i_data = t->data.data;
+      memcpy(o_data, i_data, input_type_size);
+      o_data = (void*)((uintptr_t)o_data + input_type_size);
+    }
+    return kTfLiteOk;
+  } else {
+    for (int i = 1; i < num_inputs; ++i) {
+      const TfLiteTensor* t;
+      TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &t));
+      TF_LITE_ENSURE_EQ(context, t->dims->size, t0->dims->size);
+      TF_LITE_ENSURE_EQ(context, t->type, input_type);
+      for (int d = 0; d < t0->dims->size; ++d) {
+        if (d == axis) {
+          // Avoid integer overflow in sum_axis below
+          TF_LITE_ENSURE(context, t->dims->data[axis] >= 0);
+          TF_LITE_ENSURE(context,
+                         t->dims->data[axis] <=
+                             std::numeric_limits<int>::max() - sum_axis);
+          sum_axis += t->dims->data[axis];
+        } else {
+          TF_LITE_ENSURE_EQ(context, t->dims->data[d], t0->dims->data[d]);
+        }
       }
     }
   }
@@ -195,16 +236,6 @@
     TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
   }
 
-  // Check to see if we can calculate the output now.
-  bool all_inputs_at_prepare = true;
-  for (int i = 0; i < num_inputs; ++i) {
-    const TfLiteTensor* t;
-    TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &t));
-    if (!IsConstantOrPersistentTensor(t)) {
-      all_inputs_at_prepare = false;
-      break;
-    }
-  }
   if (all_inputs_at_prepare) {
     SetTensorToPersistentRo(output);
     context->ResizeTensor(context, output, output_size);
diff --git a/tensorflow/lite/kernels/concatenation_test.cc b/tensorflow/lite/kernels/concatenation_test.cc
index e9b7b93..a5b4a3c 100644
--- a/tensorflow/lite/kernels/concatenation_test.cc
+++ b/tensorflow/lite/kernels/concatenation_test.cc
@@ -762,6 +762,22 @@
   }
 }
 
+TYPED_TEST(ConcatenationOpPersistentModelTest, PersistentScalarTest) {
+  PersistentTestCase test_case{TestInputType::kPersistentRo,
+                               GetTensorType<TypeParam>(), false};
+  std::vector<std::vector<TypeParam>> input_data_lists = {{1}, {7}};
+  std::vector<TensorData> input_template = {{GetTensorType<TypeParam>(), {}},
+                                            {GetTensorType<TypeParam>(), {}}};
+  TensorData output_template = {GetTensorType<TypeParam>(), {2}};
+  PersistentConcatenationOpModel<TypeParam> m0(
+      input_template, /*axis=*/0, output_template, test_case, input_data_lists);
+  m0.PopulateInputTensors();
+  ASSERT_EQ(m0.Invoke(), kTfLiteOk);
+  ASSERT_EQ(m0.IsPersistentOutput(),
+            test_case.test_type == TestInputType::kPersistentRo);
+  EXPECT_THAT(m0.GetOutput(), ElementsAreArray(ArrayFloatNear({1.0, 7.0})));
+}
+
 TYPED_TEST(ConcatenationOpPersistentModelTest, QuantizedPersistentTest) {
   const bool is_quantized = true;
   for (PersistentTestCase test_case :
diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD
index b6a2912..f74cf28 100644
--- a/tensorflow/lite/kernels/internal/BUILD
+++ b/tensorflow/lite/kernels/internal/BUILD
@@ -40,6 +40,14 @@
     ],
 })
 
+filegroup(
+    name = "tflite_internal_cc_3p_api_deps_src",
+    srcs = ["compatibility.h"],
+    visibility = [
+        "//tensorflow/lite:__pkg__",
+    ],
+)
+
 cc_library(
     name = "compatibility",
     hdrs = ["compatibility.h"],
diff --git a/tensorflow/lite/kernels/reduce_window.cc b/tensorflow/lite/kernels/reduce_window.cc
deleted file mode 100644
index a6906a6..0000000
--- a/tensorflow/lite/kernels/reduce_window.cc
+++ /dev/null
@@ -1,342 +0,0 @@
-/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-         //
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include <algorithm>
-#include <array>
-#include <cstdint>
-#include <functional>
-#include <type_traits>
-
-#include "tensorflow/lite/array.h"
-#include "tensorflow/lite/c/c_api_types.h"
-#include "tensorflow/lite/core/c/builtin_op_data.h"
-#include "tensorflow/lite/core/c/common.h"
-#include "tensorflow/lite/kernels/kernel_util.h"
-
-namespace tflite {
-namespace ops {
-namespace builtin {
-namespace reduce_window {
-namespace {
-
-constexpr int32_t kMaxReduceWindowDims = 6;
-
-template <int Val>
-using IntCst = std::integral_constant<int, Val>;
-
-// Reduces the elements of a tensor viewed through a strided window.
-//
-// This applies a reduction to a tensor by skipping over elements that are not
-// in the window defined by the given shape and strides. The window is reduced
-// to one element.
-//
-// The shape is the shape of the window. The strides are based on the actual
-// tensor and the distance between window elements, counted in elements. Sparse
-// windows are possible.
-//
-// For instance: the following window has a [2, 2] shape and [8, 3] strides.
-//
-// ┌──┐     ┌──┐
-// │ 1│ 2  3│ 4│
-// └──┘     └──┘
-//   5  6  7  8    is reduced to 1 + 4 + 9 + 12 = 26
-// ┌──┐     ┌──┐
-// │ 9│10 11│12│
-// └──┘     └──┘
-//  13 14 15 16
-//
-// This is a recursive implementation of the strided reduction.
-template <class Op, class Type>
-void StridedReduce(const Type* input, const int64_t* const shape,
-                   const int64_t* const strides, Type& accu, const int rank,
-                   const int depth) {
-  const int64_t stride = strides[depth];
-  const int64_t size = shape[depth];
-  if (depth + 1 == rank) {
-    const Op op;
-    for (int64_t i = 0; i < size; ++i) {
-      accu = op(accu, *input);
-      input += stride;
-    }
-  } else {
-    for (int64_t i = 0; i < size; ++i) {
-      StridedReduce<Op, Type>(input, shape, strides, accu, rank, depth + 1);
-      input += stride;
-    }
-  }
-}
-
-// Recursively computes strided reductions using a sliding window over the given
-// tensor.
-//
-// The window is defined using a shape and a dilation. The shape defines the
-// elements that the window will let the reduction *see*. The dilation defines
-// the step between window elements.
-//
-// For instance: the following window has a [2, 2] shape and [2, 3] dilations.
-//
-//    3
-// ┌────┐
-// ┌─┐   ┌─┐
-// │X│X X│X│┐
-// └─┘   └─┘│2
-//  X X X X ┘
-// ┌─┐   ┌─┐
-// │X│X X│X│
-// └─┘   └─┘
-template <class Op, class Type>
-void ReduceWindowImpl(const Type* input, Type* output,
-                      const int64_t* const output_shape,
-                      const int64_t* const output_strides,
-                      const int64_t* const window_offset_strides,
-                      const int64_t* const window_shape,
-                      const int64_t* const window_reduce_strides,
-                      const Type init, const int rank, const int depth) {
-  if (depth + 1 == rank) {
-    for (int32_t dim = 0; dim < output_shape[depth]; ++dim) {
-      *output = init;
-      StridedReduce<Op, Type>(input, window_shape, window_reduce_strides,
-                              *output, rank, 0);
-      input += window_offset_strides[depth];
-      output += output_strides[depth];
-    }
-  } else {
-    for (int32_t dim = 0; dim < output_shape[depth]; ++dim) {
-      ReduceWindowImpl<Op, Type>(input, output, output_shape, output_strides,
-                                 window_offset_strides, window_shape,
-                                 window_reduce_strides, init, rank, depth + 1);
-      input += window_offset_strides[depth];
-      output += output_strides[depth];
-    }
-  }
-}
-
-std::array<int64_t, kMaxReduceWindowDims> ComputeStrides(
-    const int64_t* const shape, const int64_t rank) {
-  std::array<int64_t, kMaxReduceWindowDims> strides;
-  strides[rank - 1] = 1;
-  for (int64_t i = rank - 2; i >= 0; --i) {
-    strides[i] = shape[i + 1] * strides[i + 1];
-  }
-  return strides;
-}
-
-// Element-wise multiplication of the two operands of given size.
-std::array<int64_t, kMaxReduceWindowDims> Multiply(const int64_t* const vec1,
-                                                   const int64_t* const vec2,
-                                                   const int64_t size) {
-  std::array<int64_t, kMaxReduceWindowDims> result;
-  for (int64_t i = 0; i < size; ++i) {
-    result[i] = vec2[i] * vec1[i];
-  }
-  return result;
-}
-
-// Computes the output shape of the ReduceWindow operator.
-std::array<int64_t, kMaxReduceWindowDims> ComputeOutputShape(
-    const int64_t* const shape, const int64_t* const window_shape,
-    const int64_t* const window_strides, const int64_t* const window_dilations,
-    const int64_t rank) {
-  std::array<int64_t, kMaxReduceWindowDims> dilated_window_shape;
-  for (int64_t i = 0; i < rank; ++i) {
-    dilated_window_shape[i] = (window_shape[i] - 1) * window_dilations[i] + 1;
-  }
-
-  std::array<int64_t, kMaxReduceWindowDims> window_range;
-  for (int64_t i = 0; i < rank; ++i) {
-    window_range[i] =
-        (shape[i] - dilated_window_shape[i]) / window_strides[i] + 1;
-  }
-  return window_range;
-}
-
-template <class Op, class Type>
-void ReduceWindow(const Type* const input, Type* output,
-                  const int64_t* const shape, const int64_t* const window_shape,
-                  const int64_t* const window_strides,
-                  const int64_t* const window_dilations, const Type init,
-                  const int rank) {
-  const std::array<int64_t, kMaxReduceWindowDims> strides =
-      ComputeStrides(shape, rank);
-  const std::array<int64_t, kMaxReduceWindowDims> window_reduce_strides =
-      Multiply(strides.data(), window_dilations, rank);
-  const std::array<int64_t, kMaxReduceWindowDims> window_offset_strides =
-      Multiply(strides.data(), window_strides, rank);
-  const std::array<int64_t, kMaxReduceWindowDims> output_shape =
-      ComputeOutputShape(shape, window_shape, window_strides, window_dilations,
-                         rank);
-  const std::array<int64_t, kMaxReduceWindowDims> output_strides =
-      ComputeStrides(output_shape.data(), rank);
-  ReduceWindowImpl<Op, Type>(input, output, output_shape.data(),
-                             output_strides.data(),
-                             window_offset_strides.data(), window_shape,
-                             window_reduce_strides.data(), init, rank, 0);
-}
-
-std::array<int64_t, kMaxReduceWindowDims> AsInt64(const int32_t* data,
-                                                  const int size) {
-  std::array<int64_t, kMaxReduceWindowDims> res;
-  std::copy_n(data, size, res.data());
-  return res;
-}
-
-// Holds the tensors and operation context for convenience.
-struct ReduceWindowContext {
-  enum InputTensorId {
-    kInput,
-    kInitValue,
-    kWindowShape,
-    kWindowStrides,
-    kWindowDilations,
-    kNumInputTensors
-  };
-  enum OutputTensorId { kOutput, kNumOutputTensors };
-
-  ReduceWindowContext(TfLiteContext* context, TfLiteNode* node)
-      : context(context),
-        node(node),
-        input_tensor(GetInput(context, node, kInput)),
-        init_value_tensor(GetInput(context, node, kInitValue)),
-        window_shape_tensor(GetInput(context, node, kWindowShape)),
-        window_strides_tensor(GetInput(context, node, kWindowStrides)),
-        window_dilations_tensor(GetInput(context, node, kWindowDilations)),
-        output_tensor(GetOutput(context, node, kOutput)) {}
-
-  TfLiteContext* context;
-  TfLiteNode* node;
-  const TfLiteTensor* input_tensor;
-  const TfLiteTensor* init_value_tensor;
-  const TfLiteTensor* window_shape_tensor;
-  const TfLiteTensor* window_strides_tensor;
-  const TfLiteTensor* window_dilations_tensor;
-  TfLiteTensor* output_tensor;
-};
-
-TfLiteStatus SetupOutputTensor(const ReduceWindowContext& ctx) {
-  const int rank = ctx.input_tensor->dims->size;
-  const std::array<int64_t, kMaxReduceWindowDims> input_shape =
-      AsInt64(ctx.input_tensor->dims->data, rank);
-  const std::array<int64_t, kMaxReduceWindowDims> output_shape_data =
-      ComputeOutputShape(input_shape.data(), ctx.window_shape_tensor->data.i64,
-                         ctx.window_strides_tensor->data.i64,
-                         ctx.window_dilations_tensor->data.i64, rank);
-  IntArrayUniquePtr output_shape =
-      BuildTfLiteArray<int32_t>(rank, output_shape_data.data());
-  return ctx.context->ResizeTensor(ctx.context, ctx.output_tensor,
-                                   output_shape.release());
-}
-
-template <class Op>
-TfLiteStatus DispatchReduceWindowType(ReduceWindowContext& ctx) {
-  const int rank = ctx.input_tensor->dims->size;
-  const std::array<int64_t, kMaxReduceWindowDims> input_shape =
-      AsInt64(ctx.input_tensor->dims->data, rank);
-#define REDUCE_WINDOW_TYPE_CASE(CPP_TYPE, TENSOR_TYPE)                        \
-  case TENSOR_TYPE:                                                           \
-    ReduceWindow<Op, CPP_TYPE>(                                               \
-        reinterpret_cast<const CPP_TYPE*>(ctx.input_tensor->data.raw),        \
-        reinterpret_cast<CPP_TYPE*>(ctx.output_tensor->data.raw),             \
-        input_shape.data(), ctx.window_shape_tensor->data.i64,                \
-        ctx.window_strides_tensor->data.i64,                                  \
-        ctx.window_dilations_tensor->data.i64,                                \
-        *reinterpret_cast<CPP_TYPE*>(ctx.init_value_tensor->data.raw), rank); \
-    break;
-  switch (ctx.input_tensor->type) {
-    REDUCE_WINDOW_TYPE_CASE(int8_t, kTfLiteBool);
-    REDUCE_WINDOW_TYPE_CASE(int8_t, kTfLiteInt8);
-    REDUCE_WINDOW_TYPE_CASE(int16_t, kTfLiteInt16);
-    REDUCE_WINDOW_TYPE_CASE(int32_t, kTfLiteInt32);
-    REDUCE_WINDOW_TYPE_CASE(int64_t, kTfLiteInt64);
-    REDUCE_WINDOW_TYPE_CASE(uint8_t, kTfLiteUInt8);
-    // REDUCE_WINDOW_TYPE_CASE(uint16_t, kTfLiteUInt16);
-    // REDUCE_WINDOW_TYPE_CASE(uint32_t, kTfLiteUInt32);
-    // REDUCE_WINDOW_TYPE_CASE(uint64_t, kTfLiteUInt64);
-    REDUCE_WINDOW_TYPE_CASE(float, kTfLiteFloat32);
-    static_assert(sizeof(float) == 4,
-                  "float type is expected to be 32 bit long");
-    REDUCE_WINDOW_TYPE_CASE(double, kTfLiteFloat64);
-    static_assert(sizeof(double) == 8,
-                  "double type is expected to be 64 bit long");
-    default:
-      return kTfLiteError;
-  }
-#undef REDUCE_WINDOW_TYPE_CASE
-  return kTfLiteOk;
-}
-
-TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
-  TF_LITE_ENSURE_EQ(context, NumInputs(node),
-                    ReduceWindowContext::kNumInputTensors);
-  TF_LITE_ENSURE_EQ(context, NumOutputs(node),
-                    ReduceWindowContext::kNumOutputTensors);
-  ReduceWindowContext ctx(context, node);
-  TF_LITE_ENSURE(context, IsConstantTensor(ctx.window_shape_tensor));
-  TF_LITE_ENSURE(context, IsConstantTensor(ctx.window_strides_tensor));
-  TF_LITE_ENSURE(context, IsConstantTensor(ctx.window_dilations_tensor));
-  TF_LITE_ENSURE(context, ctx.input_tensor->dims != nullptr);
-  TF_LITE_ENSURE(context, ctx.input_tensor->dims->size > 0);
-  TF_LITE_ENSURE(context, ctx.input_tensor->dims->size <= kMaxReduceWindowDims);
-  return SetupOutputTensor(ctx);
-}
-
-struct Max {
-  template <class T>
-  constexpr T operator()(const T& a, const T& b) const {
-    return a >= b ? a : b;
-  }
-};
-
-struct Min {
-  template <class T>
-  constexpr T operator()(const T& a, const T& b) const {
-    return a <= b ? a : b;
-  }
-};
-
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
-  const auto& params =
-      *reinterpret_cast<TfLiteReduceWindowParams*>(node->builtin_data);
-  ReduceWindowContext ctx(context, node);
-  switch (params.reduce_function) {
-    case TfLiteReduceWindowFunctionUnsupported:
-      return kTfLiteError;
-    case TfLiteReduceWindowFunctionAdd:
-      return DispatchReduceWindowType<std::plus<>>(ctx);
-    case TfLiteReduceWindowFunctionMul:
-      return DispatchReduceWindowType<std::multiplies<>>(ctx);
-    case TfLiteReduceWindowFunctionAll:
-      return DispatchReduceWindowType<std::logical_and<>>(ctx);
-    case TfLiteReduceWindowFunctionAny:
-      return DispatchReduceWindowType<std::logical_or<>>(ctx);
-    case TfLiteReduceWindowFunctionMin:
-      return DispatchReduceWindowType<Min>(ctx);
-    case TfLiteReduceWindowFunctionMax:
-      return DispatchReduceWindowType<Max>(ctx);
-  }
-}
-
-}  // namespace
-}  // namespace reduce_window
-
-TfLiteRegistration* Register_REDUCE_WINDOW() {
-  static TfLiteRegistration r = {/*.init=*/nullptr, /*.free=*/nullptr,
-                                 /*.prepare=*/reduce_window::Prepare,
-                                 /*.invoke=*/reduce_window::Eval};
-  return &r;
-}
-
-}  // namespace builtin
-}  // namespace ops
-}  // namespace tflite
diff --git a/tensorflow/lite/kernels/register_ref.cc b/tensorflow/lite/kernels/register_ref.cc
index af37557..7db3969 100644
--- a/tensorflow/lite/kernels/register_ref.cc
+++ b/tensorflow/lite/kernels/register_ref.cc
@@ -188,6 +188,12 @@
 TfLiteRegistration* Register_DILATE();
 TfLiteRegistration* Register_STABLEHLO_RNG_BIT_GENERATOR();
 TfLiteRegistration* Register_REDUCE_WINDOW();
+TfLiteRegistration* Register_STABLEHLO_GATHER();
+TfLiteRegistration* Register_STABLEHLO_ADD();
+TfLiteRegistration* Register_STABLEHLO_MULTIPLY();
+TfLiteRegistration* Register_STABLEHLO_REDUCE_WINDOW();
+TfLiteRegistration* Register_STABLEHLO_MAXIMUM();
+TfLiteRegistration* Register_STABLEHLO_MINIMUM();
 
 namespace {
 
@@ -541,10 +547,17 @@
   AddBuiltin(BuiltinOperator_BITWISE_XOR, Register_BITWISE_XOR());
   AddBuiltin(BuiltinOperator_RIGHT_SHIFT, Register_RIGHT_SHIFT());
   AddBuiltin(BuiltinOperator_STABLEHLO_SCATTER, Register_STABLEHLO_SCATTER());
+  AddBuiltin(BuiltinOperator_STABLEHLO_ADD, Register_STABLEHLO_ADD());
+  AddBuiltin(BuiltinOperator_STABLEHLO_MULTIPLY, Register_STABLEHLO_MULTIPLY());
+  AddBuiltin(BuiltinOperator_STABLEHLO_MAXIMUM, Register_STABLEHLO_MAXIMUM());
+  AddBuiltin(BuiltinOperator_STABLEHLO_MINIMUM, Register_STABLEHLO_MINIMUM());
   AddBuiltin(BuiltinOperator_DILATE, Register_DILATE());
   AddBuiltin(BuiltinOperator_STABLEHLO_RNG_BIT_GENERATOR,
              Register_STABLEHLO_RNG_BIT_GENERATOR());
   AddBuiltin(BuiltinOperator_REDUCE_WINDOW, Register_REDUCE_WINDOW());
+  AddBuiltin(BuiltinOperator_STABLEHLO_REDUCE_WINDOW,
+             Register_STABLEHLO_REDUCE_WINDOW());
+  AddBuiltin(BuiltinOperator_STABLEHLO_GATHER, Register_STABLEHLO_GATHER());
   AddCustom("NumericVerify",
             tflite::ops::custom::Register_NUMERIC_VERIFY_REF());
   // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
diff --git a/tensorflow/lite/kernels/stablehlo_add.cc b/tensorflow/lite/kernels/stablehlo_add.cc
new file mode 100644
index 0000000..f22db17
--- /dev/null
+++ b/tensorflow/lite/kernels/stablehlo_add.cc
@@ -0,0 +1,26 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/core/c/common.h"
+#include "tensorflow/lite/kernels/stablehlo_elementwise.h"
+
+namespace tflite::ops::builtin {
+
+TfLiteRegistration* Register_STABLEHLO_ADD() {
+  static TfLiteRegistration r = {nullptr, nullptr, ElementwisePrepare,
+                                 ElementwiseEval<ComputationType::kAdd>};
+  return &r;
+}
+}  // namespace tflite::ops::builtin
diff --git a/tensorflow/lite/kernels/stablehlo_add_test.cc b/tensorflow/lite/kernels/stablehlo_add_test.cc
new file mode 100644
index 0000000..a62203a
--- /dev/null
+++ b/tensorflow/lite/kernels/stablehlo_add_test.cc
@@ -0,0 +1,64 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/lite/c/c_api_types.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/core/subgraph.h"
+#include "tensorflow/lite/kernels/test_util.h"
+#include "tensorflow/lite/schema/schema_generated.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAre;
+
+class AddOpModel : public SingleOpModel {
+ public:
+  AddOpModel(const TensorData& input1, const TensorData& input2,
+             const TensorData& output) {
+    input1_ = AddInput(input1);
+    input2_ = AddInput(input2);
+    output_ = AddOutput(output);
+    SetBuiltinOp(BuiltinOperator_STABLEHLO_ADD, BuiltinOptions_NONE, 0);
+    SetBypassDefaultDelegates();
+    BuildInterpreter({GetShape(input1_), GetShape(input2_)});
+  }
+
+  int input1() { return input1_; }
+  int input2() { return input2_; }
+
+  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ protected:
+  int input1_;
+  int input2_;
+  int output_;
+};
+
+TEST(StablehloElementwise, AddWorks) {
+  AddOpModel model({TensorType_FLOAT32, {1, 2, 2, 1}},
+                   {TensorType_FLOAT32, {1, 2, 2, 1}},
+                   {TensorType_FLOAT32, {}});
+  model.PopulateTensor<float>(model.input1(), {-2.0, 0.2, 0.7, 0.8});
+  model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.3, 0.5});
+  ASSERT_EQ(model.Invoke(), kTfLiteOk);
+  EXPECT_THAT(model.GetOutput(), ElementsAre(-1.9, 0.4, 1.0, 1.3));
+}
+
+}  // namespace
+}  // namespace tflite
diff --git a/tensorflow/lite/kernels/stablehlo_elementwise.cc b/tensorflow/lite/kernels/stablehlo_elementwise.cc
new file mode 100644
index 0000000..89b2e9b
--- /dev/null
+++ b/tensorflow/lite/kernels/stablehlo_elementwise.cc
@@ -0,0 +1,56 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/lite/kernels/stablehlo_elementwise.h"
+
+#include "tensorflow/lite/core/c/common.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+TfLiteStatus ElementwisePrepare(TfLiteContext* context, TfLiteNode* node) {
+  TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
+  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+  const TfLiteTensor* input_tensor1;
+  TF_LITE_ENSURE_OK(context,
+                    GetInputSafe(context, node, kInputTensor1, &input_tensor1));
+  const TfLiteTensor* input_tensor2;
+  TF_LITE_ENSURE_OK(context,
+                    GetInputSafe(context, node, kInputTensor2, &input_tensor2));
+
+  // Check the two input tensors have the same type and size.
+  TF_LITE_ENSURE_TYPES_EQ(context, input_tensor1->type, input_tensor2->type);
+  TF_LITE_ENSURE_EQ(context, input_tensor1->dims->size,
+                    input_tensor2->dims->size);
+  for (int idx = 0; idx < input_tensor1->dims->size; ++idx) {
+    TF_LITE_ENSURE_EQ(context, input_tensor1->dims->data[idx],
+                      input_tensor2->dims->data[idx]);
+  }
+
+  TfLiteTensor* output;
+  TF_LITE_ENSURE_OK(context,
+                    GetOutputSafe(context, node, kOutputTensor, &output));
+
+  // We need the copy since ResizeTensor takes ownership of output_size.
+  TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_tensor1->dims);
+  TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_size));
+
+  return TfLiteStatus::kTfLiteOk;
+}
+
+}  // namespace builtin
+}  // namespace ops
+}  // namespace tflite
diff --git a/tensorflow/lite/kernels/stablehlo_elementwise.h b/tensorflow/lite/kernels/stablehlo_elementwise.h
new file mode 100644
index 0000000..df0549b
--- /dev/null
+++ b/tensorflow/lite/kernels/stablehlo_elementwise.h
@@ -0,0 +1,150 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_KERNELS_STABLEHLO_ELEMENTWISE_H_
+#define TENSORFLOW_LITE_KERNELS_STABLEHLO_ELEMENTWISE_H_
+
+#include <cstdint>
+#include <vector>
+
+#include "Eigen/Core"  // from @eigen_archive
+#include "tensorflow/lite/core/c/common.h"
+#include "tensorflow/lite/kernels/internal/runtime_shape.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/internal/types.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+
+constexpr int kInputTensor1 = 0;
+constexpr int kInputTensor2 = 1;
+constexpr int kOutputTensor = 0;
+
+// Indicates the type of the computation performed by the element-wise op.
+enum class ComputationType { kAdd, kSub, kMax, kMin, kMul };
+
+TfLiteStatus ElementwisePrepare(TfLiteContext* context, TfLiteNode* node);
+
+// A helper function that converts a tensor index into a flat array index.
+template <typename IndexType>
+static IndexType TensorIndexToFlat(const IndexType* index, const int64_t dims,
+                                   const RuntimeShape& shape) {
+  // If it's a scalar, just return the index of the first element.
+  if (dims == 0) {
+    return 0;
+  }
+  IndexType flat_index = index[0];
+  for (int64_t i = 1; i < dims; ++i) {
+    flat_index = flat_index * shape.Dims(i) + index[i];
+  }
+  return flat_index;
+}
+
+template <typename DataType, ComputationType computation_type>
+inline DataType ApplyComputation(DataType input1, DataType input2) {
+  if (computation_type == ComputationType::kAdd) {
+    return input1 + input2;
+  } else if (computation_type == ComputationType::kSub) {
+    return input1 - input2;
+  } else if (computation_type == ComputationType::kMax) {
+    return std::max(input1, input2);
+  } else if (computation_type == ComputationType::kMin) {
+    return std::min(input1, input2);
+  } else if (computation_type == ComputationType::kMul) {
+    return input1 * input2;
+  }
+}
+
+// Evaluates this node given the type of the elements in the output_tensor
+// and the type of the elements in the input/updates vector.
+template <ComputationType computation_type, typename DataType>
+TfLiteStatus EvalWithType(TfLiteContext* context, TfLiteNode* node) {
+  const TfLiteTensor* input_tensor1;
+  TF_LITE_ENSURE_OK(context,
+                    GetInputSafe(context, node, kInputTensor1, &input_tensor1));
+  RuntimeShape input_shape = GetTensorShape(input_tensor1);
+  const DataType* input_data1 = GetTensorData<DataType>(input_tensor1);
+
+  const TfLiteTensor* input_tensor2;
+  TF_LITE_ENSURE_OK(context,
+                    GetInputSafe(context, node, kInputTensor2, &input_tensor2));
+  const DataType* input_data2 = GetTensorData<DataType>(input_tensor2);
+
+  TfLiteTensor* output;
+  TF_LITE_ENSURE_OK(context,
+                    GetOutputSafe(context, node, kOutputTensor, &output));
+  DataType* output_data = GetTensorData<DataType>(output);
+
+  int input_rank = input_tensor1->dims->size;
+  std::vector<int64_t> index(input_rank, 0);
+
+  do {
+    DataType input_value1 =
+        input_data1[TensorIndexToFlat(index.data(), input_rank, input_shape)];
+    DataType input_value2 =
+        input_data2[TensorIndexToFlat(index.data(), input_rank, input_shape)];
+
+    output_data[TensorIndexToFlat(index.data(), input_rank, input_shape)] =
+        ApplyComputation<DataType, computation_type>(input_value1,
+                                                     input_value2);
+  } while (NextIndex(input_rank, input_tensor1->dims->data, index.data()));
+
+  return TfLiteStatus::kTfLiteOk;
+}
+
+template <ComputationType computation_type>
+TfLiteStatus ElementwiseEval(TfLiteContext* context, TfLiteNode* node) {
+  const TfLiteTensor* input_tensor1;
+  TF_LITE_ENSURE_OK(context,
+                    GetInputSafe(context, node, kInputTensor1, &input_tensor1));
+
+  TfLiteType data_type = input_tensor1->type;
+
+  switch (data_type) {
+    case kTfLiteFloat16:
+      return EvalWithType<computation_type, Eigen::half>(context, node);
+    case kTfLiteFloat32:
+      return EvalWithType<computation_type, float>(context, node);
+    case kTfLiteFloat64:
+      return EvalWithType<computation_type, double>(context, node);
+    case kTfLiteInt8:
+      return EvalWithType<computation_type, int8_t>(context, node);
+    case kTfLiteInt16:
+      return EvalWithType<computation_type, int16_t>(context, node);
+    case kTfLiteInt32:
+      return EvalWithType<computation_type, int32_t>(context, node);
+    case kTfLiteInt64:
+      return EvalWithType<computation_type, int64_t>(context, node);
+    case kTfLiteUInt8:
+      return EvalWithType<computation_type, uint8_t>(context, node);
+    case kTfLiteUInt16:
+      return EvalWithType<computation_type, uint16_t>(context, node);
+    case kTfLiteUInt32:
+      return EvalWithType<computation_type, uint32_t>(context, node);
+    case kTfLiteUInt64:
+      return EvalWithType<computation_type, uint64_t>(context, node);
+    default:
+      TF_LITE_KERNEL_LOG(context, "(Data Type: %s) currently not supported.\n",
+                         TfLiteTypeGetName(data_type));
+      return TfLiteStatus::kTfLiteError;
+  }
+}
+
+}  // namespace builtin
+}  // namespace ops
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_KERNELS_STABLEHLO_ELEMENTWISE_H_
diff --git a/tensorflow/lite/kernels/stablehlo_gather.cc b/tensorflow/lite/kernels/stablehlo_gather.cc
new file mode 100644
index 0000000..684ce4e
--- /dev/null
+++ b/tensorflow/lite/kernels/stablehlo_gather.cc
@@ -0,0 +1,332 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <algorithm>
+#include <cstdint>
+#include <memory>
+#include <vector>
+
+#include "Eigen/Core"  // from @eigen_archive
+#include "tensorflow/lite/core/c/builtin_op_data.h"
+#include "tensorflow/lite/core/c/c_api_types.h"
+#include "tensorflow/lite/core/c/common.h"
+#include "tensorflow/lite/kernels/internal/runtime_shape.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/internal/types.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/kernels/tensor_slice_util.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace stablehlo_gather {
+namespace {
+
+constexpr int kOperandTensor = 0;
+constexpr int kStartIndicesTensor = 1;
+constexpr int kOutputTensor = 0;
+
+using TfLiteIntArrayUniquePtr =
+    std::unique_ptr<TfLiteIntArray, decltype(&TfLiteIntArrayFree)>;
+
+// Clips the starting indices given the operand_shape and slice_sizes. This
+// means the starting index in a dimension will be shifted back if necessary so
+// that the whole slice can fit in the operand.
+// Example:
+// starting_index = [i, j], operand_shape = [oi, oj], slice_sizes = [si, sj]
+// starting_index will be transformed to [min(i, oi - si), min(j, oj - sj)]
+template <typename IndexType>
+TfLiteStatus ClipStartingIndex(const RuntimeShape& operand_shape,
+                               const int64_t* slice_sizes, int num_slice_sizes,
+                               Index<IndexType>& starting_index) {
+  if (operand_shape.DimensionsCount() != starting_index.size() ||
+      operand_shape.DimensionsCount() != num_slice_sizes) {
+    return kTfLiteError;
+  }
+  for (int dim = 0; dim < starting_index.size(); ++dim) {
+    starting_index[dim] = std::min((int64_t)starting_index[dim],
+                                   operand_shape.Dims(dim) - slice_sizes[dim]);
+  }
+  return kTfLiteOk;
+}
+
+// Returns a vector containing slice_sizes with all the entries with indices
+// that are present in collapsed_slice_dims removed.
+// Example: slice_sizes = {3, 5, 2, 7}, collapsed_slice_dims = {1, 3}
+// Result: {3, 2}
+static std::vector<int64_t> GetCollapsedSliceShape(
+    const int64_t* slice_sizes, int num_slice_sizes,
+    const int64_t* collapsed_slice_dims, int num_collapsed_slice_dims) {
+  std::vector<int64_t> result(num_slice_sizes - num_collapsed_slice_dims);
+  int result_ctr = 0;
+  for (int dim = 0; dim < num_slice_sizes; dim++) {
+    if (!ArrayContains(collapsed_slice_dims, num_collapsed_slice_dims, dim)) {
+      result[result_ctr] = slice_sizes[dim];
+      result_ctr++;
+    }
+  }
+  return result;
+}
+
+// Creates the result shape based on the rank of the result, options and
+// shape of the result_indices operand.
+// Refer to the spec for a full explanation:
+// https://github.com/openxla/stablehlo/blob/main/docs/spec.md#gather
+static TfLiteIntArrayUniquePtr GetResultShape(
+    int64_t result_rank, const TfLiteStablehloGatherParams* data,
+    const RuntimeShape& start_indices_shape) {
+  TfLiteIntArrayUniquePtr result = TfLiteIntArrayUniquePtr(
+      TfLiteIntArrayCreate(result_rank), &TfLiteIntArrayFree);
+  int result_ctr = 0;
+
+  std::vector<int64_t> collapsed_slice_shape = GetCollapsedSliceShape(
+      data->slice_sizes, data->num_slice_sizes, data->collapsed_slice_dims,
+      data->num_collapsed_slice_dims);
+  int64_t slice_shape_ctr = 0;
+  int64_t start_indices_shape_ctr = 0;
+
+  for (int64_t dim = 0; dim < result_rank; dim++) {
+    if (ArrayContains(data->offset_dims, data->num_offset_dims, dim)) {
+      result->data[result_ctr] = collapsed_slice_shape[slice_shape_ctr];
+      slice_shape_ctr++;
+    } else {
+      if (start_indices_shape_ctr == data->index_vector_dim) {
+        start_indices_shape_ctr++;
+      }
+      result->data[result_ctr] =
+          start_indices_shape.Dims(start_indices_shape_ctr);
+      start_indices_shape_ctr++;
+    }
+    result_ctr++;
+  }
+  return result;
+}
+
+// Extracts the batch and offset indices out of a given result index.
+// Result index is the index of an element in the output(result) tensor.
+// The location of the offset dims is given in the offset_dims argument and
+// the rest are batch dimensions.
+template <typename IndexType>
+TfLiteStatus SetBatchAndOffsetIndices(const Index<IndexType>& result_index,
+                                      const int64_t* offset_dims,
+                                      int num_offset_dims,
+                                      Index<IndexType>& batch_index,
+                                      Index<IndexType>& offset_index) {
+  int offset_index_ctr = 0;
+  int batch_index_ctr = 0;
+  for (int result_dim = 0; result_dim < result_index.size(); ++result_dim) {
+    if (ArrayContains(offset_dims, num_offset_dims, result_dim)) {
+      if (offset_index_ctr >= num_offset_dims) {
+        return kTfLiteError;
+      }
+      offset_index[offset_index_ctr] = result_index[result_dim];
+      offset_index_ctr++;
+    } else {
+      if (batch_index_ctr >= result_index.size() - num_offset_dims) {
+        return kTfLiteError;
+      }
+      batch_index[batch_index_ctr] = result_index[result_dim];
+      batch_index_ctr++;
+    }
+  }
+  return kTfLiteOk;
+}
+
+// Evaluates this node given the type of the elements in the start_indices
+// and the type of the elements in the operand tensor.
+template <typename IndexType, typename DataType>
+TfLiteStatus EvalWithTypes(TfLiteContext* context, TfLiteNode* node) {
+  TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
+  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+  const TfLiteTensor* operand;
+  TF_LITE_ENSURE_OK(context,
+                    GetInputSafe(context, node, kOperandTensor, &operand));
+  int operand_rank = operand->dims->size;
+  RuntimeShape operand_shape = GetTensorShape(operand);
+
+  const TfLiteTensor* start_indices;
+  TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kStartIndicesTensor,
+                                          &start_indices));
+
+  TfLiteTensor* output;
+  TF_LITE_ENSURE_OK(context,
+                    GetOutputSafe(context, node, kOutputTensor, &output));
+
+  const TfLiteStablehloGatherParams* data =
+      reinterpret_cast<TfLiteStablehloGatherParams*>(node->builtin_data);
+
+  RuntimeShape start_indices_shape = GetTensorShape(start_indices);
+  int result_rank = output->dims->size;
+  RuntimeShape result_runtime_shape(result_rank, output->dims->data);
+  Index<IndexType> result_index = Index<IndexType>(result_rank, 0);
+
+  int64_t num_batch_dims = result_rank - data->num_offset_dims;
+
+  Index<IndexType> batch_index(num_batch_dims);
+  Index<IndexType> offset_index(data->num_offset_dims);
+  do {
+    TF_LITE_ENSURE_OK(
+        context, SetBatchAndOffsetIndices(result_index, data->offset_dims,
+                                          data->num_offset_dims, batch_index,
+                                          offset_index));
+
+    Index<IndexType> starting_index_vector =
+        ReadIndexVector(start_indices, start_indices_shape, batch_index,
+                        data->index_vector_dim);
+
+    Index<IndexType> final_starting_index;
+    ScatterIndex(starting_index_vector, data->start_index_map,
+                 data->num_start_index_map, operand_rank,
+                 &final_starting_index);
+
+    TF_LITE_ENSURE_OK(
+        context,
+        ClipStartingIndex(operand_shape, data->slice_sizes,
+                          data->num_slice_sizes, final_starting_index));
+
+    Index<IndexType> full_offset_index;
+    ExpandDims(offset_index, data->collapsed_slice_dims,
+               data->num_collapsed_slice_dims, &full_offset_index);
+
+    Index<IndexType> operand_lookup_index =
+        AddIndices(final_starting_index, full_offset_index);
+
+    const DataType* operand_data = GetTensorData<DataType>(operand);
+    IndexType flat_operand_index =
+        TensorIndexToFlat(operand_lookup_index.data(),
+                          operand_lookup_index.size(), GetTensorShape(operand));
+    DataType looked_up_value = operand_data[flat_operand_index];
+
+    DataType* result_data = GetTensorData<DataType>(output);
+    IndexType flat_result_index = TensorIndexToFlat(
+        result_index.data(), result_index.size(), GetTensorShape(output));
+    result_data[flat_result_index] = looked_up_value;
+  } while (NextIndex(result_rank, result_runtime_shape.DimsData(),
+                     result_index.data()));
+
+  return TfLiteStatus::kTfLiteOk;
+}
+
+// Evaluates this node given the type of the elements in the scatter_indices
+// tensor.
+template <typename IndexType>
+TfLiteStatus EvalWithIndexType(TfLiteContext* context, TfLiteNode* node,
+                               TfLiteType index_type, TfLiteType data_type) {
+  switch (data_type) {
+    case kTfLiteFloat16:
+      return EvalWithTypes<IndexType, Eigen::half>(context, node);
+    case kTfLiteFloat32:
+      return EvalWithTypes<IndexType, float>(context, node);
+    case kTfLiteFloat64:
+      return EvalWithTypes<IndexType, double>(context, node);
+    case kTfLiteInt8:
+      return EvalWithTypes<IndexType, int8_t>(context, node);
+    case kTfLiteInt16:
+      return EvalWithTypes<IndexType, int16_t>(context, node);
+    case kTfLiteInt32:
+      return EvalWithTypes<IndexType, int32_t>(context, node);
+    case kTfLiteInt64:
+      return EvalWithTypes<IndexType, int64_t>(context, node);
+    case kTfLiteUInt8:
+      return EvalWithTypes<IndexType, uint8_t>(context, node);
+    case kTfLiteUInt16:
+      return EvalWithTypes<IndexType, uint16_t>(context, node);
+    case kTfLiteUInt32:
+      return EvalWithTypes<IndexType, uint32_t>(context, node);
+    case kTfLiteUInt64:
+      return EvalWithTypes<IndexType, uint64_t>(context, node);
+    default:
+      TF_LITE_KERNEL_LOG(
+          context, "(Index Type: %s, Data Type: %s) currently not supported.\n",
+          TfLiteTypeGetName(index_type), TfLiteTypeGetName(data_type));
+      return TfLiteStatus::kTfLiteError;
+  }
+}
+
+}  // namespace
+
+// This is the kernel for stablehlo.gather which receives `slice_sizes` as a
+// static attribute.
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+  const TfLiteTensor* operand;
+  TF_LITE_ENSURE_OK(context,
+                    GetInputSafe(context, node, kOperandTensor, &operand));
+  const TfLiteTensor* start_indices;
+  TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kStartIndicesTensor,
+                                          &start_indices));
+
+  TfLiteType index_type = start_indices->type;
+  TfLiteType data_type = operand->type;
+
+  if (index_type == kTfLiteInt32) {
+    return EvalWithIndexType<int32_t>(context, node, index_type, data_type);
+  } else if (index_type == kTfLiteInt64) {
+    return EvalWithIndexType<int64_t>(context, node, index_type, data_type);
+  } else {
+    TF_LITE_KERNEL_LOG(context, "(Index Type: %s) currently not supported.\n",
+                       TfLiteTypeGetName(index_type));
+    return TfLiteStatus::kTfLiteError;
+  }
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+  TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
+  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+  const TfLiteTensor* operand;
+  TF_LITE_ENSURE_OK(context,
+                    GetInputSafe(context, node, kOperandTensor, &operand));
+  const TfLiteTensor* start_indices;
+  TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kStartIndicesTensor,
+                                          &start_indices));
+
+  TfLiteType index_type = start_indices->type;
+  if (index_type != kTfLiteInt32 && index_type != kTfLiteInt64) {
+    TF_LITE_KERNEL_LOG(context, "(Index Type: %s) currently not supported.\n",
+                       TfLiteTypeGetName(index_type));
+    return TfLiteStatus::kTfLiteError;
+  }
+
+  TfLiteTensor* output;
+  TF_LITE_ENSURE_OK(context,
+                    GetOutputSafe(context, node, kOutputTensor, &output));
+
+  const TfLiteStablehloGatherParams* data =
+      reinterpret_cast<TfLiteStablehloGatherParams*>(node->builtin_data);
+
+  RuntimeShape start_indices_shape = GetTensorShape(start_indices);
+
+  TfLiteIntArrayUniquePtr result_shape =
+      GetResultShape(output->dims->size, data, start_indices_shape);
+
+  // ResizeTensor takes ownership of result_shape
+  TF_LITE_ENSURE_STATUS(
+      context->ResizeTensor(context, output, result_shape.release()));
+
+  return TfLiteStatus::kTfLiteOk;
+}
+
+}  // namespace stablehlo_gather
+
+TfLiteRegistration* Register_STABLEHLO_GATHER() {
+  static TfLiteRegistration r = {nullptr, nullptr, stablehlo_gather::Prepare,
+                                 stablehlo_gather::Eval};
+  return &r;
+}
+
+}  // namespace builtin
+}  // namespace ops
+}  // namespace tflite
diff --git a/tensorflow/lite/kernels/stablehlo_gather_test.cc b/tensorflow/lite/kernels/stablehlo_gather_test.cc
index 1c27809..6f99d2e 100644
--- a/tensorflow/lite/kernels/stablehlo_gather_test.cc
+++ b/tensorflow/lite/kernels/stablehlo_gather_test.cc
@@ -15,7 +15,6 @@
 
 #include <cstdint>
 #include <initializer_list>
-#include <memory>
 #include <vector>
 
 #include <gmock/gmock.h>
@@ -96,6 +95,7 @@
   };
   StablehloGatherOpModel model({TensorType_FLOAT32, {3, 4, 2}},
                                {TensorType_INT64, {2, 3, 2}}, params);
+
   model.SetInput<float>({1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12,
                          13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
   model.SetIndices<int64_t>({0, 0, 1, 0, 2, 1, 0, 1, 1, 1, 0, 2});
@@ -122,6 +122,54 @@
   };
   StablehloGatherOpModel model({TensorType_FLOAT32, {3, 4, 2}},
                                {TensorType_INT64, {2, 3, 2}}, params);
+
+  model.SetInput<float>({1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12,
+                         13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
+  model.SetIndices<int64_t>({0, 0, 1, 0, 2, 1, 0, 1, 1, 1, 0, 9});
+
+  ASSERT_EQ(model.Invoke(), kTfLiteOk);
+  std::vector<float> expected_values = {1,  2,  3,  4,  3,  4,  5,  6,
+                                        13, 14, 15, 16, 9,  10, 11, 12,
+                                        11, 12, 13, 14, 17, 18, 19, 20};
+  EXPECT_THAT(model.GetOutput<float>(), ElementsAreArray(expected_values));
+}
+
+TEST(StablehloScatterOpTest, WorksWithDynamicShapes) {
+  TfLiteStablehloGatherParams params = {
+      {2, 3},     // offset_dims
+      2,          // num_offset_dims;
+      {0},        // collapsed_slice_dims
+      1,          // num_collapsed_slice_dims;
+      {1, 0},     // start_index_map
+      2,          // num_start_index_map;
+      2,          // index_vector_dim;
+      {1, 2, 2},  // slice_sizes
+      3,          // num_slice_sizes;
+      false       // indices_are_sorted;
+  };
+
+  TensorData indices_tensor = {TensorType_INT64,
+                               /*shape*/ {2, 3, 2},
+                               /*min*/ 0.0f,
+                               /*max*/ 0.0f,
+                               /*scale*/ 0.0f,
+                               /*zero_point*/ 0,
+                               /*per_channel_quantization*/ false,
+                               /*per_channel_quantization_scales*/ {},
+                               /*per_channel_quantization_offsets*/ {},
+                               /*channel_index*/ 0,
+                               /*traversal_order*/ {},
+                               /*format*/ {},
+                               /*block_size*/ {},
+                               /*block_map*/ {},
+                               /*shape_signature*/ {{-1, -1, 2}}};
+
+  // shape_signature when creating the model has -1 for unknown dimension sizes.
+  // After building the interpreter, `model.BuildInterpreter` resizes the
+  // tensors with the actual shape.
+  StablehloGatherOpModel model({TensorType_FLOAT32, {3, 4, 2}}, indices_tensor,
+                               params);
+
   model.SetInput<float>({1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12,
                          13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
   model.SetIndices<int64_t>({0, 0, 1, 0, 2, 1, 0, 1, 1, 1, 0, 9});
diff --git a/tensorflow/lite/kernels/stablehlo_min_max.cc b/tensorflow/lite/kernels/stablehlo_min_max.cc
new file mode 100644
index 0000000..139ac18
--- /dev/null
+++ b/tensorflow/lite/kernels/stablehlo_min_max.cc
@@ -0,0 +1,31 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/core/c/common.h"
+#include "tensorflow/lite/kernels/stablehlo_elementwise.h"
+
+namespace tflite::ops::builtin {
+
+TfLiteRegistration* Register_STABLEHLO_MAXIMUM() {
+  static TfLiteRegistration r = {nullptr, nullptr, ElementwisePrepare,
+                                 ElementwiseEval<ComputationType::kMax>};
+  return &r;
+}
+TfLiteRegistration* Register_STABLEHLO_MINIMUM() {
+  static TfLiteRegistration r = {nullptr, nullptr, ElementwisePrepare,
+                                 ElementwiseEval<ComputationType::kMin>};
+  return &r;
+}
+}  // namespace tflite::ops::builtin
diff --git a/tensorflow/lite/kernels/stablehlo_min_max_test.cc b/tensorflow/lite/kernels/stablehlo_min_max_test.cc
new file mode 100644
index 0000000..4903a41
--- /dev/null
+++ b/tensorflow/lite/kernels/stablehlo_min_max_test.cc
@@ -0,0 +1,89 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/log/absl_log.h"
+#include "tensorflow/lite/c/c_api_types.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/core/subgraph.h"
+#include "tensorflow/lite/kernels/test_util.h"
+#include "tensorflow/lite/schema/schema_generated.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+enum class ModelType { kMax, kMin };
+
+class MinMaxOpModel : public SingleOpModel {
+ public:
+  MinMaxOpModel(ModelType model_type, const TensorData& input1,
+                const TensorData& input2, const TensorData& output) {
+    input1_ = AddInput(input1);
+    input2_ = AddInput(input2);
+    output_ = AddOutput(output);
+    model_type_ = model_type;
+
+    switch (model_type_) {
+      case ModelType::kMax:
+        SetBuiltinOp(BuiltinOperator_STABLEHLO_MAXIMUM, BuiltinOptions_NONE, 0);
+        break;
+      case ModelType::kMin:
+        SetBuiltinOp(BuiltinOperator_STABLEHLO_MINIMUM, BuiltinOptions_NONE, 0);
+        break;
+      default:
+        ABSL_LOG(FATAL) << "Unknown model type.";
+    }
+    SetBypassDefaultDelegates();
+    BuildInterpreter({GetShape(input1_), GetShape(input2_)});
+  }
+
+  int input1() { return input1_; }
+  int input2() { return input2_; }
+
+  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ protected:
+  int input1_;
+  int input2_;
+  int output_;
+  ModelType model_type_;
+};
+
+TEST(StablehloElementwise, MaxWorks) {
+  MinMaxOpModel model(ModelType::kMax, {TensorType_FLOAT32, {1, 2, 2, 1}},
+                      {TensorType_FLOAT32, {1, 2, 2, 1}},
+                      {TensorType_FLOAT32, {}});
+  model.PopulateTensor<float>(model.input1(), {1.2, 2.5, -1.2, 1});
+  model.PopulateTensor<float>(model.input2(), {0.1, 3, 2, 0.5});
+  ASSERT_EQ(model.Invoke(), kTfLiteOk);
+  EXPECT_THAT(model.GetOutput(), ElementsAreArray({1.2, 3.0, 2.0, 1.0}));
+}
+
+TEST(StablehloElementwise, MinWorks) {
+  MinMaxOpModel model(ModelType::kMin, {TensorType_FLOAT32, {1, 2, 2, 1}},
+                      {TensorType_FLOAT32, {1, 2, 2, 1}},
+                      {TensorType_FLOAT32, {}});
+  model.PopulateTensor<float>(model.input1(), {1.2, 2.5, -1.2, 1});
+  model.PopulateTensor<float>(model.input2(), {0.1, 3, 2, 0.5});
+  ASSERT_EQ(model.Invoke(), kTfLiteOk);
+  EXPECT_THAT(model.GetOutput(), ElementsAreArray({0.1, 2.5, -1.2, 0.5}));
+}
+
+}  // namespace
+}  // namespace tflite
diff --git a/tensorflow/lite/kernels/stablehlo_multiply.cc b/tensorflow/lite/kernels/stablehlo_multiply.cc
new file mode 100644
index 0000000..be7abfd
--- /dev/null
+++ b/tensorflow/lite/kernels/stablehlo_multiply.cc
@@ -0,0 +1,26 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/core/c/common.h"
+#include "tensorflow/lite/kernels/stablehlo_elementwise.h"
+
+namespace tflite::ops::builtin {
+
+TfLiteRegistration* Register_STABLEHLO_MULTIPLY() {
+  static TfLiteRegistration r = {nullptr, nullptr, ElementwisePrepare,
+                                 ElementwiseEval<ComputationType::kMul>};
+  return &r;
+}
+}  // namespace tflite::ops::builtin
diff --git a/tensorflow/lite/kernels/stablehlo_multiply_test.cc b/tensorflow/lite/kernels/stablehlo_multiply_test.cc
new file mode 100644
index 0000000..07aa0d4
--- /dev/null
+++ b/tensorflow/lite/kernels/stablehlo_multiply_test.cc
@@ -0,0 +1,66 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/lite/c/c_api_types.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/core/subgraph.h"
+#include "tensorflow/lite/kernels/test_util.h"
+#include "tensorflow/lite/schema/schema_generated.h"
+
+namespace tflite {
+namespace {
+
+class MultiplyOpModel : public SingleOpModel {
+ public:
+  MultiplyOpModel(const TensorData& input1, const TensorData& input2,
+                  const TensorData& output) {
+    input1_ = AddInput(input1);
+    input2_ = AddInput(input2);
+    output_ = AddOutput(output);
+    SetBuiltinOp(BuiltinOperator_STABLEHLO_MULTIPLY, BuiltinOptions_NONE, 0);
+    SetBypassDefaultDelegates();
+    BuildInterpreter({GetShape(input1_), GetShape(input2_)});
+  }
+
+  int input1() { return input1_; }
+  int input2() { return input2_; }
+
+  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ protected:
+  int input1_;
+  int input2_;
+  int output_;
+};
+
+TEST(StablehloElementwise, MultiplyWorks) {
+  MultiplyOpModel model({TensorType_FLOAT32, {1, 2, 2, 1}},
+                        {TensorType_FLOAT32, {1, 2, 2, 1}},
+                        {TensorType_FLOAT32, {}});
+  model.PopulateTensor<float>(model.input1(), {1.2, 2.5, -1.2, 1});
+  model.PopulateTensor<float>(model.input2(), {0.1, 3, 2, 0.5});
+  ASSERT_EQ(model.Invoke(), kTfLiteOk);
+  std::vector<float> expected_values = {0.12, 7.5, -2.4, 0.5};
+  std::vector<float> actual_values = model.GetOutput();
+  ASSERT_EQ(actual_values.size(), expected_values.size());
+  for (int idx = 0; idx < expected_values.size(); ++idx) {
+    ASSERT_NEAR(actual_values[idx], expected_values[idx], 1e-6);
+  }
+}
+
+}  // namespace
+}  // namespace tflite
diff --git a/tensorflow/lite/kernels/stablehlo_reduce_window.cc b/tensorflow/lite/kernels/stablehlo_reduce_window.cc
new file mode 100644
index 0000000..32bf358
--- /dev/null
+++ b/tensorflow/lite/kernels/stablehlo_reduce_window.cc
@@ -0,0 +1,964 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+         //
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <algorithm>
+#include <array>
+#include <cassert>
+#include <cstdint>
+#include <cstring>
+#include <functional>
+#include <limits>
+#include <memory>
+#include <type_traits>
+#include <vector>
+
+#include "tensorflow/lite/array.h"
+#include "tensorflow/lite/builtin_ops.h"
+#include "tensorflow/lite/c/c_api_types.h"
+#include "tensorflow/lite/core/c/builtin_op_data.h"
+#include "tensorflow/lite/core/c/common.h"
+#include "tensorflow/lite/core/subgraph.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/util.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+
+namespace {
+constexpr int32_t kMaxReduceWindowRank = 6;
+
+// Reccursive implementation of a strided copy of a tensor.
+void StridedCopy(const int rank, const char* input, const int64_t* input_shape,
+                 const int64_t* input_strides, char* output,
+                 const int64_t* output_strides, const int64_t element_size,
+                 const int depth) {
+  if (depth + 1 == rank) {
+    for (int64_t i = 0; i < input_shape[depth]; ++i) {
+      std::memcpy(output, input, element_size);
+      input += input_strides[depth];
+      output += output_strides[depth];
+    }
+  } else {
+    for (int64_t i = 0; i < input_shape[depth]; ++i) {
+      StridedCopy(rank, input, input_shape, input_strides, output,
+                  output_strides, element_size, depth + 1);
+      input += input_strides[depth];
+      output += output_strides[depth];
+    }
+  }
+}
+
+}  // namespace
+
+namespace dilate {
+namespace {
+
+const int64_t kTFLiteDefaultBaseDilation[kMaxReduceWindowRank] = {1, 1, 1,
+                                                                  1, 1, 1};
+
+// Computes and holds the parameters that can be precomputed for the dilation
+// operation.
+struct DilateData {
+  DilateData() = default;
+
+  DilateData(const int rank, const int64_t* input_shape,
+             const int64_t* dilation, const int64_t element_size)
+      : rank(rank), init_element_size(element_size) {
+    std::copy_n(input_shape, rank, shape);
+    std::copy_n(dilation, rank, base_dilations);
+    ComputeOutputShapeAndSize(element_size);
+    skip = std::all_of(dilation, dilation + rank,
+                       [](int64_t d) { return d == 1; });
+    if (skip) {
+      return;
+    }
+    MergeTrailingDilations(element_size);
+    ComputeInputStrides();
+    ComputeOutputStridesAndSizes();
+  }
+
+  // Trailing dilation factors of 1 can be merged to the left.
+  //
+  // This optimisation artificially reduces the number of dimensions of the
+  // input tensor. If a dilation factor is 1 then no padding element is added
+  // between elements of the given dimension. From the innermost dimension we
+  // can collapse all the adjacent dimensions that have a dilation factor
+  // of 1.
+  //
+  // Note: this function updates input_strides[rank-1].
+  void MergeTrailingDilations(int64_t element_size) {
+    for (int i = rank - 2; i >= 0; --i) {
+      if (base_dilations[i + 1] == 1) {
+        element_size *= shape[i + 1];
+        --rank;
+      } else {
+        break;
+      }
+    }
+    // This can only happen if all the dilation factors are 1. It would be
+    // better to just not apply the operation but we check it as a failsafe.
+    if (rank == 1 && base_dilations[0] == 1) {
+      element_size *= shape[0];
+      shape[0] = 1;
+    }
+    input_strides[rank - 1] = element_size;
+  }
+
+  // Computes the input strides using the shape and the element size.
+  //
+  // Note the element size must be stored in `input_strides[rank-1]`.
+  void ComputeInputStrides() {
+    assert(input_strides[rank - 1] != 0);
+    for (int i = rank - 2; i >= 0; --i) {
+      input_strides[i] = shape[i + 1] * input_strides[i + 1];
+    }
+  }
+
+  // Computes the output stride and the byte size for each dimension.
+  //
+  // The size of a dimension is not the same as the stride of the next
+  // inner dimension because of the dilation.
+  //
+  // Note the element size must be stored in `input_strides[rank-1]`.
+  void ComputeOutputStridesAndSizes() {
+    output_dimension_sizes[rank - 1] = input_strides[rank - 1];
+    output_strides[rank - 1] =
+        base_dilations[rank - 1] * output_dimension_sizes[rank - 1];
+    for (int i = rank - 2; i >= 0; --i) {
+      output_dimension_sizes[i] = ((shape[i + 1] - 1) * output_strides[i + 1] +
+                                   output_dimension_sizes[i + 1]);
+      output_strides[i] = base_dilations[i] * output_dimension_sizes[i];
+    }
+  }
+
+  void ComputeOutputShapeAndSize(const int64_t element_size) {
+    output_size = element_size;
+    for (int i = 0; i < rank; ++i) {
+      output_shape[i] = (shape[i] - 1) * base_dilations[i] + 1;
+      output_size *= output_shape[i];
+    }
+  }
+
+  int64_t ElementSize() const { return input_strides[rank - 1]; }
+
+  bool skip = true;
+  int rank = 0;
+  int64_t init_element_size = 0;
+  int64_t shape[kMaxReduceWindowRank] = {};
+  int64_t base_dilations[kMaxReduceWindowRank] = {};
+  int64_t output_strides[kMaxReduceWindowRank] = {};
+  int64_t output_dimension_sizes[kMaxReduceWindowRank] = {};
+  int64_t input_strides[kMaxReduceWindowRank] = {};
+  int64_t output_shape[kMaxReduceWindowRank] = {};
+  int64_t output_size = 1;
+};
+
+// Dilates the input tensor following the parameters held in the given context.
+//
+// The dilation operation scatters the elements of its input into a new tensor
+// according to a dilation factor for each dimension. The new tensor elements
+// are initialized to 0.
+//
+// This operation can also be seen as adding interior padding to the tensor. In
+// that case, `interior padding size = dilation factor - 1`.
+//
+// For instance:
+//
+//                        1 2 3
+// A is a 3x3 tensor. A = 4 5 6
+//                        7 8 9
+//
+// We apply a dilation of 2x3.
+//
+//                         1 0 0 2 0 0 3
+//                         0 0 0 0 0 0 0
+// B = dilate(A, [2, 3]) = 4 0 0 5 0 0 6
+//                         0 0 0 0 0 0 0
+//                         7 0 0 8 0 0 9
+//
+// More rigorously:
+// - Let [s0, ..., sN] be the shape of A.
+// - Let [d0, ..., dN] be the dilation factors.
+//
+// - The shape of B is [(s0 - 1) * d0 + 1, ..., (sN - 1) * dN + 1].
+// - B(i0, ..., iN) = ┌ A(i0 / d0, ..., iN / dN)   if iX % dX == 0 for all X
+//                    └ 0 otherwise.
+void Dilate(const DilateData& ctx, const char* input, const char* init_value,
+            char* output) {
+  assert(!ctx.skip);
+  // Fill the output tensor with the padding value.
+  {
+    std::memcpy(output, init_value, ctx.init_element_size);
+    int64_t remaining_bytes = ctx.output_size;
+    int64_t copied_bytes = ctx.init_element_size;
+    while (remaining_bytes) {
+      int64_t bytes = std::min(remaining_bytes, copied_bytes);
+      std::memcpy(output + copied_bytes, output, bytes);
+      remaining_bytes -= bytes;
+      copied_bytes += bytes;
+    }
+  }
+  // Copy the relevant input elements into the output tensor.
+  StridedCopy(ctx.rank, input, ctx.shape, ctx.input_strides, output,
+              ctx.output_strides, ctx.ElementSize(), 0);
+}
+
+}  // namespace
+}  // namespace dilate
+
+namespace pad {
+namespace {
+
+const int64_t kTFLiteDefaultPadding[kMaxReduceWindowRank] = {0, 0, 0, 0, 0, 0};
+
+// Computes and holds the parameters that can be precomputed for the padding
+// operation. Note that StableHLO padding treats negative values as cropping.
+struct PadCropData {
+  PadCropData() = default;
+
+  PadCropData(int rank, const int64_t* dims, const int64_t* padding,
+              const int64_t element_size)
+      : rank(rank), element_size(element_size) {
+    assert(rank > 0);
+    assert(rank < kMaxReduceWindowRank);
+
+    // Compute the output shape.
+    output_size = element_size;
+    for (int i = 0; i < rank; ++i) {
+      output_shape[i] = dims[i] + padding[2 * i] + padding[2 * i + 1];
+      output_size *= output_shape[i];
+    }
+
+    skip = std::all_of(padding, padding + 2 * rank,
+                       [](int64_t v) { return v == 0; });
+    if (skip) {
+      return;
+    }
+
+    // Compute the strides for the input and the output tensors.
+    output_strides[rank - 1] = element_size;
+    input_strides[rank - 1] = element_size;
+    for (int i = rank - 2; i >= 0; --i) {
+      output_strides[i] = output_shape[i + 1] * output_strides[i + 1];
+      input_strides[i] = dims[i + 1] * input_strides[i + 1];
+    }
+
+    // Compute the offset to apply to the pointers to take into account
+    // padding.
+    for (int i = 0; i < rank; ++i) {
+      input_offset += std::max<int64_t>(-padding[2 * i], 0) * input_strides[i];
+      output_offset += std::max<int64_t>(padding[2 * i], 0) * output_strides[i];
+      cropped_input_shape[i] = dims[i] + std::min<int64_t>(padding[2 * i], 0) +
+                               std::min<int64_t>(padding[2 * i + 1], 0);
+    }
+  }
+
+  bool skip = true;
+  int rank = 0;
+  int64_t element_size = 0;
+  int64_t cropped_input_shape[kMaxReduceWindowRank];
+  int64_t input_strides[kMaxReduceWindowRank];
+  int64_t output_shape[kMaxReduceWindowRank];
+  int64_t output_strides[kMaxReduceWindowRank];
+  int64_t input_offset = 0;
+  int64_t output_offset = 0;
+  int64_t output_size = 0;
+};
+
+// Pads and crops the input tensor following the parameters held in the given
+// context.
+//
+// The StableHLO padding algorithm uses negative values to denote cropping.
+void PadCrop(const PadCropData& ctx, const char* input, const char* init_value,
+             char* output) {
+  assert(!ctx.skip);
+  // Fill the output tensor with the padding value.
+  {
+    std::memcpy(output, init_value, ctx.element_size);
+    int64_t remaining_bytes = ctx.output_size - ctx.element_size;
+    int64_t copied_bytes = ctx.element_size;
+    while (remaining_bytes) {
+      int64_t bytes = std::min(remaining_bytes, copied_bytes);
+      std::memcpy(output + copied_bytes, output, bytes);
+      remaining_bytes -= bytes;
+      copied_bytes += bytes;
+    }
+  }
+  // Copy the relevant input elements into the output tensor.
+  StridedCopy(ctx.rank, input + ctx.input_offset, ctx.cropped_input_shape,
+              ctx.input_strides, output + ctx.output_offset, ctx.output_strides,
+              ctx.element_size, /*depth=*/0);
+}
+
+}  // namespace
+}  // namespace pad
+
+namespace reduce_window {
+namespace {
+
+// Reduces the elements of a tensor viewed through a strided window.
+//
+// This applies a reduction to a tensor by skipping over elements that are not
+// in the window defined by the given shape and strides. The window is reduced
+// to one element.
+//
+// The shape is the shape of the window. The strides are based on the actual
+// tensor and the distance between window elements, counted in elements.
+// Sparse windows are possible.
+//
+// For instance: the following window has a [2, 2] shape and [8, 3] strides.
+//
+// ┌──┐     ┌──┐
+// │ 1│ 2  3│ 4│
+// └──┘     └──┘
+//   5  6  7  8    is reduced to 1 + 4 + 9 + 12 = 26
+// ┌──┐     ┌──┐
+// │ 9│10 11│12│
+// └──┘     └──┘
+//  13 14 15 16
+//
+// This is a recursive implementation of the strided reduction.
+template <class Op, class Type>
+void StridedReduce(const Type* input, const int64_t* const shape,
+                   const int64_t* const strides, Type& accu, const int rank,
+                   const int depth) {
+  const int64_t stride = strides[depth];
+  const int64_t size = shape[depth];
+  if (depth + 1 == rank) {
+    const Op op;
+    for (int64_t i = 0; i < size; ++i) {
+      accu = op(accu, *input);
+      input += stride;
+    }
+  } else {
+    for (int64_t i = 0; i < size; ++i) {
+      StridedReduce<Op, Type>(input, shape, strides, accu, rank, depth + 1);
+      input += stride;
+    }
+  }
+}
+
+// Recursively computes strided reductions using a sliding window over the
+// given tensor.
+//
+// The window is defined using a shape and a dilation. The shape defines the
+// elements that the window will let the reduction *see*. The dilation defines
+// the step between window elements.
+//
+// For instance: the following window has a [2, 2] shape and [2, 3] dilations.
+//
+//    3
+// ┌────┐
+// ┌─┐   ┌─┐
+// │X│X X│X│┐
+// └─┘   └─┘│2
+//  X X X X ┘
+// ┌─┐   ┌─┐
+// │X│X X│X│
+// └─┘   └─┘
+template <class Op, class Type>
+void ReduceWindowImpl(const Type* input, Type* output,
+                      const int64_t* const output_shape,
+                      const int64_t* const output_strides,
+                      const int64_t* const window_offset_strides,
+                      const int64_t* const window_shape,
+                      const int64_t* const window_reduce_strides,
+                      const Type init, const int rank, const int depth) {
+  if (depth + 1 == rank) {
+    for (int32_t dim = 0; dim < output_shape[depth]; ++dim) {
+      *output = init;
+      StridedReduce<Op, Type>(input, window_shape, window_reduce_strides,
+                              *output, rank, /*depth=*/0);
+      input += window_offset_strides[depth];
+      output += output_strides[depth];
+    }
+  } else {
+    for (int32_t dim = 0; dim < output_shape[depth]; ++dim) {
+      ReduceWindowImpl<Op, Type>(input, output, output_shape, output_strides,
+                                 window_offset_strides, window_shape,
+                                 window_reduce_strides, init, rank, depth + 1);
+      input += window_offset_strides[depth];
+      output += output_strides[depth];
+    }
+  }
+}
+
+// Computes and holds the parameters that can be precomputed for the dilation
+// operation.
+struct ReduceWindowData {
+  ReduceWindowData() = default;
+
+  ReduceWindowData(const int rank, const int64_t* input_shape,
+                   const int64_t* window_shape, const int64_t* window_strides,
+                   const int64_t* window_dilations)
+      : rank(rank),
+        input_shape(input_shape),
+        window_shape(window_shape),
+        window_dilations(window_dilations),
+        window_strides(window_strides) {
+    ComputeStrides(input_strides, input_shape);
+    Multiply(window_reduce_strides, input_strides, window_dilations);
+    Multiply(window_offset_strides, input_strides, window_strides);
+    ComputeOutputShape();
+    ComputeStrides(output_strides, output_shape);
+  }
+
+  void ComputeStrides(int64_t* strides, const int64_t* const shape) {
+    strides[rank - 1] = 1;
+    for (int64_t i = rank - 2; i >= 0; --i) {
+      strides[i] = shape[i + 1] * strides[i + 1];
+    }
+  }
+
+  void Multiply(int64_t* dst, const int64_t* const vec1,
+                const int64_t* const vec2) {
+    for (int64_t i = 0; i < rank; ++i) {
+      dst[i] = vec2[i] * vec1[i];
+    }
+  }
+
+  void ComputeOutputShape() {
+    int64_t dilated_window_shape[kMaxReduceWindowRank];
+    for (int64_t i = 0; i < rank; ++i) {
+      dilated_window_shape[i] = (window_shape[i] - 1) * window_dilations[i] + 1;
+    }
+    for (int64_t i = 0; i < rank; ++i) {
+      if (input_shape[i] < dilated_window_shape[i]) {
+        output_shape[i] = 0;
+      } else {
+        output_shape[i] =
+            (input_shape[i] - dilated_window_shape[i]) / window_strides[i] + 1;
+      }
+    }
+  }
+
+  int rank = 0;
+  const int64_t* input_shape;
+  const int64_t* window_shape;
+  const int64_t* window_dilations;
+  const int64_t* window_strides;
+  int64_t input_strides[kMaxReduceWindowRank] = {};
+  int64_t window_offset_strides[kMaxReduceWindowRank] = {};
+  int64_t window_reduce_strides[kMaxReduceWindowRank] = {};
+  int64_t output_shape[kMaxReduceWindowRank] = {};
+  int64_t output_strides[kMaxReduceWindowRank] = {};
+};
+
+template <class Op, class Type>
+void ReduceWindow(const ReduceWindowData& ctx, const Type* const input,
+                  const Type init, Type* output) {
+  ReduceWindowImpl<Op, Type>(input, output, ctx.output_shape,
+                             ctx.output_strides, ctx.window_offset_strides,
+                             ctx.window_shape, ctx.window_reduce_strides, init,
+                             ctx.rank, /*depth=*/0);
+}
+
+}  // namespace
+}  // namespace reduce_window
+
+/// Operator implementation
+
+namespace reduce_window_op {
+namespace {
+
+// Holds the data needed throughout the node lifetime.
+struct NodeData {
+  // These members are only for STABLEHLO_REDUCE_WINDOW
+  enum { kDilateOutput, kPadOutput, kTempTensorCount };
+  int temporary_tensor_offset = -1;
+  // These members are shared.
+  pad::PadCropData pad_ctx;
+  dilate::DilateData dilate_ctx;
+  reduce_window::ReduceWindowData reduce_window_ctx;
+  TfLiteReduceWindowFunction body;
+};
+
+// Holds the operation data. This is extended by the StablehloData and the
+// TFLiteData classes.
+//
+// There are two available semantics for this op implementation.
+//
+// - StablehloData, that models the STABLEHLO_REDUCE_WINDOW op.
+// - TFLiteData, that models the DEPRECATED initial REDUCE_WINDOW op.
+struct OpData {
+  OpData(TfLiteContext* context, TfLiteNode* node)
+      : context(context), node(node) {}
+
+  TfLiteContext* context;
+  TfLiteNode* node;
+
+  TfLiteType type;
+  int rank;
+  int64_t element_size;
+  int64_t input_dims[kMaxReduceWindowRank];
+  const char* input;
+  const char* init_value;
+  const int64_t* window_dimensions;
+  const int64_t* window_strides;
+  const int64_t* base_dilations;
+  const int64_t* window_dilations;
+  const int64_t* padding;
+  char* dilate_output = nullptr;
+  char* pad_output = nullptr;
+  char* output;
+
+  // Helper to resize a tensor.
+  TfLiteStatus ResizeTensor(TfLiteTensor* const tensor,
+                            const int64_t* const shape) {
+    auto dims = BuildTfLiteArray<int32_t>(rank, shape);
+    return context->ResizeTensor(context, tensor, dims.release());
+  }
+
+  // Sets the operation data type and the associated byte size.
+  TfLiteStatus SetElementType(TfLiteType t) {
+    type = t;
+    size_t unsigned_element_size;
+    TF_LITE_ENSURE_OK(context,
+                      GetSizeOfType(context, type, &unsigned_element_size));
+    TF_LITE_ENSURE_MSG(
+        context,
+        // Directly comparing the unsigned_element_size to the max value of
+        // int64_t fails the -Wtautological-constant-out-of-range-compare
+        // warning when building on 32 bit targets.
+        sizeof(unsigned_element_size) < sizeof(int64_t) ||
+            unsigned_element_size <= std::numeric_limits<int64_t>::max(),
+        "The element size cannot be contained in an int64_t value.");
+    element_size = unsigned_element_size;
+    return kTfLiteOk;
+  }
+
+  // Factors the initialization that are common across semantics.
+  //
+  // Semantic is one of StablehloData or TFLiteData.
+  template <class Semantic>
+  TfLiteStatus InitializeBase() {
+    init_value = reinterpret_cast<const char*>(
+        GetInput(context, node, Semantic::kInitValue)->data.data);
+
+    const TfLiteTensor* const input_tensor =
+        GetInput(context, node, Semantic::kInput);
+    SetElementType(input_tensor->type);
+    rank = input_tensor->dims->size;
+    std::copy_n(input_tensor->dims->data, rank, input_dims);
+    input = reinterpret_cast<const char*>(input_tensor->data.data);
+
+    TfLiteTensor* const output_tensor =
+        GetOutput(context, node, Semantic::kOutput);
+    output = reinterpret_cast<char*>(output_tensor->data.data);
+    return kTfLiteOk;
+  }
+};
+
+// Speciliazes OpData for the STABLEHLO_REDUCE_WINDOW operation.
+struct StablehloData : public OpData {
+  enum InputTensorId { kInput, kInitValue, kNumInputTensors };
+  enum OutputTensorId { kOutput, kNumOutputTensors };
+
+  using OpData::OpData;
+
+  TfLiteTensor* GetTemporary(int id) {
+    return tflite::GetTemporary(context, node, id);
+  }
+
+  TfLiteStatus Check() const {
+    TF_LITE_ENSURE_EQ(context, NumInputs(node), kNumInputTensors);
+    TF_LITE_ENSURE_EQ(context, NumOutputs(node), kNumOutputTensors);
+    const TfLiteTensor* const input_tensor = GetInput(context, node, kInput);
+    const TfLiteTensor* const output_tensor = GetOutput(context, node, kOutput);
+    const TfLiteTensor* const init_value_tensor =
+        GetInput(context, node, kInitValue);
+    TF_LITE_ENSURE_EQ(context, input_tensor->type, output_tensor->type);
+    TF_LITE_ENSURE_EQ(context, input_tensor->type, init_value_tensor->type);
+    TF_LITE_ENSURE(context, input_tensor->dims != nullptr);
+    TF_LITE_ENSURE(context, input_tensor->dims->size > 0);
+    TF_LITE_ENSURE(context, input_tensor->dims->size <= kMaxReduceWindowRank);
+    return kTfLiteOk;
+  }
+
+  TfLiteStatus Initialize() {
+    TF_LITE_ENSURE_OK(context, InitializeBase<StablehloData>());
+    const auto& params = *reinterpret_cast<TfLiteStablehloReduceWindowParams*>(
+        node->builtin_data);
+    window_dimensions = params.window_dimensions;
+    window_strides = params.window_strides;
+    base_dilations = params.base_dilations;
+    window_dilations = params.window_dilations;
+    padding = params.padding;
+    auto AllGtThanZero = [&](const int64_t* const attr) {
+      return std::all_of(attr, attr + rank, [](int64_t d) { return d > 0; });
+    };
+    TF_LITE_ENSURE(context, AllGtThanZero(base_dilations));
+    TF_LITE_ENSURE(context, AllGtThanZero(window_dimensions));
+    TF_LITE_ENSURE(context, AllGtThanZero(window_strides));
+    TF_LITE_ENSURE(context, AllGtThanZero(window_dilations));
+
+    if (node->temporaries &&
+        node->temporaries->size >= NodeData::kTempTensorCount) {
+      TfLiteTensor* const dilated_tensor =
+          GetTemporary(NodeData::kDilateOutput);
+      TfLiteTensor* const padded_tensor = GetTemporary(NodeData::kPadOutput);
+      TF_LITE_ENSURE(context, dilated_tensor != nullptr);
+      TF_LITE_ENSURE(context, padded_tensor != nullptr);
+      // When called in Prepare, these pointers are bogus because the tensors
+      // have not been resized yet. This is ok in Eval.
+      dilate_output = dilated_tensor->data.raw;
+      pad_output = padded_tensor->data.raw;
+    }
+    return kTfLiteOk;
+  }
+
+  // Sets up the temporary and output tensors and the sub-ops to dilate, pad,
+  // crop and reduce.
+  //
+  // This should be called during Prepare.
+  TfLiteStatus Setup() {
+    NodeData& node_data = *reinterpret_cast<NodeData*>(node->user_data);
+
+    TfLiteIntArrayFree(node->temporaries);
+    node->temporaries = TfLiteIntArrayCreate(NodeData::kTempTensorCount);
+    for (int i = 0; i < NodeData::kTempTensorCount; ++i) {
+      node->temporaries->data[i] = node_data.temporary_tensor_offset + i;
+    }
+
+    node_data.body = GetBodyFunction();
+
+    node_data.dilate_ctx =
+        dilate::DilateData(rank, input_dims, base_dilations, element_size);
+    node_data.pad_ctx = pad::PadCropData(
+        rank, node_data.dilate_ctx.output_shape, padding, element_size);
+    node_data.reduce_window_ctx = reduce_window::ReduceWindowData(
+        rank, node_data.pad_ctx.output_shape, window_dimensions, window_strides,
+        window_dilations);
+
+    TfLiteTensor* const dilated_tensor = GetTemporary(NodeData::kDilateOutput);
+    TfLiteTensor* const padded_tensor = GetTemporary(NodeData::kPadOutput);
+    TfLiteTensor* const output_tensor = GetOutput(context, node, kOutput);
+    dilated_tensor->type = type;
+    dilated_tensor->allocation_type = kTfLiteArenaRw;
+    padded_tensor->type = type;
+    padded_tensor->allocation_type = kTfLiteArenaRw;
+
+    TF_LITE_ENSURE_OK(context, ResizeTensor(dilated_tensor,
+                                            node_data.dilate_ctx.output_shape));
+    TF_LITE_ENSURE_OK(
+        context, ResizeTensor(padded_tensor, node_data.pad_ctx.output_shape));
+    TF_LITE_ENSURE_OK(
+        context,
+        ResizeTensor(output_tensor, node_data.reduce_window_ctx.output_shape));
+    return kTfLiteOk;
+  }
+
+  // Inspects the subgraph associated to the STABLEHLO_REDUCE_WINDOW node to
+  // find out the reduction body.
+  TfLiteReduceWindowFunction GetBodyFunction() {
+    const TfLiteStablehloReduceWindowParams& params =
+        *reinterpret_cast<TfLiteStablehloReduceWindowParams*>(
+            node->builtin_data);
+    const int body_subgraph_index = params.body_subgraph_index;
+    const Subgraph& parent_subgraph =
+        *reinterpret_cast<Subgraph*>(context->impl_);
+    const std::vector<std::unique_ptr<Subgraph>>& subgraphs =
+        *parent_subgraph.GetSubgraphs();
+    if (body_subgraph_index >= subgraphs.size()) {
+      TF_LITE_KERNEL_LOG(
+          context, "Body subgraph not found for stablehlo.reduce_window: %d.",
+          body_subgraph_index);
+      return TfLiteReduceWindowFunctionUnsupported;
+    }
+    const Subgraph& body_subgraph = *subgraphs[body_subgraph_index];
+    const std::vector<int>& execution_plan =
+        body_subgraph.pre_delegation_execution_plan().empty()
+            ? body_subgraph.execution_plan()
+            : body_subgraph.pre_delegation_execution_plan();
+
+    if (execution_plan.size() != 1) {
+      TF_LITE_KERNEL_LOG(context,
+                         "Only one kernel is allowed within "
+                         "stablehlo.reduce_window body. (%d) kernels found.\n",
+                         execution_plan.size());
+      return TfLiteReduceWindowFunctionUnsupported;
+    }
+    const int body_kernel_index = execution_plan[0];
+    const TfLiteRegistration& body_kernel_registration =
+        body_subgraph.node_and_registration(body_kernel_index)->second;
+    switch (body_kernel_registration.builtin_code) {
+      case kTfLiteBuiltinAdd:
+      case kTfLiteBuiltinStablehloAdd:
+        return TfLiteReduceWindowFunctionAdd;
+      case kTfLiteBuiltinMul:
+      case kTfLiteBuiltinStablehloMultiply:
+        return TfLiteReduceWindowFunctionMul;
+      case kTfLiteBuiltinMaximum:
+      case kTfLiteBuiltinStablehloMaximum:
+        return TfLiteReduceWindowFunctionMax;
+      case kTfLiteBuiltinMinimum:
+      case kTfLiteBuiltinStablehloMinimum:
+        return TfLiteReduceWindowFunctionMin;
+      case kTfLiteBuiltinLogicalAnd:
+      case kTfLiteBuiltinStablehloAnd:
+        return TfLiteReduceWindowFunctionAll;
+      case kTfLiteBuiltinLogicalOr:
+      case kTfLiteBuiltinStablehloOr:
+        return TfLiteReduceWindowFunctionAny;
+      default:
+        TF_LITE_KERNEL_LOG(
+            context, "%s:%d unsupported reduction body builtin code: %d.\n",
+            __FILE__, __LINE__, body_kernel_registration.builtin_code);
+        return TfLiteReduceWindowFunctionUnsupported;
+    }
+  }
+};
+
+// Specializes OpData for the REDUCE_WINDOW operation.
+struct TFLiteData : public OpData {
+  enum InputTensorId {
+    kInput,
+    kInitValue,
+    kWindowShape,
+    kWindowStrides,
+    kWindowDilations,
+    kNumInputTensors
+  };
+  enum OutputTensorId { kOutput, kNumOutputTensors };
+
+  using OpData::OpData;
+
+  TfLiteStatus Check() const {
+    TF_LITE_ENSURE_EQ(context, NumInputs(node), kNumInputTensors);
+    TF_LITE_ENSURE_EQ(context, NumOutputs(node), kNumOutputTensors);
+    const TfLiteTensor* const input_tensor = GetInput(context, node, kInput);
+    const TfLiteTensor* const init_value_tensor =
+        GetInput(context, node, kInitValue);
+    const TfLiteTensor* const window_dimensions_tensor =
+        GetInput(context, node, kWindowShape);
+    const TfLiteTensor* const window_strides_tensor =
+        GetInput(context, node, kWindowStrides);
+    const TfLiteTensor* const window_dilations_tensor =
+        GetInput(context, node, kWindowDilations);
+    const TfLiteTensor* const output_tensor = GetOutput(context, node, kOutput);
+    TF_LITE_ENSURE(context, IsConstantTensor(window_dimensions_tensor));
+    TF_LITE_ENSURE(context, IsConstantTensor(window_strides_tensor));
+    TF_LITE_ENSURE(context, IsConstantTensor(window_dilations_tensor));
+    TF_LITE_ENSURE_EQ(context, input_tensor->type, output_tensor->type);
+    TF_LITE_ENSURE_EQ(context, input_tensor->type, init_value_tensor->type);
+    TF_LITE_ENSURE_EQ(context, window_dimensions_tensor->type, kTfLiteInt64);
+    TF_LITE_ENSURE_EQ(context, window_strides_tensor->type, kTfLiteInt64);
+    TF_LITE_ENSURE_EQ(context, window_dilations_tensor->type, kTfLiteInt64);
+    TF_LITE_ENSURE(context, input_tensor->dims != nullptr);
+    TF_LITE_ENSURE(context, input_tensor->dims->size > 0);
+    TF_LITE_ENSURE(context, input_tensor->dims->size <= kMaxReduceWindowRank);
+
+    return kTfLiteOk;
+  }
+
+  TfLiteStatus Initialize() {
+    TF_LITE_ENSURE_OK(context, InitializeBase<TFLiteData>());
+    window_dimensions = reinterpret_cast<const int64_t*>(
+        GetInput(context, node, kWindowShape)->data.data);
+    window_strides = reinterpret_cast<const int64_t*>(
+        GetInput(context, node, kWindowStrides)->data.data);
+    base_dilations = dilate::kTFLiteDefaultBaseDilation;
+    window_dilations = reinterpret_cast<const int64_t*>(
+        GetInput(context, node, kWindowDilations)->data.data);
+    padding = pad::kTFLiteDefaultPadding;
+    return kTfLiteOk;
+  }
+
+  TfLiteStatus Setup() {
+    NodeData& node_data = *reinterpret_cast<NodeData*>(node->user_data);
+    const auto& params =
+        *reinterpret_cast<TfLiteReduceWindowParams*>(node->builtin_data);
+    node_data.body = params.reduce_function;
+
+    node_data.dilate_ctx.skip = true;
+    node_data.pad_ctx.skip = true;
+    node_data.reduce_window_ctx = reduce_window::ReduceWindowData(
+        rank, input_dims, window_dimensions, window_strides, window_dilations);
+
+    TfLiteTensor* const output_tensor = GetOutput(context, node, kOutput);
+    return context->ResizeTensor(
+        context, output_tensor,
+        BuildTfLiteArray<int32_t>(rank,
+                                  node_data.reduce_window_ctx.output_shape)
+            .release());
+  }
+};
+
+// Applies the sub-ops that are needed to compute the whole
+// [STABLEHLO_]REDUCE_WINDOW op.
+//
+// The ops that aren't needed are skipped.
+template <class Op, class Type>
+void PadCropReduceWindow(const OpData& op_ctx) {
+  NodeData& node_data = *reinterpret_cast<NodeData*>(op_ctx.node->user_data);
+  const char* input = op_ctx.input;
+  const int64_t* input_shape = op_ctx.input_dims;
+
+  if (!node_data.dilate_ctx.skip) {
+    dilate::Dilate(node_data.dilate_ctx, input, op_ctx.init_value,
+                   op_ctx.dilate_output);
+    input = op_ctx.dilate_output;
+    input_shape = node_data.dilate_ctx.output_shape;
+  }
+
+  if (!node_data.pad_ctx.skip) {
+    pad::PadCrop(node_data.pad_ctx, input, op_ctx.init_value,
+                 op_ctx.pad_output);
+    input = op_ctx.pad_output;
+    input_shape = node_data.pad_ctx.output_shape;
+  }
+
+  reduce_window::ReduceWindow<Op, Type>(
+      node_data.reduce_window_ctx, reinterpret_cast<const Type*>(input),
+      *reinterpret_cast<const Type*>(op_ctx.init_value),
+      reinterpret_cast<Type*>(op_ctx.output));
+}
+
+// Dispatches to the template implementation according to the tensor type.
+template <class Op>
+TfLiteStatus DispatchReduceWindowType(OpData& ctx) {
+#define REDUCE_WINDOW_TYPE_CASE(CPP_TYPE, TENSOR_TYPE) \
+  case TENSOR_TYPE:                                    \
+    PadCropReduceWindow<Op, CPP_TYPE>(ctx);            \
+    break;
+  switch (ctx.type) {
+    REDUCE_WINDOW_TYPE_CASE(int8_t, kTfLiteBool);
+    REDUCE_WINDOW_TYPE_CASE(int8_t, kTfLiteInt8);
+    REDUCE_WINDOW_TYPE_CASE(int16_t, kTfLiteInt16);
+    REDUCE_WINDOW_TYPE_CASE(int32_t, kTfLiteInt32);
+    REDUCE_WINDOW_TYPE_CASE(int64_t, kTfLiteInt64);
+    REDUCE_WINDOW_TYPE_CASE(uint8_t, kTfLiteUInt8);
+    REDUCE_WINDOW_TYPE_CASE(float, kTfLiteFloat32);
+    REDUCE_WINDOW_TYPE_CASE(double, kTfLiteFloat64);
+    default:
+      TF_LITE_KERNEL_LOG(
+          ctx.context,
+          "%s:%d unsupported kernel data type (TfliteType: %d a.k.a %s).",
+          __FILE__, __LINE__, ctx.type, TfLiteTypeGetName(ctx.type));
+      return kTfLiteError;
+  }
+#undef REDUCE_WINDOW_TYPE_CASE
+  return kTfLiteOk;
+}
+
+struct Max {
+  template <class T>
+  constexpr T operator()(const T& a, const T& b) const {
+    return a >= b ? a : b;
+  }
+};
+
+struct Min {
+  template <class T>
+  constexpr T operator()(const T& a, const T& b) const {
+    return a <= b ? a : b;
+  }
+};
+
+// Dispatches to the template instanciation according to the reduction body.
+TfLiteStatus DispatchReduceWindowBody(OpData& ctx) {
+  const NodeData& node_data = *static_cast<NodeData*>(ctx.node->user_data);
+  switch (node_data.body) {
+    case TfLiteReduceWindowFunctionUnsupported:
+      TF_LITE_KERNEL_LOG(ctx.context, "%s:%d unsupported reduction body.\n",
+                         __FILE__, __LINE__);
+      return kTfLiteError;
+    case TfLiteReduceWindowFunctionAdd:
+      return DispatchReduceWindowType<std::plus<>>(ctx);
+    case TfLiteReduceWindowFunctionMul:
+      return DispatchReduceWindowType<std::multiplies<>>(ctx);
+    case TfLiteReduceWindowFunctionAll:
+      return DispatchReduceWindowType<std::logical_and<>>(ctx);
+    case TfLiteReduceWindowFunctionAny:
+      return DispatchReduceWindowType<std::logical_or<>>(ctx);
+    case TfLiteReduceWindowFunctionMin:
+      return DispatchReduceWindowType<Min>(ctx);
+    case TfLiteReduceWindowFunctionMax:
+      return DispatchReduceWindowType<Max>(ctx);
+  }
+  TF_LITE_KERNEL_LOG(ctx.context, "%s:%d unhandled reduction body case.\n",
+                     __FILE__, __LINE__);
+  return kTfLiteError;
+}
+
+// Initializes the node's user data when the STABLEHLO_REDUCE_WINDOW sematic is
+// used.
+void* StablehloInit(TfLiteContext* context, const char* options,
+                    size_t options_len) {
+  NodeData* node_data = new NodeData();
+  context->AddTensors(context, NodeData::kTempTensorCount,
+                      &node_data->temporary_tensor_offset);
+  return node_data;
+}
+
+void* TFLiteInit(TfLiteContext* context, const char* options,
+                 size_t options_len) {
+  return new NodeData();
+}
+
+// Frees the node's user data when the STABLEHLO_REDUCE_WINDOW sematic is used.
+void Free(TfLiteContext* context, void* node_data) {
+  delete static_cast<NodeData*>(node_data);
+}
+
+template <class Semantic>
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+  Semantic ctx(context, node);
+  TF_LITE_ENSURE_OK(context, ctx.Check());
+  TF_LITE_ENSURE_OK(context, ctx.Initialize());
+  return ctx.Setup();
+}
+
+template <class Semantic>
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+  Semantic ctx(context, node);
+  TF_LITE_ENSURE_OK(context, ctx.Initialize());
+  // Too much cropping can lead to a negative dimension.
+  //
+  // This never happens with the REDUCE_WINDOW (TFLiteData) semantic but since
+  // that op is deprecated we don't care about the extra check.
+  NodeData& node_data = *reinterpret_cast<NodeData*>(node->user_data);
+  TF_LITE_ENSURE_MSG(
+      context, node_data.pad_ctx.skip || node_data.pad_ctx.output_size > 0,
+      "The padding specification of stablehlo.reduce_window gives an empty "
+      "tensor.");
+  return DispatchReduceWindowBody(ctx);
+}
+
+}  // namespace
+}  // namespace reduce_window_op
+
+TfLiteRegistration* Register_STABLEHLO_REDUCE_WINDOW() {
+  static TfLiteRegistration r = {
+      /*.init=*/reduce_window_op::StablehloInit,
+      /*.free=*/reduce_window_op::Free,
+      /*.prepare=*/reduce_window_op::Prepare<reduce_window_op::StablehloData>,
+      /*.invoke=*/reduce_window_op::Eval<reduce_window_op::StablehloData>};
+  return &r;
+}
+
+TfLiteRegistration* Register_REDUCE_WINDOW() {
+  static TfLiteRegistration r = {
+      /*.init=*/reduce_window_op::TFLiteInit,
+      /*.free=*/reduce_window_op::Free,
+      /*.prepare=*/reduce_window_op::Prepare<reduce_window_op::TFLiteData>,
+      /*.invoke=*/reduce_window_op::Eval<reduce_window_op::TFLiteData>};
+  return &r;
+}
+
+}  // namespace builtin
+}  // namespace ops
+}  // namespace tflite
diff --git a/tensorflow/lite/kernels/stablehlo_reduce_window_test.cc b/tensorflow/lite/kernels/stablehlo_reduce_window_test.cc
new file mode 100644
index 0000000..a26c286
--- /dev/null
+++ b/tensorflow/lite/kernels/stablehlo_reduce_window_test.cc
@@ -0,0 +1,792 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+         //
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cstddef>
+#include <cstdint>
+#include <functional>
+#include <initializer_list>
+#include <limits>
+#include <ostream>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/algorithm/container.h"
+#include "absl/log/absl_log.h"
+#include "absl/random/bit_gen_ref.h"
+#include "absl/random/distributions.h"
+#include "absl/random/random.h"
+#include "absl/types/span.h"
+#include "tensorflow/lite/c/c_api_types.h"
+#include "tensorflow/lite/core/c/common.h"
+#include "tensorflow/lite/kernels/stablehlo_reduce_window_test_util.h"
+#include "tensorflow/lite/kernels/subgraph_test_util.h"
+#include "tensorflow/lite/kernels/test_util.h"
+#include "tensorflow/lite/schema/schema_generated.h"
+
+namespace tflite {
+namespace reduce_window {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::ElementsAreArray;
+
+// TF_LITE_ENSURE* family of macros require a context to be passed, which we do
+// not have when building the model.
+#define REDUCE_WINDOW_ENSURE_OK(expr)                        \
+  do {                                                       \
+    if (TfLiteStatus status = (expr); status != kTfLiteOk) { \
+      ABSL_LOG(ERROR) << #expr " failed.\n";                 \
+      return status;                                         \
+    }                                                        \
+  } while (false)
+
+// Returns kTfLiteError if the expression evaluates to false.
+#define REDUCE_WINDOW_ENSURE_IMPL(expr, msg) \
+  do {                                       \
+    if (!(expr)) {                           \
+      ABSL_LOG(ERROR) << #msg " failed.\n";  \
+      return kTfLiteError;                   \
+    }                                        \
+  } while (false)
+
+#define REDUCE_WINDOW_ENSURE(expr) REDUCE_WINDOW_ENSURE_IMPL((expr), #expr)
+
+#define REDUCE_WINDOW_ENSURE_EQ(a, b) \
+  REDUCE_WINDOW_ENSURE_IMPL((a) == (b), #a " == " #b)
+#define REDUCE_WINDOW_ENSURE_NE(a, b) \
+  REDUCE_WINDOW_ENSURE_IMPL((a) != (b), #a " != " #b)
+#define REDUCE_WINDOW_ENSURE_GE(a, b) \
+  REDUCE_WINDOW_ENSURE_IMPL((a) >= (b), #a " >= " #b)
+#define REDUCE_WINDOW_ENSURE_LE(a, b) \
+  REDUCE_WINDOW_ENSURE_IMPL((a) <= (b), #a " <= " #b)
+#define REDUCE_WINDOW_ENSURE_GT(a, b) \
+  REDUCE_WINDOW_ENSURE_IMPL((a) > (b), #a " > " #b)
+#define REDUCE_WINDOW_ENSURE_LT(a, b) \
+  REDUCE_WINDOW_ENSURE_IMPL((a) < (b), #a " < " #b)
+#define REDUCE_WINDOW_ENSURE_UNREACHABLE(msg) \
+  REDUCE_WINDOW_ENSURE_IMPL(false, msg)
+
+// Maps the native C++ types to the corresponding TFLite tensor type enum
+// values.
+template <class T>
+struct TensorTypeFor;
+
+#define TENSOR_TYPE_ASSOC(CPP_TYPE, TENSORTYPE_VALUE)     \
+  template <>                                             \
+  struct TensorTypeFor<CPP_TYPE> {                        \
+    static constexpr TensorType value = TENSORTYPE_VALUE; \
+  };
+
+TENSOR_TYPE_ASSOC(int8_t, TensorType_INT8);
+TENSOR_TYPE_ASSOC(int16_t, TensorType_INT16);
+TENSOR_TYPE_ASSOC(int32_t, TensorType_INT32);
+TENSOR_TYPE_ASSOC(int64_t, TensorType_INT64);
+TENSOR_TYPE_ASSOC(uint8_t, TensorType_UINT8);
+TENSOR_TYPE_ASSOC(uint16_t, TensorType_UINT16);
+TENSOR_TYPE_ASSOC(uint32_t, TensorType_UINT32);
+TENSOR_TYPE_ASSOC(uint64_t, TensorType_UINT64);
+TENSOR_TYPE_ASSOC(float, TensorType_FLOAT32);
+static_assert(sizeof(float) == 4, "float type is expected to be 32 bit long");
+TENSOR_TYPE_ASSOC(double, TensorType_FLOAT64);
+static_assert(sizeof(double) == 8, "double type is expected to be 64 bit long");
+
+enum class BodyFunction {
+  kUnset,
+  kUnsupported,
+  kAdd,
+  kMul,
+  kMax,
+  kMin,
+  kAll,
+  kAny
+};
+
+std::ostream& operator<<(std::ostream& os, const BodyFunction& f) {
+  switch (f) {
+    case BodyFunction::kUnset:
+      return os << "unset";
+    case BodyFunction::kUnsupported:
+      return os << "unsupported";
+    case BodyFunction::kAdd:
+      return os << "add";
+    case BodyFunction::kMul:
+      return os << "mul";
+    case BodyFunction::kMax:
+      return os << "max";
+    case BodyFunction::kMin:
+      return os << "min";
+    case BodyFunction::kAll:
+      return os << "all";
+    case BodyFunction::kAny:
+      return os << "any";
+  }
+  return os;
+}
+
+template <class T>
+class ReduceWindowOpModel : public SingleOpModel {
+  static constexpr TensorType kTensorType = TensorTypeFor<T>::value;
+
+ public:
+  // Sets the input tensor shape and data.
+  //
+  // If the data isn't provided, the buffer is filled with `iota`.
+  void SetInput(absl::Span<const int64_t> shape) {
+    input_shape_.assign(shape.begin(), shape.end());
+    input_data_.resize(absl::c_accumulate(shape, 1, std::multiplies<>()));
+    absl::c_iota(input_data_, 1);
+  }
+
+  void SetInput(absl::Span<const int64_t> shape, absl::Span<const T> data) {
+    input_shape_.assign(shape.begin(), shape.end());
+    input_data_.assign(data.begin(), data.end());
+  }
+
+  void SetInput(absl::Span<const int64_t> shape, absl::BitGenRef bitgen, T min,
+                T max) {
+    input_shape_.assign(shape.begin(), shape.end());
+    input_data_.resize(absl::c_accumulate(shape, 1, std::multiplies<>()));
+    absl::c_generate(input_data_, [&] {
+      return absl::Uniform(absl::IntervalClosed, bitgen, min, max);
+    });
+  }
+
+  void SetWindowDimensions(absl::Span<const int64_t> dimensions) {
+    window_dimensions_.assign(dimensions.begin(), dimensions.end());
+  }
+
+  // Note: the strides are counted in elements on the tensor grid not in the
+  // underlying buffer.
+  //
+  // For instance, with {2,2} window strides on the following matrix, the window
+  // anchored at element 1 will reach elements 3 (+2 horizontally), 7 (+2
+  // vertically) and 9 (+2 vertically, +2 horizontally):
+  //
+  // 1 2 3
+  // 4 5 6
+  // 7 8 9
+  void SetWindowStrides(absl::Span<const int64_t> strides) {
+    window_strides_.assign(strides.begin(), strides.end());
+  }
+
+  void SetBaseDilations(absl::Span<const int64_t> dilations) {
+    base_dilations_.assign(dilations.begin(), dilations.end());
+  }
+
+  void SetWindowDilations(absl::Span<const int64_t> dilations) {
+    window_dilations_.assign(dilations.begin(), dilations.end());
+  }
+
+  void SetPadding(absl::Span<const int64_t> padding) {
+    padding_.assign(padding.begin(), padding.end());
+  }
+
+  void SetInitValue(const T& val) { init_value_ = val; }
+
+  void SetBody(const BodyFunction func) { body_function_ = func; }
+
+  TfLiteStatus Build() {
+    constexpr int kBodySubGraphIndex = 1;
+
+    REDUCE_WINDOW_ENSURE(!input_shape_.empty());
+    REDUCE_WINDOW_ENSURE_EQ(window_dimensions_.size(), input_shape_.size());
+    REDUCE_WINDOW_ENSURE_EQ(window_strides_.size(), input_shape_.size());
+    REDUCE_WINDOW_ENSURE_EQ(base_dilations_.size(), input_shape_.size());
+    REDUCE_WINDOW_ENSURE_EQ(window_dilations_.size(), input_shape_.size());
+    REDUCE_WINDOW_ENSURE_EQ(padding_.size(), 2 * input_shape_.size());
+    REDUCE_WINDOW_ENSURE_NE(body_function_, BodyFunction::kUnset);
+    REDUCE_WINDOW_ENSURE_NE(body_function_, BodyFunction::kUnsupported);
+
+    input_tensor_id_ =
+        AddInput({kTensorType,
+                  std::vector<int>(input_shape_.begin(), input_shape_.end())});
+    init_value_tensor_id_ = AddConstInput(kTensorType, {init_value_}, {1});
+    output_tensor_id_ = AddOutput(kTensorType);
+
+    SetBuiltinOp(BuiltinOperator_STABLEHLO_REDUCE_WINDOW,
+                 BuiltinOptions2_StablehloReduceWindowOptions,
+                 CreateStablehloReduceWindowOptions(
+                     builder_, builder_.CreateVector(window_dimensions_),
+                     builder_.CreateVector(window_strides_),
+                     builder_.CreateVector(base_dilations_),
+                     builder_.CreateVector(window_dilations_),
+                     builder_.CreateVector(padding_), kBodySubGraphIndex)
+                     .Union());
+
+    BuildInterpreter(
+        /*input_shapes=*/{std::vector<int>(input_shape_.begin(),
+                                           input_shape_.end())},
+        /*num_threads=*/-1, /*allow_fp32_relax_to_fp16=*/false,
+        /*apply_delegate=*/true, /*allocate_and_delegate=*/false,
+        /*use_simple_allocator=*/false);
+
+    int body_subgraph_index;
+    AddSubgraphs(1, &body_subgraph_index);
+    REDUCE_WINDOW_ENSURE_EQ(body_subgraph_index, kBodySubGraphIndex);
+    switch (body_function_) {
+      case BodyFunction::kAdd:
+        subgraph_builder_.BuildAddSubgraph(
+            interpreter_->subgraph(body_subgraph_index));
+        break;
+      case BodyFunction::kMul:
+        subgraph_builder_.BuildMulSubgraph(
+            interpreter_->subgraph(body_subgraph_index));
+        break;
+      case BodyFunction::kMax:
+        subgraph_builder_.BuildMaximumSubgraph(
+            interpreter_->subgraph(body_subgraph_index));
+        break;
+      case BodyFunction::kMin:
+        subgraph_builder_.BuildMinimumSubgraph(
+            interpreter_->subgraph(body_subgraph_index));
+        break;
+      case BodyFunction::kAll:
+        subgraph_builder_.BuildLogicalAndSubgraph(
+            interpreter_->subgraph(body_subgraph_index));
+        break;
+      case BodyFunction::kAny:
+        subgraph_builder_.BuildLogicalOrSubgraph(
+            interpreter_->subgraph(body_subgraph_index));
+        break;
+      default:
+        REDUCE_WINDOW_ENSURE_UNREACHABLE("Unhandled body function enum value.");
+    }
+
+    AllocateAndDelegate(/*apply_delegate=*/true);
+
+    PopulateTensor(input_tensor_id_, input_data_);
+    return kTfLiteOk;
+  }
+
+  TfLiteStatus BuildAndInvoke() {
+    REDUCE_WINDOW_ENSURE_OK(Build());
+    return Invoke();
+  }
+
+  absl::Span<const T> GetOutputData() {
+    return absl::Span<const T>(interpreter_->typed_tensor<T>(output_tensor_id_),
+                               GetTensorSize(output_tensor_id_));
+  }
+
+  absl::Span<const int> GetOutputShape() {
+    const TfLiteIntArray& shape =
+        *(interpreter_->tensor(output_tensor_id_)->dims);
+    return absl::Span<const int>(shape.data, shape.size);
+  }
+
+  const std::vector<T>& GetInput() const { return input_data_; }
+
+  const std::vector<int64_t>& GetInputShape() const { return input_shape_; }
+
+  const std::vector<int64_t>& GetWindowDimensions() const {
+    return window_dimensions_;
+  }
+
+  const std::vector<int64_t>& GetWindowStrides() const {
+    return window_strides_;
+  }
+
+  const std::vector<int64_t>& GetBaseDilations() const {
+    return base_dilations_;
+  }
+
+  const std::vector<int64_t>& GetWindowDilations() const {
+    return window_dilations_;
+  }
+
+  const std::vector<int64_t>& GetPadding() const { return padding_; }
+
+  const T& GetInitValue() const { return init_value_; }
+
+  const BodyFunction& GetBodyFunction() const { return body_function_; }
+
+  friend std::ostream& operator<<(std::ostream& os,
+                                  const ReduceWindowOpModel& model) {
+    using Adapt = ReduceWindowOpModel::VectorOutputAdapter;
+    os << "input dimensions: {" << Adapt{model.GetInputShape()} << "}\n";
+    os << "  base dilations: {" << Adapt{model.GetBaseDilations()} << "}\n";
+    os << "  padding: {" << Adapt{model.GetPadding()} << "}\n";
+    os << "  window dimensions: {" << Adapt{model.GetWindowDimensions()}
+       << "}\n";
+    os << "  window dilations: {" << Adapt{model.GetWindowDilations()} << "}\n";
+    os << "  window strides: {" << Adapt{model.GetWindowStrides()} << "}\n";
+    os << "  init value: " << +model.GetInitValue() << "\n";
+    os << "  body function: " << model.GetBodyFunction() << "\n";
+    return os;
+  }
+
+ protected:
+  struct VectorOutputAdapter {
+    const std::vector<int64_t>& data;
+    friend std::ostream& operator<<(std::ostream& os,
+                                    const VectorOutputAdapter& vec) {
+      if (!vec.data.empty()) {
+        os << +vec.data[0];
+        for (size_t i = 1; i < vec.data.size(); ++i) {
+          os << ", " << +vec.data[i];
+        }
+      }
+      return os;
+    }
+  };
+
+  int input_tensor_id_ = -1;
+  int init_value_tensor_id_ = -1;
+  int output_tensor_id_ = -1;
+  std::vector<T> input_data_;
+  T init_value_;
+  std::vector<int64_t> input_shape_;
+  std::vector<int64_t> window_dimensions_;
+  std::vector<int64_t> window_strides_;
+  std::vector<int64_t> base_dilations_;
+  std::vector<int64_t> window_dilations_;
+  std::vector<int64_t> padding_;
+  BodyFunction body_function_{};
+  subgraph_test_util::SubgraphBuilder subgraph_builder_;
+};
+
+template <class StorageType>
+class StablehloReduceWindowTest : public testing::Test {};
+
+using TestList =
+    testing::Types<int8_t, int16_t, int32_t, int64_t, uint8_t, float, double>;
+TYPED_TEST_SUITE(StablehloReduceWindowTest, TestList);
+
+TYPED_TEST(StablehloReduceWindowTest, Identity) {
+  ReduceWindowOpModel<TypeParam> model;
+  model.SetInput(/*shape=*/{3, 3});
+  model.SetBaseDilations({1, 1});
+  model.SetPadding({0, 0, 0, 0});
+  model.SetWindowDimensions({1, 1});
+  model.SetWindowStrides({1, 1});
+  model.SetWindowDilations({1, 1});
+  model.SetInitValue(0);
+  model.SetBody(BodyFunction::kAdd);
+
+  ASSERT_EQ(model.BuildAndInvoke(), kTfLiteOk);
+  EXPECT_THAT(model.GetOutputShape(), ElementsAre(3, 3));
+  EXPECT_THAT(model.GetOutputData(), ElementsAre(1, 2, 3, 4, 5, 6, 7, 8, 9));
+}
+
+TYPED_TEST(StablehloReduceWindowTest, Dilate) {
+  ReduceWindowOpModel<TypeParam> model;
+  model.SetInput(/*shape=*/{3, 3});
+  model.SetBaseDilations({2, 2});
+  model.SetPadding({0, 0, 0, 0});
+  model.SetWindowDimensions({1, 1});
+  model.SetWindowStrides({1, 1});
+  model.SetWindowDilations({1, 1});
+  model.SetInitValue(0);
+  model.SetBody(BodyFunction::kAdd);
+
+  ASSERT_EQ(model.BuildAndInvoke(), kTfLiteOk);
+  EXPECT_THAT(model.GetOutputShape(), ElementsAre(5, 5));
+  EXPECT_THAT(model.GetOutputData(),
+              ElementsAreArray({1, 0, 2, 0, 3, 0, 0, 0, 0, 0, 4, 0, 5,
+                                0, 6, 0, 0, 0, 0, 0, 7, 0, 8, 0, 9}));
+}
+
+TYPED_TEST(StablehloReduceWindowTest, IdentityPadTop) {
+  ReduceWindowOpModel<TypeParam> model;
+  model.SetInput(/*shape=*/{3, 3});
+  model.SetBaseDilations({1, 1});
+  model.SetPadding({1, 0, 0, 0});
+  model.SetWindowDimensions({1, 1});
+  model.SetWindowStrides({1, 1});
+  model.SetWindowDilations({1, 1});
+  model.SetInitValue(0);
+  model.SetBody(BodyFunction::kAdd);
+
+  ASSERT_EQ(model.BuildAndInvoke(), kTfLiteOk);
+  EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3));
+  EXPECT_THAT(model.GetOutputData(),
+              ElementsAreArray({0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}));
+}
+
+TYPED_TEST(StablehloReduceWindowTest, IdentityPadBottom) {
+  ReduceWindowOpModel<TypeParam> model;
+  model.SetInput(/*shape=*/{3, 3});
+  model.SetBaseDilations({1, 1});
+  model.SetPadding({0, 1, 0, 0});
+  model.SetWindowDimensions({1, 1});
+  model.SetWindowStrides({1, 1});
+  model.SetWindowDilations({1, 1});
+  model.SetInitValue(0);
+  model.SetBody(BodyFunction::kAdd);
+
+  ASSERT_EQ(model.BuildAndInvoke(), kTfLiteOk);
+  EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3));
+  EXPECT_THAT(model.GetOutputData(),
+              ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0, 0}));
+}
+
+TYPED_TEST(StablehloReduceWindowTest, IdentityPadLeft) {
+  ReduceWindowOpModel<TypeParam> model;
+  model.SetInput(/*shape=*/{3, 3});
+  model.SetBaseDilations({1, 1});
+  model.SetPadding({0, 0, 1, 0});
+  model.SetWindowDimensions({1, 1});
+  model.SetWindowStrides({1, 1});
+  model.SetWindowDilations({1, 1});
+  model.SetInitValue(0);
+  model.SetBody(BodyFunction::kAdd);
+
+  ASSERT_EQ(model.BuildAndInvoke(), kTfLiteOk);
+  EXPECT_THAT(model.GetOutputShape(), ElementsAre(3, 4));
+  EXPECT_THAT(model.GetOutputData(),
+              ElementsAreArray({0, 1, 2, 3, 0, 4, 5, 6, 0, 7, 8, 9}));
+}
+
+TYPED_TEST(StablehloReduceWindowTest, IdentityPadRight) {
+  ReduceWindowOpModel<TypeParam> model;
+  model.SetInput(/*shape=*/{3, 3});
+  model.SetBaseDilations({1, 1});
+  model.SetPadding({0, 0, 0, 1});
+  model.SetWindowDimensions({1, 1});
+  model.SetWindowStrides({1, 1});
+  model.SetWindowDilations({1, 1});
+  model.SetInitValue(0);
+  model.SetBody(BodyFunction::kAdd);
+
+  ASSERT_EQ(model.BuildAndInvoke(), kTfLiteOk);
+  EXPECT_THAT(model.GetOutputShape(), ElementsAre(3, 4));
+  EXPECT_THAT(model.GetOutputData(),
+              ElementsAreArray({1, 2, 3, 0, 4, 5, 6, 0, 7, 8, 9, 0}));
+}
+
+TYPED_TEST(StablehloReduceWindowTest, IdentityPadAll) {
+  ReduceWindowOpModel<TypeParam> model;
+  model.SetInput(/*shape=*/{3, 3});
+  model.SetBaseDilations({1, 1});
+  model.SetPadding({1, 1, 1, 1});
+  model.SetWindowDimensions({1, 1});
+  model.SetWindowStrides({1, 1});
+  model.SetWindowDilations({1, 1});
+  model.SetInitValue(0);
+  model.SetBody(BodyFunction::kAdd);
+
+  ASSERT_EQ(model.BuildAndInvoke(), kTfLiteOk);
+  EXPECT_THAT(model.GetOutputShape(), ElementsAre(5, 5));
+  EXPECT_THAT(model.GetOutputData(),
+              ElementsAreArray({0, 0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 4, 5,
+                                6, 0, 0, 7, 8, 9, 0, 0, 0, 0, 0, 0}));
+}
+
+TYPED_TEST(StablehloReduceWindowTest, IdentityCropTop) {
+  ReduceWindowOpModel<TypeParam> model;
+  model.SetInput(/*shape=*/{3, 3});
+  model.SetBaseDilations({1, 1});
+  model.SetPadding({-1, 0, 0, 0});
+  model.SetWindowDimensions({1, 1});
+  model.SetWindowStrides({1, 1});
+  model.SetWindowDilations({1, 1});
+  model.SetInitValue(0);
+  model.SetBody(BodyFunction::kAdd);
+
+  ASSERT_EQ(model.BuildAndInvoke(), kTfLiteOk);
+  EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3));
+  EXPECT_THAT(model.GetOutputData(), ElementsAreArray({4, 5, 6, 7, 8, 9}));
+}
+
+TYPED_TEST(StablehloReduceWindowTest, IdentityCropBottom) {
+  ReduceWindowOpModel<TypeParam> model;
+  model.SetInput(/*shape=*/{3, 3});
+  model.SetBaseDilations({1, 1});
+  model.SetPadding({0, -1, 0, 0});
+  model.SetWindowDimensions({1, 1});
+  model.SetWindowStrides({1, 1});
+  model.SetWindowDilations({1, 1});
+  model.SetInitValue(0);
+  model.SetBody(BodyFunction::kAdd);
+
+  ASSERT_EQ(model.BuildAndInvoke(), kTfLiteOk);
+  EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3));
+  EXPECT_THAT(model.GetOutputData(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
+}
+
+TYPED_TEST(StablehloReduceWindowTest, IdentityCropLeft) {
+  ReduceWindowOpModel<TypeParam> model;
+  model.SetInput(/*shape=*/{3, 3});
+  model.SetBaseDilations({1, 1});
+  model.SetPadding({0, 0, -1, 0});
+  model.SetWindowDimensions({1, 1});
+  model.SetWindowStrides({1, 1});
+  model.SetWindowDilations({1, 1});
+  model.SetInitValue(0);
+  model.SetBody(BodyFunction::kAdd);
+
+  ASSERT_EQ(model.BuildAndInvoke(), kTfLiteOk);
+  EXPECT_THAT(model.GetOutputShape(), ElementsAre(3, 2));
+  EXPECT_THAT(model.GetOutputData(), ElementsAreArray({2, 3, 5, 6, 8, 9}));
+}
+
+TYPED_TEST(StablehloReduceWindowTest, IdentityCropRight) {
+  ReduceWindowOpModel<TypeParam> model;
+  model.SetInput(/*shape=*/{3, 3});
+  model.SetBaseDilations({1, 1});
+  model.SetPadding({0, 0, 0, -1});
+  model.SetWindowDimensions({1, 1});
+  model.SetWindowStrides({1, 1});
+  model.SetWindowDilations({1, 1});
+  model.SetInitValue(0);
+  model.SetBody(BodyFunction::kAdd);
+
+  ASSERT_EQ(model.BuildAndInvoke(), kTfLiteOk);
+  EXPECT_THAT(model.GetOutputShape(), ElementsAre(3, 2));
+  EXPECT_THAT(model.GetOutputData(), ElementsAreArray({1, 2, 4, 5, 7, 8}));
+}
+
+TYPED_TEST(StablehloReduceWindowTest, IdentityCropAll) {
+  ReduceWindowOpModel<TypeParam> model;
+  model.SetInput(/*shape=*/{3, 3});
+  model.SetBaseDilations({1, 1});
+  model.SetPadding({-1, -1, -1, -1});
+  model.SetWindowDimensions({1, 1});
+  model.SetWindowStrides({1, 1});
+  model.SetWindowDilations({1, 1});
+  model.SetInitValue(0);
+  model.SetBody(BodyFunction::kAdd);
+
+  ASSERT_EQ(model.BuildAndInvoke(), kTfLiteOk);
+  EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1));
+  EXPECT_THAT(model.GetOutputData(), ElementsAre(5));
+}
+
+TYPED_TEST(StablehloReduceWindowTest, ReduceWindowFullWindow) {
+  ReduceWindowOpModel<TypeParam> model;
+  model.SetInput(/*shape=*/{3, 3});
+  model.SetBaseDilations({1, 1});
+  model.SetPadding({0, 0, 0, 0});
+  model.SetWindowDimensions({3, 3});
+  model.SetWindowStrides({1, 1});
+  model.SetWindowDilations({1, 1});
+  model.SetInitValue(0);
+  model.SetBody(BodyFunction::kAdd);
+
+  ASSERT_EQ(model.BuildAndInvoke(), kTfLiteOk);
+  EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1));
+  EXPECT_THAT(model.GetOutputData(), ElementsAre(45));
+}
+
+TYPED_TEST(StablehloReduceWindowTest, ReduceWindowNoDilation) {
+  ReduceWindowOpModel<TypeParam> model;
+  model.SetInput(/*shape=*/{3, 3});
+  model.SetBaseDilations({1, 1});
+  model.SetPadding({0, 0, 0, 0});
+  model.SetBody(BodyFunction::kAdd);
+  model.SetWindowDimensions({2, 2});
+  model.SetWindowStrides({1, 1});
+  model.SetWindowDilations({1, 1});
+  model.SetInitValue(0);
+
+  ASSERT_EQ(model.BuildAndInvoke(), kTfLiteOk);
+  EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 2));
+  EXPECT_THAT(model.GetOutputData(), ElementsAre(12, 16, 24, 28));
+}
+
+TYPED_TEST(StablehloReduceWindowTest, ReduceWindowFullWindowWithDilation) {
+  ReduceWindowOpModel<TypeParam> model;
+  model.SetInput(/*shape=*/{3, 3});
+  model.SetBaseDilations({1, 1});
+  model.SetPadding({0, 0, 0, 0});
+  model.SetBody(BodyFunction::kAdd);
+  model.SetWindowDimensions({2, 2});
+  model.SetWindowStrides({1, 1});
+  model.SetWindowDilations({2, 2});
+  model.SetInitValue(0);
+
+  ASSERT_EQ(model.BuildAndInvoke(), kTfLiteOk);
+  EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1));
+  EXPECT_THAT(model.GetOutputData(), ElementsAre(20));
+}
+
+TYPED_TEST(StablehloReduceWindowTest, ReduceWindowWithDilation) {
+  ReduceWindowOpModel<TypeParam> model;
+  model.SetInput(/*shape=*/{4, 4});
+  model.SetBaseDilations({1, 1});
+  model.SetPadding({0, 0, 0, 0});
+  model.SetBody(BodyFunction::kAdd);
+  model.SetWindowDimensions({2, 2});
+  model.SetWindowStrides({1, 1});
+  model.SetWindowDilations({2, 2});
+  model.SetInitValue(0);
+
+  ASSERT_EQ(model.BuildAndInvoke(), kTfLiteOk);
+  EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 2));
+  EXPECT_THAT(model.GetOutputData(), ElementsAre(24, 28, 40, 44));
+}
+
+TYPED_TEST(StablehloReduceWindowTest, ReduceWindowWithStrides) {
+  ReduceWindowOpModel<TypeParam> model;
+  model.SetInput(/*shape=*/{4, 4});
+  model.SetBaseDilations({1, 1});
+  model.SetPadding({0, 0, 0, 0});
+  model.SetBody(BodyFunction::kAdd);
+  model.SetWindowDimensions({2, 2});
+  model.SetWindowStrides({2, 2});
+  model.SetWindowDilations({1, 1});
+  model.SetInitValue(0);
+
+  ASSERT_EQ(model.BuildAndInvoke(), kTfLiteOk);
+  EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 2));
+  EXPECT_THAT(model.GetOutputData(), ElementsAre(14, 22, 46, 54));
+}
+
+TYPED_TEST(StablehloReduceWindowTest, ReduceWindowWithDilationAndStrides) {
+  ReduceWindowOpModel<TypeParam> model;
+  model.SetInput(/*shape=*/{5, 5});
+  model.SetBaseDilations({1, 1});
+  model.SetPadding({0, 0, 0, 0});
+  model.SetBody(BodyFunction::kAdd);
+  model.SetWindowDimensions({2, 2});
+  model.SetWindowStrides({2, 2});
+  model.SetWindowDilations({2, 2});
+  model.SetInitValue(2);
+
+  ASSERT_EQ(model.BuildAndInvoke(), kTfLiteOk);
+  EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 2));
+  EXPECT_THAT(model.GetOutputData(), ElementsAre(30, 38, 70, 78));
+}
+
+TYPED_TEST(StablehloReduceWindowTest,
+           ReduceWindowOutputShapeRoundingIsCorrect) {
+  ReduceWindowOpModel<TypeParam> model;
+  model.SetInput(/*shape=*/{1, 64, 114, 114});
+  model.SetBaseDilations({1, 1, 1, 1});
+  model.SetPadding({0, 0, 0, 0, 0, 0, 0, 0});
+  model.SetBody(BodyFunction::kAdd);
+  model.SetWindowDimensions({1, 1, 3, 3});
+  model.SetWindowStrides({1, 1, 2, 2});
+  model.SetWindowDilations({1, 1, 1, 1});
+  model.SetInitValue(2);
+
+  ASSERT_EQ(model.BuildAndInvoke(), kTfLiteOk);
+  EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 64, 56, 56));
+}
+
+// Returns a vector of given size with elements in the range [min, max].
+template <class T>
+std::vector<T> RandomVector(absl::BitGen& bitgen, size_t size, T min, T max) {
+  std::vector<T> vec(size);
+  for (T& v : vec) {
+    v = absl::Uniform(absl::IntervalClosed, bitgen, min, max);
+  }
+  return vec;
+}
+
+struct Body {
+  static Body GetRandomSupported(absl::BitGen& bitgen) {
+    return Body{/*.body=*/static_cast<BodyFunction>(absl::Uniform<int>(
+        absl::IntervalClosed, bitgen, static_cast<int>(BodyFunction::kAdd),
+        static_cast<int>(BodyFunction::kAny)))};
+  }
+
+  template <class T>
+  T operator()(const T& a, const T& b) const noexcept {
+    switch (func) {
+      case BodyFunction::kUnset:
+      case BodyFunction::kUnsupported:
+        return -1;
+      case BodyFunction::kAdd:
+        return a + b;
+      case BodyFunction::kMul:
+        return a * b;
+      case BodyFunction::kMin:
+        return a <= b ? a : b;
+      case BodyFunction::kMax:
+        return a >= b ? a : b;
+      case BodyFunction::kAll:
+        return a && b;
+      case BodyFunction::kAny:
+        return a || b;
+    }
+  }
+
+  template <class T>
+  T init_value() const noexcept {
+    switch (func) {
+      case BodyFunction::kUnset:
+      case BodyFunction::kUnsupported:
+        return -1;
+      case BodyFunction::kAdd:
+        return 0;
+      case BodyFunction::kMul:
+        return 1;
+      case BodyFunction::kMin:
+        return std::numeric_limits<T>::max();
+      case BodyFunction::kMax:
+        return std::numeric_limits<T>::lowest();
+      case BodyFunction::kAll:
+        return true;
+      case BodyFunction::kAny:
+        return false;
+    }
+  }
+
+  BodyFunction func;
+};
+
+TYPED_TEST(StablehloReduceWindowTest, FuzzyTest) {
+  absl::BitGen bitgen;
+
+  for (size_t iteration = 0; iteration < 1000; ++iteration) {
+    const int rank = absl::Uniform(absl::IntervalClosed, bitgen, 1, 3);
+
+    ReduceWindowOpModel<TypeParam> model;
+    Body body = Body::GetRandomSupported(bitgen);
+    model.SetInput(
+        /*shape=*/RandomVector<int64_t>(bitgen, rank, /*min=*/1, /*max=*/10),
+        bitgen, /*min=*/-5, /*max=*/5);
+    model.SetBaseDilations(
+        RandomVector<int64_t>(bitgen, rank, /*min=*/1, /*max=*/3));
+    model.SetPadding(
+        RandomVector<int64_t>(bitgen, 2 * rank, /*min=*/-5, /*max=*/5));
+    model.SetWindowDimensions(
+        RandomVector<int64_t>(bitgen, rank, /*min=*/1, /*max=*/3));
+    model.SetWindowStrides(
+        RandomVector<int64_t>(bitgen, rank, /*min=*/1, /*max=*/3));
+    model.SetWindowDilations(
+        RandomVector<int64_t>(bitgen, rank, /*min=*/1, /*max=*/3));
+    model.SetInitValue(body.init_value<TypeParam>());
+    model.SetBody(body.func);
+
+    // Skip invalid specifications.
+    const std::vector<int64_t> padded_shape = reference::PadCropShape(
+        reference::DilateShape(model.GetInputShape(), model.GetBaseDilations()),
+        model.GetPadding());
+    if (absl::c_any_of(padded_shape, [](int64_t d) { return d <= 0; })) {
+      iteration = iteration > 1 ? iteration - 1 : 0;
+      continue;
+    }
+
+    const reference::Tensor<TypeParam> expected = reference::ReduceWindow(
+        reference::Tensor<TypeParam>{/*shape=*/model.GetInputShape(),
+                                     /*data=*/model.GetInput()},
+        model.GetBaseDilations(), model.GetPadding(), model.GetInitValue(),
+        model.GetWindowDimensions(), model.GetWindowDilations(),
+        model.GetWindowStrides(), body);
+
+    ASSERT_EQ(model.BuildAndInvoke(), kTfLiteOk);
+    EXPECT_THAT(model.GetOutputShape(), ElementsAreArray(expected.shape))
+        << model;
+    EXPECT_THAT(model.GetOutputData(), ElementsAreArray(expected.data))
+        << model;
+  }
+}
+
+}  // namespace
+}  // namespace reduce_window
+}  // namespace tflite
diff --git a/tensorflow/lite/kernels/stablehlo_reduce_window_test_util.h b/tensorflow/lite/kernels/stablehlo_reduce_window_test_util.h
new file mode 100644
index 0000000..c514587
--- /dev/null
+++ b/tensorflow/lite/kernels/stablehlo_reduce_window_test_util.h
@@ -0,0 +1,402 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_KERNELS_STABLEHLO_REDUCE_WINDOW_TEST_UTIL_H_
+#define TENSORFLOW_LITE_KERNELS_STABLEHLO_REDUCE_WINDOW_TEST_UTIL_H_
+
+#include <algorithm>
+#include <cstddef>
+#include <cstdint>
+#include <functional>
+#include <initializer_list>
+#include <numeric>
+#include <utility>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+
+namespace tflite {
+namespace reduce_window {
+namespace reference {
+
+constexpr int kMaxDims = 6;
+
+// Holds a buffer and the shape associated to a tensor.
+template <class T>
+struct Tensor {
+  std::vector<int64_t> shape;
+  std::vector<T> data;
+
+  // Builds a tensor using the given shape and fill it with the given initial
+  // value.
+  static Tensor<T> FromShape(std::vector<int64_t> shape,
+                             const T init_value = 0) {
+    Tensor tensor{std::move(shape)};
+    tensor.data.resize(tensor.size(), init_value);
+    return tensor;
+  }
+
+  // Builds a tensor using the given shape and fill it with incrementing values
+  // starting from 1.
+  template <class I>
+  static Tensor<T> iota(std::initializer_list<I> shape) {
+    Tensor<T> tensor;
+    tensor.shape.assign(shape.begin(), shape.end());
+    tensor.data.resize(absl::c_accumulate(shape, 1, std::multiplies<>()));
+    absl::c_iota(tensor.data, 1);
+    return tensor;
+  }
+
+  // Returns the number of values in the tensor.
+  int64_t size() const {
+    return absl::c_accumulate(shape, 1, std::multiplies<>());
+  }
+
+  // Computes the strides for each valid dimension in the tensor.
+  //
+  // The returned vector always has a `kMaxDims` size.
+  std::vector<int64_t> Strides() const {
+    std::vector<int64_t> strides(kMaxDims, 0);
+    if (!shape.empty()) {
+      strides[shape.size() - 1] = 1;
+      for (size_t i = shape.size() - 1; i > 0; --i) {
+        strides[i - 1] = shape[i] * strides[i];
+      }
+    }
+    return strides;
+  }
+};
+
+// Returns a new vector resized to `kMaxDims` with `val` a a default value.
+inline std::vector<int64_t> ExtendToMaxDim(std::vector<int64_t> vec,
+                                           int64_t val = 0) {
+  vec.resize(kMaxDims, val);
+  return vec;
+}
+
+inline std::vector<int64_t> DilateShape(std::vector<int64_t> shape,
+                                        const std::vector<int64_t> dilations) {
+  for (size_t i = 0; i < shape.size(); ++i) {
+    shape[i] = (shape[i] - 1) * dilations[i] + 1;
+  }
+  return shape;
+}
+
+template <class T>
+Tensor<T> Dilate(const Tensor<T>& input, const std::vector<int64_t>& dilations,
+                 const T padding_value) {
+  Tensor<T> output =
+      Tensor<T>::FromShape(DilateShape(input.shape, dilations), padding_value);
+
+  const std::vector<int64_t> strides = input.Strides();
+  const std::vector<int64_t> output_strides = output.Strides();
+  const std::vector<int64_t> safe_dilations = ExtendToMaxDim(dilations);
+  const std::vector<int64_t> safe_input_shape = ExtendToMaxDim(input.shape);
+
+  int a = 0;
+  do {
+    int b = 0;
+    do {
+      int c = 0;
+      do {
+        int d = 0;
+        do {
+          int e = 0;
+          do {
+            int f = 0;
+            do {
+              const int i_idx = a * strides[0] + b * strides[1] +
+                                c * strides[2] + d * strides[3] +
+                                e * strides[4] + f * strides[5];
+              const int o_idx = a * safe_dilations[0] * output_strides[0] +
+                                b * safe_dilations[1] * output_strides[1] +
+                                c * safe_dilations[2] * output_strides[2] +
+                                d * safe_dilations[3] * output_strides[3] +
+                                e * safe_dilations[4] * output_strides[4] +
+                                f * safe_dilations[5] * output_strides[5];
+              output.data[o_idx] = input.data[i_idx];
+            } while (++f < safe_input_shape[5]);
+          } while (++e < safe_input_shape[4]);
+        } while (++d < safe_input_shape[3]);
+      } while (++c < safe_input_shape[2]);
+    } while (++b < safe_input_shape[1]);
+  } while (++a < safe_input_shape[0]);
+
+  return output;
+}
+
+inline std::vector<int64_t> PadCropShape(std::vector<int64_t> shape,
+                                         const std::vector<int64_t> padding) {
+  for (size_t i = 0; i < shape.size(); ++i) {
+    shape[i] = shape[i] + padding[2 * i] + padding[2 * i + 1];
+  }
+  return shape;
+}
+
+// Pads the input tensor.
+//
+// `Pad` and `Crop` share the same pad/crop specification. The positive values
+// specify padding and the negative values specify cropping.
+template <class T>
+Tensor<T> Pad(const Tensor<T>& input, const std::vector<int64_t>& padding,
+              const T padding_value) {
+  // Keep only positive values in the padding.
+  std::vector<int64_t> safe_padding(kMaxDims * 2, 0);
+  absl::c_transform(padding, safe_padding.begin(),
+                    [](int64_t p) { return std::max<int64_t>(p, 0); });
+
+  Tensor<T> output = Tensor<T>::FromShape(
+      PadCropShape(input.shape, safe_padding), padding_value);
+
+  const std::vector<int64_t> strides = input.Strides();
+  const std::vector<int64_t> output_strides = output.Strides();
+  const std::vector<int64_t> safe_input_shape = ExtendToMaxDim(input.shape);
+
+  int a = 0;
+  do {
+    int b = 0;
+    do {
+      int c = 0;
+      do {
+        int d = 0;
+        do {
+          int e = 0;
+          do {
+            int f = 0;
+            do {
+              const int i_idx = a * strides[0] + b * strides[1] +
+                                c * strides[2] + d * strides[3] +
+                                e * strides[4] + f * strides[5];
+              const int o_idx = (a + safe_padding[0]) * output_strides[0] +
+                                (b + safe_padding[2]) * output_strides[1] +
+                                (c + safe_padding[4]) * output_strides[2] +
+                                (d + safe_padding[6]) * output_strides[3] +
+                                (e + safe_padding[8]) * output_strides[4] +
+                                (f + safe_padding[10]) * output_strides[5];
+              output.data[o_idx] = input.data[i_idx];
+            } while (++f < safe_input_shape[5]);
+          } while (++e < safe_input_shape[4]);
+        } while (++d < safe_input_shape[3]);
+      } while (++c < safe_input_shape[2]);
+    } while (++b < safe_input_shape[1]);
+  } while (++a < safe_input_shape[0]);
+
+  return output;
+}
+
+// Crops the input tensor.
+//
+// Only negative values are taken into account for cropping.
+template <class T>
+Tensor<T> Crop(const Tensor<T>& input, const std::vector<int64_t>& cropping) {
+  // Keep only negative values in the cropping.
+  std::vector<int64_t> safe_cropping(kMaxDims * 2, 0);
+  absl::c_transform(cropping, safe_cropping.begin(),
+                    [](int64_t p) { return std::min<int64_t>(p, 0); });
+
+  Tensor<T> output =
+      Tensor<T>::FromShape(PadCropShape(input.shape, safe_cropping));
+
+  const std::vector<int64_t> strides = input.Strides();
+  const std::vector<int64_t> output_strides = output.Strides();
+  const std::vector<int64_t> safe_output_shape = ExtendToMaxDim(output.shape);
+
+  int a = 0;
+  do {
+    int b = 0;
+    do {
+      int c = 0;
+      do {
+        int d = 0;
+        do {
+          int e = 0;
+          do {
+            int f = 0;
+            do {
+              const int i_idx = (a - safe_cropping[0]) * strides[0] +
+                                (b - safe_cropping[2]) * strides[1] +
+                                (c - safe_cropping[4]) * strides[2] +
+                                (d - safe_cropping[6]) * strides[3] +
+                                (e - safe_cropping[8]) * strides[4] +
+                                (f - safe_cropping[10]) * strides[5];
+              const int o_idx = a * output_strides[0] + b * output_strides[1] +
+                                c * output_strides[2] + d * output_strides[3] +
+                                e * output_strides[4] + f * output_strides[5];
+              output.data[o_idx] = input.data[i_idx];
+            } while (++f < safe_output_shape[5]);
+          } while (++e < safe_output_shape[4]);
+        } while (++d < safe_output_shape[3]);
+      } while (++c < safe_output_shape[2]);
+    } while (++b < safe_output_shape[1]);
+  } while (++a < safe_output_shape[0]);
+
+  return output;
+}
+
+// Gathers the elements that are visible through the given window spec in a new
+// tensor.
+template <class T>
+Tensor<T> WindowCopy(const Tensor<T>& input,
+                     const std::vector<int64_t>& window_dimensions,
+                     const std::vector<int64_t>& window_dilations,
+                     const std::vector<int64_t>& window_offset) {
+  Tensor<T> output = Tensor<T>::FromShape(window_dimensions);
+
+  const std::vector<int64_t> safe_window_dimensions =
+      ExtendToMaxDim(window_dimensions);
+  const std::vector<int64_t> safe_window_dilations =
+      ExtendToMaxDim(window_dilations, 1);
+  const std::vector<int64_t> safe_window_offset = ExtendToMaxDim(window_offset);
+
+  const std::vector<int64_t> strides = input.Strides();
+  const std::vector<int64_t> output_strides = output.Strides();
+
+  int a = 0;
+  do {
+    int b = 0;
+    do {
+      int c = 0;
+      do {
+        int d = 0;
+        do {
+          int e = 0;
+          do {
+            int f = 0;
+            do {
+              const int i_idx =
+                  (a * safe_window_dilations[0] + safe_window_offset[0]) *
+                      strides[0] +
+                  (b * safe_window_dilations[1] + safe_window_offset[1]) *
+                      strides[1] +
+                  (c * safe_window_dilations[2] + safe_window_offset[2]) *
+                      strides[2] +
+                  (d * safe_window_dilations[3] + safe_window_offset[3]) *
+                      strides[3] +
+                  (e * safe_window_dilations[4] + safe_window_offset[4]) *
+                      strides[4] +
+                  (f * safe_window_dilations[5] + safe_window_offset[5]) *
+                      strides[5];
+              const int o_idx = a * output_strides[0] + b * output_strides[1] +
+                                c * output_strides[2] + d * output_strides[3] +
+                                e * output_strides[4] + f * output_strides[5];
+              output.data[o_idx] = input.data[i_idx];
+            } while (++f < safe_window_dimensions[5]);
+          } while (++e < safe_window_dimensions[4]);
+        } while (++d < safe_window_dimensions[3]);
+      } while (++c < safe_window_dimensions[2]);
+    } while (++b < safe_window_dimensions[1]);
+  } while (++a < safe_window_dimensions[0]);
+
+  return output;
+}
+
+inline std::vector<int64_t> ReduceWindowShape(
+    std::vector<int64_t> shape, const std::vector<int64_t>& base_dilations,
+    const std::vector<int64_t>& padding,
+    const std::vector<int64_t>& window_dimensions,
+    const std::vector<int64_t>& window_dilations,
+    const std::vector<int64_t>& window_strides) {
+  const std::vector<int64_t> base_shape =
+      PadCropShape(DilateShape(shape, base_dilations), padding);
+  const std::vector<int64_t> dilated_window_dimensions =
+      DilateShape(window_dimensions, window_dilations);
+  shape.assign(base_shape.size(), 0);
+  for (int i = 0; i < base_shape.size(); ++i) {
+    if (base_shape[i] >= dilated_window_dimensions[i]) {
+      shape[i] =
+          (base_shape[i] - dilated_window_dimensions[i]) / window_strides[i] +
+          1;
+    }
+  }
+  return shape;
+}
+
+template <class T, class F>
+Tensor<T> ReduceWindow(const Tensor<T>& input,
+                       const std::vector<int64_t>& base_dilations,
+                       const std::vector<int64_t>& padding, const T& init_value,
+                       const std::vector<int64_t>& window_dimensions,
+                       const std::vector<int64_t>& window_dilations,
+                       const std::vector<int64_t>& window_strides, F&& body) {
+  Tensor<T> output = Tensor<T>::FromShape(
+      ReduceWindowShape(input.shape, base_dilations, padding, window_dimensions,
+                        window_dilations, window_strides),
+      init_value);
+
+  if (output.data.empty()) {
+    return output;
+  }
+
+  const std::vector<int64_t> safe_output_shape = ExtendToMaxDim(output.shape);
+  const std::vector<int64_t> safe_window_strides =
+      ExtendToMaxDim(window_strides);
+  const std::vector<int64_t> output_strides = output.Strides();
+
+  const Tensor<T> dilated = Dilate<T>(input, base_dilations, init_value);
+  const Tensor<T> padded = Pad<T>(dilated, padding, init_value);
+  const Tensor<T> base = Crop<T>(padded, padding);
+
+  std::vector<int64_t> output_offsets(6, 0);
+  std::vector<int64_t> window_offsets(6, 0);
+  do {
+    output_offsets[1] = 0;
+    window_offsets[1] = 0;
+    do {
+      output_offsets[2] = 0;
+      window_offsets[2] = 0;
+      do {
+        output_offsets[3] = 0;
+        window_offsets[3] = 0;
+        do {
+          output_offsets[4] = 0;
+          window_offsets[4] = 0;
+          do {
+            output_offsets[5] = 0;
+            window_offsets[5] = 0;
+            do {
+              const int64_t o_idx = output_offsets[0] * output_strides[0] +
+                                    output_offsets[1] * output_strides[1] +
+                                    output_offsets[2] * output_strides[2] +
+                                    output_offsets[3] * output_strides[3] +
+                                    output_offsets[4] * output_strides[4] +
+                                    output_offsets[5] * output_strides[5];
+              const Tensor<T> window = WindowCopy(
+                  base, window_dimensions, window_dilations, window_offsets);
+              if (window.data.empty()) {
+                output.data[o_idx] = init_value;
+              } else {
+                output.data[o_idx] = std::accumulate(
+                    window.data.begin(), window.data.end(), init_value, body);
+              }
+              window_offsets[5] += safe_window_strides[5];
+            } while (++output_offsets[5] < safe_output_shape[5]);
+            window_offsets[4] += safe_window_strides[4];
+          } while (++output_offsets[4] < safe_output_shape[4]);
+          window_offsets[3] += safe_window_strides[3];
+        } while (++output_offsets[3] < safe_output_shape[3]);
+        window_offsets[2] += safe_window_strides[2];
+      } while (++output_offsets[2] < safe_output_shape[2]);
+      window_offsets[1] += safe_window_strides[1];
+    } while (++output_offsets[1] < safe_output_shape[1]);
+    window_offsets[0] += safe_window_strides[0];
+  } while (++output_offsets[0] < safe_output_shape[0]);
+  return output;
+}
+
+}  // namespace reference
+}  // namespace reduce_window
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_KERNELS_STABLEHLO_REDUCE_WINDOW_TEST_UTIL_H_
diff --git a/tensorflow/lite/kernels/stablehlo_reduce_window_test_util_test.cc b/tensorflow/lite/kernels/stablehlo_reduce_window_test_util_test.cc
new file mode 100644
index 0000000..61f2d90
--- /dev/null
+++ b/tensorflow/lite/kernels/stablehlo_reduce_window_test_util_test.cc
@@ -0,0 +1,5941 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/kernels/stablehlo_reduce_window_test_util.h"
+
+#include <functional>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace tflite::reduce_window::reference {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::ElementsAreArray;
+
+TEST(ReferenceTest, DilateWorks) {
+  reference::Tensor<int> input = reference::Tensor<int>::iota(/*shape=*/{3, 3});
+  reference::Tensor<int> output =
+      reference::Dilate(input, /*dilations=*/{2, 3}, /*padding_value=*/-1);
+
+  EXPECT_THAT(output.data, ElementsAreArray({
+                               // clang-format off
+                                1, -1, -1,  2, -1, -1,  3,
+                               -1, -1, -1, -1, -1, -1, -1,
+                                4, -1, -1,  5, -1, -1,  6,
+                               -1, -1, -1, -1, -1, -1, -1,
+                                7, -1, -1,  8, -1, -1,  9
+                               // clang-format on
+                           }));
+}
+
+TEST(ReferenceTest, PadWorks) {
+  reference::Tensor<int> input = reference::Tensor<int>::iota(/*shape=*/{3, 3});
+  reference::Tensor<int> output =
+      reference::Pad(input, /*padding=*/{1, 2, 3, 4}, /*padding_value=*/-1);
+
+  EXPECT_THAT(output.data,
+              ElementsAreArray({
+                  // clang-format off
+                  -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
+                  -1, -1, -1,  1,  2,  3, -1, -1, -1, -1,
+                  -1, -1, -1,  4,  5,  6, -1, -1, -1, -1,
+                  -1, -1, -1,  7,  8,  9, -1, -1, -1, -1,
+                  -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
+                  -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
+                  // clang-format on
+              }));
+}
+
+TEST(ReferenceTest, PadIgnoresNegativeValues) {
+  reference::Tensor<int> input = reference::Tensor<int>::iota(/*shape=*/{3, 3});
+  reference::Tensor<int> output =
+      reference::Pad(input, /*padding=*/{-1, -1, -1, -1}, /*padding_value=*/-1);
+
+  EXPECT_THAT(output.data, ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9}));
+}
+
+TEST(ReferenceTest, CropWorks) {
+  reference::Tensor<int> input =
+      reference::Tensor<int>::iota(/*shape=*/{6, 10});
+  reference::Tensor<int> output =
+      reference::Crop(input, /*cropping=*/{-4, -1, -2, -3});
+
+  EXPECT_THAT(output.data, ElementsAreArray({43, 44, 45, 46, 47}));
+}
+
+TEST(ReferenceTest, CropIgnoresPositiveValues) {
+  reference::Tensor<int> input = reference::Tensor<int>::iota(/*shape=*/{3, 3});
+  reference::Tensor<int> output =
+      reference::Crop(input, /*cropping=*/{0, 0, 0, 0});
+
+  EXPECT_THAT(output.data, ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9}));
+}
+
+TEST(ReferenceTest, WindowCopyWorks) {
+  reference::Tensor<int> input = reference::Tensor<int>::iota(/*shape=*/{6, 4});
+  EXPECT_THAT(reference::WindowCopy(input, /*window_dimensions=*/{2, 2},
+                                    /*window_dilations=*/{2, 2},
+                                    /*window_offset=*/{2, 1})
+                  .data,
+              ElementsAreArray({10, 12, 18, 20}));
+}
+
+TEST(ReferenceTest, RandomJaxReference0) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{1, -1, 0, 0},
+      /*init_value=*/0,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{1, 1},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(19, 8));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {0, 0, 0, 0, 0, 0, 0, 0, 4,   6,   8,   10,  12,  14,  16,  18,
+           0, 0, 0, 0, 0, 0, 0, 0, 24,  26,  28,  30,  32,  34,  36,  38,
+           0, 0, 0, 0, 0, 0, 0, 0, 44,  46,  48,  50,  52,  54,  56,  58,
+           0, 0, 0, 0, 0, 0, 0, 0, 64,  66,  68,  70,  72,  74,  76,  78,
+           0, 0, 0, 0, 0, 0, 0, 0, 84,  86,  88,  90,  92,  94,  96,  98,
+           0, 0, 0, 0, 0, 0, 0, 0, 104, 106, 108, 110, 112, 114, 116, 118,
+           0, 0, 0, 0, 0, 0, 0, 0, 124, 126, 128, 130, 132, 134, 136, 138,
+           0, 0, 0, 0, 0, 0, 0, 0, 144, 146, 148, 150, 152, 154, 156, 158,
+           0, 0, 0, 0, 0, 0, 0, 0, 164, 166, 168, 170, 172, 174, 176, 178,
+           0, 0, 0, 0, 0, 0, 0, 0}));
+}
+
+TEST(ReferenceTest, RandomJaxReference1) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{2, -1, 1, 0},
+      /*init_value=*/0,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{2, 1},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(6, 18));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({0, 0,   0, 0,   0, 0,   0, 0,   0, 0,   0, 0,   0, 0,
+                        0, 0,   0, 0,   0, 3,   0, 5,   0, 7,   0, 9,   0, 11,
+                        0, 13,  0, 15,  0, 17,  0, 19,  0, 43,  0, 45,  0, 47,
+                        0, 49,  0, 51,  0, 53,  0, 55,  0, 57,  0, 59,  0, 83,
+                        0, 85,  0, 87,  0, 89,  0, 91,  0, 93,  0, 95,  0, 97,
+                        0, 99,  0, 123, 0, 125, 0, 127, 0, 129, 0, 131, 0, 133,
+                        0, 135, 0, 137, 0, 139, 0, 163, 0, 165, 0, 167, 0, 169,
+                        0, 171, 0, 173, 0, 175, 0, 177, 0, 179}));
+}
+
+TEST(ReferenceTest, RandomJaxReference2) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{2, -2, -2, 2},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{1, 2},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(8, 4));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray({5,  7,  9,  9,  15, 17, 19, 19, 25, 27, 29,
+                                29, 35, 37, 39, 39, 45, 47, 49, 49, 55, 57,
+                                59, 59, 65, 67, 69, 69, 75, 77, 79, 79}));
+}
+
+TEST(ReferenceTest, RandomJaxReference3) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{0, 1, -1, 1},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{2, 1},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(6, 19));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {-2147483647, 2,           -2147483647, 3,           -2147483647,
+           4,           -2147483647, 5,           -2147483647, 6,
+           -2147483647, 7,           -2147483647, 8,           -2147483647,
+           9,           -2147483647, 10,          -2147483647, -2147483647,
+           22,          -2147483647, 23,          -2147483647, 24,
+           -2147483647, 25,          -2147483647, 26,          -2147483647,
+           27,          -2147483647, 28,          -2147483647, 29,
+           -2147483647, 30,          -2147483647, -2147483647, 42,
+           -2147483647, 43,          -2147483647, 44,          -2147483647,
+           45,          -2147483647, 46,          -2147483647, 47,
+           -2147483647, 48,          -2147483647, 49,          -2147483647,
+           50,          -2147483647, -2147483647, 62,          -2147483647,
+           63,          -2147483647, 64,          -2147483647, 65,
+           -2147483647, 66,          -2147483647, 67,          -2147483647,
+           68,          -2147483647, 69,          -2147483647, 70,
+           -2147483647, -2147483647, 82,          -2147483647, 83,
+           -2147483647, 84,          -2147483647, 85,          -2147483647,
+           86,          -2147483647, 87,          -2147483647, 88,
+           -2147483647, 89,          -2147483647, 90,          -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647}));
+}
+
+TEST(ReferenceTest, RandomJaxReference4) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{-2, -2, -1, -2},
+      /*init_value=*/0,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{2, 2},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(3, 3));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray({46, 50, 54, 86, 90, 94, 126, 130, 134}));
+}
+
+TEST(ReferenceTest, RandomJaxReference5) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{1, 2, 1, 1},
+      /*init_value=*/1,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{1, 2},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(11, 6));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {1,    12,   14,   16,   18,   20,   1,    44,   96,   156,  224,
+           300,  1,    384,  476,  576,  684,  800,  1,    924,  1056, 1196,
+           1344, 1500, 1,    1664, 1836, 2016, 2204, 2400, 1,    2604, 2816,
+           3036, 3264, 3500, 1,    3744, 3996, 4256, 4524, 4800, 1,    5084,
+           5376, 5676, 5984, 6300, 1,    6624, 6956, 7296, 7644, 8000, 1,
+           82,   84,   86,   88,   90,   1,    92,   94,   96,   98,   100}));
+}
+
+TEST(ReferenceTest, RandomJaxReference6) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 2},
+      /*padding=*/{2, -1, 0, -2},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{2, 1},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(9, 17));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {1,           -2147483647, 2,           -2147483647, 3,
+           -2147483647, 4,           -2147483647, 5,           -2147483647,
+           6,           -2147483647, 7,           -2147483647, 8,
+           -2147483647, 9,           11,          -2147483647, 12,
+           -2147483647, 13,          -2147483647, 14,          -2147483647,
+           15,          -2147483647, 16,          -2147483647, 17,
+           -2147483647, 18,          -2147483647, 19,          21,
+           -2147483647, 22,          -2147483647, 23,          -2147483647,
+           24,          -2147483647, 25,          -2147483647, 26,
+           -2147483647, 27,          -2147483647, 28,          -2147483647,
+           29,          31,          -2147483647, 32,          -2147483647,
+           33,          -2147483647, 34,          -2147483647, 35,
+           -2147483647, 36,          -2147483647, 37,          -2147483647,
+           38,          -2147483647, 39,          41,          -2147483647,
+           42,          -2147483647, 43,          -2147483647, 44,
+           -2147483647, 45,          -2147483647, 46,          -2147483647,
+           47,          -2147483647, 48,          -2147483647, 49,
+           51,          -2147483647, 52,          -2147483647, 53,
+           -2147483647, 54,          -2147483647, 55,          -2147483647,
+           56,          -2147483647, 57,          -2147483647, 58,
+           -2147483647, 59,          61,          -2147483647, 62,
+           -2147483647, 63,          -2147483647, 64,          -2147483647,
+           65,          -2147483647, 66,          -2147483647, 67,
+           -2147483647, 68,          -2147483647, 69,          71,
+           -2147483647, 72,          -2147483647, 73,          -2147483647,
+           74,          -2147483647, 75,          -2147483647, 76,
+           -2147483647, 77,          -2147483647, 78,          -2147483647,
+           79,          81,          -2147483647, 82,          -2147483647,
+           83,          -2147483647, 84,          -2147483647, 85,
+           -2147483647, 86,          -2147483647, 87,          -2147483647,
+           88,          -2147483647, 89}));
+}
+
+TEST(ReferenceTest, RandomJaxReference7) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{-2, -2, 1, 0},
+      /*init_value=*/0,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{2, 1},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(3, 11));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray({0, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
+                                0, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
+                                0, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70}));
+}
+
+TEST(ReferenceTest, RandomJaxReference8) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{2, 1, -2, -2},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{1, 2},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(13, 3));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {-2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, 4,           6,           8,           14,
+           16,          18,          24,          26,          28,
+           34,          36,          38,          44,          46,
+           48,          54,          56,          58,          64,
+           66,          68,          74,          76,          78,
+           84,          86,          88,          94,          96,
+           98,          -2147483647, -2147483647, -2147483647}));
+}
+
+TEST(ReferenceTest, RandomJaxReference9) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{-1, 2, -2, -2},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{1, 2},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(9, 7));
+
+  EXPECT_THAT(res.data, ElementsAreArray(
+                            {32, 33, 34, 35, 36, 37, 38, 42, 43, 44, 45, 46, 47,
+                             48, 52, 53, 54, 55, 56, 57, 58, 62, 63, 64, 65, 66,
+                             67, 68, 72, 73, 74, 75, 76, 77, 78, 82, 83, 84, 85,
+                             86, 87, 88, 92, 93, 94, 95, 96, 97, 98, 82, 83, 84,
+                             85, 86, 87, 88, 92, 93, 94, 95, 96, 97, 98}));
+}
+
+TEST(ReferenceTest, RandomJaxReference10) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 2},
+      /*padding=*/{0, -1, 0, 2},
+      /*init_value=*/0,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{1, 2},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(17, 10));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15, 16,
+           17, 18, 19, 20, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,
+           23, 24, 25, 26, 27, 28, 29, 30, 21, 22, 23, 24, 25, 26, 27, 28,
+           29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 31, 32, 33, 34,
+           35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
+           41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56,
+           57, 58, 59, 60, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62,
+           63, 64, 65, 66, 67, 68, 69, 70, 61, 62, 63, 64, 65, 66, 67, 68,
+           69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 71, 72, 73, 74,
+           75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90,
+           81, 82, 83, 84, 85, 86, 87, 88, 89, 90}));
+}
+
+TEST(ReferenceTest, RandomJaxReference11) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{0, 0, 2, 0},
+      /*init_value=*/0,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{2, 2},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(4, 6));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray({0,   22,  26, 30,  34,  38,  0,   62,
+                                66,  70,  74, 78,  0,   102, 106, 110,
+                                114, 118, 0,  142, 146, 150, 154, 158}));
+}
+
+TEST(ReferenceTest, RandomJaxReference12) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{1, -2, 1, -2},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{2, 2},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(9, 5));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {-2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647}));
+}
+
+TEST(ReferenceTest, RandomJaxReference13) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{1, 2, 1, -2},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{1, 2},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(13, 5));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {-2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, 2,           4,           6,           8,
+           -2147483647, 12,          14,          16,          18,
+           -2147483647, 22,          24,          26,          28,
+           -2147483647, 32,          34,          36,          38,
+           -2147483647, 42,          44,          46,          48,
+           -2147483647, 52,          54,          56,          58,
+           -2147483647, 62,          64,          66,          68,
+           -2147483647, 72,          74,          76,          78,
+           -2147483647, 82,          84,          86,          88,
+           -2147483647, 92,          94,          96,          98,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647}));
+}
+
+TEST(ReferenceTest, RandomJaxReference14) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{1, 2, 1, -1},
+      /*init_value=*/1,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{2, 1},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(11, 9));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray(
+                  {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+                   1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+                   1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+                   1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+                   1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}));
+}
+
+TEST(ReferenceTest, RandomJaxReference15) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{-2, -2, 1, 2},
+      /*init_value=*/2147483646,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{2, 2},
+      /*body=*/[](auto a, auto b) { return a <= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(3, 11));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 2147483646,
+                        41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 2147483646,
+                        61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 2147483646}));
+}
+
+TEST(ReferenceTest, RandomJaxReference16) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{0, 0, 0, 0},
+      /*init_value=*/2147483646,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{2, 1},
+      /*body=*/[](auto a, auto b) { return a <= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(5, 19));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray(
+                  {1,          2147483646, 2,          2147483646, 3,
+                   2147483646, 4,          2147483646, 5,          2147483646,
+                   6,          2147483646, 7,          2147483646, 8,
+                   2147483646, 9,          2147483646, 10,         21,
+                   2147483646, 22,         2147483646, 23,         2147483646,
+                   24,         2147483646, 25,         2147483646, 26,
+                   2147483646, 27,         2147483646, 28,         2147483646,
+                   29,         2147483646, 30,         41,         2147483646,
+                   42,         2147483646, 43,         2147483646, 44,
+                   2147483646, 45,         2147483646, 46,         2147483646,
+                   47,         2147483646, 48,         2147483646, 49,
+                   2147483646, 50,         61,         2147483646, 62,
+                   2147483646, 63,         2147483646, 64,         2147483646,
+                   65,         2147483646, 66,         2147483646, 67,
+                   2147483646, 68,         2147483646, 69,         2147483646,
+                   70,         81,         2147483646, 82,         2147483646,
+                   83,         2147483646, 84,         2147483646, 85,
+                   2147483646, 86,         2147483646, 87,         2147483646,
+                   88,         2147483646, 89,         2147483646, 90}));
+}
+
+TEST(ReferenceTest, RandomJaxReference17) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{2, -1, 2, 1},
+      /*init_value=*/2147483646,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{1, 1},
+      /*body=*/[](auto a, auto b) { return a <= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(10, 20));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           1,          2147483646, 1,          2147483646, 2,
+           2147483646, 3,          2147483646, 4,          2147483646,
+           5,          2147483646, 6,          2147483646, 7,
+           2147483646, 8,          2147483646, 9,          2147483646,
+           1,          2147483646, 1,          2147483646, 2,
+           2147483646, 3,          2147483646, 4,          2147483646,
+           5,          2147483646, 6,          2147483646, 7,
+           2147483646, 8,          2147483646, 9,          2147483646,
+           11,         2147483646, 11,         2147483646, 12,
+           2147483646, 13,         2147483646, 14,         2147483646,
+           15,         2147483646, 16,         2147483646, 17,
+           2147483646, 18,         2147483646, 19,         2147483646,
+           21,         2147483646, 21,         2147483646, 22,
+           2147483646, 23,         2147483646, 24,         2147483646,
+           25,         2147483646, 26,         2147483646, 27,
+           2147483646, 28,         2147483646, 29,         2147483646,
+           31,         2147483646, 31,         2147483646, 32,
+           2147483646, 33,         2147483646, 34,         2147483646,
+           35,         2147483646, 36,         2147483646, 37,
+           2147483646, 38,         2147483646, 39,         2147483646,
+           41,         2147483646, 41,         2147483646, 42,
+           2147483646, 43,         2147483646, 44,         2147483646,
+           45,         2147483646, 46,         2147483646, 47,
+           2147483646, 48,         2147483646, 49,         2147483646,
+           51,         2147483646, 51,         2147483646, 52,
+           2147483646, 53,         2147483646, 54,         2147483646,
+           55,         2147483646, 56,         2147483646, 57,
+           2147483646, 58,         2147483646, 59,         2147483646,
+           61,         2147483646, 61,         2147483646, 62,
+           2147483646, 63,         2147483646, 64,         2147483646,
+           65,         2147483646, 66,         2147483646, 67,
+           2147483646, 68,         2147483646, 69,         2147483646,
+           71,         2147483646, 71,         2147483646, 72,
+           2147483646, 73,         2147483646, 74,         2147483646,
+           75,         2147483646, 76,         2147483646, 77,
+           2147483646, 78,         2147483646, 79,         2147483646}));
+}
+
+TEST(ReferenceTest, RandomJaxReference18) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{1, -2, -1, 0},
+      /*init_value=*/1,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{1, 1},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(9, 18));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {1, 1,  1, 1,  1, 1,  1, 1,  1, 1,  1, 1,  1, 1,  1, 1,  1, 1,
+           1, 2,  1, 3,  1, 4,  1, 5,  1, 6,  1, 7,  1, 8,  1, 9,  1, 10,
+           1, 12, 1, 13, 1, 14, 1, 15, 1, 16, 1, 17, 1, 18, 1, 19, 1, 20,
+           1, 22, 1, 23, 1, 24, 1, 25, 1, 26, 1, 27, 1, 28, 1, 29, 1, 30,
+           1, 32, 1, 33, 1, 34, 1, 35, 1, 36, 1, 37, 1, 38, 1, 39, 1, 40,
+           1, 42, 1, 43, 1, 44, 1, 45, 1, 46, 1, 47, 1, 48, 1, 49, 1, 50,
+           1, 52, 1, 53, 1, 54, 1, 55, 1, 56, 1, 57, 1, 58, 1, 59, 1, 60,
+           1, 62, 1, 63, 1, 64, 1, 65, 1, 66, 1, 67, 1, 68, 1, 69, 1, 70,
+           1, 72, 1, 73, 1, 74, 1, 75, 1, 76, 1, 77, 1, 78, 1, 79, 1, 80}));
+}
+
+TEST(ReferenceTest, RandomJaxReference19) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{1, 0, 0, -1},
+      /*init_value=*/2147483646,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{1, 1},
+      /*body=*/[](auto a, auto b) { return a <= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(10, 9));
+
+  EXPECT_THAT(res.data, ElementsAreArray(
+                            {1,  2,  3,  4,  5,  6,  7,  8,  9,  1,  2,  3,  4,
+                             5,  6,  7,  8,  9,  11, 12, 13, 14, 15, 16, 17, 18,
+                             19, 21, 22, 23, 24, 25, 26, 27, 28, 29, 31, 32, 33,
+                             34, 35, 36, 37, 38, 39, 41, 42, 43, 44, 45, 46, 47,
+                             48, 49, 51, 52, 53, 54, 55, 56, 57, 58, 59, 61, 62,
+                             63, 64, 65, 66, 67, 68, 69, 71, 72, 73, 74, 75, 76,
+                             77, 78, 79, 81, 82, 83, 84, 85, 86, 87, 88, 89}));
+}
+
+TEST(ReferenceTest, RandomJaxReference20) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{1, 2, 1, -1},
+      /*init_value=*/2147483646,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{2, 2},
+      /*body=*/[](auto a, auto b) { return a <= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(11, 5));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {1,  2,          4,          6,          8,          11,        12,
+           14, 16,         18,         21,         22,         24,        26,
+           28, 31,         32,         34,         36,         38,        41,
+           42, 44,         46,         48,         51,         52,        54,
+           56, 58,         61,         62,         64,         66,        68,
+           71, 72,         74,         76,         78,         81,        82,
+           84, 86,         88,         91,         92,         94,        96,
+           98, 2147483646, 2147483646, 2147483646, 2147483646, 2147483646}));
+}
+
+TEST(ReferenceTest, RandomJaxReference21) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{1, 0, 1, -1},
+      /*init_value=*/0,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{2, 2},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(5, 9));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray({1,   2,   3,   4,   5,   6,   7,   8,   9,
+                                32,  34,  36,  38,  40,  42,  44,  46,  48,
+                                72,  74,  76,  78,  80,  82,  84,  86,  88,
+                                112, 114, 116, 118, 120, 122, 124, 126, 128,
+                                152, 154, 156, 158, 160, 162, 164, 166, 168}));
+}
+
+TEST(ReferenceTest, RandomJaxReference22) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{-2, 2, -2, -2},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{1, 2},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(10, 7));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {23,          24,          25,          26,          27,
+           28,          29,          33,          34,          35,
+           36,          37,          38,          39,          43,
+           44,          45,          46,          47,          48,
+           49,          53,          54,          55,          56,
+           57,          58,          59,          63,          64,
+           65,          66,          67,          68,          69,
+           73,          74,          75,          76,          77,
+           78,          79,          83,          84,          85,
+           86,          87,          88,          89,          93,
+           94,          95,          96,          97,          98,
+           99,          -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647}));
+}
+
+TEST(ReferenceTest, RandomJaxReference23) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{2, -2, 2, 0},
+      /*init_value=*/0,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{1, 1},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(10, 11));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
+           0,   0,   0,   0,   0,   0,   0,   0,   0,   1,   3,   5,   7,   9,
+           11,  13,  15,  17,  19,  0,   11,  23,  25,  27,  29,  31,  33,  35,
+           37,  39,  0,   21,  43,  45,  47,  49,  51,  53,  55,  57,  59,  0,
+           31,  63,  65,  67,  69,  71,  73,  75,  77,  79,  0,   41,  83,  85,
+           87,  89,  91,  93,  95,  97,  99,  0,   51,  103, 105, 107, 109, 111,
+           113, 115, 117, 119, 0,   61,  123, 125, 127, 129, 131, 133, 135, 137,
+           139, 0,   71,  143, 145, 147, 149, 151, 153, 155, 157, 159}));
+}
+
+TEST(ReferenceTest, RandomJaxReference24) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{2, 2, -2, -2},
+      /*init_value=*/0,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{2, 1},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(11, 6));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({3,   4,   5,   6,   7,   8,   16,  18,  20,  22,  24,
+                        26,  36,  38,  40,  42,  44,  46,  56,  58,  60,  62,
+                        64,  66,  76,  78,  80,  82,  84,  86,  96,  98,  100,
+                        102, 104, 106, 116, 118, 120, 122, 124, 126, 136, 138,
+                        140, 142, 144, 146, 156, 158, 160, 162, 164, 166, 176,
+                        178, 180, 182, 184, 186, 93,  94,  95,  96,  97,  98}));
+}
+
+TEST(ReferenceTest, RandomJaxReference25) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{2, -1, 2, 2},
+      /*init_value=*/1,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{2, 1},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(10, 14));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({1, 1, 1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1, 1,
+                        1, 1, 1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 1, 1,
+                        1, 1, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 1, 1,
+                        1, 1, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 1, 1,
+                        1, 1, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 1, 1,
+                        1, 1, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 1, 1,
+                        1, 1, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 1, 1,
+                        1, 1, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 1, 1,
+                        1, 1, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 1, 1,
+                        1, 1, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 1, 1}));
+}
+
+TEST(ReferenceTest, RandomJaxReference26) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 2},
+      /*padding=*/{-1, 1, -1, -2},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{1, 2},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(17, 7));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {-2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647}));
+}
+
+TEST(ReferenceTest, RandomJaxReference27) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{1, -2, 2, -2},
+      /*init_value=*/0,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{1, 2},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(8, 5));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray({0, 1,   3,   5,   7,   0, 12,  16,  20,  24,
+                                0, 32,  36,  40,  44,  0, 52,  56,  60,  64,
+                                0, 72,  76,  80,  84,  0, 92,  96,  100, 104,
+                                0, 112, 116, 120, 124, 0, 132, 136, 140, 144}));
+}
+
+TEST(ReferenceTest, RandomJaxReference28) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{-2, -2, 0, -2},
+      /*init_value=*/1,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{1, 1},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(6, 8));
+
+  EXPECT_THAT(res.data, ElementsAreArray(
+                            {21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33, 34,
+                             35, 36, 37, 38, 41, 42, 43, 44, 45, 46, 47, 48,
+                             51, 52, 53, 54, 55, 56, 57, 58, 61, 62, 63, 64,
+                             65, 66, 67, 68, 71, 72, 73, 74, 75, 76, 77, 78}));
+}
+
+TEST(ReferenceTest, RandomJaxReference29) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{-1, -1, 2, 0},
+      /*init_value=*/2147483646,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{2, 1},
+      /*body=*/[](auto a, auto b) { return a <= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(4, 21));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray(
+                  {2147483646, 2147483646, 11,         2147483646, 12,
+                   2147483646, 13,         2147483646, 14,         2147483646,
+                   15,         2147483646, 16,         2147483646, 17,
+                   2147483646, 18,         2147483646, 19,         2147483646,
+                   20,         2147483646, 2147483646, 31,         2147483646,
+                   32,         2147483646, 33,         2147483646, 34,
+                   2147483646, 35,         2147483646, 36,         2147483646,
+                   37,         2147483646, 38,         2147483646, 39,
+                   2147483646, 40,         2147483646, 2147483646, 51,
+                   2147483646, 52,         2147483646, 53,         2147483646,
+                   54,         2147483646, 55,         2147483646, 56,
+                   2147483646, 57,         2147483646, 58,         2147483646,
+                   59,         2147483646, 60,         2147483646, 2147483646,
+                   71,         2147483646, 72,         2147483646, 73,
+                   2147483646, 74,         2147483646, 75,         2147483646,
+                   76,         2147483646, 77,         2147483646, 78,
+                   2147483646, 79,         2147483646, 80}));
+}
+
+TEST(ReferenceTest, RandomJaxReference30) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{-1, 1, -2, -1},
+      /*init_value=*/0,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{2, 2},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(10, 4));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray({0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
+}
+
+TEST(ReferenceTest, RandomJaxReference31) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{-2, 1, -1, -2},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{2, 1},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(5, 16));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {-2147483647, 22,          -2147483647, 23,          -2147483647,
+           24,          -2147483647, 25,          -2147483647, 26,
+           -2147483647, 27,          -2147483647, 28,          -2147483647,
+           29,          -2147483647, 42,          -2147483647, 43,
+           -2147483647, 44,          -2147483647, 45,          -2147483647,
+           46,          -2147483647, 47,          -2147483647, 48,
+           -2147483647, 49,          -2147483647, 62,          -2147483647,
+           63,          -2147483647, 64,          -2147483647, 65,
+           -2147483647, 66,          -2147483647, 67,          -2147483647,
+           68,          -2147483647, 69,          -2147483647, 82,
+           -2147483647, 83,          -2147483647, 84,          -2147483647,
+           85,          -2147483647, 86,          -2147483647, 87,
+           -2147483647, 88,          -2147483647, 89,          -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647}));
+}
+
+TEST(ReferenceTest, RandomJaxReference32) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{-1, 2, -1, 0},
+      /*init_value=*/0,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{2, 2},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(9, 5));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray({0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
+}
+
+TEST(ReferenceTest, RandomJaxReference33) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 2},
+      /*padding=*/{1, -1, 2, 1},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{1, 2},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(17, 10));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {-2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           11,          12,          13,          14,          15,
+           16,          17,          18,          19,          20,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           21,          22,          23,          24,          25,
+           26,          27,          28,          29,          30,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           31,          32,          33,          34,          35,
+           36,          37,          38,          39,          40,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           41,          42,          43,          44,          45,
+           46,          47,          48,          49,          50,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           51,          52,          53,          54,          55,
+           56,          57,          58,          59,          60,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           61,          62,          63,          64,          65,
+           66,          67,          68,          69,          70,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           71,          72,          73,          74,          75,
+           76,          77,          78,          79,          80,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           81,          82,          83,          84,          85,
+           86,          87,          88,          89,          90,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647}));
+}
+
+TEST(ReferenceTest, RandomJaxReference34) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{0, 2, 2, 2},
+      /*init_value=*/0,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{1, 1},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(12, 12));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {1,   2,   4,   6,   8,   10,  12,  14,  16,  18,  9,   10,  11,  12,
+           24,  26,  28,  30,  32,  34,  36,  38,  19,  20,  21,  22,  44,  46,
+           48,  50,  52,  54,  56,  58,  29,  30,  31,  32,  64,  66,  68,  70,
+           72,  74,  76,  78,  39,  40,  41,  42,  84,  86,  88,  90,  92,  94,
+           96,  98,  49,  50,  51,  52,  104, 106, 108, 110, 112, 114, 116, 118,
+           59,  60,  61,  62,  124, 126, 128, 130, 132, 134, 136, 138, 69,  70,
+           71,  72,  144, 146, 148, 150, 152, 154, 156, 158, 79,  80,  81,  82,
+           164, 166, 168, 170, 172, 174, 176, 178, 89,  90,  91,  92,  184, 186,
+           188, 190, 192, 194, 196, 198, 99,  100, 0,   0,   0,   0,   0,   0,
+           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
+           0,   0,   0,   0}));
+}
+
+TEST(ReferenceTest, RandomJaxReference35) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{1, 2, 1, -1},
+      /*init_value=*/0,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{2, 2},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(6, 9));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({11,  12,  13,  14,  15,  16,  17,  18,  19,  42,  44,
+                        46,  48,  50,  52,  54,  56,  58,  82,  84,  86,  88,
+                        90,  92,  94,  96,  98,  122, 124, 126, 128, 130, 132,
+                        134, 136, 138, 162, 164, 166, 168, 170, 172, 174, 176,
+                        178, 91,  92,  93,  94,  95,  96,  97,  98,  99}));
+}
+
+TEST(ReferenceTest, RandomJaxReference36) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 2},
+      /*padding=*/{2, 2, 2, 1},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{2, 1},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(11, 22));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {-2147483647, -2147483647, 1,           -2147483647, 2,
+           -2147483647, 3,           -2147483647, 4,           -2147483647,
+           5,           -2147483647, 6,           -2147483647, 7,
+           -2147483647, 8,           -2147483647, 9,           -2147483647,
+           10,          -2147483647, -2147483647, -2147483647, 11,
+           -2147483647, 12,          -2147483647, 13,          -2147483647,
+           14,          -2147483647, 15,          -2147483647, 16,
+           -2147483647, 17,          -2147483647, 18,          -2147483647,
+           19,          -2147483647, 20,          -2147483647, -2147483647,
+           -2147483647, 21,          -2147483647, 22,          -2147483647,
+           23,          -2147483647, 24,          -2147483647, 25,
+           -2147483647, 26,          -2147483647, 27,          -2147483647,
+           28,          -2147483647, 29,          -2147483647, 30,
+           -2147483647, -2147483647, -2147483647, 31,          -2147483647,
+           32,          -2147483647, 33,          -2147483647, 34,
+           -2147483647, 35,          -2147483647, 36,          -2147483647,
+           37,          -2147483647, 38,          -2147483647, 39,
+           -2147483647, 40,          -2147483647, -2147483647, -2147483647,
+           41,          -2147483647, 42,          -2147483647, 43,
+           -2147483647, 44,          -2147483647, 45,          -2147483647,
+           46,          -2147483647, 47,          -2147483647, 48,
+           -2147483647, 49,          -2147483647, 50,          -2147483647,
+           -2147483647, -2147483647, 51,          -2147483647, 52,
+           -2147483647, 53,          -2147483647, 54,          -2147483647,
+           55,          -2147483647, 56,          -2147483647, 57,
+           -2147483647, 58,          -2147483647, 59,          -2147483647,
+           60,          -2147483647, -2147483647, -2147483647, 61,
+           -2147483647, 62,          -2147483647, 63,          -2147483647,
+           64,          -2147483647, 65,          -2147483647, 66,
+           -2147483647, 67,          -2147483647, 68,          -2147483647,
+           69,          -2147483647, 70,          -2147483647, -2147483647,
+           -2147483647, 71,          -2147483647, 72,          -2147483647,
+           73,          -2147483647, 74,          -2147483647, 75,
+           -2147483647, 76,          -2147483647, 77,          -2147483647,
+           78,          -2147483647, 79,          -2147483647, 80,
+           -2147483647, -2147483647, -2147483647, 81,          -2147483647,
+           82,          -2147483647, 83,          -2147483647, 84,
+           -2147483647, 85,          -2147483647, 86,          -2147483647,
+           87,          -2147483647, 88,          -2147483647, 89,
+           -2147483647, 90,          -2147483647, -2147483647, -2147483647,
+           91,          -2147483647, 92,          -2147483647, 93,
+           -2147483647, 94,          -2147483647, 95,          -2147483647,
+           96,          -2147483647, 97,          -2147483647, 98,
+           -2147483647, 99,          -2147483647, 100,         -2147483647,
+           -2147483647, -2147483647, 91,          -2147483647, 92,
+           -2147483647, 93,          -2147483647, 94,          -2147483647,
+           95,          -2147483647, 96,          -2147483647, 97,
+           -2147483647, 98,          -2147483647, 99,          -2147483647,
+           100,         -2147483647}));
+}
+
+TEST(ReferenceTest, RandomJaxReference37) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{-2, 2, 1, 2},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{1, 2},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(18, 6));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {12,          14,          16,          18,          20,
+           20,          22,          24,          26,          28,
+           30,          30,          22,          24,          26,
+           28,          30,          30,          32,          34,
+           36,          38,          40,          40,          32,
+           34,          36,          38,          40,          40,
+           42,          44,          46,          48,          50,
+           50,          42,          44,          46,          48,
+           50,          50,          52,          54,          56,
+           58,          60,          60,          52,          54,
+           56,          58,          60,          60,          62,
+           64,          66,          68,          70,          70,
+           62,          64,          66,          68,          70,
+           70,          72,          74,          76,          78,
+           80,          80,          72,          74,          76,
+           78,          80,          80,          82,          84,
+           86,          88,          90,          90,          82,
+           84,          86,          88,          90,          90,
+           92,          94,          96,          98,          100,
+           100,         92,          94,          96,          98,
+           100,         100,         -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647}));
+}
+
+TEST(ReferenceTest, RandomJaxReference38) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{0, -2, 1, 1},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{1, 1},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(17, 11));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {1,           2,           3,           4,           5,
+           6,           7,           8,           9,           10,
+           10,          -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, 11,          12,          13,
+           14,          15,          16,          17,          18,
+           19,          20,          20,          -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, 21,
+           22,          23,          24,          25,          26,
+           27,          28,          29,          30,          30,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, 31,          32,          33,          34,
+           35,          36,          37,          38,          39,
+           40,          40,          -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, 41,          42,
+           43,          44,          45,          46,          47,
+           48,          49,          50,          50,          -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           51,          52,          53,          54,          55,
+           56,          57,          58,          59,          60,
+           60,          -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, 61,          62,          63,
+           64,          65,          66,          67,          68,
+           69,          70,          70,          -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, 71,
+           72,          73,          74,          75,          76,
+           77,          78,          79,          80,          80,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, 81,          82,          83,          84,
+           85,          86,          87,          88,          89,
+           90,          90}));
+}
+
+TEST(ReferenceTest, RandomJaxReference39) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{-1, -1, -2, 0},
+      /*init_value=*/0,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{1, 1},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(15, 8));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {0,   0,   0,   0,   0,   0,   0,   0, 36, 38, 40, 42, 44,  46,  48,
+           50,  0,   0,   0,   0,   0,   0,   0, 0,  56, 58, 60, 62,  64,  66,
+           68,  70,  0,   0,   0,   0,   0,   0, 0,  0,  76, 78, 80,  82,  84,
+           86,  88,  90,  0,   0,   0,   0,   0, 0,  0,  0,  96, 98,  100, 102,
+           104, 106, 108, 110, 0,   0,   0,   0, 0,  0,  0,  0,  116, 118, 120,
+           122, 124, 126, 128, 130, 0,   0,   0, 0,  0,  0,  0,  0,   136, 138,
+           140, 142, 144, 146, 148, 150, 0,   0, 0,  0,  0,  0,  0,   0,   156,
+           158, 160, 162, 164, 166, 168, 170, 0, 0,  0,  0,  0,  0,   0,   0}));
+}
+
+TEST(ReferenceTest, RandomJaxReference40) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{2, -1, -2, 2},
+      /*init_value=*/1,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{1, 2},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(19, 5));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({1,  1,  1,  1,  1,  3,  5,  7,  9,  1,  3,  5,  7,  9,
+                        1,  13, 15, 17, 19, 1,  13, 15, 17, 19, 1,  23, 25, 27,
+                        29, 1,  23, 25, 27, 29, 1,  33, 35, 37, 39, 1,  33, 35,
+                        37, 39, 1,  43, 45, 47, 49, 1,  43, 45, 47, 49, 1,  53,
+                        55, 57, 59, 1,  53, 55, 57, 59, 1,  63, 65, 67, 69, 1,
+                        63, 65, 67, 69, 1,  73, 75, 77, 79, 1,  73, 75, 77, 79,
+                        1,  83, 85, 87, 89, 1,  83, 85, 87, 89, 1}));
+}
+
+TEST(ReferenceTest, RandomJaxReference41) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 2},
+      /*padding=*/{-1, 2, -2, 0},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{1, 2},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(18, 8));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {-2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, 23,          24,
+           25,          26,          27,          28,          29,
+           30,          -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, 33,
+           34,          35,          36,          37,          38,
+           39,          40,          -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           43,          44,          45,          46,          47,
+           48,          49,          50,          -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, 53,          54,          55,          56,
+           57,          58,          59,          60,          -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, 63,          64,          65,
+           66,          67,          68,          69,          70,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, 73,          74,
+           75,          76,          77,          78,          79,
+           80,          -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, 83,
+           84,          85,          86,          87,          88,
+           89,          90,          -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           93,          94,          95,          96,          97,
+           98,          99,          100,         -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, 93,          94,          95,          96,
+           97,          98,          99,          100}));
+}
+
+TEST(ReferenceTest, RandomJaxReference42) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{-2, -1, -1, 1},
+      /*init_value=*/1,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{1, 1},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(15, 9));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {156,  182,  210,  240,  272,  306,  342,  380,  20,   506,  552,
+           600,  650,  702,  756,  812,  870,  30,   506,  552,  600,  650,
+           702,  756,  812,  870,  30,   1056, 1122, 1190, 1260, 1332, 1406,
+           1482, 1560, 40,   1056, 1122, 1190, 1260, 1332, 1406, 1482, 1560,
+           40,   1806, 1892, 1980, 2070, 2162, 2256, 2352, 2450, 50,   1806,
+           1892, 1980, 2070, 2162, 2256, 2352, 2450, 50,   2756, 2862, 2970,
+           3080, 3192, 3306, 3422, 3540, 60,   2756, 2862, 2970, 3080, 3192,
+           3306, 3422, 3540, 60,   3906, 4032, 4160, 4290, 4422, 4556, 4692,
+           4830, 70,   3906, 4032, 4160, 4290, 4422, 4556, 4692, 4830, 70,
+           5256, 5402, 5550, 5700, 5852, 6006, 6162, 6320, 80,   5256, 5402,
+           5550, 5700, 5852, 6006, 6162, 6320, 80,   6806, 6972, 7140, 7310,
+           7482, 7656, 7832, 8010, 90,   6806, 6972, 7140, 7310, 7482, 7656,
+           7832, 8010, 90}));
+}
+
+TEST(ReferenceTest, RandomJaxReference43) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 2},
+      /*padding=*/{1, 0, -2, 1},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{1, 1},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(19, 18));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {2,  -2147483647, 3,  -2147483647, 4,   -2147483647, 5,  -2147483647,
+           6,  -2147483647, 7,  -2147483647, 8,   -2147483647, 9,  -2147483647,
+           10, -2147483647, 2,  -2147483647, 3,   -2147483647, 4,  -2147483647,
+           5,  -2147483647, 6,  -2147483647, 7,   -2147483647, 8,  -2147483647,
+           9,  -2147483647, 10, -2147483647, 12,  -2147483647, 13, -2147483647,
+           14, -2147483647, 15, -2147483647, 16,  -2147483647, 17, -2147483647,
+           18, -2147483647, 19, -2147483647, 20,  -2147483647, 12, -2147483647,
+           13, -2147483647, 14, -2147483647, 15,  -2147483647, 16, -2147483647,
+           17, -2147483647, 18, -2147483647, 19,  -2147483647, 20, -2147483647,
+           22, -2147483647, 23, -2147483647, 24,  -2147483647, 25, -2147483647,
+           26, -2147483647, 27, -2147483647, 28,  -2147483647, 29, -2147483647,
+           30, -2147483647, 22, -2147483647, 23,  -2147483647, 24, -2147483647,
+           25, -2147483647, 26, -2147483647, 27,  -2147483647, 28, -2147483647,
+           29, -2147483647, 30, -2147483647, 32,  -2147483647, 33, -2147483647,
+           34, -2147483647, 35, -2147483647, 36,  -2147483647, 37, -2147483647,
+           38, -2147483647, 39, -2147483647, 40,  -2147483647, 32, -2147483647,
+           33, -2147483647, 34, -2147483647, 35,  -2147483647, 36, -2147483647,
+           37, -2147483647, 38, -2147483647, 39,  -2147483647, 40, -2147483647,
+           42, -2147483647, 43, -2147483647, 44,  -2147483647, 45, -2147483647,
+           46, -2147483647, 47, -2147483647, 48,  -2147483647, 49, -2147483647,
+           50, -2147483647, 42, -2147483647, 43,  -2147483647, 44, -2147483647,
+           45, -2147483647, 46, -2147483647, 47,  -2147483647, 48, -2147483647,
+           49, -2147483647, 50, -2147483647, 52,  -2147483647, 53, -2147483647,
+           54, -2147483647, 55, -2147483647, 56,  -2147483647, 57, -2147483647,
+           58, -2147483647, 59, -2147483647, 60,  -2147483647, 52, -2147483647,
+           53, -2147483647, 54, -2147483647, 55,  -2147483647, 56, -2147483647,
+           57, -2147483647, 58, -2147483647, 59,  -2147483647, 60, -2147483647,
+           62, -2147483647, 63, -2147483647, 64,  -2147483647, 65, -2147483647,
+           66, -2147483647, 67, -2147483647, 68,  -2147483647, 69, -2147483647,
+           70, -2147483647, 62, -2147483647, 63,  -2147483647, 64, -2147483647,
+           65, -2147483647, 66, -2147483647, 67,  -2147483647, 68, -2147483647,
+           69, -2147483647, 70, -2147483647, 72,  -2147483647, 73, -2147483647,
+           74, -2147483647, 75, -2147483647, 76,  -2147483647, 77, -2147483647,
+           78, -2147483647, 79, -2147483647, 80,  -2147483647, 72, -2147483647,
+           73, -2147483647, 74, -2147483647, 75,  -2147483647, 76, -2147483647,
+           77, -2147483647, 78, -2147483647, 79,  -2147483647, 80, -2147483647,
+           82, -2147483647, 83, -2147483647, 84,  -2147483647, 85, -2147483647,
+           86, -2147483647, 87, -2147483647, 88,  -2147483647, 89, -2147483647,
+           90, -2147483647, 82, -2147483647, 83,  -2147483647, 84, -2147483647,
+           85, -2147483647, 86, -2147483647, 87,  -2147483647, 88, -2147483647,
+           89, -2147483647, 90, -2147483647, 92,  -2147483647, 93, -2147483647,
+           94, -2147483647, 95, -2147483647, 96,  -2147483647, 97, -2147483647,
+           98, -2147483647, 99, -2147483647, 100, -2147483647}));
+}
+
+TEST(ReferenceTest, RandomJaxReference44) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{0, -2, 2, -1},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{1, 1},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(17, 11));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {-2147483647, -2147483647, 1,           2,           3,
+           4,           5,           6,           7,           8,
+           9,           -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, 11,
+           12,          13,          14,          15,          16,
+           17,          18,          19,          -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, 21,          22,          23,          24,
+           25,          26,          27,          28,          29,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, 31,          32,
+           33,          34,          35,          36,          37,
+           38,          39,          -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           41,          42,          43,          44,          45,
+           46,          47,          48,          49,          -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, 51,          52,          53,
+           54,          55,          56,          57,          58,
+           59,          -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, 61,
+           62,          63,          64,          65,          66,
+           67,          68,          69,          -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, 71,          72,          73,          74,
+           75,          76,          77,          78,          79,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, 81,          82,
+           83,          84,          85,          86,          87,
+           88,          89}));
+}
+
+TEST(ReferenceTest, RandomJaxReference45) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{0, -1, -2, -2},
+      /*init_value=*/1,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{1, 1},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(18, 6));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({3,  4,  5,  6,  7,  8,  1,  1,  1,  1,  1,  1,  13, 14,
+                        15, 16, 17, 18, 1,  1,  1,  1,  1,  1,  23, 24, 25, 26,
+                        27, 28, 1,  1,  1,  1,  1,  1,  33, 34, 35, 36, 37, 38,
+                        1,  1,  1,  1,  1,  1,  43, 44, 45, 46, 47, 48, 1,  1,
+                        1,  1,  1,  1,  53, 54, 55, 56, 57, 58, 1,  1,  1,  1,
+                        1,  1,  63, 64, 65, 66, 67, 68, 1,  1,  1,  1,  1,  1,
+                        73, 74, 75, 76, 77, 78, 1,  1,  1,  1,  1,  1,  83, 84,
+                        85, 86, 87, 88, 1,  1,  1,  1,  1,  1}));
+}
+
+TEST(ReferenceTest, RandomJaxReference46) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 2},
+      /*padding=*/{-1, 2, 0, -1},
+      /*init_value=*/1,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{2, 1},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(10, 17));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {11, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19,
+           21, 22, 22, 23, 23, 24, 24, 25, 25, 26, 26, 27, 27, 28, 28, 29, 29,
+           31, 32, 32, 33, 33, 34, 34, 35, 35, 36, 36, 37, 37, 38, 38, 39, 39,
+           41, 42, 42, 43, 43, 44, 44, 45, 45, 46, 46, 47, 47, 48, 48, 49, 49,
+           51, 52, 52, 53, 53, 54, 54, 55, 55, 56, 56, 57, 57, 58, 58, 59, 59,
+           61, 62, 62, 63, 63, 64, 64, 65, 65, 66, 66, 67, 67, 68, 68, 69, 69,
+           71, 72, 72, 73, 73, 74, 74, 75, 75, 76, 76, 77, 77, 78, 78, 79, 79,
+           81, 82, 82, 83, 83, 84, 84, 85, 85, 86, 86, 87, 87, 88, 88, 89, 89,
+           91, 92, 92, 93, 93, 94, 94, 95, 95, 96, 96, 97, 97, 98, 98, 99, 99,
+           1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1}));
+}
+
+TEST(ReferenceTest, RandomJaxReference47) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{0, -1, 0, 0},
+      /*init_value=*/0,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{1, 1},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(18, 10));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 0,  0,  0,  0,  0,  0,  0,
+           0,  0,  0,  11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 0,  0,  0,  0,
+           0,  0,  0,  0,  0,  0,  21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 0,
+           0,  0,  0,  0,  0,  0,  0,  0,  0,  31, 32, 33, 34, 35, 36, 37, 38,
+           39, 40, 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  41, 42, 43, 44, 45,
+           46, 47, 48, 49, 50, 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  51, 52,
+           53, 54, 55, 56, 57, 58, 59, 60, 0,  0,  0,  0,  0,  0,  0,  0,  0,
+           0,  61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 0,  0,  0,  0,  0,  0,
+           0,  0,  0,  0,  71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 0,  0,  0,
+           0,  0,  0,  0,  0,  0,  0,  81, 82, 83, 84, 85, 86, 87, 88, 89, 90,
+           0,  0,  0,  0,  0,  0,  0,  0,  0,  0}));
+}
+
+TEST(ReferenceTest, RandomJaxReference48) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{-2, -1, 1, 2},
+      /*init_value=*/1,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{1, 2},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(16, 6));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({11, 156,  210,  272,  342,  20, 1, 1, 1, 1, 1, 1,
+                        21, 506,  600,  702,  812,  30, 1, 1, 1, 1, 1, 1,
+                        31, 1056, 1190, 1332, 1482, 40, 1, 1, 1, 1, 1, 1,
+                        41, 1806, 1980, 2162, 2352, 50, 1, 1, 1, 1, 1, 1,
+                        51, 2756, 2970, 3192, 3422, 60, 1, 1, 1, 1, 1, 1,
+                        61, 3906, 4160, 4422, 4692, 70, 1, 1, 1, 1, 1, 1,
+                        71, 5256, 5550, 5852, 6162, 80, 1, 1, 1, 1, 1, 1,
+                        81, 6806, 7140, 7482, 7832, 90, 1, 1, 1, 1, 1, 1}));
+}
+
+TEST(ReferenceTest, RandomJaxReference49) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 2},
+      /*padding=*/{0, 1, -2, 0},
+      /*init_value=*/2147483646,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{2, 1},
+      /*body=*/[](auto a, auto b) { return a <= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(10, 17));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray(
+                  {2,          2147483646, 3,          2147483646, 4,
+                   2147483646, 5,          2147483646, 6,          2147483646,
+                   7,          2147483646, 8,          2147483646, 9,
+                   2147483646, 10,         12,         2147483646, 13,
+                   2147483646, 14,         2147483646, 15,         2147483646,
+                   16,         2147483646, 17,         2147483646, 18,
+                   2147483646, 19,         2147483646, 20,         22,
+                   2147483646, 23,         2147483646, 24,         2147483646,
+                   25,         2147483646, 26,         2147483646, 27,
+                   2147483646, 28,         2147483646, 29,         2147483646,
+                   30,         32,         2147483646, 33,         2147483646,
+                   34,         2147483646, 35,         2147483646, 36,
+                   2147483646, 37,         2147483646, 38,         2147483646,
+                   39,         2147483646, 40,         42,         2147483646,
+                   43,         2147483646, 44,         2147483646, 45,
+                   2147483646, 46,         2147483646, 47,         2147483646,
+                   48,         2147483646, 49,         2147483646, 50,
+                   52,         2147483646, 53,         2147483646, 54,
+                   2147483646, 55,         2147483646, 56,         2147483646,
+                   57,         2147483646, 58,         2147483646, 59,
+                   2147483646, 60,         62,         2147483646, 63,
+                   2147483646, 64,         2147483646, 65,         2147483646,
+                   66,         2147483646, 67,         2147483646, 68,
+                   2147483646, 69,         2147483646, 70,         72,
+                   2147483646, 73,         2147483646, 74,         2147483646,
+                   75,         2147483646, 76,         2147483646, 77,
+                   2147483646, 78,         2147483646, 79,         2147483646,
+                   80,         82,         2147483646, 83,         2147483646,
+                   84,         2147483646, 85,         2147483646, 86,
+                   2147483646, 87,         2147483646, 88,         2147483646,
+                   89,         2147483646, 90,         92,         2147483646,
+                   93,         2147483646, 94,         2147483646, 95,
+                   2147483646, 96,         2147483646, 97,         2147483646,
+                   98,         2147483646, 99,         2147483646, 100}));
+}
+
+TEST(ReferenceTest, RandomJaxReference50) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 2},
+      /*padding=*/{-1, -1, 1, 0},
+      /*init_value=*/2147483646,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{1, 2},
+      /*body=*/[](auto a, auto b) { return a <= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(16, 10));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646}));
+}
+
+TEST(ReferenceTest, RandomJaxReference51) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 2},
+      /*padding=*/{0, 2, -2, -1},
+      /*init_value=*/2147483646,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{1, 2},
+      /*body=*/[](auto a, auto b) { return a <= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(19, 7));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray(
+                  {2,          3,          4,          5,          6,
+                   7,          8,          2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 12,
+                   13,         14,         15,         16,         17,
+                   18,         2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 22,         23,
+                   24,         25,         26,         27,         28,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 32,         33,         34,
+                   35,         36,         37,         38,         2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 42,         43,         44,         45,
+                   46,         47,         48,         2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   52,         53,         54,         55,         56,
+                   57,         58,         2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 62,
+                   63,         64,         65,         66,         67,
+                   68,         2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 72,         73,
+                   74,         75,         76,         77,         78,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 82,         83,         84,
+                   85,         86,         87,         88,         2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 92,         93,         94,         95,
+                   96,         97,         98}));
+}
+
+TEST(ReferenceTest, RandomJaxReference52) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{-2, 0, 1, 2},
+      /*init_value=*/0,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{1, 2},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(8, 11));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray(
+                  {21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 0,  31,  32, 33, 34,
+                   35, 36, 37, 38, 39, 40, 0,  41, 42, 43, 44, 45,  46, 47, 48,
+                   49, 50, 0,  51, 52, 53, 54, 55, 56, 57, 58, 59,  60, 0,  61,
+                   62, 63, 64, 65, 66, 67, 68, 69, 70, 0,  71, 72,  73, 74, 75,
+                   76, 77, 78, 79, 80, 0,  81, 82, 83, 84, 85, 86,  87, 88, 89,
+                   90, 0,  91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 0}));
+}
+
+TEST(ReferenceTest, RandomJaxReference53) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 2},
+      /*padding=*/{2, 1, 0, 2},
+      /*init_value=*/2147483646,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{2, 2},
+      /*body=*/[](auto a, auto b) { return a <= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(11, 10));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray(
+                  {2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   1,          2,          3,          4,          5,
+                   6,          7,          8,          9,          10,
+                   11,         12,         13,         14,         15,
+                   16,         17,         18,         19,         20,
+                   21,         22,         23,         24,         25,
+                   26,         27,         28,         29,         30,
+                   31,         32,         33,         34,         35,
+                   36,         37,         38,         39,         40,
+                   41,         42,         43,         44,         45,
+                   46,         47,         48,         49,         50,
+                   51,         52,         53,         54,         55,
+                   56,         57,         58,         59,         60,
+                   61,         62,         63,         64,         65,
+                   66,         67,         68,         69,         70,
+                   71,         72,         73,         74,         75,
+                   76,         77,         78,         79,         80,
+                   81,         82,         83,         84,         85,
+                   86,         87,         88,         89,         90,
+                   91,         92,         93,         94,         95,
+                   96,         97,         98,         99,         100}));
+}
+
+TEST(ReferenceTest, RandomJaxReference54) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{-2, 0, 0, 2},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{2, 1},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(9, 12));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {11, 12, 13, 14, 15, 16, 17, 18, 19, 20,  -2147483647, -2147483647,
+           21, 22, 23, 24, 25, 26, 27, 28, 29, 30,  -2147483647, -2147483647,
+           31, 32, 33, 34, 35, 36, 37, 38, 39, 40,  -2147483647, -2147483647,
+           41, 42, 43, 44, 45, 46, 47, 48, 49, 50,  -2147483647, -2147483647,
+           51, 52, 53, 54, 55, 56, 57, 58, 59, 60,  -2147483647, -2147483647,
+           61, 62, 63, 64, 65, 66, 67, 68, 69, 70,  -2147483647, -2147483647,
+           71, 72, 73, 74, 75, 76, 77, 78, 79, 80,  -2147483647, -2147483647,
+           81, 82, 83, 84, 85, 86, 87, 88, 89, 90,  -2147483647, -2147483647,
+           91, 92, 93, 94, 95, 96, 97, 98, 99, 100, -2147483647, -2147483647}));
+}
+
+TEST(ReferenceTest, RandomJaxReference55) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{2, 1, -2, 2},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{1, 2},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(20, 5));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {3,           5,           7,           9,           -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           13,          15,          17,          19,          -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           23,          25,          27,          29,          -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           33,          35,          37,          39,          -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           43,          45,          47,          49,          -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           53,          55,          57,          59,          -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           63,          65,          67,          69,          -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           73,          75,          77,          79,          -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           83,          85,          87,          89,          -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           93,          95,          97,          99,          -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647}));
+}
+
+TEST(ReferenceTest, RandomJaxReference56) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{0, 0, 0, 1},
+      /*init_value=*/1,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{1, 1},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(18, 11));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {1,  2,  3,  4,  5,  6,  7,  8,  9,  10,  1,  11, 12, 13, 14, 15, 16,
+           17, 18, 19, 20, 1,  11, 12, 13, 14, 15,  16, 17, 18, 19, 20, 1,  21,
+           22, 23, 24, 25, 26, 27, 28, 29, 30, 1,   21, 22, 23, 24, 25, 26, 27,
+           28, 29, 30, 1,  31, 32, 33, 34, 35, 36,  37, 38, 39, 40, 1,  31, 32,
+           33, 34, 35, 36, 37, 38, 39, 40, 1,  41,  42, 43, 44, 45, 46, 47, 48,
+           49, 50, 1,  41, 42, 43, 44, 45, 46, 47,  48, 49, 50, 1,  51, 52, 53,
+           54, 55, 56, 57, 58, 59, 60, 1,  51, 52,  53, 54, 55, 56, 57, 58, 59,
+           60, 1,  61, 62, 63, 64, 65, 66, 67, 68,  69, 70, 1,  61, 62, 63, 64,
+           65, 66, 67, 68, 69, 70, 1,  71, 72, 73,  74, 75, 76, 77, 78, 79, 80,
+           1,  71, 72, 73, 74, 75, 76, 77, 78, 79,  80, 1,  81, 82, 83, 84, 85,
+           86, 87, 88, 89, 90, 1,  81, 82, 83, 84,  85, 86, 87, 88, 89, 90, 1,
+           91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 1}));
+}
+
+TEST(ReferenceTest, RandomJaxReference57) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 2},
+      /*padding=*/{0, 0, -2, 2},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{2, 2},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(10, 9));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({3,  4,  5,  6,  7,  8,  9,  10, 10, 13, 14,  15, 16,
+                        17, 18, 19, 20, 20, 23, 24, 25, 26, 27, 28,  29, 30,
+                        30, 33, 34, 35, 36, 37, 38, 39, 40, 40, 43,  44, 45,
+                        46, 47, 48, 49, 50, 50, 53, 54, 55, 56, 57,  58, 59,
+                        60, 60, 63, 64, 65, 66, 67, 68, 69, 70, 70,  73, 74,
+                        75, 76, 77, 78, 79, 80, 80, 83, 84, 85, 86,  87, 88,
+                        89, 90, 90, 93, 94, 95, 96, 97, 98, 99, 100, 100}));
+}
+
+TEST(ReferenceTest, RandomJaxReference58) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{-1, 2, 1, -2},
+      /*init_value=*/0,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{1, 2},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(11, 9));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray(
+                  {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                   0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                   0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                   0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                   0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
+}
+
+TEST(ReferenceTest, RandomJaxReference59) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 2},
+      /*padding=*/{2, -2, 2, 2},
+      /*init_value=*/2147483646,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{1, 2},
+      /*body=*/[](auto a, auto b) { return a <= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(18, 11));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray(
+                  {2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 1,          1,          2,          3,
+                   4,          5,          6,          7,          8,
+                   9,          10,         1,          1,          2,
+                   3,          4,          5,          6,          7,
+                   8,          9,          10,         11,         11,
+                   12,         13,         14,         15,         16,
+                   17,         18,         19,         20,         11,
+                   11,         12,         13,         14,         15,
+                   16,         17,         18,         19,         20,
+                   21,         21,         22,         23,         24,
+                   25,         26,         27,         28,         29,
+                   30,         21,         21,         22,         23,
+                   24,         25,         26,         27,         28,
+                   29,         30,         31,         31,         32,
+                   33,         34,         35,         36,         37,
+                   38,         39,         40,         31,         31,
+                   32,         33,         34,         35,         36,
+                   37,         38,         39,         40,         41,
+                   41,         42,         43,         44,         45,
+                   46,         47,         48,         49,         50,
+                   41,         41,         42,         43,         44,
+                   45,         46,         47,         48,         49,
+                   50,         51,         51,         52,         53,
+                   54,         55,         56,         57,         58,
+                   59,         60,         51,         51,         52,
+                   53,         54,         55,         56,         57,
+                   58,         59,         60,         61,         61,
+                   62,         63,         64,         65,         66,
+                   67,         68,         69,         70,         61,
+                   61,         62,         63,         64,         65,
+                   66,         67,         68,         69,         70,
+                   71,         71,         72,         73,         74,
+                   75,         76,         77,         78,         79,
+                   80,         71,         71,         72,         73,
+                   74,         75,         76,         77,         78,
+                   79,         80,         81,         81,         82,
+                   83,         84,         85,         86,         87,
+                   88,         89,         90}));
+}
+
+TEST(ReferenceTest, RandomJaxReference60) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{0, 2, -1, 0},
+      /*init_value=*/0,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{2, 2},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(6, 4));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray({5,   9,   13,  17,  45,  49,  53,  57,
+                                85,  89,  93,  97,  125, 129, 133, 137,
+                                165, 169, 173, 177, 0,   0,   0,   0}));
+}
+
+TEST(ReferenceTest, RandomJaxReference61) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 2},
+      /*padding=*/{0, -1, 2, -1},
+      /*init_value=*/0,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{1, 1},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(17, 20));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {0,  0, 1,  0, 2,  0, 3,  0, 4,  0, 5,  0, 6,  0, 7,  0, 8,  0,
+           9,  0, 0,  0, 11, 0, 12, 0, 13, 0, 14, 0, 15, 0, 16, 0, 17, 0,
+           18, 0, 19, 0, 0,  0, 11, 0, 12, 0, 13, 0, 14, 0, 15, 0, 16, 0,
+           17, 0, 18, 0, 19, 0, 0,  0, 21, 0, 22, 0, 23, 0, 24, 0, 25, 0,
+           26, 0, 27, 0, 28, 0, 29, 0, 0,  0, 21, 0, 22, 0, 23, 0, 24, 0,
+           25, 0, 26, 0, 27, 0, 28, 0, 29, 0, 0,  0, 31, 0, 32, 0, 33, 0,
+           34, 0, 35, 0, 36, 0, 37, 0, 38, 0, 39, 0, 0,  0, 31, 0, 32, 0,
+           33, 0, 34, 0, 35, 0, 36, 0, 37, 0, 38, 0, 39, 0, 0,  0, 41, 0,
+           42, 0, 43, 0, 44, 0, 45, 0, 46, 0, 47, 0, 48, 0, 49, 0, 0,  0,
+           41, 0, 42, 0, 43, 0, 44, 0, 45, 0, 46, 0, 47, 0, 48, 0, 49, 0,
+           0,  0, 51, 0, 52, 0, 53, 0, 54, 0, 55, 0, 56, 0, 57, 0, 58, 0,
+           59, 0, 0,  0, 51, 0, 52, 0, 53, 0, 54, 0, 55, 0, 56, 0, 57, 0,
+           58, 0, 59, 0, 0,  0, 61, 0, 62, 0, 63, 0, 64, 0, 65, 0, 66, 0,
+           67, 0, 68, 0, 69, 0, 0,  0, 61, 0, 62, 0, 63, 0, 64, 0, 65, 0,
+           66, 0, 67, 0, 68, 0, 69, 0, 0,  0, 71, 0, 72, 0, 73, 0, 74, 0,
+           75, 0, 76, 0, 77, 0, 78, 0, 79, 0, 0,  0, 71, 0, 72, 0, 73, 0,
+           74, 0, 75, 0, 76, 0, 77, 0, 78, 0, 79, 0, 0,  0, 81, 0, 82, 0,
+           83, 0, 84, 0, 85, 0, 86, 0, 87, 0, 88, 0, 89, 0, 0,  0, 81, 0,
+           82, 0, 83, 0, 84, 0, 85, 0, 86, 0, 87, 0, 88, 0, 89, 0}));
+}
+
+TEST(ReferenceTest, RandomJaxReference62) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{-2, -1, 2, 0},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{2, 1},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(3, 12));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {-2147483647, -2147483647, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
+           -2147483647, -2147483647, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70,
+           -2147483647, -2147483647, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90}));
+}
+
+TEST(ReferenceTest, RandomJaxReference63) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{-1, 0, 2, -2},
+      /*init_value=*/1,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{1, 1},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(16, 10));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({1, 1, 1,    1,    1,    1,    1,    1,    1,    1,
+                        1, 1, 231,  264,  299,  336,  375,  416,  459,  504,
+                        1, 1, 1,    1,    1,    1,    1,    1,    1,    1,
+                        1, 1, 651,  704,  759,  816,  875,  936,  999,  1064,
+                        1, 1, 1,    1,    1,    1,    1,    1,    1,    1,
+                        1, 1, 1271, 1344, 1419, 1496, 1575, 1656, 1739, 1824,
+                        1, 1, 1,    1,    1,    1,    1,    1,    1,    1,
+                        1, 1, 2091, 2184, 2279, 2376, 2475, 2576, 2679, 2784,
+                        1, 1, 1,    1,    1,    1,    1,    1,    1,    1,
+                        1, 1, 3111, 3224, 3339, 3456, 3575, 3696, 3819, 3944,
+                        1, 1, 1,    1,    1,    1,    1,    1,    1,    1,
+                        1, 1, 4331, 4464, 4599, 4736, 4875, 5016, 5159, 5304,
+                        1, 1, 1,    1,    1,    1,    1,    1,    1,    1,
+                        1, 1, 5751, 5904, 6059, 6216, 6375, 6536, 6699, 6864,
+                        1, 1, 1,    1,    1,    1,    1,    1,    1,    1,
+                        1, 1, 7371, 7544, 7719, 7896, 8075, 8256, 8439, 8624}));
+}
+
+TEST(ReferenceTest, RandomJaxReference64) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{1, 2, 0, -2},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{2, 2},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(11, 3));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray(
+                  {3,  5,  7,  13,          15,          17,         23, 25, 27,
+                   33, 35, 37, 43,          45,          47,         53, 55, 57,
+                   63, 65, 67, 73,          75,          77,         83, 85, 87,
+                   93, 95, 97, -2147483647, -2147483647, -2147483647}));
+}
+
+TEST(ReferenceTest, RandomJaxReference65) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{-1, 0, 2, 0},
+      /*init_value=*/0,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{2, 2},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(4, 11));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({0, 32,  34,  36,  38,  40,  42,  44,  46,  48,  50,
+                        0, 72,  74,  76,  78,  80,  82,  84,  86,  88,  90,
+                        0, 112, 114, 116, 118, 120, 122, 124, 126, 128, 130,
+                        0, 152, 154, 156, 158, 160, 162, 164, 166, 168, 170}));
+}
+
+TEST(ReferenceTest, RandomJaxReference66) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{0, 0, -1, -1},
+      /*init_value=*/0,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{2, 2},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(5, 8));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({14,  16,  18,  20,  22,  24,  26,  28,  54,  56,
+                        58,  60,  62,  64,  66,  68,  94,  96,  98,  100,
+                        102, 104, 106, 108, 134, 136, 138, 140, 142, 144,
+                        146, 148, 174, 176, 178, 180, 182, 184, 186, 188}));
+}
+
+TEST(ReferenceTest, RandomJaxReference67) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{1, 0, 2, 2},
+      /*init_value=*/0,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{2, 1},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(6, 13));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {0, 0,  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
+           0, 11, 23,  25,  27,  29,  31,  33,  35,  37,  39,  20,  0,
+           0, 31, 63,  65,  67,  69,  71,  73,  75,  77,  79,  40,  0,
+           0, 51, 103, 105, 107, 109, 111, 113, 115, 117, 119, 60,  0,
+           0, 71, 143, 145, 147, 149, 151, 153, 155, 157, 159, 80,  0,
+           0, 91, 183, 185, 187, 189, 191, 193, 195, 197, 199, 100, 0}));
+}
+
+TEST(ReferenceTest, RandomJaxReference68) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{2, 2, 1, -2},
+      /*init_value=*/0,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{1, 2},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(13, 9));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray(
+                  {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                   0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                   0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                   0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                   0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                   0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
+}
+
+TEST(ReferenceTest, RandomJaxReference69) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{-2, 1, -2, -1},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{2, 1},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(8, 5));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({25, 26, 27, 28, 29, 35, 36, 37, 38, 39, 45, 46, 47, 48,
+                        49, 55, 56, 57, 58, 59, 65, 66, 67, 68, 69, 75, 76, 77,
+                        78, 79, 85, 86, 87, 88, 89, 95, 96, 97, 98, 99}));
+}
+
+TEST(ReferenceTest, RandomJaxReference70) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{-1, -2, 0, 2},
+      /*init_value=*/2147483646,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{2, 2},
+      /*body=*/[](auto a, auto b) { return a <= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(4, 6));
+
+  EXPECT_THAT(res.data, ElementsAreArray({11, 13, 15, 17, 19, 2147483646,
+                                          31, 33, 35, 37, 39, 2147483646,
+                                          51, 53, 55, 57, 59, 2147483646,
+                                          71, 73, 75, 77, 79, 2147483646}));
+}
+
+TEST(ReferenceTest, RandomJaxReference71) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{1, 2, -2, 2},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{1, 1},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(21, 10));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {3,           4,           5,           6,           7,
+           8,           9,           10,          -2147483647, -2147483647,
+           3,           4,           5,           6,           7,
+           8,           9,           10,          -2147483647, -2147483647,
+           13,          14,          15,          16,          17,
+           18,          19,          20,          -2147483647, -2147483647,
+           13,          14,          15,          16,          17,
+           18,          19,          20,          -2147483647, -2147483647,
+           23,          24,          25,          26,          27,
+           28,          29,          30,          -2147483647, -2147483647,
+           23,          24,          25,          26,          27,
+           28,          29,          30,          -2147483647, -2147483647,
+           33,          34,          35,          36,          37,
+           38,          39,          40,          -2147483647, -2147483647,
+           33,          34,          35,          36,          37,
+           38,          39,          40,          -2147483647, -2147483647,
+           43,          44,          45,          46,          47,
+           48,          49,          50,          -2147483647, -2147483647,
+           43,          44,          45,          46,          47,
+           48,          49,          50,          -2147483647, -2147483647,
+           53,          54,          55,          56,          57,
+           58,          59,          60,          -2147483647, -2147483647,
+           53,          54,          55,          56,          57,
+           58,          59,          60,          -2147483647, -2147483647,
+           63,          64,          65,          66,          67,
+           68,          69,          70,          -2147483647, -2147483647,
+           63,          64,          65,          66,          67,
+           68,          69,          70,          -2147483647, -2147483647,
+           73,          74,          75,          76,          77,
+           78,          79,          80,          -2147483647, -2147483647,
+           73,          74,          75,          76,          77,
+           78,          79,          80,          -2147483647, -2147483647,
+           83,          84,          85,          86,          87,
+           88,          89,          90,          -2147483647, -2147483647,
+           83,          84,          85,          86,          87,
+           88,          89,          90,          -2147483647, -2147483647,
+           93,          94,          95,          96,          97,
+           98,          99,          100,         -2147483647, -2147483647,
+           93,          94,          95,          96,          97,
+           98,          99,          100,         -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647}));
+}
+
+TEST(ReferenceTest, RandomJaxReference72) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{0, -1, 2, 0},
+      /*init_value=*/0,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{2, 2},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(5, 5));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray({1,   4,   8,  12,  16,  21,  44, 48,  52,
+                                56,  41,  84, 88,  92,  96,  61, 124, 128,
+                                132, 136, 81, 164, 168, 172, 176}));
+}
+
+TEST(ReferenceTest, RandomJaxReference73) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{0, 0, 0, 0},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{1, 1},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(10, 8));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({3,  4,  5,  6,  7,  8,  9,  10, 13, 14, 15, 16, 17, 18,
+                        19, 20, 23, 24, 25, 26, 27, 28, 29, 30, 33, 34, 35, 36,
+                        37, 38, 39, 40, 43, 44, 45, 46, 47, 48, 49, 50, 53, 54,
+                        55, 56, 57, 58, 59, 60, 63, 64, 65, 66, 67, 68, 69, 70,
+                        73, 74, 75, 76, 77, 78, 79, 80, 83, 84, 85, 86, 87, 88,
+                        89, 90, 93, 94, 95, 96, 97, 98, 99, 100}));
+}
+
+TEST(ReferenceTest, RandomJaxReference74) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{0, -2, -2, -1},
+      /*init_value=*/0,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{1, 1},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(7, 5));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray({36,  40,  44,  48,  52,  76,  80,  84,  88,
+                                92,  116, 120, 124, 128, 132, 156, 160, 164,
+                                168, 172, 196, 200, 204, 208, 212, 236, 240,
+                                244, 248, 252, 276, 280, 284, 288, 292}));
+}
+
+TEST(ReferenceTest, RandomJaxReference75) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{0, 1, -2, 1},
+      /*init_value=*/0,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{2, 2},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(9, 5));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray({16,  20,  24,  28,  0,   36,  40,  44,  48,
+                                0,   56,  60,  64,  68,  0,   76,  80,  84,
+                                88,  0,   96,  100, 104, 108, 0,   116, 120,
+                                124, 128, 0,   136, 140, 144, 148, 0,   156,
+                                160, 164, 168, 0,   176, 180, 184, 188, 0}));
+}
+
+TEST(ReferenceTest, RandomJaxReference76) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{2, -1, -1, 0},
+      /*init_value=*/0,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{2, 1},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(6, 18));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({0, 0,  0, 0,  0, 0,  0, 0,  0, 0,  0, 0,  0, 0,  0, 0,
+                        0, 0,  0, 2,  0, 3,  0, 4,  0, 5,  0, 6,  0, 7,  0, 8,
+                        0, 9,  0, 10, 0, 22, 0, 23, 0, 24, 0, 25, 0, 26, 0, 27,
+                        0, 28, 0, 29, 0, 30, 0, 42, 0, 43, 0, 44, 0, 45, 0, 46,
+                        0, 47, 0, 48, 0, 49, 0, 50, 0, 62, 0, 63, 0, 64, 0, 65,
+                        0, 66, 0, 67, 0, 68, 0, 69, 0, 70, 0, 82, 0, 83, 0, 84,
+                        0, 85, 0, 86, 0, 87, 0, 88, 0, 89, 0, 90}));
+}
+
+TEST(ReferenceTest, RandomJaxReference77) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{-1, 2, -1, -2},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{2, 1},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(10, 5));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {-2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647}));
+}
+
+TEST(ReferenceTest, RandomJaxReference78) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{-2, 1, 2, -1},
+      /*init_value=*/0,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{1, 2},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(18, 6));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({0,  11, 13, 15, 17, 19, 0,  0,  0,  0,  0,  0,  0,  21,
+                        23, 25, 27, 29, 0,  0,  0,  0,  0,  0,  0,  31, 33, 35,
+                        37, 39, 0,  0,  0,  0,  0,  0,  0,  41, 43, 45, 47, 49,
+                        0,  0,  0,  0,  0,  0,  0,  51, 53, 55, 57, 59, 0,  0,
+                        0,  0,  0,  0,  0,  61, 63, 65, 67, 69, 0,  0,  0,  0,
+                        0,  0,  0,  71, 73, 75, 77, 79, 0,  0,  0,  0,  0,  0,
+                        0,  81, 83, 85, 87, 89, 0,  0,  0,  0,  0,  0,  0,  91,
+                        93, 95, 97, 99, 0,  0,  0,  0,  0,  0}));
+}
+
+TEST(ReferenceTest, RandomJaxReference79) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{-1, -1, -2, 1},
+      /*init_value=*/2147483646,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{1, 2},
+      /*body=*/[](auto a, auto b) { return a <= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(8, 9));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray(
+                  {12, 13, 14, 15, 16, 17, 18, 19, 20, 22, 23, 24, 25, 26, 27,
+                   28, 29, 30, 32, 33, 34, 35, 36, 37, 38, 39, 40, 42, 43, 44,
+                   45, 46, 47, 48, 49, 50, 52, 53, 54, 55, 56, 57, 58, 59, 60,
+                   62, 63, 64, 65, 66, 67, 68, 69, 70, 72, 73, 74, 75, 76, 77,
+                   78, 79, 80, 82, 83, 84, 85, 86, 87, 88, 89, 90}));
+}
+
+TEST(ReferenceTest, RandomJaxReference80) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{0, 2, 1, -1},
+      /*init_value=*/1,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{2, 2},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(10, 5));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({1, 24,   56,   96,   144,  1, 264,  336,  416,  504,
+                        1, 704,  816,  936,  1064, 1, 1344, 1496, 1656, 1824,
+                        1, 2184, 2376, 2576, 2784, 1, 3224, 3456, 3696, 3944,
+                        1, 4464, 4736, 5016, 5304, 1, 5904, 6216, 6536, 6864,
+                        1, 7544, 7896, 8256, 8624, 1, 92,   94,   96,   98}));
+}
+
+TEST(ReferenceTest, RandomJaxReference81) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{0, -1, 0, 2},
+      /*init_value=*/1,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{2, 2},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(5, 6));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray({1,  3,  5,  7,  9,  1,  21, 23, 25, 27,
+                                29, 1,  41, 43, 45, 47, 49, 1,  61, 63,
+                                65, 67, 69, 1,  81, 83, 85, 87, 89, 1}));
+}
+
+TEST(ReferenceTest, RandomJaxReference82) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{0, 2, 0, 2},
+      /*init_value=*/1,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{1, 2},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(11, 5));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {429,      2925,     8925,     20349,    171,      69069,    112125,
+           172125,   252909,   551,      494109,   664125,   874125,   1129869,
+           1131,     1803549,  2234925,  2738925,  3323229,  1911,     4765389,
+           5640525,  6630525,  7744989,  2891,     10387629, 11936925, 13652925,
+           15547149, 4071,     19918269, 22420125, 25150125, 28121709, 5451,
+           34845309, 38626125, 42706125, 47100669, 7031,     56896749, 62330925,
+           68144925, 74356029, 8811,     8463,     8835,     9215,     9603,
+           99,       1,        1,        1,        1,        1}));
+}
+
+TEST(ReferenceTest, RandomJaxReference83) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{2, -1, -2, -2},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{1, 2},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(10, 8));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {-2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, 2,           3,
+           4,           5,           6,           7,           8,
+           9,           12,          13,          14,          15,
+           16,          17,          18,          19,          22,
+           23,          24,          25,          26,          27,
+           28,          29,          32,          33,          34,
+           35,          36,          37,          38,          39,
+           42,          43,          44,          45,          46,
+           47,          48,          49,          52,          53,
+           54,          55,          56,          57,          58,
+           59,          62,          63,          64,          65,
+           66,          67,          68,          69,          72,
+           73,          74,          75,          76,          77,
+           78,          79,          82,          83,          84,
+           85,          86,          87,          88,          89}));
+}
+
+TEST(ReferenceTest, RandomJaxReference84) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 2},
+      /*padding=*/{2, -2, -2, 2},
+      /*init_value=*/2147483646,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{1, 2},
+      /*body=*/[](auto a, auto b) { return a <= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(19, 10));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2,          3,          4,          5,          6,
+           7,          8,          9,          10,         2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           12,         13,         14,         15,         16,
+           17,         18,         19,         20,         2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           22,         23,         24,         25,         26,
+           27,         28,         29,         30,         2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           32,         33,         34,         35,         36,
+           37,         38,         39,         40,         2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           42,         43,         44,         45,         46,
+           47,         48,         49,         50,         2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           52,         53,         54,         55,         56,
+           57,         58,         59,         60,         2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           62,         63,         64,         65,         66,
+           67,         68,         69,         70,         2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           72,         73,         74,         75,         76,
+           77,         78,         79,         80,         2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           82,         83,         84,         85,         86,
+           87,         88,         89,         90,         2147483646}));
+}
+
+TEST(ReferenceTest, RandomJaxReference85) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{1, 2, -2, -2},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{2, 2},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(11, 2));
+
+  EXPECT_THAT(res.data, ElementsAreArray(
+                            {-2147483647, -2147483647, -2147483647, -2147483647,
+                             -2147483647, -2147483647, -2147483647, -2147483647,
+                             -2147483647, -2147483647, -2147483647, -2147483647,
+                             -2147483647, -2147483647, -2147483647, -2147483647,
+                             -2147483647, -2147483647, -2147483647, -2147483647,
+                             -2147483647, -2147483647}));
+}
+
+TEST(ReferenceTest, RandomJaxReference86) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{-2, -1, 2, -2},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{2, 2},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(8, 5));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray(
+                  {-2147483647, 12, 14, 16, 18, -2147483647, 22, 24, 26, 28,
+                   -2147483647, 32, 34, 36, 38, -2147483647, 42, 44, 46, 48,
+                   -2147483647, 52, 54, 56, 58, -2147483647, 62, 64, 66, 68,
+                   -2147483647, 72, 74, 76, 78, -2147483647, 82, 84, 86, 88}));
+}
+
+TEST(ReferenceTest, RandomJaxReference87) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{-2, 0, 2, -1},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{1, 2},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(8, 10));
+
+  EXPECT_THAT(res.data, ElementsAreArray(
+                            {-2147483647, 21, 22, 23, 24, 25, 26, 27, 28, 29,
+                             -2147483647, 31, 32, 33, 34, 35, 36, 37, 38, 39,
+                             -2147483647, 41, 42, 43, 44, 45, 46, 47, 48, 49,
+                             -2147483647, 51, 52, 53, 54, 55, 56, 57, 58, 59,
+                             -2147483647, 61, 62, 63, 64, 65, 66, 67, 68, 69,
+                             -2147483647, 71, 72, 73, 74, 75, 76, 77, 78, 79,
+                             -2147483647, 81, 82, 83, 84, 85, 86, 87, 88, 89,
+                             -2147483647, 91, 92, 93, 94, 95, 96, 97, 98, 99}));
+}
+
+TEST(ReferenceTest, RandomJaxReference88) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{-2, 1, 2, 0},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{2, 1},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(4, 11));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({-2147483647, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
+                        -2147483647, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70,
+                        -2147483647, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90,
+                        -2147483647, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90}));
+}
+
+TEST(ReferenceTest, RandomJaxReference89) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{1, -2, 2, 2},
+      /*init_value=*/0,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{2, 1},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(9, 14));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
+}
+
+TEST(ReferenceTest, RandomJaxReference90) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{-2, 0, 1, 1},
+      /*init_value=*/0,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{2, 2},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(4, 11));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray({0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
+}
+
+TEST(ReferenceTest, RandomJaxReference91) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{-2, -2, 1, 2},
+      /*init_value=*/1,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{1, 2},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(5, 6));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({704,  574464,   763776,   995904,   1276800,  1200,
+                        1344, 2010624,  2477376,  3020544,  3648000,  2000,
+                        2184, 5189184,  6120576,  7171584,  8352000,  3000,
+                        3224, 11142144, 12773376, 14577024, 16564800, 4200,
+                        4464, 21141504, 23755776, 26604864, 29702400, 5600}));
+}
+
+TEST(ReferenceTest, RandomJaxReference92) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 2},
+      /*padding=*/{0, 0, 0, 2},
+      /*init_value=*/2147483646,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{2, 2},
+      /*body=*/[](auto a, auto b) { return a <= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(9, 10));
+
+  EXPECT_THAT(res.data, ElementsAreArray(
+                            {1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13,
+                             14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26,
+                             27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39,
+                             40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52,
+                             53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65,
+                             66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78,
+                             79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90}));
+}
+
+TEST(ReferenceTest, RandomJaxReference93) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{0, -1, 0, -2},
+      /*init_value=*/2147483646,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{1, 1},
+      /*body=*/[](auto a, auto b) { return a <= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(9, 17));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray(
+                  {1,          2147483646, 2,          2147483646, 3,
+                   2147483646, 4,          2147483646, 5,          2147483646,
+                   6,          2147483646, 7,          2147483646, 8,
+                   2147483646, 9,          11,         2147483646, 12,
+                   2147483646, 13,         2147483646, 14,         2147483646,
+                   15,         2147483646, 16,         2147483646, 17,
+                   2147483646, 18,         2147483646, 19,         21,
+                   2147483646, 22,         2147483646, 23,         2147483646,
+                   24,         2147483646, 25,         2147483646, 26,
+                   2147483646, 27,         2147483646, 28,         2147483646,
+                   29,         31,         2147483646, 32,         2147483646,
+                   33,         2147483646, 34,         2147483646, 35,
+                   2147483646, 36,         2147483646, 37,         2147483646,
+                   38,         2147483646, 39,         41,         2147483646,
+                   42,         2147483646, 43,         2147483646, 44,
+                   2147483646, 45,         2147483646, 46,         2147483646,
+                   47,         2147483646, 48,         2147483646, 49,
+                   51,         2147483646, 52,         2147483646, 53,
+                   2147483646, 54,         2147483646, 55,         2147483646,
+                   56,         2147483646, 57,         2147483646, 58,
+                   2147483646, 59,         61,         2147483646, 62,
+                   2147483646, 63,         2147483646, 64,         2147483646,
+                   65,         2147483646, 66,         2147483646, 67,
+                   2147483646, 68,         2147483646, 69,         71,
+                   2147483646, 72,         2147483646, 73,         2147483646,
+                   74,         2147483646, 75,         2147483646, 76,
+                   2147483646, 77,         2147483646, 78,         2147483646,
+                   79,         81,         2147483646, 82,         2147483646,
+                   83,         2147483646, 84,         2147483646, 85,
+                   2147483646, 86,         2147483646, 87,         2147483646,
+                   88,         2147483646, 89}));
+}
+
+TEST(ReferenceTest, RandomJaxReference94) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{-2, 0, -1, -2},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{1, 2},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(8, 3));
+
+  EXPECT_THAT(res.data, ElementsAreArray({23, 25, 27, 33, 35, 37, 43, 45,
+                                          47, 53, 55, 57, 63, 65, 67, 73,
+                                          75, 77, 83, 85, 87, 93, 95, 97}));
+}
+
+TEST(ReferenceTest, RandomJaxReference95) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{0, 0, 2, 2},
+      /*init_value=*/1,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{1, 1},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(10, 23));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {1,  1,  1,  1,  2,  1,  3,   1,  4,  1,  5,  1,  6,  1,  7,  1,  8,
+           1,  9,  1,  10, 1,  1,  1,   1,  11, 1,  12, 1,  13, 1,  14, 1,  15,
+           1,  16, 1,  17, 1,  18, 1,   19, 1,  20, 1,  1,  1,  1,  21, 1,  22,
+           1,  23, 1,  24, 1,  25, 1,   26, 1,  27, 1,  28, 1,  29, 1,  30, 1,
+           1,  1,  1,  31, 1,  32, 1,   33, 1,  34, 1,  35, 1,  36, 1,  37, 1,
+           38, 1,  39, 1,  40, 1,  1,   1,  1,  41, 1,  42, 1,  43, 1,  44, 1,
+           45, 1,  46, 1,  47, 1,  48,  1,  49, 1,  50, 1,  1,  1,  1,  51, 1,
+           52, 1,  53, 1,  54, 1,  55,  1,  56, 1,  57, 1,  58, 1,  59, 1,  60,
+           1,  1,  1,  1,  61, 1,  62,  1,  63, 1,  64, 1,  65, 1,  66, 1,  67,
+           1,  68, 1,  69, 1,  70, 1,   1,  1,  1,  71, 1,  72, 1,  73, 1,  74,
+           1,  75, 1,  76, 1,  77, 1,   78, 1,  79, 1,  80, 1,  1,  1,  1,  81,
+           1,  82, 1,  83, 1,  84, 1,   85, 1,  86, 1,  87, 1,  88, 1,  89, 1,
+           90, 1,  1,  1,  1,  91, 1,   92, 1,  93, 1,  94, 1,  95, 1,  96, 1,
+           97, 1,  98, 1,  99, 1,  100, 1,  1}));
+}
+
+TEST(ReferenceTest, RandomJaxReference96) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{2, -1, -1, 2},
+      /*init_value=*/0,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{1, 1},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(10, 10));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   5,   7,   9,
+           11,  13,  15,  17,  19,  10,  0,   30,  34,  38,  42,  46,  50,
+           54,  58,  30,  0,   70,  74,  78,  82,  86,  90,  94,  98,  50,
+           0,   110, 114, 118, 122, 126, 130, 134, 138, 70,  0,   150, 154,
+           158, 162, 166, 170, 174, 178, 90,  0,   190, 194, 198, 202, 206,
+           210, 214, 218, 110, 0,   230, 234, 238, 242, 246, 250, 254, 258,
+           130, 0,   270, 274, 278, 282, 286, 290, 294, 298, 150, 0,   310,
+           314, 318, 322, 326, 330, 334, 338, 170, 0}));
+}
+
+TEST(ReferenceTest, RandomJaxReference97) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{2, 2, -1, 1},
+      /*init_value=*/0,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{1, 2},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(12, 5));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({5,   9,   13,  17,  10,  25,  29,  33,  37,  20,
+                        50,  58,  66,  74,  40,  90,  98,  106, 114, 60,
+                        130, 138, 146, 154, 80,  170, 178, 186, 194, 100,
+                        210, 218, 226, 234, 120, 250, 258, 266, 274, 140,
+                        290, 298, 306, 314, 160, 330, 338, 346, 354, 180,
+                        165, 169, 173, 177, 90,  185, 189, 193, 197, 100}));
+}
+
+TEST(ReferenceTest, RandomJaxReference98) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 2},
+      /*padding=*/{2, -2, -1, 0},
+      /*init_value=*/2147483646,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{1, 1},
+      /*body=*/[](auto a, auto b) { return a <= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(18, 17));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray(
+                  {2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2,          2,          3,
+                   3,          4,          4,          5,          5,
+                   6,          6,          7,          7,          8,
+                   8,          9,          9,          10,         2,
+                   2,          3,          3,          4,          4,
+                   5,          5,          6,          6,          7,
+                   7,          8,          8,          9,          9,
+                   10,         12,         12,         13,         13,
+                   14,         14,         15,         15,         16,
+                   16,         17,         17,         18,         18,
+                   19,         19,         20,         12,         12,
+                   13,         13,         14,         14,         15,
+                   15,         16,         16,         17,         17,
+                   18,         18,         19,         19,         20,
+                   22,         22,         23,         23,         24,
+                   24,         25,         25,         26,         26,
+                   27,         27,         28,         28,         29,
+                   29,         30,         22,         22,         23,
+                   23,         24,         24,         25,         25,
+                   26,         26,         27,         27,         28,
+                   28,         29,         29,         30,         32,
+                   32,         33,         33,         34,         34,
+                   35,         35,         36,         36,         37,
+                   37,         38,         38,         39,         39,
+                   40,         32,         32,         33,         33,
+                   34,         34,         35,         35,         36,
+                   36,         37,         37,         38,         38,
+                   39,         39,         40,         42,         42,
+                   43,         43,         44,         44,         45,
+                   45,         46,         46,         47,         47,
+                   48,         48,         49,         49,         50,
+                   42,         42,         43,         43,         44,
+                   44,         45,         45,         46,         46,
+                   47,         47,         48,         48,         49,
+                   49,         50,         52,         52,         53,
+                   53,         54,         54,         55,         55,
+                   56,         56,         57,         57,         58,
+                   58,         59,         59,         60,         52,
+                   52,         53,         53,         54,         54,
+                   55,         55,         56,         56,         57,
+                   57,         58,         58,         59,         59,
+                   60,         62,         62,         63,         63,
+                   64,         64,         65,         65,         66,
+                   66,         67,         67,         68,         68,
+                   69,         69,         70,         62,         62,
+                   63,         63,         64,         64,         65,
+                   65,         66,         66,         67,         67,
+                   68,         68,         69,         69,         70,
+                   72,         72,         73,         73,         74,
+                   74,         75,         75,         76,         76,
+                   77,         77,         78,         78,         79,
+                   79,         80,         72,         72,         73,
+                   73,         74,         74,         75,         75,
+                   76,         76,         77,         77,         78,
+                   78,         79,         79,         80,         82,
+                   82,         83,         83,         84,         84,
+                   85,         85,         86,         86,         87,
+                   87,         88,         88,         89,         89,
+                   90}));
+}
+
+TEST(ReferenceTest, RandomJaxReference99) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{-1, -1, -2, 1},
+      /*init_value=*/1,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{2, 2},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(4, 9));
+
+  EXPECT_THAT(res.data, ElementsAreArray({12, 13, 14, 15, 16, 17, 18, 19, 20,
+                                          32, 33, 34, 35, 36, 37, 38, 39, 40,
+                                          52, 53, 54, 55, 56, 57, 58, 59, 60,
+                                          72, 73, 74, 75, 76, 77, 78, 79, 80}));
+}
+
+TEST(ReferenceTest, RandomJaxReference100) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 2},
+      /*padding=*/{0, 1, 1, 1},
+      /*init_value=*/2147483646,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{2, 1},
+      /*body=*/[](auto a, auto b) { return a <= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(10, 20));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {1,  1,  2,  2,  3,  3,  4,  4,  5,  5,  6,  6,   7,  7,  8,  8,  9,
+           9,  10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15,  15, 16, 16, 17, 17,
+           18, 18, 19, 19, 20, 20, 21, 21, 22, 22, 23, 23,  24, 24, 25, 25, 26,
+           26, 27, 27, 28, 28, 29, 29, 30, 30, 31, 31, 32,  32, 33, 33, 34, 34,
+           35, 35, 36, 36, 37, 37, 38, 38, 39, 39, 40, 40,  41, 41, 42, 42, 43,
+           43, 44, 44, 45, 45, 46, 46, 47, 47, 48, 48, 49,  49, 50, 50, 51, 51,
+           52, 52, 53, 53, 54, 54, 55, 55, 56, 56, 57, 57,  58, 58, 59, 59, 60,
+           60, 61, 61, 62, 62, 63, 63, 64, 64, 65, 65, 66,  66, 67, 67, 68, 68,
+           69, 69, 70, 70, 71, 71, 72, 72, 73, 73, 74, 74,  75, 75, 76, 76, 77,
+           77, 78, 78, 79, 79, 80, 80, 81, 81, 82, 82, 83,  83, 84, 84, 85, 85,
+           86, 86, 87, 87, 88, 88, 89, 89, 90, 90, 91, 91,  92, 92, 93, 93, 94,
+           94, 95, 95, 96, 96, 97, 97, 98, 98, 99, 99, 100, 100}));
+}
+
+TEST(ReferenceTest, RandomJaxReference101) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{-2, 2, 2, 0},
+      /*init_value=*/1,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{1, 1},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(17, 12));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {1, 1, 231,  264,  299,  336,  375,  416,  459,  504,  551,  600,
+           1, 1, 1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
+           1, 1, 651,  704,  759,  816,  875,  936,  999,  1064, 1131, 1200,
+           1, 1, 1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
+           1, 1, 1271, 1344, 1419, 1496, 1575, 1656, 1739, 1824, 1911, 2000,
+           1, 1, 1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
+           1, 1, 2091, 2184, 2279, 2376, 2475, 2576, 2679, 2784, 2891, 3000,
+           1, 1, 1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
+           1, 1, 3111, 3224, 3339, 3456, 3575, 3696, 3819, 3944, 4071, 4200,
+           1, 1, 1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
+           1, 1, 4331, 4464, 4599, 4736, 4875, 5016, 5159, 5304, 5451, 5600,
+           1, 1, 1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
+           1, 1, 5751, 5904, 6059, 6216, 6375, 6536, 6699, 6864, 7031, 7200,
+           1, 1, 1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
+           1, 1, 7371, 7544, 7719, 7896, 8075, 8256, 8439, 8624, 8811, 9000,
+           1, 1, 1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
+           1, 1, 91,   92,   93,   94,   95,   96,   97,   98,   99,   100}));
+}
+
+TEST(ReferenceTest, RandomJaxReference102) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 2},
+      /*padding=*/{1, 1, -2, 1},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{2, 1},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(11, 16));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {-2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647}));
+}
+
+TEST(ReferenceTest, RandomJaxReference103) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{0, 1, 1, -1},
+      /*init_value=*/1,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{1, 1},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(11, 8));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {2,    3,    8,    15,   24,   35,   48,   63,   12,   143,  168,
+           195,  224,  255,  288,  323,  22,   483,  528,  575,  624,  675,
+           728,  783,  32,   1023, 1088, 1155, 1224, 1295, 1368, 1443, 42,
+           1763, 1848, 1935, 2024, 2115, 2208, 2303, 52,   2703, 2808, 2915,
+           3024, 3135, 3248, 3363, 62,   3843, 3968, 4095, 4224, 4355, 4488,
+           4623, 72,   5183, 5328, 5475, 5624, 5775, 5928, 6083, 82,   6723,
+           6888, 7055, 7224, 7395, 7568, 7743, 92,   8463, 8648, 8835, 9024,
+           9215, 9408, 9603, 1,    1,    1,    1,    1,    1,    1,    1}));
+}
+
+TEST(ReferenceTest, RandomJaxReference104) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{2, -1, 1, -1},
+      /*init_value=*/0,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{1, 1},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(18, 9));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {1,   3,   5,   7,   9,   11,  13,  15,  17,  0,   0,   0,   0,   0,
+           0,   0,   0,   0,   12,  26,  30,  34,  38,  42,  46,  50,  54,  0,
+           0,   0,   0,   0,   0,   0,   0,   0,   32,  66,  70,  74,  78,  82,
+           86,  90,  94,  0,   0,   0,   0,   0,   0,   0,   0,   0,   52,  106,
+           110, 114, 118, 122, 126, 130, 134, 0,   0,   0,   0,   0,   0,   0,
+           0,   0,   72,  146, 150, 154, 158, 162, 166, 170, 174, 0,   0,   0,
+           0,   0,   0,   0,   0,   0,   92,  186, 190, 194, 198, 202, 206, 210,
+           214, 0,   0,   0,   0,   0,   0,   0,   0,   0,   112, 226, 230, 234,
+           238, 242, 246, 250, 254, 0,   0,   0,   0,   0,   0,   0,   0,   0,
+           132, 266, 270, 274, 278, 282, 286, 290, 294, 0,   0,   0,   0,   0,
+           0,   0,   0,   0,   152, 306, 310, 314, 318, 322, 326, 330, 334, 0,
+           0,   0,   0,   0,   0,   0,   0,   0}));
+}
+
+TEST(ReferenceTest, RandomJaxReference105) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{-1, 2, 1, -1},
+      /*init_value=*/1,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{1, 1},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(18, 10));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
+           231,  264,  299,  336,  375,  416,  459,  504,  551,  1,    1,
+           1,    1,    1,    1,    1,    1,    1,    1,    1,    651,  704,
+           759,  816,  875,  936,  999,  1064, 1131, 1,    1,    1,    1,
+           1,    1,    1,    1,    1,    1,    1,    1271, 1344, 1419, 1496,
+           1575, 1656, 1739, 1824, 1911, 1,    1,    1,    1,    1,    1,
+           1,    1,    1,    1,    1,    2091, 2184, 2279, 2376, 2475, 2576,
+           2679, 2784, 2891, 1,    1,    1,    1,    1,    1,    1,    1,
+           1,    1,    1,    3111, 3224, 3339, 3456, 3575, 3696, 3819, 3944,
+           4071, 1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
+           1,    4331, 4464, 4599, 4736, 4875, 5016, 5159, 5304, 5451, 1,
+           1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    5751,
+           5904, 6059, 6216, 6375, 6536, 6699, 6864, 7031, 1,    1,    1,
+           1,    1,    1,    1,    1,    1,    1,    1,    7371, 7544, 7719,
+           7896, 8075, 8256, 8439, 8624, 8811, 1,    1,    1,    1,    1,
+           1,    1,    1,    1,    1,    1,    91,   92,   93,   94,   95,
+           96,   97,   98,   99}));
+}
+
+TEST(ReferenceTest, RandomJaxReference106) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{2, 2, -2, 2},
+      /*init_value=*/1,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{2, 1},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(7, 18));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
+           1,  1,  2,  3,  3,  4,  4,  5,  5,  6,  6,  7,  7,  8,  8,  9,
+           9,  10, 10, 1,  22, 23, 23, 24, 24, 25, 25, 26, 26, 27, 27, 28,
+           28, 29, 29, 30, 30, 1,  42, 43, 43, 44, 44, 45, 45, 46, 46, 47,
+           47, 48, 48, 49, 49, 50, 50, 1,  62, 63, 63, 64, 64, 65, 65, 66,
+           66, 67, 67, 68, 68, 69, 69, 70, 70, 1,  82, 83, 83, 84, 84, 85,
+           85, 86, 86, 87, 87, 88, 88, 89, 89, 90, 90, 1,  1,  1,  1,  1,
+           1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1}));
+}
+
+TEST(ReferenceTest, RandomJaxReference107) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 2},
+      /*padding=*/{0, 1, 2, 0},
+      /*init_value=*/2147483646,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{2, 2},
+      /*body=*/[](auto a, auto b) { return a <= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(10, 11));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({2147483646, 1,  2,  3,  4,  5,  6,  7,  8,  9,  10,
+                        2147483646, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
+                        2147483646, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
+                        2147483646, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40,
+                        2147483646, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
+                        2147483646, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60,
+                        2147483646, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70,
+                        2147483646, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
+                        2147483646, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90,
+                        2147483646, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100}));
+}
+
+TEST(ReferenceTest, RandomJaxReference108) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{2, -1, 2, -1},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{1, 1},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(11, 20));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {-2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, 1,           -2147483647, 2,
+           -2147483647, 3,           -2147483647, 4,           -2147483647,
+           5,           -2147483647, 6,           -2147483647, 7,
+           -2147483647, 8,           -2147483647, 9,           -2147483647,
+           -2147483647, -2147483647, 11,          -2147483647, 12,
+           -2147483647, 13,          -2147483647, 14,          -2147483647,
+           15,          -2147483647, 16,          -2147483647, 17,
+           -2147483647, 18,          -2147483647, 19,          -2147483647,
+           -2147483647, -2147483647, 21,          -2147483647, 22,
+           -2147483647, 23,          -2147483647, 24,          -2147483647,
+           25,          -2147483647, 26,          -2147483647, 27,
+           -2147483647, 28,          -2147483647, 29,          -2147483647,
+           -2147483647, -2147483647, 31,          -2147483647, 32,
+           -2147483647, 33,          -2147483647, 34,          -2147483647,
+           35,          -2147483647, 36,          -2147483647, 37,
+           -2147483647, 38,          -2147483647, 39,          -2147483647,
+           -2147483647, -2147483647, 41,          -2147483647, 42,
+           -2147483647, 43,          -2147483647, 44,          -2147483647,
+           45,          -2147483647, 46,          -2147483647, 47,
+           -2147483647, 48,          -2147483647, 49,          -2147483647,
+           -2147483647, -2147483647, 51,          -2147483647, 52,
+           -2147483647, 53,          -2147483647, 54,          -2147483647,
+           55,          -2147483647, 56,          -2147483647, 57,
+           -2147483647, 58,          -2147483647, 59,          -2147483647,
+           -2147483647, -2147483647, 61,          -2147483647, 62,
+           -2147483647, 63,          -2147483647, 64,          -2147483647,
+           65,          -2147483647, 66,          -2147483647, 67,
+           -2147483647, 68,          -2147483647, 69,          -2147483647,
+           -2147483647, -2147483647, 71,          -2147483647, 72,
+           -2147483647, 73,          -2147483647, 74,          -2147483647,
+           75,          -2147483647, 76,          -2147483647, 77,
+           -2147483647, 78,          -2147483647, 79,          -2147483647,
+           -2147483647, -2147483647, 81,          -2147483647, 82,
+           -2147483647, 83,          -2147483647, 84,          -2147483647,
+           85,          -2147483647, 86,          -2147483647, 87,
+           -2147483647, 88,          -2147483647, 89,          -2147483647}));
+}
+
+TEST(ReferenceTest, RandomJaxReference109) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{0, -2, 0, 0},
+      /*init_value=*/2147483646,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{2, 2},
+      /*body=*/[](auto a, auto b) { return a <= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(8, 5));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({1,  3,  5,  7,  9,  11, 13, 15, 17, 19, 21, 23, 25, 27,
+                        29, 31, 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55,
+                        57, 59, 61, 63, 65, 67, 69, 71, 73, 75, 77, 79}));
+}
+
+TEST(ReferenceTest, RandomJaxReference110) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 2},
+      /*padding=*/{-1, -1, 2, 0},
+      /*init_value=*/1,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{1, 1},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(17, 20));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
+           1,  1,  1,  1,  11, 11, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 17,
+           17, 18, 18, 19, 19, 20, 1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
+           1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  21, 21, 22, 22, 23, 23, 24,
+           24, 25, 25, 26, 26, 27, 27, 28, 28, 29, 29, 30, 1,  1,  1,  1,  1,
+           1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  31,
+           31, 32, 32, 33, 33, 34, 34, 35, 35, 36, 36, 37, 37, 38, 38, 39, 39,
+           40, 1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
+           1,  1,  1,  1,  1,  41, 41, 42, 42, 43, 43, 44, 44, 45, 45, 46, 46,
+           47, 47, 48, 48, 49, 49, 50, 1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
+           1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  51, 51, 52, 52, 53, 53,
+           54, 54, 55, 55, 56, 56, 57, 57, 58, 58, 59, 59, 60, 1,  1,  1,  1,
+           1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
+           61, 61, 62, 62, 63, 63, 64, 64, 65, 65, 66, 66, 67, 67, 68, 68, 69,
+           69, 70, 1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
+           1,  1,  1,  1,  1,  1,  71, 71, 72, 72, 73, 73, 74, 74, 75, 75, 76,
+           76, 77, 77, 78, 78, 79, 79, 80, 1,  1,  1,  1,  1,  1,  1,  1,  1,
+           1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  81, 81, 82, 82, 83,
+           83, 84, 84, 85, 85, 86, 86, 87, 87, 88, 88, 89, 89, 90, 1,  1,  1,
+           1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1}));
+}
+
+TEST(ReferenceTest, RandomJaxReference111) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{-1, 0, 2, -1},
+      /*init_value=*/2147483646,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{1, 1},
+      /*body=*/[](auto a, auto b) { return a <= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(8, 11));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {2147483646, 2147483646, 11, 12, 13, 14, 15, 16, 17, 18, 19,
+           2147483646, 2147483646, 21, 22, 23, 24, 25, 26, 27, 28, 29,
+           2147483646, 2147483646, 31, 32, 33, 34, 35, 36, 37, 38, 39,
+           2147483646, 2147483646, 41, 42, 43, 44, 45, 46, 47, 48, 49,
+           2147483646, 2147483646, 51, 52, 53, 54, 55, 56, 57, 58, 59,
+           2147483646, 2147483646, 61, 62, 63, 64, 65, 66, 67, 68, 69,
+           2147483646, 2147483646, 71, 72, 73, 74, 75, 76, 77, 78, 79,
+           2147483646, 2147483646, 81, 82, 83, 84, 85, 86, 87, 88, 89}));
+}
+
+TEST(ReferenceTest, RandomJaxReference112) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{1, 1, 1, 2},
+      /*init_value=*/2147483646,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{1, 1},
+      /*body=*/[](auto a, auto b) { return a <= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(20, 13));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {2147483646, 1,          2,          3,          4,
+           5,          6,          7,          8,          9,
+           10,         2147483646, 2147483646, 2147483646, 1,
+           2,          3,          4,          5,          6,
+           7,          8,          9,          10,         2147483646,
+           2147483646, 2147483646, 11,         12,         13,
+           14,         15,         16,         17,         18,
+           19,         20,         2147483646, 2147483646, 2147483646,
+           11,         12,         13,         14,         15,
+           16,         17,         18,         19,         20,
+           2147483646, 2147483646, 2147483646, 21,         22,
+           23,         24,         25,         26,         27,
+           28,         29,         30,         2147483646, 2147483646,
+           2147483646, 21,         22,         23,         24,
+           25,         26,         27,         28,         29,
+           30,         2147483646, 2147483646, 2147483646, 31,
+           32,         33,         34,         35,         36,
+           37,         38,         39,         40,         2147483646,
+           2147483646, 2147483646, 31,         32,         33,
+           34,         35,         36,         37,         38,
+           39,         40,         2147483646, 2147483646, 2147483646,
+           41,         42,         43,         44,         45,
+           46,         47,         48,         49,         50,
+           2147483646, 2147483646, 2147483646, 41,         42,
+           43,         44,         45,         46,         47,
+           48,         49,         50,         2147483646, 2147483646,
+           2147483646, 51,         52,         53,         54,
+           55,         56,         57,         58,         59,
+           60,         2147483646, 2147483646, 2147483646, 51,
+           52,         53,         54,         55,         56,
+           57,         58,         59,         60,         2147483646,
+           2147483646, 2147483646, 61,         62,         63,
+           64,         65,         66,         67,         68,
+           69,         70,         2147483646, 2147483646, 2147483646,
+           61,         62,         63,         64,         65,
+           66,         67,         68,         69,         70,
+           2147483646, 2147483646, 2147483646, 71,         72,
+           73,         74,         75,         76,         77,
+           78,         79,         80,         2147483646, 2147483646,
+           2147483646, 71,         72,         73,         74,
+           75,         76,         77,         78,         79,
+           80,         2147483646, 2147483646, 2147483646, 81,
+           82,         83,         84,         85,         86,
+           87,         88,         89,         90,         2147483646,
+           2147483646, 2147483646, 81,         82,         83,
+           84,         85,         86,         87,         88,
+           89,         90,         2147483646, 2147483646, 2147483646,
+           91,         92,         93,         94,         95,
+           96,         97,         98,         99,         100,
+           2147483646, 2147483646, 2147483646, 91,         92,
+           93,         94,         95,         96,         97,
+           98,         99,         100,        2147483646, 2147483646}));
+}
+
+TEST(ReferenceTest, RandomJaxReference113) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{2, -2, 1, 0},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{2, 1},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(5, 10));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {-2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           1,           2,           3,           4,           5,
+           6,           7,           8,           9,           10,
+           21,          22,          23,          24,          25,
+           26,          27,          28,          29,          30,
+           41,          42,          43,          44,          45,
+           46,          47,          48,          49,          50,
+           61,          62,          63,          64,          65,
+           66,          67,          68,          69,          70}));
+}
+
+TEST(ReferenceTest, RandomJaxReference114) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{-2, -1, 1, 1},
+      /*init_value=*/0,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{2, 1},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(8, 10));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {12,  24,  26,  28,  30,  32,  34,  36,  38,  19,  22,  44,  46,  48,
+           50,  52,  54,  56,  58,  29,  32,  64,  66,  68,  70,  72,  74,  76,
+           78,  39,  42,  84,  86,  88,  90,  92,  94,  96,  98,  49,  52,  104,
+           106, 108, 110, 112, 114, 116, 118, 59,  62,  124, 126, 128, 130, 132,
+           134, 136, 138, 69,  72,  144, 146, 148, 150, 152, 154, 156, 158, 79,
+           82,  164, 166, 168, 170, 172, 174, 176, 178, 89}));
+}
+
+TEST(ReferenceTest, RandomJaxReference115) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{-2, -2, -2, 1},
+      /*init_value=*/0,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{2, 2},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(7, 4));
+
+  EXPECT_THAT(res.data, ElementsAreArray({74,  82,  90,  98,  114, 122, 130,
+                                          138, 154, 162, 170, 178, 194, 202,
+                                          210, 218, 234, 242, 250, 258, 274,
+                                          282, 290, 298, 314, 322, 330, 338}));
+}
+
+TEST(ReferenceTest, RandomJaxReference116) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 2},
+      /*padding=*/{0, -2, 1, 1},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{1, 1},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(16, 21));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {-2147483647, 1,           -2147483647, 2,           -2147483647,
+           3,           -2147483647, 4,           -2147483647, 5,
+           -2147483647, 6,           -2147483647, 7,           -2147483647,
+           8,           -2147483647, 9,           -2147483647, 10,
+           -2147483647, -2147483647, 11,          -2147483647, 12,
+           -2147483647, 13,          -2147483647, 14,          -2147483647,
+           15,          -2147483647, 16,          -2147483647, 17,
+           -2147483647, 18,          -2147483647, 19,          -2147483647,
+           20,          -2147483647, -2147483647, 11,          -2147483647,
+           12,          -2147483647, 13,          -2147483647, 14,
+           -2147483647, 15,          -2147483647, 16,          -2147483647,
+           17,          -2147483647, 18,          -2147483647, 19,
+           -2147483647, 20,          -2147483647, -2147483647, 21,
+           -2147483647, 22,          -2147483647, 23,          -2147483647,
+           24,          -2147483647, 25,          -2147483647, 26,
+           -2147483647, 27,          -2147483647, 28,          -2147483647,
+           29,          -2147483647, 30,          -2147483647, -2147483647,
+           21,          -2147483647, 22,          -2147483647, 23,
+           -2147483647, 24,          -2147483647, 25,          -2147483647,
+           26,          -2147483647, 27,          -2147483647, 28,
+           -2147483647, 29,          -2147483647, 30,          -2147483647,
+           -2147483647, 31,          -2147483647, 32,          -2147483647,
+           33,          -2147483647, 34,          -2147483647, 35,
+           -2147483647, 36,          -2147483647, 37,          -2147483647,
+           38,          -2147483647, 39,          -2147483647, 40,
+           -2147483647, -2147483647, 31,          -2147483647, 32,
+           -2147483647, 33,          -2147483647, 34,          -2147483647,
+           35,          -2147483647, 36,          -2147483647, 37,
+           -2147483647, 38,          -2147483647, 39,          -2147483647,
+           40,          -2147483647, -2147483647, 41,          -2147483647,
+           42,          -2147483647, 43,          -2147483647, 44,
+           -2147483647, 45,          -2147483647, 46,          -2147483647,
+           47,          -2147483647, 48,          -2147483647, 49,
+           -2147483647, 50,          -2147483647, -2147483647, 41,
+           -2147483647, 42,          -2147483647, 43,          -2147483647,
+           44,          -2147483647, 45,          -2147483647, 46,
+           -2147483647, 47,          -2147483647, 48,          -2147483647,
+           49,          -2147483647, 50,          -2147483647, -2147483647,
+           51,          -2147483647, 52,          -2147483647, 53,
+           -2147483647, 54,          -2147483647, 55,          -2147483647,
+           56,          -2147483647, 57,          -2147483647, 58,
+           -2147483647, 59,          -2147483647, 60,          -2147483647,
+           -2147483647, 51,          -2147483647, 52,          -2147483647,
+           53,          -2147483647, 54,          -2147483647, 55,
+           -2147483647, 56,          -2147483647, 57,          -2147483647,
+           58,          -2147483647, 59,          -2147483647, 60,
+           -2147483647, -2147483647, 61,          -2147483647, 62,
+           -2147483647, 63,          -2147483647, 64,          -2147483647,
+           65,          -2147483647, 66,          -2147483647, 67,
+           -2147483647, 68,          -2147483647, 69,          -2147483647,
+           70,          -2147483647, -2147483647, 61,          -2147483647,
+           62,          -2147483647, 63,          -2147483647, 64,
+           -2147483647, 65,          -2147483647, 66,          -2147483647,
+           67,          -2147483647, 68,          -2147483647, 69,
+           -2147483647, 70,          -2147483647, -2147483647, 71,
+           -2147483647, 72,          -2147483647, 73,          -2147483647,
+           74,          -2147483647, 75,          -2147483647, 76,
+           -2147483647, 77,          -2147483647, 78,          -2147483647,
+           79,          -2147483647, 80,          -2147483647, -2147483647,
+           71,          -2147483647, 72,          -2147483647, 73,
+           -2147483647, 74,          -2147483647, 75,          -2147483647,
+           76,          -2147483647, 77,          -2147483647, 78,
+           -2147483647, 79,          -2147483647, 80,          -2147483647,
+           -2147483647, 81,          -2147483647, 82,          -2147483647,
+           83,          -2147483647, 84,          -2147483647, 85,
+           -2147483647, 86,          -2147483647, 87,          -2147483647,
+           88,          -2147483647, 89,          -2147483647, 90,
+           -2147483647}));
+}
+
+TEST(ReferenceTest, RandomJaxReference117) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{-2, -2, -1, 0},
+      /*init_value=*/1,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{2, 1},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(8, 8));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {156,  182,  210,  240,  272,  306,  342,  380,  506,  552,  600,
+           650,  702,  756,  812,  870,  1056, 1122, 1190, 1260, 1332, 1406,
+           1482, 1560, 1806, 1892, 1980, 2070, 2162, 2256, 2352, 2450, 2756,
+           2862, 2970, 3080, 3192, 3306, 3422, 3540, 3906, 4032, 4160, 4290,
+           4422, 4556, 4692, 4830, 5256, 5402, 5550, 5700, 5852, 6006, 6162,
+           6320, 6806, 6972, 7140, 7310, 7482, 7656, 7832, 8010}));
+}
+
+TEST(ReferenceTest, RandomJaxReference118) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 2},
+      /*padding=*/{-2, -1, -2, 1},
+      /*init_value=*/0,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{2, 2},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(8, 8));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({25,  27,  29,  31,  33,  35,  37,  39,  45,  47,  49,
+                        51,  53,  55,  57,  59,  65,  67,  69,  71,  73,  75,
+                        77,  79,  85,  87,  89,  91,  93,  95,  97,  99,  105,
+                        107, 109, 111, 113, 115, 117, 119, 125, 127, 129, 131,
+                        133, 135, 137, 139, 145, 147, 149, 151, 153, 155, 157,
+                        159, 165, 167, 169, 171, 173, 175, 177, 179}));
+}
+
+TEST(ReferenceTest, RandomJaxReference119) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{-2, 0, 1, 2},
+      /*init_value=*/1,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{1, 1},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(6, 22));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({1, 861,  1, 924,  1, 989,  1, 1056, 1, 1125, 1, 1196,
+                        1, 1269, 1, 1344, 1, 1421, 1, 1500, 1, 1,    1, 1581,
+                        1, 1664, 1, 1749, 1, 1836, 1, 1925, 1, 2016, 1, 2109,
+                        1, 2204, 1, 2301, 1, 2400, 1, 1,    1, 2501, 1, 2604,
+                        1, 2709, 1, 2816, 1, 2925, 1, 3036, 1, 3149, 1, 3264,
+                        1, 3381, 1, 3500, 1, 1,    1, 3621, 1, 3744, 1, 3869,
+                        1, 3996, 1, 4125, 1, 4256, 1, 4389, 1, 4524, 1, 4661,
+                        1, 4800, 1, 1,    1, 4941, 1, 5084, 1, 5229, 1, 5376,
+                        1, 5525, 1, 5676, 1, 5829, 1, 5984, 1, 6141, 1, 6300,
+                        1, 1,    1, 6461, 1, 6624, 1, 6789, 1, 6956, 1, 7125,
+                        1, 7296, 1, 7469, 1, 7644, 1, 7821, 1, 8000, 1, 1}));
+}
+
+TEST(ReferenceTest, RandomJaxReference120) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 2},
+      /*padding=*/{-2, 1, 2, 0},
+      /*init_value=*/2147483646,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{2, 1},
+      /*body=*/[](auto a, auto b) { return a <= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(9, 21));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray(
+                  {2147483646, 2147483646, 11,         2147483646, 12,
+                   2147483646, 13,         2147483646, 14,         2147483646,
+                   15,         2147483646, 16,         2147483646, 17,
+                   2147483646, 18,         2147483646, 19,         2147483646,
+                   20,         2147483646, 2147483646, 21,         2147483646,
+                   22,         2147483646, 23,         2147483646, 24,
+                   2147483646, 25,         2147483646, 26,         2147483646,
+                   27,         2147483646, 28,         2147483646, 29,
+                   2147483646, 30,         2147483646, 2147483646, 31,
+                   2147483646, 32,         2147483646, 33,         2147483646,
+                   34,         2147483646, 35,         2147483646, 36,
+                   2147483646, 37,         2147483646, 38,         2147483646,
+                   39,         2147483646, 40,         2147483646, 2147483646,
+                   41,         2147483646, 42,         2147483646, 43,
+                   2147483646, 44,         2147483646, 45,         2147483646,
+                   46,         2147483646, 47,         2147483646, 48,
+                   2147483646, 49,         2147483646, 50,         2147483646,
+                   2147483646, 51,         2147483646, 52,         2147483646,
+                   53,         2147483646, 54,         2147483646, 55,
+                   2147483646, 56,         2147483646, 57,         2147483646,
+                   58,         2147483646, 59,         2147483646, 60,
+                   2147483646, 2147483646, 61,         2147483646, 62,
+                   2147483646, 63,         2147483646, 64,         2147483646,
+                   65,         2147483646, 66,         2147483646, 67,
+                   2147483646, 68,         2147483646, 69,         2147483646,
+                   70,         2147483646, 2147483646, 71,         2147483646,
+                   72,         2147483646, 73,         2147483646, 74,
+                   2147483646, 75,         2147483646, 76,         2147483646,
+                   77,         2147483646, 78,         2147483646, 79,
+                   2147483646, 80,         2147483646, 2147483646, 81,
+                   2147483646, 82,         2147483646, 83,         2147483646,
+                   84,         2147483646, 85,         2147483646, 86,
+                   2147483646, 87,         2147483646, 88,         2147483646,
+                   89,         2147483646, 90,         2147483646, 2147483646,
+                   91,         2147483646, 92,         2147483646, 93,
+                   2147483646, 94,         2147483646, 95,         2147483646,
+                   96,         2147483646, 97,         2147483646, 98,
+                   2147483646, 99,         2147483646, 100}));
+}
+
+TEST(ReferenceTest, RandomJaxReference121) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 2},
+      /*padding=*/{0, 2, -1, 1},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{1, 2},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(20, 9));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {-2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647}));
+}
+
+TEST(ReferenceTest, RandomJaxReference122) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{0, 2, -1, 1},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{2, 2},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(5, 10));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {-2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647}));
+}
+
+TEST(ReferenceTest, RandomJaxReference123) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{-2, -2, 0, 2},
+      /*init_value=*/0,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{1, 2},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(6, 6));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray({43,  47,  51,  55,  59,  0,   63,  67,  71,
+                                75,  79,  0,   83,  87,  91,  95,  99,  0,
+                                103, 107, 111, 115, 119, 0,   123, 127, 131,
+                                135, 139, 0,   143, 147, 151, 155, 159, 0}));
+}
+
+TEST(ReferenceTest, RandomJaxReference124) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{0, 2, -2, 0},
+      /*init_value=*/1,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{1, 2},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(10, 4));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray({69,   125,  189,  261,  429,  525,  629,  741,
+                                989,  1125, 1269, 1421, 1749, 1925, 2109, 2301,
+                                2709, 2925, 3149, 3381, 3869, 4125, 4389, 4661,
+                                5229, 5525, 5829, 6141, 6789, 7125, 7469, 7821,
+                                83,   85,   87,   89,   93,   95,   97,   99}));
+}
+
+TEST(ReferenceTest, RandomJaxReference125) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 2},
+      /*padding=*/{-1, -1, 2, 1},
+      /*init_value=*/0,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{2, 1},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(9, 21));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
+}
+
+TEST(ReferenceTest, RandomJaxReference126) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{0, 1, 0, 0},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{1, 2},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(20, 5));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {1,           3,           5,           7,           9,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           11,          13,          15,          17,          19,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           21,          23,          25,          27,          29,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           31,          33,          35,          37,          39,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           41,          43,          45,          47,          49,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           51,          53,          55,          57,          59,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           61,          63,          65,          67,          69,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           71,          73,          75,          77,          79,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           81,          83,          85,          87,          89,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           91,          93,          95,          97,          99,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647}));
+}
+
+TEST(ReferenceTest, RandomJaxReference127) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{1, -2, 0, -2},
+      /*init_value=*/0,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{1, 2},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(16, 4));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {0, 0, 0, 0, 12,  16,  20,  24,  0, 0, 0, 0, 32,  36,  40,  44,
+           0, 0, 0, 0, 52,  56,  60,  64,  0, 0, 0, 0, 72,  76,  80,  84,
+           0, 0, 0, 0, 92,  96,  100, 104, 0, 0, 0, 0, 112, 116, 120, 124,
+           0, 0, 0, 0, 132, 136, 140, 144, 0, 0, 0, 0, 152, 156, 160, 164}));
+}
+
+TEST(ReferenceTest, RandomJaxReference128) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{-1, -2, 0, -2},
+      /*init_value=*/2147483646,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{1, 2},
+      /*body=*/[](auto a, auto b) { return a <= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(7, 3));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray({11, 13, 15, 21, 23, 25, 31, 33, 35, 41, 43,
+                                45, 51, 53, 55, 61, 63, 65, 71, 73, 75}));
+}
+
+TEST(ReferenceTest, RandomJaxReference129) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{1, 2, -1, 2},
+      /*init_value=*/1,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{1, 2},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(12, 9));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+                        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+                        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+                        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+                        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+                        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}));
+}
+
+TEST(ReferenceTest, RandomJaxReference130) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{-1, 1, 1, 1},
+      /*init_value=*/1,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{1, 2},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(19, 6));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {1,  1,  1,  1,  1,  1,   1,  12, 14, 16, 18, 20, 1,  1,  1,  1,  1,
+           1,  1,  22, 24, 26, 28,  30, 1,  1,  1,  1,  1,  1,  1,  32, 34, 36,
+           38, 40, 1,  1,  1,  1,   1,  1,  1,  42, 44, 46, 48, 50, 1,  1,  1,
+           1,  1,  1,  1,  52, 54,  56, 58, 60, 1,  1,  1,  1,  1,  1,  1,  62,
+           64, 66, 68, 70, 1,  1,   1,  1,  1,  1,  1,  72, 74, 76, 78, 80, 1,
+           1,  1,  1,  1,  1,  1,   82, 84, 86, 88, 90, 1,  1,  1,  1,  1,  1,
+           1,  92, 94, 96, 98, 100, 1,  1,  1,  1,  1,  1}));
+}
+
+TEST(ReferenceTest, RandomJaxReference131) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 2},
+      /*padding=*/{1, -1, -2, -1},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{2, 1},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(9, 16));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({2,  -2147483647, 3,  -2147483647, 4,  -2147483647,
+                        5,  -2147483647, 6,  -2147483647, 7,  -2147483647,
+                        8,  -2147483647, 9,  -2147483647, 12, -2147483647,
+                        13, -2147483647, 14, -2147483647, 15, -2147483647,
+                        16, -2147483647, 17, -2147483647, 18, -2147483647,
+                        19, -2147483647, 22, -2147483647, 23, -2147483647,
+                        24, -2147483647, 25, -2147483647, 26, -2147483647,
+                        27, -2147483647, 28, -2147483647, 29, -2147483647,
+                        32, -2147483647, 33, -2147483647, 34, -2147483647,
+                        35, -2147483647, 36, -2147483647, 37, -2147483647,
+                        38, -2147483647, 39, -2147483647, 42, -2147483647,
+                        43, -2147483647, 44, -2147483647, 45, -2147483647,
+                        46, -2147483647, 47, -2147483647, 48, -2147483647,
+                        49, -2147483647, 52, -2147483647, 53, -2147483647,
+                        54, -2147483647, 55, -2147483647, 56, -2147483647,
+                        57, -2147483647, 58, -2147483647, 59, -2147483647,
+                        62, -2147483647, 63, -2147483647, 64, -2147483647,
+                        65, -2147483647, 66, -2147483647, 67, -2147483647,
+                        68, -2147483647, 69, -2147483647, 72, -2147483647,
+                        73, -2147483647, 74, -2147483647, 75, -2147483647,
+                        76, -2147483647, 77, -2147483647, 78, -2147483647,
+                        79, -2147483647, 82, -2147483647, 83, -2147483647,
+                        84, -2147483647, 85, -2147483647, 86, -2147483647,
+                        87, -2147483647, 88, -2147483647, 89, -2147483647}));
+}
+
+TEST(ReferenceTest, RandomJaxReference132) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{0, 2, 2, -1},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{2, 1},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(10, 10));
+
+  EXPECT_THAT(res.data, ElementsAreArray(
+                            {-2147483647, 11, 12, 13, 14, 15, 16, 17, 18, 19,
+                             -2147483647, 21, 22, 23, 24, 25, 26, 27, 28, 29,
+                             -2147483647, 31, 32, 33, 34, 35, 36, 37, 38, 39,
+                             -2147483647, 41, 42, 43, 44, 45, 46, 47, 48, 49,
+                             -2147483647, 51, 52, 53, 54, 55, 56, 57, 58, 59,
+                             -2147483647, 61, 62, 63, 64, 65, 66, 67, 68, 69,
+                             -2147483647, 71, 72, 73, 74, 75, 76, 77, 78, 79,
+                             -2147483647, 81, 82, 83, 84, 85, 86, 87, 88, 89,
+                             -2147483647, 91, 92, 93, 94, 95, 96, 97, 98, 99,
+                             -2147483647, 91, 92, 93, 94, 95, 96, 97, 98, 99}));
+}
+
+TEST(ReferenceTest, RandomJaxReference133) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{0, 2, 1, -1},
+      /*init_value=*/2147483646,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{2, 2},
+      /*body=*/[](auto a, auto b) { return a <= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(10, 4));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({2,  2,  4,  6,  12, 12, 14, 16, 22, 22, 24, 26, 32, 32,
+                        34, 36, 42, 42, 44, 46, 52, 52, 54, 56, 62, 62, 64, 66,
+                        72, 72, 74, 76, 82, 82, 84, 86, 92, 92, 94, 96}));
+}
+
+TEST(ReferenceTest, RandomJaxReference134) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{-2, 2, 2, 1},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{2, 1},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(5, 22));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {-2147483647, -2147483647, 31,          -2147483647, 32,
+           -2147483647, 33,          -2147483647, 34,          -2147483647,
+           35,          -2147483647, 36,          -2147483647, 37,
+           -2147483647, 38,          -2147483647, 39,          -2147483647,
+           40,          -2147483647, -2147483647, -2147483647, 51,
+           -2147483647, 52,          -2147483647, 53,          -2147483647,
+           54,          -2147483647, 55,          -2147483647, 56,
+           -2147483647, 57,          -2147483647, 58,          -2147483647,
+           59,          -2147483647, 60,          -2147483647, -2147483647,
+           -2147483647, 71,          -2147483647, 72,          -2147483647,
+           73,          -2147483647, 74,          -2147483647, 75,
+           -2147483647, 76,          -2147483647, 77,          -2147483647,
+           78,          -2147483647, 79,          -2147483647, 80,
+           -2147483647, -2147483647, -2147483647, 91,          -2147483647,
+           92,          -2147483647, 93,          -2147483647, 94,
+           -2147483647, 95,          -2147483647, 96,          -2147483647,
+           97,          -2147483647, 98,          -2147483647, 99,
+           -2147483647, 100,         -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647}));
+}
+
+TEST(ReferenceTest, RandomJaxReference135) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{1, 0, 0, 2},
+      /*init_value=*/2147483646,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{2, 1},
+      /*body=*/[](auto a, auto b) { return a <= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(10, 12));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646}));
+}
+
+TEST(ReferenceTest, RandomJaxReference136) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{0, 1, 0, 0},
+      /*init_value=*/2147483646,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{1, 2},
+      /*body=*/[](auto a, auto b) { return a <= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(11, 9));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray(
+                  {1,          2,          3,          4,          5,
+                   6,          7,          8,          9,          11,
+                   12,         13,         14,         15,         16,
+                   17,         18,         19,         21,         22,
+                   23,         24,         25,         26,         27,
+                   28,         29,         31,         32,         33,
+                   34,         35,         36,         37,         38,
+                   39,         41,         42,         43,         44,
+                   45,         46,         47,         48,         49,
+                   51,         52,         53,         54,         55,
+                   56,         57,         58,         59,         61,
+                   62,         63,         64,         65,         66,
+                   67,         68,         69,         71,         72,
+                   73,         74,         75,         76,         77,
+                   78,         79,         81,         82,         83,
+                   84,         85,         86,         87,         88,
+                   89,         91,         92,         93,         94,
+                   95,         96,         97,         98,         99,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646}));
+}
+
+TEST(ReferenceTest, RandomJaxReference137) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{0, -1, 2, -1},
+      /*init_value=*/2147483646,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{2, 2},
+      /*body=*/[](auto a, auto b) { return a <= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(4, 5));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray({1,  1,  3,  5,  7,  21, 21, 23, 25, 27,
+                                41, 41, 43, 45, 47, 61, 61, 63, 65, 67}));
+}
+
+TEST(ReferenceTest, RandomJaxReference138) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 2},
+      /*padding=*/{0, -1, 1, 2},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{1, 1},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(18, 22));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {-2147483647, 1,           -2147483647, 2,           -2147483647,
+           3,           -2147483647, 4,           -2147483647, 5,
+           -2147483647, 6,           -2147483647, 7,           -2147483647,
+           8,           -2147483647, 9,           -2147483647, 10,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           11,          -2147483647, 12,          -2147483647, 13,
+           -2147483647, 14,          -2147483647, 15,          -2147483647,
+           16,          -2147483647, 17,          -2147483647, 18,
+           -2147483647, 19,          -2147483647, 20,          -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, 21,
+           -2147483647, 22,          -2147483647, 23,          -2147483647,
+           24,          -2147483647, 25,          -2147483647, 26,
+           -2147483647, 27,          -2147483647, 28,          -2147483647,
+           29,          -2147483647, 30,          -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, 31,          -2147483647,
+           32,          -2147483647, 33,          -2147483647, 34,
+           -2147483647, 35,          -2147483647, 36,          -2147483647,
+           37,          -2147483647, 38,          -2147483647, 39,
+           -2147483647, 40,          -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, 41,          -2147483647, 42,
+           -2147483647, 43,          -2147483647, 44,          -2147483647,
+           45,          -2147483647, 46,          -2147483647, 47,
+           -2147483647, 48,          -2147483647, 49,          -2147483647,
+           50,          -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, 51,          -2147483647, 52,          -2147483647,
+           53,          -2147483647, 54,          -2147483647, 55,
+           -2147483647, 56,          -2147483647, 57,          -2147483647,
+           58,          -2147483647, 59,          -2147483647, 60,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           61,          -2147483647, 62,          -2147483647, 63,
+           -2147483647, 64,          -2147483647, 65,          -2147483647,
+           66,          -2147483647, 67,          -2147483647, 68,
+           -2147483647, 69,          -2147483647, 70,          -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, 71,
+           -2147483647, 72,          -2147483647, 73,          -2147483647,
+           74,          -2147483647, 75,          -2147483647, 76,
+           -2147483647, 77,          -2147483647, 78,          -2147483647,
+           79,          -2147483647, 80,          -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, 81,          -2147483647,
+           82,          -2147483647, 83,          -2147483647, 84,
+           -2147483647, 85,          -2147483647, 86,          -2147483647,
+           87,          -2147483647, 88,          -2147483647, 89,
+           -2147483647, 90,          -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647}));
+}
+
+TEST(ReferenceTest, RandomJaxReference139) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{1, 2, 0, 2},
+      /*init_value=*/1,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{2, 2},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(6, 10));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {132,      156,      182,      210,      240,      272,      306,
+           342,      380,      20,       130944,   164736,   204204,   249900,
+           302400,   362304,   430236,   506844,   592800,   800,      2630784,
+           2910336,  3211164,  3534300,  3880800,  4251744,  4648236,  5071404,
+           5522400,  2400,     13557024, 14485536, 15460524, 16483500, 17556000,
+           18679584, 19855836, 21086364, 22372800, 4800,     42797664, 44970336,
+           47224284, 49561500, 51984000, 54493824, 57093036, 59783724, 62568000,
+           8000,     8372,     8556,     8742,     8930,     9120,     9312,
+           9506,     9702,     9900,     100}));
+}
+
+TEST(ReferenceTest, RandomJaxReference140) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 2},
+      /*padding=*/{-2, 0, -1, -2},
+      /*init_value=*/1,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{1, 2},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(15, 8));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+                        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+                        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+                        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+                        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+                        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+                        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}));
+}
+
+TEST(ReferenceTest, RandomJaxReference141) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{-2, -2, 1, 1},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{2, 2},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(3, 6));
+
+  EXPECT_THAT(res.data, ElementsAreArray({-2147483647, 22, 24, 26, 28, 30,
+                                          -2147483647, 42, 44, 46, 48, 50,
+                                          -2147483647, 62, 64, 66, 68, 70}));
+}
+
+TEST(ReferenceTest, RandomJaxReference142) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{1, 0, 0, -1},
+      /*init_value=*/1,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{1, 1},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(18, 9));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {1,    1,    1,    1,    1,    1,    1,    1,    1,    11,   24,
+           39,   56,   75,   96,   119,  144,  171,  1,    1,    1,    1,
+           1,    1,    1,    1,    1,    231,  264,  299,  336,  375,  416,
+           459,  504,  551,  1,    1,    1,    1,    1,    1,    1,    1,
+           1,    651,  704,  759,  816,  875,  936,  999,  1064, 1131, 1,
+           1,    1,    1,    1,    1,    1,    1,    1,    1271, 1344, 1419,
+           1496, 1575, 1656, 1739, 1824, 1911, 1,    1,    1,    1,    1,
+           1,    1,    1,    1,    2091, 2184, 2279, 2376, 2475, 2576, 2679,
+           2784, 2891, 1,    1,    1,    1,    1,    1,    1,    1,    1,
+           3111, 3224, 3339, 3456, 3575, 3696, 3819, 3944, 4071, 1,    1,
+           1,    1,    1,    1,    1,    1,    1,    4331, 4464, 4599, 4736,
+           4875, 5016, 5159, 5304, 5451, 1,    1,    1,    1,    1,    1,
+           1,    1,    1,    5751, 5904, 6059, 6216, 6375, 6536, 6699, 6864,
+           7031, 1,    1,    1,    1,    1,    1,    1,    1,    1,    7371,
+           7544, 7719, 7896, 8075, 8256, 8439, 8624, 8811}));
+}
+
+TEST(ReferenceTest, RandomJaxReference143) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{2, -1, -2, -1},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{2, 2},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(5, 8));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({2,  3,  4,  5,  6,  7,  8,  9,  22, 23, 24, 25, 26, 27,
+                        28, 29, 42, 43, 44, 45, 46, 47, 48, 49, 62, 63, 64, 65,
+                        66, 67, 68, 69, 82, 83, 84, 85, 86, 87, 88, 89}));
+}
+
+TEST(ReferenceTest, RandomJaxReference144) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{0, 1, 2, 2},
+      /*init_value=*/2147483646,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{2, 1},
+      /*body=*/[](auto a, auto b) { return a <= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(6, 23));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray(
+                  {2147483646, 2147483646, 1,          2147483646, 2,
+                   2147483646, 3,          2147483646, 4,          2147483646,
+                   5,          2147483646, 6,          2147483646, 7,
+                   2147483646, 8,          2147483646, 9,          2147483646,
+                   10,         2147483646, 2147483646, 2147483646, 2147483646,
+                   21,         2147483646, 22,         2147483646, 23,
+                   2147483646, 24,         2147483646, 25,         2147483646,
+                   26,         2147483646, 27,         2147483646, 28,
+                   2147483646, 29,         2147483646, 30,         2147483646,
+                   2147483646, 2147483646, 2147483646, 41,         2147483646,
+                   42,         2147483646, 43,         2147483646, 44,
+                   2147483646, 45,         2147483646, 46,         2147483646,
+                   47,         2147483646, 48,         2147483646, 49,
+                   2147483646, 50,         2147483646, 2147483646, 2147483646,
+                   2147483646, 61,         2147483646, 62,         2147483646,
+                   63,         2147483646, 64,         2147483646, 65,
+                   2147483646, 66,         2147483646, 67,         2147483646,
+                   68,         2147483646, 69,         2147483646, 70,
+                   2147483646, 2147483646, 2147483646, 2147483646, 81,
+                   2147483646, 82,         2147483646, 83,         2147483646,
+                   84,         2147483646, 85,         2147483646, 86,
+                   2147483646, 87,         2147483646, 88,         2147483646,
+                   89,         2147483646, 90,         2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646}));
+}
+
+TEST(ReferenceTest, RandomJaxReference145) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{2, -2, 2, -2},
+      /*init_value=*/1,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{2, 2},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(10, 5));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({1, 1,    1,    1,    1,    1, 2,    12,   30,   56,
+                        1, 132,  182,  240,  306,  1, 462,  552,  650,  756,
+                        1, 992,  1122, 1260, 1406, 1, 1722, 1892, 2070, 2256,
+                        1, 2652, 2862, 3080, 3306, 1, 3782, 4032, 4290, 4556,
+                        1, 5112, 5402, 5700, 6006, 1, 6642, 6972, 7310, 7656}));
+}
+
+TEST(ReferenceTest, RandomJaxReference146) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{1, -2, 1, 0},
+      /*init_value=*/2147483646,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{1, 2},
+      /*body=*/[](auto a, auto b) { return a <= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(17, 6));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {2147483646, 2,  4,  6,  8,  10, 2147483646, 2,  4,  6,  8,  10,
+           2147483646, 12, 14, 16, 18, 20, 2147483646, 12, 14, 16, 18, 20,
+           2147483646, 22, 24, 26, 28, 30, 2147483646, 22, 24, 26, 28, 30,
+           2147483646, 32, 34, 36, 38, 40, 2147483646, 32, 34, 36, 38, 40,
+           2147483646, 42, 44, 46, 48, 50, 2147483646, 42, 44, 46, 48, 50,
+           2147483646, 52, 54, 56, 58, 60, 2147483646, 52, 54, 56, 58, 60,
+           2147483646, 62, 64, 66, 68, 70, 2147483646, 62, 64, 66, 68, 70,
+           2147483646, 72, 74, 76, 78, 80, 2147483646, 72, 74, 76, 78, 80,
+           2147483646, 82, 84, 86, 88, 90}));
+}
+
+TEST(ReferenceTest, RandomJaxReference147) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{-2, 0, 2, 0},
+      /*init_value=*/0,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{2, 2},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(8, 5));
+
+  EXPECT_THAT(res.data, ElementsAreArray(
+                            {11, 24,  28,  32,  36,  21, 44,  48,  52,  56,
+                             31, 64,  68,  72,  76,  41, 84,  88,  92,  96,
+                             51, 104, 108, 112, 116, 61, 124, 128, 132, 136,
+                             71, 144, 148, 152, 156, 81, 164, 168, 172, 176}));
+}
+
+TEST(ReferenceTest, RandomJaxReference148) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 2},
+      /*padding=*/{1, -2, 2, 1},
+      /*init_value=*/1,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{1, 1},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(17, 22));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {1,  1, 1,  1, 2,  1, 3,  1, 4,  1, 5,  1, 6,  1, 7,  1, 8,  1, 9,  1,
+           10, 1, 1,  1, 1,  1, 2,  1, 3,  1, 4,  1, 5,  1, 6,  1, 7,  1, 8,  1,
+           9,  1, 10, 1, 1,  1, 11, 1, 12, 1, 13, 1, 14, 1, 15, 1, 16, 1, 17, 1,
+           18, 1, 19, 1, 20, 1, 1,  1, 11, 1, 12, 1, 13, 1, 14, 1, 15, 1, 16, 1,
+           17, 1, 18, 1, 19, 1, 20, 1, 1,  1, 21, 1, 22, 1, 23, 1, 24, 1, 25, 1,
+           26, 1, 27, 1, 28, 1, 29, 1, 30, 1, 1,  1, 21, 1, 22, 1, 23, 1, 24, 1,
+           25, 1, 26, 1, 27, 1, 28, 1, 29, 1, 30, 1, 1,  1, 31, 1, 32, 1, 33, 1,
+           34, 1, 35, 1, 36, 1, 37, 1, 38, 1, 39, 1, 40, 1, 1,  1, 31, 1, 32, 1,
+           33, 1, 34, 1, 35, 1, 36, 1, 37, 1, 38, 1, 39, 1, 40, 1, 1,  1, 41, 1,
+           42, 1, 43, 1, 44, 1, 45, 1, 46, 1, 47, 1, 48, 1, 49, 1, 50, 1, 1,  1,
+           41, 1, 42, 1, 43, 1, 44, 1, 45, 1, 46, 1, 47, 1, 48, 1, 49, 1, 50, 1,
+           1,  1, 51, 1, 52, 1, 53, 1, 54, 1, 55, 1, 56, 1, 57, 1, 58, 1, 59, 1,
+           60, 1, 1,  1, 51, 1, 52, 1, 53, 1, 54, 1, 55, 1, 56, 1, 57, 1, 58, 1,
+           59, 1, 60, 1, 1,  1, 61, 1, 62, 1, 63, 1, 64, 1, 65, 1, 66, 1, 67, 1,
+           68, 1, 69, 1, 70, 1, 1,  1, 61, 1, 62, 1, 63, 1, 64, 1, 65, 1, 66, 1,
+           67, 1, 68, 1, 69, 1, 70, 1, 1,  1, 71, 1, 72, 1, 73, 1, 74, 1, 75, 1,
+           76, 1, 77, 1, 78, 1, 79, 1, 80, 1, 1,  1, 71, 1, 72, 1, 73, 1, 74, 1,
+           75, 1, 76, 1, 77, 1, 78, 1, 79, 1, 80, 1, 1,  1, 81, 1, 82, 1, 83, 1,
+           84, 1, 85, 1, 86, 1, 87, 1, 88, 1, 89, 1, 90, 1}));
+}
+
+TEST(ReferenceTest, RandomJaxReference149) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{-1, -2, -2, 2},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{1, 2},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(6, 5));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray(
+                  {23, 25, 27, 29, -2147483647, 33, 35, 37, 39, -2147483647,
+                   43, 45, 47, 49, -2147483647, 53, 55, 57, 59, -2147483647,
+                   63, 65, 67, 69, -2147483647, 73, 75, 77, 79, -2147483647}));
+}
+
+TEST(ReferenceTest, RandomJaxReference150) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{2, -1, -2, 0},
+      /*init_value=*/0,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{2, 1},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(6, 17));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({0, 0,  0, 0,  0,  0,  0,  0,  0,  0,  0,  0, 0,  0, 0,
+                        0, 0,  2, 0,  3,  0,  4,  0,  5,  0,  6,  0, 7,  0, 8,
+                        0, 9,  0, 10, 22, 0,  23, 0,  24, 0,  25, 0, 26, 0, 27,
+                        0, 28, 0, 29, 0,  30, 42, 0,  43, 0,  44, 0, 45, 0, 46,
+                        0, 47, 0, 48, 0,  49, 0,  50, 62, 0,  63, 0, 64, 0, 65,
+                        0, 66, 0, 67, 0,  68, 0,  69, 0,  70, 82, 0, 83, 0, 84,
+                        0, 85, 0, 86, 0,  87, 0,  88, 0,  89, 0,  90}));
+}
+
+TEST(ReferenceTest, RandomJaxReference151) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 2},
+      /*padding=*/{2, 1, 2, -1},
+      /*init_value=*/1,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{2, 1},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(10, 19));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {1,    1,    1,    2,    2,    3,    3,    4,    4,    5,    5,
+           6,    6,    7,    7,    8,    8,    9,    9,    1,    11,   11,
+           24,   24,   39,   39,   56,   56,   75,   75,   96,   96,   119,
+           119,  144,  144,  171,  171,  1,    231,  231,  264,  264,  299,
+           299,  336,  336,  375,  375,  416,  416,  459,  459,  504,  504,
+           551,  551,  1,    651,  651,  704,  704,  759,  759,  816,  816,
+           875,  875,  936,  936,  999,  999,  1064, 1064, 1131, 1131, 1,
+           1271, 1271, 1344, 1344, 1419, 1419, 1496, 1496, 1575, 1575, 1656,
+           1656, 1739, 1739, 1824, 1824, 1911, 1911, 1,    2091, 2091, 2184,
+           2184, 2279, 2279, 2376, 2376, 2475, 2475, 2576, 2576, 2679, 2679,
+           2784, 2784, 2891, 2891, 1,    3111, 3111, 3224, 3224, 3339, 3339,
+           3456, 3456, 3575, 3575, 3696, 3696, 3819, 3819, 3944, 3944, 4071,
+           4071, 1,    4331, 4331, 4464, 4464, 4599, 4599, 4736, 4736, 4875,
+           4875, 5016, 5016, 5159, 5159, 5304, 5304, 5451, 5451, 1,    5751,
+           5751, 5904, 5904, 6059, 6059, 6216, 6216, 6375, 6375, 6536, 6536,
+           6699, 6699, 6864, 6864, 7031, 7031, 1,    7371, 7371, 7544, 7544,
+           7719, 7719, 7896, 7896, 8075, 8075, 8256, 8256, 8439, 8439, 8624,
+           8624, 8811, 8811}));
+}
+
+TEST(ReferenceTest, RandomJaxReference152) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{-1, 2, -2, 1},
+      /*init_value=*/1,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{2, 1},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(9, 7));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray({1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+                                1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+                                1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+                                1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}));
+}
+
+TEST(ReferenceTest, RandomJaxReference153) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{-2, 2, -1, 2},
+      /*init_value=*/1,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{2, 2},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(9, 5));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {88704,    139776,   209664,   302400,   600,      574464,   763776,
+           995904,   1276800,  1200,     2010624,  2477376,  3020544,  3648000,
+           2000,     5189184,  6120576,  7171584,  8352000,  3000,     11142144,
+           12773376, 14577024, 16564800, 4200,     21141504, 23755776, 26604864,
+           29702400, 5600,     36699264, 40627776, 44863104, 49420800, 7200,
+           59567424, 65189376, 71199744, 77616000, 9000,     8648,     9024,
+           9408,     9800,     100}));
+}
+
+TEST(ReferenceTest, RandomJaxReference154) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{2, 2, -1, 0},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{2, 1},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(7, 18));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {-2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, 2,
+           -2147483647, 3,           -2147483647, 4,           -2147483647,
+           5,           -2147483647, 6,           -2147483647, 7,
+           -2147483647, 8,           -2147483647, 9,           -2147483647,
+           10,          -2147483647, 22,          -2147483647, 23,
+           -2147483647, 24,          -2147483647, 25,          -2147483647,
+           26,          -2147483647, 27,          -2147483647, 28,
+           -2147483647, 29,          -2147483647, 30,          -2147483647,
+           42,          -2147483647, 43,          -2147483647, 44,
+           -2147483647, 45,          -2147483647, 46,          -2147483647,
+           47,          -2147483647, 48,          -2147483647, 49,
+           -2147483647, 50,          -2147483647, 62,          -2147483647,
+           63,          -2147483647, 64,          -2147483647, 65,
+           -2147483647, 66,          -2147483647, 67,          -2147483647,
+           68,          -2147483647, 69,          -2147483647, 70,
+           -2147483647, 82,          -2147483647, 83,          -2147483647,
+           84,          -2147483647, 85,          -2147483647, 86,
+           -2147483647, 87,          -2147483647, 88,          -2147483647,
+           89,          -2147483647, 90,          -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647}));
+}
+
+TEST(ReferenceTest, RandomJaxReference155) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{0, 2, 2, 1},
+      /*init_value=*/0,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{2, 2},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(11, 6));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({0,   3,   7,   11,  15,  19,  0,   23,  27,  31,  35,
+                        39,  0,   43,  47,  51,  55,  59,  0,   63,  67,  71,
+                        75,  79,  0,   83,  87,  91,  95,  99,  0,   103, 107,
+                        111, 115, 119, 0,   123, 127, 131, 135, 139, 0,   143,
+                        147, 151, 155, 159, 0,   163, 167, 171, 175, 179, 0,
+                        183, 187, 191, 195, 199, 0,   0,   0,   0,   0,   0}));
+}
+
+TEST(ReferenceTest, RandomJaxReference156) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{2, -1, -1, -1},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{1, 2},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(11, 3));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {-2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, 4,           6,           8,           14,
+           16,          18,          24,          26,          28,
+           34,          36,          38,          44,          46,
+           48,          54,          56,          58,          64,
+           66,          68,          74,          76,          78,
+           84,          86,          88}));
+}
+
+TEST(ReferenceTest, RandomJaxReference157) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{-1, -1, -2, -1},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{1, 1},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(16, 7));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {13, 14, 15, 16, 17, 18, 19, 13, 14, 15, 16, 17, 18, 19, 23, 24,
+           25, 26, 27, 28, 29, 23, 24, 25, 26, 27, 28, 29, 33, 34, 35, 36,
+           37, 38, 39, 33, 34, 35, 36, 37, 38, 39, 43, 44, 45, 46, 47, 48,
+           49, 43, 44, 45, 46, 47, 48, 49, 53, 54, 55, 56, 57, 58, 59, 53,
+           54, 55, 56, 57, 58, 59, 63, 64, 65, 66, 67, 68, 69, 63, 64, 65,
+           66, 67, 68, 69, 73, 74, 75, 76, 77, 78, 79, 73, 74, 75, 76, 77,
+           78, 79, 83, 84, 85, 86, 87, 88, 89, 83, 84, 85, 86, 87, 88, 89}));
+}
+
+TEST(ReferenceTest, RandomJaxReference158) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{2, -2, 2, 0},
+      /*init_value=*/0,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{2, 2},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(9, 6));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({0,  0,  0,  0,  0,  0,  0,  1,  3,  5,  7,  9,  0,  11,
+                        13, 15, 17, 19, 0,  21, 23, 25, 27, 29, 0,  31, 33, 35,
+                        37, 39, 0,  41, 43, 45, 47, 49, 0,  51, 53, 55, 57, 59,
+                        0,  61, 63, 65, 67, 69, 0,  71, 73, 75, 77, 79}));
+}
+
+TEST(ReferenceTest, RandomJaxReference159) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{1, 2, -1, 1},
+      /*init_value=*/2147483646,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{1, 1},
+      /*body=*/[](auto a, auto b) { return a <= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(13, 9));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray(
+                  {2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2,
+                   3,          4,          5,          6,          7,
+                   8,          9,          10,         12,         13,
+                   14,         15,         16,         17,         18,
+                   19,         20,         22,         23,         24,
+                   25,         26,         27,         28,         29,
+                   30,         32,         33,         34,         35,
+                   36,         37,         38,         39,         40,
+                   42,         43,         44,         45,         46,
+                   47,         48,         49,         50,         52,
+                   53,         54,         55,         56,         57,
+                   58,         59,         60,         62,         63,
+                   64,         65,         66,         67,         68,
+                   69,         70,         72,         73,         74,
+                   75,         76,         77,         78,         79,
+                   80,         82,         83,         84,         85,
+                   86,         87,         88,         89,         90,
+                   92,         93,         94,         95,         96,
+                   97,         98,         99,         100,        2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646}));
+}
+
+TEST(ReferenceTest, RandomJaxReference160) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{-1, 0, -2, 1},
+      /*init_value=*/0,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{2, 2},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(4, 4));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray({74, 82, 90, 98, 154, 162, 170, 178, 234, 242,
+                                250, 258, 314, 322, 330, 338}));
+}
+
+TEST(ReferenceTest, RandomJaxReference161) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{0, 2, -1, 1},
+      /*init_value=*/0,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{2, 2},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(11, 5));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({5,   9,   13,  17,  10,  25,  29,  33,  37,  20,  45,
+                        49,  53,  57,  30,  65,  69,  73,  77,  40,  85,  89,
+                        93,  97,  50,  105, 109, 113, 117, 60,  125, 129, 133,
+                        137, 70,  145, 149, 153, 157, 80,  165, 169, 173, 177,
+                        90,  185, 189, 193, 197, 100, 0,   0,   0,   0,   0}));
+}
+
+TEST(ReferenceTest, RandomJaxReference162) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{2, -1, -1, 0},
+      /*init_value=*/0,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{1, 1},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(10, 17));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
+           0,   0,   0,   2,   2,   3,   3,   4,   4,   5,   5,   6,   6,   7,
+           7,   8,   8,   9,   9,   10,  14,  14,  16,  16,  18,  18,  20,  20,
+           22,  22,  24,  24,  26,  26,  28,  28,  30,  34,  34,  36,  36,  38,
+           38,  40,  40,  42,  42,  44,  44,  46,  46,  48,  48,  50,  54,  54,
+           56,  56,  58,  58,  60,  60,  62,  62,  64,  64,  66,  66,  68,  68,
+           70,  74,  74,  76,  76,  78,  78,  80,  80,  82,  82,  84,  84,  86,
+           86,  88,  88,  90,  94,  94,  96,  96,  98,  98,  100, 100, 102, 102,
+           104, 104, 106, 106, 108, 108, 110, 114, 114, 116, 116, 118, 118, 120,
+           120, 122, 122, 124, 124, 126, 126, 128, 128, 130, 134, 134, 136, 136,
+           138, 138, 140, 140, 142, 142, 144, 144, 146, 146, 148, 148, 150, 154,
+           154, 156, 156, 158, 158, 160, 160, 162, 162, 164, 164, 166, 166, 168,
+           168, 170}));
+}
+
+TEST(ReferenceTest, RandomJaxReference163) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{0, 0, 0, 2},
+      /*init_value=*/0,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{1, 1},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(10, 12));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({1,  2,  3,  4,  5,  6,   7,  8,  9,  10, 0,  0,  11, 12,
+                        13, 14, 15, 16, 17, 18,  19, 20, 0,  0,  21, 22, 23, 24,
+                        25, 26, 27, 28, 29, 30,  0,  0,  31, 32, 33, 34, 35, 36,
+                        37, 38, 39, 40, 0,  0,   41, 42, 43, 44, 45, 46, 47, 48,
+                        49, 50, 0,  0,  51, 52,  53, 54, 55, 56, 57, 58, 59, 60,
+                        0,  0,  61, 62, 63, 64,  65, 66, 67, 68, 69, 70, 0,  0,
+                        71, 72, 73, 74, 75, 76,  77, 78, 79, 80, 0,  0,  81, 82,
+                        83, 84, 85, 86, 87, 88,  89, 90, 0,  0,  91, 92, 93, 94,
+                        95, 96, 97, 98, 99, 100, 0,  0}));
+}
+
+TEST(ReferenceTest, RandomJaxReference164) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{-1, 0, 2, 1},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{1, 2},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(9, 7));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray({-2147483647, 11, 13, 15, 17, 19, -2147483647,
+                                -2147483647, 21, 23, 25, 27, 29, -2147483647,
+                                -2147483647, 31, 33, 35, 37, 39, -2147483647,
+                                -2147483647, 41, 43, 45, 47, 49, -2147483647,
+                                -2147483647, 51, 53, 55, 57, 59, -2147483647,
+                                -2147483647, 61, 63, 65, 67, 69, -2147483647,
+                                -2147483647, 71, 73, 75, 77, 79, -2147483647,
+                                -2147483647, 81, 83, 85, 87, 89, -2147483647,
+                                -2147483647, 91, 93, 95, 97, 99, -2147483647}));
+}
+
+TEST(ReferenceTest, RandomJaxReference165) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{2, -2, 2, -1},
+      /*init_value=*/1,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{2, 1},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(5, 18));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({1,    1, 1,    1, 1,    1, 1,    1, 1,    1, 1,    1,
+                        1,    1, 1,    1, 1,    1, 1,    1, 2,    1, 6,    1,
+                        12,   1, 20,   1, 30,   1, 42,   1, 56,   1, 72,   1,
+                        21,   1, 462,  1, 506,  1, 552,  1, 600,  1, 650,  1,
+                        702,  1, 756,  1, 812,  1, 41,   1, 1722, 1, 1806, 1,
+                        1892, 1, 1980, 1, 2070, 1, 2162, 1, 2256, 1, 2352, 1,
+                        61,   1, 3782, 1, 3906, 1, 4032, 1, 4160, 1, 4290, 1,
+                        4422, 1, 4556, 1, 4692, 1}));
+}
+
+TEST(ReferenceTest, RandomJaxReference166) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{1, -1, 0, -2},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{1, 2},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(8, 3));
+
+  EXPECT_THAT(res.data, ElementsAreArray({13, 15, 17, 23, 25, 27, 33, 35,
+                                          37, 43, 45, 47, 53, 55, 57, 63,
+                                          65, 67, 73, 75, 77, 83, 85, 87}));
+}
+
+TEST(ReferenceTest, RandomJaxReference167) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{-2, 2, 0, 1},
+      /*init_value=*/0,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{2, 2},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(9, 5));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray({66,  74,  82,  90,  98,  106, 114, 122, 130,
+                                138, 146, 154, 162, 170, 178, 186, 194, 202,
+                                210, 218, 226, 234, 242, 250, 258, 266, 274,
+                                282, 290, 298, 306, 314, 322, 330, 338, 346,
+                                354, 362, 370, 378, 183, 187, 191, 195, 199}));
+}
+
+TEST(ReferenceTest, RandomJaxReference168) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{0, -1, -1, -2},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{2, 2},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(4, 7));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({-2147483647, -2147483647, -2147483647, -2147483647,
+                        -2147483647, -2147483647, -2147483647, -2147483647,
+                        -2147483647, -2147483647, -2147483647, -2147483647,
+                        -2147483647, -2147483647, -2147483647, -2147483647,
+                        -2147483647, -2147483647, -2147483647, -2147483647,
+                        -2147483647, -2147483647, -2147483647, -2147483647,
+                        -2147483647, -2147483647, -2147483647, -2147483647}));
+}
+
+TEST(ReferenceTest, RandomJaxReference169) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{1, -2, 0, 1},
+      /*init_value=*/1,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{1, 2},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(9, 9));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({1,    1,    1,    1,    1,    1,    1,    1,    1,
+                        2,    6,    12,   20,   30,   42,   56,   72,   90,
+                        132,  156,  182,  210,  240,  272,  306,  342,  380,
+                        462,  506,  552,  600,  650,  702,  756,  812,  870,
+                        992,  1056, 1122, 1190, 1260, 1332, 1406, 1482, 1560,
+                        1722, 1806, 1892, 1980, 2070, 2162, 2256, 2352, 2450,
+                        2652, 2756, 2862, 2970, 3080, 3192, 3306, 3422, 3540,
+                        3782, 3906, 4032, 4160, 4290, 4422, 4556, 4692, 4830,
+                        5112, 5256, 5402, 5550, 5700, 5852, 6006, 6162, 6320}));
+}
+
+TEST(ReferenceTest, RandomJaxReference170) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{2, -1, 0, -1},
+      /*init_value=*/1,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{2, 1},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(6, 18));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({1,  1, 1,  1, 1,  1, 1,  1, 1,  1, 1,  1, 1,  1, 1,  1,
+                        1,  1, 1,  1, 2,  1, 3,  1, 4,  1, 5,  1, 6,  1, 7,  1,
+                        8,  1, 9,  1, 21, 1, 22, 1, 23, 1, 24, 1, 25, 1, 26, 1,
+                        27, 1, 28, 1, 29, 1, 41, 1, 42, 1, 43, 1, 44, 1, 45, 1,
+                        46, 1, 47, 1, 48, 1, 49, 1, 61, 1, 62, 1, 63, 1, 64, 1,
+                        65, 1, 66, 1, 67, 1, 68, 1, 69, 1, 81, 1, 82, 1, 83, 1,
+                        84, 1, 85, 1, 86, 1, 87, 1, 88, 1, 89, 1}));
+}
+
+TEST(ReferenceTest, RandomJaxReference171) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{0, -2, 2, 0},
+      /*init_value=*/1,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{1, 2},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(7, 10));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray(
+                  {1, 11,   24,   39,   56,   75,   96,   119,  144,  171,
+                   1, 231,  264,  299,  336,  375,  416,  459,  504,  551,
+                   1, 651,  704,  759,  816,  875,  936,  999,  1064, 1131,
+                   1, 1271, 1344, 1419, 1496, 1575, 1656, 1739, 1824, 1911,
+                   1, 2091, 2184, 2279, 2376, 2475, 2576, 2679, 2784, 2891,
+                   1, 3111, 3224, 3339, 3456, 3575, 3696, 3819, 3944, 4071,
+                   1, 4331, 4464, 4599, 4736, 4875, 5016, 5159, 5304, 5451}));
+}
+
+TEST(ReferenceTest, RandomJaxReference172) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 2},
+      /*padding=*/{-1, 1, 2, 2},
+      /*init_value=*/0,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{2, 2},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(10, 12));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray({0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
+}
+
+TEST(ReferenceTest, RandomJaxReference173) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 2},
+      /*padding=*/{1, 1, 1, 0},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{1, 2},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(21, 10));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {-2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           1,           2,           3,           4,           5,
+           6,           7,           8,           9,           10,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           11,          12,          13,          14,          15,
+           16,          17,          18,          19,          20,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           21,          22,          23,          24,          25,
+           26,          27,          28,          29,          30,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           31,          32,          33,          34,          35,
+           36,          37,          38,          39,          40,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           41,          42,          43,          44,          45,
+           46,          47,          48,          49,          50,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           51,          52,          53,          54,          55,
+           56,          57,          58,          59,          60,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           61,          62,          63,          64,          65,
+           66,          67,          68,          69,          70,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           71,          72,          73,          74,          75,
+           76,          77,          78,          79,          80,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           81,          82,          83,          84,          85,
+           86,          87,          88,          89,          90,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           91,          92,          93,          94,          95,
+           96,          97,          98,          99,          100,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647}));
+}
+
+TEST(ReferenceTest, RandomJaxReference174) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 2},
+      /*padding=*/{0, -1, -2, -1},
+      /*init_value=*/2147483646,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{1, 2},
+      /*body=*/[](auto a, auto b) { return a <= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(18, 7));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray(
+                  {2,          3,          4,          5,          6,
+                   7,          8,          2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 12,
+                   13,         14,         15,         16,         17,
+                   18,         2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 22,         23,
+                   24,         25,         26,         27,         28,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 32,         33,         34,
+                   35,         36,         37,         38,         2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 42,         43,         44,         45,
+                   46,         47,         48,         2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   52,         53,         54,         55,         56,
+                   57,         58,         2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 62,
+                   63,         64,         65,         66,         67,
+                   68,         2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 72,         73,
+                   74,         75,         76,         77,         78,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 82,         83,         84,
+                   85,         86,         87,         88,         2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646}));
+}
+
+TEST(ReferenceTest, RandomJaxReference175) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{-2, 2, 0, 0},
+      /*init_value=*/1,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{1, 2},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(9, 9));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {458304,   534336,   619344,   714000,   819000,   935064,   1062936,
+           1203384,  1357200,  1708224,  1907136,  2122824,  2356200,  2608200,
+           2879784,  3171936,  3485664,  3822000,  4566744,  4977336,  5414904,
+           5880600,  6375600,  6901104,  7458336,  8048544,  8673000,  10029864,
+           10764936, 11539584, 12355200, 13213200, 14115024, 15062136, 16056024,
+           17098200, 19333584, 20529936, 21780864, 23088000, 24453000, 25877544,
+           27363336, 28912104, 30525600, 33953904, 35772336, 37662744, 39627000,
+           41667000, 43784664, 45981936, 48260784, 50623200, 55606824, 58232136,
+           60949224, 63760200, 66667200, 69672384, 72777936, 75986064, 79299000,
+           8372,     8556,     8742,     8930,     9120,     9312,     9506,
+           9702,     9900,     1,        1,        1,        1,        1,
+           1,        1,        1,        1}));
+}
+
+TEST(ReferenceTest, RandomJaxReference176) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{1, -2, -1, 2},
+      /*init_value=*/2147483646,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{2, 1},
+      /*body=*/[](auto a, auto b) { return a <= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(8, 10));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646}));
+}
+
+TEST(ReferenceTest, RandomJaxReference177) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{0, 1, 0, 2},
+      /*init_value=*/0,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{2, 2},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(10, 5));
+
+  EXPECT_THAT(res.data, ElementsAreArray(
+                            {4,   8,   12,  16,  9,  24,  28,  32,  36,  19,
+                             44,  48,  52,  56,  29, 64,  68,  72,  76,  39,
+                             84,  88,  92,  96,  49, 104, 108, 112, 116, 59,
+                             124, 128, 132, 136, 69, 144, 148, 152, 156, 79,
+                             164, 168, 172, 176, 89, 184, 188, 192, 196, 99}));
+}
+
+TEST(ReferenceTest, RandomJaxReference178) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{2, 1, 2, 1},
+      /*init_value=*/2147483646,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{2, 1},
+      /*body=*/[](auto a, auto b) { return a <= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(7, 11));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray(
+                  {2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 1,          2,          1,          2,
+                   3,          4,          5,          6,          7,
+                   8,          9,          21,         22,         21,
+                   22,         23,         24,         25,         26,
+                   27,         28,         29,         41,         42,
+                   41,         42,         43,         44,         45,
+                   46,         47,         48,         49,         61,
+                   62,         61,         62,         63,         64,
+                   65,         66,         67,         68,         69,
+                   81,         82,         81,         82,         83,
+                   84,         85,         86,         87,         88,
+                   89,         2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646}));
+}
+
+TEST(ReferenceTest, RandomJaxReference179) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{-2, -2, 2, 0},
+      /*init_value=*/1,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{1, 2},
+      /*window_strides=*/{2, 2},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(3, 11));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray({1, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
+                                1, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
+                                1, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70}));
+}
+
+TEST(ReferenceTest, RandomJaxReference180) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{1, -2, 1, 0},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{1, 2},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(9, 6));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {-2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, 2,           4,           6,
+           8,           10,          -2147483647, 12,          14,
+           16,          18,          20,          -2147483647, 22,
+           24,          26,          28,          30,          -2147483647,
+           32,          34,          36,          38,          40,
+           -2147483647, 42,          44,          46,          48,
+           50,          -2147483647, 52,          54,          56,
+           58,          60,          -2147483647, 62,          64,
+           66,          68,          70,          -2147483647, 72,
+           74,          76,          78,          80}));
+}
+
+TEST(ReferenceTest, RandomJaxReference181) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{-2, -1, -1, -2},
+      /*init_value=*/2147483646,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{1, 2},
+      /*body=*/[](auto a, auto b) { return a <= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(7, 8));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray(
+                  {2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646}));
+}
+
+TEST(ReferenceTest, RandomJaxReference182) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{-1, -1, 2, -1},
+      /*init_value=*/0,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{2, 1},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(3, 20));
+
+  EXPECT_THAT(res.data, ElementsAreArray(
+                            {0,   0, 42,  0, 44,  0, 46,  0, 48,  0, 50,  0,
+                             52,  0, 54,  0, 56,  0, 58,  0, 0,   0, 82,  0,
+                             84,  0, 86,  0, 88,  0, 90,  0, 92,  0, 94,  0,
+                             96,  0, 98,  0, 0,   0, 122, 0, 124, 0, 126, 0,
+                             128, 0, 130, 0, 132, 0, 134, 0, 136, 0, 138, 0}));
+}
+
+TEST(ReferenceTest, RandomJaxReference183) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 2},
+      /*padding=*/{1, -1, -2, -1},
+      /*init_value=*/1,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{1, 2},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(17, 8));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {1,    1,    1,    1,    1,    1,    1,    1,    24,   39,   56,
+           75,   96,   119,  144,  171,  1,    1,    1,    1,    1,    1,
+           1,    1,    264,  299,  336,  375,  416,  459,  504,  551,  1,
+           1,    1,    1,    1,    1,    1,    1,    704,  759,  816,  875,
+           936,  999,  1064, 1131, 1,    1,    1,    1,    1,    1,    1,
+           1,    1344, 1419, 1496, 1575, 1656, 1739, 1824, 1911, 1,    1,
+           1,    1,    1,    1,    1,    1,    2184, 2279, 2376, 2475, 2576,
+           2679, 2784, 2891, 1,    1,    1,    1,    1,    1,    1,    1,
+           3224, 3339, 3456, 3575, 3696, 3819, 3944, 4071, 1,    1,    1,
+           1,    1,    1,    1,    1,    4464, 4599, 4736, 4875, 5016, 5159,
+           5304, 5451, 1,    1,    1,    1,    1,    1,    1,    1,    5904,
+           6059, 6216, 6375, 6536, 6699, 6864, 7031, 1,    1,    1,    1,
+           1,    1,    1,    1}));
+}
+
+TEST(ReferenceTest, RandomJaxReference184) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{-1, 1, 2, -1},
+      /*init_value=*/2147483646,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{2, 2},
+      /*body=*/[](auto a, auto b) { return a <= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(9, 5));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+           2147483646, 2147483646, 2147483646, 2147483646, 2147483646}));
+}
+
+TEST(ReferenceTest, RandomJaxReference185) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{2, -1, -2, 0},
+      /*init_value=*/0,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{1, 2},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(11, 9));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
+           0,  2,  3,  4,  5,  6,  7,  8,  9,  10, 12, 13, 14, 15, 16, 17, 18,
+           19, 20, 22, 23, 24, 25, 26, 27, 28, 29, 30, 32, 33, 34, 35, 36, 37,
+           38, 39, 40, 42, 43, 44, 45, 46, 47, 48, 49, 50, 52, 53, 54, 55, 56,
+           57, 58, 59, 60, 62, 63, 64, 65, 66, 67, 68, 69, 70, 72, 73, 74, 75,
+           76, 77, 78, 79, 80, 82, 83, 84, 85, 86, 87, 88, 89, 90}));
+}
+
+TEST(ReferenceTest, RandomJaxReference186) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{0, 0, 0, -2},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{1, 1},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(9, 7));
+
+  EXPECT_THAT(res.data, ElementsAreArray(
+                            {12, 13, 14, 15, 16, 17, 18, 22, 23, 24, 25, 26, 27,
+                             28, 32, 33, 34, 35, 36, 37, 38, 42, 43, 44, 45, 46,
+                             47, 48, 52, 53, 54, 55, 56, 57, 58, 62, 63, 64, 65,
+                             66, 67, 68, 72, 73, 74, 75, 76, 77, 78, 82, 83, 84,
+                             85, 86, 87, 88, 92, 93, 94, 95, 96, 97, 98}));
+}
+
+TEST(ReferenceTest, RandomJaxReference187) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{0, 0, 0, -2},
+      /*init_value=*/2147483646,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{2, 2},
+      /*body=*/[](auto a, auto b) { return a <= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(4, 4));
+
+  EXPECT_THAT(res.data, ElementsAreArray({1, 3, 5, 7, 21, 23, 25, 27, 41, 43,
+                                          45, 47, 61, 63, 65, 67}));
+}
+
+TEST(ReferenceTest, RandomJaxReference188) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 2},
+      /*padding=*/{2, 2, -1, -1},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{2, 1},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(11, 17));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {-2147483647, 2,           -2147483647, 3,           -2147483647,
+           4,           -2147483647, 5,           -2147483647, 6,
+           -2147483647, 7,           -2147483647, 8,           -2147483647,
+           9,           -2147483647, -2147483647, 12,          -2147483647,
+           13,          -2147483647, 14,          -2147483647, 15,
+           -2147483647, 16,          -2147483647, 17,          -2147483647,
+           18,          -2147483647, 19,          -2147483647, -2147483647,
+           22,          -2147483647, 23,          -2147483647, 24,
+           -2147483647, 25,          -2147483647, 26,          -2147483647,
+           27,          -2147483647, 28,          -2147483647, 29,
+           -2147483647, -2147483647, 32,          -2147483647, 33,
+           -2147483647, 34,          -2147483647, 35,          -2147483647,
+           36,          -2147483647, 37,          -2147483647, 38,
+           -2147483647, 39,          -2147483647, -2147483647, 42,
+           -2147483647, 43,          -2147483647, 44,          -2147483647,
+           45,          -2147483647, 46,          -2147483647, 47,
+           -2147483647, 48,          -2147483647, 49,          -2147483647,
+           -2147483647, 52,          -2147483647, 53,          -2147483647,
+           54,          -2147483647, 55,          -2147483647, 56,
+           -2147483647, 57,          -2147483647, 58,          -2147483647,
+           59,          -2147483647, -2147483647, 62,          -2147483647,
+           63,          -2147483647, 64,          -2147483647, 65,
+           -2147483647, 66,          -2147483647, 67,          -2147483647,
+           68,          -2147483647, 69,          -2147483647, -2147483647,
+           72,          -2147483647, 73,          -2147483647, 74,
+           -2147483647, 75,          -2147483647, 76,          -2147483647,
+           77,          -2147483647, 78,          -2147483647, 79,
+           -2147483647, -2147483647, 82,          -2147483647, 83,
+           -2147483647, 84,          -2147483647, 85,          -2147483647,
+           86,          -2147483647, 87,          -2147483647, 88,
+           -2147483647, 89,          -2147483647, -2147483647, 92,
+           -2147483647, 93,          -2147483647, 94,          -2147483647,
+           95,          -2147483647, 96,          -2147483647, 97,
+           -2147483647, 98,          -2147483647, 99,          -2147483647,
+           -2147483647, 92,          -2147483647, 93,          -2147483647,
+           94,          -2147483647, 95,          -2147483647, 96,
+           -2147483647, 97,          -2147483647, 98,          -2147483647,
+           99,          -2147483647}));
+}
+
+TEST(ReferenceTest, RandomJaxReference189) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{1, 0, -2, 2},
+      /*init_value=*/0,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{1, 1},
+      /*window_strides=*/{2, 2},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(5, 5));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray({7,   11,  15,  19,  0,   74,  82,  90,  98,
+                                0,   154, 162, 170, 178, 0,   234, 242, 250,
+                                258, 0,   314, 322, 330, 338, 0}));
+}
+
+TEST(ReferenceTest, RandomJaxReference190) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{1, -1, 2, 0},
+      /*init_value=*/2147483646,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{2, 2},
+      /*body=*/[](auto a, auto b) { return a <= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(9, 6));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray(
+                  {2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646}));
+}
+
+TEST(ReferenceTest, RandomJaxReference191) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{1, 0, 2, -2},
+      /*init_value=*/0,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{1, 2},
+      /*body=*/std::plus<>());
+  EXPECT_THAT(res.shape, ElementsAre(11, 5));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {0,   0,   0,   0,   0,   0,   3,   7,   11,  15,  0,   23,  27, 31,
+           35,  0,   43,  47,  51,  55,  0,   63,  67,  71,  75,  0,   83, 87,
+           91,  95,  0,   103, 107, 111, 115, 0,   123, 127, 131, 135, 0,  143,
+           147, 151, 155, 0,   163, 167, 171, 175, 0,   183, 187, 191, 195}));
+}
+
+TEST(ReferenceTest, RandomJaxReference192) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{-1, 2, 0, -1},
+      /*init_value=*/2147483646,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{1, 2},
+      /*body=*/[](auto a, auto b) { return a <= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(11, 4));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray(
+                  {11,         13,         15,         17,         21,
+                   23,         25,         27,         31,         33,
+                   35,         37,         41,         43,         45,
+                   47,         51,         53,         55,         57,
+                   61,         63,         65,         67,         71,
+                   73,         75,         77,         81,         83,
+                   85,         87,         91,         93,         95,
+                   97,         2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646}));
+}
+
+TEST(ReferenceTest, RandomJaxReference193) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{2, 2, 1, 0},
+      /*init_value=*/1,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{1, 2},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(21, 6));
+
+  EXPECT_THAT(res.data, ElementsAreArray(
+                            {1, 2,    4,    6,    8,    10,   1, 1, 1, 1, 1, 1,
+                             1, 24,   56,   96,   144,  200,  1, 1, 1, 1, 1, 1,
+                             1, 264,  336,  416,  504,  600,  1, 1, 1, 1, 1, 1,
+                             1, 704,  816,  936,  1064, 1200, 1, 1, 1, 1, 1, 1,
+                             1, 1344, 1496, 1656, 1824, 2000, 1, 1, 1, 1, 1, 1,
+                             1, 2184, 2376, 2576, 2784, 3000, 1, 1, 1, 1, 1, 1,
+                             1, 3224, 3456, 3696, 3944, 4200, 1, 1, 1, 1, 1, 1,
+                             1, 4464, 4736, 5016, 5304, 5600, 1, 1, 1, 1, 1, 1,
+                             1, 5904, 6216, 6536, 6864, 7200, 1, 1, 1, 1, 1, 1,
+                             1, 7544, 7896, 8256, 8624, 9000, 1, 1, 1, 1, 1, 1,
+                             1, 92,   94,   96,   98,   100}));
+}
+
+TEST(ReferenceTest, RandomJaxReference194) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 2},
+      /*padding=*/{-2, -1, -2, -1},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{2, 2},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{2, 2},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(7, 8));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray({22, 23, 24, 25, 26, 27, 28, 29, 32, 33, 34, 35,
+                                36, 37, 38, 39, 42, 43, 44, 45, 46, 47, 48, 49,
+                                52, 53, 54, 55, 56, 57, 58, 59, 62, 63, 64, 65,
+                                66, 67, 68, 69, 72, 73, 74, 75, 76, 77, 78, 79,
+                                82, 83, 84, 85, 86, 87, 88, 89}));
+}
+
+TEST(ReferenceTest, RandomJaxReference195) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{-2, 1, -2, 2},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{2, 2},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(5, 9));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {23,          24,          25,          26,          27,
+           28,          29,          30,          30,          43,
+           44,          45,          46,          47,          48,
+           49,          50,          50,          63,          64,
+           65,          66,          67,          68,          69,
+           70,          70,          83,          84,          85,
+           86,          87,          88,          89,          90,
+           90,          -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647}));
+}
+
+TEST(ReferenceTest, RandomJaxReference196) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 1},
+      /*padding=*/{1, 1, 1, -1},
+      /*init_value=*/1,
+      /*window_dimensions=*/{1, 2},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{2, 2},
+      /*body=*/std::multiplies<>());
+  EXPECT_THAT(res.shape, ElementsAre(6, 5));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray({1,    1,    1,    1,    1,    11,   156,  210,
+                                272,  342,  31,   1056, 1190, 1332, 1482, 51,
+                                2756, 2970, 3192, 3422, 71,   5256, 5550, 5852,
+                                6162, 91,   8556, 8930, 9312, 9702}));
+}
+
+TEST(ReferenceTest, RandomJaxReference197) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 1},
+      /*padding=*/{-2, -2, -2, -2},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{2, 2},
+      /*window_strides=*/{2, 2},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(7, 3));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray({23, 25, 27, 33, 35, 37, 43, 45, 47, 53, 55,
+                                57, 63, 65, 67, 73, 75, 77, 83, 85, 87}));
+}
+
+TEST(ReferenceTest, RandomJaxReference198) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{1, 2},
+      /*padding=*/{1, 1, -2, 0},
+      /*init_value=*/2147483646,
+      /*window_dimensions=*/{1, 1},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{1, 2},
+      /*body=*/[](auto a, auto b) { return a <= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(12, 9));
+
+  EXPECT_THAT(res.data,
+              ElementsAreArray(
+                  {2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2,
+                   3,          4,          5,          6,          7,
+                   8,          9,          10,         12,         13,
+                   14,         15,         16,         17,         18,
+                   19,         20,         22,         23,         24,
+                   25,         26,         27,         28,         29,
+                   30,         32,         33,         34,         35,
+                   36,         37,         38,         39,         40,
+                   42,         43,         44,         45,         46,
+                   47,         48,         49,         50,         52,
+                   53,         54,         55,         56,         57,
+                   58,         59,         60,         62,         63,
+                   64,         65,         66,         67,         68,
+                   69,         70,         72,         73,         74,
+                   75,         76,         77,         78,         79,
+                   80,         82,         83,         84,         85,
+                   86,         87,         88,         89,         90,
+                   92,         93,         94,         95,         96,
+                   97,         98,         99,         100,        2147483646,
+                   2147483646, 2147483646, 2147483646, 2147483646, 2147483646,
+                   2147483646, 2147483646, 2147483646}));
+}
+
+TEST(ReferenceTest, RandomJaxReference199) {
+  const Tensor<int> res = ReduceWindow<int>(
+      /*input=*/Tensor<int>::iota(/*shape=*/{10, 10}),
+      /*base_dilations=*/{2, 2},
+      /*padding=*/{-1, 1, -1, -2},
+      /*init_value=*/-2147483647,
+      /*window_dimensions=*/{2, 1},
+      /*window_dilations=*/{2, 1},
+      /*window_strides=*/{1, 2},
+      /*body=*/[](auto a, auto b) { return a >= b ? a : b; });
+  EXPECT_THAT(res.shape, ElementsAre(17, 8));
+
+  EXPECT_THAT(
+      res.data,
+      ElementsAreArray(
+          {-2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647, -2147483647, -2147483647, -2147483647, -2147483647,
+           -2147483647}));
+}
+
+}  // namespace
+}  // namespace tflite::reduce_window::reference
diff --git a/tensorflow/lite/kernels/stablehlo_scatter.cc b/tensorflow/lite/kernels/stablehlo_scatter.cc
index 0fd7d50..9885e8c 100644
--- a/tensorflow/lite/kernels/stablehlo_scatter.cc
+++ b/tensorflow/lite/kernels/stablehlo_scatter.cc
@@ -110,13 +110,13 @@
 
 static ComputationType OpCodeToComputationType(int op_code) {
   switch (op_code) {
-    case kTfLiteBuiltinAdd:
+    case kTfLiteBuiltinStablehloAdd:
       return ComputationType::kAdd;
-    case kTfLiteBuiltinMul:
+    case kTfLiteBuiltinStablehloMultiply:
       return ComputationType::kMultiply;
-    case kTfLiteBuiltinMaximum:
+    case kTfLiteBuiltinStablehloMaximum:
       return ComputationType::kMaximum;
-    case kTfLiteBuiltinMinimum:
+    case kTfLiteBuiltinStablehloMinimum:
       return ComputationType::kMinimum;
     default:
       return ComputationType::kOther;
diff --git a/tensorflow/lite/kernels/stablehlo_scatter_test.cc b/tensorflow/lite/kernels/stablehlo_scatter_test.cc
index 961ad41..fb2faea 100644
--- a/tensorflow/lite/kernels/stablehlo_scatter_test.cc
+++ b/tensorflow/lite/kernels/stablehlo_scatter_test.cc
@@ -69,13 +69,15 @@
     int* dummy = nullptr;
     AddSubgraphs(1, dummy);
     if (op_type == StablehloScatterOpType::kAdd) {
-      subgraph_builder_.BuildAddSubgraph(interpreter_->subgraph(1));
+      subgraph_builder_.BuildStablehloAddSubgraph(interpreter_->subgraph(1));
     } else if (op_type == StablehloScatterOpType::kMul) {
-      subgraph_builder_.BuildMulSubgraph(interpreter_->subgraph(1));
+      subgraph_builder_.BuildStablehloMulSubgraph(interpreter_->subgraph(1));
     } else if (op_type == StablehloScatterOpType::kMax) {
-      subgraph_builder_.BuildMaximumSubgraph(interpreter_->subgraph(1));
+      subgraph_builder_.BuildStablehloMaximumSubgraph(
+          interpreter_->subgraph(1));
     } else if (op_type == StablehloScatterOpType::kMin) {
-      subgraph_builder_.BuildMinimumSubgraph(interpreter_->subgraph(1));
+      subgraph_builder_.BuildStablehloMinimumSubgraph(
+          interpreter_->subgraph(1));
     } else if (op_type == StablehloScatterOpType::kUpdate) {
       subgraph_builder_.BuildOutputIsSecondInputSubgraph(
           interpreter_->subgraph(1));
diff --git a/tensorflow/lite/kernels/subgraph_test_util.cc b/tensorflow/lite/kernels/subgraph_test_util.cc
index 31030b0..2bd0992 100644
--- a/tensorflow/lite/kernels/subgraph_test_util.cc
+++ b/tensorflow/lite/kernels/subgraph_test_util.cc
@@ -424,6 +424,13 @@
                         params, operand_type, operand_type, operand_type);
 }
 
+void SubgraphBuilder::BuildStablehloAddSubgraph(Subgraph* subgraph,
+                                                const TfLiteType operand_type) {
+  BuildBinaryOpSubgraph(subgraph, ops::builtin::Register_STABLEHLO_ADD,
+                        kTfLiteBuiltinStablehloAdd, nullptr, operand_type,
+                        operand_type, operand_type);
+}
+
 // This body subgraph has arena and dynamic output tensors which are not in
 // place to verify that body subgraph outputs are written directly to node
 // outputs. It also has inplace dynamic and arena outputs.
@@ -573,6 +580,13 @@
                         /*output_type=*/operand_type);
 }
 
+void SubgraphBuilder::BuildStablehloMaximumSubgraph(
+    Subgraph* subgraph, const TfLiteType operand_type) {
+  BuildBinaryOpSubgraph(subgraph, ops::builtin::Register_STABLEHLO_MAXIMUM,
+                        kTfLiteBuiltinStablehloMaximum, nullptr, operand_type,
+                        operand_type, operand_type);
+}
+
 void SubgraphBuilder::BuildMinimumSubgraph(Subgraph* subgraph,
                                            const TfLiteType operand_type) {
   BuildBinaryOpSubgraph(subgraph, ops::builtin::Register_MINIMUM,
@@ -582,6 +596,29 @@
                         /*output_type=*/operand_type);
 }
 
+void SubgraphBuilder::BuildStablehloMinimumSubgraph(
+    Subgraph* subgraph, const TfLiteType operand_type) {
+  BuildBinaryOpSubgraph(subgraph, ops::builtin::Register_STABLEHLO_MINIMUM,
+                        kTfLiteBuiltinStablehloMinimum, nullptr, operand_type,
+                        operand_type, operand_type);
+}
+
+void SubgraphBuilder::BuildLogicalOrSubgraph(Subgraph* subgraph) {
+  BuildBinaryOpSubgraph(subgraph, ops::builtin::Register_LOGICAL_OR,
+                        kTfLiteBuiltinLogicalOr, /*params=*/nullptr,
+                        /*input1_type=*/kTfLiteBool,
+                        /*input2_type=*/kTfLiteBool,
+                        /*output_type=*/kTfLiteBool);
+}
+
+void SubgraphBuilder::BuildLogicalAndSubgraph(Subgraph* subgraph) {
+  BuildBinaryOpSubgraph(subgraph, ops::builtin::Register_LOGICAL_AND,
+                        kTfLiteBuiltinLogicalAnd, /*params=*/nullptr,
+                        /*input1_type=*/kTfLiteBool,
+                        /*input2_type=*/kTfLiteBool,
+                        /*output_type=*/kTfLiteBool);
+}
+
 void SubgraphBuilder::BuildOutputIsSecondInputSubgraph(Subgraph* subgraph) {
   const int kInput1 = 0;
   const int kInput2 = 1;
@@ -613,6 +650,13 @@
                         /*output_type=*/operand_type);
 }
 
+void SubgraphBuilder::BuildStablehloMulSubgraph(Subgraph* subgraph,
+                                                const TfLiteType operand_type) {
+  BuildBinaryOpSubgraph(subgraph, ops::builtin::Register_STABLEHLO_MULTIPLY,
+                        kTfLiteBuiltinStablehloMultiply, nullptr, operand_type,
+                        operand_type, operand_type);
+}
+
 // Build a subgraph with a pad op. Helper function for testing.
 void SubgraphBuilder::BuildPadSubgraph(Subgraph* subgraph) {
   const int kInput1 = 0;
diff --git a/tensorflow/lite/kernels/subgraph_test_util.h b/tensorflow/lite/kernels/subgraph_test_util.h
index 73d36ee..311397b 100644
--- a/tensorflow/lite/kernels/subgraph_test_util.h
+++ b/tensorflow/lite/kernels/subgraph_test_util.h
@@ -107,16 +107,39 @@
   void BuildAddSubgraph(Subgraph* subgraph,
                         TfLiteType operand_type = kTfLiteInt32);
 
+  // Build a subgraph with a single stablehlo Add op.
+  // 2 inputs. 1 output.
+  void BuildStablehloAddSubgraph(Subgraph* subgraph,
+                                 TfLiteType operand_type = kTfLiteInt32);
+
   // Build a subgraph with a single Maximum op.
   // 2 inputs. 1 output.
   void BuildMaximumSubgraph(Subgraph* subgraph,
                             TfLiteType operand_type = kTfLiteInt32);
 
+  // Build a subgraph with a single stablehlo Maximum op.
+  // 2 inputs. 1 output.
+  void BuildStablehloMaximumSubgraph(Subgraph* subgraph,
+                                     TfLiteType operand_type = kTfLiteInt32);
+
   // Build a subgraph with a single Minimum op.
   // 2 inputs. 1 output.
   void BuildMinimumSubgraph(Subgraph* subgraph,
                             TfLiteType operand_type = kTfLiteInt32);
 
+  // Build a subgraph with a single stablehlo Minimum op.
+  // 2 inputs. 1 output.
+  void BuildStablehloMinimumSubgraph(Subgraph* subgraph,
+                                     TfLiteType operand_type = kTfLiteInt32);
+
+  // Build a subgraph with a single LogicalOr op.
+  // 2 inputs. 1 output.
+  void BuildLogicalOrSubgraph(Subgraph* subgraph);
+
+  // Build a subgraph with a single LogicalAnd op.
+  // 2 inputs. 1 output.
+  void BuildLogicalAndSubgraph(Subgraph* subgraph);
+
   // Build a subgraph with no ops inside.
   // 2 inputs. 1 output. Routes the second input to the output.
   void BuildOutputIsSecondInputSubgraph(Subgraph* subgraph);
@@ -126,6 +149,11 @@
   void BuildMulSubgraph(Subgraph* subgraph,
                         TfLiteType operand_type = kTfLiteInt32);
 
+  // Build a subgraph with a single stablehlo Multiply op.
+  // 2 inputs. 1 output.
+  void BuildStablehloMulSubgraph(Subgraph* subgraph,
+                                 TfLiteType operand_type = kTfLiteInt32);
+
   // Build a subgraph with a single Pad op.
   // 2 inputs. 1 output.
   void BuildPadSubgraph(Subgraph* subgraph);
diff --git a/tensorflow/lite/python/authoring/BUILD b/tensorflow/lite/python/authoring/BUILD
index cf1b467..b6f13fc 100644
--- a/tensorflow/lite/python/authoring/BUILD
+++ b/tensorflow/lite/python/authoring/BUILD
@@ -2,12 +2,6 @@
 
 package(
     # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
-    default_visibility = [
-        "//tensorflow:internal",
-        "//tensorflow_estimator:__subpackages__",
-        "//tensorflow_federated:__subpackages__",
-        "//third_party/py/tensorflow:__subpackages__",
-    ],
     licenses = ["notice"],
 )
 
diff --git a/tensorflow/lite/python/interpreter_wrapper/numpy.cc b/tensorflow/lite/python/interpreter_wrapper/numpy.cc
index b0776e5..0e07563 100644
--- a/tensorflow/lite/python/interpreter_wrapper/numpy.cc
+++ b/tensorflow/lite/python/interpreter_wrapper/numpy.cc
@@ -91,6 +91,8 @@
       return kTfLiteUInt32;
     case NPY_INT16:
       return kTfLiteInt16;
+    case NPY_UINT16:
+      return kTfLiteUInt16;
     case NPY_UINT8:
       return kTfLiteUInt8;
     case NPY_INT8:
diff --git a/tensorflow/lite/schema/BUILD b/tensorflow/lite/schema/BUILD
index a8b0ec8..02b7fd4 100644
--- a/tensorflow/lite/schema/BUILD
+++ b/tensorflow/lite/schema/BUILD
@@ -11,6 +11,18 @@
     licenses = ["notice"],
 )
 
+filegroup(
+    name = "tflite_internal_cc_3p_api_deps_src",
+    srcs = [
+        ":schema_fbs_srcs",
+        ":schema_utils.cc",
+        ":schema_utils.h",
+    ],
+    visibility = [
+        "//tensorflow/lite:__pkg__",
+    ],
+)
+
 # This is the package group declaration to which targets for TensorFlow Lite
 # Flatbuffer schema utilities.
 #
diff --git a/tensorflow/lite/schema/schema.fbs b/tensorflow/lite/schema/schema.fbs
index 993c069..6bffead 100644
--- a/tensorflow/lite/schema/schema.fbs
+++ b/tensorflow/lite/schema/schema.fbs
@@ -427,10 +427,10 @@
   // All Operators start with STABLEHLO_ prefixes are subject to change
   // Many of the ops below can not be executed by TFlite runtime
   STABLEHLO_LOGISTIC = 162, // WARNING: Do not have runtime support
-  STABLEHLO_ADD = 163, // WARNING: No runtime support yet
+  STABLEHLO_ADD = 163,
   STABLEHLO_DIVIDE = 164, // WARNING: No runtime support yet
-  STABLEHLO_MULTIPLY = 165, // WARNING: No runtime support yet
-  STABLEHLO_MAXIMUM = 166, // WARNING: No runtime support yet
+  STABLEHLO_MULTIPLY = 165,
+  STABLEHLO_MAXIMUM = 166,
   STABLEHLO_RESHAPE = 167, // WARNING: No runtime support yet
   STABLEHLO_CLAMP = 168, // WARNING: No runtime support
   STABLEHLO_CONCATENATE = 169, // WARNING: No runtime support
@@ -445,7 +445,7 @@
   STABLEHLO_EXPONENTIAL = 178, // WARNING: No runtime support
   STABLEHLO_FLOOR = 179, // WARNING: No runtime support
   STABLEHLO_LOG = 180, // WARNING: No runtime support
-  STABLEHLO_MINIMUM = 181, // WARNING: No runtime support
+  STABLEHLO_MINIMUM = 181,
   STABLEHLO_NEGATE = 182, // WARNING: No runtime support
   STABLEHLO_OR = 183, // WARNING: No runtime support
   STABLEHLO_POWER = 184, // WARNING: No runtime support
@@ -462,14 +462,14 @@
   STABLEHLO_PAD = 195, // WARNING: No runtime support
   STABLEHLO_IOTA = 196, // WARNING: No runtime support
   STABLEHLO_DOT_GENERAL = 197, // WARNING: No runtime support
-  STABLEHLO_REDUCE_WINDOW = 198, // WARNING: No runtime support
+  STABLEHLO_REDUCE_WINDOW = 198,
   STABLEHLO_SORT = 199, // WARNING: No runtime support
   STABLEHLO_WHILE = 200, // WARNING: No runtime support
-  STABLEHLO_GATHER = 201, // WARNING: No runtime support
+  STABLEHLO_GATHER = 201,
   STABLEHLO_TRANSPOSE = 202, // WARNING: No runtime support
   DILATE = 203,
   STABLEHLO_RNG_BIT_GENERATOR = 204,
-  REDUCE_WINDOW = 205,
+  REDUCE_WINDOW = 205 (deprecated),
 }
 // LINT.ThenChange(nnapi_linter/linter.proto)
 
@@ -626,7 +626,7 @@
   StablehloTransposeOptions,
   DilateOptions,
   StablehloRngBitGeneratorOptions,
-  ReduceWindowOptions,
+  ReduceWindowOptions (deprecated),
 }
 
 table StablehloGatherOptions{
@@ -1458,7 +1458,7 @@
   ANY,
 }
 
-table ReduceWindowOptions{
+table ReduceWindowOptions (deprecated) {
   reduce_function: ReduceWindowFunction;
 }
 
diff --git a/tensorflow/lite/special_rules.bzl b/tensorflow/lite/special_rules.bzl
index 83d1a65..2e2ec72 100644
--- a/tensorflow/lite/special_rules.bzl
+++ b/tensorflow/lite/special_rules.bzl
@@ -78,6 +78,18 @@
     This is a no-op outside of Google."""
     return []
 
+def xnnpack_plugin_impl_visibility_allowlist():
+    """Returns a list of packages that can depend on tensorflow/lite/core/acceleration/configuration:xnnpack_plugin.
+
+    This is a no-op outside of Google."""
+    return []
+
+def tflite_internal_cc_3p_api_deps_src_all_visibility_allowlist():
+    """Returns a list of packages that can depend on tensorflow/lite:tflite_internal_cc_3p_api_deps_src_all.
+
+    This is a no-op outside of Google."""
+    return []
+
 def tflite_extra_gles_deps():
     """This is a no-op outside of Google."""
     return []
@@ -137,11 +149,12 @@
 
     return [
         "//third_party/fft2d:fft2d_headers",
-        "//third_party/eigen3",
         "@com_google_absl//absl/log",
         "@com_google_absl//absl/log:check",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/strings:str_format",
+        "@com_google_absl//absl/types:optional",
+        "@eigen_archive//:eigen3",
         "@gemmlowp",
         "@icu//:common",
         "//third_party/icu/data:conversion_data",
diff --git a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/proto/BUILD b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/proto/BUILD
index 69ac63b..9bf70a2 100644
--- a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/proto/BUILD
+++ b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/proto/BUILD
@@ -2,6 +2,7 @@
 #  Holds model-agnostic files and proto definitions. The app will bundle the files into assets.
 
 load("//tensorflow/lite/tools/benchmark/experimental/delegate_performance/android:proto.bzl", "proto_data")
+# copybara:uncomment load("//tools/build_defs/proto/cpp:cc_proto_library.bzl", "cc_proto_library")
 
 package(
     # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
diff --git a/tensorflow/lite/tools/cmake/modules/abseil-cpp.cmake b/tensorflow/lite/tools/cmake/modules/abseil-cpp.cmake
index 92e746a..28fbb62 100644
--- a/tensorflow/lite/tools/cmake/modules/abseil-cpp.cmake
+++ b/tensorflow/lite/tools/cmake/modules/abseil-cpp.cmake
@@ -24,7 +24,7 @@
   abseil-cpp
   GIT_REPOSITORY https://github.com/abseil/abseil-cpp
   # Sync with tensorflow/third_party/absl/workspace.bzl
-  GIT_TAG b971ac5250ea8de900eae9f95e06548d14cd95fe
+  GIT_TAG fb3621f4f897824c0dbe0615fa94543df6192f30
   GIT_SHALLOW TRUE
   GIT_PROGRESS TRUE
   PREFIX "${CMAKE_BINARY_DIR}"
diff --git a/tensorflow/lite/tools/cmake/modules/eigen.cmake b/tensorflow/lite/tools/cmake/modules/eigen.cmake
index 3a0b67b..1bb2033 100644
--- a/tensorflow/lite/tools/cmake/modules/eigen.cmake
+++ b/tensorflow/lite/tools/cmake/modules/eigen.cmake
@@ -23,7 +23,7 @@
   eigen
   GIT_REPOSITORY https://gitlab.com/libeigen/eigen.git
   # Sync with tensorflow/third_party/eigen3/workspace.bzl
-  GIT_TAG 66e8f38891841bf88ee976a316c0c78a52f0cee5
+  GIT_TAG aa6964bf3a34fd607837dd8123bc42465185c4f8
   # It's not currently (cmake 3.17) possible to shallow clone with a GIT TAG
   # as cmake attempts to git checkout the commit hash after the clone
   # which doesn't work as it's a shallow clone hence a different commit hash.
diff --git a/tensorflow/lite/tools/cmake/modules/xnnpack.cmake b/tensorflow/lite/tools/cmake/modules/xnnpack.cmake
index cb40b56..a6b3645 100644
--- a/tensorflow/lite/tools/cmake/modules/xnnpack.cmake
+++ b/tensorflow/lite/tools/cmake/modules/xnnpack.cmake
@@ -23,7 +23,7 @@
   xnnpack
   GIT_REPOSITORY https://github.com/google/XNNPACK
   # Sync with tensorflow/workspace2.bzl
-  GIT_TAG bbbaa7352a3ea729987d3e654d37be93e8009691
+  GIT_TAG c7e7cde37615a81a529c326aa278bfab4cd6fe5a
   GIT_PROGRESS TRUE
   PREFIX "${CMAKE_BINARY_DIR}"
   SOURCE_DIR "${CMAKE_BINARY_DIR}/xnnpack"
diff --git a/tensorflow/lite/tools/evaluation/proto/BUILD b/tensorflow/lite/tools/evaluation/proto/BUILD
index efea792..696876a 100644
--- a/tensorflow/lite/tools/evaluation/proto/BUILD
+++ b/tensorflow/lite/tools/evaluation/proto/BUILD
@@ -18,6 +18,7 @@
     "//tensorflow/core/platform:build_config.bzl",
     "tf_proto_library",
 )
+# copybara:uncomment load("//tools/build_defs/proto/cpp:cc_proto_library.bzl", "cc_proto_library")
 
 package(
     # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
diff --git a/tensorflow/lite/tools/optimize/sparsity/BUILD b/tensorflow/lite/tools/optimize/sparsity/BUILD
index f18d3a4..13a95f0 100644
--- a/tensorflow/lite/tools/optimize/sparsity/BUILD
+++ b/tensorflow/lite/tools/optimize/sparsity/BUILD
@@ -16,6 +16,10 @@
         "-fexceptions",
         "-fno-strict-aliasing",
     ],
+    data = [
+        "format_converter_wrapper_pybind11.pyi",
+    ],
+    enable_stub_generation = True,
     features = ["-use_header_modules"],
     deps = [
         "//tensorflow/lite/core/c:common",
diff --git a/tensorflow/lite/tools/optimize/sparsity/format_converter_wrapper_pybind11.cc b/tensorflow/lite/tools/optimize/sparsity/format_converter_wrapper_pybind11.cc
index 32705aa..7d04291 100644
--- a/tensorflow/lite/tools/optimize/sparsity/format_converter_wrapper_pybind11.cc
+++ b/tensorflow/lite/tools/optimize/sparsity/format_converter_wrapper_pybind11.cc
@@ -15,6 +15,7 @@
 
 #include <vector>
 
+#include "pybind11/attr.h"  // from @pybind11
 #include "pybind11/pybind11.h"  // from @pybind11
 #include "pybind11/stl.h"  // from @pybind11
 #include "tensorflow/lite/core/c/common.h"
@@ -37,6 +38,9 @@
       .value("TF_LITE_DIM_SPARSE_CSR", TfLiteDimensionType::kTfLiteDimSparseCSR)
       .export_values();
 
+  py::class_<TfLiteSparsity> sparsity_class(m, "TfLiteSparsity",
+                                            py::module_local());
+
   py::class_<FormatConverterFp32>(m, "FormatConverterFp32")
       .def(py::init</*shape=*/const std::vector<int>&,
                     /*traversal_order=*/const std::vector<int>&,
diff --git a/tensorflow/lite/tools/optimize/sparsity/format_converter_wrapper_pybind11.pyi b/tensorflow/lite/tools/optimize/sparsity/format_converter_wrapper_pybind11.pyi
new file mode 100644
index 0000000..8010487
--- /dev/null
+++ b/tensorflow/lite/tools/optimize/sparsity/format_converter_wrapper_pybind11.pyi
@@ -0,0 +1,71 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from typing import ClassVar
+
+from typing import overload
+TF_LITE_DIM_DENSE: TfLiteDimensionType
+TF_LITE_DIM_SPARSE_CSR: TfLiteDimensionType
+TF_LITE_ERROR: TfLiteStatus
+TF_LITE_OK: TfLiteStatus
+
+class FormatConverterFp32:
+    @overload
+    def __init__(self, arg0: list[int], arg1: list[int], arg2: list[TfLiteDimensionType], arg3: list[int], arg4: list[int]) -> None: ...
+    @overload
+    def __init__(self, arg0: list[int], arg1: TfLiteSparsity) -> None: ...
+    def DenseToSparse(self, arg0) -> TfLiteStatus: ...
+    def GetData(self) -> list[float]: ...
+    def GetDimMetadata(self) -> list[list[int]]: ...
+    def SparseToDense(self, arg0) -> TfLiteStatus: ...
+
+class TfLiteDimensionType:
+    __members__: ClassVar[dict] = ...  # read-only
+    TF_LITE_DIM_DENSE: ClassVar[TfLiteDimensionType] = ...
+    TF_LITE_DIM_SPARSE_CSR: ClassVar[TfLiteDimensionType] = ...
+    __entries: ClassVar[dict] = ...
+    def __init__(self, value: int) -> None: ...
+    def __eq__(self, other: object) -> bool: ...
+    def __getstate__(self) -> int: ...
+    def __hash__(self) -> int: ...
+    def __index__(self) -> int: ...
+    def __int__(self) -> int: ...
+    def __ne__(self, other: object) -> bool: ...
+    def __setstate__(self, state: int) -> None: ...
+    @property
+    def name(self) -> str: ...
+    @property
+    def value(self) -> int: ...
+
+class TfLiteSparsity:
+    def __init__(self, *args, **kwargs) -> None: ...
+
+class TfLiteStatus:
+    __members__: ClassVar[dict] = ...  # read-only
+    TF_LITE_ERROR: ClassVar[TfLiteStatus] = ...
+    TF_LITE_OK: ClassVar[TfLiteStatus] = ...
+    __entries: ClassVar[dict] = ...
+    def __init__(self, value: int) -> None: ...
+    def __eq__(self, other: object) -> bool: ...
+    def __getstate__(self) -> int: ...
+    def __hash__(self) -> int: ...
+    def __index__(self) -> int: ...
+    def __int__(self) -> int: ...
+    def __ne__(self, other: object) -> bool: ...
+    def __setstate__(self, state: int) -> None: ...
+    @property
+    def name(self) -> str: ...
+    @property
+    def value(self) -> int: ...
diff --git a/tensorflow/lite/tools/tflite-android.Dockerfile b/tensorflow/lite/tools/tflite-android.Dockerfile
index aca45fc..d1981b0 100644
--- a/tensorflow/lite/tools/tflite-android.Dockerfile
+++ b/tensorflow/lite/tools/tflite-android.Dockerfile
@@ -23,7 +23,7 @@
     rm ${ANDROID_SDK_FILENAME}
 
 # Install Android NDK.
-ENV ANDROID_NDK_FILENAME android-ndk-r25b-linux.zip
+ENV ANDROID_NDK_FILENAME android-ndk-r21e-linux-x86_64.zip
 ENV ANDROID_NDK_URL https://dl.google.com/android/repository/${ANDROID_NDK_FILENAME}
 ENV ANDROID_NDK_HOME ${ANDROID_DEV_HOME}/ndk
 ENV PATH ${PATH}:${ANDROID_NDK_HOME}
diff --git a/tensorflow/lite/tools/versioning/runtime_version.cc b/tensorflow/lite/tools/versioning/runtime_version.cc
index d149caa..47282cb 100644
--- a/tensorflow/lite/tools/versioning/runtime_version.cc
+++ b/tensorflow/lite/tools/versioning/runtime_version.cc
@@ -433,7 +433,13 @@
            {{BuiltinOperator_STABLEHLO_SCATTER, 1}, "2.15.0"},
            {{BuiltinOperator_DILATE, 1}, "2.15.0"},
            {{BuiltinOperator_STABLEHLO_RNG_BIT_GENERATOR, 1}, "2.15.0"},
-           {{BuiltinOperator_REDUCE_WINDOW, 1}, "2.15.0"}});
+           {{BuiltinOperator_REDUCE_WINDOW, 1}, "2.15.0"},
+           {{BuiltinOperator_STABLEHLO_GATHER, 1}, "2.16.0"},
+           {{BuiltinOperator_STABLEHLO_ADD, 1}, "2.16.0"},
+           {{BuiltinOperator_STABLEHLO_MULTIPLY, 1}, "2.16.0"},
+           {{BuiltinOperator_STABLEHLO_REDUCE_WINDOW, 1}, "2.16.0"},
+           {{BuiltinOperator_STABLEHLO_MAXIMUM, 1}, "2.16.0"},
+           {{BuiltinOperator_STABLEHLO_MINIMUM, 1}, "2.16.0"}});
 
   std::pair<BuiltinOperator, int> version_key = {op_code, op_version};
   auto it = op_version_map->find(version_key);
@@ -460,8 +466,8 @@
         continue;
       }
       if (CompareRuntimeVersion(model_min_version, runtime_version)) {
-        // Current min model runtime version should be bumped if we see a higher
-        // op version.
+        // Current min model runtime version should be bumped if we see a
+        // higher op version.
         model_min_version = runtime_version;
       }
     }
diff --git a/tensorflow/opensource_only.files b/tensorflow/opensource_only.files
index 4271727..68f19f8 100644
--- a/tensorflow/opensource_only.files
+++ b/tensorflow/opensource_only.files
@@ -94,7 +94,13 @@
 tf_staging/tensorflow/lite/delegates/gpu/common/task/serialization_base_generated.h:
 tf_staging/tensorflow/lite/delegates/hexagon/hexagon_nn/BUILD:
 tf_staging/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/BUILD:
+tf_staging/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_app_using_stable_delegate.cc:
 tf_staging/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_stable_delegate_external.cc:
+tf_staging/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_stable_delegate_external_test.cc:
+tf_staging/tensorflow/lite/delegates/utils/experimental/stable_delegate/BUILD:
+tf_staging/tensorflow/lite/delegates/utils/experimental/stable_delegate/delegate_loader.cc:
+tf_staging/tensorflow/lite/delegates/utils/experimental/stable_delegate/delegate_loader.h:
+tf_staging/tensorflow/lite/delegates/utils/experimental/stable_delegate/delegate_loader_test.cc:
 tf_staging/tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h:
 tf_staging/tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api.h:
 tf_staging/tensorflow/lite/experimental/acceleration/mini_benchmark/libjpeg.h:
@@ -202,24 +208,13 @@
 tf_staging/third_party/cython.BUILD:
 tf_staging/third_party/ducc/BUILD:
 tf_staging/third_party/ducc/ducc0_custom_lowlevel_threading.h:
+tf_staging/third_party/ducc/fft.cc:
+tf_staging/third_party/ducc/fft.h:
 tf_staging/third_party/ducc/threading.cc:
 tf_staging/third_party/ducc/threading.h:
 tf_staging/third_party/eigen3/BUILD:
-tf_staging/third_party/eigen3/Eigen/Cholesky:
-tf_staging/third_party/eigen3/Eigen/Core:
-tf_staging/third_party/eigen3/Eigen/Eigenvalues:
-tf_staging/third_party/eigen3/Eigen/LU:
-tf_staging/third_party/eigen3/Eigen/OrderingMethods:
-tf_staging/third_party/eigen3/Eigen/QR:
-tf_staging/third_party/eigen3/Eigen/SVD:
-tf_staging/third_party/eigen3/Eigen/SparseCholesky:
-tf_staging/third_party/eigen3/Eigen/SparseCore:
 tf_staging/third_party/eigen3/LICENSE:
 tf_staging/third_party/eigen3/eigen_archive.BUILD:
-tf_staging/third_party/eigen3/unsupported/Eigen/CXX11/Tensor:
-tf_staging/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool:
-tf_staging/third_party/eigen3/unsupported/Eigen/MatrixFunctions:
-tf_staging/third_party/eigen3/unsupported/Eigen/SpecialFunctions:
 tf_staging/third_party/fft2d/BUILD:
 tf_staging/third_party/fft2d/LICENSE:
 tf_staging/third_party/fft2d/fft.h:
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 354a02c..c38726b 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -189,7 +189,9 @@
         "//tensorflow/python/grappler:tf_cluster",
         "//tensorflow/python/grappler:tf_item",
         "//tensorflow/python/grappler:tf_optimizer",
-        "//tensorflow/python/lib/io:lib",
+        "//tensorflow/python/lib/io:file_io",
+        "//tensorflow/python/lib/io:python_io",
+        "//tensorflow/python/lib/io:tf_record",
         "//tensorflow/python/module",
         "//tensorflow/python/ops:array_ops",
         "//tensorflow/python/ops:array_ops_stack",
@@ -376,17 +378,19 @@
         "//tensorflow/python/eager:monitoring",
         "//tensorflow/python/eager:remote",
         "//tensorflow/python/feature_column:feature_column_py",
-        "//tensorflow/python/framework",
         "//tensorflow/python/framework:combinations",
         "//tensorflow/python/framework:composite_tensor",
         "//tensorflow/python/framework:config",
         "//tensorflow/python/framework:errors",
         "//tensorflow/python/framework:extension_type",
+        "//tensorflow/python/framework:framework_lib",
         "//tensorflow/python/framework:graph_util",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/framework:test_combinations_lib",
         "//tensorflow/python/framework:versions",
-        "//tensorflow/python/lib/io:lib",
+        "//tensorflow/python/lib/io:file_io",
+        "//tensorflow/python/lib/io:python_io",
+        "//tensorflow/python/lib/io:tf_record",
         "//tensorflow/python/module",
         "//tensorflow/python/ops:audio_ops_gen",
         "//tensorflow/python/ops:bincount_ops",
@@ -524,12 +528,17 @@
         "//conditions:default": ["//tensorflow:libtensorflow_framework.so.%s" % VERSION],
         "//tensorflow:windows": [],
     }),
+    enable_stub_generation = True,
+    pytype_srcs = [
+        "_pywrap_py_exception_registry.pyi",
+    ],
     static_deps = tf_python_pybind_static_deps(),
     # Do not sort: core:py_exception_registry must come before platform:status
     deps = [
         "@com_google_absl//absl/container:fixed_array",
         "@pybind11",
         "//third_party/python_runtime:headers",
+        "//tensorflow/c:tf_status_headers",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/python/lib/core:py_exception_registry",
         "//tensorflow/core/platform:status",
@@ -1078,6 +1087,10 @@
         "//tensorflow/compiler/mlir/python:pywrap_mlir_hdrs",
         "//tensorflow/python/lib/core:safe_pyobject_ptr_required_hdrs",
     ],
+    enable_stub_generation = True,
+    pytype_srcs = [
+        "_pywrap_mlir.pyi",
+    ],
     deps = [
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core/platform:status",
@@ -1331,6 +1344,10 @@
 tf_python_pybind_extension(
     name = "_pywrap_dtensor_device",
     srcs = ["pywrap_dtensor_device.cc"],
+    data = [
+        "_pywrap_dtensor_device.pyi",
+    ],
+    enable_stub_generation = True,
     features = ["-layering_check"],
     deps = [
         ":pywrap_densor_device_headers",
diff --git a/tensorflow/python/_pywrap_dtensor_device.pyi b/tensorflow/python/_pywrap_dtensor_device.pyi
new file mode 100644
index 0000000..bf8f123
--- /dev/null
+++ b/tensorflow/python/_pywrap_dtensor_device.pyi
@@ -0,0 +1,130 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from typing import Any, ClassVar
+
+from typing import overload
+
+class Layout:
+    __hash__: ClassVar[None] = ...
+    @overload
+    def __init__(self, layout: Layout) -> None: ...
+    @overload
+    def __init__(self, type: LayoutType, sharding_specs: list[str], mesh: Mesh) -> None: ...
+    @overload
+    def __init__(self, layout_proto) -> None: ...
+    @overload
+    def __init__(self, layout_str: str) -> None: ...
+    @overload
+    def __init__(self, mesh: Mesh, rank: int) -> None: ...
+    @overload
+    def __init__(self, mesh: Mesh, rank: int, batch_dim: str, axis: int) -> None: ...
+    @overload
+    def __init__(self, mesh: Mesh) -> None: ...
+    def as_proto(self, *args, **kwargs) -> Any: ...
+    def global_shape_from_local_shape(self, local_shape: list[int]) -> tuple: ...
+    def is_batch_parallel(self) -> bool: ...
+    def is_fully_replicated(self) -> bool: ...
+    def is_single_device(self) -> bool: ...
+    def local_shape_from_global_shape(self, global_shape: list[int]) -> tuple: ...
+    def num_shards(self, idx: int) -> int: ...
+    def to_parted(self) -> Layout: ...
+    def to_string(self) -> str: ...
+    def __eq__(self, arg0: Layout) -> bool: ...
+    @property
+    def mesh(self) -> Mesh: ...
+    @property
+    def rank(self) -> int: ...
+    @property
+    def sharding_specs(self) -> list[str]: ...
+    @property
+    def type(self) -> LayoutType: ...
+
+class LayoutType:
+    __members__: ClassVar[dict] = ...  # read-only
+    PARTED: ClassVar[LayoutType] = ...
+    SINGLE_DEVICE: ClassVar[LayoutType] = ...
+    STATIC: ClassVar[LayoutType] = ...
+    __entries: ClassVar[dict] = ...
+    def __init__(self, value: int) -> None: ...
+    def __eq__(self, other: object) -> bool: ...
+    def __getstate__(self) -> int: ...
+    def __hash__(self) -> int: ...
+    def __index__(self) -> int: ...
+    def __int__(self) -> int: ...
+    def __ne__(self, other: object) -> bool: ...
+    def __setstate__(self, state: int) -> None: ...
+    @property
+    def name(self) -> str: ...
+    @property
+    def value(self) -> int: ...
+
+class Mesh:
+    __hash__: ClassVar[None] = ...
+    @overload
+    def __init__(self, mesh: Mesh) -> None: ...
+    @overload
+    def __init__(self, arg0: str, arg1: list[str], arg2: list[int], arg3: list[int], arg4: list[str], arg5: list[int], arg6: list[str], arg7: bool) -> None: ...
+    @overload
+    def __init__(self, single_device: str) -> None: ...
+    @overload
+    def __init__(self, mesh_proto) -> None: ...
+    @overload
+    def __init__(self, mesh_str: str) -> None: ...
+    def as_proto(self, *args, **kwargs) -> Any: ...
+    def contains_dim(self, dim_name: str) -> bool: ...
+    def device_location(self, arg0: int) -> list[int]: ...
+    def device_type(self) -> str: ...
+    def dim_size(self, dim_name: str) -> int: ...
+    def global_device_ids(self): ...
+    def global_devices(self) -> list[str]: ...
+    def host_mesh(self) -> Mesh: ...
+    def is_remote(self) -> bool: ...
+    def is_single_device(self) -> bool: ...
+    def local_device_ids(self): ...
+    def local_devices(self): ...
+    def min_global_device_id(self) -> int: ...
+    def num_local_devices(self) -> int: ...
+    def shape(self) -> list[int]: ...
+    def to_string(self) -> str: ...
+    def use_xla_spmd(self) -> bool: ...
+    def __contains__(self, dim_name: str) -> bool: ...
+    def __eq__(self, arg0: Mesh) -> bool: ...
+    @property
+    def dim_names(self) -> list[str]: ...
+    @property
+    def name(self) -> str: ...
+    @property
+    def single_device(self) -> str: ...
+    @property
+    def size(self) -> int: ...
+
+def AddMesh(arg0, arg1: str, arg2: bool) -> None: ...
+def Allocate(arg0: str, arg1: bool, arg2: int) -> object: ...
+def ClearTPUCoreIDs(arg0) -> None: ...
+def ExperimentalClearDefaultLayout(arg0) -> None: ...
+def ExperimentalClearDefaultMesh(arg0) -> None: ...
+def ExperimentalSetDefaultLayout(arg0, arg1: str) -> None: ...
+def ExperimentalSetDefaultMesh(arg0, arg1: str) -> None: ...
+def FetchLayout(arg0: object, arg1: object, arg2) -> object: ...
+def GetStats(arg0: object, arg1) -> dict[str,int]: ...
+def IsDTensor(arg0: object, arg1: object, arg2) -> bool: ...
+def IsSparseDTensor(arg0: object, arg1: object, arg2) -> bool: ...
+def Pack(arg0: object, arg1: object, arg2: str, arg3, arg4: bool) -> object: ...
+def SetIteratorElementLayouts(arg0: object, arg1: object, arg2: list[str], arg3) -> None: ...
+def SetTPUCoreIDs(arg0, arg1: str, arg2: list[int]) -> None: ...
+def TPUCoreIDsToLocations(arg0: object, arg1, arg2: list[int]) -> list[list[int]]: ...
+def TPUCoreLocationsToIDs(arg0: object, arg1, arg2: list[list[int]]) -> list[int]: ...
+def Unpack(arg0: object, arg1: object, arg2) -> object: ...
diff --git a/tensorflow/python/_pywrap_mlir.pyi b/tensorflow/python/_pywrap_mlir.pyi
new file mode 100644
index 0000000..d1375e1
--- /dev/null
+++ b/tensorflow/python/_pywrap_mlir.pyi
@@ -0,0 +1,28 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from typing import overload
+
+def ExperimentalConvertSavedModelToMlir(arg0: str, arg1: str, arg2: bool) -> str: ...
+def ExperimentalConvertSavedModelV1ToMlir(arg0: str, arg1: str, arg2: str, arg3: bool, arg4: bool, arg5: bool, arg6: bool) -> str: ...
+def ExperimentalConvertSavedModelV1ToMlirLite(arg0: str, arg1: str, arg2: str, arg3: bool, arg4: bool) -> str: ...
+def ExperimentalRunPassPipeline(arg0: str, arg1: str, arg2: bool) -> str: ...
+def ExperimentalTFLiteToTosaBytecode(arg0: str, arg1: str, arg2: bool, arg3: list[str], arg4: list[str]) -> None: ...
+def ExperimentalWriteBytecode(arg0: str, arg1: str) -> None: ...
+def ImportFunction(arg0: object, arg1: str, arg2: str, arg3: bool) -> str: ...
+@overload
+def ImportGraphDef(arg0: str, arg1: str, arg2: bool) -> str: ...
+@overload
+def ImportGraphDef(arg0: str, arg1: str, arg2: bool, arg3: str, arg4: str, arg5: str, arg6: str) -> str: ...
diff --git a/tensorflow/python/_pywrap_py_exception_registry.pyi b/tensorflow/python/_pywrap_py_exception_registry.pyi
new file mode 100644
index 0000000..2fe8027
--- /dev/null
+++ b/tensorflow/python/_pywrap_py_exception_registry.pyi
@@ -0,0 +1,64 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from typing import ClassVar
+
+TF_ABORTED: TF_Code
+TF_CANCELLED: TF_Code
+TF_DATA_LOSS: TF_Code
+TF_DEADLINE_EXCEEDED: TF_Code
+TF_FAILED_PRECONDITION: TF_Code
+TF_INTERNAL: TF_Code
+TF_INVALID_ARGUMENT: TF_Code
+TF_OK: TF_Code
+TF_OUT_OF_RANGE: TF_Code
+TF_PERMISSION_DENIED: TF_Code
+TF_RESOURCE_EXHAUSTED: TF_Code
+TF_UNAUTHENTICATED: TF_Code
+TF_UNIMPLEMENTED: TF_Code
+TF_UNKNOWN: TF_Code
+
+class TF_Code:
+    __members__: ClassVar[dict] = ...  # read-only
+    TF_ABORTED: ClassVar[TF_Code] = ...
+    TF_CANCELLED: ClassVar[TF_Code] = ...
+    TF_DATA_LOSS: ClassVar[TF_Code] = ...
+    TF_DEADLINE_EXCEEDED: ClassVar[TF_Code] = ...
+    TF_FAILED_PRECONDITION: ClassVar[TF_Code] = ...
+    TF_INTERNAL: ClassVar[TF_Code] = ...
+    TF_INVALID_ARGUMENT: ClassVar[TF_Code] = ...
+    TF_OK: ClassVar[TF_Code] = ...
+    TF_OUT_OF_RANGE: ClassVar[TF_Code] = ...
+    TF_PERMISSION_DENIED: ClassVar[TF_Code] = ...
+    TF_RESOURCE_EXHAUSTED: ClassVar[TF_Code] = ...
+    TF_UNAUTHENTICATED: ClassVar[TF_Code] = ...
+    TF_UNIMPLEMENTED: ClassVar[TF_Code] = ...
+    TF_UNKNOWN: ClassVar[TF_Code] = ...
+    __entries: ClassVar[dict] = ...
+    def __init__(self, value: int) -> None: ...
+    def __eq__(self, other: object) -> bool: ...
+    def __getstate__(self) -> int: ...
+    def __hash__(self) -> int: ...
+    def __index__(self) -> int: ...
+    def __int__(self) -> int: ...
+    def __ne__(self, other: object) -> bool: ...
+    def __setstate__(self, state: int) -> None: ...
+    @property
+    def name(self) -> str: ...
+    @property
+    def value(self) -> int: ...
+
+def PyExceptionRegistry_Init(arg0: object) -> None: ...
+def PyExceptionRegistry_Lookup(arg0: TF_Code) -> None: ...
diff --git a/tensorflow/python/autograph/impl/api_test.py b/tensorflow/python/autograph/impl/api_test.py
index 31da05e..cfb8c14 100644
--- a/tensorflow/python/autograph/impl/api_test.py
+++ b/tensorflow/python/autograph/impl/api_test.py
@@ -1291,22 +1291,30 @@
       else:
         return a + a
 
+    @def_function.function
+    def test_func2(a):
+      if constant_op.constant(True):
+        return a
+      else:
+        return a + a
+
     patch = test.mock.patch
     with patch.dict(os.environ, {'AUTOGRAPH_STRICT_CONVERSION': '0'}), \
-         patch.object(inspect, 'findsource', side_effect=OSError()), \
-         patch.object(ag_logging, 'warning') as warning_log_mock:
+        patch.object(inspect, 'findsource', side_effect=OSError()), \
+        self.assertLogs(level='WARNING') as logs:
 
-      with patch.object(ag_ctx, 'INSPECT_SOURCE_SUPPORTED', False):
-        with self.assertRaisesRegex(tf_errors.OperatorNotAllowedInGraphError,
-                                    'source code may not be visible'):
-          test_func(2)
-      warning_log_mock.assert_not_called()
+      with patch.object(ag_ctx, 'INSPECT_SOURCE_SUPPORTED', False), \
+          self.assertRaisesRegex(tf_errors.OperatorNotAllowedInGraphError,
+                                 'source code may not be visible'):
+        test_func(2)
+      self.assertEmpty(logs.output)
 
-      with patch.object(ag_ctx, 'INSPECT_SOURCE_SUPPORTED', True):
-        with self.assertRaisesRegex(tf_errors.OperatorNotAllowedInGraphError,
-                                    'using an unsupported feature'):
-          test_func(2)
-      warning_log_mock.called_once_with('AutoGraph could not transform')
+      with patch.object(ag_ctx, 'INSPECT_SOURCE_SUPPORTED', True), \
+          self.assertRaisesRegex(tf_errors.OperatorNotAllowedInGraphError,
+                                 'using an unsupported feature'):
+        test_func2(2)
+    self.assertLen(logs.output, 1)
+    self.assertRegex(logs.output[0], r'^.+:AutoGraph could not transform.+')
 
 
 if __name__ == '__main__':
diff --git a/tensorflow/python/autograph/pyct/templates_test.py b/tensorflow/python/autograph/pyct/templates_test.py
index 41894a5..f69942e 100644
--- a/tensorflow/python/autograph/pyct/templates_test.py
+++ b/tensorflow/python/autograph/pyct/templates_test.py
@@ -113,14 +113,14 @@
         template,
         block=[
             gast.Assign(
-                [
+                targets=[
                     gast.Name(
                         'a',
                         ctx=ShouldBeReplaced,
                         annotation=None,
                         type_comment=None)
                 ],
-                gast.BinOp(
+                value=gast.BinOp(
                     gast.Name(
                         'a',
                         ctx=ShouldBeReplaced,
diff --git a/tensorflow/python/checkpoint/BUILD b/tensorflow/python/checkpoint/BUILD
index c34f499..4f36736 100644
--- a/tensorflow/python/checkpoint/BUILD
+++ b/tensorflow/python/checkpoint/BUILD
@@ -2,11 +2,11 @@
 #   Utilities for reading and writing object-based checkpoints.
 
 load("//tensorflow:strict.default.bzl", "py_strict_binary", "py_strict_library")
+load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test", "tf_py_strict_test")
 load(
     "//tensorflow/tools/test:performance.bzl",
     "tf_py_logged_benchmark",
 )
-load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test", "tf_py_strict_test")
 
 package(
     # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
@@ -404,15 +404,20 @@
         "//tensorflow/python/eager:context",
         "//tensorflow/python/eager:def_function",
         "//tensorflow/python/framework:constant_op",
+        "//tensorflow/python/framework:device",
         "//tensorflow/python/framework:dtypes",
         "//tensorflow/python/framework:ops",
+        "//tensorflow/python/framework:tensor",
+        "//tensorflow/python/framework:tensor_shape",
         "//tensorflow/python/framework:tensor_spec",
         "//tensorflow/python/framework:tensor_util",
         "//tensorflow/python/ops:array_ops",
         "//tensorflow/python/ops:io_ops",
         "//tensorflow/python/ops:io_ops_gen",
         "//tensorflow/python/ops:string_ops",
+        "//tensorflow/python/ops:variables",
         "//tensorflow/python/saved_model/registration",
+        "//tensorflow/python/trackable:base",
         "//tensorflow/python/trackable:trackable_utils",
         "//tensorflow/python/training/saving:saveable_object",
         "//tensorflow/python/training/saving:saveable_object_util",
diff --git a/tensorflow/python/checkpoint/async_checkpoint_helper.py b/tensorflow/python/checkpoint/async_checkpoint_helper.py
index 45868da..aa801c4 100644
--- a/tensorflow/python/checkpoint/async_checkpoint_helper.py
+++ b/tensorflow/python/checkpoint/async_checkpoint_helper.py
@@ -263,10 +263,10 @@
     # custom __getattr__ code, see b/152031870 for context.
     for t in all_trackables:
       # Special case 1: TPU Embedding, populate object_map here
-    # Special case 1: Handle TPU Embedding by addnig a dummy instance to the
-    # object map. Also add TPUEmbedding to separate list for special handling
-    # with values copy.
-      if hasattr(t, _TPU_EMBEDDING_ATTR):
+      # Special case 1: Handle TPU Embedding by addnig a dummy instance to the
+      # object map. Also add TPUEmbedding to separate list for special handling
+      # with values copy.
+      if hasattr(type(t), _TPU_EMBEDDING_ATTR):
         self._handle_tpu_embedding(t)
       # Special case 2: handle slot variables. The object_map is populated later
       # when the variable values are being copied to host CPU for the first
@@ -414,9 +414,9 @@
     Raises:
       AttributeError: if the input trackable is not TPUEmbedding type.
     """
-    if not hasattr(
-        tpu_embedding, _TPU_EMBEDDING_ATTR
-    ) or not callable(tpu_embedding._create_copy_for_async_checkpoint):  # pylint: disable=protected-access
+    if not hasattr(type(tpu_embedding), _TPU_EMBEDDING_ATTR) or not callable(
+        tpu_embedding._create_copy_for_async_checkpoint  # pylint: disable=protected-access
+    ):
       raise AttributeError(
           "Expecting TPUEmbedding type; got %s" % type(tpu_embedding)
       )
diff --git a/tensorflow/python/checkpoint/functional_saver.py b/tensorflow/python/checkpoint/functional_saver.py
index dfdefdf..2695fce 100644
--- a/tensorflow/python/checkpoint/functional_saver.py
+++ b/tensorflow/python/checkpoint/functional_saver.py
@@ -14,20 +14,28 @@
 # ==============================================================================
 """Saves and restore variables inside traced @tf.functions."""
 
+import dataclasses
+from typing import Callable, Dict, List
+
 from tensorflow.core.protobuf import saver_pb2
 from tensorflow.python.checkpoint import checkpoint_options
 from tensorflow.python.eager import context
 from tensorflow.python.eager import def_function
 from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import device as device_lib
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor as tensor_lib
+from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import tensor_spec
 from tensorflow.python.framework import tensor_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gen_io_ops
 from tensorflow.python.ops import io_ops
 from tensorflow.python.ops import string_ops
+from tensorflow.python.ops import variables
 from tensorflow.python.saved_model import registration
+from tensorflow.python.trackable import base
 from tensorflow.python.trackable import trackable_utils
 from tensorflow.python.training.saving import saveable_object
 from tensorflow.python.training.saving import saveable_object_util
@@ -35,6 +43,39 @@
 from tensorflow.python.util import object_identity
 
 
+@dataclasses.dataclass(frozen=True)
+class ShardableTensor:
+  """Tensor wrapper containing data necessary for sharding."""
+  _tensor_save_spec: saveable_object.SaveSpec
+  tensor: tensor_lib.Tensor
+  dtype: dtypes.DType
+  device: device_lib.DeviceSpec
+  name: str
+  shape: tensor_shape.TensorShape
+  slice_spec: variables.Variable.SaveSliceInfo
+  checkpoint_key: str
+  trackable: base.Trackable
+
+  def __hash__(self):
+    return hash((self.name, self.dtype, str(self.device), self.checkpoint_key))
+
+
+@dataclasses.dataclass(frozen=True)
+class ShardingCallback:
+  """Checkpoint sharding callback function, along with a text description."""
+  callback: Callable[
+      [List[ShardableTensor], ...],
+      List[Dict[str, Dict[tensor_spec.TensorSpec, saveable_object.SaveSpec]]]]
+  description: str
+
+  def __hash__(self):
+    if hasattr(self.callback, "__name__"):
+      callback_hash = hash((self.callback.__module__, self.callback.__name__))
+    else:
+      callback_hash = id(self.callback)
+    return hash((callback_hash, self.description))
+
+
 class _SingleDeviceSaver(object):
   """Saves and restores checkpoints from the current device."""
 
diff --git a/tensorflow/python/client/BUILD b/tensorflow/python/client/BUILD
index ff470d0..782ffe5 100644
--- a/tensorflow/python/client/BUILD
+++ b/tensorflow/python/client/BUILD
@@ -123,11 +123,7 @@
 
 py_strict_library(
     name = "client",
-    srcs = [
-        "client_lib.py",
-        "device_lib.py",
-        "timeline.py",
-    ],
+    srcs = ["client_lib.py"],
     srcs_version = "PY3",
     visibility = [
         "//tensorflow:internal",
@@ -205,7 +201,7 @@
     ],
     python_version = "PY3",
     deps = [
-        ":client",
+        ":device_lib",
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python/framework:test_lib",
         "//tensorflow/python/platform:client_testlib",
@@ -341,13 +337,13 @@
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/eager:def_function",
-        "//tensorflow/python/framework",
         "//tensorflow/python/framework:config",
         "//tensorflow/python/framework:constant_op",
         "//tensorflow/python/framework:device",
         "//tensorflow/python/framework:dtypes",
         "//tensorflow/python/framework:errors",
         "//tensorflow/python/framework:function",
+        "//tensorflow/python/framework:importer",
         "//tensorflow/python/framework:indexed_slices",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/framework:sparse_tensor",
@@ -463,8 +459,8 @@
     ],
     xla_enable_strict_auto_jit = False,  # Graph structure is different with autojit
     deps = [
-        ":client",
         ":session",
+        ":timeline",
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python/framework:constant_op",
         "//tensorflow/python/framework:ops",
diff --git a/tensorflow/python/client/_pywrap_tf_session.pyi b/tensorflow/python/client/_pywrap_tf_session.pyi
index a453092..62ea3ff 100644
--- a/tensorflow/python/client/_pywrap_tf_session.pyi
+++ b/tensorflow/python/client/_pywrap_tf_session.pyi
@@ -95,7 +95,7 @@
     def __contains__(self, arg0: object) -> bool: ...
     def __delitem__(self, arg0: int) -> None: ...
     def __getitem__(self, arg0: int) -> object: ...
-    def __iter__(self) -> Iterator: ...
+    def __iter__(self) -> Iterator[int]: ...
     def __len__(self) -> int: ...
     def __setitem__(self, arg0: int, arg1: object) -> None: ...
 
@@ -111,7 +111,7 @@
     def __contains__(self, arg0: object) -> bool: ...
     def __delitem__(self, arg0: str) -> None: ...
     def __getitem__(self, arg0: str) -> object: ...
-    def __iter__(self) -> Iterator: ...
+    def __iter__(self) -> Iterator[str]: ...
     def __len__(self) -> int: ...
     def __setitem__(self, arg0: str, arg1: object) -> None: ...
 
@@ -376,6 +376,7 @@
 def TF_DeviceListType(arg0: TF_DeviceList, arg1: int) -> str: ...
 def TF_FinishOperation(arg0: TF_OperationDescription) -> TF_Operation: ...
 def TF_FunctionImportFunctionDef(arg0: bytes) -> TF_Function: ...
+def TF_FunctionImportFunctionDefNoSerialization(arg0) -> TF_Function: ...
 def TF_FunctionSetAttrValueProto(arg0: TF_Function, arg1: str, arg2: bytes) -> None: ...
 def TF_FunctionToFunctionDef(arg0: TF_Function, arg1: TF_Buffer) -> None: ...
 def TF_GetAllOpList() -> TF_Buffer: ...
@@ -388,6 +389,7 @@
 def TF_GetXlaConstantFoldingDisabled() -> int: ...
 def TF_GraphCopyFunction(arg0: PyGraph, arg1: TF_Function, arg2: TF_Function) -> None: ...
 def TF_GraphImportGraphDefWithResults(arg0: PyGraph, arg1: TF_Buffer, arg2: TF_ImportGraphDefOptions) -> TF_ImportGraphDefResults: ...
+def TF_GraphImportGraphDefWithResultsNoSerialization(arg0: PyGraph, arg1, arg2: TF_ImportGraphDefOptions) -> TF_ImportGraphDefResults: ...
 def TF_GraphNextOperation(arg0: PyGraph, arg1: int) -> tuple: ...
 def TF_GraphRemoveFunction(arg0: PyGraph, arg1: str) -> None: ...
 def TF_GraphSetOutputHandleShapesAndTypes_wrapper(arg0: PyGraph, arg1: TF_Output, arg2: list[Optional[list[int]]], arg3: list[int], arg4: object) -> None: ...
diff --git a/tensorflow/python/client/tf_session_wrapper.cc b/tensorflow/python/client/tf_session_wrapper.cc
index 6bbe4d6..790629c 100644
--- a/tensorflow/python/client/tf_session_wrapper.cc
+++ b/tensorflow/python/client/tf_session_wrapper.cc
@@ -38,6 +38,7 @@
 #include "tensorflow/c/c_api_internal.h"
 #include "tensorflow/c/python_api.h"
 #include "tensorflow/c/safe_ptr.h"
+#include "tensorflow/c/tf_buffer.h"
 #include "tensorflow/c/tf_datatype.h"
 #include "tensorflow/core/distributed_runtime/server_lib.h"
 #include "tensorflow/core/framework/full_type.pb.h"
@@ -1709,6 +1710,25 @@
       py::return_value_policy::reference);
 
   m.def(
+      "TF_GraphImportGraphDefWithResultsNoSerialization",
+      [](PyGraph* graph, const tensorflow::GraphDef* graph_def,
+         const TF_ImportGraphDefOptions* options) {
+        tensorflow::Safe_TF_StatusPtr status =
+            tensorflow::make_safe(TF_NewStatus());
+        TF_ImportGraphDefResults* output;
+        {
+          TF_Buffer graph_def_buffer;
+          graph_def_buffer.data = reinterpret_cast<const void*>(graph_def);
+          graph_def_buffer.length = sizeof(tensorflow::GraphDef*);
+          output = TF_GraphImportGraphDefWithResultsNoSerialization(
+              graph->tf_graph(), &graph_def_buffer, options, status.get());
+        }
+        tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
+        return output;
+      },
+      py::return_value_policy::reference);
+
+  m.def(
       "TF_GraphNextOperation",
       [](PyGraph* graph, size_t pos) {
         tensorflow::Safe_TF_StatusPtr status =
@@ -1848,6 +1868,25 @@
       },
       py::return_value_policy::reference);
 
+  m.def(
+      "TF_FunctionImportFunctionDefNoSerialization",
+      [](tensorflow::FunctionDef fdef) {
+        tensorflow::Safe_TF_StatusPtr status =
+            tensorflow::make_safe(TF_NewStatus());
+
+        // Release GIL.
+        py::gil_scoped_release release;
+        TF_Function* func = new TF_Function();
+        func->record =
+            new tensorflow::FunctionRecord(std::move(fdef), {}, false);
+        status.get()->status = ::tensorflow::OkStatus();
+        // Acquire GIL for returning output returning.
+        pybind11::gil_scoped_acquire acquire;
+        tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
+        return func;
+      },
+      py::return_value_policy::reference);
+
   m.def("EqualAttrValueWrapper", tensorflow::EqualAttrValueWrapper,
         py::call_guard<py::gil_scoped_release>());
 
diff --git a/tensorflow/python/compat/BUILD b/tensorflow/python/compat/BUILD
index 404927a..f6930ea 100644
--- a/tensorflow/python/compat/BUILD
+++ b/tensorflow/python/compat/BUILD
@@ -24,7 +24,7 @@
         "//tensorflow/python/framework:tensor",
         "//tensorflow/python/framework:tensor_shape",
         "//tensorflow/python/ops:control_flow_v2_toggles",
-        "//tensorflow/python/ops:variable_scope",
+        "//tensorflow/python/ops:resource_variables_toggle",
         "//tensorflow/python/util:tf_export",
     ],
 )
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index c16566a..59c8fad 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -29,7 +29,7 @@
 # This value changes every day with an automatic CL. It can be modified in code
 # via `forward_compatibility_horizon()` or with the environment variable
 # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 10, 19)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 11, 2)
 _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
 _FORWARD_COMPATIBILITY_DATE_NUMBER = None
 
diff --git a/tensorflow/python/compat/v2_compat.py b/tensorflow/python/compat/v2_compat.py
index 481c3c6..cef625b 100644
--- a/tensorflow/python/compat/v2_compat.py
+++ b/tensorflow/python/compat/v2_compat.py
@@ -26,7 +26,7 @@
 from tensorflow.python.framework import tensor
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.ops import control_flow_v2_toggles
-from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import resource_variables_toggle
 
 from tensorflow.python.util.tf_export import tf_export
 
@@ -60,7 +60,7 @@
   tf2.enable()
   ops.enable_eager_execution()
   tensor_shape.enable_v2_tensorshape()  # Also switched by tf2
-  variable_scope.enable_resource_variables()
+  resource_variables_toggle.enable_resource_variables()
   tensor.enable_tensor_equality()
   # Enables TensorArrayV2 and control flow V2.
   control_flow_v2_toggles.enable_control_flow_v2()
@@ -105,7 +105,7 @@
   tf2.disable()
   ops.disable_eager_execution()
   tensor_shape.disable_v2_tensorshape()  # Also switched by tf2
-  variable_scope.disable_resource_variables()
+  resource_variables_toggle.disable_resource_variables()
   tensor.disable_tensor_equality()
   # Disables TensorArrayV2 and control flow V2.
   control_flow_v2_toggles.disable_control_flow_v2()
diff --git a/tensorflow/python/compiler/tensorrt/BUILD b/tensorflow/python/compiler/tensorrt/BUILD
index f3ca24c..9fbdcf5 100644
--- a/tensorflow/python/compiler/tensorrt/BUILD
+++ b/tensorflow/python/compiler/tensorrt/BUILD
@@ -44,10 +44,10 @@
         "//tensorflow/python/client:session",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/eager:wrap_function",
-        "//tensorflow/python/framework",
         "//tensorflow/python/framework:convert_to_constants",
         "//tensorflow/python/framework:dtypes",
         "//tensorflow/python/framework:errors",
+        "//tensorflow/python/framework:importer",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/framework:tensor",
         "//tensorflow/python/grappler:tf_optimizer",
@@ -83,8 +83,8 @@
         "//tensorflow/compiler/tf2tensorrt:_pywrap_py_utils",
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python/eager:def_function",
-        "//tensorflow/python/framework",
         "//tensorflow/python/framework:config",
+        "//tensorflow/python/framework:graph_io",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/framework:tensor_spec",
         "//tensorflow/python/framework:test_lib",
@@ -189,9 +189,9 @@
         "//tensorflow/python/estimator",
         "//tensorflow/python/estimator:model_fn",
         "//tensorflow/python/estimator:run_config",
-        "//tensorflow/python/framework",
         "//tensorflow/python/framework:convert_to_constants",
         "//tensorflow/python/framework:dtypes",
+        "//tensorflow/python/framework:importer",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/framework:test_lib",
         "//tensorflow/python/keras:metrics",
diff --git a/tensorflow/python/compiler/tensorrt/model_tests/BUILD b/tensorflow/python/compiler/tensorrt/model_tests/BUILD
index bce6b05..d4034d6 100644
--- a/tensorflow/python/compiler/tensorrt/model_tests/BUILD
+++ b/tensorflow/python/compiler/tensorrt/model_tests/BUILD
@@ -22,9 +22,9 @@
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python/client:session",
         "//tensorflow/python/compiler/tensorrt:trt_convert_py",
-        "//tensorflow/python/framework",
         "//tensorflow/python/framework:convert_to_constants",
         "//tensorflow/python/framework:dtypes",
+        "//tensorflow/python/framework:importer",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/ops:math_ops",
         "//tensorflow/python/ops:random_ops",
diff --git a/tensorflow/python/compiler/tensorrt/test/BUILD b/tensorflow/python/compiler/tensorrt/test/BUILD
index cadb9fe..4bea640 100644
--- a/tensorflow/python/compiler/tensorrt/test/BUILD
+++ b/tensorflow/python/compiler/tensorrt/test/BUILD
@@ -31,8 +31,8 @@
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python/compiler/tensorrt:trt_convert_py",
         "//tensorflow/python/eager:def_function",
-        "//tensorflow/python/framework",
         "//tensorflow/python/framework:config",
+        "//tensorflow/python/framework:graph_io",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/framework:tensor_spec",
         "//tensorflow/python/framework:test_lib",
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_fusion_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_fusion_test.py
index 836654c..03f795a 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/map_fusion_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_fusion_test.py
@@ -55,7 +55,8 @@
   def reduce_fn(x, y):
     name, functions = y
     return x + combinations.combine(
-        functions=combinations.NamedObject(name, functions))
+        functions=combinations.NamedObject(name, functions)
+    )
 
   return functools.reduce(reduce_fn, cases, [])
 
@@ -63,13 +64,36 @@
 class MapFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
 
   @combinations.generate(
-      combinations.times(test_base.default_test_combinations(),
-                         _test_combinations()))
-  def testMapFusion(self, functions):
-    dataset = dataset_ops.Dataset.range(5).apply(
-        testing.assert_next(["Map", "MemoryCacheImpl"]))
+      combinations.times(
+          test_base.default_test_combinations(),
+          _test_combinations(),
+          combinations.combine(
+              num_parallel_calls=[None, 2, dataset_ops.AUTOTUNE]
+          ),
+          combinations.combine(deterministic=[None, True, False]),
+      )
+  )
+  def testMapFusion(self, functions, num_parallel_calls, deterministic):
+    dataset = dataset_ops.Dataset.range(5)
+    if num_parallel_calls is None:
+      dataset = dataset.apply(testing.assert_next(["Map", "MemoryCacheImpl"]))
+    elif num_parallel_calls in [dataset_ops.AUTOTUNE]:
+      # TODO(b/148614504): Support fusion of parallel maps with
+      # non-AUTOTUNE value.
+      dataset = dataset.apply(
+          testing.assert_next(["ParallelMap", "MemoryCacheImpl"])
+      )
+    else:
+      dataset = dataset.apply(
+          testing.assert_next(["ParallelMap", "ParallelMap"])
+      )
+
     for function in functions:
-      dataset = dataset.map(function)
+      dataset = dataset.map(
+          function,
+          num_parallel_calls=num_parallel_calls,
+          deterministic=deterministic,
+      )
 
     dataset = dataset.cache()
     options = options_lib.Options()
@@ -85,18 +109,36 @@
         else:
           r = function(r)
       expected_output.append(r)
-    self.assertDatasetProduces(dataset, expected_output=expected_output)
 
-  @combinations.generate(test_base.default_test_combinations())
-  def testCapturedInputs(self):
+    if num_parallel_calls is None or deterministic in [None, True]:
+      self.assertDatasetProduces(dataset, expected_output=expected_output)
+
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          combinations.combine(
+              num_parallel_calls=[None, 2, dataset_ops.AUTOTUNE]
+          ),
+      )
+  )
+  def testCapturedInputs(self, num_parallel_calls):
     a = constant_op.constant(3, dtype=dtypes.int64)
     b = constant_op.constant(4, dtype=dtypes.int64)
     some_tensor = math_ops.mul(a, b)
 
+    dataset = dataset_ops.Dataset.range(1)
     # We currently do not support functions with captured inputs.
-    dataset = dataset_ops.Dataset.range(1).apply(
-        testing.assert_next(["Map", "Map"
-                            ])).map(lambda x: some_tensor).map(lambda x: x)
+    if num_parallel_calls in [2, dataset_ops.AUTOTUNE]:
+      dataset = dataset.apply(
+          testing.assert_next(["ParallelMap", "ParallelMap"])
+      )
+    else:
+      dataset = dataset.apply(testing.assert_next(["Map", "Map"]))
+
+    dataset = dataset.map(
+        lambda x: some_tensor, num_parallel_calls=num_parallel_calls
+    ).map(lambda x: x, num_parallel_calls=num_parallel_calls)
+
     options = options_lib.Options()
     options.experimental_optimization.apply_default_optimizations = False
     options.experimental_optimization.map_fusion = True
diff --git a/tensorflow/python/data/experimental/kernel_tests/service/fault_tolerance_test.py b/tensorflow/python/data/experimental/kernel_tests/service/fault_tolerance_test.py
index 4000a9f..ef947f5 100644
--- a/tensorflow/python/data/experimental/kernel_tests/service/fault_tolerance_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/service/fault_tolerance_test.py
@@ -121,6 +121,18 @@
     self.assertDatasetProduces(ds, list(range(num_elements)))
 
   @combinations.generate(test_base.eager_only_combinations())
+  def testDispatcherRestartWithMultipleDatasets(self):
+    cluster = data_service_test_base.TestCluster(num_workers=1)
+    num_elements = 100
+    datasets = []
+    for _ in range(10):
+      datasets.append(self.make_distributed_range_dataset(100, cluster))
+      cluster.restart_dispatcher()
+
+    for ds in datasets:
+      self.assertDatasetProduces(ds, list(range(num_elements)))
+
+  @combinations.generate(test_base.eager_only_combinations())
   def testDispatcherManyRestarts(self):
     cluster = data_service_test_base.TestCluster(num_workers=1)
     num_elements_start = 10
diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py
index 53f6caa..9c2b344 100644
--- a/tensorflow/python/distribute/tpu_strategy.py
+++ b/tensorflow/python/distribute/tpu_strategy.py
@@ -369,7 +369,6 @@
             tpu_cluster_resolver,
             device_assignment=experimental_device_assignment,
             use_spmd_for_xla_partitioning=experimental_spmd_xla_partitioning,
-            enable_data_reorder=experimental_device_assignment is not None,
         )
     )
     distribute_lib.distribution_strategy_gauge.get_cell("V2").set("TPUStrategy")
@@ -710,7 +709,6 @@
             self,
             tpu_cluster_resolver,
             device_assignment=device_assignment,
-            enable_data_reorder=device_assignment is not None,
         )
     )
     distribute_lib.distribution_strategy_gauge.get_cell("V2").set("TPUStrategy")
@@ -864,7 +862,6 @@
       steps_per_run=None,
       device_assignment=None,
       use_spmd_for_xla_partitioning=False,
-      enable_data_reorder=False,
   ):
     super().__init__(container_strategy)
 
@@ -926,15 +923,6 @@
       self._host_input_worker_devices.setdefault(host_device, [])
       self._host_input_worker_devices[host_device].append(host_device)
 
-    # Create the replica order based on the assigned device order.
-    # This replica order will be used to match the IteratorGetNext ops
-    # with the device assigment.
-    self._replica_order = (
-        self._get_replica_order(self._tpu_devices[:, 0])
-        if enable_data_reorder
-        else None
-    )
-
     # TODO(sourabhbajaj): Remove this once performance of running one step
     # at a time is comparable to multiple steps.
     self.steps_per_run = steps_per_run
@@ -965,7 +953,11 @@
         self._using_custom_device = True
         break
 
-  def _get_replica_order(self, tpu_devices):
+    # This is a flag to enable data reorder which is used
+    # to match IteratorGetNext's device with the TPUExecute device.
+    self._enable_data_reorder = False
+
+  def _get_replica_order(self):
     """Get the replica order based on the tpu device order.
 
     For example, if the tpu_devices are:
@@ -985,13 +977,14 @@
     iterators,
     so that they can be placed on the same node as their computation graphs.
 
-    Args:
-      tpu_devices (List[str]): A list of tpu device names in the order of
-        replicas.
-
     Returns:
       A list containing the order ids of corresponding TPU devices.
     """
+    if not self._enable_data_reorder:
+      return None
+
+    tpu_devices = self._tpu_devices[:, 0]
+
     devices_with_ids = []
     for i, tpu_device in enumerate(tpu_devices):
       spec = tf_device.DeviceSpec.from_string(tpu_device)
@@ -1083,7 +1076,7 @@
         self._container_strategy(),
         num_replicas_in_sync=self._num_replicas_in_sync,
         options=options,
-        replica_order=self._replica_order,
+        replica_order=self._get_replica_order(),
     )
 
   def _distribute_datasets_from_function(self, dataset_fn, options):
@@ -1109,7 +1102,7 @@
         input_contexts,
         self._container_strategy(),
         options=options,
-        replica_order=self._replica_order,
+        replica_order=self._get_replica_order(),
     )
 
     # We can only check after the dataset_fn is called.
diff --git a/tensorflow/python/distribute/tpu_strategy_test.py b/tensorflow/python/distribute/tpu_strategy_test.py
index a3fc8ce..131c1db 100644
--- a/tensorflow/python/distribute/tpu_strategy_test.py
+++ b/tensorflow/python/distribute/tpu_strategy_test.py
@@ -1208,6 +1208,7 @@
     strategy = tpu_lib.TPUStrategyV2(
         resolver, experimental_device_assignment=device_assignment
     )
+    strategy.extended._enable_data_reorder = True
 
     dist_dataset = create_dist_dataset_fn(strategy)
     iterator = iter(dist_dataset)
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index 5053bb3..cb1e3df 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -1053,9 +1053,9 @@
         "//tensorflow/core:protos_all_py",
         "//tensorflow/core/function/polymorphism:function_type",
         "//tensorflow/python/eager/polymorphic_function:atomic_function",
-        "//tensorflow/python/framework",
         "//tensorflow/python/framework:composite_tensor",
         "//tensorflow/python/framework:func_graph",
+        "//tensorflow/python/framework:importer",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/framework:sparse_tensor",
         "//tensorflow/python/framework:tensor",
@@ -1081,9 +1081,9 @@
         ":wrap_function",
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python/data/ops:dataset_ops",
-        "//tensorflow/python/framework",
         "//tensorflow/python/framework:constant_op",
         "//tensorflow/python/framework:dtypes",
+        "//tensorflow/python/framework:importer",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/framework:tensor_spec",
         "//tensorflow/python/framework:test_lib",
@@ -1110,10 +1110,10 @@
         ":def_function",
         ":wrap_function",
         "//tensorflow/python/data/ops:dataset_ops",
-        "//tensorflow/python/framework",
         "//tensorflow/python/framework:config",
         "//tensorflow/python/framework:constant_op",
         "//tensorflow/python/framework:dtypes",
+        "//tensorflow/python/framework:importer",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/platform:client_testlib",
         "@absl_py//absl/testing:parameterized",
diff --git a/tensorflow/python/eager/forwardprop_test.py b/tensorflow/python/eager/forwardprop_test.py
index facd626..70f6e0e 100644
--- a/tensorflow/python/eager/forwardprop_test.py
+++ b/tensorflow/python/eager/forwardprop_test.py
@@ -24,7 +24,6 @@
 from tensorflow.python import pywrap_tfe
 from tensorflow.python.distribute import mirrored_strategy
 from tensorflow.python.eager import backprop
-from tensorflow.python.eager import context
 from tensorflow.python.eager import def_function
 from tensorflow.python.eager import forwardprop
 from tensorflow.python.eager import forwardprop_util
@@ -440,20 +439,14 @@
       return math_ops.reduce_prod(
           pointwise + math_ops.reduce_sum(pointwise), axis=1)
 
-    if (context.run_eager_op_as_function_enabled() and
-        test_util.is_xla_enabled()):
-      # Autoclustering kicks in when eager_op_as_function is enabled.
-      # Under XLA the symbolic tolerances are less than under TF.
-      # Ref: b/202559426
-      _test_gradients(
-          self,
-          f, [constant_op.constant([[2.0, 3.0], [1.0, 4.0]])],
-          order=3,
-          srtol=1e-6,
-          satol=1e-3)
-    else:
-      _test_gradients(
-          self, f, [constant_op.constant([[2.0, 3.0], [1.0, 4.0]])], order=3)
+    _test_gradients(
+        self,
+        f,
+        [constant_op.constant([[2.0, 3.0], [1.0, 4.0]])],
+        order=3,
+        srtol=1e-6,
+        satol=1e-3,
+    )
 
   @test_util.assert_no_new_pyobjects_executing_eagerly
   def testNumericHigherOrderFloat64(self):
diff --git a/tensorflow/python/eager/polymorphic_function/BUILD b/tensorflow/python/eager/polymorphic_function/BUILD
index 57c3ebc..974d99b 100644
--- a/tensorflow/python/eager/polymorphic_function/BUILD
+++ b/tensorflow/python/eager/polymorphic_function/BUILD
@@ -290,6 +290,7 @@
         "//tensorflow/python/ops:script_ops",
         "//tensorflow/python/ops:sendrecv_ops_gen",
         "//tensorflow/python/ops:string_ops",
+        "//tensorflow/python/ops:training_ops_gen",
         "//tensorflow/python/ops:variable_scope",
         "//tensorflow/python/ops:variables",
         "//tensorflow/python/ops/ragged:ragged_factory_ops",
@@ -300,7 +301,6 @@
         "//tensorflow/python/saved_model:save",
         "//tensorflow/python/saved_model:save_context",
         "//tensorflow/python/saved_model:save_options",
-        "//tensorflow/python/training:training_ops",
         "//tensorflow/python/util:compat",
         "//tensorflow/python/util:nest",
         "//tensorflow/python/util:tf_decorator",
diff --git a/tensorflow/python/eager/polymorphic_function/polymorphic_function_test.py b/tensorflow/python/eager/polymorphic_function/polymorphic_function_test.py
index 1fb7777..64aab16 100644
--- a/tensorflow/python/eager/polymorphic_function/polymorphic_function_test.py
+++ b/tensorflow/python/eager/polymorphic_function/polymorphic_function_test.py
@@ -69,6 +69,7 @@
 from tensorflow.python.ops import functional_ops
 from tensorflow.python.ops import gen_random_ops
 from tensorflow.python.ops import gen_sendrecv_ops
+from tensorflow.python.ops import gen_training_ops
 from tensorflow.python.ops import gradients_impl
 from tensorflow.python.ops import list_ops
 from tensorflow.python.ops import math_ops
@@ -86,7 +87,6 @@
 from tensorflow.python.saved_model import save_options
 from tensorflow.python.saved_model.load import load
 from tensorflow.python.saved_model.save import save
-from tensorflow.python.training import training_ops
 from tensorflow.python.util import compat
 from tensorflow.python.util import nest
 from tensorflow.python.util import tf_decorator
@@ -1483,7 +1483,7 @@
 
     @polymorphic_function.function
     def resource_apply_adam():
-      training_ops.resource_apply_adam(
+      gen_training_ops.resource_apply_adam(
           v_cpu.handle,
           v_gpu.handle,
           v_also_cpu.handle,
diff --git a/tensorflow/python/eager/polymorphic_function/polymorphic_function_xla_jit_test.py b/tensorflow/python/eager/polymorphic_function/polymorphic_function_xla_jit_test.py
index 9c777e6..e3bc7c7 100644
--- a/tensorflow/python/eager/polymorphic_function/polymorphic_function_xla_jit_test.py
+++ b/tensorflow/python/eager/polymorphic_function/polymorphic_function_xla_jit_test.py
@@ -213,6 +213,21 @@
       self.assertNotEqual(matches[0], matches[1])
       self._compareTwoMethodsCompilerIROutput(fn, [inputs, inputs], {})
 
+  def testCollectiveReduceReplicaGroups(self):
+    with ops.device('device:{}:0'.format(self.device)):
+
+      @polymorphic_function.function(jit_compile=True)
+      def fn(x):
+        t0 = collective_ops.all_reduce_v2(
+            t=x, group_size=2, group_key=1, instance_key=1)
+        return t0
+
+      inputs = constant_op.constant([1.0, 2.0, 3.0])
+      # Make sure replica groups are assigned
+      hlo_str = fn.experimental_get_compiler_ir(inputs)()
+      self.assertIn('replica_groups={{', hlo_str)
+      self._compareTwoMethodsCompilerIROutput(fn, [inputs], {})
+
   def testCollectiveReduceGroupAssignment(self):
     if not test_util.is_mlir_bridge_enabled():
       self.skipTest('AssignGroup is only supported in the MLIR bridge.')
diff --git a/tensorflow/python/flags_pybind.pyi b/tensorflow/python/flags_pybind.pyi
index 90aa0a7..1d78af8 100644
--- a/tensorflow/python/flags_pybind.pyi
+++ b/tensorflow/python/flags_pybind.pyi
@@ -24,6 +24,7 @@
     graph_building_optimization: Flag
     more_stack_traces: Flag
     op_building_optimization: Flag
+    publish_function_graphs: Flag
     replicate_small_constants: Flag
     saved_model_fingerprinting: Flag
     test_only_experiment_1: Flag
diff --git a/tensorflow/python/framework/BUILD b/tensorflow/python/framework/BUILD
index 3e32c7f..2a44b2d 100644
--- a/tensorflow/python/framework/BUILD
+++ b/tensorflow/python/framework/BUILD
@@ -532,9 +532,16 @@
     ],
 )
 
+tf_cc_binary(
+    name = "_native_proto_caster",
+    linkshared = True,
+    deps = ["@pybind11_protobuf//pybind11_protobuf:native_proto_caster"],
+)
+
 py_strict_library(
     name = "function",
     srcs = ["function.py"],
+    data = [":_native_proto_caster"],  # copybara:comment
     srcs_version = "PY3",
     visibility = visibility + [
         "//smartass/brain:__subpackages__",
@@ -558,6 +565,7 @@
         "//tensorflow/python/util:compat",
         "//tensorflow/python/util:function_utils",
         "//tensorflow/python/util:tf_decorator",
+        # copybara:uncomment "@pybind11_protobuf//pybind11_protobuf:native_proto_caster",
     ],
 )
 
@@ -1929,7 +1937,7 @@
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python/client:pywrap_tf_session",
         "//tensorflow/python/eager:context",
-        "//tensorflow/python/lib/io:lib",
+        "//tensorflow/python/lib/io:file_io",
         "//tensorflow/python/platform:tf_logging",
         "//tensorflow/python/util:compat",
         "@pypi_packaging//:pkg",
@@ -1963,7 +1971,7 @@
     deps = [
         ":byte_swap_tensor",
         ":ops",
-        "//tensorflow/python/lib/io:lib",
+        "//tensorflow/python/lib/io:file_io",
         "//tensorflow/python/util:tf_export",
     ],
 )
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index 068a685..848a4c8 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -38,6 +38,8 @@
 from tensorflow.python.util import tf_contextlib
 from tensorflow.python.util import tf_inspect
 
+is_oss = True  # updated by copybara
+
 
 # TODO(b/136040013): Drop support for Defun.
 class Defun(object):
@@ -1161,8 +1163,11 @@
   result = _DefinedFunction(func, argnames, input_types, func_name, grad_func,
                             python_grad_func, out_names)
   # pylint: disable=protected-access
-  serialized = fdef.SerializeToString()
-  c_func = c_api.TF_FunctionImportFunctionDef(serialized)
+  if is_oss:
+    serialized = fdef.SerializeToString()
+    c_func = c_api.TF_FunctionImportFunctionDef(serialized)
+  else:
+    c_func = c_api.TF_FunctionImportFunctionDefNoSerialization(fdef)
   result._c_func = c_api_util.ScopedTFFunction(c_func, func_name)
   result._extra_inputs = []
   result._op_def = fdef.signature
diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py
index 359e9f4..d49c62d 100644
--- a/tensorflow/python/framework/importer.py
+++ b/tensorflow/python/framework/importer.py
@@ -31,6 +31,10 @@
 from tensorflow.python.util.tf_export import tf_export
 
 
+# TODO(b/307794935): Remove after bug is fixed.
+is_oss = True  # Updated by copybara.
+
+
 def _IsControlInput(input_name):
   # Expected format: '^operation_name' (control input).
   return input_name.startswith('^')
@@ -505,15 +509,26 @@
   # TF_GraphImportGraphDefWithResults call and mutating the them in
   # _ProcessNewOps.
   with graph._mutation_lock():  # pylint: disable=protected-access
-    with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
-      try:
-        with graph._c_graph.get() as c_graph:  # pylint: disable=protected-access
-          results = c_api.TF_GraphImportGraphDefWithResults(
-              c_graph, serialized, options)
-        results = c_api_util.ScopedTFImportGraphDefResults(results)
-      except errors.InvalidArgumentError as e:
-        # Convert to ValueError for backwards compatibility.
-        raise ValueError(str(e))
+    if is_oss:
+      graph_def_input = c_api.TF_NewBufferFromString(
+          compat.as_bytes(graph_def.SerializeToString())
+      )
+      graph_import_graphdef = c_api.TF_GraphImportGraphDefWithResults
+    else:
+      graph_def_input = graph_def
+      graph_import_graphdef = (
+          c_api.TF_GraphImportGraphDefWithResultsNoSerialization
+      )
+    try:
+      with graph._c_graph.get() as c_graph:  # pylint: disable=protected-access
+        results = graph_import_graphdef(c_graph, graph_def_input, options)
+      results = c_api_util.ScopedTFImportGraphDefResults(results)
+    except errors.InvalidArgumentError as e:
+      # Convert to ValueError for backwards compatibility.
+      raise ValueError(str(e))
+    finally:
+      if is_oss:
+        c_api.TF_DeleteBuffer(graph_def_input)
 
     # Create _DefinedFunctions for any imported functions.
     #
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 622208c..c736d55 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -81,6 +81,9 @@
 from tensorflow.python.util.tf_export import tf_export
 
 
+# TODO(b/307794935): Remove after bug is fixed.
+is_oss = True  # Updated by copybara
+
 # Temporary global switches determining if we should enable the work-in-progress
 # calls to the C API. These will be removed once all functionality is supported.
 _USE_C_API: bool = True
@@ -2406,7 +2409,9 @@
 
     return graph, self.version
 
-  def as_graph_def(self, from_version=None, add_shapes=False):
+  def as_graph_def(
+      self, from_version=None, add_shapes=False, use_pybind11_proto=False
+  ):
     # pylint: disable=line-too-long
     """Returns a serialized `GraphDef` representation of this graph.
 
@@ -2422,6 +2427,9 @@
         property had the given value.
       add_shapes: If true, adds an "_output_shapes" list attr to each node with
         the inferred shapes of each of its outputs.
+      use_pybind11_proto: If true, If true, uses the c++ pybind11_proto api to
+        get the GraphDef proto directly from c++, instead of through a TF
+        buffer. See https://github.com/pybind/pybind11_protobuf for reference.
 
     Returns:
       A
@@ -2432,7 +2440,11 @@
       ValueError: If the `graph_def` would be too large.
     """
     # pylint: enable=line-too-long
-    result, _ = self._as_graph_def(from_version, add_shapes)
+    if is_oss:
+      use_pybind11_proto = False
+    result, _ = self._as_graph_def(
+        from_version, add_shapes, use_pybind11_proto=use_pybind11_proto
+    )
     return result
 
   def _is_function(self, name):
diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD
index 1785c09..4aea8bd 100644
--- a/tensorflow/python/keras/distribute/BUILD
+++ b/tensorflow/python/keras/distribute/BUILD
@@ -84,7 +84,7 @@
         "//tensorflow/python/framework:errors",
         "//tensorflow/python/keras:backend",
         "//tensorflow/python/keras/utils:mode_keys",
-        "//tensorflow/python/lib/io:lib",
+        "//tensorflow/python/lib/io:file_io",
         "//tensorflow/python/ops:variables",
         "//tensorflow/python/training:checkpoint_management",
     ],
@@ -98,7 +98,7 @@
     srcs_version = "PY3",
     deps = [
         "//tensorflow/python/distribute:distribute_lib",
-        "//tensorflow/python/lib/io:lib",
+        "//tensorflow/python/lib/io:file_io",
     ],
 )
 
diff --git a/tensorflow/python/keras/engine/BUILD b/tensorflow/python/keras/engine/BUILD
index a24dc27..50f4725 100644
--- a/tensorflow/python/keras/engine/BUILD
+++ b/tensorflow/python/keras/engine/BUILD
@@ -214,7 +214,6 @@
         "//tensorflow/python/framework:tensor_shape",
         "//tensorflow/python/framework:tensor_spec",
         "//tensorflow/python/keras:backend",
-        "//tensorflow/python/lib/io:lib",
         "//tensorflow/python/util:nest",
         "//tensorflow/python/util:tf_export",
     ],
@@ -229,7 +228,6 @@
         "//tensorflow/python/framework:tensor",
         "//tensorflow/python/framework:tensor_shape",
         "//tensorflow/python/keras/utils:object_identity",
-        "//tensorflow/python/lib/io:lib",
         "//tensorflow/python/util:nest",
     ],
 )
diff --git a/tensorflow/python/keras/optimizer_v2/adadelta.py b/tensorflow/python/keras/optimizer_v2/adadelta.py
index f2264bd..a0671a1 100644
--- a/tensorflow/python/keras/optimizer_v2/adadelta.py
+++ b/tensorflow/python/keras/optimizer_v2/adadelta.py
@@ -20,7 +20,7 @@
 from tensorflow.python.keras import backend_config
 from tensorflow.python.keras.optimizer_v2 import optimizer_v2
 from tensorflow.python.ops import array_ops
-from tensorflow.python.training import gen_training_ops
+from tensorflow.python.ops import gen_training_ops
 
 
 class Adadelta(optimizer_v2.OptimizerV2):
diff --git a/tensorflow/python/keras/optimizer_v2/adagrad.py b/tensorflow/python/keras/optimizer_v2/adagrad.py
index c59e165..93f8aac 100644
--- a/tensorflow/python/keras/optimizer_v2/adagrad.py
+++ b/tensorflow/python/keras/optimizer_v2/adagrad.py
@@ -22,8 +22,8 @@
 from tensorflow.python.keras import backend_config
 from tensorflow.python.keras.optimizer_v2 import optimizer_v2
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_training_ops
 from tensorflow.python.ops import init_ops
-from tensorflow.python.training import gen_training_ops
 
 
 class Adagrad(optimizer_v2.OptimizerV2):
diff --git a/tensorflow/python/keras/optimizer_v2/adam.py b/tensorflow/python/keras/optimizer_v2/adam.py
index 7b1e90a..8820b9b 100644
--- a/tensorflow/python/keras/optimizer_v2/adam.py
+++ b/tensorflow/python/keras/optimizer_v2/adam.py
@@ -23,9 +23,9 @@
 from tensorflow.python.keras.optimizer_v2 import optimizer_v2
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_training_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import state_ops
-from tensorflow.python.training import gen_training_ops
 
 
 class Adam(optimizer_v2.OptimizerV2):
diff --git a/tensorflow/python/keras/optimizer_v2/adamax.py b/tensorflow/python/keras/optimizer_v2/adamax.py
index f5e0093..016e932 100644
--- a/tensorflow/python/keras/optimizer_v2/adamax.py
+++ b/tensorflow/python/keras/optimizer_v2/adamax.py
@@ -22,8 +22,8 @@
 from tensorflow.python.keras.optimizer_v2 import optimizer_v2
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_training_ops
 from tensorflow.python.ops import math_ops
-from tensorflow.python.training import gen_training_ops
 
 
 class Adamax(optimizer_v2.OptimizerV2):
diff --git a/tensorflow/python/keras/optimizer_v2/ftrl.py b/tensorflow/python/keras/optimizer_v2/ftrl.py
index 6e9ba72..d6bb32c 100644
--- a/tensorflow/python/keras/optimizer_v2/ftrl.py
+++ b/tensorflow/python/keras/optimizer_v2/ftrl.py
@@ -17,9 +17,9 @@
 
 from tensorflow.python.keras.optimizer_v2 import optimizer_v2
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_training_ops
 from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import math_ops
-from tensorflow.python.training import gen_training_ops
 
 
 class Ftrl(optimizer_v2.OptimizerV2):
diff --git a/tensorflow/python/keras/optimizer_v2/gradient_descent.py b/tensorflow/python/keras/optimizer_v2/gradient_descent.py
index fe4e01d..29423429 100644
--- a/tensorflow/python/keras/optimizer_v2/gradient_descent.py
+++ b/tensorflow/python/keras/optimizer_v2/gradient_descent.py
@@ -19,7 +19,7 @@
 from tensorflow.python.keras.optimizer_v2 import optimizer_v2
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gen_resource_variable_ops
-from tensorflow.python.training import gen_training_ops
+from tensorflow.python.ops import gen_training_ops
 
 
 class SGD(optimizer_v2.OptimizerV2):
diff --git a/tensorflow/python/keras/optimizer_v2/rmsprop.py b/tensorflow/python/keras/optimizer_v2/rmsprop.py
index 39fa85a..c2e4d44 100644
--- a/tensorflow/python/keras/optimizer_v2/rmsprop.py
+++ b/tensorflow/python/keras/optimizer_v2/rmsprop.py
@@ -24,9 +24,9 @@
 from tensorflow.python.keras.optimizer_v2 import optimizer_v2
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_training_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import state_ops
-from tensorflow.python.training import gen_training_ops
 
 
 class RMSprop(optimizer_v2.OptimizerV2):
diff --git a/tensorflow/python/keras/saving/BUILD b/tensorflow/python/keras/saving/BUILD
index db1d2d8..a8e9553 100644
--- a/tensorflow/python/keras/saving/BUILD
+++ b/tensorflow/python/keras/saving/BUILD
@@ -49,7 +49,7 @@
         "//tensorflow/python/keras/utils:engine_utils",
         "//tensorflow/python/keras/utils:metrics_utils",
         "//tensorflow/python/keras/utils:mode_keys",
-        "//tensorflow/python/lib/io:lib",
+        "//tensorflow/python/lib/io:file_io",
         "//tensorflow/python/ops:math_ops",
         "//tensorflow/python/platform:gfile",
         "//tensorflow/python/platform:tf_logging",
diff --git a/tensorflow/python/kernel_tests/array_ops/BUILD b/tensorflow/python/kernel_tests/array_ops/BUILD
index 618af15..4852a3c 100644
--- a/tensorflow/python/kernel_tests/array_ops/BUILD
+++ b/tensorflow/python/kernel_tests/array_ops/BUILD
@@ -216,10 +216,10 @@
     deps = [
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python/eager:def_function",
-        "//tensorflow/python/framework",
         "//tensorflow/python/framework:constant_op",
         "//tensorflow/python/framework:dtypes",
         "//tensorflow/python/framework:errors",
+        "//tensorflow/python/framework:importer",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/framework:tensor",
         "//tensorflow/python/framework:tensor_shape",
diff --git a/tensorflow/python/kernel_tests/control_flow/BUILD b/tensorflow/python/kernel_tests/control_flow/BUILD
index ed84b0c..0826a30 100644
--- a/tensorflow/python/kernel_tests/control_flow/BUILD
+++ b/tensorflow/python/kernel_tests/control_flow/BUILD
@@ -20,7 +20,6 @@
         "//tensorflow/python/eager:context",
         "//tensorflow/python/eager:def_function",
         "//tensorflow/python/eager:remote",
-        "//tensorflow/python/framework",
         "//tensorflow/python/framework:constant_op",
         "//tensorflow/python/framework:dtypes",
         "//tensorflow/python/framework:ops",
@@ -285,10 +284,11 @@
         "//tensorflow/python/eager:backprop",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/eager:def_function",
-        "//tensorflow/python/framework",
         "//tensorflow/python/framework:constant_op",
         "//tensorflow/python/framework:dtypes",
         "//tensorflow/python/framework:function",
+        "//tensorflow/python/framework:importer",
+        "//tensorflow/python/framework:meta_graph",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/framework:tensor_shape",
         "//tensorflow/python/framework:tensor_spec",
diff --git a/tensorflow/python/kernel_tests/custom_ops/BUILD b/tensorflow/python/kernel_tests/custom_ops/BUILD
index 7404adc..45360b1 100644
--- a/tensorflow/python/kernel_tests/custom_ops/BUILD
+++ b/tensorflow/python/kernel_tests/custom_ops/BUILD
@@ -24,7 +24,7 @@
         "notap",
     ],
     deps = [
-        "//tensorflow/python/framework",
+        "//tensorflow/python/framework:load_library",
         "//tensorflow/python/framework:test_lib",
         "//tensorflow/python/platform:client_testlib",
         "//tensorflow/python/platform:resource_loader",
@@ -46,7 +46,7 @@
         "notap",
     ],
     deps = [
-        "//tensorflow/python/framework",
+        "//tensorflow/python/framework:load_library",
         "//tensorflow/python/framework:test_lib",
         "//tensorflow/python/ops:math_ops",
         "//tensorflow/python/platform:client_testlib",
@@ -69,8 +69,8 @@
         "notap",
     ],
     deps = [
-        "//tensorflow/python/framework",
         "//tensorflow/python/framework:errors",
+        "//tensorflow/python/framework:load_library",
         "//tensorflow/python/platform:client_testlib",
         "//tensorflow/python/platform:resource_loader",
     ],
diff --git a/tensorflow/python/kernel_tests/variables/BUILD b/tensorflow/python/kernel_tests/variables/BUILD
index 9f2dcc0..1e2cf18 100644
--- a/tensorflow/python/kernel_tests/variables/BUILD
+++ b/tensorflow/python/kernel_tests/variables/BUILD
@@ -149,8 +149,9 @@
         "//tensorflow/python/eager:def_function",
         "//tensorflow/python/eager:wrap_function",
         "//tensorflow/python/framework:constant_op",
+        "//tensorflow/python/framework:dtypes",
         "//tensorflow/python/framework:errors",
-        "//tensorflow/python/framework:for_generated_wrappers",
+        "//tensorflow/python/framework:ops",
         "//tensorflow/python/framework:test_lib",
         "//tensorflow/python/layers",
         "//tensorflow/python/ops:array_ops",
@@ -158,6 +159,7 @@
         "//tensorflow/python/ops:init_ops",
         "//tensorflow/python/ops:math_ops",
         "//tensorflow/python/ops:resource_variable_ops",
+        "//tensorflow/python/ops:resource_variables_toggle",
         "//tensorflow/python/ops:state_ops",
         "//tensorflow/python/ops:variable_scope",
         "//tensorflow/python/ops:variable_v1",
diff --git a/tensorflow/python/kernel_tests/variables/variable_scope_test.py b/tensorflow/python/kernel_tests/variables/variable_scope_test.py
index 41eb82a..44e6587 100644
--- a/tensorflow/python/kernel_tests/variables/variable_scope_test.py
+++ b/tensorflow/python/kernel_tests/variables/variable_scope_test.py
@@ -33,6 +33,7 @@
 from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import resource_variables_toggle
 from tensorflow.python.ops import state_ops
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variable_v1
@@ -456,18 +457,18 @@
   # AssertionError: True is not false (last assertFalse)
   @test_util.run_deprecated_v1
   def testEnableResourceVariables(self):
-    old = variable_scope._DEFAULT_USE_RESOURCE
+    old = resource_variables_toggle._DEFAULT_USE_RESOURCE
     try:
-      variable_scope.enable_resource_variables()
+      resource_variables_toggle.enable_resource_variables()
       self.assertIsInstance(
           variable_v1.VariableV1(1.0),
           resource_variable_ops.ResourceVariable)
-      variable_scope.disable_resource_variables()
+      resource_variables_toggle.disable_resource_variables()
       self.assertNotIsInstance(
           variable_v1.VariableV1(1.0),
           resource_variable_ops.ResourceVariable)
     finally:
-      variable_scope._DEFAULT_USE_RESOURCE = old
+      resource_variables_toggle._DEFAULT_USE_RESOURCE = old
 
   # Not converted to use wrap_function because of
   # TypeError: Fetch argument None has invalid type <type 'NoneType'>
diff --git a/tensorflow/python/lib/io/BUILD b/tensorflow/python/lib/io/BUILD
index c15359b..5f7b8e0 100644
--- a/tensorflow/python/lib/io/BUILD
+++ b/tensorflow/python/lib/io/BUILD
@@ -34,23 +34,6 @@
     ],
 )
 
-py_strict_library(
-    name = "lib",
-    deprecation = "This target has been split. Depend on the sub-targets instead.",
-    srcs_version = "PY3",
-    visibility = visibility + [
-        "//tensorflow:internal",
-        "//third_party/cloud_tpu/convergence_tools:__subpackages__",
-        "//third_party/py/tf_agents:__subpackages__",
-        "//third_party/py/tf_slim:__subpackages__",
-    ],
-    deps = [
-        ":file_io",
-        ":python_io",
-        ":tf_record",
-    ],
-)
-
 tf_python_pybind_extension(
     name = "_pywrap_record_io",
     srcs = ["record_io_wrapper.cc"],
diff --git a/tensorflow/python/modules_with_exports.py b/tensorflow/python/modules_with_exports.py
index 4b89ea8..5f86568 100644
--- a/tensorflow/python/modules_with_exports.py
+++ b/tensorflow/python/modules_with_exports.py
@@ -91,7 +91,6 @@
 from tensorflow.python.module import module
 
 # Ops
-from tensorflow.python.ops.standard_ops import *  # pylint: disable=redefined-builtin
 from tensorflow.python.ops.random_crop_ops import *
 from tensorflow.python.ops import bincount_ops
 from tensorflow.python.ops import bitwise_ops as bitwise
diff --git a/tensorflow/python/ops/BUILD b/tensorflow/python/ops/BUILD
index b4d46a7..9077f8d 100644
--- a/tensorflow/python/ops/BUILD
+++ b/tensorflow/python/ops/BUILD
@@ -620,7 +620,10 @@
 tf_gen_op_strict_wrapper_private_py(
     name = "training_ops_gen",
     visibility = [
+        "//tensorflow/compiler/tests:__pkg__",
+        "//tensorflow/contrib/opt:__pkg__",
         "//tensorflow/python:__pkg__",
+        "//tensorflow/python/eager/polymorphic_function:__pkg__",
         "//tensorflow/python/training:__pkg__",
     ],
 )
@@ -1283,6 +1286,7 @@
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/util:keras_deps",
         "//tensorflow/python/util:tf_decorator",
+        "//tensorflow/python/util:tf_export",
     ],
 )
 
@@ -1302,7 +1306,6 @@
     srcs_version = "PY3",
     deps = [
         ":control_flow_util",
-        ":control_flow_util_v2",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/platform:tf_logging",
         "//tensorflow/python/util:tf_export",
@@ -1316,7 +1319,6 @@
     python_version = "PY3",
     deps = [
         ":control_flow_util_v2",
-        ":control_flow_v2_toggles",
         "//tensorflow/python/platform:client_testlib",
         "//tensorflow/python/platform:test",
     ],
@@ -1630,11 +1632,13 @@
         ":linalg_grad",
         ":linalg_ops",
         ":logging_ops",
+        ":lookup_grad",
         ":manip_grad",
         ":math_grad",
         ":math_ops",
         ":nccl_ops",
         ":optional_grad",
+        ":parsing_grad",
         ":proto_ops",
         ":random_grad",
         ":rnn_grad",
@@ -1962,6 +1966,7 @@
     deps = [
         ":array_ops",
         ":control_flow_ops",
+        ":lookup_grad",
         ":lookup_ops_gen",
         ":math_ops",
         ":string_ops",
@@ -1986,6 +1991,14 @@
     ],
 )
 
+py_strict_library(
+    name = "lookup_grad",
+    srcs = ["lookup_grad.py"],
+    deps = [
+        "//tensorflow/python/framework:ops",
+    ],
+)
+
 tf_py_strict_test(
     name = "lookup_ops_async_checkpoint_test",
     srcs = ["lookup_ops_async_checkpoint_test.py"],
@@ -2307,6 +2320,7 @@
         ":control_flow_ops",
         ":math_ops",
         ":parsing_config",
+        ":parsing_grad",
         ":parsing_ops_gen",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/framework:sparse_tensor",
@@ -2317,6 +2331,12 @@
 )
 
 py_strict_library(
+    name = "parsing_grad",
+    srcs = ["parsing_grad.py"],
+    deps = ["//tensorflow/python/framework:ops"],
+)
+
+py_strict_library(
     name = "partitioned_variables",
     srcs = ["partitioned_variables.py"],
     srcs_version = "PY3",
@@ -3080,11 +3100,10 @@
         ":array_ops",
         ":init_ops",
         ":resource_variable_ops",
+        ":resource_variables_toggle",
         ":variables",
-        "//tensorflow/python:tf2",
         "//tensorflow/python/client:session",
         "//tensorflow/python/eager:context",
-        "//tensorflow/python/eager:monitoring",
         "//tensorflow/python/framework:dtypes",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/framework:tensor",
@@ -3093,7 +3112,6 @@
         "//tensorflow/python/platform:tf_logging",
         "//tensorflow/python/types:core",
         "//tensorflow/python/util:compat",
-        "//tensorflow/python/util:deprecation",
         "//tensorflow/python/util:function_utils",
         "//tensorflow/python/util:tf_decorator",
         "//tensorflow/python/util:tf_export",
@@ -3101,6 +3119,19 @@
 )
 
 py_strict_library(
+    name = "resource_variables_toggle",
+    srcs = ["resource_variables_toggle.py"],
+    srcs_version = "PY3",
+    deps = [
+        "//tensorflow/python:tf2",
+        "//tensorflow/python/eager:monitoring",
+        "//tensorflow/python/platform:tf_logging",
+        "//tensorflow/python/util:deprecation",
+        "//tensorflow/python/util:tf_export",
+    ],
+)
+
+py_strict_library(
     name = "variables",
     srcs = ["variables.py"],
     srcs_version = "PY3",
@@ -3137,6 +3168,7 @@
         ":array_ops",
         ":array_ops_gen",
         ":resource_variable_ops",
+        ":resource_variables_toggle",
         ":state_ops",
         ":state_ops_gen",
         ":variable_scope",
@@ -4755,10 +4787,10 @@
         ":weak_tensor_test_util",
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python/eager:def_function",
-        "//tensorflow/python/framework",
         "//tensorflow/python/framework:constant_op",
         "//tensorflow/python/framework:dtypes",
         "//tensorflow/python/framework:flexible_dtypes",
+        "//tensorflow/python/framework:importer",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/framework:tensor",
         "//tensorflow/python/framework:tensor_util",
diff --git a/tensorflow/python/ops/control_flow_util_v2.py b/tensorflow/python/ops/control_flow_util_v2.py
index 213698c..3d36180 100644
--- a/tensorflow/python/ops/control_flow_util_v2.py
+++ b/tensorflow/python/ops/control_flow_util_v2.py
@@ -29,6 +29,7 @@
 from tensorflow.python.ops import gradients_util
 from tensorflow.python.util import keras_deps
 from tensorflow.python.util import tf_contextlib
+from tensorflow.python.util.tf_export import tf_export
 
 _EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = None
 _DISABLE_LOWER_USING_SWITCH_MERGE = False
@@ -398,3 +399,31 @@
     return results
   else:
     return make_op(inputs)
+
+
+@tf_export(v1=["experimental.output_all_intermediates"])
+def set_output_all_intermediates(state):  # pylint: disable=invalid-name
+  """Whether to output all intermediates from functional control flow ops.
+
+  The "default" behavior to is to output all intermediates when using v2 control
+  flow inside Keras models in graph mode (possibly inside Estimators). This is
+  needed to support taking gradients of v2 control flow. In graph mode, Keras
+  can sometimes freeze the forward graph before the gradient computation which
+  does not work for v2 control flow since it requires updating the forward ops
+  to output the needed intermediates. We work around this by proactively
+  outputting the needed intermediates when building the forward pass itself.
+  Ideally any such extra tensors should be pruned out at runtime. However, if
+  for any reason this doesn't work for you or if you have an inference-only
+  model you can turn this behavior off using
+  `tf.compat.v1.experimental.output_all_intermediates(False)`.
+
+  If with the default behavior you are still seeing errors of the form
+  "Connecting to invalid output X of source node Y which has Z outputs" try
+  setting `tf.compat.v1.experimental.output_all_intermediates(True)` and
+  please file an issue at https://github.com/tensorflow/tensorflow/issues.
+
+  Args:
+    state: True, False or None. None restores the default behavior.
+  """
+  global _EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
+  _EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = state  # pylint: disable=protected-access
diff --git a/tensorflow/python/ops/control_flow_v2_toggles.py b/tensorflow/python/ops/control_flow_v2_toggles.py
index 0985ce9..8ce7d59 100644
--- a/tensorflow/python/ops/control_flow_v2_toggles.py
+++ b/tensorflow/python/ops/control_flow_v2_toggles.py
@@ -17,7 +17,6 @@
 
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import control_flow_util
-from tensorflow.python.ops import control_flow_util_v2
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.util.tf_export import tf_export
 
@@ -68,30 +67,3 @@
   Note: v2 control flow is always enabled inside of tf.function.
   """
   return control_flow_util.EnableControlFlowV2(ops.get_default_graph())
-
-
-@tf_export(v1=["experimental.output_all_intermediates"])
-def output_all_intermediates(state):  # pylint: disable=invalid-name
-  """Whether to output all intermediates from functional control flow ops.
-
-  The "default" behavior to is to output all intermediates when using v2 control
-  flow inside Keras models in graph mode (possibly inside Estimators). This is
-  needed to support taking gradients of v2 control flow. In graph mode, Keras
-  can sometimes freeze the forward graph before the gradient computation which
-  does not work for v2 control flow since it requires updating the forward ops
-  to output the needed intermediates. We work around this by proactively
-  outputting the needed intermediates when building the forward pass itself.
-  Ideally any such extra tensors should be pruned out at runtime. However, if
-  for any reason this doesn't work for you or if you have an inference-only
-  model you can turn this behavior off using
-  `tf.compat.v1.experimental.output_all_intermediates(False)`.
-
-  If with the default behavior you are still seeing errors of the form
-  "Connecting to invalid output X of source node Y which has Z outputs" try
-  setting `tf.compat.v1.experimental.output_all_intermediates(True)` and
-  please file an issue at https://github.com/tensorflow/tensorflow/issues.
-
-  Args:
-    state: True, False or None. None restores the default behavior.
-  """
-  control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = state  # pylint: disable=protected-access
diff --git a/tensorflow/python/ops/control_flow_v2_toggles_test.py b/tensorflow/python/ops/control_flow_v2_toggles_test.py
index 51febb4..f91778e 100644
--- a/tensorflow/python/ops/control_flow_v2_toggles_test.py
+++ b/tensorflow/python/ops/control_flow_v2_toggles_test.py
@@ -15,7 +15,6 @@
 """Tests for control_flow_v2_toggles.py."""
 
 from tensorflow.python.ops import control_flow_util_v2
-from tensorflow.python.ops import control_flow_v2_toggles
 from tensorflow.python.platform import googletest
 from tensorflow.python.platform import test
 
@@ -25,13 +24,13 @@
   def testOutputAllIntermediates(self):
     self.assertIsNone(
         control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE)
-    control_flow_v2_toggles.output_all_intermediates(True)
+    control_flow_util_v2.set_output_all_intermediates(True)
     self.assertTrue(
         control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE)
-    control_flow_v2_toggles.output_all_intermediates(False)
+    control_flow_util_v2.set_output_all_intermediates(False)
     self.assertFalse(
         control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE)
-    control_flow_v2_toggles.output_all_intermediates(None)
+    control_flow_util_v2.set_output_all_intermediates(None)
     self.assertIsNone(
         control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE)
 
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index b7d2e77..ae88a6d 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -27,11 +27,13 @@
 from tensorflow.python.ops import linalg_grad  # pylint: disable=unused-import
 from tensorflow.python.ops import linalg_ops  # pylint: disable=unused-import
 from tensorflow.python.ops import logging_ops  # pylint: disable=unused-import
+from tensorflow.python.ops import lookup_grad  # pylint: disable=unused-import
 from tensorflow.python.ops import manip_grad  # pylint: disable=unused-import
 from tensorflow.python.ops import math_grad  # pylint: disable=unused-import
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import nccl_ops  # pylint: disable=unused-import
 from tensorflow.python.ops import optional_grad  # pylint: disable=unused-import
+from tensorflow.python.ops import parsing_grad  # pylint: disable=unused-import
 from tensorflow.python.ops import proto_ops  # pylint: disable=unused-import
 from tensorflow.python.ops import random_grad  # pylint: disable=unused-import
 from tensorflow.python.ops import rnn_grad  # pylint: disable=unused-import
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index 4195f44..fd8995f 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -15,6 +15,7 @@
 """Implementation of image ops."""
 
 import functools
+
 import numpy as np
 
 from tensorflow.python.eager import context
@@ -2919,7 +2920,7 @@
 
 @tf_export('image.adjust_jpeg_quality')
 @dispatch.add_dispatch_support
-def adjust_jpeg_quality(image, jpeg_quality, name=None):
+def adjust_jpeg_quality(image, jpeg_quality, dct_method='', name=None):
   """Adjust jpeg encoding quality of an image.
 
   This is a convenience method that converts an image to uint8 representation,
@@ -2955,7 +2956,7 @@
          [[1., 1., 1.],
           [1., 1., 1.]]], dtype=float32)>
 
-  Note that `jpeg_quality` 100 is still lossy compresson.
+  Note that `jpeg_quality` 100 is still lossy compression.
 
   >>> x = tf.constant([[[1, 2, 3],
   ...                   [4, 5, 6]],
@@ -2971,6 +2972,10 @@
   Args:
     image: 3D image. The size of the last dimension must be None, 1 or 3.
     jpeg_quality: Python int or Tensor of type int32. jpeg encoding quality.
+    dct_method: An optional string. Specifies the DCT method to use for JPEG
+      decompression. Currently available options are ["INTEGER_FAST",
+      "INTEGER_ACCURATE"]. Defaults to "" which maps to "INTEGER_FAST",
+      sacrificing image quality for speed.
     name: A name for this operation (optional).
 
   Returns:
@@ -2991,7 +2996,9 @@
       jpeg_quality = ops.convert_to_tensor(jpeg_quality, dtype=dtypes.int32)
     image = gen_image_ops.encode_jpeg_variable_quality(image, jpeg_quality)
 
-    image = gen_image_ops.decode_jpeg(image, channels=channels)
+    image = gen_image_ops.decode_jpeg(
+        image, channels=channels, dct_method=dct_method
+    )
     return convert_image_dtype(image, orig_dtype, saturate=True)
 
 
diff --git a/tensorflow/python/ops/lookup_grad.py b/tensorflow/python/ops/lookup_grad.py
new file mode 100644
index 0000000..3ae89b6
--- /dev/null
+++ b/tensorflow/python/ops/lookup_grad.py
@@ -0,0 +1,37 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Gradients for lookup operations."""
+
+from tensorflow.python.framework import ops
+
+
+ops.NotDifferentiable("LookupTableFind")
+ops.NotDifferentiable("LookupTableFindV2")
+ops.NotDifferentiable("LookupTableInsert")
+ops.NotDifferentiable("LookupTableInsertV2")
+ops.NotDifferentiable("LookupTableSize")
+ops.NotDifferentiable("LookupTableSizeV2")
+ops.NotDifferentiable("HashTable")
+ops.NotDifferentiable("HashTableV2")
+ops.NotDifferentiable("InitializeTable")
+ops.NotDifferentiable("InitializeTableV2")
+ops.NotDifferentiable("InitializeTableFromTextFile")
+ops.NotDifferentiable("InitializeTableFromTextFileV2")
+ops.NotDifferentiable("MutableDenseHashTable")
+ops.NotDifferentiable("MutableDenseHashTableV2")
+ops.NotDifferentiable("MutableHashTable")
+ops.NotDifferentiable("MutableHashTableV2")
+ops.NotDifferentiable("MutableHashTableOfTensors")
+ops.NotDifferentiable("MutableHashTableOfTensorsV2")
diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py
index 1f3ed12..9731fff 100644
--- a/tensorflow/python/ops/lookup_ops.py
+++ b/tensorflow/python/ops/lookup_ops.py
@@ -30,6 +30,8 @@
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import gen_lookup_ops
+# Ensure lookup gradients are registered
+from tensorflow.python.ops import lookup_grad  # pylint: disable=unused-import
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import string_ops
 # go/tf-wildcard-import
@@ -2458,23 +2460,3 @@
           return gen_lookup_ops.lookup_table_import_v2(self.op.resource_handle,
                                                        restored_tensors[0],
                                                        restored_tensors[1])
-
-
-ops.NotDifferentiable("LookupTableFind")
-ops.NotDifferentiable("LookupTableFindV2")
-ops.NotDifferentiable("LookupTableInsert")
-ops.NotDifferentiable("LookupTableInsertV2")
-ops.NotDifferentiable("LookupTableSize")
-ops.NotDifferentiable("LookupTableSizeV2")
-ops.NotDifferentiable("HashTable")
-ops.NotDifferentiable("HashTableV2")
-ops.NotDifferentiable("InitializeTable")
-ops.NotDifferentiable("InitializeTableV2")
-ops.NotDifferentiable("InitializeTableFromTextFile")
-ops.NotDifferentiable("InitializeTableFromTextFileV2")
-ops.NotDifferentiable("MutableDenseHashTable")
-ops.NotDifferentiable("MutableDenseHashTableV2")
-ops.NotDifferentiable("MutableHashTable")
-ops.NotDifferentiable("MutableHashTableV2")
-ops.NotDifferentiable("MutableHashTableOfTensors")
-ops.NotDifferentiable("MutableHashTableOfTensorsV2")
diff --git a/tensorflow/core/tfrt/saved_model/python/_pywrap_saved_model_aot_compile.pyi b/tensorflow/python/ops/parsing_grad.py
similarity index 70%
rename from tensorflow/core/tfrt/saved_model/python/_pywrap_saved_model_aot_compile.pyi
rename to tensorflow/python/ops/parsing_grad.py
index 05aae4b..bcfb2a0 100644
--- a/tensorflow/core/tfrt/saved_model/python/_pywrap_saved_model_aot_compile.pyi
+++ b/tensorflow/python/ops/parsing_grad.py
@@ -13,7 +13,12 @@
 # limitations under the License.
 # ==============================================================================
 
-class AotOptions:
-    def __init__(self) -> None: ...
+"""Gradient registrations for parsing ops."""
+from tensorflow.python.framework import ops
 
-def AotCompileSavedModel(input_model_dir: str = ..., aot_options: AotOptions = ..., output_model_dir: str = ...) -> None: ...
+
+ops.NotDifferentiable("DecodeRaw")
+ops.NotDifferentiable("DecodePaddedRaw")
+ops.NotDifferentiable("ParseTensor")
+ops.NotDifferentiable("SerializeTensor")
+ops.NotDifferentiable("StringToNumber")
diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py
index 66b2ea3..507bb33 100644
--- a/tensorflow/python/ops/parsing_ops.py
+++ b/tensorflow/python/ops/parsing_ops.py
@@ -22,6 +22,8 @@
 from tensorflow.python.ops import gen_parsing_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import parsing_config
+# Ensure parsing_ops gradients are registered
+from tensorflow.python.ops import parsing_grad  # pylint: disable=unused-import
 # go/tf-wildcard-import
 # pylint: disable=wildcard-import,undefined-variable
 from tensorflow.python.ops.gen_parsing_ops import *
@@ -30,14 +32,6 @@
 from tensorflow.python.util import dispatch
 from tensorflow.python.util.tf_export import tf_export
 
-
-ops.NotDifferentiable("DecodeRaw")
-ops.NotDifferentiable("DecodePaddedRaw")
-ops.NotDifferentiable("ParseTensor")
-ops.NotDifferentiable("SerializeTensor")
-ops.NotDifferentiable("StringToNumber")
-
-
 VarLenFeature = parsing_config.VarLenFeature
 RaggedFeature = parsing_config.RaggedFeature
 SparseFeature = parsing_config.SparseFeature
diff --git a/tensorflow/python/ops/ref_variable.py b/tensorflow/python/ops/ref_variable.py
index aebba73..7e51288 100644
--- a/tensorflow/python/ops/ref_variable.py
+++ b/tensorflow/python/ops/ref_variable.py
@@ -26,6 +26,7 @@
 from tensorflow.python.ops import gen_array_ops
 from tensorflow.python.ops import gen_state_ops
 from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import resource_variables_toggle
 from tensorflow.python.ops import state_ops
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variable_v1
@@ -59,7 +60,7 @@
   if use_resource is None:
     use_resource = variable_scope.get_variable_scope().use_resource
   if use_resource is None:
-    use_resource = variable_scope._DEFAULT_USE_RESOURCE  # pylint: disable=protected-access
+    use_resource = resource_variables_toggle.resource_variables_enabled()
   use_resource = use_resource or context.executing_eagerly()
   if use_resource:
     distribute_strategy = kwargs.get("distribute_strategy", None)
diff --git a/tensorflow/python/ops/resource_variables_toggle.py b/tensorflow/python/ops/resource_variables_toggle.py
new file mode 100644
index 0000000..0521a16
--- /dev/null
+++ b/tensorflow/python/ops/resource_variables_toggle.py
@@ -0,0 +1,84 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Toggle to enable/disable resource variables."""
+
+from tensorflow.python import tf2
+from tensorflow.python.eager import monitoring
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import deprecation
+from tensorflow.python.util.tf_export import tf_export
+
+
+_api_usage_gauge = monitoring.BoolGauge(
+    "/tensorflow/api/resource_variables",
+    "Whether resource_variables_toggle.enable_resource_variables() is called.")
+
+_DEFAULT_USE_RESOURCE = tf2.enabled()
+
+
+@tf_export(v1=["enable_resource_variables"])
+def enable_resource_variables() -> None:
+  """Creates resource variables by default.
+
+  Resource variables are improved versions of TensorFlow variables with a
+  well-defined memory model. Accessing a resource variable reads its value, and
+  all ops which access a specific read value of the variable are guaranteed to
+  see the same value for that tensor. Writes which happen after a read (by
+  having a control or data dependency on the read) are guaranteed not to affect
+  the value of the read tensor, and similarly writes which happen before a read
+  are guaranteed to affect the value. No guarantees are made about unordered
+  read/write pairs.
+
+  Calling tf.enable_resource_variables() lets you opt-in to this TensorFlow 2.0
+  feature.
+  """
+  global _DEFAULT_USE_RESOURCE
+  _DEFAULT_USE_RESOURCE = True
+  logging.vlog(1, "Enabling resource variables")
+  _api_usage_gauge.get_cell().set(True)
+
+
+@deprecation.deprecated(
+    None, "non-resource variables are not supported in the long term")
+@tf_export(v1=["disable_resource_variables"])
+def disable_resource_variables() -> None:
+  """Opts out of resource variables.
+
+  If your code needs tf.disable_resource_variables() to be called to work
+  properly please file a bug.
+  """
+  global _DEFAULT_USE_RESOURCE
+  _DEFAULT_USE_RESOURCE = False
+  logging.vlog(1, "Disabling resource variables")
+  _api_usage_gauge.get_cell().set(False)
+
+
+@tf_export(v1=["resource_variables_enabled"])
+def resource_variables_enabled() -> bool:
+  """Returns `True` if resource variables are enabled.
+
+  Resource variables are improved versions of TensorFlow variables with a
+  well-defined memory model. Accessing a resource variable reads its value, and
+  all ops which access a specific read value of the variable are guaranteed to
+  see the same value for that tensor. Writes which happen after a read (by
+  having a control or data dependency on the read) are guaranteed not to affect
+  the value of the read tensor, and similarly writes which happen before a read
+  are guaranteed to affect the value. No guarantees are made about unordered
+  read/write pairs.
+
+  Calling tf.enable_resource_variables() lets you opt-in to this TensorFlow 2.0
+  feature.
+  """
+  return _DEFAULT_USE_RESOURCE
diff --git a/tensorflow/python/ops/structured/BUILD b/tensorflow/python/ops/structured/BUILD
index 708242c..f8d659f 100644
--- a/tensorflow/python/ops/structured/BUILD
+++ b/tensorflow/python/ops/structured/BUILD
@@ -32,13 +32,9 @@
 
 py_strict_library(
     name = "structured_tensor",
-    srcs = [
-        "structured_array_ops.py",
-        "structured_tensor.py",
-    ],
+    srcs = ["structured_tensor.py"],
     srcs_version = "PY3",
     deps = [
-        "//tensorflow/core/config:flags_py",
         "//tensorflow/python/framework:constant_op",
         "//tensorflow/python/framework:dtypes",
         "//tensorflow/python/framework:extension_type",
@@ -50,14 +46,11 @@
         "//tensorflow/python/ops:check_ops",
         "//tensorflow/python/ops:control_flow_ops",
         "//tensorflow/python/ops:math_ops",
-        "//tensorflow/python/ops:random_ops",
         "//tensorflow/python/ops/ragged:dynamic_ragged_shape",
         "//tensorflow/python/ops/ragged:ragged_factory_ops",
         "//tensorflow/python/ops/ragged:ragged_tensor",
         "//tensorflow/python/ops/ragged:row_partition",
         "//tensorflow/python/util:compat",
-        "//tensorflow/python/util:deprecation",
-        "//tensorflow/python/util:dispatch",
         "//tensorflow/python/util:nest",
         "//tensorflow/python/util:tf_export",
         "//third_party/py/numpy",
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index ad066f2..5c1e7cd 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -21,10 +21,8 @@
 import threading
 import traceback
 
-from tensorflow.python import tf2
 from tensorflow.python.client import session
 from tensorflow.python.eager import context
-from tensorflow.python.eager import monitoring
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor
@@ -33,10 +31,10 @@
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import resource_variables_toggle
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.types import core
-from tensorflow.python.util import deprecation
 from tensorflow.python.util import function_utils
 from tensorflow.python.util import tf_contextlib
 from tensorflow.python.util import tf_inspect
@@ -50,10 +48,6 @@
     "no_regularizer", "VariableSynchronization", "VariableAggregation"
 ]
 
-_api_usage_gauge = monitoring.BoolGauge(
-    "/tensorflow/api/resource_variables",
-    "Whether variable_scope.enable_resource_variables() is called.")
-
 
 class _PartitionInfo:
   """Holds partition info used by initializer functions."""
@@ -219,65 +213,6 @@
 it does exist, simply return it.
 """
 
-_DEFAULT_USE_RESOURCE = tf2.enabled()
-
-
-@tf_export(v1=["enable_resource_variables"])
-def enable_resource_variables():
-  """Creates resource variables by default.
-
-  Resource variables are improved versions of TensorFlow variables with a
-  well-defined memory model. Accessing a resource variable reads its value, and
-  all ops which access a specific read value of the variable are guaranteed to
-  see the same value for that tensor. Writes which happen after a read (by
-  having a control or data dependency on the read) are guaranteed not to affect
-  the value of the read tensor, and similarly writes which happen before a read
-  are guaranteed to affect the value. No guarantees are made about unordered
-  read/write pairs.
-
-  Calling tf.enable_resource_variables() lets you opt-in to this TensorFlow 2.0
-  feature.
-  """
-  global _DEFAULT_USE_RESOURCE
-  _DEFAULT_USE_RESOURCE = True
-  logging.vlog(1, "Enabling resource variables")
-  _api_usage_gauge.get_cell().set(True)
-
-
-@tf_export(v1=["resource_variables_enabled"])
-def resource_variables_enabled():
-  """Returns `True` if resource variables are enabled.
-
-  Resource variables are improved versions of TensorFlow variables with a
-  well-defined memory model. Accessing a resource variable reads its value, and
-  all ops which access a specific read value of the variable are guaranteed to
-  see the same value for that tensor. Writes which happen after a read (by
-  having a control or data dependency on the read) are guaranteed not to affect
-  the value of the read tensor, and similarly writes which happen before a read
-  are guaranteed to affect the value. No guarantees are made about unordered
-  read/write pairs.
-
-  Calling tf.enable_resource_variables() lets you opt-in to this TensorFlow 2.0
-  feature.
-  """
-  global _DEFAULT_USE_RESOURCE
-  return _DEFAULT_USE_RESOURCE
-
-
-@deprecation.deprecated(
-    None, "non-resource variables are not supported in the long term")
-@tf_export(v1=["disable_resource_variables"])
-def disable_resource_variables():
-  """Opts out of resource variables.
-
-  If your code needs tf.disable_resource_variables() to be called to work
-  properly please file a bug.
-  """
-  global _DEFAULT_USE_RESOURCE
-  _DEFAULT_USE_RESOURCE = False
-  logging.vlog(1, "Disabling resource variables")
-  _api_usage_gauge.get_cell().set(False)
-
 
 def _needs_no_arguments(python_callable):
   """Returns true if the callable needs no arguments to call."""
@@ -964,7 +899,7 @@
     # Create the variable.
     if use_resource is None:
       # Set the default value if unspecified.
-      use_resource = _DEFAULT_USE_RESOURCE
+      use_resource = resource_variables_toggle.resource_variables_enabled()
     v = _variable_v1(
         initial_value=init_val,
         name=name,
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index d99b7e1..5208dd1 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -352,7 +352,7 @@
         variable and return the Tensor for the projected value (which must have
         the same shape). Constraints are not safe to use when doing asynchronous
         distributed training.
-      synchronization: Indicates when a distributed a variable will be
+      synchronization: Indicates when a distributed variable will be
         aggregated. Accepted values are constants defined in the class
         `tf.VariableSynchronization`. By default the synchronization is set to
         `AUTO` and the current `DistributionStrategy` chooses when to
diff --git a/tensorflow/python/platform/BUILD b/tensorflow/python/platform/BUILD
index b3a7f64..1e344f1 100644
--- a/tensorflow/python/platform/BUILD
+++ b/tensorflow/python/platform/BUILD
@@ -91,7 +91,6 @@
         "@absl_py//absl:app",
         "@absl_py//absl/testing:absltest",
         "//tensorflow/python/framework:errors",
-        "//tensorflow/python/lib/io:lib",
         "//tensorflow/python/util:tf_decorator",
         "//tensorflow/python/util:tf_inspect",
     ]),
@@ -347,10 +346,7 @@
 tf_py_strict_test(
     name = "build_info_test",
     size = "small",
-    srcs = [
-        "build_info.py",
-        "build_info_test.py",
-    ],
+    srcs = ["build_info_test.py"],
     main = "build_info_test.py",
     python_version = "PY3",
     tags = [
@@ -367,26 +363,16 @@
 tf_py_strict_test(
     name = "benchmark_test",
     size = "small",
-    srcs = [
-        "benchmark.py",
-        "benchmark_test.py",
-    ],
+    srcs = ["benchmark_test.py"],
     main = "benchmark_test.py",
     python_version = "PY3",
     tags = [
         "no_pip",
     ],
     deps = [
+        ":benchmark",
         ":client_testlib",
-        ":gfile",
-        ":tf_logging",
         "//tensorflow/core:protos_all_py",
-        "//tensorflow/python/client",
-        "//tensorflow/python/client:timeline",
-        "//tensorflow/python/framework:ops",
-        "//tensorflow/python/util:tf_export",
-        "//tensorflow/python/util:tf_inspect",
-        "@absl_py//absl:app",
     ],
 )
 
diff --git a/tensorflow/python/py_exception_registry_wrapper.cc b/tensorflow/python/py_exception_registry_wrapper.cc
index 49739ba..0f1a17c 100644
--- a/tensorflow/python/py_exception_registry_wrapper.cc
+++ b/tensorflow/python/py_exception_registry_wrapper.cc
@@ -17,13 +17,32 @@
 
 #include <array>
 
+#include "pybind11/attr.h"  // from @pybind11
 #include "pybind11/pybind11.h"  // from @pybind11
 #include "pybind11/pytypes.h"  // from @pybind11
+#include "tensorflow/c/tf_status.h"
 #include "tensorflow/python/lib/core/py_exception_registry.h"
 
 namespace py = pybind11;
 
 PYBIND11_MODULE(_pywrap_py_exception_registry, m) {
+  py::enum_<TF_Code>(m, "TF_Code", py::module_local())
+      .value("TF_OK", TF_OK)
+      .value("TF_CANCELLED", TF_CANCELLED)
+      .value("TF_UNKNOWN", TF_UNKNOWN)
+      .value("TF_INVALID_ARGUMENT", TF_INVALID_ARGUMENT)
+      .value("TF_DEADLINE_EXCEEDED", TF_DEADLINE_EXCEEDED)
+      .value("TF_PERMISSION_DENIED", TF_PERMISSION_DENIED)
+      .value("TF_UNAUTHENTICATED", TF_UNAUTHENTICATED)
+      .value("TF_RESOURCE_EXHAUSTED", TF_RESOURCE_EXHAUSTED)
+      .value("TF_FAILED_PRECONDITION", TF_FAILED_PRECONDITION)
+      .value("TF_ABORTED", TF_ABORTED)
+      .value("TF_OUT_OF_RANGE", TF_OUT_OF_RANGE)
+      .value("TF_UNIMPLEMENTED", TF_UNIMPLEMENTED)
+      .value("TF_INTERNAL", TF_INTERNAL)
+      .value("TF_DATA_LOSS", TF_DATA_LOSS)
+      .export_values();
+
   m.def("PyExceptionRegistry_Init", [](py::object& code_to_exc_type_map) {
     tensorflow::PyExceptionRegistry::Init(code_to_exc_type_map.ptr());
   });
diff --git a/tensorflow/python/pywrap_dtensor_device.cc b/tensorflow/python/pywrap_dtensor_device.cc
index 4330c01..8cd5fe8 100644
--- a/tensorflow/python/pywrap_dtensor_device.cc
+++ b/tensorflow/python/pywrap_dtensor_device.cc
@@ -563,5 +563,21 @@
             return layout.num_shards_for_dim(dim);
           },
           py::arg("idx"),
-          "Returns the number of shards for tensor dimension `idx`.");
+          "Returns the number of shards for tensor dimension `idx`.")
+      .def(
+          "global_shape_from_local_shape",
+          [](const Layout& layout, std::vector<int64_t> local_shape) {
+            return py::tuple(
+                py::cast(layout.GlobalShapeFromLocalShape(local_shape)));
+          },
+          py::arg("local_shape"),
+          "Returns the global shape computed from this local shape.")
+      .def(
+          "local_shape_from_global_shape",
+          [](const Layout& layout, std::vector<int64_t> global_shape) {
+            return py::tuple(
+                py::cast(layout.LocalShapeFromGlobalShape(global_shape)));
+          },
+          py::arg("global_shape"),
+          "Returns the local shape computed from this global shape.");
 }
diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD
index 32e85b0..2b522cd 100644
--- a/tensorflow/python/saved_model/BUILD
+++ b/tensorflow/python/saved_model/BUILD
@@ -84,6 +84,7 @@
         "//tensorflow/python/util:compat",
         "//tensorflow/python/util:deprecation",
         "//tensorflow/python/util:tf_export",
+        # copybara:uncomment "//tensorflow/tools/proto_splitter/python:saved_model",
     ],
 )
 
@@ -125,8 +126,8 @@
         "//tensorflow/python/framework:errors",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/ops:control_flow_ops",
+        "//tensorflow/python/ops:resource_variables_toggle",
         "//tensorflow/python/ops:state_ops",
-        "//tensorflow/python/ops:variable_scope",
         "//tensorflow/python/ops:variable_v1",
         "//tensorflow/python/ops:variables",
         "//tensorflow/python/platform:client_testlib",
@@ -194,8 +195,8 @@
         "//tensorflow/python/lib/io:file_io",
         "//tensorflow/python/ops:control_flow_ops",
         "//tensorflow/python/ops:math_ops",
+        "//tensorflow/python/ops:resource_variables_toggle",
         "//tensorflow/python/ops:state_ops",
-        "//tensorflow/python/ops:variable_scope",
         "//tensorflow/python/ops:variable_v1",
         "//tensorflow/python/ops:variables",
         "//tensorflow/python/ops/ragged:ragged_factory_ops",
@@ -390,11 +391,11 @@
         "//tensorflow/python/eager/polymorphic_function:concrete_function",
         "//tensorflow/python/eager/polymorphic_function:saved_model_exported_concrete",
         "//tensorflow/python/eager/polymorphic_function:saved_model_utils",
-        "//tensorflow/python/framework",
         "//tensorflow/python/framework:dtypes",
         "//tensorflow/python/framework:error_interpolation",
         "//tensorflow/python/framework:errors",
         "//tensorflow/python/framework:function",
+        "//tensorflow/python/framework:meta_graph",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/framework:tensor_util",
         "//tensorflow/python/framework:versions",
@@ -453,9 +454,9 @@
         "//tensorflow/python/eager:context",
         "//tensorflow/python/eager:def_function",
         "//tensorflow/python/eager:test",
-        "//tensorflow/python/framework",
         "//tensorflow/python/framework:constant_op",
         "//tensorflow/python/framework:dtypes",
+        "//tensorflow/python/framework:meta_graph",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/framework:tensor_spec",
         "//tensorflow/python/framework:test_lib",
@@ -852,7 +853,6 @@
         ":method_name_updater",
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python/eager:test",
-        "//tensorflow/python/framework",
         "//tensorflow/python/lib/io:file_io",
         "//tensorflow/python/platform:client_testlib",
         "//tensorflow/python/util:compat",
@@ -993,7 +993,6 @@
         ":pywrap_saved_model",
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python/eager:test",
-        "//tensorflow/python/lib/io:lib",
         "//tensorflow/python/platform:client_testlib",
     ],
 )
diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py
index 18bdc53..ba0cb58 100644
--- a/tensorflow/python/saved_model/builder_impl.py
+++ b/tensorflow/python/saved_model/builder_impl.py
@@ -18,7 +18,6 @@
 import os
 
 from google.protobuf.any_pb2 import Any
-
 from tensorflow.core.framework import types_pb2
 from tensorflow.core.protobuf import meta_graph_pb2
 from tensorflow.core.protobuf import saved_model_pb2
@@ -38,7 +37,7 @@
 from tensorflow.python.util import compat
 from tensorflow.python.util.deprecation import deprecated_args
 from tensorflow.python.util.tf_export import tf_export
-
+# Placeholder for protosplitter import.  # copybara:comment
 # API label for SavedModel metrics.
 _SAVE_BUILDER_LABEL = "save_v1_builder"
 
@@ -397,19 +396,29 @@
     # subsequent attempts to save variables will fail.
     self._has_saved_variables = True
 
-  def save(self, as_text=False):
+  def save(self, as_text=False, experimental_image_format=False):
     """Writes a `SavedModel` protocol buffer to disk.
 
     The function writes the SavedModel protocol buffer to the export directory
     in a serialized format.
 
     Args:
-      as_text: Writes the SavedModel protocol buffer in text format to
-        disk. Protocol buffers in text format are useful for debugging, but
-        parsing fails when it encounters an unknown field and so is not forward
+      as_text: Writes the SavedModel protocol buffer in text format to disk.
+        Protocol buffers in text format are useful for debugging, but parsing
+        fails when it encounters an unknown field and so is not forward
         compatible. This means changes to TensorFlow may prevent deployment of
         new text format SavedModels to existing serving binaries. Do not deploy
         `as_text` SavedModels to production.
+      experimental_image_format: Writes the SavedModel protobuf in the
+        experimental image format. See
+      https://www.tensorflow.org/api_docs/python/tf/saved_model/SaveOptions for
+        more details. This allows `SavedModelBuilder` to save models larger than
+        2 GiB.
+    
+    Raises:
+       RuntimeError: When trying to use `proto_splitter` but `proto_splitter` is
+         not imported. This check is here because `proto_splitter` is not 
+         available in OSS at the moment. 
 
     Returns:
       The path to which the SavedModel protocol buffer was written.
@@ -424,11 +433,30 @@
           compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT))
       file_io.write_string_to_file(path, str(self._saved_model))
     else:
-      path = file_io.join(
-          compat.as_bytes(self._export_dir),
-          compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
-      file_io.write_string_to_file(
-          path, self._saved_model.SerializeToString(deterministic=True))
+      if experimental_image_format:
+        path = file_io.join(
+            self._export_dir,
+            constants.SAVED_MODEL_FILENAME_PREFIX,
+        )
+        if (
+            locals().get("proto_splitter", globals().get("proto_splitter"))
+            is None
+        ):
+          raise RuntimeError(
+              "No proto_splitter is provided, cannot use"
+              " experimental_image_format."
+          )
+        # Overwrites path to record whether the saved_model is split, i.e.,
+        # whether the suffix is `.pb` or `.cpb`.
+        path = proto_splitter.SavedModelSplitter(self._saved_model).write(path)
+      else:
+        path = file_io.join(
+            compat.as_bytes(self._export_dir),
+            compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB),
+        )
+        file_io.write_string_to_file(
+            path, self._saved_model.SerializeToString(deterministic=True)
+        )
       # Placeholder for internal TF1 model fingerprint write
     tf_logging.info("SavedModel written to: %s", compat.as_text(path))
     metrics.IncrementWrite(write_version="1")
diff --git a/tensorflow/python/saved_model/loader_test.py b/tensorflow/python/saved_model/loader_test.py
index 52433cd..ab621cc 100644
--- a/tensorflow/python/saved_model/loader_test.py
+++ b/tensorflow/python/saved_model/loader_test.py
@@ -23,8 +23,8 @@
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import resource_variables_toggle
 from tensorflow.python.ops import state_ops
-from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variable_v1
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import test
@@ -40,7 +40,7 @@
 
 
 def _tensor_name(name):
-  if variable_scope.resource_variables_enabled():
+  if resource_variables_toggle.resource_variables_enabled():
     return name + "/Read/ReadVariableOp:0"
   return name + ":0"
 
diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py
index 87fd915..36a2813 100644
--- a/tensorflow/python/saved_model/saved_model_test.py
+++ b/tensorflow/python/saved_model/saved_model_test.py
@@ -30,8 +30,8 @@
 from tensorflow.python.lib.io import file_io
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variables_toggle
 from tensorflow.python.ops import state_ops
-from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variable_v1
 from tensorflow.python.ops import variables
 from tensorflow.python.ops.ragged import ragged_factory_ops
@@ -96,7 +96,7 @@
     index = "0"
     if ":" in name:
       name, index = name.split(":")
-    if variable_scope.resource_variables_enabled():
+    if resource_variables_toggle.resource_variables_enabled():
       name = name + "/Read/ReadVariableOp"
     return self.evaluate(name + ":" + index)
 
@@ -934,7 +934,7 @@
         meta_graph_def = loader.load(sess, ["foo"], export_dir)
         self.assertEqual(3, self._eval("v1"))
         self.assertEqual(2, self._eval("v2"))
-        if variable_scope.resource_variables_enabled():
+        if resource_variables_toggle.resource_variables_enabled():
           self.assertEqual(
               loader_impl.get_train_op(meta_graph_def).type,
               "AssignAddVariableOp")
@@ -990,7 +990,7 @@
 
       with self.session(graph=ops.Graph()) as sess:
         meta_graph_def = loader.load(sess, ["foo"], export_dir)
-        if variable_scope.resource_variables_enabled():
+        if resource_variables_toggle.resource_variables_enabled():
           self.assertEqual(
               loader_impl.get_train_op(meta_graph_def).type,
               "AssignAddVariableOp")
diff --git a/tensorflow/python/summary/BUILD b/tensorflow/python/summary/BUILD
index 2b26dff..5ed5b0f 100644
--- a/tensorflow/python/summary/BUILD
+++ b/tensorflow/python/summary/BUILD
@@ -92,10 +92,10 @@
     deps = [
         ":summary_py",
         "//tensorflow/core:protos_all_py",
-        "//tensorflow/python/framework",
         "//tensorflow/python/framework:constant_op",
         "//tensorflow/python/framework:dtypes",
         "//tensorflow/python/framework:errors",
+        "//tensorflow/python/framework:meta_graph",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/framework:test_lib",
         "//tensorflow/python/ops:array_ops",
diff --git a/tensorflow/python/summary/writer/BUILD b/tensorflow/python/summary/writer/BUILD
index 8acb0c7..96a509a 100644
--- a/tensorflow/python/summary/writer/BUILD
+++ b/tensorflow/python/summary/writer/BUILD
@@ -29,7 +29,7 @@
         ":event_file_writer_v2",
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python/eager:context",
-        "//tensorflow/python/framework",
+        "//tensorflow/python/framework:meta_graph",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/platform:gfile",
         "//tensorflow/python/platform:tf_logging",
@@ -87,10 +87,10 @@
         ":writer_cache",
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python/client:session",
-        "//tensorflow/python/framework",
         "//tensorflow/python/framework:constant_op",
         "//tensorflow/python/framework:dtypes",
         "//tensorflow/python/framework:errors",
+        "//tensorflow/python/framework:meta_graph",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/framework:test_lib",
         "//tensorflow/python/ops:summary_ops_v2",
diff --git a/tensorflow/python/tools/BUILD b/tensorflow/python/tools/BUILD
index d5f0c8c..97b8d54 100644
--- a/tensorflow/python/tools/BUILD
+++ b/tensorflow/python/tools/BUILD
@@ -80,8 +80,8 @@
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python/checkpoint:checkpoint_management",
         "//tensorflow/python/client:session",
-        "//tensorflow/python/framework",
         "//tensorflow/python/framework:convert_to_constants",
+        "//tensorflow/python/framework:importer",
         "//tensorflow/python/platform:gfile",
         "//tensorflow/python/saved_model:loader",
         "//tensorflow/python/saved_model:tag_constants",
@@ -114,7 +114,7 @@
     deps = [
         ":saved_model_utils",
         "//tensorflow/python/client:session",
-        "//tensorflow/python/framework",
+        "//tensorflow/python/framework:importer",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/summary:summary_py",
         "@absl_py//absl:app",
@@ -375,7 +375,7 @@
         "//tensorflow/python/debug/wrappers:local_cli_wrapper",
         "//tensorflow/python/eager:def_function",
         "//tensorflow/python/eager:function",
-        "//tensorflow/python/framework",
+        "//tensorflow/python/framework:meta_graph",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/framework:tensor_spec",
         "//tensorflow/python/lib/io:file_io",
diff --git a/tensorflow/python/tools/api/generator/BUILD b/tensorflow/python/tools/api/generator/BUILD
index c7b90c3..f019e7a 100644
--- a/tensorflow/python/tools/api/generator/BUILD
+++ b/tensorflow/python/tools/api/generator/BUILD
@@ -42,17 +42,12 @@
 
 py_strict_test(
     name = "create_python_api_test",
-    srcs = [
-        "create_python_api.py",
-        "create_python_api_test.py",
-    ],
+    srcs = ["create_python_api_test.py"],
     python_version = "PY3",
     srcs_version = "PY3",
     deps = [
-        ":doc_srcs",
-        "//tensorflow/python/eager:wrap_function",
+        ":create_python_api",
         "//tensorflow/python/platform:client_testlib",
-        "//tensorflow/python/util:tf_decorator",
         "//tensorflow/python/util:tf_export",
     ],
 )
diff --git a/tensorflow/python/tools/optimize_for_inference.py b/tensorflow/python/tools/optimize_for_inference.py
index d580e75..9cef953 100644
--- a/tensorflow/python/tools/optimize_for_inference.py
+++ b/tensorflow/python/tools/optimize_for_inference.py
@@ -19,6 +19,8 @@
 created to train a model, that help reduce the amount of computation needed when
 the network is used only for inference. These include:
 
+ - Convert given PlaceholderWithDefault or Placeholder nodes to Constant
+
  - Removing training-only operations like checkpoint saving.
 
  - Stripping out parts of the graph that are never reached.
@@ -85,7 +87,9 @@
       FLAGS.input_names.split(","),
       FLAGS.output_names.split(","),
       _parse_placeholder_types(FLAGS.placeholder_type_enum),
-      FLAGS.toco_compatible)
+      FLAGS.toco_compatible,
+      FLAGS.placeholder_to_const_names.split(","),
+  )
 
   if FLAGS.frozen_graph:
     f = gfile.GFile(FLAGS.output, "w")
@@ -153,6 +157,16 @@
       If true, only use ops compatible with Tensorflow
       Lite Optimizing Converter.\
       """)
+  parser.add_argument(
+      "--placeholder_to_const_names",
+      type=str,
+      default="",
+      help="""\
+      List of PlaceholderWithDefault or Placeholder node names and
+      their new value to be converted to Constant node, comma separated.
+      eg: --placeholder_to_const_names=phase_train=False\
+      """,
+  )
   return parser.parse_known_args()
 
 
diff --git a/tensorflow/python/tools/optimize_for_inference_lib.py b/tensorflow/python/tools/optimize_for_inference_lib.py
index 53edbf2..c5ffb8e 100644
--- a/tensorflow/python/tools/optimize_for_inference_lib.py
+++ b/tensorflow/python/tools/optimize_for_inference_lib.py
@@ -89,6 +89,11 @@
     "FusedBatchNorm": "epsilon",
     "FusedBatchNormV3": "epsilon",
 }
+# List of standard PlaceholderWithDefault names with default value to be changed to
+# Const nodes for inference.
+PLACEHOLDER_WITH_DEFAULT_LIST = {
+    "keras_learning_phase": "False",
+}
 
 
 def optimize_for_inference(
@@ -97,6 +102,7 @@
     output_node_names: Sequence[str],
     placeholder_type_enum: int,
     toco_compatible: bool = False,
+    placeholder_to_const_names=None,
 ) -> graph_pb2.GraphDef:
   """Applies a series of inference optimizations on the input graph.
 
@@ -110,12 +116,17 @@
       a list that specifies one value per input node name.
     toco_compatible: Boolean, if True, only runs optimizations that result in
       TOCO compatible graph operations (default=False).
+    placeholder_to_const_names: A list of names of the PlaceholderWithDefault
+      nodes to be converted to Constant.
 
   Returns:
     An optimized version of the input graph.
   """
   ensure_graph_is_valid(input_graph_def)
   optimized_graph_def = input_graph_def
+  optimized_graph_def = convert_placeholder_to_const(
+      optimized_graph_def, placeholder_to_const_names
+  )
   optimized_graph_def = strip_unused_lib.strip_unused(
       optimized_graph_def,
       input_node_names,
@@ -134,6 +145,47 @@
   return optimized_graph_def
 
 
+def strtobool(val_str):
+  """Return boolean value of it's equivalent string representation"""
+  if val_str in ("True", "true"):
+    return True
+  elif val_str in ("False", "false"):
+    return False
+  else:
+    tf_logging.warning(
+        "Wrong string values.       Supports False/false or True/true only."
+        " val_str = ",
+        val_str,
+    )
+    return False
+
+
+def parse_entry(entry):
+  """Parse a "key=value" pair separated by '='
+
+  eg: var_name=False
+  """
+  items = entry.split("=")
+  key = items[0].strip()  # remove blanks around keys
+  if len(items) > 1:
+    value = items[1]
+    return (key, value)
+  else:
+    return (None, None)
+
+
+def parse_nodes_dict(nodes):
+  """Parse a series of key-value pairs and return a dictionary"""
+  d = {}
+
+  if nodes:
+    for node in nodes:
+      key, val = parse_entry(node)
+      if key is not None:
+        d[key] = val
+  return d
+
+
 def ensure_graph_is_valid(graph_def: graph_pb2.GraphDef) -> None:
   """Makes sure that the graph is internally consistent.
 
@@ -592,3 +644,87 @@
 
   result_graph_def.node.extend(new_ops)
   return result_graph_def
+
+
+def convert_placeholder_to_const(input_graph_def, nodes_to_convert=None):
+  """Rename the PlaceHolderWithDefault node to constant
+
+  In a frozen graph, PlaceholderWithDefault nodes can be converted to
+  Constant op nodes with same value. This will help simplify the graph.
+
+  Args:
+    input_graph_def: A GraphDef containing a model.
+    nodes_to_convert: A list of PlaceholderWithDefault or Placeholder nodes to
+      be converted to Constants with their new value.
+
+  Returns:
+    modified graph with PlaceholderWithDefault node converted to Constant node
+  """
+
+  input_node_map = {}
+  for node in input_graph_def.node:
+    if node.name not in input_node_map:
+      input_node_map[node.name] = node
+    else:
+      raise ValueError("Duplicate node names detected for ", node.name)
+
+  # create a dictionary of nodes to be converted to Const
+  dict_to_change = {}
+  for key in PLACEHOLDER_WITH_DEFAULT_LIST:
+    dict_to_change[key] = PLACEHOLDER_WITH_DEFAULT_LIST[key]
+
+  if nodes_to_convert is not None and len(nodes_to_convert) > 0:
+    dict_list = parse_nodes_dict(nodes_to_convert)
+    dict_to_change.update(dict_list)
+
+  ph_node_list = []
+  for ph_node in dict_to_change:
+    if not ph_node and ph_node not in input_node_map:
+      continue
+    ph_node_list.append(ph_node)
+
+  # if no nodes found, then nothing to change
+  if not ph_node_list:
+    tf_logging.warning(
+        "No PlaceholderWithDefault nodes found to convert to "
+        "Constant. Maybe check the spellings"
+    )
+    return input_graph_def
+
+  result_graph_def = graph_pb2.GraphDef()
+  for node in input_graph_def.node:
+    is_replaced = False
+    new_node = node_def_pb2.NodeDef()
+    if node.op == "PlaceholderWithDefault" or node.op == "Placeholder":
+      match_key = [
+          find_key
+          for find_key in dict_to_change.keys()
+          if find_key in node.name
+      ]
+      if len(match_key) > 0:
+        if dtypes.bool.as_datatype_enum == node.attr["dtype"].type:
+          new_val_str = dict_to_change[match_key[0]]
+          new_node.op = "Const"
+          new_node.name = node.name
+          new_node.attr["dtype"].CopyFrom(node.attr["dtype"])
+          new_node.attr["value"].CopyFrom(
+              attr_value_pb2.AttrValue(
+                  tensor=tensor_util.make_tensor_proto(
+                      strtobool(new_val_str), dtype=dtypes.bool, shape=[]
+                  )
+              )
+          )
+          is_replaced = True
+        else:
+          tf_logging.warning(
+              "Not converting to Const. Currently only bool            "
+              " PlaceholderWithDefault or Placeholder can be converted to"
+              " const.             current dtype = ",
+              node.attr["dtype"],
+          )
+
+    if not is_replaced:
+      new_node.CopyFrom(node)
+
+    result_graph_def.node.extend([new_node])
+  return result_graph_def
diff --git a/tensorflow/python/tools/optimize_for_inference_test.py b/tensorflow/python/tools/optimize_for_inference_test.py
index f15c75c..73563e1 100644
--- a/tensorflow/python/tools/optimize_for_inference_test.py
+++ b/tensorflow/python/tools/optimize_for_inference_test.py
@@ -124,6 +124,127 @@
         graph_def, [], [add_name], dtypes.float32.as_datatype_enum)
     self.assertProtoEquals(expected_output, output)
 
+  def testConvertPlaceholderToConstant(self):
+    """Build the placeholder testing graph."""
+    placeholder_name = "phase_train"
+    relu_name = "r_relu"
+
+    g_def = graph_pb2.GraphDef()
+
+    ph_node = node_def_pb2.NodeDef()
+    ph_node.op = "Placeholder"
+    ph_node.name = placeholder_name
+
+    self.set_attr_dtype(ph_node, "dtype", dtypes.bool)
+    g_def.node.extend([ph_node])
+
+    r_node = self.create_node_def("Relu", relu_name, [placeholder_name])
+    g_def.node.extend([r_node])
+
+    opt_graph_def = optimize_for_inference_lib.optimize_for_inference(
+        g_def,
+        [],
+        [relu_name],
+        dtypes.float32.as_datatype_enum,
+        placeholder_to_const_names=["phase_train=False"],
+    )
+    for node in opt_graph_def.node:
+      self.assertNotEqual("Placeholder", node.op)
+      if node.name == "phase_train":
+        self.assertEqual(node.op, "Const")
+        const_value = optimize_for_inference_lib.values_from_const(node)
+        self.assertEqual(const_value, False)
+
+  def testConvertPlaceholderToConstant2(self):
+    """Build the placeholder testing graph."""
+    placeholder_name = "phase_train"
+    relu_name = "r_relu"
+
+    g_def = graph_pb2.GraphDef()
+
+    ph_node = node_def_pb2.NodeDef()
+    ph_node.op = "Placeholder"
+    ph_node.name = placeholder_name
+
+    self.set_attr_dtype(ph_node, "dtype", dtypes.bool)
+    g_def.node.extend([ph_node])
+
+    r_node = self.create_node_def("Relu", relu_name, [placeholder_name])
+    g_def.node.extend([r_node])
+
+    opt_graph_def = optimize_for_inference_lib.convert_placeholder_to_const(
+        g_def, ["phase_train=True"]
+    )
+    for node in opt_graph_def.node:
+      self.assertNotEqual("Placeholder", node.op)
+      if node.name == "phase_train":
+        self.assertEqual(node.op, "Const")
+        const_value = optimize_for_inference_lib.values_from_const(node)
+        self.assertEqual(const_value, True)
+
+  def testConvertPlaceholderWithDefaultToConstant(self):
+    """Build the placeholder_with_default testing graph."""
+    placeholder_name = "keras_learning_phase"
+    a_constant_name = "a_constant"
+    relu_name = "r_relu"
+
+    g_def = graph_pb2.GraphDef()
+    const_node = self.create_constant_node_def(
+        a_constant_name, value=True, dtype=dtypes.bool, shape=[]
+    )
+    g_def.node.extend([const_node])
+
+    ph_node = self.create_node_def(
+        "PlaceholderWithDefault", placeholder_name, [a_constant_name]
+    )
+    self.set_attr_dtype(ph_node, "dtype", dtypes.bool)
+    g_def.node.extend([ph_node])
+
+    r_node = self.create_node_def("Relu", relu_name, [placeholder_name])
+    g_def.node.extend([r_node])
+
+    opt_graph_def = optimize_for_inference_lib.convert_placeholder_to_const(
+        g_def
+    )
+    for node in opt_graph_def.node:
+      self.assertNotEqual("PlaceholderWithDefault", node.op)
+      if node.name == "keras_learning_phase":
+        self.assertEqual(node.op, "Const")
+        const_value = optimize_for_inference_lib.values_from_const(node)
+        # Notice optimize_for_inference rewrites keras_learning_phase to False
+        self.assertEqual(const_value, False)
+
+  def testConvertPlaceholderWithDefaultToConstant2(self):
+    """Build the placeholder_with_default testing graph."""
+    placeholder_name = "keras_learning_phase"
+    a_constant_name = "a_constant"
+    relu_name = "r_relu"
+
+    g_def = graph_pb2.GraphDef()
+    const_node = self.create_constant_node_def(
+        a_constant_name, value=True, dtype=dtypes.bool, shape=[]
+    )
+    g_def.node.extend([const_node])
+
+    ph_node = self.create_node_def(
+        "PlaceholderWithDefault", placeholder_name, [a_constant_name]
+    )
+    self.set_attr_dtype(ph_node, "dtype", dtypes.bool)
+    g_def.node.extend([ph_node])
+
+    r_node = self.create_node_def("Relu", relu_name, [placeholder_name])
+    g_def.node.extend([r_node])
+
+    opt_graph_def = optimize_for_inference_lib.optimize_for_inference(
+        g_def, [], [relu_name], dtypes.float32.as_datatype_enum
+    )
+    for node in opt_graph_def.node:
+      self.assertNotEqual("PlaceholderWithDefault", node.op)
+      if node.name == "keras_learning_phase":
+        self.assertEqual(node.op, "Const")
+        const_value = optimize_for_inference_lib.values_from_const(node)
+        self.assertEqual(const_value, False)
+
   @test_util.run_deprecated_v1
   def testFoldBatchNorms(self):
     with self.cached_session() as sess:
@@ -350,7 +471,6 @@
       self.assertNotEqual("Conv2D", node.op)
       self.assertNotEqual("MirrorPad", node.op)
 
-
   @test_util.run_deprecated_v1
   def testFusePadAndConv(self):
     with self.cached_session() as sess:
diff --git a/tensorflow/python/tpu/BUILD b/tensorflow/python/tpu/BUILD
index de6db21..95200d2 100644
--- a/tensorflow/python/tpu/BUILD
+++ b/tensorflow/python/tpu/BUILD
@@ -37,12 +37,19 @@
     srcs = ["tpu_test_wrapper.bzl"],
 )
 
+py_strict_library(
+    name = "tpu_test_wrapper",
+    srcs = ["tpu_test_wrapper.py"],
+    srcs_version = "PY3",
+    deps = [
+        "//tensorflow/python/platform:flags",
+        "//tensorflow/python/util:tf_decorator",
+    ],
+)
+
 py_strict_test(
     name = "tpu_test_wrapper_test",
-    srcs = [
-        "tpu_test_wrapper.py",
-        "tpu_test_wrapper_test.py",
-    ],
+    srcs = ["tpu_test_wrapper_test.py"],
     main = "tpu_test_wrapper_test.py",
     python_version = "PY3",
     srcs_version = "PY3",
@@ -51,9 +58,9 @@
         "no_pip",
     ],
     deps = [
+        ":tpu_test_wrapper",
         "//tensorflow/python/platform:client_testlib",
         "//tensorflow/python/platform:flags",
-        "//tensorflow/python/util:tf_decorator",
         "@absl_py//absl/testing:flagsaver",
     ],
 )
@@ -259,11 +266,11 @@
         ":tpu_replication",
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python/eager:monitoring",
-        "//tensorflow/python/framework",
         "//tensorflow/python/framework:constant_op",
         "//tensorflow/python/framework:dtypes",
         "//tensorflow/python/framework:func_graph",
         "//tensorflow/python/framework:function",
+        "//tensorflow/python/framework:graph_io",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/framework:tensor",
         "//tensorflow/python/framework:tensor_util",
@@ -512,9 +519,9 @@
         ":tpu_py",
         ":tpu_replication",
         "//tensorflow/python/eager:def_function",
-        "//tensorflow/python/framework",
         "//tensorflow/python/framework:constant_op",
         "//tensorflow/python/framework:dtypes",
+        "//tensorflow/python/framework:importer",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/layers",
         "//tensorflow/python/ops:array_ops",
@@ -817,12 +824,44 @@
 )
 
 pytype_strict_library(
+    name = "tpu_embedding_v3_utils",
+    srcs = ["tpu_embedding_v3_utils.py"],
+    deps = [
+        "//tensorflow/core/tpu/kernels:sparse_core_layout_proto_py",
+        "//tensorflow/python/framework:constant_op",
+        "//tensorflow/python/framework:dtypes",
+        "//tensorflow/python/framework:tensor",
+        "//tensorflow/python/ops:array_ops",
+        "//tensorflow/python/ops:manip_ops",
+        "//tensorflow/python/ops:variables",
+        "//tensorflow/python/trackable:base",
+    ],
+)
+
+tf_py_strict_test(
+    name = "tpu_embedding_v3_utils_test",
+    srcs = ["tpu_embedding_v3_utils_test.py"],
+    deps = [
+        ":tpu_embedding_v3_utils",
+        "//tensorflow/python/compat:v2_compat",
+        "//tensorflow/python/eager:test",
+        "//tensorflow/python/framework:constant_op",
+        "//tensorflow/python/ops:array_ops",
+        "//tensorflow/python/ops:math_ops",
+        "//tensorflow/python/platform:client_testlib",
+        "@absl_py//absl/testing:parameterized",
+    ],
+)
+
+pytype_strict_library(
     name = "tpu_embedding_for_serving",
     srcs = ["tpu_embedding_for_serving.py"],
     srcs_version = "PY3",
     deps = [
         ":tpu_embedding_base",
         ":tpu_embedding_v2_utils",
+        ":tpu_embedding_v3_utils",
+        "//tensorflow/core/tpu/kernels:sparse_core_layout_proto_py",
         "//tensorflow/python/distribute:distribute_lib",
         "//tensorflow/python/distribute:tpu_strategy",
         "//tensorflow/python/framework:constant_op",
@@ -837,7 +876,6 @@
         "//tensorflow/python/ops:sparse_ops",
         "//tensorflow/python/ops:variables",
         "//tensorflow/python/ops/ragged:ragged_tensor",
-        "//tensorflow/python/trackable:base",
         "//tensorflow/python/types:core",
         "//tensorflow/python/util:nest",
         "//tensorflow/python/util:tf_export",
diff --git a/tensorflow/python/tpu/ops/BUILD b/tensorflow/python/tpu/ops/BUILD
index dd61c8c..bc70f4c 100644
--- a/tensorflow/python/tpu/ops/BUILD
+++ b/tensorflow/python/tpu/ops/BUILD
@@ -44,10 +44,15 @@
 tf_gen_op_wrapper_py(
     name = "gen_xla_ops",
     out = "gen_xla_ops.py",
+    api_def_srcs = [
+        "//tensorflow/core/api_def:base_api_def",
+        "//tensorflow/core/api_def:python_api_def",
+    ],
     op_allowlist = [
         "ConvertToCooTensor",
         "GetMinibatchesInCsrWithPhysicalReplica",
         "GetMinibatchSplitsWithPhysicalReplica",
+        "StoreMinibatchStatisticsInFdo",
         "GlobalIterId",
         "TPUCopyWithDynamicShape",
         "TPUAnnotateTensorsWithDynamicShape",
diff --git a/tensorflow/python/tpu/tpu_embedding_for_serving.py b/tensorflow/python/tpu/tpu_embedding_for_serving.py
index a85f50b..9e5da7d 100644
--- a/tensorflow/python/tpu/tpu_embedding_for_serving.py
+++ b/tensorflow/python/tpu/tpu_embedding_for_serving.py
@@ -18,6 +18,7 @@
 
 from absl import logging
 
+from tensorflow.core.tpu.kernels import sparse_core_layout_pb2
 from tensorflow.python.distribute import distribute_lib
 from tensorflow.python.distribute import tpu_strategy
 from tensorflow.python.framework import dtypes
@@ -34,7 +35,7 @@
 from tensorflow.python.ops.ragged import ragged_tensor
 from tensorflow.python.tpu import tpu_embedding_base
 from tensorflow.python.tpu import tpu_embedding_v2_utils
-from tensorflow.python.trackable import base as trackable_base
+from tensorflow.python.tpu import tpu_embedding_v3_utils
 from tensorflow.python.types import core
 from tensorflow.python.util import nest
 from tensorflow.python.util.tf_export import tf_export
@@ -149,131 +150,95 @@
       with ops.init_scope():
         self.build()
 
+  # TODO(silkyarora) Update the tests for all TPU embedding to expect this
+  # possibly empty information in checkpoints.
+  def _maybe_delete_sc_layouts_from_checkpoint(self):
+    # Remove the sparse_core_table_layouts from the checkpoint, it is only
+    # required for sparsecore.
+    if (
+        hasattr(
+            self,
+            tpu_embedding_v3_utils.SPARSECORE_LAYOUTS_CHECKPOINT_KEY,
+        )
+        and not self._get_sparse_core_table_layouts_str()
+    ):
+      delattr(
+          self,
+          tpu_embedding_v3_utils.SPARSECORE_LAYOUTS_CHECKPOINT_KEY,
+      )
+
   def build(self):
     """Create variables and slots variables for TPU embeddings."""
     super().build()
-    # Remove the training_restore_info from the checkpoint, it was only
-    # required to restore from sparsecore training.
-    if hasattr(self, "training_restore_info"):
-      delattr(self, "training_restore_info")
+    self._maybe_delete_sc_layouts_from_checkpoint()
 
-  def _unshuffle_from_sc_to_cpu(self, t: tensor.Tensor) -> tensor.Tensor:
-    num_tpu_devices, num_sc_per_chip = self.training_restore_info.read_value()
-    # TODO(silkyarora): Is there a better way to confirm that there is nothing
-    # to be done?
-    if num_tpu_devices == 0 and num_sc_per_chip == 0:
-      return t
-    num_sc_devices = num_tpu_devices * num_sc_per_chip
-    old_shape = t.shape
-    # The width of the table must be a multiple of number of SC devices. The
-    # tpu strategy does this round off at training time so we expect the
-    # checkpoints value to meet this requirement.
-    assert t.shape[0] % num_sc_devices == 0
-    intermediate_tensor = array_ops.reshape(
-        t, (num_sc_devices, t.shape[0] // num_sc_devices, t.shape[1])
-    )
-    intermediate_tensor = array_ops.transpose(intermediate_tensor, (1, 0, 2))
-    return array_ops.reshape(intermediate_tensor, old_shape)
-
-  def _remove_padding_from_sc(
-      self, value_in_checkpoint: tensor.Tensor, variable_shape: tuple[int, int]
-  ) -> tensor.Tensor:
-    checkpoint_value_shape = value_in_checkpoint.shape.as_list()
-    # If the checkpoint shape is at least the size of the variable, we conclude
-    # that the extra rows and cols must be padding.
-    is_init_value_padded = all(
-        [i >= j for i, j in zip(checkpoint_value_shape, variable_shape)]
-    )
-    if not is_init_value_padded:
-      return value_in_checkpoint
-    # checkpoint has padding so we can remove it.
-    begin = [0] * len(checkpoint_value_shape)
-    return array_ops.slice(
-        value_in_checkpoint, begin=begin, size=variable_shape
-    )
-
-  def _create_variables(
-      self, table: tpu_embedding_v2_utils.TableConfig, trainable: bool
-  ) -> Dict[str, tf_variables.Variable]:
-    """Create all variables including table variables and slot variables."""
-    variable_shape = (table.vocabulary_size, table.dim)
-
+  def _track_restore_info_for_cpu(self) -> None:
     def getter(name, shape, dtype, initializer, trainable):
+      del shape
       # _add_variable_with_custom_getter clears the shape sometimes, so we
       # take the global shape from outside the getter.
-      del shape
-      if isinstance(initializer, trackable_base.CheckpointInitialValueCallable):
-        checkpoint_init_value = initializer(variable_shape).wrapped_value
-        restore_uid = initializer.restore_uid
-        unshuffled = self._unshuffle_from_sc_to_cpu(checkpoint_init_value)
-        truncated = self._remove_padding_from_sc(unshuffled, variable_shape)
-        var = tf_variables.Variable(
-            name=name,
-            initial_value=truncated,
-            shape=variable_shape,
-            dtype=dtype,
-            trainable=trainable,
-        )
-        # Maybe initialize the variable
-        var._maybe_initialize_trackable()  # pylint:disable=protected-access
-        # Update the uid for this variable from the checkpoint init value.
-        # This lets the checkpoint deferred restoration code know that this
-        # variable was restored while creation, so no need to restore it from
-        # the checkpoint later.
-        if restore_uid is not None:
-          var._update_uid = initializer.restore_uid  # pylint:disable=protected-access
-        return var
-
-      initial_value = functools.partial(
-          initializer, variable_shape, dtype=dtype
-      )
+      initial_value = functools.partial(initializer, dtype=dtype)
       return tf_variables.Variable(
           name=name,
           initial_value=initial_value,
-          shape=variable_shape,
+          shape=None,
           dtype=dtype,
           trainable=trainable,
       )
 
-    def variable_creator(name, initializer, shape, trainable=True):
-      # Use add_variable_with_custom_getter here so that we take advantage of
-      # the checkpoint loading to allow restore before the variables get
-      # created which avoids double initialization.
-      return self._add_variable_with_custom_getter(
-          name=name,
-          initializer=initializer,
-          shape=shape,
-          dtype=dtypes.float32,
-          getter=getter,
-          trainable=trainable,
-      )
+    def empty_string(dtype: dtypes.DType):
+      return tf_constant("", dtype=dtype)
 
-    parameters = variable_creator(
-        table.name, table.initializer, variable_shape, trainable=trainable
-    )
-
-    def slot_creator(name, initializer):
-      return variable_creator(table.name + "/" + name, initializer, False)
-
-    if table.optimizer is not None:
-      slot_vars = table.optimizer._create_slots(parameters, slot_creator)  # pylint: disable=protected-access
-    else:
-      slot_vars = {}
-    slot_vars["parameters"] = parameters
-    return slot_vars
-
-  def _track_restore_info_for_cpu(self) -> None:
-    self.training_restore_info = tf_variables.Variable(
-        name="training_restore_info",
-        initial_value=tf_constant(
-            [0, 0],
-            shape=(2,),
-            dtype=dtypes.int32,
+    # _add_variable_with_custom_getter is used here to restore from checkpoint
+    # at creation time. The layouts from sparse core must be restored from
+    # checkpoint and before any other tables are restored
+    setattr(
+        self,
+        tpu_embedding_v3_utils.SPARSECORE_LAYOUTS_CHECKPOINT_KEY,
+        self._add_variable_with_custom_getter(
+            name=tpu_embedding_v3_utils.SPARSECORE_LAYOUTS_CHECKPOINT_KEY,
+            initializer=empty_string,
+            dtype=dtypes.string,
+            getter=getter,
+            trainable=False,
         ),
-        shape=(2,),
-        dtype=dtypes.int32,
     )
 
+  def _get_sparse_core_table_layouts_str(self) -> bytes:
+    layouts_str = getattr(
+        self,
+        tpu_embedding_v3_utils.SPARSECORE_LAYOUTS_CHECKPOINT_KEY,
+    )
+    return layouts_str.read_value().numpy()
+
+  def _create_variables_from_stacked_tables(self):
+    sc_layouts = sparse_core_layout_pb2.SparseCoreTableLayouts()
+    sc_layouts.ParseFromString(self._get_sparse_core_table_layouts_str())
+    stacked_table_name_to_layouts = {}
+    for layout in sc_layouts.tables:
+      stacked_tables_list = stacked_table_name_to_layouts.setdefault(
+          layout.stacked_table_name, []
+      )
+      stacked_tables_list.append(layout)
+    table_to_config = {table.name: table for table in self._table_config}
+    variables = {}
+    for stacked_table_name, layouts in stacked_table_name_to_layouts.items():
+      logging.info(
+          "Loading stacked table state variables(%s) for %s tables",
+          stacked_table_name,
+          len(layouts),
+      )
+      stacked_var_trackable = (
+          tpu_embedding_v3_utils.SparseCoreStackedTableTrackable(
+              layouts, table_to_config
+          )
+      )
+      # The stacked table is added as trackable to the embedding so that the
+      # checkpoint key corresponsing to stacked table is read.
+      self._track_trackable(stacked_var_trackable, stacked_table_name)
+      variables.update(stacked_var_trackable.get_vars())
+    return variables
+
   def _create_variables_and_slots(
       self,
   ) -> Dict[str, Dict[str, tf_variables.Variable]]:
@@ -285,8 +250,14 @@
     """
     self._track_restore_info_for_cpu()
     variables = {}
+    # If there are stacked variables from SC checkpoint process those
+    # first
+    stacked_variables = self._create_variables_from_stacked_tables()
     for table in self._table_config:
-      variables[table.name] = self._create_variables(table, trainable=True)
+      if table.name in stacked_variables:
+        variables[table.name] = {"parameters": stacked_variables[table.name]}
+      else:
+        variables[table.name] = self._create_variables(table, trainable=True)
     return variables
 
   def embedding_lookup(
diff --git a/tensorflow/python/tpu/tpu_embedding_v2_utils.py b/tensorflow/python/tpu/tpu_embedding_v2_utils.py
index 858c74c..0ea650b 100644
--- a/tensorflow/python/tpu/tpu_embedding_v2_utils.py
+++ b/tensorflow/python/tpu/tpu_embedding_v2_utils.py
@@ -1021,14 +1021,19 @@
 
   """
 
-  def __init__(self,
-               vocabulary_size: int,
-               dim: int,
-               initializer: Optional[Callable[[Any], None]] = None,
-               optimizer: Optional[_Optimizer] = None,
-               combiner: Text = "mean",
-               name: Optional[Text] = None,
-               quantization_config: QuantizationConfig = None):
+  def __init__(
+      self,
+      vocabulary_size: int,
+      dim: int,
+      initializer: Optional[Callable[[Any], None]] = None,
+      optimizer: Optional[_Optimizer] = None,
+      combiner: Text = "mean",
+      name: Optional[Text] = None,
+      quantization_config: QuantizationConfig = None,
+      # TODO(b/295372790): Change the type to SparseCoreTableLayout after it is
+      # open sourced.
+      layout: Optional[Any] = None,
+  ):
     """Embedding table configuration.
 
     Args:
@@ -1054,6 +1059,9 @@
       quantization_config: The simulated quantization config. An instance of
         `tf.tpu.experimental.embedding.QuantizationConfig`. See the class for
         more documentation.
+      layout: If the table already has its layout computed, you can pass it in
+        here. Otherwise, we will compute it for you. Most users should leave
+        this as None.
 
     Returns:
       `TableConfig`.
@@ -1100,6 +1108,7 @@
     self.combiner = combiner
     self.name = name
     self.quantization_config = quantization_config
+    self.layout = layout
 
   def __repr__(self):
     # If using the default initializer, just print "None" for clarity.
@@ -1268,11 +1277,6 @@
       raise ValueError(
           f"Argument `max_sequence_length` must be an int and must be >= 0. "
           f"Received: {max_sequence_length}")
-    if name is None:
-      logging.warning(
-          "Name of the Feature config must be specified for running on"
-          " SparseCore. Different feature configs must have unique names."
-      )
 
     self.table = table
     self.max_sequence_length = max_sequence_length
diff --git a/tensorflow/python/tpu/tpu_embedding_v3.py b/tensorflow/python/tpu/tpu_embedding_v3.py
index 3d4966f..ca20bfe 100644
--- a/tensorflow/python/tpu/tpu_embedding_v3.py
+++ b/tensorflow/python/tpu/tpu_embedding_v3.py
@@ -530,13 +530,28 @@
           trainable=False,
       )
 
-    parameters = variable_creator(stacked_table_name, table_initialize_fn)
+    with variable_scope.variable_creator_scope(
+        make_sharded_variable_creator(self._strategy, shape_is_local=False)
+    ):
+      parameters = variable_creator(stacked_table_name, table_initialize_fn)
 
     def slot_creator(name, initializer):
       return variable_creator(stacked_table_name + "/" + name, initializer)
 
     if optimizer is not None:
-      slot_vars = optimizer._create_slots(parameters, slot_creator)  # pylint: disable=protected-access
+      # FIXME(b/305882915): tensorflow_recommender calls into keras legacy
+      # optimizer, which creates the slot variable with
+      # TPUShardedEmbeddingVariable.shape as shape, but that shape attribute
+      # returns a local shape. We shall change the shape attribute to
+      # return a global shape, but such a change will break users who already
+      # depend on the attribute being local.
+      shape_is_local = optimizer.slot_variable_creation_fn is not None
+      with variable_scope.variable_creator_scope(
+          make_sharded_variable_creator(
+              self._strategy, shape_is_local=shape_is_local
+          )
+      ):
+        slot_vars = optimizer._create_slots(parameters, slot_creator)  # pylint: disable=protected-access
     else:
       slot_vars = {}
     slot_vars["parameters"] = parameters
@@ -629,11 +644,11 @@
     self._table_to_sample_count = {
         table_name: 0 for table_name in self._stacked_table_to_tables
     }
-    for _, feature in self._flat_features:
+    for feature_path, feature in self._flat_features:
       stacked_table_name = self._table_to_stacked_table_offset[
           feature.table.name
       ][0]
-      self._feature_to_sample_offset[feature.name] = (
+      self._feature_to_sample_offset[feature_path] = (
           self._table_to_sample_count[stacked_table_name]
       )
       self._table_to_sample_count[stacked_table_name] += functools.reduce(
@@ -651,12 +666,9 @@
     """
     variables = {}
     for stacked_table_name, tables in self._stacked_table_to_tables.items():
-      with variable_scope.variable_creator_scope(
-          make_sharded_variable_creator(self._strategy)
-      ):
-        variables[stacked_table_name] = self._create_variables(
-            tables, stacked_table_name=stacked_table_name
-        )
+      variables[stacked_table_name] = self._create_variables(
+          tables, stacked_table_name=stacked_table_name
+      )
     return variables
 
   def _maybe_build(self):
@@ -1033,13 +1045,13 @@
     table_to_list_of_coos = {
         table_name: ([], [], []) for table_name in stacked_table_to_tables
     }
-    for inp, weight, (_, feature) in zip(
+    for inp, weight, (feature_path, feature) in zip(
         flat_inputs, flat_weights, flat_features
     ):
       table_name, col_offset, col_shift = table_to_stacked_table_offset[
           feature.table.name
       ]
-      row_offset = feature_to_sample_offset[feature.name]
+      row_offset = feature_to_sample_offset[feature_path]
       # Consider making this into one op per table rather than per feature?
       row_ids, col_ids, gains = TPUEmbeddingV2._convert_input_feature_to_coo(
           inp,
@@ -1312,13 +1324,10 @@
           flat_weights=flat_weights,
       )
     elif device is None:
-      # This is used by keras function tracing.
+      # This is used by keras function tracing. Use any of the TPU devices
+      # and trace once for a single device.
       tpu_devices = self._strategy.extended._tpu_devices  # pylint:disable=protected-access
-      num_replicas, num_cores_per_replica = tpu_devices.shape
-      if num_replicas > 1 or num_cores_per_replica > 1:
-        raise NotImplementedError(
-            "SPMD is not implemented, use strategy.run instead."
-        )
+
       with ops.device(device_util.get_host_for_device(tpu_devices[0][0])):
         return TPUEmbeddingV2.preprocess_features(
             num_replicas_in_sync=self._strategy.num_replicas_in_sync,
@@ -1719,14 +1728,14 @@
         )
         for table_name in stacked_table_to_tables
     }
-    for inp, weight, (_, feature) in zip(
+    for inp, weight, (feature_path, feature) in zip(
         flat_inputs, flat_weights, flat_features
     ):
       table_name, col_offset, col_shift = table_to_stacked_table_offset[
           feature.table.name
       ]
       stacked_table_sample_count = stacked_table_to_sample_count[table_name]
-      row_offset = feature_to_sample_offset[feature.name]
+      row_offset = feature_to_sample_offset[feature_path]
       # Consider making this into one op per table rather than per feature?
       row_ids_list, col_ids_list, gains_list, sample_count = (
           TPUEmbeddingV2._experimental_convert_input_feature_to_list_of_coo_tensors(
@@ -2208,12 +2217,13 @@
 
 
 def make_sharded_variable_creator(
-    strategy: distribute_lib.Strategy,
+    strategy: distribute_lib.Strategy, shape_is_local: bool
 ) -> Callable[..., Any]:
   """Create a variable creator which shards across all the tpu device.
 
   Args:
     strategy: a TPUStrategy object.
+    shape_is_local: If the shape to the creator is per replica.
 
   Returns:
     The sharded variable creator.
@@ -2255,7 +2265,8 @@
       )
 
     partition_shape = shape.as_list()
-    partition_shape[shard_dim] = partition_shape[shard_dim] // num_devices
+    if not shape_is_local:
+      partition_shape[shard_dim] = partition_shape[shard_dim] // num_devices
 
     unwrapped_arg_spec = tf_inspect.getargspec(unwrapped_initial_value)
     sharding_aware = "shard_info" in unwrapped_arg_spec.args
diff --git a/tensorflow/python/tpu/tpu_embedding_v3_utils.py b/tensorflow/python/tpu/tpu_embedding_v3_utils.py
new file mode 100644
index 0000000..ed30d99
--- /dev/null
+++ b/tensorflow/python/tpu/tpu_embedding_v3_utils.py
@@ -0,0 +1,215 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utils for Sparsecore Checkpoints."""
+
+import functools
+from typing import Any, Dict
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor
+from tensorflow.python.framework.constant_op import constant as tf_constant
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import manip_ops
+from tensorflow.python.ops import variables as tf_variables
+from tensorflow.python.trackable import base as trackable_base
+
+SPARSECORE_LAYOUTS_CHECKPOINT_KEY = "_sparse_core_table_layouts"
+
+
+def unshuffle_from_sc_to_cpu(
+    t: tensor.Tensor,
+    num_sparse_cores: int,
+    offset_in_shard: int,
+    size_in_shard: int,
+    shard_rotation: int = 0,
+) -> tensor.Tensor:
+  """Unshuffles the sparse core sharded embedding tables to unsharded.
+
+  This converts an input tensor respresenting stacked and sharded embedding
+  table into a specific embedding table variable by using the provided
+  metadata about the said table within the stacked, sharded embedding table.
+  Args:
+    t: The input stacked and sharded embedding table from sparsecore.
+    num_sparse_cores: The number of sparsecores, this determines the number of
+      shards that are present in the input t.
+    offset_in_shard: Offset within a shard where the queried table starts.
+    size_in_shard: size (number of rows) of this queried table within each shard
+      of the input t.
+    shard_rotation: The rotation of this table's shards.
+
+  Returns:
+    An embedding table which is part of the stacked embedding table t.
+  """
+  old_shape = t.shape
+  # The width of the table must be a multiple of number of SC devices. The
+  # tpu strategy does this round off at training time so we expect the
+  # checkpoints value to meet this requirement.
+  if t.shape[0] % num_sparse_cores != 0:
+    raise ValueError(
+        "The dim of table ({}) should be multiple of number of sparse cores"
+        " ({})".format(t.shape[1], num_sparse_cores)
+    )
+  # get shards in the input t
+  shards_t = array_ops.reshape(
+      t,
+      (
+          num_sparse_cores,
+          t.shape[0] // num_sparse_cores,
+          t.shape[1],
+      ),
+  )
+  # From each shard in t, get the part for just the queried table.
+  shards = shards_t[:, offset_in_shard : offset_in_shard + size_in_shard, :]
+  # This table's shards were rotated by `shard_rotation`, so we need to rotate
+  # the same amount in opposite direction
+  shards = manip_ops.roll(shards, -shard_rotation, axis=0)
+  # Re-arrange (transpose and reshape) the shards to get the queried embedding
+  # table.
+  intermediate_tensor = array_ops.transpose(shards, (1, 0, 2))
+  new_shape = size_in_shard * num_sparse_cores, old_shape[1]
+  return array_ops.reshape(intermediate_tensor, new_shape)
+
+
+def remove_padding_from_sc(
+    value_in_checkpoint: tensor.Tensor, variable_shape: tuple[int, int]
+) -> tensor.Tensor:
+  """Removes padding, if any, from sparsecore checkpoint.
+
+  Args:
+    value_in_checkpoint: input tensor value, usually from checkpoint.
+    variable_shape: Expected shape of tensor after removing padding.
+
+  Returns:
+    A slice of the input tensor to match the variable_shape if the
+    variable shape is a valid slice if the input tensor.
+  """
+  checkpoint_value_shape = value_in_checkpoint.shape.as_list()
+  # If the checkpoint shape is at least the size of the variable, we conclude
+  # that the extra rows and cols must be padding.
+  is_init_value_padded = all(
+      [i >= j for i, j in zip(checkpoint_value_shape, variable_shape)]
+  )
+  if not is_init_value_padded:
+    return value_in_checkpoint
+  # checkpoint has padding so we can remove it.
+  begin = [0] * len(checkpoint_value_shape)
+  return array_ops.slice(value_in_checkpoint, begin=begin, size=variable_shape)
+
+
+def map_indices_in_shard(
+    num_sparse_cores: int,
+    offset_in_shard: int,
+    shard_rotation: int,
+    row_indices: tensor.Tensor,
+) -> tuple[tensor.Tensor, tensor.Tensor]:
+  """Maps a row of a given table to its sparse core shard and position.
+
+  Maps a given a row index of a logical table and its layout in sparse core,
+  returns the index of the shard where the row is placed and its relative
+  position within
+  that sparse core shard.
+  Args:
+    num_sparse_cores: The number of sparsecores, this determines the number of
+      shards present.
+    offset_in_shard: Offset within a shard where the queried table starts.
+    shard_rotation: The rotation of this table's shards.
+    row_indices: row indices of the embedding table being looked up.
+
+  Returns:
+    A Tuple representing shard_index and position of the row in that shard.
+  """
+  shard_index = (
+      (row_indices % num_sparse_cores) + shard_rotation
+  ) % num_sparse_cores
+  position_in_shard = offset_in_shard + row_indices // num_sparse_cores
+  return (shard_index, position_in_shard)
+
+
+class SparseCoreLayoutsTrackable(trackable_base.Trackable):
+  """Trackable for sparsecore layouts used in training."""
+
+  def __init__(self, proto_str_tensor: tensor.Tensor):
+    self.value = proto_str_tensor
+
+  def _serialize_to_tensors(self) -> Dict[str, tensor.Tensor]:
+    return {trackable_base.VARIABLE_VALUE_KEY: self.value}
+
+  def _restore_from_tensors(
+      self, restored_tensors: Dict[str, tensor.Tensor]
+  ) -> None:
+    self.value = restored_tensors[trackable_base.VARIABLE_VALUE_KEY]
+
+
+class SparseCoreStackedTableTrackable(trackable_base.Trackable):
+  """Trackable for stacked tables generated from sparse core."""
+
+  def __init__(self, stacked_layouts, table_to_config):
+    self.vars = {}
+    self._stacked_layouts = stacked_layouts
+    for table_layout in stacked_layouts:
+      variable_shape = tuple(table_layout.unsharded_shape)
+      self.vars[table_layout.table_name] = tf_variables.Variable(
+          name=table_layout.table_name,
+          initial_value=functools.partial(
+              table_to_config[table_layout.table_name].initializer,
+              variable_shape,
+              dtype=dtypes.float32,
+          ),
+          shape=variable_shape,
+          dtype=dtypes.float32,
+      )
+
+  def _serialize_to_tensors(self) -> Any:
+    return {
+        # We need to export some variable here for restore to pick
+        # the checkpoint key the actual value is not important so 0 works
+        trackable_base.VARIABLE_VALUE_KEY: tf_constant(
+            0.0, dtype=dtypes.float32
+        ),
+    }
+
+  def _restore_from_tensors(self, restored_tensors: Dict[str, tensor.Tensor]):
+    def fn(restored_tensors):
+      value_from_checkpoint = restored_tensors[
+          trackable_base.VARIABLE_VALUE_KEY
+      ]
+      # Do unsharding to get the individual tables from the stacked table in
+      # checkpoint
+      for layout in self._stacked_layouts:
+        variable_shape = (
+            layout.unsharded_shape[0],
+            layout.unsharded_shape[1],
+        )
+        t_part = unshuffle_from_sc_to_cpu(
+            t=value_from_checkpoint,
+            num_sparse_cores=layout.num_sparse_cores,
+            offset_in_shard=layout.sparse_core_shard_row_offset,
+            size_in_shard=(
+                layout.unsharded_padded_shape[0] // layout.num_sparse_cores
+            ),
+            shard_rotation=layout.sparse_core_shard_rotation,
+        )
+        t_part = remove_padding_from_sc(t_part, variable_shape)
+        self.vars[layout.table_name].assign(t_part)
+
+    return fn(restored_tensors)
+
+  def get_var(self, name: str) -> tf_variables.Variable:
+    return self.vars[name]
+
+  def get_vars(self) -> Dict[str, tf_variables.Variable]:
+    return self.vars
+
+  def __repr__(self):
+    return "SparseCoreStackedTableTrackable({})".format(self.vars.keys())
diff --git a/tensorflow/python/tpu/tpu_embedding_v3_utils_test.py b/tensorflow/python/tpu/tpu_embedding_v3_utils_test.py
new file mode 100644
index 0000000..c790436
--- /dev/null
+++ b/tensorflow/python/tpu/tpu_embedding_v3_utils_test.py
@@ -0,0 +1,278 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test for tpu_embedding_v3_utils."""
+
+import collections
+
+from absl.testing import parameterized
+
+from tensorflow.python.compat import v2_compat
+from tensorflow.python.framework.constant_op import constant as tf_constant
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+from tensorflow.python.tpu import tpu_embedding_v3_utils as v3_utils
+
+TestTable = collections.namedtuple("Table", ["vocab", "dim", "shift"])
+
+
+def create_test_table_shards(
+    table: TestTable, num_sc_shards: int, table_data_start=0
+):
+  t = array_ops.reshape(
+      math_ops.range(
+          start=table_data_start,
+          delta=1,
+          limit=table_data_start + table.vocab * table.dim,
+      ),
+      (table.vocab, table.dim),
+  )
+  shards = [t[i::num_sc_shards, :] for i in range(num_sc_shards)]
+  if table.shift:
+    shards = collections.deque(shards)
+    shards.rotate(table.shift)
+    return (t, list(shards))
+  else:
+    return (t, shards)
+
+
+class TpuEmbeddingV3UtilsTest(test.TestCase, parameterized.TestCase):
+
+  def test_unpadding(self):
+    self.assertAllEqual(
+        v3_utils.remove_padding_from_sc(
+            array_ops.ones((4, 5)), variable_shape=(3, 2)
+        ),
+        array_ops.ones((3, 2)),
+    )
+    x = array_ops.reshape(math_ops.range(12), (3, 4))
+    self.assertAllEqual(
+        v3_utils.remove_padding_from_sc(x, variable_shape=(2, 2)),
+        tf_constant([[0, 1], [4, 5]]),
+    )
+    self.assertAllEqual(
+        v3_utils.remove_padding_from_sc(x, variable_shape=(3, 5)),
+        x,
+    )
+
+  @parameterized.named_parameters(
+      ("one", 8, 4, 4), ("two", 27, 6, 3), ("three", 128, 8, 4)
+  )
+  def test_unshuffle_one_table_basic(self, vocab, dim, num_sc):
+    # input vocab should be multiple of num_sc
+    self.assertEqual(vocab % num_sc, 0)
+    x, shards = create_test_table_shards(
+        TestTable(vocab=vocab, dim=dim, shift=0), num_sc
+    )
+    x_sharded = array_ops.concat(shards, axis=0)
+    self.assertAllEqual(
+        v3_utils.unshuffle_from_sc_to_cpu(
+            t=x_sharded,
+            num_sparse_cores=num_sc,
+            offset_in_shard=0,
+            size_in_shard=vocab // num_sc,
+            shard_rotation=0,
+        ),
+        x,
+    )
+
+  def test_unshuffle_stacking_basic(self):
+    num_sc = 4
+    ta = TestTable(vocab=12, dim=4, shift=0)
+    tb = TestTable(vocab=32, dim=4, shift=1)
+    x, x_shards = create_test_table_shards(ta, num_sc)
+    y, y_shards = create_test_table_shards(tb, num_sc)
+    stacked_shards = [
+        array_ops.concat([i, j], axis=0) for i, j in zip(x_shards, y_shards)
+    ]
+    stacked = array_ops.concat(stacked_shards, axis=0)
+    self.assertAllEqual(
+        v3_utils.unshuffle_from_sc_to_cpu(
+            t=stacked,
+            num_sparse_cores=num_sc,
+            offset_in_shard=0,
+            size_in_shard=ta.vocab // num_sc,
+            shard_rotation=ta.shift,
+        ),
+        x,
+    )
+    self.assertAllEqual(
+        v3_utils.unshuffle_from_sc_to_cpu(
+            t=stacked,
+            num_sparse_cores=num_sc,
+            offset_in_shard=ta.vocab // num_sc,
+            size_in_shard=tb.vocab // num_sc,
+            shard_rotation=tb.shift,
+        ),
+        y,
+    )
+
+  def test_unshuffle_stacking_many_tables(self):
+    num_sc = 4
+    tables = [
+        TestTable(vocab=12, dim=4, shift=0),
+        TestTable(vocab=32, dim=4, shift=1),
+        TestTable(vocab=32, dim=4, shift=2),
+        TestTable(vocab=32, dim=4, shift=3),
+        TestTable(vocab=32, dim=4, shift=4),
+        TestTable(vocab=32, dim=4, shift=5),
+    ]
+    u, u_shards = create_test_table_shards(tables[0], num_sc)
+    v, v_shards = create_test_table_shards(tables[1], num_sc)
+    w, w_shards = create_test_table_shards(tables[2], num_sc)
+    x, x_shards = create_test_table_shards(tables[3], num_sc)
+    y, y_shards = create_test_table_shards(tables[4], num_sc)
+    z, z_shards = create_test_table_shards(tables[5], num_sc)
+    stacked_shards = [
+        array_ops.concat([i, j, k, l, m, n], axis=0)
+        for i, j, k, l, m, n in zip(
+            u_shards, v_shards, w_shards, x_shards, y_shards, z_shards
+        )
+    ]
+    stacked = array_ops.concat(stacked_shards, axis=0)
+    self.assertAllEqual(
+        v3_utils.unshuffle_from_sc_to_cpu(
+            t=stacked,
+            num_sparse_cores=num_sc,
+            offset_in_shard=0,
+            size_in_shard=tables[0].vocab // num_sc,
+            shard_rotation=tables[0].shift,
+        ),
+        u,
+    )
+    self.assertAllEqual(
+        v3_utils.unshuffle_from_sc_to_cpu(
+            t=stacked,
+            num_sparse_cores=num_sc,
+            offset_in_shard=tables[0].vocab // num_sc,
+            size_in_shard=tables[1].vocab // num_sc,
+            shard_rotation=tables[1].shift,
+        ),
+        v,
+    )
+    self.assertAllEqual(
+        v3_utils.unshuffle_from_sc_to_cpu(
+            t=stacked,
+            num_sparse_cores=num_sc,
+            offset_in_shard=(tables[0].vocab + tables[1].vocab) // num_sc,
+            size_in_shard=tables[2].vocab // num_sc,
+            shard_rotation=tables[2].shift,
+        ),
+        w,
+    )
+    self.assertAllEqual(
+        v3_utils.unshuffle_from_sc_to_cpu(
+            t=stacked,
+            num_sparse_cores=num_sc,
+            offset_in_shard=(
+                tables[0].vocab + tables[1].vocab + tables[2].vocab
+            )
+            // num_sc,
+            size_in_shard=tables[3].vocab // num_sc,
+            shard_rotation=tables[3].shift,
+        ),
+        x,
+    )
+    self.assertAllEqual(
+        v3_utils.unshuffle_from_sc_to_cpu(
+            t=stacked,
+            num_sparse_cores=num_sc,
+            offset_in_shard=(
+                tables[0].vocab
+                + tables[1].vocab
+                + tables[2].vocab
+                + tables[3].vocab
+            )
+            // num_sc,
+            size_in_shard=tables[4].vocab // num_sc,
+            shard_rotation=tables[4].shift,
+        ),
+        y,
+    )
+    self.assertAllEqual(
+        v3_utils.unshuffle_from_sc_to_cpu(
+            t=stacked,
+            num_sparse_cores=num_sc,
+            offset_in_shard=(
+                tables[0].vocab
+                + tables[1].vocab
+                + tables[2].vocab
+                + tables[3].vocab
+                + tables[4].vocab
+            )
+            // num_sc,
+            size_in_shard=tables[5].vocab // num_sc,
+            shard_rotation=tables[5].shift,
+        ),
+        z,
+    )
+
+  def test_index_mapping_one_table(self):
+    num_sc = 4
+    x, shards = create_test_table_shards(
+        TestTable(vocab=12, dim=4, shift=0), num_sc
+    )
+    indices = tf_constant([1, 2, 5, 7, 9])
+    shard_idx, position_in_shard = v3_utils.map_indices_in_shard(
+        num_sparse_cores=num_sc,
+        offset_in_shard=0,
+        shard_rotation=0,
+        row_indices=indices,
+    )
+    self.assertAllEqual(
+        shard_idx,
+        indices % num_sc,
+    )
+    self.assertAllEqual(
+        [x[i] for i in indices],
+        [shards[j][k] for j, k in zip(shard_idx, position_in_shard)],
+    )
+
+  def test_index_mapping_stacked_tables(self):
+    num_sc = 4
+    ta = TestTable(vocab=12, dim=4, shift=0)
+    tb = TestTable(vocab=32, dim=4, shift=1)
+    x, x_shards = create_test_table_shards(ta, num_sc)
+    y, y_shards = create_test_table_shards(tb, num_sc, table_data_start=100)
+    stacked_shards = [
+        array_ops.concat([i, j], axis=0) for i, j in zip(x_shards, y_shards)
+    ]
+    indices_ta = tf_constant([1, 2, 7, 9, 11])
+    shard_idx, position_in_shard = v3_utils.map_indices_in_shard(
+        num_sparse_cores=num_sc,
+        offset_in_shard=0,
+        shard_rotation=ta.shift,
+        row_indices=indices_ta,
+    )
+    self.assertAllEqual(
+        [x[i] for i in indices_ta],
+        [stacked_shards[j][k] for j, k in zip(shard_idx, position_in_shard)],
+    )
+    indices_tb = tf_constant([1, 2, 7, 9, 15, 27])
+    shard_idx, position_in_shard = v3_utils.map_indices_in_shard(
+        num_sparse_cores=num_sc,
+        offset_in_shard=ta.vocab // num_sc,
+        shard_rotation=tb.shift,
+        row_indices=indices_tb,
+    )
+    self.assertAllEqual(
+        [y[i] for i in indices_tb],
+        [stacked_shards[j][k] for j, k in zip(shard_idx, position_in_shard)],
+    )
+
+
+if __name__ == "__main__":
+  v2_compat.enable_v2_behavior()
+  test.main()
diff --git a/tensorflow/python/training/BUILD b/tensorflow/python/training/BUILD
index feca1c7..a041d50 100644
--- a/tensorflow/python/training/BUILD
+++ b/tensorflow/python/training/BUILD
@@ -95,9 +95,9 @@
     srcs_version = "PY3",
     deps = [
         ":optimizer",
-        ":training_ops",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/ops:math_ops",
+        "//tensorflow/python/ops:training_ops_gen",
         "//tensorflow/python/util:tf_export",
     ],
 )
@@ -108,11 +108,11 @@
     srcs_version = "PY3",
     deps = [
         ":optimizer",
-        ":training_ops",
         "//tensorflow/python/framework:constant_op",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/ops:array_ops",
         "//tensorflow/python/ops:math_ops",
+        "//tensorflow/python/ops:training_ops_gen",
         "//tensorflow/python/util:tf_export",
     ],
 )
@@ -123,12 +123,12 @@
     srcs_version = "PY3",
     deps = [
         ":optimizer",
-        ":training_ops",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/ops:array_ops",
         "//tensorflow/python/ops:array_ops_gen",
         "//tensorflow/python/ops:init_ops",
         "//tensorflow/python/ops:math_ops",
+        "//tensorflow/python/ops:training_ops_gen",
         "//tensorflow/python/util:tf_export",
     ],
 )
@@ -139,13 +139,13 @@
     srcs_version = "PY3",
     deps = [
         ":optimizer",
-        ":training_ops",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/ops:control_flow_ops",
         "//tensorflow/python/ops:math_ops",
         "//tensorflow/python/ops:resource_variable_ops",
         "//tensorflow/python/ops:state_ops",
+        "//tensorflow/python/ops:training_ops_gen",
         "//tensorflow/python/util:tf_export",
     ],
 )
@@ -253,11 +253,11 @@
     srcs_version = "PY3",
     deps = [
         ":optimizer",
-        ":training_ops",
         "//tensorflow/python/framework:dtypes",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/ops:array_ops",
         "//tensorflow/python/ops:math_ops",
+        "//tensorflow/python/ops:training_ops_gen",
         "//tensorflow/python/util:tf_export",
     ],
 )
@@ -273,11 +273,11 @@
     ],
     deps = [
         ":optimizer",
-        ":training_ops",
         "//tensorflow/python/framework:indexed_slices",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/ops:math_ops",
         "//tensorflow/python/ops:resource_variable_ops",
+        "//tensorflow/python/ops:training_ops_gen",
         "//tensorflow/python/util:tf_export",
     ],
 )
@@ -327,9 +327,9 @@
     ],
     deps = [
         ":optimizer",
-        ":training_ops",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/ops:math_ops",
+        "//tensorflow/python/ops:training_ops_gen",
         "//tensorflow/python/util:tf_export",
     ],
 )
@@ -403,10 +403,10 @@
     srcs_version = "PY3",
     deps = [
         ":optimizer",
-        ":training_ops",
         "//tensorflow/python/framework:constant_op",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/ops:math_ops",
+        "//tensorflow/python/ops:training_ops_gen",
         "//tensorflow/python/util:tf_export",
     ],
 )
@@ -417,9 +417,9 @@
     srcs_version = "PY3",
     deps = [
         ":optimizer",
-        ":training_ops",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/ops:math_ops",
+        "//tensorflow/python/ops:training_ops_gen",
         "//tensorflow/python/util:tf_export",
     ],
 )
@@ -469,11 +469,11 @@
     srcs_version = "PY3",
     deps = [
         ":optimizer",
-        ":training_ops",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/ops:array_ops",
         "//tensorflow/python/ops:init_ops",
         "//tensorflow/python/ops:math_ops",
+        "//tensorflow/python/ops:training_ops_gen",
         "//tensorflow/python/util:tf_export",
     ],
 )
@@ -551,18 +551,6 @@
 )
 
 py_strict_library(
-    name = "training_ops",
-    srcs = [
-        "gen_training_ops.py",
-        "training_ops.py",
-    ],
-    srcs_version = "PY3",
-    deps = [
-        "//tensorflow/python/ops:training_ops_gen",
-    ],
-)
-
-py_strict_library(
     name = "warm_starting_util",
     srcs = ["warm_starting_util.py"],
     srcs_version = "PY3",
@@ -826,10 +814,10 @@
         "//tensorflow/python/checkpoint:checkpoint_management",
         "//tensorflow/python/client:session",
         "//tensorflow/python/eager:context",
-        "//tensorflow/python/framework",
         "//tensorflow/python/framework:constant_op",
         "//tensorflow/python/framework:device",
         "//tensorflow/python/framework:errors",
+        "//tensorflow/python/framework:meta_graph",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/ops:array_ops",
         "//tensorflow/python/ops:control_flow_ops",
@@ -890,11 +878,12 @@
         "//tensorflow/python/data/ops:dataset_ops",
         "//tensorflow/python/data/ops:iterator_ops",
         "//tensorflow/python/eager:context",
-        "//tensorflow/python/framework",
         "//tensorflow/python/framework:constant_op",
         "//tensorflow/python/framework:dtypes",
         "//tensorflow/python/framework:errors",
         "//tensorflow/python/framework:function",
+        "//tensorflow/python/framework:graph_io",
+        "//tensorflow/python/framework:meta_graph",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/framework:test_lib",
         "//tensorflow/python/lib/io:file_io",
@@ -976,9 +965,9 @@
         ":training_util",
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python/client:timeline",
-        "//tensorflow/python/framework",
         "//tensorflow/python/framework:dtypes",
         "//tensorflow/python/framework:errors",
+        "//tensorflow/python/framework:meta_graph",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/ops:init_ops",
         "//tensorflow/python/ops:variable_scope",
@@ -1011,8 +1000,8 @@
         ":training_util",
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python/eager:context",
-        "//tensorflow/python/framework",
         "//tensorflow/python/framework:dtypes",
+        "//tensorflow/python/framework:meta_graph",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/ops:control_flow_ops",
         "//tensorflow/python/ops:lookup_ops",
@@ -1039,10 +1028,10 @@
         ":supervisor",
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python/checkpoint:checkpoint_management",
-        "//tensorflow/python/framework",
         "//tensorflow/python/framework:constant_op",
         "//tensorflow/python/framework:dtypes",
         "//tensorflow/python/framework:errors",
+        "//tensorflow/python/framework:meta_graph",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/framework:test_lib",
         "//tensorflow/python/ops:array_ops",
@@ -1083,8 +1072,8 @@
     ],
     deps = [
         "//tensorflow/python/eager:context",
-        "//tensorflow/python/framework",
         "//tensorflow/python/framework:dtypes",
+        "//tensorflow/python/framework:graph_io",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/framework:tensor",
         "//tensorflow/python/ops:cond",
@@ -1182,7 +1171,6 @@
     "//tensorflow/python/ops:data_flow_ops_gen",
     "//tensorflow/python/ops:embedding_ops",
     "//tensorflow/python/framework:errors",
-    "//tensorflow/python/framework",
     "//tensorflow/python/framework:for_generated_wrappers",
     "//tensorflow/python/framework:test_lib",
     "//tensorflow/python/ops:custom_gradient",
@@ -1383,6 +1371,7 @@
         ":saver",
         "//tensorflow/python/client:session",
         "//tensorflow/python/framework:constant_op",
+        "//tensorflow/python/framework:importer",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/ops:variable_v1",
     ] + TRAINING_TEST_DEPS,
@@ -1441,12 +1430,12 @@
     srcs = ["training_ops_test.py"],
     python_version = "PY3",
     deps = [
-        ":training_ops",
         "//tensorflow/python/eager:def_function",
         "//tensorflow/python/framework:constant_op",
         "//tensorflow/python/framework:dtypes",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/ops:math_ops",
+        "//tensorflow/python/ops:training_ops_gen",
         "//tensorflow/python/ops:variable_v1",
     ] + TRAINING_TEST_DEPS,
 )
@@ -1482,7 +1471,6 @@
     disable_mlir_bridge = False,
     main = "training_ops_test.py",
     deps = [
-        ":training_ops",
         "//tensorflow/python/eager:def_function",
         "//tensorflow/python/framework:constant_op",
         "//tensorflow/python/framework:dtypes",
@@ -1490,6 +1478,7 @@
         "//tensorflow/python/framework:test_lib",
         "//tensorflow/python/ops:math_ops",
         "//tensorflow/python/ops:resource_variable_ops",
+        "//tensorflow/python/ops:training_ops_gen",
         "//tensorflow/python/ops:variable_v1",
         "//tensorflow/python/ops:variables",
         "//tensorflow/python/platform:test",
@@ -1514,7 +1503,7 @@
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/framework:test_lib",
         "//tensorflow/python/ops:array_ops",
-        "//tensorflow/python/ops:variable_scope",
+        "//tensorflow/python/ops:resource_variables_toggle",
         "//tensorflow/python/ops:variable_v1",
         "//tensorflow/python/ops:variables",
         "//tensorflow/python/ops:while_loop",
@@ -1541,10 +1530,10 @@
         ":training_util",
         "//tensorflow/python/client:session",
         "//tensorflow/python/data/ops:dataset_ops",
-        "//tensorflow/python/framework",
         "//tensorflow/python/framework:constant_op",
         "//tensorflow/python/framework:dtypes",
         "//tensorflow/python/framework:errors",
+        "//tensorflow/python/framework:meta_graph",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/framework:test_lib",
         "//tensorflow/python/ops:array_ops",
diff --git a/tensorflow/python/training/adadelta.py b/tensorflow/python/training/adadelta.py
index c3690de..f1d27ae 100644
--- a/tensorflow/python/training/adadelta.py
+++ b/tensorflow/python/training/adadelta.py
@@ -15,9 +15,9 @@
 
 """Adadelta for TensorFlow."""
 from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_training_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.training import optimizer
-from tensorflow.python.training import training_ops
 from tensorflow.python.util.tf_export import tf_export
 
 
@@ -146,7 +146,7 @@
   def _apply_dense(self, grad, var):
     accum = self.get_slot(var, "accum")
     accum_update = self.get_slot(var, "accum_update")
-    return training_ops.apply_adadelta(
+    return gen_training_ops.apply_adadelta(
         var,
         accum,
         accum_update,
@@ -159,7 +159,7 @@
   def _resource_apply_dense(self, grad, var):
     accum = self.get_slot(var, "accum")
     accum_update = self.get_slot(var, "accum_update")
-    return training_ops.resource_apply_adadelta(
+    return gen_training_ops.resource_apply_adadelta(
         var.handle,
         accum.handle,
         accum_update.handle,
@@ -172,7 +172,7 @@
   def _apply_sparse(self, grad, var):
     accum = self.get_slot(var, "accum")
     accum_update = self.get_slot(var, "accum_update")
-    return training_ops.sparse_apply_adadelta(
+    return gen_training_ops.sparse_apply_adadelta(
         var,
         accum,
         accum_update,
@@ -186,7 +186,7 @@
   def _resource_apply_sparse(self, grad, var, indices):
     accum = self.get_slot(var, "accum")
     accum_update = self.get_slot(var, "accum_update")
-    return training_ops.resource_sparse_apply_adadelta(
+    return gen_training_ops.resource_sparse_apply_adadelta(
         var.handle,
         accum.handle,
         accum_update.handle,
diff --git a/tensorflow/python/training/adagrad.py b/tensorflow/python/training/adagrad.py
index 0fbe901..6cb2ca4 100644
--- a/tensorflow/python/training/adagrad.py
+++ b/tensorflow/python/training/adagrad.py
@@ -17,10 +17,10 @@
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import gen_training_ops
 from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.training import optimizer
-from tensorflow.python.training import training_ops
 from tensorflow.python.util.tf_export import tf_export
 
 
@@ -158,7 +158,7 @@
 
   def _apply_dense(self, grad, var):
     acc = self.get_slot(var, "accumulator")
-    return training_ops.apply_adagrad(
+    return gen_training_ops.apply_adagrad(
         var,
         acc,
         math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
@@ -167,7 +167,7 @@
 
   def _resource_apply_dense(self, grad, var):
     acc = self.get_slot(var, "accumulator")
-    return training_ops.resource_apply_adagrad(
+    return gen_training_ops.resource_apply_adagrad(
         var.handle,
         acc.handle,
         math_ops.cast(self._learning_rate_tensor, grad.dtype.base_dtype),
@@ -176,7 +176,7 @@
 
   def _apply_sparse(self, grad, var):
     acc = self.get_slot(var, "accumulator")
-    return training_ops.sparse_apply_adagrad(
+    return gen_training_ops.sparse_apply_adagrad(
         var,
         acc,
         math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
@@ -186,7 +186,7 @@
 
   def _resource_apply_sparse(self, grad, var, indices):
     acc = self.get_slot(var, "accumulator")
-    return training_ops.resource_sparse_apply_adagrad(
+    return gen_training_ops.resource_sparse_apply_adagrad(
         var.handle,
         acc.handle,
         math_ops.cast(self._learning_rate_tensor, grad.dtype),
diff --git a/tensorflow/python/training/adagrad_da.py b/tensorflow/python/training/adagrad_da.py
index 9f9784f..2081996 100644
--- a/tensorflow/python/training/adagrad_da.py
+++ b/tensorflow/python/training/adagrad_da.py
@@ -16,9 +16,9 @@
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_training_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.training import optimizer
-from tensorflow.python.training import training_ops
 from tensorflow.python.util.tf_export import tf_export
 
 
@@ -109,7 +109,7 @@
     gg_acc = self.get_slot(var, "gradient_squared_accumulator")
     with ops.device(var.device):
       global_step = array_ops.identity(self._global_step_on_worker)
-    return training_ops.apply_adagrad_da(
+    return gen_training_ops.apply_adagrad_da(
         var,
         g_acc,
         gg_acc,
@@ -125,7 +125,7 @@
     gg_acc = self.get_slot(var, "gradient_squared_accumulator")
     with ops.device(var.device):
       global_step = array_ops.identity(self._global_step_on_worker)
-    return training_ops.resource_apply_adagrad_da(
+    return gen_training_ops.resource_apply_adagrad_da(
         var.handle,
         g_acc.handle,
         gg_acc.handle,
@@ -141,7 +141,7 @@
     gg_acc = self.get_slot(var, "gradient_squared_accumulator")
     with ops.device(var.device):
       global_step = array_ops.identity(self._global_step_on_worker)
-    return training_ops.sparse_apply_adagrad_da(
+    return gen_training_ops.sparse_apply_adagrad_da(
         var,
         g_acc,
         gg_acc,
@@ -158,7 +158,7 @@
     gg_acc = self.get_slot(var, "gradient_squared_accumulator")
     with ops.device(var.device):
       global_step = array_ops.identity(self._global_step_on_worker)
-    return training_ops.resource_sparse_apply_adagrad_da(
+    return gen_training_ops.resource_sparse_apply_adagrad_da(
         var.handle,
         g_acc.handle,
         gg_acc.handle,
diff --git a/tensorflow/python/training/adam.py b/tensorflow/python/training/adam.py
index f479dbb..0192b22 100644
--- a/tensorflow/python/training/adam.py
+++ b/tensorflow/python/training/adam.py
@@ -16,11 +16,11 @@
 from tensorflow.python.eager import context
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_training_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import state_ops
 from tensorflow.python.training import optimizer
-from tensorflow.python.training import training_ops
 from tensorflow.python.util.tf_export import tf_export
 
 
@@ -214,7 +214,7 @@
     m = self.get_slot(var, "m")
     v = self.get_slot(var, "v")
     beta1_power, beta2_power = self._get_beta_accumulators()
-    return training_ops.apply_adam(
+    return gen_training_ops.apply_adam(
         var,
         m,
         v,
@@ -231,7 +231,7 @@
     m = self.get_slot(var, "m")
     v = self.get_slot(var, "v")
     beta1_power, beta2_power = self._get_beta_accumulators()
-    return training_ops.resource_apply_adam(
+    return gen_training_ops.resource_apply_adam(
         var.handle,
         m.handle,
         v.handle,
diff --git a/tensorflow/python/training/ftrl.py b/tensorflow/python/training/ftrl.py
index 282ca74..cb44245 100644
--- a/tensorflow/python/training/ftrl.py
+++ b/tensorflow/python/training/ftrl.py
@@ -16,9 +16,9 @@
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_training_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.training import optimizer
-from tensorflow.python.training import training_ops
 from tensorflow.python.util.tf_export import tf_export
 
 
@@ -164,7 +164,7 @@
     accum = self.get_slot(var, "accum")
     linear = self.get_slot(var, "linear")
     if self._l2_shrinkage_regularization_strength <= 0.0:
-      return training_ops.apply_ftrl(
+      return gen_training_ops.apply_ftrl(
           var,
           accum,
           linear,
@@ -177,7 +177,7 @@
           math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype),
           use_locking=self._use_locking)
     else:
-      return training_ops.apply_ftrl_v2(
+      return gen_training_ops.apply_ftrl_v2(
           var,
           accum,
           linear,
@@ -196,7 +196,7 @@
     accum = self.get_slot(var, "accum")
     linear = self.get_slot(var, "linear")
     if self._l2_shrinkage_regularization_strength <= 0.0:
-      return training_ops.resource_apply_ftrl(
+      return gen_training_ops.resource_apply_ftrl(
           var.handle,
           accum.handle,
           linear.handle,
@@ -209,7 +209,7 @@
           math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype),
           use_locking=self._use_locking)
     else:
-      return training_ops.resource_apply_ftrl_v2(
+      return gen_training_ops.resource_apply_ftrl_v2(
           var.handle,
           accum.handle,
           linear.handle,
@@ -228,7 +228,7 @@
     accum = self.get_slot(var, "accum")
     linear = self.get_slot(var, "linear")
     if self._l2_shrinkage_regularization_strength <= 0.0:
-      return training_ops.sparse_apply_ftrl(
+      return gen_training_ops.sparse_apply_ftrl(
           var,
           accum,
           linear,
@@ -242,7 +242,7 @@
           math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype),
           use_locking=self._use_locking)
     else:
-      return training_ops.sparse_apply_ftrl_v2(
+      return gen_training_ops.sparse_apply_ftrl_v2(
           var,
           accum,
           linear,
@@ -262,7 +262,7 @@
     accum = self.get_slot(var, "accum")
     linear = self.get_slot(var, "linear")
     if self._l2_shrinkage_regularization_strength <= 0.0:
-      return training_ops.resource_sparse_apply_ftrl(
+      return gen_training_ops.resource_sparse_apply_ftrl(
           var.handle,
           accum.handle,
           linear.handle,
@@ -275,7 +275,7 @@
           math_ops.cast(self._learning_rate_power_tensor, grad.dtype),
           use_locking=self._use_locking)
     else:
-      return training_ops.resource_sparse_apply_ftrl_v2(
+      return gen_training_ops.resource_sparse_apply_ftrl_v2(
           var.handle,
           accum.handle,
           linear.handle,
diff --git a/tensorflow/python/training/gen_training_ops.py b/tensorflow/python/training/gen_training_ops.py
deleted file mode 100644
index a569a58..0000000
--- a/tensorflow/python/training/gen_training_ops.py
+++ /dev/null
@@ -1,25 +0,0 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-"""Python wrappers for training ops."""
-# NOTE(allenl): The generated op wrappers for training ops were originally in
-# training/gen_training_ops.py. They moved to ops/gen_training_ops.py when
-# training/ became a module, and this is an alias to avoid breaking existing
-# imports.
-
-# go/tf-wildcard-import
-# pylint: disable=wildcard-import
-from tensorflow.python.ops.gen_training_ops import *
-# pylint: enable=wildcard-import
diff --git a/tensorflow/python/training/gradient_descent.py b/tensorflow/python/training/gradient_descent.py
index 007efce..2b06548 100644
--- a/tensorflow/python/training/gradient_descent.py
+++ b/tensorflow/python/training/gradient_descent.py
@@ -16,10 +16,10 @@
 """GradientDescent for TensorFlow."""
 from tensorflow.python.framework import indexed_slices
 from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_training_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.training import optimizer
-from tensorflow.python.training import training_ops
 from tensorflow.python.util.tf_export import tf_export
 
 
@@ -50,14 +50,14 @@
     self._learning_rate_tensor = None
 
   def _apply_dense(self, grad, var):
-    return training_ops.apply_gradient_descent(
+    return gen_training_ops.apply_gradient_descent(
         var,
         math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
         grad,
         use_locking=self._use_locking).op
 
   def _resource_apply_dense(self, grad, handle):
-    return training_ops.resource_apply_gradient_descent(
+    return gen_training_ops.resource_apply_gradient_descent(
         handle.handle, math_ops.cast(self._learning_rate_tensor,
                                      grad.dtype.base_dtype),
         grad, use_locking=self._use_locking)
diff --git a/tensorflow/python/training/momentum.py b/tensorflow/python/training/momentum.py
index b7fec55..abeb6b0 100644
--- a/tensorflow/python/training/momentum.py
+++ b/tensorflow/python/training/momentum.py
@@ -15,9 +15,9 @@
 
 """Momentum for TensorFlow."""
 from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_training_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.training import optimizer
-from tensorflow.python.training import training_ops
 from tensorflow.python.util.tf_export import tf_export
 
 
@@ -164,7 +164,7 @@
 
   def _apply_dense(self, grad, var):
     mom = self.get_slot(var, "momentum")
-    return training_ops.apply_momentum(
+    return gen_training_ops.apply_momentum(
         var, mom,
         math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
         grad,
@@ -174,7 +174,7 @@
 
   def _resource_apply_dense(self, grad, var):
     mom = self.get_slot(var, "momentum")
-    return training_ops.resource_apply_momentum(
+    return gen_training_ops.resource_apply_momentum(
         var.handle, mom.handle,
         math_ops.cast(self._learning_rate_tensor, grad.dtype.base_dtype),
         grad,
@@ -184,7 +184,7 @@
 
   def _apply_sparse(self, grad, var):
     mom = self.get_slot(var, "momentum")
-    return training_ops.sparse_apply_momentum(
+    return gen_training_ops.sparse_apply_momentum(
         var, mom,
         math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
         grad.values, grad.indices,
@@ -194,7 +194,7 @@
 
   def _resource_apply_sparse(self, grad, var, indices):
     mom = self.get_slot(var, "momentum")
-    return training_ops.resource_sparse_apply_momentum(
+    return gen_training_ops.resource_sparse_apply_momentum(
         var.handle, mom.handle,
         math_ops.cast(self._learning_rate_tensor, grad.dtype),
         grad, indices,
diff --git a/tensorflow/python/training/proximal_adagrad.py b/tensorflow/python/training/proximal_adagrad.py
index 9001a45..389c76b 100644
--- a/tensorflow/python/training/proximal_adagrad.py
+++ b/tensorflow/python/training/proximal_adagrad.py
@@ -16,9 +16,9 @@
 """ProximalAdagrad for TensorFlow."""
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_training_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.training import optimizer
-from tensorflow.python.training import training_ops
 from tensorflow.python.util.tf_export import tf_export
 
 
@@ -89,7 +89,7 @@
 
   def _apply_dense(self, grad, var):
     acc = self.get_slot(var, "accumulator")
-    return training_ops.apply_proximal_adagrad(
+    return gen_training_ops.apply_proximal_adagrad(
         var, acc, self._learning_rate_tensor,
         self._l1_regularization_strength_tensor,
         self._l2_regularization_strength_tensor,
@@ -97,7 +97,7 @@
 
   def _resource_apply_dense(self, grad, var):
     acc = self.get_slot(var, "accumulator")
-    return training_ops.resource_apply_proximal_adagrad(
+    return gen_training_ops.resource_apply_proximal_adagrad(
         var.handle, acc.handle, self._learning_rate_tensor,
         self._l1_regularization_strength_tensor,
         self._l2_regularization_strength_tensor,
@@ -105,7 +105,7 @@
 
   def _apply_sparse(self, grad, var):
     acc = self.get_slot(var, "accumulator")
-    return training_ops.sparse_apply_proximal_adagrad(
+    return gen_training_ops.sparse_apply_proximal_adagrad(
         var, acc, self._learning_rate_tensor,
         self._l1_regularization_strength_tensor,
         self._l2_regularization_strength_tensor,
@@ -114,7 +114,7 @@
 
   def _resource_apply_sparse(self, grad, var, indices):
     acc = self.get_slot(var, "accumulator")
-    return training_ops.resource_sparse_apply_proximal_adagrad(
+    return gen_training_ops.resource_sparse_apply_proximal_adagrad(
         var.handle, acc.handle,
         math_ops.cast(self._learning_rate_tensor, grad.dtype),
         math_ops.cast(self._l1_regularization_strength_tensor, grad.dtype),
diff --git a/tensorflow/python/training/proximal_gradient_descent.py b/tensorflow/python/training/proximal_gradient_descent.py
index 70f6997..5620cd9 100644
--- a/tensorflow/python/training/proximal_gradient_descent.py
+++ b/tensorflow/python/training/proximal_gradient_descent.py
@@ -15,11 +15,11 @@
 
 """ProximalGradientDescent for TensorFlow."""
 from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_training_ops
 # pylint: disable=unused-import
 from tensorflow.python.ops import math_ops
 # pylint: enable=unused-import
 from tensorflow.python.training import optimizer
-from tensorflow.python.training import training_ops
 from tensorflow.python.util.tf_export import tf_export
 
 
@@ -58,7 +58,7 @@
     self._l2_regularization_strength_tensor = None
 
   def _apply_dense(self, grad, var):
-    return training_ops.apply_proximal_gradient_descent(
+    return gen_training_ops.apply_proximal_gradient_descent(
         var,
         self._learning_rate_tensor,
         self._l1_regularization_strength_tensor,
@@ -67,7 +67,7 @@
         use_locking=self._use_locking).op
 
   def _resource_apply_dense(self, grad, var):
-    return training_ops.resource_apply_proximal_gradient_descent(
+    return gen_training_ops.resource_apply_proximal_gradient_descent(
         var.handle,
         self._learning_rate_tensor,
         self._l1_regularization_strength_tensor,
@@ -76,7 +76,7 @@
         use_locking=self._use_locking)
 
   def _apply_sparse(self, grad, var):
-    return training_ops.sparse_apply_proximal_gradient_descent(
+    return gen_training_ops.sparse_apply_proximal_gradient_descent(
         var,
         self._learning_rate_tensor,
         self._l1_regularization_strength_tensor,
@@ -86,7 +86,7 @@
         use_locking=self._use_locking).op
 
   def _resource_apply_sparse(self, grad, var, indices):
-    return training_ops.resource_sparse_apply_proximal_gradient_descent(
+    return gen_training_ops.resource_sparse_apply_proximal_gradient_descent(
         var.handle,
         math_ops.cast(self._learning_rate_tensor, grad.dtype),
         math_ops.cast(self._l1_regularization_strength_tensor, grad.dtype),
diff --git a/tensorflow/python/training/rmsprop.py b/tensorflow/python/training/rmsprop.py
index c20ea36..157f389 100644
--- a/tensorflow/python/training/rmsprop.py
+++ b/tensorflow/python/training/rmsprop.py
@@ -39,10 +39,10 @@
 
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_training_ops
 from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.training import optimizer
-from tensorflow.python.training import training_ops
 from tensorflow.python.util.tf_export import tf_export
 
 
@@ -211,7 +211,7 @@
     mom = self.get_slot(var, "momentum")
     if self._centered:
       mg = self.get_slot(var, "mg")
-      return training_ops.apply_centered_rms_prop(
+      return gen_training_ops.apply_centered_rms_prop(
           var,
           mg,
           rms,
@@ -223,7 +223,7 @@
           grad,
           use_locking=self._use_locking).op
     else:
-      return training_ops.apply_rms_prop(
+      return gen_training_ops.apply_rms_prop(
           var,
           rms,
           mom,
@@ -239,7 +239,7 @@
     mom = self.get_slot(var, "momentum")
     if self._centered:
       mg = self.get_slot(var, "mg")
-      return training_ops.resource_apply_centered_rms_prop(
+      return gen_training_ops.resource_apply_centered_rms_prop(
           var.handle,
           mg.handle,
           rms.handle,
@@ -251,7 +251,7 @@
           grad,
           use_locking=self._use_locking)
     else:
-      return training_ops.resource_apply_rms_prop(
+      return gen_training_ops.resource_apply_rms_prop(
           var.handle,
           rms.handle,
           mom.handle,
@@ -267,7 +267,7 @@
     mom = self.get_slot(var, "momentum")
     if self._centered:
       mg = self.get_slot(var, "mg")
-      return training_ops.sparse_apply_centered_rms_prop(
+      return gen_training_ops.sparse_apply_centered_rms_prop(
           var,
           mg,
           rms,
@@ -280,7 +280,7 @@
           grad.indices,
           use_locking=self._use_locking)
     else:
-      return training_ops.sparse_apply_rms_prop(
+      return gen_training_ops.sparse_apply_rms_prop(
           var,
           rms,
           mom,
@@ -297,7 +297,7 @@
     mom = self.get_slot(var, "momentum")
     if self._centered:
       mg = self.get_slot(var, "mg")
-      return training_ops.resource_sparse_apply_centered_rms_prop(
+      return gen_training_ops.resource_sparse_apply_centered_rms_prop(
           var.handle,
           mg.handle,
           rms.handle,
@@ -310,7 +310,7 @@
           indices,
           use_locking=self._use_locking)
     else:
-      return training_ops.resource_sparse_apply_rms_prop(
+      return gen_training_ops.resource_sparse_apply_rms_prop(
           var.handle,
           rms.handle,
           mom.handle,
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index 4fa6c57..92ec7cf 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -1371,7 +1371,9 @@
     # pylint: enable=line-too-long
     return export_meta_graph(
         filename=filename,
-        graph_def=ops.get_default_graph().as_graph_def(add_shapes=True),
+        graph_def=ops.get_default_graph().as_graph_def(
+            add_shapes=True, use_pybind11_proto=True
+        ),
         saver_def=self.saver_def,
         collection_list=collection_list,
         as_text=as_text,
@@ -1379,7 +1381,8 @@
         clear_devices=clear_devices,
         clear_extraneous_savers=clear_extraneous_savers,
         strip_default_attrs=strip_default_attrs,
-        save_debug_info=save_debug_info)
+        save_debug_info=save_debug_info,
+    )
 
   def restore(self, sess, save_path):
     """Restores previously saved variables.
diff --git a/tensorflow/python/training/session_manager_test.py b/tensorflow/python/training/session_manager_test.py
index f01bc6b..b3139c9 100644
--- a/tensorflow/python/training/session_manager_test.py
+++ b/tensorflow/python/training/session_manager_test.py
@@ -23,7 +23,7 @@
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import resource_variables_toggle
 from tensorflow.python.ops import variable_v1
 from tensorflow.python.ops import variables
 from tensorflow.python.ops import while_loop
@@ -39,7 +39,7 @@
   @classmethod
   def setUpClass(cls):
     super(SessionManagerTest, cls).setUpClass()
-    variable_scope.disable_resource_variables()
+    resource_variables_toggle.disable_resource_variables()
 
   def testPrepareSessionSucceeds(self):
     with ops.Graph().as_default():
@@ -678,7 +678,7 @@
   @classmethod
   def setUpClass(cls):
     super(ObsoleteSessionManagerTest, cls).setUpClass()
-    variable_scope.disable_resource_variables()
+    resource_variables_toggle.disable_resource_variables()
 
   def testPrepareSessionSucceeds(self):
     with ops.Graph().as_default():
diff --git a/tensorflow/python/training/training_ops.py b/tensorflow/python/training/training_ops.py
deleted file mode 100644
index 80a47f8..0000000
--- a/tensorflow/python/training/training_ops.py
+++ /dev/null
@@ -1,22 +0,0 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-"""Python wrappers for training ops."""
-
-from tensorflow.python.ops import gen_training_ops  # pylint: disable=unused-import
-# go/tf-wildcard-import
-# pylint: disable=wildcard-import
-from tensorflow.python.ops.gen_training_ops import *
-# pylint: enable=wildcard-import
diff --git a/tensorflow/python/training/training_ops_test.py b/tensorflow/python/training/training_ops_test.py
index f451244..9ee2af2 100644
--- a/tensorflow/python/training/training_ops_test.py
+++ b/tensorflow/python/training/training_ops_test.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Tests for tensorflow.learning.training_ops."""
+"""Tests for tensorflow.ops.gen_training_ops."""
 
 import itertools
 import threading
@@ -26,12 +26,12 @@
 from tensorflow.python.framework import test_util
 from tensorflow.python.framework.test_util import TensorFlowTestCase
 # Import resource_variable_ops for the variables-to-tensor implicit conversion.
+from tensorflow.python.ops import gen_training_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import resource_variable_ops  # pylint: disable=unused-import
 from tensorflow.python.ops import variable_v1
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import googletest
-from tensorflow.python.training import training_ops
 
 
 class TrainingOpsTest(TensorFlowTestCase):
@@ -56,7 +56,7 @@
       var = variable_v1.VariableV1(x)
       self.evaluate(variables.global_variables_initializer())
       self.assertAllCloseAccordingToType(x, self.evaluate(var))
-      apply_sgd = training_ops.apply_gradient_descent(var, alpha, delta)
+      apply_sgd = gen_training_ops.apply_gradient_descent(var, alpha, delta)
       out = self.evaluate(apply_sgd)
       self.assertShapeEqual(out, apply_sgd)
       self.assertAllCloseAccordingToType(x - alpha * delta, out)
@@ -79,7 +79,7 @@
       self.evaluate(variables.global_variables_initializer())
 
       self.assertAllCloseAccordingToType(x, self.evaluate(var))
-      apply_adagrad = training_ops.apply_adagrad(var, accum, lr, grad)
+      apply_adagrad = gen_training_ops.apply_adagrad(var, accum, lr, grad)
       out = self.evaluate(apply_adagrad)
       self.assertShapeEqual(out, apply_adagrad)
       self.assertAllCloseAccordingToType(x - lr * grad * (y + grad * grad)**
@@ -104,8 +104,8 @@
       self.evaluate(variables.global_variables_initializer())
 
       self.assertAllCloseAccordingToType(x, self.evaluate(var))
-      apply_ftrl = training_ops.apply_ftrl(var, accum, linear, grad, lr, l1, l2,
-                                           lr_power)
+      apply_ftrl = gen_training_ops.apply_ftrl(var, accum, linear, grad, lr, l1,
+                                               l2, lr_power)
       out = self.evaluate(apply_ftrl)
       self.assertShapeEqual(out, apply_ftrl)
       accum_update = y + grad * grad
@@ -150,7 +150,7 @@
 
       self.assertAllCloseAccordingToType(x, self.evaluate(var))
       apply_ftrl = (
-          training_ops.apply_ftrl(
+          gen_training_ops.apply_ftrl(
               var,
               accum,
               linear,
@@ -232,7 +232,7 @@
       self.evaluate(variables.global_variables_initializer())
 
       self.assertAllCloseAccordingToType(x, self.evaluate(var))
-      sparse_apply_adagrad = training_ops.sparse_apply_adagrad(
+      sparse_apply_adagrad = gen_training_ops.sparse_apply_adagrad(
           var, accum, lr, grad,
           constant_op.constant(indices, self._toType(indices.dtype)))
       out = self.evaluate(sparse_apply_adagrad)
@@ -264,7 +264,7 @@
       self.evaluate(variables.global_variables_initializer())
 
       self.assertAllCloseAccordingToType(x, self.evaluate(var))
-      sparse_apply_ftrl = training_ops.sparse_apply_ftrl(
+      sparse_apply_ftrl = gen_training_ops.sparse_apply_ftrl(
           var,
           accum,
           linear,
@@ -304,7 +304,7 @@
 
       self.assertAllCloseAccordingToType(x, self.evaluate(var))
       sparse_apply_ftrl = (
-          training_ops.sparse_apply_ftrl(
+          gen_training_ops.sparse_apply_ftrl(
               var,
               accum,
               linear,
@@ -440,9 +440,9 @@
       self.assertAllCloseAccordingToType(var, self.evaluate(var_t))
       new_var, _, _ = self._adamUpdateNumpy(var, grad, t, m, v, lr, beta1,
                                             beta2, epsilon)
-      apply_adam = training_ops.apply_adam(var_t, m_t, v_t, beta1_power_t,
-                                           beta2_power_t, lr_t, beta1_t,
-                                           beta2_t, epsilon_t, grad)
+      apply_adam = gen_training_ops.apply_adam(var_t, m_t, v_t, beta1_power_t,
+                                               beta2_power_t, lr_t, beta1_t,
+                                               beta2_t, epsilon_t, grad)
       out = self.evaluate(apply_adam)
       self.assertShapeEqual(out, apply_adam)
       self.assertAllCloseAccordingToType(new_var, out)
@@ -488,7 +488,7 @@
     def fn_resource_sparse_apply_adagrad_v2():
       ret = constant_op.constant(0, dtypes.int32)
       for i in math_ops.range(num_iter):
-        adagrad_op = training_ops.resource_sparse_apply_adagrad_v2(
+        adagrad_op = gen_training_ops.resource_sparse_apply_adagrad_v2(
             var.handle, accum.handle, lr, epsilon, grad,
             constant_op.constant(indices, dtypes.int32))
         with ops.control_dependencies([adagrad_op]):
diff --git a/tensorflow/python/util/BUILD b/tensorflow/python/util/BUILD
index 7ef8734..7d2d9ef 100644
--- a/tensorflow/python/util/BUILD
+++ b/tensorflow/python/util/BUILD
@@ -84,6 +84,10 @@
     name = "_pywrap_utils",
     srcs = ["util_wrapper.cc"],
     hdrs = ["util.h"],
+    enable_stub_generation = True,
+    pytype_srcs = [
+        "_pywrap_utils.pyi",
+    ],
     deps = [
         "//tensorflow/core/platform:platform_port",
         "//tensorflow/python/lib/core:pybind11_lib",
@@ -96,6 +100,10 @@
     name = "_pywrap_nest",
     srcs = ["nest_wrapper.cc"],
     hdrs = ["nest.h"],
+    enable_stub_generation = True,
+    pytype_srcs = [
+        "_pywrap_nest.pyi",
+    ],
     deps = [
         "//tensorflow/python/lib/core:pybind11_lib",
         "//third_party/python_runtime:headers",
@@ -348,7 +356,37 @@
         "//third_party/py/tensorflow_core:__subpackages__",
     ],
     deps = [
-        ":tf_decorator",
+        ":tf_decorator_py",
+        ":tf_inspect",
+    ],
+)
+
+py_strict_library(
+    name = "tf_contextlib",
+    srcs = ["tf_contextlib.py"],
+    compatible_with = get_compatible_with_portable(),
+    visibility = [
+        "//learning/brain/analytics:__subpackages__",
+        "//learning/deepmind/research/language/translation/lm:__subpackages__",
+        "//tensorflow:__pkg__",
+        "//tensorflow:__subpackages__",
+        "//third_party/py/tensorflow_core:__subpackages__",
+        "//third_party/py/tf_slim:__subpackages__",
+    ],
+    deps = [":tf_decorator_py"],
+)
+
+py_strict_library(
+    name = "tf_decorator_py",
+    srcs = ["tf_decorator.py"],
+    compatible_with = get_compatible_with_portable(),
+    visibility = [
+        "//learning/brain/analytics:__subpackages__",
+        "//learning/deepmind/research/language/translation/lm:__subpackages__",
+        "//tensorflow:__pkg__",
+        "//tensorflow:__subpackages__",
+        "//third_party/py/tensorflow_core:__subpackages__",
+        "//third_party/py/tf_slim:__subpackages__",
     ],
 )
 
@@ -357,7 +395,7 @@
     srcs = ["tf_export_test.py"],
     python_version = "PY3",
     deps = [
-        ":tf_decorator",
+        ":tf_decorator_py",
         ":tf_export",
         "//tensorflow/python/platform:client_testlib",
         "//tensorflow/python/platform:tf_logging",
@@ -379,12 +417,8 @@
 # TODO(mdan): Move this utility outside of TF.
 py_strict_library(
     name = "tf_decorator",
-    srcs = [
-        "tf_contextlib.py",
-        "tf_decorator.py",
-        "tf_inspect.py",
-    ],
     compatible_with = get_compatible_with_portable(),
+    deprecation = "This target has been split. Depend on the sub-targets instead.",
     srcs_version = "PY3",
     visibility = [
         "//tensorflow:__subpackages__",
@@ -396,7 +430,9 @@
         "//third_party/py/tensorflow_core:__subpackages__",
     ],
     deps = [
-        "@six_archive//:six",
+        ":tf_contextlib",
+        ":tf_decorator_py",
+        ":tf_inspect",
     ],
 )
 
@@ -615,7 +651,9 @@
     srcs = ["tf_contextlib_test.py"],
     python_version = "PY3",
     deps = [
-        ":tf_decorator",
+        ":tf_contextlib",
+        ":tf_decorator_py",
+        ":tf_inspect",
         "//tensorflow/python/platform:client_testlib",
     ],
 )
@@ -626,7 +664,8 @@
     srcs = ["tf_decorator_test.py"],
     python_version = "PY3",
     deps = [
-        ":tf_decorator",
+        ":tf_decorator_py",
+        ":tf_inspect",
         "//tensorflow/python/platform:client_testlib",
         "//tensorflow/python/platform:tf_logging",
     ],
@@ -637,7 +676,7 @@
     srcs = ["tf_should_use.py"],
     srcs_version = "PY3",
     deps = [
-        ":tf_decorator",
+        ":tf_decorator_py",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/platform:tf_logging",
@@ -666,7 +705,8 @@
     srcs = ["tf_inspect_test.py"],
     python_version = "PY3",
     deps = [
-        ":tf_decorator",
+        ":tf_decorator_py",
+        ":tf_inspect",
         "//tensorflow/python/platform:client_testlib",
         "//tensorflow/python/platform:tf_logging",
     ],
@@ -691,13 +731,15 @@
     deps = [
         ":decorator_utils",
         ":is_in_graph_mode",
+        ":tf_contextlib",
+        ":tf_decorator_py",
+        ":tf_inspect",
         "//tensorflow/python/framework:strict_mode",
         "//tensorflow/python/platform:tf_logging",
         # global_test_configuration is added here because all major tests depend on this
         # library. It isn't possible to add these test dependencies via tensorflow.bzl's
         # py_test because not all tensorflow tests use tensorflow.bzl's py_test.
         "//tensorflow/python:global_test_configuration",
-        ":tf_decorator",
         "//tensorflow/tools/docs:doc_controls",
     ],
 )
@@ -741,11 +783,11 @@
     compatible_with = get_compatible_with_portable(),
     visibility = util_subpackage_visibility,
     deps = [
+        ":tf_decorator_py",
         # global_test_configuration is added here because all major tests depend on this
         # library. It isn't possible to add these test dependencies via tensorflow.bzl's
         # py_test because not all tensorflow tests use tensorflow.bzl's py_test.
         "//tensorflow/python:global_test_configuration",
-        ":tf_decorator",
         "@six_archive//:six",
     ],
 )
@@ -783,8 +825,10 @@
     python_version = "PY3",
     deps = [
         ":lazy_loader",
+        ":tf_inspect",
         "//tensorflow/python/framework:test_lib",
         "//tensorflow/python/platform:client_testlib",
+        "//tensorflow/python/platform:tf_logging",
     ],
 )
 
@@ -900,11 +944,11 @@
     compatible_with = get_compatible_with_portable(),
     visibility = util_subpackage_visibility,
     deps = [
+        ":tf_decorator_py",
         # global_test_configuration is added here because all major tests depend on this
         # library. It isn't possible to add these test dependencies via tensorflow.bzl's
         # py_test because not all tensorflow tests use tensorflow.bzl's py_test.
         "//tensorflow/python:global_test_configuration",
-        ":tf_decorator",
         ":tf_export",
     ],
 )
@@ -928,13 +972,14 @@
     visibility = util_subpackage_visibility,
     deps = [
         ":__init__",
+        ":tf_decorator_py",
+        ":tf_inspect",
         "//tensorflow/python/eager:monitoring",
         "//tensorflow/python/platform:tf_logging",
         # global_test_configuration is added here because all major tests depend on this
         # library. It isn't possible to add these test dependencies via tensorflow.bzl's
         # py_test because not all tensorflow tests use tensorflow.bzl's py_test.
         "//tensorflow/python:global_test_configuration",
-        ":tf_decorator",
         "//tensorflow/tools/compatibility:all_renames_v2",
         ":fast_module_type",
     ],
@@ -945,6 +990,8 @@
     srcs = ["function_utils.py"],
     visibility = util_subpackage_visibility,
     deps = [
+        ":tf_decorator_py",
+        ":tf_inspect",
         "//tensorflow/core:protos_all_py",
         # global_test_configuration is added here because all major tests depend on this
         # library. It isn't possible to add these test dependencies via tensorflow.bzl's
@@ -986,6 +1033,8 @@
     srcs = ["dispatch.py"],
     visibility = util_subpackage_visibility,
     deps = [
+        ":tf_decorator_py",
+        ":tf_inspect",
         ":traceback_utils",
         ":type_annotations",
         "//tensorflow/python/framework:_pywrap_python_api_dispatcher",
@@ -994,7 +1043,6 @@
         # library. It isn't possible to add these test dependencies via tensorflow.bzl's
         # py_test because not all tensorflow tests use tensorflow.bzl's py_test.
         "//tensorflow/python:global_test_configuration",
-        ":tf_decorator",
         ":tf_export",
     ],
 )
@@ -1051,11 +1099,11 @@
     compatible_with = get_compatible_with_portable(),
     visibility = util_subpackage_visibility,
     deps = [
+        ":tf_inspect",
         # global_test_configuration is added here because all major tests depend on this
         # library. It isn't possible to add these test dependencies via tensorflow.bzl's
         # py_test because not all tensorflow tests use tensorflow.bzl's py_test.
         "//tensorflow/python:global_test_configuration",
-        ":tf_decorator",
     ],
 )
 
@@ -1212,7 +1260,7 @@
     name = "tf_decorator_export",
     srcs = ["tf_decorator_export.py"],
     deps = [
-        ":tf_decorator",
+        ":tf_decorator_py",
         ":tf_export",
     ],
 )
diff --git a/third_party/xla/third_party/gpus/cuda/cuda_config.py.tpl b/tensorflow/python/util/_pywrap_nest.pyi
similarity index 84%
rename from third_party/xla/third_party/gpus/cuda/cuda_config.py.tpl
rename to tensorflow/python/util/_pywrap_nest.pyi
index 3da256b..8ea0151 100644
--- a/third_party/xla/third_party/gpus/cuda/cuda_config.py.tpl
+++ b/tensorflow/python/util/_pywrap_nest.pyi
@@ -1,4 +1,4 @@
-# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,4 +13,4 @@
 # limitations under the License.
 # ==============================================================================
 
-config = %{cuda_config}
+def FlattenDictItems(arg0: object) -> object: ...
diff --git a/tensorflow/python/util/_pywrap_utils.pyi b/tensorflow/python/util/_pywrap_utils.pyi
new file mode 100644
index 0000000..c8e51ec
--- /dev/null
+++ b/tensorflow/python/util/_pywrap_utils.pyi
@@ -0,0 +1,36 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+def AssertSameStructure(arg0: object, arg1: object, arg2: bool, arg3: bool) -> bool: ...
+def AssertSameStructureForData(arg0: object, arg1: object, arg2: bool) -> bool: ...
+def Flatten(arg0: object, arg1: bool) -> object: ...
+def FlattenForData(arg0: object) -> object: ...
+def IsAttrs(arg0: object) -> bool: ...
+def IsBF16SupportedByOneDNNOnThisCPU() -> bool: ...
+def IsCompositeTensor(arg0: object) -> bool: ...
+def IsMapping(arg0: object) -> bool: ...
+def IsMappingView(arg0: object) -> bool: ...
+def IsMutableMapping(arg0: object) -> bool: ...
+def IsNamedtuple(arg0: object, arg1: bool) -> object: ...
+def IsNested(arg0: object) -> bool: ...
+def IsNestedForData(arg0: object) -> bool: ...
+def IsNestedOrComposite(arg0: object) -> bool: ...
+def IsResourceVariable(arg0: object) -> bool: ...
+def IsTensor(arg0: object) -> bool: ...
+def IsTypeSpec(arg0: object) -> bool: ...
+def IsVariable(arg0: object) -> bool: ...
+def RegisterPyObject(arg0: object, arg1: object) -> object: ...
+def RegisterType(arg0: object, arg1: object) -> object: ...
+def SameNamedtuples(arg0: object, arg1: object) -> object: ...
diff --git a/tensorflow/python/util/lazy_loader.py b/tensorflow/python/util/lazy_loader.py
index 08e0fb0..717965d 100644
--- a/tensorflow/python/util/lazy_loader.py
+++ b/tensorflow/python/util/lazy_loader.py
@@ -20,6 +20,8 @@
 import types
 from tensorflow.python.platform import tf_logging as logging
 
+_TENSORFLOW_LAZY_LOADER_PREFIX = "_tfll"
+
 
 class LazyLoader(types.ModuleType):
   """Lazily import a module, mainly to avoid pulling in large dependencies.
@@ -30,31 +32,31 @@
 
   # The lint error here is incorrect.
   def __init__(self, local_name, parent_module_globals, name, warning=None):
-    self._local_name = local_name
-    self._parent_module_globals = parent_module_globals
-    self._warning = warning
+    self._tfll_local_name = local_name
+    self._tfll_parent_module_globals = parent_module_globals
+    self._tfll_warning = warning
 
     # These members allows doctest correctly process this module member without
     # triggering self._load(). self._load() mutates parant_module_globals and
     # triggers a dict mutated during iteration error from doctest.py.
     # - for from_module()
-    self.__module__ = name.rsplit(".", 1)[0]
+    super().__setattr__("__module__", name.rsplit(".", 1)[0])
     # - for is_routine()
-    self.__wrapped__ = None
+    super().__setattr__("__wrapped__", None)
 
-    super(LazyLoader, self).__init__(name)
+    super().__init__(name)
 
   def _load(self):
     """Load the module and insert it into the parent's globals."""
     # Import the target module and insert it into the parent's namespace
     module = importlib.import_module(self.__name__)
-    self._parent_module_globals[self._local_name] = module
+    self._tfll_parent_module_globals[self._tfll_local_name] = module
 
     # Emit a warning if one was specified
-    if self._warning:
-      logging.warning(self._warning)
+    if self._tfll_warning:
+      logging.warning(self._tfll_warning)
       # Make sure to only warn once.
-      self._warning = None
+      self._tfll_warning = None
 
     # Update this object's dict so that if someone keeps a reference to the
     #   LazyLoader, lookups are efficient (__getattr__ is only called on lookups
@@ -63,14 +65,42 @@
 
     return module
 
-  def __getattr__(self, item):
+  def __getattr__(self, name):
     module = self._load()
-    return getattr(module, item)
+    return getattr(module, name)
+
+  def __setattr__(self, name, value):
+    if name.startswith(_TENSORFLOW_LAZY_LOADER_PREFIX):
+      super().__setattr__(name, value)
+    else:
+      module = self._load()
+      setattr(module, name, value)
+      self.__dict__[name] = value
+      try:
+        # check if the module has __all__
+        if name not in self.__all__ and name != "__all__":
+          self.__all__.append(name)
+      except AttributeError:
+        pass
+
+  def __delattr__(self, name):
+    if name.startswith(_TENSORFLOW_LAZY_LOADER_PREFIX):
+      super().__delattr__(name)
+    else:
+      module = self._load()
+      delattr(module, name)
+      self.__dict__.pop(name)
+      try:
+        # check if the module has __all__
+        if name in self.__all__:
+          self.__all__.remove(name)
+      except AttributeError:
+        pass
 
   def __repr__(self):
     # Carefully to not trigger _load, since repr may be called in very
     # sensitive places.
-    return f"<LazyLoader {self.__name__} as {self._local_name}>"
+    return f"<LazyLoader {self.__name__} as {self._tfll_local_name}>"
 
   def __dir__(self):
     module = self._load()
@@ -82,15 +112,15 @@
 
   def __init__(  # pylint: disable=super-init-not-called
       self, parent_module_globals, mode=None, submodule=None, name="keras"):
-    self._parent_module_globals = parent_module_globals
-    self._mode = mode
-    self._submodule = submodule
-    self._name = name
-    self._initialized = False
+    self._tfll_parent_module_globals = parent_module_globals
+    self._tfll_mode = mode
+    self._tfll_submodule = submodule
+    self._tfll_name = name
+    self._tfll_initialized = False
 
   def _initialize(self):
     """Resolve the Keras version to use and initialize the loader."""
-    self._initialized = True
+    self._tfll_initialized = True
     package_name = None
     keras_version = None
     if os.environ.get("TF_USE_LEGACY_KERAS", None) in ("true", "True", "1"):
@@ -98,7 +128,7 @@
         import tf_keras  # pylint: disable=g-import-not-at-top,unused-import
 
         keras_version = "tf_keras"
-        if self._mode == "v1":
+        if self._tfll_mode == "v1":
           package_name = "tf_keras.api._v1.keras"
         else:
           package_name = "tf_keras.api._v2.keras"
@@ -120,7 +150,7 @@
         else:
           # This is the Keras 2.x case.
           keras_version = "keras_2"
-          if self._mode == "v1":
+          if self._tfll_mode == "v1":
             package_name = "keras.api._v1.keras"
           else:
             package_name = "keras.api._v2.keras"
@@ -129,25 +159,29 @@
             "Keras cannot be imported. Check that it is installed."
         )
 
-    self._keras_version = keras_version
+    self._tfll_keras_version = keras_version
     if keras_version is not None:
-      if self._submodule is not None:
-        package_name += "." + self._submodule
-      super().__init__(self._name, self._parent_module_globals, package_name)
+      if self._tfll_submodule is not None:
+        package_name += "." + self._tfll_submodule
+      super().__init__(
+          self._tfll_name, self._tfll_parent_module_globals, package_name
+      )
     else:
       raise ImportError(  # pylint: disable=raise-missing-from
           "Keras cannot be imported. Check that it is installed."
       )
 
   def __getattr__(self, item):
-    if item in ("_mode", "_initialized", "_name"):
+    if item in ("_tfll_mode", "_tfll_initialized", "_tfll_name"):
       return super(types.ModuleType, self).__getattribute__(item)
-    if not self._initialized:
+    if not self._tfll_initialized:
       self._initialize()
-    if self._keras_version == "keras_3":
-      if (self._mode == "v1" and
-          not self._submodule and
-          item.startswith("compat.v1.")):
+    if self._tfll_keras_version == "keras_3":
+      if (
+          self._tfll_mode == "v1"
+          and not self._tfll_submodule
+          and item.startswith("compat.v1.")
+      ):
         raise AttributeError(
             "`tf.compat.v1.keras` is not available with Keras 3. Keras 3 has "
             "no support for TF 1 APIs. You can install the `tf_keras` package "
@@ -155,15 +189,18 @@
             "`TF_USE_LEGACY_KERAS=True` to configure TensorFlow to route "
             "`tf.compat.v1.keras` to `tf_keras`."
         )
-      elif (self._mode == "v2" and
-            not self._submodule and
-            item.startswith("compat.v2.")):
+      elif (
+          self._tfll_mode == "v2"
+          and not self._tfll_submodule
+          and item.startswith("compat.v2.")
+      ):
         raise AttributeError(
             "`tf.compat.v2.keras` is not available with Keras 3. Just use "
             "`import keras` instead."
         )
-      elif (self._submodule and
-            self._submodule.startswith("__internal__.legacy.")):
+      elif self._tfll_submodule and self._tfll_submodule.startswith(
+          "__internal__.legacy."
+      ):
         raise AttributeError(
             f"`{item}` is not available with Keras 3."
         )
@@ -171,12 +208,14 @@
     return getattr(module, item)
 
   def __repr__(self):
-    if self._initialized:
-      return (f"<KerasLazyLoader ({self._keras_version}) "
-              f"{self.__name__} as {self._local_name} mode={self._mode}>")
+    if self._tfll_initialized:
+      return (
+          f"<KerasLazyLoader ({self._tfll_keras_version}) "
+          f"{self.__name__} as {self._tfll_local_name} mode={self._tfll_mode}>"
+      )
     return "<KerasLazyLoader>"
 
   def __dir__(self):
-    if not self._initialized:
+    if not self._tfll_initialized:
       self._initialize()
     return super().__dir__()
diff --git a/tensorflow/python/util/lazy_loader_test.py b/tensorflow/python/util/lazy_loader_test.py
index 7309716..94f2581 100644
--- a/tensorflow/python/util/lazy_loader_test.py
+++ b/tensorflow/python/util/lazy_loader_test.py
@@ -20,7 +20,9 @@
 import types
 
 from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.util import lazy_loader
+from tensorflow.python.util import tf_inspect
 
 
 class LazyLoaderTest(test.TestCase):
@@ -36,6 +38,21 @@
 
     self.assertIsInstance(module.foo, lazy_loader.LazyLoader)
 
+  @test.mock.patch.object(logging, "warning", autospec=True)
+  def testLazyLoaderMock(self, mock_warning):
+    name = LazyLoaderTest.__module__
+    lazy_loader_module = lazy_loader.LazyLoader(
+        "lazy_loader_module", globals(), name, warning="Test warning.")
+
+    self.assertEqual(0, mock_warning.call_count)
+    lazy_loader_module.foo = 0
+    self.assertEqual(1, mock_warning.call_count)
+    foo = lazy_loader_module.foo
+    self.assertEqual(1, mock_warning.call_count)
+
+    # Check that values stayed the same
+    self.assertEqual(lazy_loader_module.foo, foo)
+
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index a1531f5..2dc79e1 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -1395,7 +1395,11 @@
             is invalid to specify both "hidden" and "op_allowlist".
         cc_linkopts: Optional linkopts to be added to tf_cc_binary that contains the
             specified ops.
-        api_def_srcs: undocumented.
+        api_def_srcs: a list of targets that defines the attributes of API endpoints
+            for this target. For an api_def file to take effect it must be included
+            (transitively) from this list.
+            For example, `visibility: HIDDEN` in the api_def hides the Op from
+            the tf.* namespace.
         compatible_with: undocumented.
         testonly: undocumented.
         copts: undocumented.
@@ -1751,7 +1755,8 @@
         linkopts = lrt_if_needed(),
         kernels = [],
         create_named_test_suite = False,
-        visibility = None):
+        visibility = None,
+        features = []):
     test_names = []
     for src in srcs:
         test_name = src_to_test_name(src)
@@ -1765,6 +1770,7 @@
             linkstatic = linkstatic,
             tags = tags,
             deps = deps,
+            features = features,
             visibility = visibility,
         )
         test_names.append(test_name)
diff --git a/tensorflow/tools/android/test/jni/object_tracking/logging.h b/tensorflow/tools/android/test/jni/object_tracking/logging.h
index 812ba77..3982f57 100644
--- a/tensorflow/tools/android/test/jni/object_tracking/logging.h
+++ b/tensorflow/tools/android/test/jni/object_tracking/logging.h
@@ -33,7 +33,7 @@
   TypeName(const TypeName&) = delete;         \
   void operator=(const TypeName&) = delete
 
-#if defined(COMPILER_GCC3)
+#if defined(__GNUC__)
 #define TF_PREDICT_FALSE(x) (__builtin_expect(x, 0))
 #define TF_PREDICT_TRUE(x) (__builtin_expect(!!(x), 1))
 #else
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-graph.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-graph.pbtxt
index 80ad2d5..f713288 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-graph.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-graph.pbtxt
@@ -52,7 +52,7 @@
   }
   member_method {
     name: "as_graph_def"
-    argspec: "args=[\'self\', \'from_version\', \'add_shapes\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
+    argspec: "args=[\'self\', \'from_version\', \'add_shapes\', \'use_pybind11_proto\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'False\'], "
   }
   member_method {
     name: "as_graph_element"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.image.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.image.pbtxt
index ce9bbed..e288edf 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.image.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.image.pbtxt
@@ -22,7 +22,7 @@
   }
   member_method {
     name: "adjust_jpeg_quality"
-    argspec: "args=[\'image\', \'jpeg_quality\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'image\', \'jpeg_quality\', \'dct_method\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
   }
   member_method {
     name: "adjust_saturation"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
index 6844d3e..fbebe3b 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
@@ -1053,10 +1053,6 @@
     argspec: "args=[\'input\', \'filter\', \'out_backprop\', \'strides\', \'padding\', \'use_cudnn_on_gpu\', \'explicit_paddings\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'[]\', \'NHWC\', \'[1, 1, 1, 1]\', \'None\'], "
   }
   member_method {
-    name: "convert_to_coo_tensor"
-    argspec: "args=[\'indices_or_row_splits\', \'values\', \'weights\', \'sample_count\', \'combiner\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
     name: "convert_to_tensor"
     argspec: "args=[\'value\', \'dtype\', \'name\', \'preferred_dtype\', \'dtype_hint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
   }
@@ -1405,14 +1401,6 @@
     argspec: "args=[], varargs=None, keywords=None, defaults=None"
   }
   member_method {
-    name: "get_minibatch_splits_with_physical_replica"
-    argspec: "args=[\'program_key\', \'row_ids\', \'col_ids\', \'gains\', \'sample_count\', \'num_replica\', \'table_vocab_size\', \'feature_width\', \'num_sc_per_chip\', \'table_name\', \'mini_batch_splits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "get_minibatches_in_csr_with_physical_replica"
-    argspec: "args=[\'program_key\', \'row_ids\', \'col_ids\', \'gains\', \'splits\', \'id_counts\', \'sample_count\', \'num_replica\', \'max_minibatches_per_sc\', \'max_ids_per_chip_per_sample\', \'table_vocab_size\', \'feature_width\', \'num_sc_per_chip\', \'table_name\', \'mini_batch_in_csr\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
     name: "get_seed"
     argspec: "args=[\'op_seed\'], varargs=None, keywords=None, defaults=None"
   }
@@ -2477,14 +2465,6 @@
     argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'ToInt64\'], "
   }
   member_method {
-    name: "tpu_annotate_tensors_with_dynamic_shape"
-    argspec: "args=[\'tensors\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "tpu_copy_with_dynamic_shape"
-    argspec: "args=[\'tensors\', \'unpadded_sizes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
     name: "trace"
     argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
@@ -2605,54 +2585,6 @@
     argspec: "args=[\'filename\', \'contents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
-    name: "xla_sparse_core_adagrad"
-    argspec: "args=[\'indices\', \'gradient\', \'learning_rate\', \'accumulator\', \'embedding_table\', \'feature_width\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "xla_sparse_core_adagrad_momentum"
-    argspec: "args=[\'indices\', \'gradient\', \'learning_rate\', \'beta_1\', \'epsilon\', \'accumulator\', \'momentum\', \'embedding_table\', \'feature_width\', \'use_nesterov\', \'beta_2\', \'exponent\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "xla_sparse_core_adam"
-    argspec: "args=[\'embedding_table\', \'indices\', \'gradient\', \'learning_rate\', \'momentum\', \'velocity\', \'beta_1\', \'beta_2\', \'epsilon\', \'feature_width\', \'use_sum_inside_sqrt\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "xla_sparse_core_ftrl"
-    argspec: "args=[\'embedding_table\', \'accumulator\', \'linear\', \'learning_rate\', \'indices\', \'gradient\', \'beta\', \'learning_rate_power\', \'l2_regularization_strength\', \'feature_width\', \'multiply_linear_by_learning_rate\', \'l1_regularization_strength\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "xla_sparse_core_sgd"
-    argspec: "args=[\'indices\', \'gradient\', \'learning_rate\', \'embedding_table\', \'feature_width\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "xla_sparse_dense_matmul"
-    argspec: "args=[\'row_ids\', \'col_ids\', \'values\', \'offsets\', \'embedding_table\', \'max_ids_per_partition\', \'max_unique_ids_per_partition\', \'input_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "xla_sparse_dense_matmul_grad_with_adagrad_and_csr_input"
-    argspec: "args=[\'row_pointers\', \'sorted_sample_ids\', \'sorted_token_ids\', \'sorted_gains\', \'activation_gradients\', \'learning_rate\', \'embedding_table\', \'accumulator\', \'num_minibatches_per_physical_sparse_core\', \'table_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "xla_sparse_dense_matmul_grad_with_adagrad_momentum_and_csr_input"
-    argspec: "args=[\'row_pointers\', \'sorted_sample_ids\', \'sorted_token_ids\', \'sorted_gains\', \'activation_gradients\', \'learning_rate\', \'embedding_table\', \'accumulator\', \'momenta\', \'num_minibatches_per_physical_sparse_core\', \'use_nesterov\', \'exponent\', \'beta1\', \'beta2\', \'epsilon\', \'table_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "xla_sparse_dense_matmul_grad_with_adam_and_csr_input"
-    argspec: "args=[\'row_pointers\', \'sorted_sample_ids\', \'sorted_token_ids\', \'sorted_gains\', \'activation_gradients\', \'learning_rate\', \'embedding_table\', \'momenta\', \'velocity\', \'num_minibatches_per_physical_sparse_core\', \'use_sum_inside_sqrt\', \'beta1\', \'beta2\', \'epsilon\', \'table_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "xla_sparse_dense_matmul_grad_with_ftrl_and_csr_input"
-    argspec: "args=[\'row_pointers\', \'sorted_sample_ids\', \'sorted_token_ids\', \'sorted_gains\', \'activation_gradients\', \'learning_rate\', \'embedding_table\', \'accumulator\', \'linear\', \'num_minibatches_per_physical_sparse_core\', \'multiply_linear_by_learning_rate\', \'beta\', \'learning_rate_power\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'table_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "xla_sparse_dense_matmul_grad_with_sgd_and_csr_input"
-    argspec: "args=[\'row_pointers\', \'sorted_sample_ids\', \'sorted_token_ids\', \'sorted_gains\', \'activation_gradients\', \'learning_rate\', \'embedding_table\', \'num_minibatches_per_physical_sparse_core\', \'table_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "xla_sparse_dense_matmul_with_csr_input"
-    argspec: "args=[\'row_pointers\', \'sorted_sample_ids\', \'sorted_token_ids\', \'sorted_gains\', \'embedding_table\', \'num_minibatches_per_physical_sparse_core\', \'input_size\', \'quantization_config_low\', \'quantization_config_high\', \'quantization_config_num_buckets\', \'table_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
     name: "zeros"
     argspec: "args=[\'shape\', \'dtype\', \'name\', \'layout\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\', \'None\'], "
   }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
index 250ba4a..f78ba2e 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
@@ -4893,6 +4893,10 @@
     argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
+    name: "StoreMinibatchStatisticsInFdo"
+    argspec: "args=[\'program_key\', \'max_ids\', \'max_uniques\', \'sample_count\', \'num_replica\', \'feature_width\', \'num_sc_per_chip\', \'table_name\', \'mini_batch_splits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
     name: "StridedSlice"
     argspec: "args=[\'input\', \'begin\', \'end\', \'strides\', \'begin_mask\', \'end_mask\', \'ellipsis_mask\', \'new_axis_mask\', \'shrink_axis_mask\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'0\', \'0\', \'0\', \'None\'], "
   }
@@ -5682,23 +5686,23 @@
   }
   member_method {
     name: "XlaSparseDenseMatmulGradWithAdagradAndCsrInput"
-    argspec: "args=[\'row_pointers\', \'sorted_sample_ids\', \'sorted_token_ids\', \'sorted_gains\', \'activation_gradients\', \'learning_rate\', \'embedding_table\', \'accumulator\', \'num_minibatches_per_physical_sparse_core\', \'table_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'row_pointers\', \'sorted_sample_ids\', \'sorted_token_ids\', \'sorted_gains\', \'activation_gradients\', \'learning_rate\', \'embedding_table\', \'accumulator\', \'num_minibatches_per_physical_sparse_core\', \'table_name\', \'clip_weight_min\', \'clip_weight_max\', \'name\'], varargs=None, keywords=None, defaults=[\'-inf\', \'inf\', \'None\'], "
   }
   member_method {
     name: "XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput"
-    argspec: "args=[\'row_pointers\', \'sorted_sample_ids\', \'sorted_token_ids\', \'sorted_gains\', \'activation_gradients\', \'learning_rate\', \'embedding_table\', \'accumulator\', \'momenta\', \'num_minibatches_per_physical_sparse_core\', \'use_nesterov\', \'exponent\', \'beta1\', \'beta2\', \'epsilon\', \'table_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'row_pointers\', \'sorted_sample_ids\', \'sorted_token_ids\', \'sorted_gains\', \'activation_gradients\', \'learning_rate\', \'embedding_table\', \'accumulator\', \'momenta\', \'num_minibatches_per_physical_sparse_core\', \'use_nesterov\', \'exponent\', \'beta1\', \'beta2\', \'epsilon\', \'table_name\', \'clip_weight_min\', \'clip_weight_max\', \'name\'], varargs=None, keywords=None, defaults=[\'-inf\', \'inf\', \'None\'], "
   }
   member_method {
     name: "XlaSparseDenseMatmulGradWithAdamAndCsrInput"
-    argspec: "args=[\'row_pointers\', \'sorted_sample_ids\', \'sorted_token_ids\', \'sorted_gains\', \'activation_gradients\', \'learning_rate\', \'embedding_table\', \'momenta\', \'velocity\', \'num_minibatches_per_physical_sparse_core\', \'use_sum_inside_sqrt\', \'beta1\', \'beta2\', \'epsilon\', \'table_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'row_pointers\', \'sorted_sample_ids\', \'sorted_token_ids\', \'sorted_gains\', \'activation_gradients\', \'learning_rate\', \'embedding_table\', \'momenta\', \'velocity\', \'num_minibatches_per_physical_sparse_core\', \'use_sum_inside_sqrt\', \'beta1\', \'beta2\', \'epsilon\', \'table_name\', \'clip_weight_min\', \'clip_weight_max\', \'name\'], varargs=None, keywords=None, defaults=[\'-inf\', \'inf\', \'None\'], "
   }
   member_method {
     name: "XlaSparseDenseMatmulGradWithFtrlAndCsrInput"
-    argspec: "args=[\'row_pointers\', \'sorted_sample_ids\', \'sorted_token_ids\', \'sorted_gains\', \'activation_gradients\', \'learning_rate\', \'embedding_table\', \'accumulator\', \'linear\', \'num_minibatches_per_physical_sparse_core\', \'multiply_linear_by_learning_rate\', \'beta\', \'learning_rate_power\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'table_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'row_pointers\', \'sorted_sample_ids\', \'sorted_token_ids\', \'sorted_gains\', \'activation_gradients\', \'learning_rate\', \'embedding_table\', \'accumulator\', \'linear\', \'num_minibatches_per_physical_sparse_core\', \'multiply_linear_by_learning_rate\', \'beta\', \'learning_rate_power\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'table_name\', \'clip_weight_min\', \'clip_weight_max\', \'name\'], varargs=None, keywords=None, defaults=[\'-inf\', \'inf\', \'None\'], "
   }
   member_method {
     name: "XlaSparseDenseMatmulGradWithSgdAndCsrInput"
-    argspec: "args=[\'row_pointers\', \'sorted_sample_ids\', \'sorted_token_ids\', \'sorted_gains\', \'activation_gradients\', \'learning_rate\', \'embedding_table\', \'num_minibatches_per_physical_sparse_core\', \'table_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'row_pointers\', \'sorted_sample_ids\', \'sorted_token_ids\', \'sorted_gains\', \'activation_gradients\', \'learning_rate\', \'embedding_table\', \'num_minibatches_per_physical_sparse_core\', \'table_name\', \'clip_weight_min\', \'clip_weight_max\', \'name\'], varargs=None, keywords=None, defaults=[\'-inf\', \'inf\', \'None\'], "
   }
   member_method {
     name: "XlaSparseDenseMatmulWithCsrInput"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.saved_model.-builder.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.saved_model.-builder.pbtxt
index e4cc006..ba9c2c1 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.saved_model.-builder.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.saved_model.-builder.pbtxt
@@ -17,6 +17,6 @@
   }
   member_method {
     name: "save"
-    argspec: "args=[\'self\', \'as_text\'], varargs=None, keywords=None, defaults=[\'False\'], "
+    argspec: "args=[\'self\', \'as_text\', \'experimental_image_format\'], varargs=None, keywords=None, defaults=[\'False\', \'False\'], "
   }
 }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.saved_model.builder.-saved-model-builder.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.saved_model.builder.-saved-model-builder.pbtxt
index 44860b1..b348558 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.saved_model.builder.-saved-model-builder.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.saved_model.builder.-saved-model-builder.pbtxt
@@ -17,6 +17,6 @@
   }
   member_method {
     name: "save"
-    argspec: "args=[\'self\', \'as_text\'], varargs=None, keywords=None, defaults=[\'False\'], "
+    argspec: "args=[\'self\', \'as_text\', \'experimental_image_format\'], varargs=None, keywords=None, defaults=[\'False\', \'False\'], "
   }
 }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-table-config.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-table-config.pbtxt
index 8487be2..ea39cff 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-table-config.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-table-config.pbtxt
@@ -4,6 +4,6 @@
   is_instance: "<type \'object\'>"
   member_method {
     name: "__init__"
-    argspec: "args=[\'self\', \'vocabulary_size\', \'dim\', \'initializer\', \'optimizer\', \'combiner\', \'name\', \'quantization_config\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'mean\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'vocabulary_size\', \'dim\', \'initializer\', \'optimizer\', \'combiner\', \'name\', \'quantization_config\', \'layout\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'mean\', \'None\', \'None\', \'None\'], "
   }
 }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-graph.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-graph.pbtxt
index 80ad2d5..f713288 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.-graph.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-graph.pbtxt
@@ -52,7 +52,7 @@
   }
   member_method {
     name: "as_graph_def"
-    argspec: "args=[\'self\', \'from_version\', \'add_shapes\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
+    argspec: "args=[\'self\', \'from_version\', \'add_shapes\', \'use_pybind11_proto\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'False\'], "
   }
   member_method {
     name: "as_graph_element"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.-func-graph.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.-func-graph.pbtxt
index d2db7d6..5fda52f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.-func-graph.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.-func-graph.pbtxt
@@ -113,7 +113,7 @@
   }
   member_method {
     name: "as_graph_def"
-    argspec: "args=[\'self\', \'from_version\', \'add_shapes\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
+    argspec: "args=[\'self\', \'from_version\', \'add_shapes\', \'use_pybind11_proto\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'False\'], "
   }
   member_method {
     name: "as_graph_element"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.saved_model.-saved-model-builder.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.saved_model.-saved-model-builder.pbtxt
index 53675d7..758beb4 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.saved_model.-saved-model-builder.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.saved_model.-saved-model-builder.pbtxt
@@ -16,6 +16,6 @@
   }
   member_method {
     name: "save"
-    argspec: "args=[\'self\', \'as_text\'], varargs=None, keywords=None, defaults=[\'False\'], "
+    argspec: "args=[\'self\', \'as_text\', \'experimental_image_format\'], varargs=None, keywords=None, defaults=[\'False\', \'False\'], "
   }
 }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.experimental.dtensor.-layout.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.experimental.dtensor.-layout.pbtxt
index c1a0c6f..a0b8430 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.experimental.dtensor.-layout.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.experimental.dtensor.-layout.pbtxt
@@ -56,6 +56,10 @@
     argspec: "args=[\'cls\', \'layout_str\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
+    name: "global_shape_from_local_shape"
+    argspec: "args=[\'self\', \'local_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
     name: "inner_sharded"
     argspec: "args=[\'cls\', \'mesh\', \'inner_dim\', \'rank\'], varargs=None, keywords=None, defaults=None"
   }
@@ -72,6 +76,10 @@
     argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
+    name: "local_shape_from_global_shape"
+    argspec: "args=[\'self\', \'global_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
     name: "num_shards"
     argspec: "args=[\'self\', \'idx\'], varargs=None, keywords=None, defaults=None"
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.image.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.image.pbtxt
index f5f9de5..85b36b7 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.image.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.image.pbtxt
@@ -22,7 +22,7 @@
   }
   member_method {
     name: "adjust_jpeg_quality"
-    argspec: "args=[\'image\', \'jpeg_quality\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'image\', \'jpeg_quality\', \'dct_method\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
   }
   member_method {
     name: "adjust_saturation"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
index b68f187..c514ae5 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
@@ -585,10 +585,6 @@
     argspec: "args=[\'input\', \'filter\', \'out_backprop\', \'strides\', \'padding\', \'use_cudnn_on_gpu\', \'explicit_paddings\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'[]\', \'NHWC\', \'[1, 1, 1, 1]\', \'None\'], "
   }
   member_method {
-    name: "convert_to_coo_tensor"
-    argspec: "args=[\'indices_or_row_splits\', \'values\', \'weights\', \'sample_count\', \'combiner\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
     name: "convert_to_tensor"
     argspec: "args=[\'value\', \'dtype\', \'dtype_hint\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
   }
@@ -713,14 +709,6 @@
     argspec: "args=[], varargs=None, keywords=None, defaults=None"
   }
   member_method {
-    name: "get_minibatch_splits_with_physical_replica"
-    argspec: "args=[\'program_key\', \'row_ids\', \'col_ids\', \'gains\', \'sample_count\', \'num_replica\', \'table_vocab_size\', \'feature_width\', \'num_sc_per_chip\', \'table_name\', \'mini_batch_splits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "get_minibatches_in_csr_with_physical_replica"
-    argspec: "args=[\'program_key\', \'row_ids\', \'col_ids\', \'gains\', \'splits\', \'id_counts\', \'sample_count\', \'num_replica\', \'max_minibatches_per_sc\', \'max_ids_per_chip_per_sample\', \'table_vocab_size\', \'feature_width\', \'num_sc_per_chip\', \'table_name\', \'mini_batch_in_csr\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
     name: "get_static_value"
     argspec: "args=[\'tensor\', \'partial\'], varargs=None, keywords=None, defaults=[\'False\'], "
   }
@@ -1165,14 +1153,6 @@
     argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
-    name: "tpu_annotate_tensors_with_dynamic_shape"
-    argspec: "args=[\'tensors\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "tpu_copy_with_dynamic_shape"
-    argspec: "args=[\'tensors\', \'unpadded_sizes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
     name: "transpose"
     argspec: "args=[\'a\', \'perm\', \'conjugate\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'transpose\'], "
   }
@@ -1229,54 +1209,6 @@
     argspec: "args=[\'cond\', \'body\', \'loop_vars\', \'shape_invariants\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'maximum_iterations\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'True\', \'False\', \'None\', \'None\'], "
   }
   member_method {
-    name: "xla_sparse_core_adagrad"
-    argspec: "args=[\'indices\', \'gradient\', \'learning_rate\', \'accumulator\', \'embedding_table\', \'feature_width\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "xla_sparse_core_adagrad_momentum"
-    argspec: "args=[\'indices\', \'gradient\', \'learning_rate\', \'beta_1\', \'epsilon\', \'accumulator\', \'momentum\', \'embedding_table\', \'feature_width\', \'use_nesterov\', \'beta_2\', \'exponent\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "xla_sparse_core_adam"
-    argspec: "args=[\'embedding_table\', \'indices\', \'gradient\', \'learning_rate\', \'momentum\', \'velocity\', \'beta_1\', \'beta_2\', \'epsilon\', \'feature_width\', \'use_sum_inside_sqrt\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "xla_sparse_core_ftrl"
-    argspec: "args=[\'embedding_table\', \'accumulator\', \'linear\', \'learning_rate\', \'indices\', \'gradient\', \'beta\', \'learning_rate_power\', \'l2_regularization_strength\', \'feature_width\', \'multiply_linear_by_learning_rate\', \'l1_regularization_strength\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "xla_sparse_core_sgd"
-    argspec: "args=[\'indices\', \'gradient\', \'learning_rate\', \'embedding_table\', \'feature_width\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "xla_sparse_dense_matmul"
-    argspec: "args=[\'row_ids\', \'col_ids\', \'values\', \'offsets\', \'embedding_table\', \'max_ids_per_partition\', \'max_unique_ids_per_partition\', \'input_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "xla_sparse_dense_matmul_grad_with_adagrad_and_csr_input"
-    argspec: "args=[\'row_pointers\', \'sorted_sample_ids\', \'sorted_token_ids\', \'sorted_gains\', \'activation_gradients\', \'learning_rate\', \'embedding_table\', \'accumulator\', \'num_minibatches_per_physical_sparse_core\', \'table_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "xla_sparse_dense_matmul_grad_with_adagrad_momentum_and_csr_input"
-    argspec: "args=[\'row_pointers\', \'sorted_sample_ids\', \'sorted_token_ids\', \'sorted_gains\', \'activation_gradients\', \'learning_rate\', \'embedding_table\', \'accumulator\', \'momenta\', \'num_minibatches_per_physical_sparse_core\', \'use_nesterov\', \'exponent\', \'beta1\', \'beta2\', \'epsilon\', \'table_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "xla_sparse_dense_matmul_grad_with_adam_and_csr_input"
-    argspec: "args=[\'row_pointers\', \'sorted_sample_ids\', \'sorted_token_ids\', \'sorted_gains\', \'activation_gradients\', \'learning_rate\', \'embedding_table\', \'momenta\', \'velocity\', \'num_minibatches_per_physical_sparse_core\', \'use_sum_inside_sqrt\', \'beta1\', \'beta2\', \'epsilon\', \'table_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "xla_sparse_dense_matmul_grad_with_ftrl_and_csr_input"
-    argspec: "args=[\'row_pointers\', \'sorted_sample_ids\', \'sorted_token_ids\', \'sorted_gains\', \'activation_gradients\', \'learning_rate\', \'embedding_table\', \'accumulator\', \'linear\', \'num_minibatches_per_physical_sparse_core\', \'multiply_linear_by_learning_rate\', \'beta\', \'learning_rate_power\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'table_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "xla_sparse_dense_matmul_grad_with_sgd_and_csr_input"
-    argspec: "args=[\'row_pointers\', \'sorted_sample_ids\', \'sorted_token_ids\', \'sorted_gains\', \'activation_gradients\', \'learning_rate\', \'embedding_table\', \'num_minibatches_per_physical_sparse_core\', \'table_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "xla_sparse_dense_matmul_with_csr_input"
-    argspec: "args=[\'row_pointers\', \'sorted_sample_ids\', \'sorted_token_ids\', \'sorted_gains\', \'embedding_table\', \'num_minibatches_per_physical_sparse_core\', \'input_size\', \'quantization_config_low\', \'quantization_config_high\', \'quantization_config_num_buckets\', \'table_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
     name: "zeros"
     argspec: "args=[\'shape\', \'dtype\', \'name\', \'layout\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\', \'None\'], "
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
index 250ba4a..f78ba2e 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
@@ -4893,6 +4893,10 @@
     argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
+    name: "StoreMinibatchStatisticsInFdo"
+    argspec: "args=[\'program_key\', \'max_ids\', \'max_uniques\', \'sample_count\', \'num_replica\', \'feature_width\', \'num_sc_per_chip\', \'table_name\', \'mini_batch_splits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
     name: "StridedSlice"
     argspec: "args=[\'input\', \'begin\', \'end\', \'strides\', \'begin_mask\', \'end_mask\', \'ellipsis_mask\', \'new_axis_mask\', \'shrink_axis_mask\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'0\', \'0\', \'0\', \'None\'], "
   }
@@ -5682,23 +5686,23 @@
   }
   member_method {
     name: "XlaSparseDenseMatmulGradWithAdagradAndCsrInput"
-    argspec: "args=[\'row_pointers\', \'sorted_sample_ids\', \'sorted_token_ids\', \'sorted_gains\', \'activation_gradients\', \'learning_rate\', \'embedding_table\', \'accumulator\', \'num_minibatches_per_physical_sparse_core\', \'table_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'row_pointers\', \'sorted_sample_ids\', \'sorted_token_ids\', \'sorted_gains\', \'activation_gradients\', \'learning_rate\', \'embedding_table\', \'accumulator\', \'num_minibatches_per_physical_sparse_core\', \'table_name\', \'clip_weight_min\', \'clip_weight_max\', \'name\'], varargs=None, keywords=None, defaults=[\'-inf\', \'inf\', \'None\'], "
   }
   member_method {
     name: "XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput"
-    argspec: "args=[\'row_pointers\', \'sorted_sample_ids\', \'sorted_token_ids\', \'sorted_gains\', \'activation_gradients\', \'learning_rate\', \'embedding_table\', \'accumulator\', \'momenta\', \'num_minibatches_per_physical_sparse_core\', \'use_nesterov\', \'exponent\', \'beta1\', \'beta2\', \'epsilon\', \'table_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'row_pointers\', \'sorted_sample_ids\', \'sorted_token_ids\', \'sorted_gains\', \'activation_gradients\', \'learning_rate\', \'embedding_table\', \'accumulator\', \'momenta\', \'num_minibatches_per_physical_sparse_core\', \'use_nesterov\', \'exponent\', \'beta1\', \'beta2\', \'epsilon\', \'table_name\', \'clip_weight_min\', \'clip_weight_max\', \'name\'], varargs=None, keywords=None, defaults=[\'-inf\', \'inf\', \'None\'], "
   }
   member_method {
     name: "XlaSparseDenseMatmulGradWithAdamAndCsrInput"
-    argspec: "args=[\'row_pointers\', \'sorted_sample_ids\', \'sorted_token_ids\', \'sorted_gains\', \'activation_gradients\', \'learning_rate\', \'embedding_table\', \'momenta\', \'velocity\', \'num_minibatches_per_physical_sparse_core\', \'use_sum_inside_sqrt\', \'beta1\', \'beta2\', \'epsilon\', \'table_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'row_pointers\', \'sorted_sample_ids\', \'sorted_token_ids\', \'sorted_gains\', \'activation_gradients\', \'learning_rate\', \'embedding_table\', \'momenta\', \'velocity\', \'num_minibatches_per_physical_sparse_core\', \'use_sum_inside_sqrt\', \'beta1\', \'beta2\', \'epsilon\', \'table_name\', \'clip_weight_min\', \'clip_weight_max\', \'name\'], varargs=None, keywords=None, defaults=[\'-inf\', \'inf\', \'None\'], "
   }
   member_method {
     name: "XlaSparseDenseMatmulGradWithFtrlAndCsrInput"
-    argspec: "args=[\'row_pointers\', \'sorted_sample_ids\', \'sorted_token_ids\', \'sorted_gains\', \'activation_gradients\', \'learning_rate\', \'embedding_table\', \'accumulator\', \'linear\', \'num_minibatches_per_physical_sparse_core\', \'multiply_linear_by_learning_rate\', \'beta\', \'learning_rate_power\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'table_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'row_pointers\', \'sorted_sample_ids\', \'sorted_token_ids\', \'sorted_gains\', \'activation_gradients\', \'learning_rate\', \'embedding_table\', \'accumulator\', \'linear\', \'num_minibatches_per_physical_sparse_core\', \'multiply_linear_by_learning_rate\', \'beta\', \'learning_rate_power\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'table_name\', \'clip_weight_min\', \'clip_weight_max\', \'name\'], varargs=None, keywords=None, defaults=[\'-inf\', \'inf\', \'None\'], "
   }
   member_method {
     name: "XlaSparseDenseMatmulGradWithSgdAndCsrInput"
-    argspec: "args=[\'row_pointers\', \'sorted_sample_ids\', \'sorted_token_ids\', \'sorted_gains\', \'activation_gradients\', \'learning_rate\', \'embedding_table\', \'num_minibatches_per_physical_sparse_core\', \'table_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'row_pointers\', \'sorted_sample_ids\', \'sorted_token_ids\', \'sorted_gains\', \'activation_gradients\', \'learning_rate\', \'embedding_table\', \'num_minibatches_per_physical_sparse_core\', \'table_name\', \'clip_weight_min\', \'clip_weight_max\', \'name\'], varargs=None, keywords=None, defaults=[\'-inf\', \'inf\', \'None\'], "
   }
   member_method {
     name: "XlaSparseDenseMatmulWithCsrInput"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-table-config.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-table-config.pbtxt
index 8487be2..ea39cff 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-table-config.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-table-config.pbtxt
@@ -4,6 +4,6 @@
   is_instance: "<type \'object\'>"
   member_method {
     name: "__init__"
-    argspec: "args=[\'self\', \'vocabulary_size\', \'dim\', \'initializer\', \'optimizer\', \'combiner\', \'name\', \'quantization_config\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'mean\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'vocabulary_size\', \'dim\', \'initializer\', \'optimizer\', \'combiner\', \'name\', \'quantization_config\', \'layout\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'mean\', \'None\', \'None\', \'None\'], "
   }
 }
diff --git a/tensorflow/tools/api/lib/BUILD b/tensorflow/tools/api/lib/BUILD
index 4dfa261..462bc6b 100644
--- a/tensorflow/tools/api/lib/BUILD
+++ b/tensorflow/tools/api/lib/BUILD
@@ -30,7 +30,8 @@
         ":api_objects_proto_py",
         "//tensorflow/python/platform:tf_logging",
         "//tensorflow/python/util:deprecation",
-        "//tensorflow/python/util:tf_decorator",
+        "//tensorflow/python/util:tf_decorator_py",
+        "//tensorflow/python/util:tf_inspect",
     ],
 )
 
diff --git a/tensorflow/tools/ci_build/Dockerfile.rbe.cuda11.8-cudnn8.6-ubuntu20.04-manylinux2014-multipython b/tensorflow/tools/ci_build/Dockerfile.rbe.cuda11.8-cudnn8.6-ubuntu20.04-manylinux2014-multipython
index 65e89a1..ec912d7 100644
--- a/tensorflow/tools/ci_build/Dockerfile.rbe.cuda11.8-cudnn8.6-ubuntu20.04-manylinux2014-multipython
+++ b/tensorflow/tools/ci_build/Dockerfile.rbe.cuda11.8-cudnn8.6-ubuntu20.04-manylinux2014-multipython
@@ -45,3 +45,6 @@
 RUN SETUPTOOLS_USE_DISTUTILS=stdlib /install/install_pip_packages_by_version.sh "/usr/local/bin/pip3.10" "jax"
 RUN SETUPTOOLS_USE_DISTUTILS=stdlib /install/install_pip_packages_by_version.sh "/usr/local/bin/pip3.11" "jax"
 RUN SETUPTOOLS_USE_DISTUTILS=stdlib /install/install_pip_packages_by_version.sh "/usr/local/bin/pip3.12" "jax"
+
+COPY install/install_clang_17.sh /install/
+RUN /install/install_clang_17.sh
diff --git a/tensorflow/core/tfrt/saved_model/python/_pywrap_saved_model_aot_compile.pyi b/tensorflow/tools/ci_build/install/install_clang_17.sh
old mode 100644
new mode 100755
similarity index 61%
copy from tensorflow/core/tfrt/saved_model/python/_pywrap_saved_model_aot_compile.pyi
copy to tensorflow/tools/ci_build/install/install_clang_17.sh
index 05aae4b..50d2fcd
--- a/tensorflow/core/tfrt/saved_model/python/_pywrap_saved_model_aot_compile.pyi
+++ b/tensorflow/tools/ci_build/install/install_clang_17.sh
@@ -1,3 +1,4 @@
+#!/bin/bash -eu
 # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,7 +14,18 @@
 # limitations under the License.
 # ==============================================================================
 
-class AotOptions:
-    def __init__(self) -> None: ...
+# LLVM/Clang: https://apt.llvm.org/
+apt-key adv --fetch-keys https://apt.llvm.org/llvm-snapshot.gpg.key
 
-def AotCompileSavedModel(input_model_dir: str = ..., aot_options: AotOptions = ..., output_model_dir: str = ...) -> None: ...
+# Set up custom sources
+cat >/etc/apt/sources.list.d/custom.list <<SOURCES
+
+# LLVM/Clang repository
+deb http://apt.llvm.org/focal/ llvm-toolchain-focal-17 main
+deb-src http://apt.llvm.org/focal/ llvm-toolchain-focal-17 main
+SOURCES
+
+apt-get update && apt-get install -y \
+    llvm-17 \
+    clang-17 \
+    lld-17
\ No newline at end of file
diff --git a/tensorflow/tools/ci_build/osx/arm64/.macos.bazelrc b/tensorflow/tools/ci_build/osx/arm64/.macos.bazelrc
index 89dd808..c388f53 100644
--- a/tensorflow/tools/ci_build/osx/arm64/.macos.bazelrc
+++ b/tensorflow/tools/ci_build/osx/arm64/.macos.bazelrc
@@ -23,9 +23,6 @@
 # Test-related settings below this point.
 test --verbose_failures=true --local_test_jobs=HOST_CPUS --test_output=errors
 
-# TODO(b/294367488) disable after 2.15 brancut
-test --flaky_test_attempts=3
-
 # Increase the test timeout as tests often take longer on mac.
 test --test_timeout=300,450,1200,3600
 test --test_size_filters=small,medium
diff --git a/tensorflow/tools/ci_build/release/requirements_common.txt b/tensorflow/tools/ci_build/release/requirements_common.txt
index e9a79b1..da3ea7d 100644
--- a/tensorflow/tools/ci_build/release/requirements_common.txt
+++ b/tensorflow/tools/ci_build/release/requirements_common.txt
@@ -5,18 +5,19 @@
 astunparse ~= 1.6.3
 flatbuffers ~= 23.5.26
 google_pasta ~= 0.2
-h5py ~= 3.8.0  # Earliest version for Python 3.11
-ml_dtypes ~= 0.2
+h5py ~= 3.10.0  # Earliest version for Python 3.12
+ml_dtypes ~= 0.3.1
 # TODO(b/262592253): Support older versions of NumPy for Python 3.10 and lower
 # to support TFX. Remove when Apache Beam upgrades to newer NumPy.
 numpy ~= 1.22.0; python_version < '3.11'
-numpy ~= 1.23.2; python_version >= '3.11' # Earliest version for Python 3.11
+numpy ~= 1.23.2; python_version == '3.11' # Earliest version for Python 3.11
+numpy ~= 1.26.0; python_version >= '3.12' # Earliest version for Python 3.12
 opt_einsum ~= 3.3.0
 protobuf ~= 3.20.3  # NOTE: Earliest version for Python 3.10
 six ~= 1.16.0
 termcolor ~= 2.1.1
-typing_extensions ~= 3.10.0.0
-wheel ~= 0.38.1
+typing_extensions ~= 4.8.0
+wheel ~= 0.41.2
 wrapt ~= 1.14.1
 
 # We need to pin the gast dependency exactly
@@ -31,14 +32,15 @@
 tf-estimator-nightly ~= 2.14.0.dev
 
 # Test dependencies
-grpcio ~= 1.49.1 # Earliest version for Python 3.11
-portpicker ~= 1.5.2
+grpcio ~= 1.59.0 # Earliest version for Python 3.12
+portpicker ~= 1.6.0
 scipy ~= 1.7.2; python_version < '3.11'
-scipy ~= 1.9.2; python_version >= '3.11' # Earliest version for Python 3.11
+scipy ~= 1.9.2; python_version == '3.11' # Earliest version for Python 3.11
+scipy ~= 1.11.3; python_version >= '3.12' # Earliest version for Python 3.12
 
 # This is usually vendored in setuptools but ensure it gets installed in CI anyway
 # No bound here, we prefer the one in setuptools
 packaging
 
 # For using Python 3.11 with Bazel 6 (b/286090018)
-lit ~= 16.0.5.post0
+lit ~= 17.0.2
diff --git a/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh b/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh
index f058acb..ed0576a 100755
--- a/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh
+++ b/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh
@@ -60,14 +60,18 @@
 mkdir -p ${DIR}/include/tensorflow/c
 mkdir -p ${DIR}/include/tensorflow/c/eager
 mkdir -p ${DIR}/include/tensorflow/core/platform
+mkdir -p ${DIR}/include/tsl/c
+mkdir -p ${DIR}/include/tsl/platform
 mkdir -p ${DIR}/lib
 cp bazel-bin/tensorflow/tensorflow.dll ${DIR}/lib/tensorflow.dll
 cp bazel-bin/tensorflow/tensorflow.lib ${DIR}/lib/tensorflow.lib
 cp tensorflow/c/c_api.h \
   tensorflow/c/tf_attrtype.h \
+  tensorflow/c/tf_buffer.h  \
   tensorflow/c/tf_datatype.h \
   tensorflow/c/tf_status.h \
   tensorflow/c/tf_tensor.h \
+  tensorflow/c/tf_tensor_helper.h \
   tensorflow/c/tf_tstring.h \
   tensorflow/c/tf_file_statistics.h \
   tensorflow/c/tensor_interface.h \
@@ -81,6 +85,11 @@
 cp tensorflow/core/platform/ctstring.h \
   tensorflow/core/platform/ctstring_internal.h \
   ${DIR}/include/tensorflow/core/platform
+cp third_party/xla/third_party/tsl/tsl/c/tsl_status.h \
+   ${DIR}/include/tsl/c
+cp third_party/xla/third_party/tsl/tsl/platform/ctstring.h \
+   third_party/xla/third_party/tsl/tsl/platform/ctstring_internal.h \
+   ${DIR}/include/tsl/platform
 cp LICENSE ${DIR}/LICENSE
 cp bazel-bin/tensorflow/tools/lib_package/THIRD_PARTY_TF_C_LICENSES ${DIR}/
 cd ${DIR}
@@ -92,9 +101,11 @@
   include/tensorflow/c/eager/dlpack.h \
   include/tensorflow/c/c_api.h \
   include/tensorflow/c/tf_attrtype.h \
+  include/tensorflow/c/tf_buffer.h  \
   include/tensorflow/c/tf_datatype.h \
   include/tensorflow/c/tf_status.h \
   include/tensorflow/c/tf_tensor.h \
+  include/tensorflow/c/tf_tensor_helper.h \
   include/tensorflow/c/tf_tstring.h \
   include/tensorflow/c/tf_file_statistics.h \
   include/tensorflow/c/tensor_interface.h \
@@ -102,6 +113,9 @@
   include/tensorflow/c/c_api_experimental.h \
   include/tensorflow/core/platform/ctstring.h \
   include/tensorflow/core/platform/ctstring_internal.h \
+  include/tsl/c/tsl_status.h \
+  include/tsl/platform/ctstring.h \
+  include/tsl/platform/ctstring_internal.h \
   LICENSE \
   THIRD_PARTY_TF_C_LICENSES
 rm -rf lib include
diff --git a/tensorflow/tools/common/BUILD b/tensorflow/tools/common/BUILD
index c9280c9..52e88a9 100644
--- a/tensorflow/tools/common/BUILD
+++ b/tensorflow/tools/common/BUILD
@@ -17,7 +17,7 @@
     name = "public_api",
     srcs = ["public_api.py"],
     srcs_version = "PY3",
-    deps = ["//tensorflow/python/util:tf_decorator"],
+    deps = ["//tensorflow/python/util:tf_inspect"],
 )
 
 py_strict_test(
@@ -40,7 +40,7 @@
     name = "traverse",
     srcs = ["traverse.py"],
     srcs_version = "PY3",
-    deps = ["//tensorflow/python/util:tf_decorator"],
+    deps = ["//tensorflow/python/util:tf_inspect"],
 )
 
 py_strict_test(
diff --git a/tensorflow/tools/compatibility/BUILD b/tensorflow/tools/compatibility/BUILD
index 6bfd98e..3d77c92 100644
--- a/tensorflow/tools/compatibility/BUILD
+++ b/tensorflow/tools/compatibility/BUILD
@@ -171,8 +171,9 @@
         "//tensorflow:tensorflow_py",
         "//tensorflow/python/framework:test_lib",
         "//tensorflow/python/platform:client_testlib",
-        "//tensorflow/python/util:tf_decorator",
+        "//tensorflow/python/util:tf_decorator_py",
         "//tensorflow/python/util:tf_export",
+        "//tensorflow/python/util:tf_inspect",
         "//tensorflow/tools/common:public_api",
         "//tensorflow/tools/common:traverse",
         "@absl_py//absl/testing:parameterized",
diff --git a/tensorflow/tools/compatibility/update/BUILD b/tensorflow/tools/compatibility/update/BUILD
index 2c4b24598..f545e7d 100644
--- a/tensorflow/tools/compatibility/update/BUILD
+++ b/tensorflow/tools/compatibility/update/BUILD
@@ -19,7 +19,7 @@
         "//tensorflow/python:modules_with_exports",
         "//tensorflow/python:no_contrib",
         "//tensorflow/python/lib/io:file_io",
-        "//tensorflow/python/util:tf_decorator",
+        "//tensorflow/python/util:tf_decorator_py",
         "//tensorflow/python/util:tf_export",
         "//tensorflow/tools/common:public_api",
         "//tensorflow/tools/common:traverse",
@@ -39,8 +39,9 @@
         # copybara:uncomment "//third_party/py/tensorflow:tensorflow_compat_v2_estimator",
         "//tensorflow/python:no_contrib",
         "//tensorflow/python/lib/io:file_io",
-        "//tensorflow/python/util:tf_decorator",
+        "//tensorflow/python/util:tf_decorator_py",
         "//tensorflow/python/util:tf_export",
+        "//tensorflow/python/util:tf_inspect",
         "//tensorflow/tools/common:public_api",
         "//tensorflow/tools/common:traverse",
         "//tensorflow/tools/compatibility:tf_upgrade_v2_lib",
diff --git a/tensorflow/tools/def_file_filter/def_file_filter.py.tpl b/tensorflow/tools/def_file_filter/def_file_filter.py.tpl
index 18426a4..4091a57 100644
--- a/tensorflow/tools/def_file_filter/def_file_filter.py.tpl
+++ b/tensorflow/tools/def_file_filter/def_file_filter.py.tpl
@@ -300,8 +300,8 @@
     def_fp.write("\t ??_7ConfigProto@tensorflow@@6B@\n") # for _pywrap_tfe
     def_fp.write("\t ??_7CoordinatedTask@tensorflow@@6B@\n") # for _pywrap_tfe
     def_fp.write("\t ?InternalSwap@CoordinatedTask@tensorflow@@AEAAXPEAV12@@Z\n") # for _pywrap_tfe
-    def_fp.write("\t ?kSeed@MixingHashState@hash_internal@lts_20230125@absl@@0QEBXEB\n") # for _pywrap_tfcompile
-    def_fp.write("\t ?kEmptyGroup@container_internal@lts_20230125@absl@@3QBW4ctrl_t@123@B\n") # for _pywrap_tfcompile
+    def_fp.write("\t ?kSeed@MixingHashState@hash_internal@lts_20230802@absl@@0QEBXEB\n") # for _pywrap_tfcompile
+    def_fp.write("\t ?kEmptyGroup@container_internal@lts_20230802@absl@@3QBW4ctrl_t@123@B\n") # for _pywrap_tfcompile
     def_fp.write("\t ??_7GraphDef@tensorflow@@6B@\n")
     def_fp.write("\t ??_7DeviceProperties@tensorflow@@6B@\n")
     def_fp.write("\t ??_7MetaGraphDef@tensorflow@@6B@\n")
@@ -310,7 +310,7 @@
     def_fp.write("\t ??1CoordinatedTask@tensorflow@@UEAA@XZ\n") # for _pywrap_tfe
     def_fp.write("\t ?CopyFrom@CoordinatedTask@tensorflow@@QEAAXAEBV12@@Z\n") # for _pywrap_tfe
     def_fp.write("\t ??0CoordinatedTask@tensorflow@@IEAA@PEAVArena@protobuf@google@@_N@Z\n") # for _pywrap_tfe
-    def_fp.write("\t ?MaybeTrackCordImpl@CordzInfo@cord_internal@lts_20230125@absl@@CAXAEAVInlineData@234@AEBV5234@W4MethodIdentifier@CordzUpdateTracker@234@@Z\n") # for tensorflow::Status usage of absl::Cord
+    def_fp.write("\t ?MaybeTrackCordImpl@CordzInfo@cord_internal@lts_20230802@absl@@CAXAEAVInlineData@234@AEBV5234@W4MethodIdentifier@CordzUpdateTracker@234@@Z\n") # for tensorflow::Status usage of absl::Cord
 
 
     # Each symbols returned by undname matches the same position in candidates.
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index 2166019..b926de8 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -66,7 +66,6 @@
         "//tensorflow/python/util:kernel_registry",
         "//tensorflow/python/framework:python_op_gen",
         "//tensorflow/python/client:tf_session_helper",
-        "//third_party/eigen3",
         "@local_xla//xla/stream_executor",
     ] + if_cuda([
         "@local_config_cuda//cuda:cuda_headers",
diff --git a/tensorflow/tools/pip_package/build_pip_package.sh b/tensorflow/tools/pip_package/build_pip_package.sh
index 423a79b..8a626aa 100755
--- a/tensorflow/tools/pip_package/build_pip_package.sh
+++ b/tensorflow/tools/pip_package/build_pip_package.sh
@@ -300,7 +300,6 @@
   fi
 
   mkdir -p ${TMPDIR}/third_party
-  cp -R $RUNFILES/third_party/eigen3 ${TMPDIR}/third_party
   cp -LR $RUNFILES/../local_config_cuda/cuda/_virtual_includes/cuda_headers_virtual/third_party/gpus ${TMPDIR}/third_party
   cp $RUNFILES/tensorflow/tools/pip_package/THIRD_PARTY_NOTICES.txt "${TMPDIR}/tensorflow"
 
diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py
index 24596fd..aaf475a 100644
--- a/tensorflow/tools/pip_package/setup.py
+++ b/tensorflow/tools/pip_package/setup.py
@@ -119,7 +119,7 @@
     # dependencies on the release branch is updated to the stable releases (RC
     # or final). For example, 'keras-nightly ~= 2.14.0.dev' will be replaced by
     # 'keras >= 2.14.0rc0, < 2.15' on the release branch after the branch cut.
-    'tb-nightly ~= 2.15.0.a',
+    'tb-nightly ~= 2.16.0.a',
     'tf-estimator-nightly ~= 2.14.0.dev',
     'keras-nightly ~= 3.0.0.dev',
 ]
@@ -176,7 +176,7 @@
     'nvidia-curand-cu12 == 10.3.3.141',
     'nvidia-cusolver-cu12 == 11.5.2.141',
     'nvidia-cusparse-cu12 == 12.1.2.141',
-    'nvidia-nccl-cu12 == 2.16.5',
+    'nvidia-nccl-cu12 == 2.18.3',
     'nvidia-nvjitlink-cu12 == 12.2.140',
     'tensorrt == 8.6.1.post1',
     'tensorrt-bindings == 8.6.1',
@@ -349,7 +349,6 @@
     list(find_files('*.h', 'tensorflow/tsl')) +
     list(find_files('*.h', 'google/com_google_protobuf/src')) +
     list(find_files('*.inc', 'google/com_google_protobuf/src')) +
-    list(find_files('*', 'third_party/eigen3')) +
     list(find_files('*', 'third_party/gpus')) +
     list(find_files('*.h', 'tensorflow/include/external/com_google_absl')) +
     list(find_files('*.inc', 'tensorflow/include/external/com_google_absl')) +
diff --git a/tensorflow/tools/proto_splitter/split.py b/tensorflow/tools/proto_splitter/split.py
index cf37143..7a22950d 100644
--- a/tensorflow/tools/proto_splitter/split.py
+++ b/tensorflow/tools/proto_splitter/split.py
@@ -48,11 +48,14 @@
     """Splits proto message into a Sequence of protos/bytes."""
 
   @abc.abstractmethod
-  def write(self, file_prefix: str) -> None:
+  def write(self, file_prefix: str) -> str:
     """Serializes proto to disk.
 
     Args:
       file_prefix: string prefix of the filepath.
+
+    Returns:
+      The actual path the proto is written to.
     """
 
 
@@ -147,7 +150,7 @@
       self._built = True
     return self._chunks, self._chunked_message
 
-  def write(self, file_prefix: str) -> None:
+  def write(self, file_prefix: str) -> str:
     """Serializes a proto to disk.
 
     The writer writes all chunks into a riegeli file. The chunk metadata
@@ -157,6 +160,11 @@
       file_prefix: string prefix of the filepath. The writer will automatically
         attach a `.pb` or `.cpb` (chunked pb) suffix depending on whether the
         proto is split.
+
+    Returns:
+      The actual filepath the proto is written to. The filepath will be
+      different depending on whether the proto is split, i.e., whether it will
+      be a pb or not.
     """
     if self._parent_splitter is not None:
       raise ValueError(
@@ -174,7 +182,7 @@
           path, self._proto.SerializeToString(deterministic=True)
       )
       logging.info("Unchunked file exported to %s", path)
-      return
+      return path
 
     path = f"{file_prefix}.cpb"
     with riegeli.RecordWriter(file_io.FileIO(path, "wb")) as f:
@@ -206,6 +214,7 @@
         "Number of chunks created (including initial message): %s",
         len(chunks),
     )
+    return path
 
   def add_chunk(
       self,
diff --git a/tensorflow/tools/proto_splitter/split_test.py b/tensorflow/tools/proto_splitter/split_test.py
index de2f6d3..849d680 100644
--- a/tensorflow/tools/proto_splitter/split_test.py
+++ b/tensorflow/tools/proto_splitter/split_test.py
@@ -72,9 +72,10 @@
   def testWrite(self):
     path = os.path.join(self.create_tempdir(), "split-repeat")
     data = [_random_string(5), _random_string(10), _random_string(15)]
-    RepeatedStringSplitter(test_message_pb2.RepeatedString(strings=data)).write(
-        path
-    )
+    returned_path = RepeatedStringSplitter(
+        test_message_pb2.RepeatedString(strings=data)
+    ).write(path)
+    self.assertEqual(returned_path, f"{path}.cpb")
 
     with riegeli.RecordReader(open(f"{path}.cpb", "rb")) as reader:
       self.assertTrue(reader.check_file_format())
@@ -148,10 +149,11 @@
   def testWriteNoChunks(self):
     path = os.path.join(self.create_tempdir(), "split-none")
     proto = test_message_pb2.RepeatedString(strings=["a", "bc", "de"])
-    NoOpSplitter(proto).write(path)
+    returned_path = NoOpSplitter(proto).write(path)
 
     expected_file_path = path + ".pb"
     self.assertTrue(os.path.isfile(expected_file_path))
+    self.assertEqual(returned_path, expected_file_path)
 
     parsed_proto = test_message_pb2.RepeatedString()
     with open(expected_file_path, "rb") as f:
diff --git a/tensorflow/tools/tensorflow_builder/compat_checker/BUILD b/tensorflow/tools/tensorflow_builder/compat_checker/BUILD
index cb13531..c2eb4da 100644
--- a/tensorflow/tools/tensorflow_builder/compat_checker/BUILD
+++ b/tensorflow/tools/tensorflow_builder/compat_checker/BUILD
@@ -18,7 +18,7 @@
     srcs_version = "PY3",
     deps = [
         "//tensorflow/python/platform:tf_logging",
-        "//tensorflow/python/util:tf_decorator",
+        "//tensorflow/python/util:tf_inspect",
     ],
 )
 
diff --git a/tensorflow/tools/tf_sig_build_dockerfiles/builder.devtoolset/build_devtoolset.sh b/tensorflow/tools/tf_sig_build_dockerfiles/builder.devtoolset/build_devtoolset.sh
index 6f4a566..b4c6367 100755
--- a/tensorflow/tools/tf_sig_build_dockerfiles/builder.devtoolset/build_devtoolset.sh
+++ b/tensorflow/tools/tf_sig_build_dockerfiles/builder.devtoolset/build_devtoolset.sh
@@ -184,7 +184,7 @@
 # TODO(klimek): Automate linking in all non-gcc / non-kernel include
 # directories.
 mkdir -p "/${TARGET}/usr/include/x86_64-linux-gnu"
-PYTHON_VERSIONS=("python3.9" "python3.10" "python3.11")
+PYTHON_VERSIONS=("python3.9" "python3.10" "python3.11" "python3.12")
 for v in "${PYTHON_VERSIONS[@]}"; do
   ln -s "/usr/local/include/${v}" "/${TARGET}/usr/include/x86_64-linux-gnu/${v}"
 done
diff --git a/tensorflow/tools/tf_sig_build_dockerfiles/devel.packages.txt b/tensorflow/tools/tf_sig_build_dockerfiles/devel.packages.txt
index f0f05be..9fcd871 100644
--- a/tensorflow/tools/tf_sig_build_dockerfiles/devel.packages.txt
+++ b/tensorflow/tools/tf_sig_build_dockerfiles/devel.packages.txt
@@ -11,8 +11,10 @@
 libcurand-12-2
 libcusolver-dev-12-2
 libcusparse-dev-12-2
+libcublas-12-2
 libcublas-dev-12-2
 libnccl-dev=2.18.5-1+cuda12.2
+libnccl2=2.18.5-1+cuda12.2
 # CuDNN: https://docs.nvidia.com/deeplearning/sdk/cudnn-install/index.html#ubuntu-network-installation
 libcudnn8-dev=8.9.4.25-1+cuda12.2
 libcudnn8=8.9.4.25-1+cuda12.2
@@ -28,6 +30,8 @@
 automake
 build-essential
 ca-certificates
+# TODO(b/308399490) Remove CMake once dm-tree (Keras dependency) has 3.12 wheels
+cmake
 llvm-17
 clang-17
 lld-17
diff --git a/tensorflow/tools/tf_sig_build_dockerfiles/devel.requirements.txt b/tensorflow/tools/tf_sig_build_dockerfiles/devel.requirements.txt
index 9e12049..62e73c9 100644
--- a/tensorflow/tools/tf_sig_build_dockerfiles/devel.requirements.txt
+++ b/tensorflow/tools/tf_sig_build_dockerfiles/devel.requirements.txt
@@ -8,19 +8,21 @@
 astunparse ~= 1.6.3
 flatbuffers ~= 23.5.26
 google_pasta ~= 0.2
-h5py ~= 3.8.0 # Earliest version for Python 3.11
-ml_dtypes ~= 0.2
+h5py ~= 3.10.0 # Earliest version for Python 3.12
+ml_dtypes ~= 0.3.1
 # TODO(b/262592253): Support older versions of NumPy for Python 3.10 and lower
 # to support TFX. Remove when Apache Beam upgrades to newer NumPy.
 numpy ~= 1.22.0; python_version < '3.11'
-numpy ~= 1.23.2; python_version >= '3.11' # Earliest version for Python 3.11
+numpy ~= 1.23.2; python_version == '3.11' # Earliest version for Python 3.11
+numpy ~= 1.26.0; python_version >= '3.12' # Earliest version for Python 3.12
 opt_einsum ~= 3.3.0
-packaging ~= 21.3
+packaging ~= 23.2
 protobuf ~= 3.20.3
 six ~= 1.16.0
 termcolor ~= 2.1.1
-typing_extensions ~= 3.10.0.0
-wheel ~= 0.38.1
+typing_extensions ~= 4.8.0
+wheel ~= 0.41.2
+setuptools >= 68.2.2
 wrapt ~= 1.14.1
 # We need to pin the gast dependency exactly
 gast == 0.4.0
@@ -34,13 +36,14 @@
 tb-nightly ~= 2.13.0.a
 tf-estimator-nightly ~= 2.14.0.dev
 # Test dependencies
-grpcio ~= 1.49.1 # Earliest version for Python 3.11
-portpicker ~= 1.5.2
+grpcio ~= 1.59.0 # Earliest version for Python 3.12
+portpicker ~= 1.6.0
 scipy ~= 1.7.2; python_version < '3.11'
-scipy ~= 1.9.2; python_version >= '3.11' # Earliest version for Python 3.11
+scipy ~= 1.9.2; python_version == '3.11' # Earliest version for Python 3.11
+scipy ~= 1.11.3; python_version >= '3.12' # Earliest version for Python 3.12
 # Required for TFLite import from JAX tests
-jax ~= 0.3.25
-jaxlib ~= 0.3.25 # Earliest version for Python 3.11
+jax ~= 0.3.25; python_version <= '3.11'
+jaxlib ~= 0.3.25; python_version <= '3.11' # Earliest version for Python 3.11
 # Needs to be addressed. Unblocked 2.4 branchcut cl/338377048
 PyYAML ~= 6.0
 # For uploading
@@ -52,4 +55,4 @@
 pylint ~= 2.13.9
 
 # For using Python 3.11 with Bazel 6 (b/286090018)
-lit ~= 16.0.5.post0
+lit ~= 17.0.2
diff --git a/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/cpu_gcc.bazelrc b/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/cpu_gcc.bazelrc
index 1ecb606..14b7564 100644
--- a/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/cpu_gcc.bazelrc
+++ b/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/cpu_gcc.bazelrc
@@ -29,7 +29,7 @@
 build --profile=/tf/pkg/profile.json.gz
 
 # Use the NVCC toolchain to compile for manylinux2014
-build --crosstool_top="@sigbuild-r2.14_config_cuda//crosstool:toolchain"
+build --crosstool_top="@sigbuild-r2.16_config_cuda//crosstool:toolchain"
 
 # Test-related settings below this point.
 test --build_tests_only --keep_going --test_output=errors --verbose_failures=true
@@ -67,14 +67,14 @@
 build:rbe --remote_download_toplevel
 build:rbe --action_env=PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/go/bin"
 build:rbe --linkopt=-lrt --host_linkopt=-lrt --linkopt=-lm --host_linkopt=-lm  # Unclear why this is here
-build:rbe --host_crosstool_top="@sigbuild-r2.14_config_cuda//crosstool:toolchain"
-build:rbe --crosstool_top="@sigbuild-r2.14_config_cuda//crosstool:toolchain"
-build:rbe --extra_toolchains="@sigbuild-r2.14_config_cuda//crosstool:toolchain-linux-x86_64"
-build:rbe --extra_execution_platforms="@sigbuild-r2.14_config_platform//:platform"
-build:rbe --host_platform="@sigbuild-r2.14_config_platform//:platform"
-build:rbe --platforms="@sigbuild-r2.14_config_platform//:platform"
+build:rbe --host_crosstool_top="@sigbuild-r2.16_config_cuda//crosstool:toolchain"
+build:rbe --crosstool_top="@sigbuild-r2.16_config_cuda//crosstool:toolchain"
+build:rbe --extra_toolchains="@sigbuild-r2.16_config_cuda//crosstool:toolchain-linux-x86_64"
+build:rbe --extra_execution_platforms="@sigbuild-r2.16_config_platform//:platform"
+build:rbe --host_platform="@sigbuild-r2.16_config_platform//:platform"
+build:rbe --platforms="@sigbuild-r2.16_config_platform//:platform"
 # Python config is the same across all containers because the binary is the same
-build:rbe --repo_env=TF_PYTHON_CONFIG_REPO="@sigbuild-r2.14_config_python"
+build:rbe --repo_env=TF_PYTHON_CONFIG_REPO="@sigbuild-r2.16_config_python"
 build:rbe --remote_instance_name=projects/tensorflow-testing/instances/default_instance
 build:rbe --project_id="tensorflow-testing"
 
diff --git a/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/wheel_verification.bats b/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/wheel_verification.bats
index 60cde49..19662eb 100644
--- a/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/wheel_verification.bats
+++ b/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/wheel_verification.bats
@@ -72,7 +72,7 @@
 # Is this still useful?
 @test "TensorFlow has Keras" {
     source /tf/venv/bin/activate
-    python3 -c 'import sys; import tensorflow as tf; sys.exit(0 if "_v2.keras" in tf.keras.__name__ else 1)'
+    python3 -c 'import sys; import tensorflow as tf; sys.exit(0 if "keras" in tf.keras.__name__ else 1)'
 }
 
 # Is this still useful?
diff --git a/tensorflow/tools/tf_sig_build_dockerfiles/setup.python.sh b/tensorflow/tools/tf_sig_build_dockerfiles/setup.python.sh
index 98d23d0..007ee34 100755
--- a/tensorflow/tools/tf_sig_build_dockerfiles/setup.python.sh
+++ b/tensorflow/tools/tf_sig_build_dockerfiles/setup.python.sh
@@ -59,6 +59,7 @@
 curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
 python3 get-pip.py
 python3 -m pip install --no-cache-dir --upgrade pip
+python3 -m pip install -U setuptools
 
 # Disable the cache dir to save image space, and install packages
 python3 -m pip install --no-cache-dir -r $REQUIREMENTS -U
diff --git a/tensorflow/tools/toolchains/python/python_repo.bzl b/tensorflow/tools/toolchains/python/python_repo.bzl
index 59be9f6..77011b2 100644
--- a/tensorflow/tools/toolchains/python/python_repo.bzl
+++ b/tensorflow/tools/toolchains/python/python_repo.bzl
@@ -5,7 +5,7 @@
 """
 
 VERSIONS = ["3.9", "3.10", "3.11", "3.12"]
-DEFAULT_VERSION = "3.10"
+DEFAULT_VERSION = "3.11"
 WARNING = """
 TF_PYTHON_VERSION environment variable was not set correctly; using Python {}.
 
@@ -13,6 +13,11 @@
 export TF_PYTHON_VERSION=3.11
 """.format(DEFAULT_VERSION)
 
+content = """
+TF_PYTHON_VERSION = "{}"
+HERMETIC_PYTHON_VERSION = "{}"
+"""
+
 def _python_repository_impl(repository_ctx):
     repository_ctx.file("BUILD", "")
     version = repository_ctx.os.environ.get("TF_PYTHON_VERSION", "")
@@ -21,8 +26,7 @@
         version = DEFAULT_VERSION
     repository_ctx.file(
         "py_version.bzl",
-        "HERMETIC_PYTHON_VERSION = \"%s\"" %
-        version,
+        content.format(version, version),
     )
 
 python_repository = repository_rule(
diff --git a/tensorflow/tools/toolchains/remote_config/configs.bzl b/tensorflow/tools/toolchains/remote_config/configs.bzl
index 512dcf9..e8fc081 100644
--- a/tensorflow/tools/toolchains/remote_config/configs.bzl
+++ b/tensorflow/tools/toolchains/remote_config/configs.bzl
@@ -602,3 +602,82 @@
             "TF_TENSORRT_VERSION": "8.6",
         },
     )
+
+    sigbuild_tf_configs(
+        name_container_map = {
+            "sigbuild-r2.16": "docker://gcr.io/tensorflow-sigs/build@sha256:c13559bbf5df818bb586ad0880b29c409398b56fd8cc122ab0b31dc2b2416505",
+            "sigbuild-r2.16-python3.9": "docker://gcr.io/tensorflow-sigs/build@sha256:c13559bbf5df818bb586ad0880b29c409398b56fd8cc122ab0b31dc2b2416505",
+            "sigbuild-r2.16-python3.10": "docker://gcr.io/tensorflow-sigs/build@sha256:93c234df4c781af6974d86e9d1dd2e19ce0845b1b662c38e9a30d1de64eab3b0",
+            "sigbuild-r2.16-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:d0a91705406aad65a79011683b8f7d4b8131625ea26a6d08aa7c6eb6955873a2",
+            "sigbuild-r2.16-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:ed7313f95bce391cbf3b498ff6c534d163cc2bb91ca1d6ef6363bde4fd9e0cfc",
+        },
+        # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12
+        # and manylinux2014 is 2.17.
+        env = {
+            "ABI_LIBC_VERSION": "glibc_2.19",
+            "ABI_VERSION": "gcc",
+            "BAZEL_COMPILER": "/dt9/usr/bin/gcc",
+            "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu",
+            "BAZEL_TARGET_CPU": "k8",
+            "BAZEL_TARGET_LIBC": "glibc_2.19",
+            "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu",
+            "CC": "/dt9/usr/bin/gcc",
+            "CC_TOOLCHAIN_NAME": "linux_gnu_x86",
+            "CLEAR_CACHE": "1",
+            "CUDNN_INSTALL_PATH": "/usr/lib/x86_64-linux-gnu",
+            "GCC_HOST_COMPILER_PATH": "/dt9/usr/bin/gcc",
+            "GCC_HOST_COMPILER_PREFIX": "/usr/bin",
+            "HOST_CXX_COMPILER": "/dt9/usr/bin/gcc",
+            "HOST_C_COMPILER": "/dt9/usr/bin/gcc",
+            "PYTHON_BIN_PATH": "/usr/bin/python3",
+            "TENSORRT_INSTALL_PATH": "/usr/lib/x86_64-linux-gnu",
+            "TF_CUDA_CLANG": "0",
+            "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0",
+            "TF_CUDA_VERSION": "12.2",
+            "TF_CUDNN_VERSION": "8.9",
+            "TF_ENABLE_XLA": "1",
+            "TF_NEED_CUDA": "1",
+            "TF_NEED_TENSORRT": "1",
+            "TF_SYSROOT": "/dt9",
+            "TF_TENSORRT_VERSION": "8.6",
+        },
+    )
+
+    sigbuild_tf_configs(
+        name_container_map = {
+            "sigbuild-r2.16-clang": "docker://gcr.io/tensorflow-sigs/build@sha256:c13559bbf5df818bb586ad0880b29c409398b56fd8cc122ab0b31dc2b2416505",
+            "sigbuild-r2.16-clang-python3.9": "docker://gcr.io/tensorflow-sigs/build@sha256:c13559bbf5df818bb586ad0880b29c409398b56fd8cc122ab0b31dc2b2416505",
+            "sigbuild-r2.16-clang-python3.10": "docker://gcr.io/tensorflow-sigs/build@sha256:93c234df4c781af6974d86e9d1dd2e19ce0845b1b662c38e9a30d1de64eab3b0",
+            "sigbuild-r2.16-clang-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:d0a91705406aad65a79011683b8f7d4b8131625ea26a6d08aa7c6eb6955873a2",
+            "sigbuild-r2.16-clang-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:ed7313f95bce391cbf3b498ff6c534d163cc2bb91ca1d6ef6363bde4fd9e0cfc",
+        },
+        # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12
+        # and manylinux2014 is 2.17.
+        env = {
+            "ABI_LIBC_VERSION": "glibc_2.19",
+            "ABI_VERSION": "gcc",
+            "BAZEL_COMPILER": "/usr/lib/llvm-17/bin/clang",
+            "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu",
+            "BAZEL_TARGET_CPU": "k8",
+            "BAZEL_TARGET_LIBC": "glibc_2.19",
+            "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu",
+            "CC": "/usr/lib/llvm-17/bin/clang",
+            "CC_TOOLCHAIN_NAME": "linux_gnu_x86",
+            "CLEAR_CACHE": "1",
+            "CUDNN_INSTALL_PATH": "/usr/lib/x86_64-linux-gnu",
+            "CLANG_CUDA_COMPILER_PATH": "/usr/lib/llvm-17/bin/clang",
+            "HOST_CXX_COMPILER": "/usr/lib/llvm-17/bin/clang",
+            "HOST_C_COMPILER": "/usr/lib/llvm-17/bin/clang",
+            "PYTHON_BIN_PATH": "/usr/bin/python3",
+            "TENSORRT_INSTALL_PATH": "/usr/lib/x86_64-linux-gnu",
+            "TF_CUDA_CLANG": "1",
+            "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0",
+            "TF_CUDA_VERSION": "12.2",
+            "TF_CUDNN_VERSION": "8.9",
+            "TF_ENABLE_XLA": "1",
+            "TF_NEED_CUDA": "1",
+            "TF_NEED_TENSORRT": "1",
+            "TF_SYSROOT": "/dt9",
+            "TF_TENSORRT_VERSION": "8.6",
+        },
+    )
diff --git a/tensorflow/tools/toolchains/remote_config/containers.bzl b/tensorflow/tools/toolchains/remote_config/containers.bzl
index 1b540ed..bfb4634 100644
--- a/tensorflow/tools/toolchains/remote_config/containers.bzl
+++ b/tensorflow/tools/toolchains/remote_config/containers.bzl
@@ -1,11 +1,11 @@
 """Docker images used with remote config and RBE."""
 
-"""SHA 256 values for each image."""
+# SHA 256 values for each image.
 container_digests = {
     # TF now uses only this container
     "cuda11.2-cudnn8.1-ubuntu20.04-manylinux2014-multipython": "sha256:48612bd85709cd014711d0b0f87e0806f3567d06d2e81c6e860516b87498b821",
     # JAX manylinux2014 configs.
-    "cuda11.8-cudnn8.6-ubuntu20.04-manylinux2014-multipython": "sha256:77234e5750afcf85c08e8980eff2e8c58ba207a0c32b06a372cafb687d144d2b",
+    "cuda11.8-cudnn8.6-ubuntu20.04-manylinux2014-multipython": "sha256:ab39410baf2fc1d31d50540acec7640d7f4814fa694e2421b696b6f0a058d645",
     "cuda12.2-cudnn8.9-ubuntu20.04-manylinux2014-multipython": "sha256:b699d6ae235ac601dc3e62391ac7c4606cb10331f8141983858c1580f5e74ddb",
     # ROCM, probably not all of them still in use
     "rocm-ubuntu18.04-manylinux2010-multipython": "sha256:6e953a09b145df338bcb03e9e36f99b291140c29b72d0a048fb6c5905ccad5eb",
diff --git a/tensorflow/workspace1.bzl b/tensorflow/workspace1.bzl
index c74a2e1..9b092a1 100644
--- a/tensorflow/workspace1.bzl
+++ b/tensorflow/workspace1.bzl
@@ -5,7 +5,6 @@
 load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")
 load("@com_google_benchmark//:bazel/benchmark_deps.bzl", "benchmark_deps")
 load("@io_bazel_rules_closure//closure:defs.bzl", "closure_repositories")
-load("@rules_cuda//cuda:dependencies.bzl", "rules_cuda_dependencies")
 load("@rules_pkg//:deps.bzl", "rules_pkg_dependencies")
 
 # buildifier: disable=unnamed-macro
@@ -13,10 +12,9 @@
     """Loads a set of TensorFlow dependencies. To be used in a WORKSPACE file.
 
     Args:
-      with_rules_cc: whether to load and patch rules_cc repository.
+      with_rules_cc: Unused, to be removed soon.
     """
     native.register_toolchains("@local_config_python//:py_toolchain")
-    rules_cuda_dependencies(with_rules_cc)
     rules_pkg_dependencies()
 
     closure_repositories()
diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl
index 7a35a4e..6507478 100644
--- a/tensorflow/workspace2.bzl
+++ b/tensorflow/workspace2.bzl
@@ -150,9 +150,9 @@
     # LINT.IfChange
     tf_http_archive(
         name = "XNNPACK",
-        sha256 = "f9c5e1cf1bcc7920985df92322b95e537f284914339c0836e91c352f51345182",
-        strip_prefix = "XNNPACK-bbbaa7352a3ea729987d3e654d37be93e8009691",
-        urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/bbbaa7352a3ea729987d3e654d37be93e8009691.zip"),
+        sha256 = "88e0158aff1e1498e34dfcaf08d948a73a3246a04fe96e548da71f6b9245a009",
+        strip_prefix = "XNNPACK-c7e7cde37615a81a529c326aa278bfab4cd6fe5a",
+        urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/c7e7cde37615a81a529c326aa278bfab4cd6fe5a.zip"),
     )
     # LINT.ThenChange(//tensorflow/lite/tools/cmake/modules/xnnpack.cmake)
 
@@ -507,9 +507,9 @@
         name = "nccl_archive",
         build_file = "//third_party:nccl/archive.BUILD",
         patch_file = ["//third_party/nccl:archive.patch"],
-        sha256 = "0e3d7b6295beed81dc15002e88abf7a3b45b5c686b13b779ceac056f5612087f",
-        strip_prefix = "nccl-2.16.5-1",
-        urls = tf_mirror_urls("https://github.com/nvidia/nccl/archive/v2.16.5-1.tar.gz"),
+        sha256 = "16ac98f3e926c024ce48e10ab220e19ce734adc48c423cfd55ad6f509bd1179f",
+        strip_prefix = "nccl-2.18.5-1",
+        urls = tf_mirror_urls("https://github.com/nvidia/nccl/archive/v2.18.5-1.tar.gz"),
     )
 
     java_import_external(
diff --git a/third_party/absl/absl_designated_initializers.patch b/third_party/absl/absl_designated_initializers.patch
deleted file mode 100644
index 6ee2322..0000000
--- a/third_party/absl/absl_designated_initializers.patch
+++ /dev/null
@@ -1,65 +0,0 @@
-diff --git a/absl/crc/internal/crc_memcpy_x86_64.cc b/absl/crc/internal/crc_memcpy_x86_64.cc
-index 66f784de..ff424c54 100644
---- a/absl/crc/internal/crc_memcpy_x86_64.cc
-+++ b/absl/crc/internal/crc_memcpy_x86_64.cc
-@@ -359,18 +359,18 @@ CrcMemcpy::ArchSpecificEngines CrcMemcpy::GetArchSpecificEngines() {
-     case CpuType::kIntelHaswell:
-     case CpuType::kIntelIvybridge:
-       return {
--          .temporal = new FallbackCrcMemcpyEngine(),
--          .non_temporal = new CrcNonTemporalMemcpyAVXEngine(),
-+          /*.temporal=*/new FallbackCrcMemcpyEngine(),
-+          /*.non_temporal=*/new CrcNonTemporalMemcpyAVXEngine(),
-       };
-     // INTEL_SANDYBRIDGE performs better with SSE than AVX.
-     case CpuType::kIntelSandybridge:
-       return {
--          .temporal = new FallbackCrcMemcpyEngine(),
--          .non_temporal = new CrcNonTemporalMemcpyEngine(),
-+          /*.temporal=*/new FallbackCrcMemcpyEngine(),
-+          /*.non_temporal=*/new CrcNonTemporalMemcpyEngine(),
-       };
-     default:
--      return {.temporal = new FallbackCrcMemcpyEngine(),
--              .non_temporal = new FallbackCrcMemcpyEngine()};
-+      return {/*.temporal=*/new FallbackCrcMemcpyEngine(),
-+              /*.non_temporal=*/new FallbackCrcMemcpyEngine()};
-   }
- #else
-   // Get the underlying architecture.
-@@ -388,8 +388,8 @@ CrcMemcpy::ArchSpecificEngines CrcMemcpy::GetArchSpecificEngines() {
-     case CpuType::kAmdRome:
-     case CpuType::kAmdNaples:
-       return {
--          .temporal = new AcceleratedCrcMemcpyEngine<1, 2>(),
--          .non_temporal = new CrcNonTemporalMemcpyAVXEngine(),
-+          /*.temporal=*/new AcceleratedCrcMemcpyEngine<1, 2>(),
-+          /*.non_temporal=*/new CrcNonTemporalMemcpyAVXEngine(),
-       };
-     // PCLMULQDQ is slow and we don't have wide enough issue width to take
-     // advantage of it.  For an unknown architecture, don't risk using CLMULs.
-@@ -400,18 +400,18 @@ CrcMemcpy::ArchSpecificEngines CrcMemcpy::GetArchSpecificEngines() {
-     case CpuType::kIntelHaswell:
-     case CpuType::kIntelIvybridge:
-       return {
--          .temporal = new AcceleratedCrcMemcpyEngine<3, 0>(),
--          .non_temporal = new CrcNonTemporalMemcpyAVXEngine(),
-+          /*.temporal=*/new AcceleratedCrcMemcpyEngine<3, 0>(),
-+          /*.non_temporal=*/new CrcNonTemporalMemcpyAVXEngine(),
-       };
-     // INTEL_SANDYBRIDGE performs better with SSE than AVX.
-     case CpuType::kIntelSandybridge:
-       return {
--          .temporal = new AcceleratedCrcMemcpyEngine<3, 0>(),
--          .non_temporal = new CrcNonTemporalMemcpyEngine(),
-+          /*.temporal=*/new AcceleratedCrcMemcpyEngine<3, 0>(),
-+          /*.non_temporal=*/new CrcNonTemporalMemcpyEngine(),
-       };
-     default:
--      return {.temporal = new FallbackCrcMemcpyEngine(),
--              .non_temporal = new FallbackCrcMemcpyEngine()};
-+      return {/*.temporal=*/new FallbackCrcMemcpyEngine(),
-+              /*.non_temporal=*/new FallbackCrcMemcpyEngine()};
-   }
- #endif  // UNDEFINED_BEHAVIOR_SANITIZER
- }
diff --git a/third_party/absl/workspace.bzl b/third_party/absl/workspace.bzl
index 07f49ce..06f7516 100644
--- a/third_party/absl/workspace.bzl
+++ b/third_party/absl/workspace.bzl
@@ -7,8 +7,8 @@
 
     # Attention: tools parse and update these lines.
     # LINT.IfChange
-    ABSL_COMMIT = "b971ac5250ea8de900eae9f95e06548d14cd95fe"
-    ABSL_SHA256 = "8eeec9382fc0338ef5c60053f3a4b0e0708361375fe51c9e65d0ce46ccfe55a7"
+    ABSL_COMMIT = "fb3621f4f897824c0dbe0615fa94543df6192f30"
+    ABSL_SHA256 = "0320586856674d16b0b7a4d4afb22151bdc798490bb7f295eddd8f6a62b46fea"
     # LINT.ThenChange(//tensorflow/lite/tools/cmake/modules/abseil-cpp.cmake)
 
     SYS_DIRS = [
@@ -42,9 +42,6 @@
         build_file = "//third_party/absl:com_google_absl.BUILD",
         system_build_file = "//third_party/absl:system.BUILD",
         system_link_files = SYS_LINKS,
-        # This patch pulls in a fix for designated initializers that MSVC
-        # complains about. It shouldn't be necessary at the next LTS release.
-        patch_file = ["//third_party/absl:absl_designated_initializers.patch"],
         strip_prefix = "abseil-cpp-{commit}".format(commit = ABSL_COMMIT),
         urls = tf_mirror_urls("https://github.com/abseil/abseil-cpp/archive/{commit}.tar.gz".format(commit = ABSL_COMMIT)),
     )
diff --git a/third_party/ducc/ducc.BUILD b/third_party/ducc/ducc.BUILD
index 0eff685..8d71392 100644
--- a/third_party/ducc/ducc.BUILD
+++ b/third_party/ducc/ducc.BUILD
@@ -5,39 +5,39 @@
 
 exports_files(["LICENSE"])
 
-# The DUCC FFT source files are dual-licensed as BSD 3 clause and GPLv2.
-# We choose BSD 3 clause.
-DUCC_SOURCES = [
-    "google/ducc0_custom_lowlevel_threading.h",
-    "google/threading.cc",
-    "src/ducc0/infra/aligned_array.h",
-    "src/ducc0/infra/error_handling.h",
-    "src/ducc0/infra/misc_utils.h",
-    "src/ducc0/infra/simd.h",
-    "src/ducc0/infra/threading.cc",
-    "src/ducc0/infra/useful_macros.h",
-    "src/ducc0/math/cmplx.h",
-    "src/ducc0/math/unity_roots.h",
+DUCC_COPTS = [
+    "-frtti",
+    "-fexceptions",
+    "-ffp-contract=fast",
 ]
 
-DUCC_HEADERS = [
-    "google/threading.h",
-    "src/ducc0/fft/fft.h",
-    "src/ducc0/fft/fft1d_impl.h",
-    "src/ducc0/fft/fftnd_impl.h",
-    "src/ducc0/infra/mav.h",
-    "src/ducc0/infra/threading.h",
-]
-
+# This library exposes the raw DUCC fft API.  It should be used
+# with caution, since inclusion of the headers will require any
+# dependent targets to be build with exceptions and RTTI enabled.
+# For a better-isolated target, use ":fft_wrapper".
 cc_library(
     name = "fft",
-    srcs = DUCC_SOURCES,
-    hdrs = DUCC_HEADERS,
-    copts = [
-        "-frtti",
-        "-fexceptions",
-        "-ffp-contract=fast",
+    srcs = [
+        "google/ducc0_custom_lowlevel_threading.h",
+        "google/threading.cc",
+        "src/ducc0/infra/aligned_array.h",
+        "src/ducc0/infra/error_handling.h",
+        "src/ducc0/infra/misc_utils.h",
+        "src/ducc0/infra/simd.h",
+        "src/ducc0/infra/threading.cc",
+        "src/ducc0/infra/useful_macros.h",
+        "src/ducc0/math/cmplx.h",
+        "src/ducc0/math/unity_roots.h",
     ],
+    hdrs = [
+        "google/threading.h",
+        "src/ducc0/fft/fft.h",
+        "src/ducc0/fft/fft1d_impl.h",
+        "src/ducc0/fft/fftnd_impl.h",
+        "src/ducc0/infra/mav.h",
+        "src/ducc0/infra/threading.h",
+    ],
+    copts = DUCC_COPTS,
     defines = [
         # Use custom TSL/Eigen threading.
         "DUCC0_CUSTOM_LOWLEVEL_THREADING=1",
@@ -45,21 +45,32 @@
     features = ["-use_header_modules"],
     include_prefix = "ducc",
     includes = [
+        ".",  # Needed for google/-relative paths.
         "google",  # Needed for finding ducc0_custom_lowlevel_threading.h.
-        "src",
+        "src",  # Needed for internal headers.
     ],
+    # The DUCC FFT source files are dual-licensed as BSD 3 clause and GPLv2.
+    # We choose BSD 3 clause.
     licenses = ["notice"],
+    visibility = ["//visibility:private"],
     deps = [
         # Required for custom threadpool usage:
-        "@org_tensorflow//third_party/eigen3",
-        "@local_tsl//tsl/platform:env",
+        "@eigen_archive//:eigen3",
         "@local_tsl//tsl/platform:mutex",
-        "@local_tsl//tsl/platform:platform_port",
     ],
 )
 
-# Export source files needed for mobile builds, which do not use granular targets.
-filegroup(
-    name = "mobile_srcs_no_runtime",
-    srcs = DUCC_SOURCES + DUCC_HEADERS,
+cc_library(
+    name = "fft_wrapper",
+    srcs = ["google/fft.cc"],
+    hdrs = ["google/fft.h"],
+    copts = DUCC_COPTS,
+    features = ["-use_header_modules"],
+    include_prefix = "ducc",
+    licenses = ["notice"],
+    visibility = ["//visibility:public"],
+    deps = [
+        ":fft",
+        "@eigen_archive//:eigen3",
+    ],
 )
diff --git a/third_party/ducc/ducc0_custom_lowlevel_threading.h b/third_party/ducc/ducc0_custom_lowlevel_threading.h
index 6ac63a9..688efe7 100644
--- a/third_party/ducc/ducc0_custom_lowlevel_threading.h
+++ b/third_party/ducc/ducc0_custom_lowlevel_threading.h
@@ -27,7 +27,7 @@
 using CondVar = tsl::condition_variable;
 
 // Missing variable used by DUCC threading.cc.
-static thread_local bool in_parallel_region = false;
+extern thread_local bool in_parallel_region;
 
 }  // namespace detail_threading
 }  // namespace ducc0
diff --git a/third_party/ducc/fft.cc b/third_party/ducc/fft.cc
new file mode 100644
index 0000000..ec3c66f
--- /dev/null
+++ b/third_party/ducc/fft.cc
@@ -0,0 +1,148 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "ducc/google/fft.h"
+
+#include <complex>
+#include <cstddef>
+#include <cstdlib>
+#include <exception>
+#include <iostream>
+#include <ostream>
+#include <vector>
+
+#include "ducc/google/threading.h"
+#include "ducc/src/ducc0/fft/fft.h"
+#include "ducc/src/ducc0/fft/fft1d_impl.h"  // IWYU pragma: keep, DUCC definitions.
+#include "ducc/src/ducc0/fft/fftnd_impl.h"  // IWYU pragma: keep, DUCC definitions.
+#include "ducc/src/ducc0/infra/mav.h"
+#include "ducc/src/ducc0/infra/threading.h"
+#include "unsupported/Eigen/CXX11/ThreadPool"
+
+namespace ducc0 {
+
+// Wrappers around DUCC calls.
+namespace google {
+
+using Shape = std::vector<std::size_t>;
+using Stride = std::vector<std::ptrdiff_t>;
+
+template <typename RealScalar>
+void c2c(const std::complex<RealScalar>* in, const Shape& in_shape,
+         const Stride& in_stride, std::complex<RealScalar>* out,
+         const Shape& out_shape, const Stride& out_stride, const Shape& axes,
+         bool forward, RealScalar scale,
+         Eigen::ThreadPoolInterface* thread_pool) {
+  ducc0::cfmav<std::complex<RealScalar>> m_in(in, in_shape, in_stride);
+  ducc0::vfmav<std::complex<RealScalar>> m_out(out, out_shape, out_stride);
+
+  try {
+    if (thread_pool == nullptr) {
+      // Use a fake threadpool.
+      ducc0::google::NoThreadPool no_thread_pool;
+      ducc0::detail_threading::ScopedUseThreadPool thread_pool_guard(
+          no_thread_pool);
+      ducc0::c2c(m_in, m_out, axes, forward, scale, 1);
+    } else {
+      EigenThreadPool eigen_thread_pool(*thread_pool);
+      ducc0::detail_threading::ScopedUseThreadPool thread_pool_guard(
+          eigen_thread_pool);
+      ducc0::c2c(m_in, m_out, axes, forward, scale,
+                 eigen_thread_pool.nthreads());
+    }
+  } catch (const std::exception& ex) {
+    std::cerr << "DUCC FFT c2c failed: " << ex.what() << std::endl;
+    std::abort();
+  }
+}
+
+template <typename RealScalar>
+void r2c(const RealScalar* in, const Shape& in_shape, const Stride& in_stride,
+         std::complex<RealScalar>* out, const Shape& out_shape,
+         const Stride& out_stride, const Shape& axes, bool forward,
+         RealScalar scale, Eigen::ThreadPoolInterface* thread_pool) {
+  ducc0::cfmav<RealScalar> m_in(in, in_shape, in_stride);
+  ducc0::vfmav<std::complex<RealScalar>> m_out(out, out_shape, out_stride);
+
+  try {
+    if (thread_pool == nullptr) {
+      // Use a fake threadpool.
+      ducc0::google::NoThreadPool no_thread_pool;
+      ducc0::detail_threading::ScopedUseThreadPool thread_pool_guard(
+          no_thread_pool);
+      ducc0::r2c(m_in, m_out, axes, forward, scale, 1);
+    } else {
+      EigenThreadPool eigen_thread_pool(*thread_pool);
+      ducc0::detail_threading::ScopedUseThreadPool thread_pool_guard(
+          eigen_thread_pool);
+      ducc0::r2c(m_in, m_out, axes, forward, scale,
+                 eigen_thread_pool.nthreads());
+    }
+  } catch (const std::exception& ex) {
+    std::cerr << "DUCC FFT r2c failed: " << ex.what() << std::endl;
+    std::abort();
+  }
+}
+
+template <typename RealScalar>
+void c2r(const std::complex<RealScalar>* in, const Shape& in_shape,
+         const Stride& in_stride, RealScalar* out, const Shape& out_shape,
+         const Stride& out_stride, const Shape& axes, bool forward,
+         RealScalar scale, Eigen::ThreadPoolInterface* thread_pool) {
+  ducc0::cfmav<std::complex<RealScalar>> m_in(in, in_shape, in_stride);
+  ducc0::vfmav<RealScalar> m_out(out, out_shape, out_stride);
+
+  try {
+    if (thread_pool == nullptr) {
+      // Use a fake threadpool.
+      ducc0::google::NoThreadPool no_thread_pool;
+      ducc0::detail_threading::ScopedUseThreadPool thread_pool_guard(
+          no_thread_pool);
+      ducc0::c2r(m_in, m_out, axes, forward, scale, 1);
+    } else {
+      EigenThreadPool eigen_thread_pool(*thread_pool);
+      ducc0::detail_threading::ScopedUseThreadPool thread_pool_guard(
+          eigen_thread_pool);
+      ducc0::c2r(m_in, m_out, axes, forward, scale,
+                 eigen_thread_pool.nthreads());
+    }
+  } catch (const std::exception& ex) {
+    std::cerr << "DUCC FFT c2r failed: " << ex.what() << std::endl;
+    std::abort();
+  }
+}
+
+#define FFT_DEFINITIONS(RealScalar)                                            \
+  template void c2c<RealScalar>(                                               \
+      const std::complex<RealScalar>* in, const Shape& in_shape,               \
+      const Stride& in_stride, std::complex<RealScalar>* out,                  \
+      const Shape& out_shape, const Stride& out_stride, const Shape& axes,     \
+      bool forward, RealScalar scale,                                          \
+      Eigen::ThreadPoolInterface* thread_pool);                                \
+  template void r2c<RealScalar>(                                               \
+      const RealScalar* in, const Shape& in_shape, const Stride& in_stride,    \
+      std::complex<RealScalar>* out, const Shape& out_shape,                   \
+      const Stride& out_stride, const Shape& axes, bool forward,               \
+      RealScalar scale, Eigen::ThreadPoolInterface* thread_pool);              \
+  template void c2r(const std::complex<RealScalar>* in, const Shape& in_shape, \
+                    const Stride& in_stride, RealScalar* out,                  \
+                    const Shape& out_shape, const Stride& out_stride,          \
+                    const Shape& axes, bool forward, RealScalar scale,         \
+                    Eigen::ThreadPoolInterface* thread_pool)
+FFT_DEFINITIONS(float);
+FFT_DEFINITIONS(double);
+#undef FFT_DEFINITIONS
+
+}  // namespace google
+}  // namespace ducc0
\ No newline at end of file
diff --git a/third_party/ducc/fft.h b/third_party/ducc/fft.h
new file mode 100644
index 0000000..8c1691d
--- /dev/null
+++ b/third_party/ducc/fft.h
@@ -0,0 +1,77 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_DUCC_GOOGLE_FFT_H_
+#define THIRD_PARTY_DUCC_GOOGLE_FFT_H_
+
+// Wrapper around the DUCC FFT library to isolate usage of exceptions
+// and RTTI.  Eliminates all direct usage of DUCC headers.
+
+#include <complex>
+#include <cstddef>
+#include <vector>
+
+#include "unsupported/Eigen/CXX11/ThreadPool"
+
+namespace ducc0 {
+namespace google {
+
+using Shape = std::vector<std::size_t>;
+using Stride = std::vector<std::ptrdiff_t>;
+
+template <typename RealScalar>
+void c2c(const std::complex<RealScalar>* in, const Shape& in_shape,
+         const Stride& in_stride, std::complex<RealScalar>* out,
+         const Shape& out_shape, const Stride& out_stride, const Shape& axes,
+         bool forward, RealScalar scale,
+         Eigen::ThreadPoolInterface* thread_pool);
+
+template <typename RealScalar>
+void r2c(const RealScalar* in, const Shape& in_shape, const Stride& in_stride,
+         std::complex<RealScalar>* out, const Shape& out_shape,
+         const Stride& out_stride, const Shape& axes, bool forward,
+         RealScalar scale, Eigen::ThreadPoolInterface* thread_pool);
+
+template <typename RealScalar>
+void c2r(const std::complex<RealScalar>* in, const Shape& in_shape,
+         const Stride& in_stride, RealScalar* out, const Shape& out_shape,
+         const Stride& out_stride, const Shape& axes, bool forward,
+         RealScalar scale, Eigen::ThreadPoolInterface* thread_pool);
+
+#define FFT_DECLARATIONS(RealScalar)                                        \
+  extern template void c2c<RealScalar>(                                     \
+      const std::complex<RealScalar>* in, const Shape& in_shape,            \
+      const Stride& in_stride, std::complex<RealScalar>* out,               \
+      const Shape& out_shape, const Stride& out_stride, const Shape& axes,  \
+      bool forward, RealScalar scale,                                       \
+      Eigen::ThreadPoolInterface* thread_pool);                             \
+  extern template void r2c<RealScalar>(                                     \
+      const RealScalar* in, const Shape& in_shape, const Stride& in_stride, \
+      std::complex<RealScalar>* out, const Shape& out_shape,                \
+      const Stride& out_stride, const Shape& axes, bool forward,            \
+      RealScalar scale, Eigen::ThreadPoolInterface* thread_pool);           \
+  extern template void c2r(                                                 \
+      const std::complex<RealScalar>* in, const Shape& in_shape,            \
+      const Stride& in_stride, RealScalar* out, const Shape& out_shape,     \
+      const Stride& out_stride, const Shape& axes, bool forward,            \
+      RealScalar scale, Eigen::ThreadPoolInterface* thread_pool)
+FFT_DECLARATIONS(float);
+FFT_DECLARATIONS(double);
+#undef FFT_DECLARATIONS
+
+}  // namespace google
+}  // namespace ducc0
+
+#endif  // THIRD_PARTY_DUCC_GOOGLE_FFT_H_
\ No newline at end of file
diff --git a/third_party/ducc/threading.cc b/third_party/ducc/threading.cc
index 6a52d8b..d079398 100644
--- a/third_party/ducc/threading.cc
+++ b/third_party/ducc/threading.cc
@@ -12,33 +12,30 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-
 #include "ducc/google/threading.h"
 
+#include <thread>
 #include <utility>
 
 #include "ducc/src/ducc0/infra/threading.h"
-#include "tsl/platform/cpu_info.h"
-#include "tsl/platform/env.h"
-#include "tsl/platform/threadpool.h"
+#include "unsupported/Eigen/CXX11/ThreadPool"
 
 namespace ducc0 {
+
 namespace google {
+
 namespace {
 
 // Default shared global pool.  It is created on first use.
-thread_pool* GetGlobalThreadPoolSingleton() {
-  static tsl::thread::ThreadPool* tsl_pool =
-      new tsl::thread::ThreadPool(tsl::Env::Default(), "ducc_global_threadpool",
-                                  tsl::port::MaxParallelism());
-  static thread_pool* pool =
-      new EigenThreadPool(*tsl_pool->AsEigenThreadPool());
+EigenThreadPool* GetGlobalThreadPoolSingleton() {
+  static Eigen::ThreadPool* eigen_pool =
+      new Eigen::ThreadPool(std::thread::hardware_concurrency());
+  static EigenThreadPool* pool = new EigenThreadPool(*eigen_pool);
   return pool;
 }
 
-
 // Thread-local active pool for current execution.
-thread_pool*& GetActiveThreadPoolSingleton() {
+ducc0::detail_threading::thread_pool*& GetActiveThreadPoolSingleton() {
   thread_local thread_pool* active_pool = nullptr;
   return active_pool;
 }
@@ -49,6 +46,9 @@
 // Implementations required by ducc0.
 namespace detail_threading {
 
+// Missing variable used by DUCC threading.cc.
+thread_local bool in_parallel_region = false;
+
 thread_pool* set_active_pool(thread_pool* new_pool) {
   return std::exchange(ducc0::google::GetActiveThreadPoolSingleton(), new_pool);
 }
@@ -65,5 +65,4 @@
 }
 
 }  // namespace detail_threading
-
-}  // namespace ducc0
+}  // namespace ducc0
\ No newline at end of file
diff --git a/third_party/ducc/threading.h b/third_party/ducc/threading.h
index b9953a1..a374e3d 100644
--- a/third_party/ducc/threading.h
+++ b/third_party/ducc/threading.h
@@ -17,7 +17,7 @@
 #define THIRD_PARTY_DUCC_GOOGLE_THREADING_H_
 
 #include "ducc/src/ducc0/infra/threading.h"
-#include "third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool"
+#include "unsupported/Eigen/CXX11/ThreadPool"
 
 namespace ducc0 {
 namespace google {
diff --git a/third_party/ducc/workspace.bzl b/third_party/ducc/workspace.bzl
index ce0d633..1475579 100644
--- a/third_party/ducc/workspace.bzl
+++ b/third_party/ducc/workspace.bzl
@@ -13,7 +13,9 @@
         build_file = "//third_party/ducc:ducc.BUILD",
         link_files = {
             "//third_party/ducc:ducc0_custom_lowlevel_threading.h": "google/ducc0_custom_lowlevel_threading.h",
-            "//third_party/ducc:threading.h": "google/threading.h",
+            "//third_party/ducc:fft.h": "google/fft.h",
+            "//third_party/ducc:fft.cc": "google/fft.cc",
             "//third_party/ducc:threading.cc": "google/threading.cc",
+            "//third_party/ducc:threading.h": "google/threading.h",
         },
     )
diff --git a/third_party/eigen3/BUILD b/third_party/eigen3/BUILD
index d9a8308..84a4205 100644
--- a/third_party/eigen3/BUILD
+++ b/third_party/eigen3/BUILD
@@ -1,68 +1,3 @@
 # Description:
 #   Eigen is a C++ template library for linear algebra: vectors,
 #   matrices, and related algorithms.
-# This is the BUILD file with extra code to patch into @eigen_archive.
-
-load("//third_party/mkl:build_defs.bzl", "if_mkl")
-
-licenses([
-    # Note: Eigen is an MPL2 library that includes GPL v3 and LGPL v2.1+ code.
-    #       We've taken special care to not reference any restricted code.
-    "reciprocal",  # MPL2
-    "notice",  # Portions BSD
-])
-
-exports_files(["LICENSE"])
-
-EIGEN3_THIRD_PARTY_HEADERS = [
-    "Eigen/Core",
-    "Eigen/LU",
-    "Eigen/Cholesky",
-    "Eigen/Eigenvalues",
-    "Eigen/OrderingMethods",
-    "Eigen/QR",
-    "Eigen/SparseCholesky",
-    "Eigen/SparseCore",
-    "Eigen/SVD",
-    "unsupported/Eigen/MatrixFunctions",
-    "unsupported/Eigen/SpecialFunctions",
-    "unsupported/Eigen/CXX11/ThreadPool",
-    "unsupported/Eigen/CXX11/Tensor",
-]
-
-cc_library(
-    name = "eigen3",
-    hdrs = EIGEN3_THIRD_PARTY_HEADERS,
-    includes = if_mkl(["./mkl_include"]),
-    visibility = ["//visibility:public"],
-    deps = [
-        "@eigen_archive//:eigen3_internal",
-    ],
-)
-
-filegroup(
-    name = "eigen_third_party_header_files",
-    srcs = EIGEN3_THIRD_PARTY_HEADERS,
-    visibility = ["//visibility:public"],
-)
-
-genrule(
-    name = "install_eigen_headers",
-    srcs = [
-        "@eigen_archive//:eigen_header_files",
-        "@eigen_archive//:eigen_source_files",
-        ":eigen_third_party_header_files",
-    ],
-    outs = ["include"],
-    cmd = """
-    mkdir $@
-    for f in $(SRCS); do
-      d="$${f%/*}"
-      d="$${d#*external/eigen_archive/}"
-
-      mkdir -p "$@/$${d}"
-      cp "$${f}" "$@/$${d}/"
-    done
-    """,
-    tags = ["manual"],
-)
diff --git a/third_party/eigen3/Eigen/Cholesky b/third_party/eigen3/Eigen/Cholesky
deleted file mode 100644
index c199a025..0000000
--- a/third_party/eigen3/Eigen/Cholesky
+++ /dev/null
@@ -1 +0,0 @@
-#include "Eigen/Cholesky"
diff --git a/third_party/eigen3/Eigen/Core b/third_party/eigen3/Eigen/Core
deleted file mode 100644
index d4b0367..0000000
--- a/third_party/eigen3/Eigen/Core
+++ /dev/null
@@ -1 +0,0 @@
-#include "Eigen/Core"
diff --git a/third_party/eigen3/Eigen/Eigenvalues b/third_party/eigen3/Eigen/Eigenvalues
deleted file mode 100644
index bf739b9..0000000
--- a/third_party/eigen3/Eigen/Eigenvalues
+++ /dev/null
@@ -1 +0,0 @@
-#include "Eigen/Eigenvalues"
diff --git a/third_party/eigen3/Eigen/LU b/third_party/eigen3/Eigen/LU
deleted file mode 100644
index 536149c..0000000
--- a/third_party/eigen3/Eigen/LU
+++ /dev/null
@@ -1 +0,0 @@
-#include "Eigen/LU"
diff --git a/third_party/eigen3/Eigen/OrderingMethods b/third_party/eigen3/Eigen/OrderingMethods
deleted file mode 100644
index 190fc22..0000000
--- a/third_party/eigen3/Eigen/OrderingMethods
+++ /dev/null
@@ -1 +0,0 @@
-#include "Eigen/OrderingMethods"
\ No newline at end of file
diff --git a/third_party/eigen3/Eigen/QR b/third_party/eigen3/Eigen/QR
deleted file mode 100644
index be067d3..0000000
--- a/third_party/eigen3/Eigen/QR
+++ /dev/null
@@ -1 +0,0 @@
-#include "Eigen/QR"
diff --git a/third_party/eigen3/Eigen/SVD b/third_party/eigen3/Eigen/SVD
deleted file mode 100644
index eecf47c..0000000
--- a/third_party/eigen3/Eigen/SVD
+++ /dev/null
@@ -1 +0,0 @@
-#include "Eigen/SVD"
diff --git a/third_party/eigen3/Eigen/SparseCholesky b/third_party/eigen3/Eigen/SparseCholesky
deleted file mode 100644
index a6d362b..0000000
--- a/third_party/eigen3/Eigen/SparseCholesky
+++ /dev/null
@@ -1 +0,0 @@
-#include "Eigen/SparseCholesky"
\ No newline at end of file
diff --git a/third_party/eigen3/Eigen/SparseCore b/third_party/eigen3/Eigen/SparseCore
deleted file mode 100644
index 3c60745..0000000
--- a/third_party/eigen3/Eigen/SparseCore
+++ /dev/null
@@ -1 +0,0 @@
-#include "Eigen/SparseCore"
\ No newline at end of file
diff --git a/third_party/eigen3/eigen_archive.BUILD b/third_party/eigen3/eigen_archive.BUILD
index 3179cad..78b1fc8 100644
--- a/third_party/eigen3/eigen_archive.BUILD
+++ b/third_party/eigen3/eigen_archive.BUILD
@@ -4,8 +4,6 @@
 # This is the BUILD file used for the @eigen_archive external repository.
 
 licenses([
-    # Note: Although Eigen also includes GPL V3 and LGPL v2.1+ code, TensorFlow
-    #       has taken special care to not reference any restricted code.
     "reciprocal",  # MPL2
     "notice",  # Portions BSD
 ])
@@ -26,38 +24,29 @@
     ] + ALL_FILES_WITH_EXTENSIONS,
 )
 
-# Internal eigen headers, known to be under an MPL2 license.
-EIGEN_MPL2_SOURCES = glob(
+# Internal eigen headers.
+EIGEN_SOURCES = glob(
     [
         "Eigen/**/src/**/*.h",
         "Eigen/**/src/**/*.inc",
         "unsupported/Eigen/**/src/**/*.h",
         "unsupported/Eigen/**/src/**/*.inc",
     ],
-    exclude = [
-        # This guarantees that any file depending on non MPL2 licensed code
-        # will not compile.
-        "Eigen/src/Core/util/NonMPL2.h",
-    ],
-)
-
-alias(
-    name = "eigen3",
-    actual = "@org_tensorflow//third_party/eigen3",
-    visibility = ["//visibility:public"],
 )
 
 cc_library(
-    name = "eigen3_internal",
-    srcs = EIGEN_MPL2_SOURCES,
+    name = "eigen3",
+    srcs = EIGEN_SOURCES,
     hdrs = EIGEN_HEADERS,
     defines = [
-        # This define (mostly) guarantees we don't link any problematic
-        # code. We use it, but we do not rely on it, as evidenced above.
-        "EIGEN_MPL2_ONLY",
         "EIGEN_MAX_ALIGN_BYTES=64",
+        "EIGEN_ALLOW_UNALIGNED_SCALARS",  # TODO(b/296071640): Remove when underlying bugs are fixed.
+        "EIGEN_USE_AVX512_GEMM_KERNELS=0",  # TODO(b/238649163): Remove this once no longer necessary.
     ],
-    includes = ["."],
+    includes = [
+        ".",  # Third-party libraries include eigen relative to its root.
+        "./mkl_include",  # For using MKL backend for Eigen when available.
+    ],
     visibility = ["//visibility:public"],
 )
 
@@ -69,6 +58,6 @@
 
 filegroup(
     name = "eigen_source_files",
-    srcs = EIGEN_MPL2_SOURCES,
+    srcs = EIGEN_SOURCES,
     visibility = ["//visibility:public"],
 )
diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/Tensor b/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
deleted file mode 100644
index 41db119..0000000
--- a/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
+++ /dev/null
@@ -1 +0,0 @@
-#include "unsupported/Eigen/CXX11/Tensor"
diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool b/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool
deleted file mode 100644
index d2639af..0000000
--- a/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool
+++ /dev/null
@@ -1 +0,0 @@
-#include "unsupported/Eigen/CXX11/ThreadPool"
diff --git a/third_party/eigen3/unsupported/Eigen/MatrixFunctions b/third_party/eigen3/unsupported/Eigen/MatrixFunctions
deleted file mode 100644
index 314b325..0000000
--- a/third_party/eigen3/unsupported/Eigen/MatrixFunctions
+++ /dev/null
@@ -1 +0,0 @@
-#include "unsupported/Eigen/MatrixFunctions"
diff --git a/third_party/eigen3/unsupported/Eigen/SpecialFunctions b/third_party/eigen3/unsupported/Eigen/SpecialFunctions
deleted file mode 100644
index ad13359..0000000
--- a/third_party/eigen3/unsupported/Eigen/SpecialFunctions
+++ /dev/null
@@ -1 +0,0 @@
-#include "unsupported/Eigen/SpecialFunctions"
diff --git a/third_party/eigen3/workspace.bzl b/third_party/eigen3/workspace.bzl
index d1d8d4a..027454e 100644
--- a/third_party/eigen3/workspace.bzl
+++ b/third_party/eigen3/workspace.bzl
@@ -7,8 +7,8 @@
 
     # Attention: tools parse and update these lines.
     # LINT.IfChange
-    EIGEN_COMMIT = "66e8f38891841bf88ee976a316c0c78a52f0cee5"
-    EIGEN_SHA256 = "01fcd68409c038bbcfd16394274c2bf71e2bb6dda89a2319e23fc59a2da17210"
+    EIGEN_COMMIT = "aa6964bf3a34fd607837dd8123bc42465185c4f8"
+    EIGEN_SHA256 = "35ba771e30c735a4215ed784d7e032086cf89fe6622dce4d793c45dd74373362"
     # LINT.ThenChange(//tensorflow/lite/tools/cmake/modules/eigen.cmake)
 
     tf_http_archive(
diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl
index a87c49a..eb1be91 100644
--- a/third_party/llvm/workspace.bzl
+++ b/third_party/llvm/workspace.bzl
@@ -4,8 +4,8 @@
 
 def repo(name):
     """Imports LLVM."""
-    LLVM_COMMIT = "fd1a0b0ee4d8f6092dad6caff5217b8fd2193798"
-    LLVM_SHA256 = "d1194a3e2404d3a00cf695e1f31a324d6def42bcce4251abcdd938e4b1f6eabb"
+    LLVM_COMMIT = "9e0a5be0de320e29226225b6e466474c031d9ca6"
+    LLVM_SHA256 = "b32656f1bee03ffee0de64af05b4f590ec488a51e30757015ef57aca720d8182"
 
     tf_http_archive(
         name = name,
diff --git a/third_party/nccl/archive.BUILD b/third_party/nccl/archive.BUILD
index 05293fd..1813bac 100644
--- a/third_party/nccl/archive.BUILD
+++ b/third_party/nccl/archive.BUILD
@@ -19,7 +19,7 @@
 
 NCCL_MAJOR = 2
 
-NCCL_MINOR = 16
+NCCL_MINOR = 18
 
 NCCL_PATCH = 5
 
@@ -210,6 +210,10 @@
     ],
     include_prefix = "third_party/nccl",
     linkopts = ["-lrt"],
+    # The following definition is needed to enable placeholder literals such as
+    # PRIx64 defined at the inttypes.h since Tensorflow docker image uses
+    # an old version of glibc.
+    local_defines = ["__STDC_FORMAT_MACROS"],
     strip_include_prefix = "src",
     target_compatible_with = select({
         "@local_config_cuda//cuda:using_clang": [],
diff --git a/third_party/nccl/archive.patch b/third_party/nccl/archive.patch
index f951a6a..8ef0af9 100644
--- a/third_party/nccl/archive.patch
+++ b/third_party/nccl/archive.patch
@@ -30,19 +30,6 @@
 similarity index 100%
 rename from src/collectives/device/sendrecv.cu
 rename to src/collectives/device/sendrecv.cu.cc
-diff --git a/src/include/nvtx.h b/src/include/nvtx.h
-index 2aeb932..cdc67d2 100644
---- a/src/include/nvtx.h
-+++ b/src/include/nvtx.h
-@@ -37,7 +37,7 @@ struct nccl_domain{static constexpr char const* name{"NCCL"};};
-
- class payload_schema {
-  public:
--  NVTX3_RELAXED_CONSTEXPR explicit payload_schema(const nvtxPayloadSchemaEntry_t entries[], size_t numEntries, const uint64_t schemaId, const char* schemaName = nullptr) noexcept
-+  explicit payload_schema(const nvtxPayloadSchemaEntry_t entries[], size_t numEntries, const uint64_t schemaId, const char* schemaName = nullptr) noexcept
-   {
-     schema_attr.name = schemaName;
-     schema_attr.entries = entries;
 diff --git a/src/collectives/device/common.h b/src/collectives/device/common.h
 index accf8371a..4ab1bfac6 100644
 --- a/src/collectives/device/common.h
diff --git a/third_party/py/ml_dtypes/ml_dtypes.BUILD b/third_party/py/ml_dtypes/ml_dtypes.BUILD
index ccf607d..a85195e 100644
--- a/third_party/py/ml_dtypes/ml_dtypes.BUILD
+++ b/third_party/py/ml_dtypes/ml_dtypes.BUILD
@@ -17,7 +17,7 @@
         ".",
         "ml_dtypes",
     ],
-    deps = ["@org_tensorflow//third_party/eigen3"],
+    deps = ["@eigen_archive//:eigen3"],
 )
 
 cc_library(
@@ -48,7 +48,7 @@
     deps = [
         ":float8",
         ":int4",
-        "@org_tensorflow//third_party/eigen3",
+        "@eigen_archive//:eigen3",
         "@org_tensorflow//third_party/py/numpy:headers",
     ],
 )
diff --git a/third_party/py/ml_dtypes/ml_dtypes.tests.BUILD b/third_party/py/ml_dtypes/ml_dtypes.tests.BUILD
index 37cd52d..574659a 100644
--- a/third_party/py/ml_dtypes/ml_dtypes.tests.BUILD
+++ b/third_party/py/ml_dtypes/ml_dtypes.tests.BUILD
@@ -55,7 +55,7 @@
         "//:float8",
         "@com_google_absl//absl/strings",
         "@com_google_googletest//:gtest_main",
-        "@org_tensorflow//third_party/eigen3",
+        "@eigen_archive//:eigen3",
     ],
 )
 
@@ -66,6 +66,6 @@
     deps = [
         "//:int4",
         "@com_google_googletest//:gtest_main",
-        "@org_tensorflow//third_party/eigen3",
+        "@eigen_archive//:eigen3",
     ],
 )
diff --git a/third_party/py/non_hermetic/ml_dtypes/ml_dtypes.BUILD b/third_party/py/non_hermetic/ml_dtypes/ml_dtypes.BUILD
index ccf607d..a85195e 100644
--- a/third_party/py/non_hermetic/ml_dtypes/ml_dtypes.BUILD
+++ b/third_party/py/non_hermetic/ml_dtypes/ml_dtypes.BUILD
@@ -17,7 +17,7 @@
         ".",
         "ml_dtypes",
     ],
-    deps = ["@org_tensorflow//third_party/eigen3"],
+    deps = ["@eigen_archive//:eigen3"],
 )
 
 cc_library(
@@ -48,7 +48,7 @@
     deps = [
         ":float8",
         ":int4",
-        "@org_tensorflow//third_party/eigen3",
+        "@eigen_archive//:eigen3",
         "@org_tensorflow//third_party/py/numpy:headers",
     ],
 )
diff --git a/third_party/py/non_hermetic/ml_dtypes/ml_dtypes.tests.BUILD b/third_party/py/non_hermetic/ml_dtypes/ml_dtypes.tests.BUILD
index 37cd52d..574659a 100644
--- a/third_party/py/non_hermetic/ml_dtypes/ml_dtypes.tests.BUILD
+++ b/third_party/py/non_hermetic/ml_dtypes/ml_dtypes.tests.BUILD
@@ -55,7 +55,7 @@
         "//:float8",
         "@com_google_absl//absl/strings",
         "@com_google_googletest//:gtest_main",
-        "@org_tensorflow//third_party/eigen3",
+        "@eigen_archive//:eigen3",
     ],
 )
 
@@ -66,6 +66,6 @@
     deps = [
         "//:int4",
         "@com_google_googletest//:gtest_main",
-        "@org_tensorflow//third_party/eigen3",
+        "@eigen_archive//:eigen3",
     ],
 )
diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch
index 5161610..b3cb70f 100644
--- a/third_party/stablehlo/temporary.patch
+++ b/third_party/stablehlo/temporary.patch
@@ -181,6 +181,18 @@
  
  #-------------------------------------------------------------------------------
  # Directory setup
+diff --ruN a/stablehlo/stablehlo/conversions/tosa/tests/BUILD.bazel b/stablehlo/stablehlo/conversions/tosa/tests/BUILD.bazel
+--- stablehlo/stablehlo/conversions/tosa/tests/BUILD.bazel
++++ stablehlo/stablehlo/conversions/tosa/tests/BUILD.bazel
+@@ -29,7 +29,7 @@
+         "@LLVM_TOOLS_DIR@": package_path("@llvm-project//llvm:BUILD"),
+         "\"@STABLEHLO_TOOLS_DIR@\"": "os.path.join(os.environ['TEST_SRCDIR'], 'stablehlo')",
+         "\"@STABLEHLO_SOURCE_DIR@\"": "os.path.join(os.environ['TEST_SRCDIR'], 'stablehlo')",
+-    },
++     },
+     template = "lit.site.cfg.py.in",
+ )
+ 
 diff --ruN a/stablehlo/stablehlo/dialect/Base.cpp b/stablehlo/stablehlo/dialect/Base.cpp
 --- stablehlo/stablehlo/dialect/Base.cpp
 +++ stablehlo/stablehlo/dialect/Base.cpp
@@ -980,6 +992,44 @@
 +}  // namespace mlir
 +
 +#endif  // STABLEHLO_DIALECT_EXPERIMENTAL_OPS_H
+diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.cpp b/stablehlo/stablehlo/dialect/StablehloOps.cpp
+--- stablehlo/stablehlo/dialect/StablehloOps.cpp
++++ stablehlo/stablehlo/dialect/StablehloOps.cpp
+@@ -1543,6 +1543,7 @@
+     p << " across dimensions = [";
+     llvm::interleaveComma(getDimensions().getValues<int64_t>(), p);
+     p << "]";
++    p.printOptionalAttrDict(getOperation()->getAttrs(), {"dimensions"});
+     p << " : ";
+     p.printFunctionalType(*this);
+   } else {
+@@ -1705,6 +1706,7 @@
+   if (parser.parseKeyword("across") || parser.parseKeyword("dimensions") ||
+       parser.parseEqual() ||
+       parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, parseDim) ||
++      parser.parseOptionalAttrDict(result.attributes) ||
+       parser.parseColon() || parser.parseType(reduceOpFnType) ||
+       parser.parseOptionalLocationSpecifier(explicitLoc))
+     return failure();
+diff --ruN a/stablehlo/stablehlo/tests/print_reduce.mlir b/stablehlo/stablehlo/tests/print_reduce.mlir
+--- stablehlo/stablehlo/tests/print_reduce.mlir
++++ stablehlo/stablehlo/tests/print_reduce.mlir
+@@ -168,3 +168,15 @@
+ 
+   func.return %0: tensor<4xf32>
+ }
++
++// The test case makes sure any custom attrs set on the reduce-op are
++// printed/parsed when pretty-printed.
++
++// CHECK-LABEL:  func @pretty_print_with_custom_attr
++// CHECK:          applies stablehlo.add across dimensions = [1] {custom_user_attr = 1 : i64}
++
++func.func @pretty_print_with_custom_attr(%arg0: tensor<2x64x13xf32>) -> tensor<2x13xf32> {
++  %0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
++  %1 = stablehlo.reduce(%arg0 init: %0) applies stablehlo.add across dimensions = [1] {custom_user_attr = 1 : i64} : (tensor<2x64x13xf32>, tensor<f32>) -> tensor<2x13xf32>
++  return %1 : tensor<2x13xf32>
++}
 diff --ruN a/stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir b/stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir
 --- stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir
 +++ stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir
diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl
index 4dc93b8..da650e4 100644
--- a/third_party/stablehlo/workspace.bzl
+++ b/third_party/stablehlo/workspace.bzl
@@ -4,8 +4,8 @@
 
 def repo():
     # LINT.IfChange
-    STABLEHLO_COMMIT = "03216ba4f6ead279db5912828f8c94634589007d"
-    STABLEHLO_SHA256 = "84e9624cc61e70586c2e4bb0356da8d7fdbe653d0a015fd67b5c9a56660ba258"
+    STABLEHLO_COMMIT = "5e41e674af78da676652459c2dcf6a0d76e59ddb"
+    STABLEHLO_SHA256 = "02f7db52b6dc6b14d3dcbe8e7982c8cebcc2b91b1843abbe292ab98eef1fc9f2"
     # LINT.ThenChange(Google-internal path)
 
     tf_http_archive(
diff --git a/third_party/tf_runtime/tf_runtime.patch b/third_party/tf_runtime/tf_runtime.patch
deleted file mode 100644
index a9a9d4a..0000000
--- a/third_party/tf_runtime/tf_runtime.patch
+++ /dev/null
@@ -1,84 +0,0 @@
-Intermittent patch to TFRT to submit a TF/TFRT cross-cutting change.
-This patch will be applied only until TF's TFRT commit is automatically bumped.
-
----
-
-diff --git a/backends/gpu/include/tfrt/gpu/gpu_types.h b/backends/gpu/include/tfrt/gpu/gpu_types.h
-index 3d311c3..a216716 100644
---- a/backends/gpu/include/tfrt/gpu/gpu_types.h
-+++ b/backends/gpu/include/tfrt/gpu/gpu_types.h
-@@ -295,11 +295,7 @@
-       wrapper::CurrentContext current, wrapper::Stream stream,
-       wrapper::CclComm comm)>;
- 
--  explicit GpuCclHandle(AsyncValueRef<GpuContext> context,
--                        wrapper::OwningCclComm comm, int num_ranks);
--  // TODO(hanbinyoon): Remove after transitioning to the above constructor.
--  explicit GpuCclHandle(AsyncValueRef<GpuContext> context,
--                        wrapper::OwningCclComm comm);
-+  GpuCclHandle(AsyncValueRef<GpuContext> context, wrapper::OwningCclComm comm);
-   ~GpuCclHandle();
- 
-   GpuCclHandle(GpuCclHandle&&) = default;
-@@ -311,8 +307,6 @@
-   llvm::Error ExecuteCallbacks(wrapper::CurrentContext current,
-                                wrapper::Stream stream);
- 
--  int num_ranks() const { return num_ranks_; }
--
-   const wrapper::OwningCclComm& operator->() const { return comm_; }
-   wrapper::CclComm get() const { return comm_.get(); }
-   wrapper::CclComm release();
-@@ -322,7 +316,6 @@
-  private:
-   AsyncValueRef<GpuContext> context_;
-   wrapper::OwningCclComm comm_;
--  int num_ranks_;
-   std::vector<Callback> callbacks_;
- };
- 
-diff --git a/backends/gpu/lib/gpu_types.cc b/backends/gpu/lib/gpu_types.cc
-index 38529bc..01e3dba 100644
---- a/backends/gpu/lib/gpu_types.cc
-+++ b/backends/gpu/lib/gpu_types.cc
-@@ -214,15 +214,8 @@
- GpuBlasHandle::~GpuBlasHandle() = default;
- 
- GpuCclHandle::GpuCclHandle(AsyncValueRef<GpuContext> context,
--                           wrapper::OwningCclComm comm, int num_ranks)
--    : context_(std::move(context)),
--      comm_(std::move(comm)),
--      num_ranks_(num_ranks) {}
--
--// TODO(hanbinyoon): Remove after transitioning to the above constructor.
--GpuCclHandle::GpuCclHandle(AsyncValueRef<GpuContext> context,
-                            wrapper::OwningCclComm comm)
--    : context_(std::move(context)), comm_(std::move(comm)), num_ranks_(0) {}
-+    : context_(std::move(context)), comm_(std::move(comm)) {}
- 
- GpuCclHandle::~GpuCclHandle() = default;
- 
-diff --git a/backends/gpu/lib/kernels/ccl_kernels.cc b/backends/gpu/lib/kernels/ccl_kernels.cc
-index 52ce820..9cfc1de 100644
---- a/backends/gpu/lib/kernels/ccl_kernels.cc
-+++ b/backends/gpu/lib/kernels/ccl_kernels.cc
-@@ -107,8 +107,6 @@
-   auto width = ToWidthInBytes(type);
-   if (!width) return width.takeError();
-   assert(*width != 0);
--  if (input->size() != output->size() * handle->num_ranks())
--    return MakeStringError("Input size must be output size times ranks.");
- 
-   handle->AddCallback([input = input.ValueRef(), output = output.ValueRef(),
-                        recvcount = output->size() / *width, type,
-@@ -116,6 +114,10 @@
-                           wrapper::CurrentContext current,
-                           wrapper::Stream stream,
-                           wrapper::CclComm comm) -> llvm::Error {
-+    auto count = wrapper::CclCommCount(comm);
-+    if (!count) return count.takeError();
-+    if (input->size() != output->size() * *count)
-+      return MakeStringError("Input size must be output size times ranks.");
-     return wrapper::CclReduceScatter(current, input->pointer(),
-                                      output->pointer(), recvcount, type, op,
-                                      comm, stream);
diff --git a/third_party/tf_runtime/tf_runtime_clangcl.patch b/third_party/tf_runtime/tf_runtime_clangcl.patch
deleted file mode 100644
index ce1859d..0000000
--- a/third_party/tf_runtime/tf_runtime_clangcl.patch
+++ /dev/null
@@ -1,14 +0,0 @@
-diff --git a/include/tfrt/support/std_mutex.h b/include/tfrt/support/std_mutex.h
-index 6238d097..9fb24279 100644
---- a/include/tfrt/support/std_mutex.h
-+++ b/include/tfrt/support/std_mutex.h
-@@ -50,7 +50,7 @@ class TFRT_CAPABILITY("mutex") mutex {
- 
-  private:
-   friend class mutex_lock;
--  std::mutex mu_;
-+  std::mutex mu_{};
- };
-
- // Wrap std::unique_lock<std::mutex> with support for thread annotations.
- 
diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl
index 44d692c..f4c03be 100644
--- a/third_party/tf_runtime/workspace.bzl
+++ b/third_party/tf_runtime/workspace.bzl
@@ -6,8 +6,8 @@
     """Imports TFRT."""
 
     # Attention: tools parse and update these lines.
-    TFRT_COMMIT = "6d71fa4816fafb69ee1caac955b2b3844290d577"
-    TFRT_SHA256 = "37da0c18b558e85e8e9c9e482c217221762679265245893ab87c136df7265446"
+    TFRT_COMMIT = "bc45a6d53a5554e3b12fd42d0e0b4862cf2cef92"
+    TFRT_SHA256 = "e77b6fd0de15ff2e3e0cccaa78f645aac6674d495b9f9a90fe348ce40a233c6b"
 
     tf_http_archive(
         name = "tf_runtime",
diff --git a/third_party/triton/cl568176943.patch b/third_party/triton/cl568176943.patch
index d91e505..c187e67 100644
--- a/third_party/triton/cl568176943.patch
+++ b/third_party/triton/cl568176943.patch
@@ -1,8 +1,16 @@
 diff --git a/lib/Target/LLVMIR/LLVMIRTranslation.cpp b/lib/Target/LLVMIR/LLVMIRTranslation.cpp
-index d2a3f7c74..cb668303a 100644
+index e78e7298c..a4685653c 100644
 --- a/lib/Target/LLVMIR/LLVMIRTranslation.cpp
 +++ b/lib/Target/LLVMIR/LLVMIRTranslation.cpp
-@@ -273,8 +273,10 @@ static std::map<std::string, std::string> getExternLibs(mlir::ModuleOp module) {
+@@ -40,7 +40,6 @@
+ #include "llvm/Support/SourceMgr.h"
+ #include "llvm/Target/TargetMachine.h"
+ #include "llvm/Transforms/InstCombine/InstCombine.h"
+-#include "third_party/py/triton/google/find_cuda.h"
+ #include <optional>
+ #ifdef _WIN32
+ #define WIN32_LEAN_AND_MEAN
+@@ -277,8 +276,10 @@ static std::map<std::string, std::string> getExternLibs(mlir::ModuleOp module) {
      // Search for libdevice relative to its library path if used from Python
      // Then native code is in `triton/_C/libtriton.so` and libdevice in
      // `triton/third_party/cuda/lib/libdevice.10.bc`
diff --git a/third_party/triton/cl576548341.patch b/third_party/triton/cl576548341.patch
new file mode 100644
index 0000000..3efc805
--- /dev/null
+++ b/third_party/triton/cl576548341.patch
@@ -0,0 +1,16 @@
+diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp
+--- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp
++++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp
+@@ -362,8 +362,10 @@ Value addStringToModule(Location loc, Co
+   }
+ 
+   Value zero = i32_val(0);
+-  Value globalPtr =
+-      rewriter.create<LLVM::AddressOfOp>(UnknownLoc::get(ctx), global);
++  Type globalPtrType =
++      LLVM::LLVMPointerType::get(globalType, global.getAddrSpace());
++  Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
++      UnknownLoc::get(ctx), globalPtrType, global.getSymName());
+   Value stringStart =
+       rewriter.create<LLVM::GEPOp>(UnknownLoc::get(ctx), ptr_ty(i8_ty),
+                                    globalPtr, SmallVector<Value>({zero, zero}));
diff --git a/third_party/triton/cl577369732.patch b/third_party/triton/cl577369732.patch
new file mode 100644
index 0000000..e63b9f3
--- /dev/null
+++ b/third_party/triton/cl577369732.patch
@@ -0,0 +1,116 @@
+==== triton/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp#19 - /google/src/cloud/springerm/mlir_3cd2a0bc1a2dcf851f1821765946b77d0e65bd2e_1698463035/triton/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp ====
+# action=edit type=text
+--- triton/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp	2023-10-19 14:55:11.000000000 -0700
++++ triton/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp	2023-10-27 20:17:46.000000000 -0700
+@@ -759,7 +759,7 @@
+   OpBuilder builder(forOp);
+   // Get init operands for loop carried values
+   for (BlockArgument &arg : forOp.getRegionIterArgs()) {
+-    OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg);
++    OpOperand &operand = *forOp.getTiedLoopInit(arg);
+     setValueMapping(arg, operand.get(), 0);
+   }
+ 
+==== triton/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp#10 - /google/src/cloud/springerm/mlir_3cd2a0bc1a2dcf851f1821765946b77d0e65bd2e_1698463035/triton/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp ====
+# action=edit type=text
+--- triton/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp	2023-10-19 14:55:11.000000000 -0700
++++ triton/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp	2023-10-27 20:17:46.000000000 -0700
+@@ -188,7 +188,7 @@
+   auto getIncomingOp = [this](Value v) -> Value {
+     if (auto arg = v.dyn_cast<BlockArgument>())
+       if (arg.getOwner()->getParentOp() == forOp.getOperation())
+-        return forOp.getOpOperandForRegionIterArg(arg).get();
++        return forOp.getTiedLoopInit(arg)->get();
+     return Value();
+   };
+ 
+@@ -298,10 +298,10 @@
+       Operation *firstDot = builder.clone(*dot, mapping);
+       if (Value a = operand2headPrefetch.lookup(dot.getA()))
+         firstDot->setOperand(
+-            0, newForOp.getRegionIterArgForOpOperand(*a.use_begin()));
++            0, newForOp.getTiedLoopRegionIterArg(&*a.use_begin()));
+       if (Value b = operand2headPrefetch.lookup(dot.getB()))
+         firstDot->setOperand(
+-            1, newForOp.getRegionIterArgForOpOperand(*b.use_begin()));
++            1, newForOp.getTiedLoopRegionIterArg(&*b.use_begin()));
+ 
+       // remaining part
+       int64_t kOff = prefetchWidth;
+==== triton/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp#18 - /google/src/cloud/springerm/mlir_3cd2a0bc1a2dcf851f1821765946b77d0e65bd2e_1698463035/triton/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp ====
+# action=edit type=text
+--- triton/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp	2023-10-24 18:31:01.000000000 -0700
++++ triton/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp	2023-10-27 20:17:46.000000000 -0700
+@@ -245,7 +245,7 @@
+   for (OpOperand &use : value.getUses()) {
+     Operation *user = use.getOwner();
+     if (auto forOp = dyn_cast<scf::ForOp>(user)) {
+-      Value arg = forOp.getRegionIterArgForOpOperand(use);
++      Value arg = forOp.getTiedLoopRegionIterArg(&use);
+       Value result = forOp.getResultForOpOperand(use);
+       setEncoding({arg, result}, info, changed, user);
+       continue;
+@@ -767,7 +767,7 @@
+       SmallVector<Value> newOperands;
+       for (auto arg : forOp.getRegionIterArgs()) {
+         if (slice.count(arg)) {
+-          OpOperand &initVal = forOp.getOpOperandForRegionIterArg(arg);
++          OpOperand &initVal = *forOp.getTiedLoopInit(arg);
+           argMapping.push_back(std::make_pair(
+               forOp.getResultForOpOperand(initVal).getResultNumber(),
+               forOp.getInitArgs().size() + newOperands.size()));
+==== triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp#16 - /google/src/cloud/springerm/mlir_3cd2a0bc1a2dcf851f1821765946b77d0e65bd2e_1698463035/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp ====
+# action=edit type=text
+--- triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp	2023-10-24 18:31:01.000000000 -0700
++++ triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp	2023-10-27 20:17:46.000000000 -0700
+@@ -430,10 +430,10 @@
+     Block *block = blockArg.getOwner();
+     Operation *parentOp = block->getParentOp();
+     if (auto forOp = dyn_cast<scf::ForOp>(parentOp)) {
+-      OpOperand &initOperand = forOp.getOpOperandForRegionIterArg(blockArg);
++      OpOperand *initOperand = forOp.getTiedLoopInit(blockArg);
+       Value yieldOperand = forOp.getBody()->getTerminator()->getOperand(
+           blockArg.getArgNumber() - forOp.getNumInductionVars());
+-      queue.push_back({initOperand.get(), encoding});
++      queue.push_back({initOperand->get(), encoding});
+       queue.push_back({yieldOperand, encoding});
+       continue;
+     }
+==== triton/lib/Dialect/TritonNvidiaGPU/Transforms/Utility.cpp#1 - /google/src/cloud/springerm/mlir_3cd2a0bc1a2dcf851f1821765946b77d0e65bd2e_1698463035/triton/lib/Dialect/TritonNvidiaGPU/Transforms/Utility.cpp ====
+# action=edit type=text
+--- triton/lib/Dialect/TritonNvidiaGPU/Transforms/Utility.cpp	2023-10-12 01:35:16.000000000 -0700
++++ triton/lib/Dialect/TritonNvidiaGPU/Transforms/Utility.cpp	2023-10-27 20:17:46.000000000 -0700
+@@ -88,9 +88,8 @@
+     auto parentOp = blockArg.getOwner()->getParentOp();
+     if (auto forOp = dyn_cast<scf::ForOp>(parentOp)) {
+       if (blockArg.getArgNumber() >= forOp.getNumInductionVars()) {
+-        if (failed(getDependentPointers(
+-                forOp.getOpOperandForRegionIterArg(blockArg).get(),
+-                dependentSet, processedSet)))
++        if (failed(getDependentPointers(forOp.getTiedLoopInit(blockArg)->get(),
++                                        dependentSet, processedSet)))
+           return failure();
+ 
+         unsigned operandIdx =
+@@ -383,7 +382,7 @@
+       if (failed(addControlOperandsForForOp(forOp)))
+         return failure();
+       if (blockArg.getArgNumber() >= forOp.getNumInductionVars()) {
+-        Value operand = forOp.getOpOperandForRegionIterArg(blockArg).get();
++        Value operand = forOp.getTiedLoopInit(blockArg)->get();
+         if (failed(tryInsertAndPropagate(operand)))
+           return failure();
+ 
+==== triton/test/lib/Analysis/TestAlias.cpp#5 - /google/src/cloud/springerm/mlir_3cd2a0bc1a2dcf851f1821765946b77d0e65bd2e_1698463035/triton/test/lib/Analysis/TestAlias.cpp ====
+# action=edit type=text
+--- triton/test/lib/Analysis/TestAlias.cpp	2023-10-19 14:55:11.000000000 -0700
++++ triton/test/lib/Analysis/TestAlias.cpp	2023-10-27 20:17:47.000000000 -0700
+@@ -87,7 +87,7 @@
+       }
+       if (auto forOp = dyn_cast<scf::ForOp>(op)) {
+         for (auto arg : llvm::enumerate(forOp.getRegionIterArgs())) {
+-          auto operand = forOp.getOpOperandForRegionIterArg(arg.value()).get();
++          auto operand = forOp.getTiedLoopInit(arg.value())->get();
+           auto opNames = getAllocOpNames(operand);
+           auto argName = getValueOperandName(arg.value(), state);
+           print(argName, opNames, os);
diff --git a/third_party/triton/cl577379396.patch b/third_party/triton/cl577379396.patch
new file mode 100644
index 0000000..ee569f9
--- /dev/null
+++ b/third_party/triton/cl577379396.patch
@@ -0,0 +1,33 @@
+diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
+--- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
++++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
+@@ -246,7 +246,7 @@ SmallVector<Value> LayoutPropagation::pr
+     Operation *user = use.getOwner();
+     if (auto forOp = dyn_cast<scf::ForOp>(user)) {
+       Value arg = forOp.getTiedLoopRegionIterArg(&use);
+-      Value result = forOp.getResultForOpOperand(use);
++      Value result = forOp.getTiedLoopResult(&use);
+       setEncoding({arg, result}, info, changed, user);
+       continue;
+     }
+@@ -769,7 +769,7 @@ static void rewriteSlice(SetVector<Value
+         if (slice.count(arg)) {
+           OpOperand &initVal = *forOp.getTiedLoopInit(arg);
+           argMapping.push_back(std::make_pair(
+-              forOp.getResultForOpOperand(initVal).getResultNumber(),
++              forOp.getTiedLoopResult(&initVal).getResultNumber(),
+               forOp.getInitArgs().size() + newOperands.size()));
+           newOperands.push_back(mapping.lookup(initVal.get()));
+         }
+diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp
+--- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp
++++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp
+@@ -545,7 +545,7 @@ struct ForOpDeadArgElimination : public 
+       Value value = queue.pop_back_val();
+       if (auto nestedFor = value.getDefiningOp<scf::ForOp>()) {
+         auto result = value.cast<OpResult>();
+-        OpOperand &forOperand = nestedFor.getOpOperandForResult(result);
++        OpOperand &forOperand = *nestedFor.getTiedLoopInit(result);
+         markLive(forOperand.get());
+         auto nestedYieldOp =
+             cast<scf::YieldOp>(nestedFor.getBody()->getTerminator());
diff --git a/third_party/triton/workspace.bzl b/third_party/triton/workspace.bzl
index 7e15f39..9ca9639 100644
--- a/third_party/triton/workspace.bzl
+++ b/third_party/triton/workspace.bzl
@@ -5,8 +5,8 @@
 def repo():
     """Imports Triton."""
 
-    TRITON_COMMIT = "cl568176943"
-    TRITON_SHA256 = "5ffa5b538641fa306c8a24010438294ce7f43f80a462fe373a7cf747afde18b5"
+    TRITON_COMMIT = "cl575842988"
+    TRITON_SHA256 = "caa815ec863182eb3745fdc0884f521d622aa2b37be521b850f7ea330cadc923"
 
     tf_http_archive(
         name = "triton",
@@ -17,5 +17,8 @@
         patch_file = [
             "//third_party/triton:cl568176943.patch",
             "//third_party/triton:b304456327.patch",
+            "//third_party/triton:cl576548341.patch",
+            "//third_party/triton:cl577369732.patch",
+            "//third_party/triton:cl577379396.patch",
         ],
     )
diff --git a/third_party/xla/.bazelrc b/third_party/xla/.bazelrc
index 7035378..e9fc2d4 100644
--- a/third_party/xla/.bazelrc
+++ b/third_party/xla/.bazelrc
@@ -55,6 +55,7 @@
 #
 #     rbe_linux_cpu:                  RBE options to build with only CPU support.
 #     rbe_linux_cuda:                 RBE options to build with GPU support using clang.
+#     rbe_linux_cuda_nvcc:            RBE options to build with GPU support using nvcc.
 #
 #     rbe_win_py39: Windows Python 3.9 RBE config
 #
@@ -237,9 +238,12 @@
 # Select supported compute capabilities (supported graphics cards).
 # This is the same as the official TensorFlow builds.
 # See https://developer.nvidia.com/cuda-gpus#compute
-# TODO(angerson, perfinion): What does sm_ vs compute_ mean? How can users
-# select a good value for this? See go/tf-pip-cuda
-build:cuda_clang --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_75,compute_80"
+# `compute_XY` enables PTX embedding in addition to SASS. PTX
+# is forward compatible beyond the current compute capability major
+# release while SASS is only forward compatible inside the current
+# major release. Example: sm_80 kernels can run on sm_89 GPUs but
+# not on sm_90 GPUs. compute_80 kernels though can also run on sm_90 GPUs.
+build:cuda_clang --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90"
 
 # Set up compilation CUDA version and paths and use the CUDA Clang toolchain.
 build:cuda_clang_official --config=cuda_clang
@@ -249,7 +253,7 @@
 build:cuda_clang_official --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc"
 build:cuda_clang_official --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-17/bin/clang"
 build:cuda_clang_official --action_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64"
-build:cuda_clang_official --crosstool_top="@sigbuild-r2.14-clang_config_cuda//crosstool:toolchain"
+build:cuda_clang_official --crosstool_top="@sigbuild-r2.16-clang_config_cuda//crosstool:toolchain"
 
 # Debug config
 build:dbg -c dbg
@@ -482,12 +486,12 @@
 
 build:rbe_linux_cpu --config=rbe_linux
 # Linux cpu and cuda builds share the same toolchain now.
-build:rbe_linux_cpu --host_crosstool_top="@sigbuild-r2.14-clang_config_cuda//crosstool:toolchain"
-build:rbe_linux_cpu --crosstool_top="@sigbuild-r2.14-clang_config_cuda//crosstool:toolchain"
-build:rbe_linux_cpu --extra_toolchains="@sigbuild-r2.14-clang_config_cuda//crosstool:toolchain-linux-x86_64"
-build:rbe_linux_cpu --extra_execution_platforms="@sigbuild-r2.14-clang_config_platform//:platform"
-build:rbe_linux_cpu --host_platform="@sigbuild-r2.14-clang_config_platform//:platform"
-build:rbe_linux_cpu --platforms="@sigbuild-r2.14-clang_config_platform//:platform"
+build:rbe_linux_cpu --host_crosstool_top="@sigbuild-r2.16-clang_config_cuda//crosstool:toolchain"
+build:rbe_linux_cpu --crosstool_top="@sigbuild-r2.16-clang_config_cuda//crosstool:toolchain"
+build:rbe_linux_cpu --extra_toolchains="@sigbuild-r2.16-clang_config_cuda//crosstool:toolchain-linux-x86_64"
+build:rbe_linux_cpu --extra_execution_platforms="@sigbuild-r2.16-clang_config_platform//:platform"
+build:rbe_linux_cpu --host_platform="@sigbuild-r2.16-clang_config_platform//:platform"
+build:rbe_linux_cpu --platforms="@sigbuild-r2.16-clang_config_platform//:platform"
 # This is needed for all Clang17 builds but must not be present in GCC builds.
 build:rbe_linux_cpu --copt=-Wno-error=unused-command-line-argument
 # This was added in clang-16 by https://reviews.llvm.org/D133574.
@@ -496,7 +500,7 @@
 # See https://github.com/protocolbuffers/upb/blob/9effcbcb27f0a665f9f345030188c0b291e32482/upb/upb.c#L183.
 build:rbe_linux_cpu --copt=-Wno-gnu-offsetof-extensions
 # Python config is the same across all containers because the binary is the same
-build:rbe_linux_cpu --repo_env=TF_PYTHON_CONFIG_REPO="@sigbuild-r2.14-clang_config_python"
+build:rbe_linux_cpu --repo_env=TF_PYTHON_CONFIG_REPO="@sigbuild-r2.16-clang_config_python"
 build:rbe_linux_cpu --python_path="/usr/bin/python3"
 # These you may need to change for your own GCP project.
 common:rbe_linux_cpu --remote_instance_name=projects/tensorflow-testing/instances/default_instance
@@ -517,11 +521,40 @@
 build:rbe_linux_cuda --config=rbe_linux_cpu
 # For Remote build execution -- GPU configuration
 build:rbe_linux_cuda --repo_env=REMOTE_GPU_TESTING=1
-build:rbe_linux_cuda --repo_env=TF_CUDA_CONFIG_REPO="@sigbuild-r2.14-clang_config_cuda"
-build:rbe_linux_cuda --repo_env=TF_TENSORRT_CONFIG_REPO="@sigbuild-r2.14-clang_config_tensorrt"
-build:rbe_linux_cuda --repo_env=TF_NCCL_CONFIG_REPO="@sigbuild-r2.14-clang_config_nccl"
+build:rbe_linux_cuda --repo_env=TF_CUDA_CONFIG_REPO="@sigbuild-r2.16-clang_config_cuda"
+build:rbe_linux_cuda --repo_env=TF_TENSORRT_CONFIG_REPO="@sigbuild-r2.16-clang_config_tensorrt"
+build:rbe_linux_cuda --repo_env=TF_NCCL_CONFIG_REPO="@sigbuild-r2.16-clang_config_nccl"
 test:rbe_linux_cuda --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64"
 
+build:rbe_linux_cuda_nvcc --config=cuda
+build:rbe_linux_cuda_nvcc --repo_env TF_NCCL_USE_STUB=1
+build:rbe_linux_cuda_nvcc --@local_xla//xla/python:enable_gpu=true
+build:rbe_linux_cuda_nvcc --@local_xla//xla/python:jax_cuda_pip_rpaths=true
+build:rbe_linux_cuda_nvcc --define=xla_python_enable_gpu=true
+build:rbe_linux_cuda_nvcc --config=tensorrt
+build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_75,compute_80"
+build:rbe_linux_cuda_nvcc --action_env=TF_CUDA_VERSION="12"
+build:rbe_linux_cuda_nvcc --action_env=TF_CUDNN_VERSION="8"
+build:rbe_linux_cuda_nvcc --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-12.2"
+build:rbe_linux_cuda_nvcc --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc"
+build:rbe_linux_cuda_nvcc --action_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64"
+build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_cuda//crosstool:toolchain"
+build:rbe_linux_cuda_nvcc --config=rbe_linux
+build:rbe_linux_cuda_nvcc --host_crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_cuda//crosstool:toolchain"
+build:rbe_linux_cuda_nvcc --extra_toolchains="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_cuda//crosstool:toolchain-linux-x86_64"
+build:rbe_linux_cuda_nvcc --extra_execution_platforms="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_platform//:platform"
+build:rbe_linux_cuda_nvcc --host_platform="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_platform//:platform"
+build:rbe_linux_cuda_nvcc --platforms="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_platform//:platform"
+build:rbe_linux_cuda_nvcc --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_python3.9"
+build:rbe_linux_cuda_nvcc --python_path="/usr/bin/python3"
+# These you may need to change for your own GCP project.
+common:rbe_linux_cuda_nvcc --remote_instance_name=projects/tensorflow-testing/instances/default_instance
+build:rbe_linux_cuda_nvcc --repo_env=REMOTE_GPU_TESTING=1
+build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda12.2-cudnn8.9_config_cuda"
+build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda12.2-cudnn8.9_config_tensorrt"
+build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda12.2-cudnn8.9_config_nccl"
+test:rbe_linux_cuda_nvcc --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64"
+
 # TODO(kanglan): Remove rbe_win and rbe_win_py3* after b/289091160 is fixed
 build:rbe_win --config=rbe_base
 build:rbe_win --crosstool_top="//tensorflow/tools/toolchains/win/tf_win_05022023:toolchain"
@@ -576,8 +609,6 @@
 # Here are bazelrc configs for release builds
 # Build TensorFlow v2.
 test:release_base --test_size_filters=small,medium
-# TODO(b/294367488) disable after 2.15 brancut
-test:release_base --flaky_test_attempts=3
 
 # Target the AVX instruction set
 build:release_linux_base --config=avx_linux
@@ -615,7 +646,7 @@
 
 # Use the Clang toolchain to compile
 build:release_cpu_linux --config=release_linux_base
-build:release_cpu_linux --crosstool_top="@sigbuild-r2.14-clang_config_cuda//crosstool:toolchain"
+build:release_cpu_linux --crosstool_top="@sigbuild-r2.16-clang_config_cuda//crosstool:toolchain"
 
 build:release_gpu_linux --config=release_cpu_linux
 # Set up compilation CUDA version and paths and use the CUDA Clang toolchain.
@@ -684,7 +715,7 @@
 build:macos   --config=no_tfrt
 build:windows --config=no_tfrt
 build:rocm --config=no_tfrt
-build:no_tfrt --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compiler/mlir/tfrt/benchmarks,tensorflow/compiler/mlir/tfrt/ir,tensorflow/compiler/mlir/tfrt/ir/mlrt,tensorflow/compiler/mlir/tfrt/jit/python_binding,tensorflow/compiler/mlir/tfrt/jit/transforms,tensorflow/compiler/mlir/tfrt/python_tests,tensorflow/compiler/mlir/tfrt/tests,tensorflow/compiler/mlir/tfrt/tests/mlrt,tensorflow/compiler/mlir/tfrt/tests/ir,tensorflow/compiler/mlir/tfrt/tests/analysis,tensorflow/compiler/mlir/tfrt/tests/jit,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt,tensorflow/compiler/mlir/tfrt/tests/tf_to_corert,tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data,tensorflow/compiler/mlir/tfrt/tests/saved_model,tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu,tensorflow/compiler/mlir/tfrt/transforms/mlrt,tensorflow/core/runtime_fallback,tensorflow/core/runtime_fallback/conversion,tensorflow/core/runtime_fallback/kernel,tensorflow/core/runtime_fallback/opdefs,tensorflow/core/runtime_fallback/runtime,tensorflow/core/runtime_fallback/util,tensorflow/core/runtime_fallback/test,tensorflow/core/runtime_fallback/test/gpu,tensorflow/core/runtime_fallback/test/saved_model,tensorflow/core/runtime_fallback/test/testdata,tensorflow/core/tfrt/stubs,tensorflow/core/tfrt/tfrt_session,tensorflow/core/tfrt/mlrt,tensorflow/core/tfrt/mlrt/attribute,tensorflow/core/tfrt/mlrt/kernel,tensorflow/core/tfrt/mlrt/bytecode,tensorflow/core/tfrt/mlrt/interpreter,tensorflow/compiler/mlir/tfrt/translate/mlrt,tensorflow/compiler/mlir/tfrt/translate/mlrt/testdata,tensorflow/core/tfrt/gpu,tensorflow/core/tfrt/run_handler_thread_pool,tensorflow/core/tfrt/runtime,tensorflow/core/tfrt/saved_model,tensorflow/core/tfrt/graph_executor,tensorflow/core/tfrt/saved_model/tests,tensorflow/core/tfrt/tpu,tensorflow/core/tfrt/utils,tensorflow/core/tfrt/utils/debug,tensorflow/core/tfrt/saved_model/python,tensorflow/core/tfrt/graph_executor/python,tensorflow/core/tfrt/saved_model/utils
+build:no_tfrt --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compiler/mlir/tfrt/benchmarks,tensorflow/compiler/mlir/tfrt/ir,tensorflow/compiler/mlir/tfrt/ir/mlrt,tensorflow/compiler/mlir/tfrt/jit/python_binding,tensorflow/compiler/mlir/tfrt/jit/transforms,tensorflow/compiler/mlir/tfrt/python_tests,tensorflow/compiler/mlir/tfrt/tests,tensorflow/compiler/mlir/tfrt/tests/ifrt,tensorflow/compiler/mlir/tfrt/tests/mlrt,tensorflow/compiler/mlir/tfrt/tests/ir,tensorflow/compiler/mlir/tfrt/tests/analysis,tensorflow/compiler/mlir/tfrt/tests/jit,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt,tensorflow/compiler/mlir/tfrt/tests/tf_to_corert,tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data,tensorflow/compiler/mlir/tfrt/tests/saved_model,tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu,tensorflow/compiler/mlir/tfrt/transforms/mlrt,tensorflow/core/runtime_fallback,tensorflow/core/runtime_fallback/conversion,tensorflow/core/runtime_fallback/kernel,tensorflow/core/runtime_fallback/opdefs,tensorflow/core/runtime_fallback/runtime,tensorflow/core/runtime_fallback/util,tensorflow/core/runtime_fallback/test,tensorflow/core/runtime_fallback/test/gpu,tensorflow/core/runtime_fallback/test/saved_model,tensorflow/core/runtime_fallback/test/testdata,tensorflow/core/tfrt/stubs,tensorflow/core/tfrt/tfrt_session,tensorflow/core/tfrt/mlrt,tensorflow/core/tfrt/mlrt/attribute,tensorflow/core/tfrt/mlrt/kernel,tensorflow/core/tfrt/mlrt/bytecode,tensorflow/core/tfrt/mlrt/interpreter,tensorflow/compiler/mlir/tfrt/translate/mlrt,tensorflow/compiler/mlir/tfrt/translate/mlrt/testdata,tensorflow/core/tfrt/gpu,tensorflow/core/tfrt/run_handler_thread_pool,tensorflow/core/tfrt/runtime,tensorflow/core/tfrt/saved_model,tensorflow/core/tfrt/graph_executor,tensorflow/core/tfrt/saved_model/tests,tensorflow/core/tfrt/tpu,tensorflow/core/tfrt/utils,tensorflow/core/tfrt/utils/debug,tensorflow/core/tfrt/saved_model/python,tensorflow/core/tfrt/graph_executor/python,tensorflow/core/tfrt/saved_model/utils
 
 # BEGIN TF CACHE HELPER OPTIONS
 # Options when using remote execution
diff --git a/third_party/xla/.kokoro/linux/build.sh b/third_party/xla/.kokoro/linux/build.sh
index edced7a..89071ee 100644
--- a/third_party/xla/.kokoro/linux/build.sh
+++ b/third_party/xla/.kokoro/linux/build.sh
@@ -26,6 +26,10 @@
   [[ "$KOKORO_JOB_NAME" =~ tensorflow/xla/linux/.*gpu.* ]]
 }
 
+function is_use_nvcc() {
+  [[ -z "${USE_NVCC:-}" ]] || [[ "$USE_NVCC" == "true" ]]
+}
+
 # Pull the container (in case it was updated since the instance started) and
 # store its SHA in the Sponge log.
 docker pull "$DOCKER_IMAGE"
@@ -44,16 +48,23 @@
 TARGET_FILTER=""
 TAGS_FILTER="-no_oss,-oss_excluded,-oss_serial"
 ADDITIONAL_FLAGS=""
+RBE_CONFIG=""
 
 if is_linux_gpu_job ; then
     TAGS_FILTER="$TAGS_FILTER,gpu,requires-gpu-nvidia,-no_gpu"
     ADDITIONAL_FLAGS="$ADDITIONAL_FLAGS --run_under=//tools/ci_build/gpu_build:parallel_gpu_execute"
     RC_FILE="/usertools/gpu.bazelrc"
+    if is_use_nvcc ; then
+      RBE_CONFIG="rbe_linux_cuda_nvcc"
+    else
+      RBE_CONFIG="rbe_linux_cuda"
+    fi
     echo "***NOTE: nvidia-smi lists the highest CUDA version the driver supports, which may be different than the version of CUDA actually used!!***"
     nvidia-smi
 else
     TAGS_FILTER="$TAGS_FILTER,-gpu,-requires-gpu-nvidia"
     ADDITIONAL_FLAGS="$ADDITIONAL_FLAGS --config=nonccl"
+    RBE_CONFIG="rbe_linux_cpu"
 fi
 
 # Build & test XLA
@@ -65,7 +76,7 @@
         --features=layering_check \
         --profile=/tf/pkg/profile.json.gz \
         --flaky_test_attempts=3 \
-        --config=rbe \
+        --config=$RBE_CONFIG \
         --jobs=150 \
         --nobuild_tests_only \
         $ADDITIONAL_FLAGS \
diff --git a/third_party/xla/configure.py b/third_party/xla/configure.py
index 7b7c3c5..b2b30df 100644
--- a/third_party/xla/configure.py
+++ b/third_party/xla/configure.py
@@ -15,8 +15,8 @@
 """configure script to get build parameters from user."""
 
 import argparse
-import glob
 import os
+import pathlib
 import platform
 import re
 import subprocess
@@ -31,7 +31,7 @@
 
 _DEFAULT_CUDA_VERSION = '11'
 _DEFAULT_CUDNN_VERSION = '2'
-_DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,7.0'
+_DEFAULT_CUDA_COMPUTE_CAPABILITIES = '5.2,7.0'
 
 _DEFAULT_PROMPT_ASK_ATTEMPTS = 10
 
@@ -688,8 +688,9 @@
         ' binary GPU code, or as "sm_xy" to only include the binary '
         'code.\nPlease note that each additional compute capability '
         'significantly increases your build time and binary size, and that '
-        'XLA only supports compute capabilities >= 3.5 [Default is: '
-        '%s]: ' % default_cuda_compute_capabilities)
+        'XLA only supports compute capabilities >= 5.2 [Default is: '
+        '%s]: ' % default_cuda_compute_capabilities
+    )
     tf_cuda_compute_capabilities = get_from_env_or_user_or_default(
         environ_cp, 'TF_CUDA_COMPUTE_CAPABILITIES',
         ask_cuda_compute_capabilities, default_cuda_compute_capabilities)
@@ -701,7 +702,7 @@
     for compute_capability in tf_cuda_compute_capabilities.split(','):
       m = re.match('[0-9]+.[0-9]+', compute_capability)
       if not m:
-        # We now support sm_35,sm_50,sm_60,compute_70.
+        # We now support sm_52,compute_70.
         sm_compute_match = re.match('(sm|compute)_?([0-9]+[0-9]+)',
                                     compute_capability)
         if not sm_compute_match:
@@ -709,25 +710,22 @@
           all_valid = False
         else:
           ver = int(sm_compute_match.group(2))
-          if ver < 30:
+          if ver < 52:
             print(
                 'ERROR: XLA only supports small CUDA compute'
-                ' capabilities of sm_30 and higher. Please re-specify the list'
-                ' of compute capabilities excluding version %s.' % ver)
+                ' capabilities of sm_52 and higher. Please re-specify the list'
+                ' of compute capabilities excluding version %s.' % ver
+            )
             all_valid = False
-          if ver < 35:
-            print('WARNING: XLA does not support CUDA compute capabilities '
-                  'lower than sm_35. Disable XLA when running on older GPUs.')
       else:
         ver = float(m.group(0))
-        if ver < 3.0:
-          print('ERROR: XLA only supports CUDA compute capabilities 3.0 '
-                'and higher. Please re-specify the list of compute '
-                'capabilities excluding version %s.' % ver)
+        if ver < 5.2:
+          print(
+              'ERROR: XLA only supports CUDA compute capabilities 5.2 '
+              'and higher. Please re-specify the list of compute '
+              'capabilities excluding version %s.' % ver
+          )
           all_valid = False
-        if ver < 3.5:
-          print('WARNING: XLA does not support CUDA compute capabilities '
-                'lower than 3.5. Disable XLA when running on older GPUs.')
 
     if all_valid:
       break
@@ -854,14 +852,20 @@
     if environ_cp.get('TF_NCCL_VERSION', None):
       cuda_libraries.append('nccl')
 
-  paths = glob.glob('**/third_party/gpus/find_cuda_config.py', recursive=True)
-  if not paths:
+  find_cuda_script = os.path.join(
+      pathlib.Path(__file__).parent.resolve(),
+      'third_party/tsl/third_party/gpus/find_cuda_config.py',
+  )
+  if not os.path.isfile(find_cuda_script):
     raise FileNotFoundError(
-        "Can't find 'find_cuda_config.py' script inside working directory")
+        "Can't find 'find_cuda_config.py' script inside working directory,"
+        f' expected in {find_cuda_script}'
+    )
   proc = subprocess.Popen(
-      [environ_cp['PYTHON_BIN_PATH'], paths[0]] + cuda_libraries,
+      [environ_cp['PYTHON_BIN_PATH'], find_cuda_script] + cuda_libraries,
       stdout=subprocess.PIPE,
-      env=maybe_encode_env(environ_cp))
+      env=maybe_encode_env(environ_cp),
+  )
 
   if proc.wait():
     # Errors from find_cuda_config.py were sent to stderr.
diff --git a/third_party/xla/opensource_only.files b/third_party/xla/opensource_only.files
index 1930b01..9abb254 100644
--- a/third_party/xla/opensource_only.files
+++ b/third_party/xla/opensource_only.files
@@ -3,82 +3,15 @@
 compiler/xla/stream_executor/build_defs.bzl:
 third_party/BUILD:
 third_party/__init__:.py
-third_party/absl/com_google_absl.BUILD:
-third_party/clang_toolchain/BUILD:
-third_party/clang_toolchain/cc_configure_clang.bzl:
-third_party/clang_toolchain/download_clang.bzl:
 third_party/compute_library/BUILD:
 third_party/compute_library/build_defs.bzl:
-third_party/curl.BUILD:
-third_party/cython.BUILD:
-third_party/eigen3/BUILD:
-third_party/eigen3/Eigen/Cholesky:
-third_party/eigen3/Eigen/Core:
-third_party/eigen3/Eigen/Eigenvalues:
-third_party/eigen3/Eigen/LU:
-third_party/eigen3/Eigen/OrderingMethods:
-third_party/eigen3/Eigen/QR:
-third_party/eigen3/Eigen/SVD:
-third_party/eigen3/Eigen/SparseCholesky:
-third_party/eigen3/Eigen/SparseCore:
-third_party/eigen3/LICENSE:
-third_party/eigen3/eigen_archive.BUILD:
-third_party/eigen3/unsupported/Eigen/CXX11/Tensor:
-third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool:
-third_party/eigen3/unsupported/Eigen/MatrixFunctions:
-third_party/eigen3/unsupported/Eigen/SpecialFunctions:
-third_party/gif.BUILD:
-third_party/gif_fix_strtok_r.patch:
-third_party/git/BUILD.tpl:
-third_party/git/BUILD:
-third_party/git/git_configure.bzl:
-third_party/gpus/BUILD:
-third_party/gpus/crosstool/BUILD.rocm.tpl:
-third_party/gpus/crosstool/BUILD.tpl:
-third_party/gpus/crosstool/BUILD:
-third_party/gpus/crosstool/LICENSE:
-third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl:
-third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl:
-third_party/gpus/cuda/BUILD.tpl:
-third_party/gpus/cuda/BUILD:
-third_party/gpus/cuda/LICENSE:
-third_party/gpus/cuda/build_defs.bzl.tpl:
-third_party/gpus/cuda/cuda_config.h.tpl:
-third_party/gpus/cuda/cuda_config.py.tpl:
-third_party/gpus/cuda_configure.bzl:
-third_party/gpus/find_cuda_config:.py
-third_party/gpus/rocm/BUILD.tpl:
-third_party/gpus/rocm/BUILD:
-third_party/gpus/rocm/build_defs.bzl.tpl:
-third_party/gpus/rocm/rocm_config.h.tpl:
-third_party/gpus/rocm_configure.bzl:
-third_party/grpc/BUILD:
-third_party/implib_so/BUILD:
-third_party/implib_so/get_symbols.py:
-third_party/implib_so/make_stub.py:
 third_party/llvm_openmp/BUILD:
 third_party/llvm_openmp/cmake_vars.bzl:
 third_party/llvm_openmp/expand_cmake_vars:.py
 third_party/llvm_openmp/openmp.bzl:
-third_party/mkl/BUILD:
-third_party/mkl/build_defs.bzl:
-third_party/mkl_dnn/LICENSE:
-third_party/mkl_dnn/build_defs.bzl:
-third_party/mkl_dnn/mkldnn_acl.BUILD:
-third_party/mkl_dnn/mkldnn_v1.BUILD:
-third_party/nccl/BUILD:
-third_party/nccl/LICENSE:
-third_party/nccl/archive.BUILD:
-third_party/nccl/archive.patch:
-third_party/nccl/build_defs.bzl.tpl:
-third_party/nccl/nccl_configure.bzl:
-third_party/nccl/system.BUILD.tpl:
 third_party/ortools/BUILD:
 third_party/ortools/glpk.BUILD:
 third_party/ortools/ortools.patch:
-third_party/png.BUILD:
-third_party/png_fix_rpi.patch:
-third_party/protobuf/BUILD:
 third_party/py/non_hermetic/BUILD.tpl:
 third_party/py/non_hermetic/BUILD:
 third_party/py/non_hermetic/README:
@@ -86,55 +19,9 @@
 third_party/py/non_hermetic/ml_dtypes/LICENSE:
 third_party/py/non_hermetic/numpy/BUILD:
 third_party/py/non_hermetic/python_configure.bzl:
-third_party/pybind11.BUILD:
-third_party/pybind11_bazel/BUILD:
 third_party/python_runtime/BUILD:
-third_party/remote_config/BUILD.tpl:
-third_party/remote_config/BUILD:
-third_party/remote_config/common.bzl:
-third_party/remote_config/remote_platform_configure.bzl:
 third_party/repo.bzl:
-third_party/six.BUILD:
-third_party/snappy.BUILD:
 third_party/stablehlo/BUILD:
-third_party/systemlibs/BUILD.tpl:
-third_party/systemlibs/BUILD:
-third_party/systemlibs/absl_py.BUILD:
-third_party/systemlibs/absl_py.absl.flags.BUILD:
-third_party/systemlibs/absl_py.absl.logging.BUILD:
-third_party/systemlibs/absl_py.absl.testing.BUILD:
-third_party/systemlibs/boringssl.BUILD:
-third_party/systemlibs/build_defs.bzl.tpl:
-third_party/systemlibs/curl.BUILD:
-third_party/systemlibs/cython.BUILD:
-third_party/systemlibs/double_conversion.BUILD:
-third_party/systemlibs/gif.BUILD:
-third_party/systemlibs/google_cloud_cpp.BUILD:
-third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD:
-third_party/systemlibs/grpc.BUILD:
-third_party/systemlibs/jsoncpp.BUILD:
-third_party/systemlibs/lmdb.BUILD:
-third_party/systemlibs/nsync.BUILD:
-third_party/systemlibs/png.BUILD:
-third_party/systemlibs/protobuf.BUILD:
-third_party/systemlibs/protobuf.bzl:
-third_party/systemlibs/re2.BUILD:
-third_party/systemlibs/six.BUILD:
-third_party/systemlibs/snappy.BUILD:
-third_party/systemlibs/sqlite.BUILD:
-third_party/systemlibs/syslibs_configure.bzl:
-third_party/systemlibs/zlib.BUILD:
-third_party/tensorrt/BUILD.tpl:
-third_party/tensorrt/BUILD:
-third_party/tensorrt/LICENSE:
-third_party/tensorrt/build_defs.bzl.tpl:
-third_party/tensorrt/plugin/BUILD:
-third_party/tensorrt/tensorrt/include/tensorrt_config.h.tpl:
-third_party/tensorrt/tensorrt/tensorrt_config.py.tpl:
-third_party/tensorrt/tensorrt_configure.bzl:
-third_party/tensorrt/workspace.bzl:
-third_party/tf_runtime/BUILD:
-third_party/zlib.BUILD:
 tools/toolchains/BUILD:
 tools/toolchains/clang6/BUILD:
 tools/toolchains/cpus/py/BUILD:
diff --git a/third_party/xla/third_party/absl/BUILD b/third_party/xla/third_party/absl/BUILD
deleted file mode 100644
index 3c41380..0000000
--- a/third_party/xla/third_party/absl/BUILD
+++ /dev/null
@@ -1 +0,0 @@
-# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"])
diff --git a/third_party/xla/third_party/absl/absl_designated_initializers.patch b/third_party/xla/third_party/absl/absl_designated_initializers.patch
deleted file mode 100644
index 6ee2322..0000000
--- a/third_party/xla/third_party/absl/absl_designated_initializers.patch
+++ /dev/null
@@ -1,65 +0,0 @@
-diff --git a/absl/crc/internal/crc_memcpy_x86_64.cc b/absl/crc/internal/crc_memcpy_x86_64.cc
-index 66f784de..ff424c54 100644
---- a/absl/crc/internal/crc_memcpy_x86_64.cc
-+++ b/absl/crc/internal/crc_memcpy_x86_64.cc
-@@ -359,18 +359,18 @@ CrcMemcpy::ArchSpecificEngines CrcMemcpy::GetArchSpecificEngines() {
-     case CpuType::kIntelHaswell:
-     case CpuType::kIntelIvybridge:
-       return {
--          .temporal = new FallbackCrcMemcpyEngine(),
--          .non_temporal = new CrcNonTemporalMemcpyAVXEngine(),
-+          /*.temporal=*/new FallbackCrcMemcpyEngine(),
-+          /*.non_temporal=*/new CrcNonTemporalMemcpyAVXEngine(),
-       };
-     // INTEL_SANDYBRIDGE performs better with SSE than AVX.
-     case CpuType::kIntelSandybridge:
-       return {
--          .temporal = new FallbackCrcMemcpyEngine(),
--          .non_temporal = new CrcNonTemporalMemcpyEngine(),
-+          /*.temporal=*/new FallbackCrcMemcpyEngine(),
-+          /*.non_temporal=*/new CrcNonTemporalMemcpyEngine(),
-       };
-     default:
--      return {.temporal = new FallbackCrcMemcpyEngine(),
--              .non_temporal = new FallbackCrcMemcpyEngine()};
-+      return {/*.temporal=*/new FallbackCrcMemcpyEngine(),
-+              /*.non_temporal=*/new FallbackCrcMemcpyEngine()};
-   }
- #else
-   // Get the underlying architecture.
-@@ -388,8 +388,8 @@ CrcMemcpy::ArchSpecificEngines CrcMemcpy::GetArchSpecificEngines() {
-     case CpuType::kAmdRome:
-     case CpuType::kAmdNaples:
-       return {
--          .temporal = new AcceleratedCrcMemcpyEngine<1, 2>(),
--          .non_temporal = new CrcNonTemporalMemcpyAVXEngine(),
-+          /*.temporal=*/new AcceleratedCrcMemcpyEngine<1, 2>(),
-+          /*.non_temporal=*/new CrcNonTemporalMemcpyAVXEngine(),
-       };
-     // PCLMULQDQ is slow and we don't have wide enough issue width to take
-     // advantage of it.  For an unknown architecture, don't risk using CLMULs.
-@@ -400,18 +400,18 @@ CrcMemcpy::ArchSpecificEngines CrcMemcpy::GetArchSpecificEngines() {
-     case CpuType::kIntelHaswell:
-     case CpuType::kIntelIvybridge:
-       return {
--          .temporal = new AcceleratedCrcMemcpyEngine<3, 0>(),
--          .non_temporal = new CrcNonTemporalMemcpyAVXEngine(),
-+          /*.temporal=*/new AcceleratedCrcMemcpyEngine<3, 0>(),
-+          /*.non_temporal=*/new CrcNonTemporalMemcpyAVXEngine(),
-       };
-     // INTEL_SANDYBRIDGE performs better with SSE than AVX.
-     case CpuType::kIntelSandybridge:
-       return {
--          .temporal = new AcceleratedCrcMemcpyEngine<3, 0>(),
--          .non_temporal = new CrcNonTemporalMemcpyEngine(),
-+          /*.temporal=*/new AcceleratedCrcMemcpyEngine<3, 0>(),
-+          /*.non_temporal=*/new CrcNonTemporalMemcpyEngine(),
-       };
-     default:
--      return {.temporal = new FallbackCrcMemcpyEngine(),
--              .non_temporal = new FallbackCrcMemcpyEngine()};
-+      return {/*.temporal=*/new FallbackCrcMemcpyEngine(),
-+              /*.non_temporal=*/new FallbackCrcMemcpyEngine()};
-   }
- #endif  // UNDEFINED_BEHAVIOR_SANITIZER
- }
diff --git a/third_party/xla/third_party/absl/com_google_absl.BUILD b/third_party/xla/third_party/absl/com_google_absl.BUILD
deleted file mode 100644
index 8fca145..0000000
--- a/third_party/xla/third_party/absl/com_google_absl.BUILD
+++ /dev/null
@@ -1,5 +0,0 @@
-package(default_visibility = ["//visibility:public"])
-
-licenses(["notice"])  # Apache
-
-exports_files(["LICENSE"])
diff --git a/third_party/xla/third_party/absl/invert_the_is_inline_bin.patch b/third_party/xla/third_party/absl/invert_the_is_inline_bin.patch
deleted file mode 100644
index 28c4e9b..0000000
--- a/third_party/xla/third_party/absl/invert_the_is_inline_bin.patch
+++ /dev/null
@@ -1,108 +0,0 @@
-From 5c9f72faadaca7250b341b99da358e855a8d902e Mon Sep 17 00:00:00 2001
-From: Abseil Team <absl-team@google.com>
-Date: Tue, 5 Sep 2023 10:45:53 -0700
-Subject: [PATCH] Invert the "is inlined" bit of absl::Status
-
-This change makes  RepToPointer/PointerToRep have 0 instructions.
-This makes IsMovedFrom simpler (although this could always have left out the IsInlined check since that bit can never be set on the aligned pointer)
-
-In exchange, it makes CodeToInlinedRep slower, but does not inhibit replacing it with a constant.
-InlinedRepToCode is unaffected.
-
-PiperOrigin-RevId: 562826801
-Change-Id: I2732f04ab293b773edc2efdec546b3a287b980c2
----
- absl/status/status.cc |  4 ++++
- absl/status/status.h  | 23 +++++++++++++----------
- 2 files changed, 17 insertions(+), 10 deletions(-)
-
-diff --git a/absl/status/status.cc b/absl/status/status.cc
-index 577dea4b..911f4b28 100644
---- a/absl/status/status.cc
-+++ b/absl/status/status.cc
-@@ -46,6 +46,10 @@
- namespace absl {
- ABSL_NAMESPACE_BEGIN
- 
-+static_assert(
-+    alignof(status_internal::StatusRep) >= 4,
-+    "absl::Status assumes it can use the bottom 2 bits of a StatusRep*.");
-+
- std::string StatusCodeToString(StatusCode code) {
-   switch (code) {
-     case StatusCode::kOk:
-diff --git a/absl/status/status.h b/absl/status/status.h
-index 595064c0..2dac2fea 100644
---- a/absl/status/status.h
-+++ b/absl/status/status.h
-@@ -51,10 +51,15 @@
- #ifndef ABSL_STATUS_STATUS_H_
- #define ABSL_STATUS_STATUS_H_
- 
-+#include <cassert>
-+#include <cstdint>
- #include <ostream>
- #include <string>
- #include <utility>
- 
-+#include "absl/base/attributes.h"
-+#include "absl/base/config.h"
-+#include "absl/base/optimization.h"
- #include "absl/functional/function_ref.h"
- #include "absl/status/internal/status_internal.h"
- #include "absl/strings/cord.h"
-@@ -644,13 +649,13 @@ class Status final {
-   std::string ToStringSlow(StatusToStringMode mode) const;
- 
-   // Status supports two different representations.
--  //  - When the low bit is off it is an inlined representation.
-+  //  - When the low bit is set it is an inlined representation.
-   //    It uses the canonical error space, no message or payload.
-   //    The error code is (rep_ >> 2).
-   //    The (rep_ & 2) bit is the "moved from" indicator, used in IsMovedFrom().
--  //  - When the low bit is on it is an external representation.
-+  //  - When the low bit is off it is an external representation.
-   //    In this case all the data comes from a heap allocated Rep object.
--  //    (rep_ - 1) is a status_internal::StatusRep* pointer to that structure.
-+  //    rep_ is a status_internal::StatusRep* pointer to that structure.
-   uintptr_t rep_;
- };
- 
-@@ -839,18 +844,16 @@ inline status_internal::Payloads* Status::GetPayloads() {
-   return IsInlined(rep_) ? nullptr : RepToPointer(rep_)->payloads.get();
- }
- 
--inline bool Status::IsInlined(uintptr_t rep) { return (rep & 1) == 0; }
-+inline bool Status::IsInlined(uintptr_t rep) { return (rep & 1) != 0; }
- 
--inline bool Status::IsMovedFrom(uintptr_t rep) {
--  return IsInlined(rep) && (rep & 2) != 0;
--}
-+inline bool Status::IsMovedFrom(uintptr_t rep) { return (rep & 2) != 0; }
- 
- inline uintptr_t Status::MovedFromRep() {
-   return CodeToInlinedRep(absl::StatusCode::kInternal) | 2;
- }
- 
- inline uintptr_t Status::CodeToInlinedRep(absl::StatusCode code) {
--  return static_cast<uintptr_t>(code) << 2;
-+  return (static_cast<uintptr_t>(code) << 2) + 1;
- }
- 
- inline absl::StatusCode Status::InlinedRepToCode(uintptr_t rep) {
-@@ -860,11 +863,11 @@ inline absl::StatusCode Status::InlinedRepToCode(uintptr_t rep) {
- 
- inline status_internal::StatusRep* Status::RepToPointer(uintptr_t rep) {
-   assert(!IsInlined(rep));
--  return reinterpret_cast<status_internal::StatusRep*>(rep - 1);
-+  return reinterpret_cast<status_internal::StatusRep*>(rep);
- }
- 
- inline uintptr_t Status::PointerToRep(status_internal::StatusRep* rep) {
--  return reinterpret_cast<uintptr_t>(rep) + 1;
-+  return reinterpret_cast<uintptr_t>(rep);
- }
- 
- inline void Status::Ref(uintptr_t rep) {
--- 
-2.25.1
diff --git a/third_party/xla/third_party/absl/system.BUILD b/third_party/xla/third_party/absl/system.BUILD
deleted file mode 100644
index 134d273..0000000
--- a/third_party/xla/third_party/absl/system.BUILD
+++ /dev/null
@@ -1,8 +0,0 @@
-package(default_visibility = ["//visibility:public"])
-
-licenses(["notice"])  # Apache
-
-filegroup(
-    name = "LICENSE",
-    visibility = ["//visibility:public"],
-)
diff --git a/third_party/xla/third_party/absl/system.absl.algorithm.BUILD b/third_party/xla/third_party/absl/system.absl.algorithm.BUILD
deleted file mode 100644
index ffcb03a..0000000
--- a/third_party/xla/third_party/absl/system.absl.algorithm.BUILD
+++ /dev/null
@@ -1,10 +0,0 @@
-load("@rules_cc//cc:defs.bzl", "cc_library")
-
-package(default_visibility = ["//visibility:public"])
-
-[cc_library(
-    name = n,
-) for n in [
-    "algorithm",
-    "container",
-]]
diff --git a/third_party/xla/third_party/absl/system.absl.base.BUILD b/third_party/xla/third_party/absl/system.absl.base.BUILD
deleted file mode 100644
index d6bf874..0000000
--- a/third_party/xla/third_party/absl/system.absl.base.BUILD
+++ /dev/null
@@ -1,107 +0,0 @@
-load("@rules_cc//cc:defs.bzl", "cc_library")
-
-package(default_visibility = ["//visibility:public"])
-
-[cc_library(
-    name = n,
-) for n in [
-    "config",
-    "core_headers",
-    "base_internal",
-    "dynamic_annotations",
-    "atomic_hook",
-    "errno_saver",
-    "fast_type_id",
-    "pretty_function",
-]]
-
-cc_library(
-    name = "log_severity",
-    linkopts = ["-labsl_log_severity"],
-)
-
-cc_library(
-    name = "raw_logging_internal",
-    linkopts = ["-labsl_raw_logging_internal"],
-    visibility = [
-        "//absl:__subpackages__",
-    ],
-    deps = [
-        ":log_severity",
-    ],
-)
-
-cc_library(
-    name = "spinlock_wait",
-    linkopts = ["-labsl_spinlock_wait"],
-    visibility = [
-        "//absl/base:__pkg__",
-    ],
-)
-
-cc_library(
-    name = "malloc_internal",
-    linkopts = [
-        "-labsl_malloc_internal",
-        "-pthread",
-    ],
-    deps = [
-        ":base",
-        ":raw_logging_internal",
-    ],
-)
-
-cc_library(
-    name = "base",
-    linkopts = [
-        "-labsl_base",
-        "-pthread",
-    ],
-    deps = [
-        ":log_severity",
-        ":raw_logging_internal",
-        ":spinlock_wait",
-    ],
-)
-
-cc_library(
-    name = "throw_delegate",
-    linkopts = ["-labsl_throw_delegate"],
-    visibility = [
-        "//absl:__subpackages__",
-    ],
-    deps = [
-        ":raw_logging_internal",
-    ],
-)
-
-cc_library(
-    name = "endian",
-    deps = [
-        ":base",
-    ],
-)
-
-cc_library(
-    name = "exponential_biased",
-    linkopts = ["-labsl_exponential_biased"],
-    visibility = [
-        "//absl:__subpackages__",
-    ],
-)
-
-cc_library(
-    name = "periodic_sampler",
-    linkopts = ["-labsl_periodic_sampler"],
-    deps = [
-        ":exponential_biased",
-    ],
-)
-
-cc_library(
-    name = "strerror",
-    linkopts = ["-labsl_strerror"],
-    visibility = [
-        "//absl:__subpackages__",
-    ],
-)
diff --git a/third_party/xla/third_party/absl/system.absl.cleanup.BUILD b/third_party/xla/third_party/absl/system.absl.cleanup.BUILD
deleted file mode 100644
index eec527b..0000000
--- a/third_party/xla/third_party/absl/system.absl.cleanup.BUILD
+++ /dev/null
@@ -1,6 +0,0 @@
-load("@rules_cc//cc:defs.bzl", "cc_library")
-
-cc_library(
-    name = "cleanup",
-    visibility = ["//visibility:public"],
-)
diff --git a/third_party/xla/third_party/absl/system.absl.container.BUILD b/third_party/xla/third_party/absl/system.absl.container.BUILD
deleted file mode 100644
index 95c1626..0000000
--- a/third_party/xla/third_party/absl/system.absl.container.BUILD
+++ /dev/null
@@ -1,217 +0,0 @@
-load("@rules_cc//cc:defs.bzl", "cc_library")
-
-package(default_visibility = ["//visibility:public"])
-
-cc_library(
-    name = "compressed_tuple",
-    deps = [
-        "//absl/utility",
-    ],
-)
-
-cc_library(
-    name = "fixed_array",
-    deps = [
-        ":compressed_tuple",
-        "//absl/algorithm",
-        "//absl/base:config",
-        "//absl/base:core_headers",
-        "//absl/base:dynamic_annotations",
-        "//absl/base:throw_delegate",
-        "//absl/memory",
-    ],
-)
-
-cc_library(
-    name = "inlined_vector_internal",
-    deps = [
-        ":compressed_tuple",
-        "//absl/base:core_headers",
-        "//absl/memory",
-        "//absl/meta:type_traits",
-        "//absl/types:span",
-    ],
-)
-
-cc_library(
-    name = "inlined_vector",
-    deps = [
-        ":inlined_vector_internal",
-        "//absl/algorithm",
-        "//absl/base:core_headers",
-        "//absl/base:throw_delegate",
-        "//absl/memory",
-    ],
-)
-
-cc_library(
-    name = "flat_hash_map",
-    deps = [
-        ":container_memory",
-        ":hash_function_defaults",
-        ":raw_hash_map",
-        "//absl/algorithm:container",
-        "//absl/memory",
-    ],
-)
-
-cc_library(
-    name = "flat_hash_set",
-    deps = [
-        ":container_memory",
-        ":hash_function_defaults",
-        ":raw_hash_set",
-        "//absl/algorithm:container",
-        "//absl/base:core_headers",
-        "//absl/memory",
-    ],
-)
-
-cc_library(
-    name = "node_hash_map",
-    deps = [
-        ":container_memory",
-        ":hash_function_defaults",
-        ":node_hash_policy",
-        ":raw_hash_map",
-        "//absl/algorithm:container",
-        "//absl/memory",
-    ],
-)
-
-cc_library(
-    name = "node_hash_set",
-    deps = [
-        ":hash_function_defaults",
-        ":node_hash_policy",
-        ":raw_hash_set",
-        "//absl/algorithm:container",
-        "//absl/memory",
-    ],
-)
-
-cc_library(
-    name = "container_memory",
-    deps = [
-        "//absl/base:config",
-        "//absl/memory",
-        "//absl/meta:type_traits",
-        "//absl/utility",
-    ],
-)
-
-cc_library(
-    name = "hash_function_defaults",
-    deps = [
-        "//absl/base:config",
-        "//absl/hash",
-        "//absl/strings",
-        "//absl/strings:cord",
-    ],
-)
-
-cc_library(
-    name = "hash_policy_traits",
-    deps = ["//absl/meta:type_traits"],
-)
-
-cc_library(
-    name = "hashtable_debug",
-    deps = [
-        ":hashtable_debug_hooks",
-    ],
-)
-
-cc_library(
-    name = "hashtable_debug_hooks",
-    deps = [
-        "//absl/base:config",
-    ],
-)
-
-cc_library(
-    name = "hashtablez_sampler",
-    linkopts = ["-labsl_hashtablez_sampler"],
-    deps = [
-        "//absl/base",
-        "//absl/base:core_headers",
-        "//absl/base:exponential_biased",
-        "//absl/debugging:stacktrace",
-        "//absl/memory",
-        "//absl/synchronization",
-        "//absl/utility",
-    ],
-)
-
-cc_library(
-    name = "node_hash_policy",
-    deps = ["//absl/base:config"],
-)
-
-cc_library(
-    name = "raw_hash_map",
-    deps = [
-        ":container_memory",
-        ":raw_hash_set",
-        "//absl/base:throw_delegate",
-    ],
-)
-
-cc_library(
-    name = "common",
-    deps = [
-        "//absl/meta:type_traits",
-        "//absl/types:optional",
-    ],
-)
-
-cc_library(
-    name = "raw_hash_set",
-    linkopts = ["-labsl_raw_hash_set"],
-    deps = [
-        ":common",
-        ":compressed_tuple",
-        ":container_memory",
-        ":hash_policy_traits",
-        ":hashtable_debug_hooks",
-        ":hashtablez_sampler",
-        ":layout",
-        "//absl/base:config",
-        "//absl/base:core_headers",
-        "//absl/base:endian",
-        "//absl/memory",
-        "//absl/meta:type_traits",
-        "//absl/numeric:bits",
-        "//absl/utility",
-    ],
-)
-
-cc_library(
-    name = "layout",
-    deps = [
-        "//absl/base:config",
-        "//absl/base:core_headers",
-        "//absl/meta:type_traits",
-        "//absl/strings",
-        "//absl/types:span",
-        "//absl/utility",
-    ],
-)
-
-cc_library(
-    name = "btree",
-    deps = [
-        ":common",
-        ":compressed_tuple",
-        ":container_memory",
-        ":layout",
-        "//absl/base:core_headers",
-        "//absl/base:throw_delegate",
-        "//absl/memory",
-        "//absl/meta:type_traits",
-        "//absl/strings",
-        "//absl/strings:cord",
-        "//absl/types:compare",
-        "//absl/utility",
-    ],
-)
diff --git a/third_party/xla/third_party/absl/system.absl.debugging.BUILD b/third_party/xla/third_party/absl/system.absl.debugging.BUILD
deleted file mode 100644
index 931ffdc..0000000
--- a/third_party/xla/third_party/absl/system.absl.debugging.BUILD
+++ /dev/null
@@ -1,69 +0,0 @@
-load("@rules_cc//cc:defs.bzl", "cc_library")
-
-package(default_visibility = ["//visibility:public"])
-
-cc_library(
-    name = "stacktrace",
-    linkopts = ["-labsl_stacktrace"],
-    deps = [
-        ":debugging_internal",
-    ],
-)
-
-cc_library(
-    name = "symbolize",
-    linkopts = ["-labsl_symbolize"],
-    deps = [
-        ":debugging_internal",
-        ":demangle_internal",
-        "//absl/base",
-        "//absl/base:dynamic_annotations",
-        "//absl/base:malloc_internal",
-        "//absl/base:raw_logging_internal",
-        "//absl/strings",
-    ],
-)
-
-cc_library(
-    name = "failure_signal_handler",
-    linkopts = [
-        "-labsl_failure_signal_handler",
-        "-labsl_examine_stack",
-    ],
-    deps = [
-        ":stacktrace",
-        ":symbolize",
-        "//absl/base",
-        "//absl/base:errno_saver",
-        "//absl/base:raw_logging_internal",
-    ],
-)
-
-cc_library(
-    name = "debugging_internal",
-    linkopts = ["-labsl_debugging_internal"],
-    deps = [
-        "//absl/base:dynamic_annotations",
-        "//absl/base:errno_saver",
-        "//absl/base:raw_logging_internal",
-    ],
-)
-
-cc_library(
-    name = "demangle_internal",
-    linkopts = ["-labsl_demangle_internal"],
-    deps = [
-        "//absl/base",
-    ],
-)
-
-cc_library(
-    name = "leak_check",
-    linkopts = ["-labsl_leak_check"],
-)
-
-cc_library(
-    name = "leak_check_disable",
-    linkopts = ["-labsl_leak_check_disable"],
-    alwayslink = 1,
-)
diff --git a/third_party/xla/third_party/absl/system.absl.flags.BUILD b/third_party/xla/third_party/absl/system.absl.flags.BUILD
deleted file mode 100644
index aff653c..0000000
--- a/third_party/xla/third_party/absl/system.absl.flags.BUILD
+++ /dev/null
@@ -1,155 +0,0 @@
-load("@rules_cc//cc:defs.bzl", "cc_library")
-
-package(default_visibility = ["//visibility:public"])
-
-cc_library(
-    name = "program_name",
-    linkopts = ["-labsl_flags_program_name"],
-    visibility = [
-        "//absl/flags:__pkg__",
-    ],
-    deps = [
-        "//absl/strings",
-        "//absl/synchronization",
-    ],
-)
-
-cc_library(
-    name = "config",
-    linkopts = ["-labsl_flags_config"],
-    deps = [
-        ":program_name",
-        "//absl/strings",
-        "//absl/synchronization",
-    ],
-)
-
-cc_library(
-    name = "marshalling",
-    linkopts = ["-labsl_flags_marshalling"],
-    deps = [
-        "//absl/base:log_severity",
-        "//absl/strings",
-        "//absl/strings:str_format",
-    ],
-)
-
-cc_library(
-    name = "commandlineflag_internal",
-    linkopts = ["-labsl_flags_commandlineflag_internal"],
-)
-
-cc_library(
-    name = "commandlineflag",
-    linkopts = ["-labsl_flags_commandlineflag"],
-    deps = [
-        ":commandlineflag_internal",
-        "//absl/strings",
-        "//absl/types:optional",
-    ],
-)
-
-cc_library(
-    name = "private_handle_accessor",
-    linkopts = ["-labsl_flags_private_handle_accessor"],
-    visibility = [
-        "//absl/flags:__pkg__",
-    ],
-    deps = [
-        ":commandlineflag",
-        ":commandlineflag_internal",
-        "//absl/strings",
-    ],
-)
-
-cc_library(
-    name = "reflection",
-    linkopts = ["-labsl_flags_reflection"],
-    deps = [
-        ":commandlineflag",
-        ":commandlineflag_internal",
-        ":config",
-        ":private_handle_accessor",
-        "//absl/container:flat_hash_map",
-        "//absl/strings",
-        "//absl/synchronization",
-    ],
-)
-
-cc_library(
-    name = "flag_internal",
-    linkopts = ["-labsl_flags_internal"],
-    visibility = ["//absl/base:__subpackages__"],
-    deps = [
-        ":commandlineflag",
-        ":commandlineflag_internal",
-        ":config",
-        ":marshalling",
-        ":reflection",
-        "//absl/base",
-        "//absl/memory",
-        "//absl/meta:type_traits",
-        "//absl/strings",
-        "//absl/synchronization",
-        "//absl/utility",
-    ],
-)
-
-cc_library(
-    name = "flag",
-    linkopts = ["-labsl_flags"],
-    deps = [
-        ":config",
-        ":flag_internal",
-        ":reflection",
-        "//absl/base",
-        "//absl/strings",
-    ],
-)
-
-cc_library(
-    name = "usage_internal",
-    linkopts = ["-labsl_flags_usage_internal"],
-    visibility = [
-        "//absl/flags:__pkg__",
-    ],
-    deps = [
-        ":commandlineflag",
-        ":config",
-        ":flag",
-        ":flag_internal",
-        ":private_handle_accessor",
-        ":program_name",
-        ":reflection",
-        "//absl/strings",
-    ],
-)
-
-cc_library(
-    name = "usage",
-    linkopts = ["-labsl_flags_usage"],
-    deps = [
-        ":usage_internal",
-        "//absl/strings",
-        "//absl/synchronization",
-    ],
-)
-
-cc_library(
-    name = "parse",
-    linkopts = ["-labsl_flags_parse"],
-    deps = [
-        ":commandlineflag",
-        ":commandlineflag_internal",
-        ":config",
-        ":flag",
-        ":flag_internal",
-        ":private_handle_accessor",
-        ":program_name",
-        ":reflection",
-        ":usage",
-        ":usage_internal",
-        "//absl/strings",
-        "//absl/synchronization",
-    ],
-)
diff --git a/third_party/xla/third_party/absl/system.absl.functional.BUILD b/third_party/xla/third_party/absl/system.absl.functional.BUILD
deleted file mode 100644
index a4f70ac..0000000
--- a/third_party/xla/third_party/absl/system.absl.functional.BUILD
+++ /dev/null
@@ -1,11 +0,0 @@
-load("@rules_cc//cc:defs.bzl", "cc_library")
-
-package(default_visibility = ["//visibility:public"])
-
-cc_library(
-    name = "bind_front",
-)
-
-cc_library(
-    name = "function_ref",
-)
diff --git a/third_party/xla/third_party/absl/system.absl.hash.BUILD b/third_party/xla/third_party/absl/system.absl.hash.BUILD
deleted file mode 100644
index 3367340..0000000
--- a/third_party/xla/third_party/absl/system.absl.hash.BUILD
+++ /dev/null
@@ -1,37 +0,0 @@
-load("@rules_cc//cc:defs.bzl", "cc_library")
-
-package(default_visibility = ["//visibility:public"])
-
-cc_library(
-    name = "hash",
-    linkopts = ["-labsl_hash"],
-    deps = [
-        ":city",
-        ":low_level_hash",
-        "//absl/base:endian",
-        "//absl/container:fixed_array",
-        "//absl/numeric:int128",
-        "//absl/strings",
-        "//absl/types:optional",
-        "//absl/types:variant",
-        "//absl/utility",
-    ],
-)
-
-cc_library(
-    name = "city",
-    linkopts = ["-labsl_city"],
-    deps = [
-        "//absl/base:endian",
-    ],
-)
-
-cc_library(
-    name = "low_level_hash",
-    linkopts = ["-labsl_low_level_hash"],
-    visibility = ["//visibility:private"],
-    deps = [
-        "//absl/base:endian",
-        "//absl/numeric:int128",
-    ],
-)
diff --git a/third_party/xla/third_party/absl/system.absl.memory.BUILD b/third_party/xla/third_party/absl/system.absl.memory.BUILD
deleted file mode 100644
index 592c004..0000000
--- a/third_party/xla/third_party/absl/system.absl.memory.BUILD
+++ /dev/null
@@ -1,7 +0,0 @@
-load("@rules_cc//cc:defs.bzl", "cc_library")
-
-package(default_visibility = ["//visibility:public"])
-
-cc_library(
-    name = "memory",
-)
diff --git a/third_party/xla/third_party/absl/system.absl.meta.BUILD b/third_party/xla/third_party/absl/system.absl.meta.BUILD
deleted file mode 100644
index 966a7ac..0000000
--- a/third_party/xla/third_party/absl/system.absl.meta.BUILD
+++ /dev/null
@@ -1,6 +0,0 @@
-load("@rules_cc//cc:defs.bzl", "cc_library")
-
-cc_library(
-    name = "type_traits",
-    visibility = ["//visibility:public"],
-)
diff --git a/third_party/xla/third_party/absl/system.absl.numeric.BUILD b/third_party/xla/third_party/absl/system.absl.numeric.BUILD
deleted file mode 100644
index 59a5836..0000000
--- a/third_party/xla/third_party/absl/system.absl.numeric.BUILD
+++ /dev/null
@@ -1,16 +0,0 @@
-load("@rules_cc//cc:defs.bzl", "cc_library")
-
-package(default_visibility = ["//visibility:public"])
-
-cc_library(
-    name = "bits",
-)
-
-cc_library(
-    name = "int128",
-    linkopts = ["-labsl_int128"],
-)
-
-cc_library(
-    name = "representation",
-)
diff --git a/third_party/xla/third_party/absl/system.absl.random.BUILD b/third_party/xla/third_party/absl/system.absl.random.BUILD
deleted file mode 100644
index 948de07..0000000
--- a/third_party/xla/third_party/absl/system.absl.random.BUILD
+++ /dev/null
@@ -1,53 +0,0 @@
-load("@rules_cc//cc:defs.bzl", "cc_library")
-
-package(default_visibility = ["//visibility:public"])
-
-cc_library(
-    name = "random",
-    deps = [
-        ":distributions",
-        ":seed_sequences",
-        "//absl/base:endian",
-    ],
-)
-
-cc_library(
-    name = "distributions",
-    linkopts = ["-labsl_random_distributions"],
-    deps = [
-        "//absl/numeric:bits",
-        "//absl/numeric:int128",
-        "//absl/strings",
-    ],
-)
-
-cc_library(
-    name = "seed_gen_exception",
-    linkopts = ["-labsl_random_seed_gen_exception"],
-)
-
-cc_library(
-    name = "seed_sequences",
-    linkopts = [
-        "-labsl_random_internal_platform",
-        "-labsl_random_internal_pool_urbg",
-        "-labsl_random_internal_randen",
-        "-labsl_random_internal_randen_hwaes",
-        "-labsl_random_internal_randen_hwaes_impl",
-        "-labsl_random_internal_randen_slow",
-        "-labsl_random_internal_seed_material",
-        "-labsl_random_seed_sequences",
-        "-pthread",
-    ],
-    deps = [
-        ":seed_gen_exception",
-        "//absl/base",
-        "//absl/base:endian",
-        "//absl/base:raw_logging_internal",
-        "//absl/container:inlined_vector",
-        "//absl/numeric:int128",
-        "//absl/strings",
-        "//absl/types:optional",
-        "//absl/types:span",
-    ],
-)
diff --git a/third_party/xla/third_party/absl/system.absl.status.BUILD b/third_party/xla/third_party/absl/system.absl.status.BUILD
deleted file mode 100644
index e50e979..0000000
--- a/third_party/xla/third_party/absl/system.absl.status.BUILD
+++ /dev/null
@@ -1,31 +0,0 @@
-load("@rules_cc//cc:defs.bzl", "cc_library")
-
-package(default_visibility = ["//visibility:public"])
-
-cc_library(
-    name = "status",
-    linkopts = ["-labsl_status"],
-    deps = [
-        "//absl/base:atomic_hook",
-        "//absl/base:raw_logging_internal",
-        "//absl/container:inlined_vector",
-        "//absl/debugging:stacktrace",
-        "//absl/debugging:symbolize",
-        "//absl/strings",
-        "//absl/strings:cord",
-        "//absl/strings:str_format",
-        "//absl/types:optional",
-    ],
-)
-
-cc_library(
-    name = "statusor",
-    linkopts = ["-labsl_statusor"],
-    deps = [
-        ":status",
-        "//absl/base:raw_logging_internal",
-        "//absl/strings",
-        "//absl/types:variant",
-        "//absl/utility",
-    ],
-)
diff --git a/third_party/xla/third_party/absl/system.absl.strings.BUILD b/third_party/xla/third_party/absl/system.absl.strings.BUILD
deleted file mode 100644
index fa9a7a8..0000000
--- a/third_party/xla/third_party/absl/system.absl.strings.BUILD
+++ /dev/null
@@ -1,49 +0,0 @@
-load("@rules_cc//cc:defs.bzl", "cc_library")
-
-package(default_visibility = ["//visibility:public"])
-
-cc_library(
-    name = "strings",
-    linkopts = ["-labsl_strings"],
-    deps = [
-        ":internal",
-        "//absl/base",
-        "//absl/base:throw_delegate",
-        "//absl/memory",
-        "//absl/numeric:bits",
-        "//absl/numeric:int128",
-    ],
-)
-
-cc_library(
-    name = "internal",
-    linkopts = ["-labsl_strings_internal"],
-    deps = [
-        "//absl/base:endian",
-        "//absl/base:raw_logging_internal",
-    ],
-)
-
-cc_library(
-    name = "cord",
-    linkopts = ["-labsl_cord"],
-    deps = [
-        ":str_format",
-        "//absl/container:compressed_tuple",
-        "//absl/container:fixed_array",
-        "//absl/container:inlined_vector",
-        "//absl/container:layout",
-    ],
-)
-
-cc_library(
-    name = "str_format",
-    linkopts = ["-labsl_str_format_internal"],
-    deps = [
-        ":strings",
-        "//absl/functional:function_ref",
-        "//absl/numeric:representation",
-        "//absl/types:optional",
-        "//absl/types:span",
-    ],
-)
diff --git a/third_party/xla/third_party/absl/system.absl.synchronization.BUILD b/third_party/xla/third_party/absl/system.absl.synchronization.BUILD
deleted file mode 100644
index c0fa37a..0000000
--- a/third_party/xla/third_party/absl/system.absl.synchronization.BUILD
+++ /dev/null
@@ -1,36 +0,0 @@
-load("@rules_cc//cc:defs.bzl", "cc_library")
-
-package(default_visibility = ["//visibility:public"])
-
-# Internal data structure for efficiently detecting mutex dependency cycles
-cc_library(
-    name = "graphcycles_internal",
-    linkopts = ["-labsl_graphcycles_internal"],
-    visibility = [
-        "//absl:__subpackages__",
-    ],
-    deps = [
-        "//absl/base",
-        "//absl/base:malloc_internal",
-        "//absl/base:raw_logging_internal",
-    ],
-)
-
-cc_library(
-    name = "synchronization",
-    linkopts = [
-        "-labsl_synchronization",
-        "-pthread",
-    ],
-    deps = [
-        ":graphcycles_internal",
-        "//absl/base",
-        "//absl/base:atomic_hook",
-        "//absl/base:dynamic_annotations",
-        "//absl/base:malloc_internal",
-        "//absl/base:raw_logging_internal",
-        "//absl/debugging:stacktrace",
-        "//absl/debugging:symbolize",
-        "//absl/time",
-    ],
-)
diff --git a/third_party/xla/third_party/absl/system.absl.time.BUILD b/third_party/xla/third_party/absl/system.absl.time.BUILD
deleted file mode 100644
index fe295c3..0000000
--- a/third_party/xla/third_party/absl/system.absl.time.BUILD
+++ /dev/null
@@ -1,18 +0,0 @@
-load("@rules_cc//cc:defs.bzl", "cc_library")
-
-package(default_visibility = ["//visibility:public"])
-
-cc_library(
-    name = "time",
-    linkopts = [
-        "-labsl_time",
-        "-labsl_civil_time",
-        "-labsl_time_zone",
-    ],
-    deps = [
-        "//absl/base",
-        "//absl/base:raw_logging_internal",
-        "//absl/numeric:int128",
-        "//absl/strings",
-    ],
-)
diff --git a/third_party/xla/third_party/absl/system.absl.types.BUILD b/third_party/xla/third_party/absl/system.absl.types.BUILD
deleted file mode 100644
index db94fc9..0000000
--- a/third_party/xla/third_party/absl/system.absl.types.BUILD
+++ /dev/null
@@ -1,59 +0,0 @@
-load("@rules_cc//cc:defs.bzl", "cc_library")
-
-package(default_visibility = ["//visibility:public"])
-
-cc_library(
-    name = "any",
-    deps = [
-        ":bad_any_cast",
-    ],
-)
-
-cc_library(
-    name = "bad_any_cast",
-    linkopts = ["-labsl_bad_any_cast_impl"],
-)
-
-cc_library(
-    name = "span",
-    deps = [
-        "//absl/base:throw_delegate",
-    ],
-)
-
-cc_library(
-    name = "optional",
-    deps = [
-        ":bad_optional_access",
-    ],
-)
-
-cc_library(
-    name = "bad_optional_access",
-    linkopts = ["-labsl_bad_optional_access"],
-    deps = [
-        "//absl/base:raw_logging_internal",
-    ],
-)
-
-cc_library(
-    name = "bad_variant_access",
-    linkopts = ["-labsl_bad_variant_access"],
-    deps = [
-        "//absl/base:raw_logging_internal",
-    ],
-)
-
-cc_library(
-    name = "variant",
-    deps = [
-        ":bad_variant_access",
-    ],
-)
-
-cc_library(
-    name = "compare",
-    deps = [
-        "//absl/meta:type_traits",
-    ],
-)
diff --git a/third_party/xla/third_party/absl/system.absl.utility.BUILD b/third_party/xla/third_party/absl/system.absl.utility.BUILD
deleted file mode 100644
index e15049e..0000000
--- a/third_party/xla/third_party/absl/system.absl.utility.BUILD
+++ /dev/null
@@ -1,6 +0,0 @@
-load("@rules_cc//cc:defs.bzl", "cc_library")
-
-cc_library(
-    name = "utility",
-    visibility = ["//visibility:public"],
-)
diff --git a/third_party/xla/third_party/absl/workspace.bzl b/third_party/xla/third_party/absl/workspace.bzl
deleted file mode 100644
index 07f49ce..0000000
--- a/third_party/xla/third_party/absl/workspace.bzl
+++ /dev/null
@@ -1,50 +0,0 @@
-"""Provides the repository macro to import absl."""
-
-load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
-
-def repo():
-    """Imports absl."""
-
-    # Attention: tools parse and update these lines.
-    # LINT.IfChange
-    ABSL_COMMIT = "b971ac5250ea8de900eae9f95e06548d14cd95fe"
-    ABSL_SHA256 = "8eeec9382fc0338ef5c60053f3a4b0e0708361375fe51c9e65d0ce46ccfe55a7"
-    # LINT.ThenChange(//tensorflow/lite/tools/cmake/modules/abseil-cpp.cmake)
-
-    SYS_DIRS = [
-        "algorithm",
-        "base",
-        "cleanup",
-        "container",
-        "debugging",
-        "flags",
-        "functional",
-        "hash",
-        "memory",
-        "meta",
-        "numeric",
-        "random",
-        "status",
-        "strings",
-        "synchronization",
-        "time",
-        "types",
-        "utility",
-    ]
-    SYS_LINKS = {
-        "//third_party/absl:system.absl.{name}.BUILD".format(name = n): "absl/{name}/BUILD.bazel".format(name = n)
-        for n in SYS_DIRS
-    }
-
-    tf_http_archive(
-        name = "com_google_absl",
-        sha256 = ABSL_SHA256,
-        build_file = "//third_party/absl:com_google_absl.BUILD",
-        system_build_file = "//third_party/absl:system.BUILD",
-        system_link_files = SYS_LINKS,
-        # This patch pulls in a fix for designated initializers that MSVC
-        # complains about. It shouldn't be necessary at the next LTS release.
-        patch_file = ["//third_party/absl:absl_designated_initializers.patch"],
-        strip_prefix = "abseil-cpp-{commit}".format(commit = ABSL_COMMIT),
-        urls = tf_mirror_urls("https://github.com/abseil/abseil-cpp/archive/{commit}.tar.gz".format(commit = ABSL_COMMIT)),
-    )
diff --git a/third_party/xla/third_party/benchmark/BUILD b/third_party/xla/third_party/benchmark/BUILD
deleted file mode 100644
index 3c41380..0000000
--- a/third_party/xla/third_party/benchmark/BUILD
+++ /dev/null
@@ -1 +0,0 @@
-# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"])
diff --git a/third_party/xla/third_party/benchmark/workspace.bzl b/third_party/xla/third_party/benchmark/workspace.bzl
deleted file mode 100644
index 679133c..0000000
--- a/third_party/xla/third_party/benchmark/workspace.bzl
+++ /dev/null
@@ -1,14 +0,0 @@
-"""Provides the repo macro to import google benchmark"""
-
-load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
-
-def repo():
-    """Imports benchmark."""
-    BM_COMMIT = "f7547e29ccaed7b64ef4f7495ecfff1c9f6f3d03"
-    BM_SHA256 = "552ca3d4d1af4beeb1907980f7096315aa24150d6baf5ac1e5ad90f04846c670"
-    tf_http_archive(
-        name = "com_google_benchmark",
-        sha256 = BM_SHA256,
-        strip_prefix = "benchmark-{commit}".format(commit = BM_COMMIT),
-        urls = tf_mirror_urls("https://github.com/google/benchmark/archive/{commit}.tar.gz".format(commit = BM_COMMIT)),
-    )
diff --git a/third_party/xla/third_party/clang_toolchain/BUILD b/third_party/xla/third_party/clang_toolchain/BUILD
deleted file mode 100644
index e69de29..0000000
--- a/third_party/xla/third_party/clang_toolchain/BUILD
+++ /dev/null
diff --git a/third_party/xla/third_party/clang_toolchain/cc_configure_clang.bzl b/third_party/xla/third_party/clang_toolchain/cc_configure_clang.bzl
deleted file mode 100644
index a6b87ab..0000000
--- a/third_party/xla/third_party/clang_toolchain/cc_configure_clang.bzl
+++ /dev/null
@@ -1,27 +0,0 @@
-""" Downloads clang and configures the crosstool using bazel's autoconf."""
-
-load("@bazel_tools//tools/cpp:cc_configure.bzl", "cc_autoconf_impl")
-load(":download_clang.bzl", "download_clang")
-
-_TF_DOWNLOAD_CLANG = "TF_DOWNLOAD_CLANG"
-_TF_NEED_CUDA = "TF_NEED_CUDA"
-
-def _cc_clang_autoconf(repo_ctx):
-    if repo_ctx.os.environ.get(_TF_DOWNLOAD_CLANG) != "1":
-        return
-    if repo_ctx.os.environ.get(_TF_NEED_CUDA) == "1":
-        # Clang is handled separately for CUDA configs.
-        # See cuda_configure.bzl for more details.
-        return
-
-    download_clang(repo_ctx, out_folder = "extra_tools")
-    overridden_tools = {"gcc": "extra_tools/bin/clang"}
-    cc_autoconf_impl(repo_ctx, overridden_tools)
-
-cc_download_clang_toolchain = repository_rule(
-    environ = [
-        _TF_DOWNLOAD_CLANG,
-        _TF_NEED_CUDA,
-    ],
-    implementation = _cc_clang_autoconf,
-)
diff --git a/third_party/xla/third_party/clang_toolchain/download_clang.bzl b/third_party/xla/third_party/clang_toolchain/download_clang.bzl
deleted file mode 100644
index 6e6091b..0000000
--- a/third_party/xla/third_party/clang_toolchain/download_clang.bzl
+++ /dev/null
@@ -1,64 +0,0 @@
-""" Helpers to download a recent clang release."""
-
-def _get_platform_folder(os_name):
-    os_name = os_name.lower()
-    if os_name.startswith("windows"):
-        return "Win"
-    if os_name.startswith("mac os"):
-        return "Mac"
-    if not os_name.startswith("linux"):
-        fail("Unknown platform")
-    return "Linux_x64"
-
-def _download_chromium_clang(
-        repo_ctx,
-        platform_folder,
-        package_version,
-        sha256,
-        out_folder):
-    cds_url = "https://commondatastorage.googleapis.com/chromium-browser-clang"
-    cds_file = "clang-%s.tgz" % package_version
-    cds_full_url = "{0}/{1}/{2}".format(cds_url, platform_folder, cds_file)
-    repo_ctx.download_and_extract(cds_full_url, output = out_folder, sha256 = sha256)
-
-def download_clang(repo_ctx, out_folder):
-    """ Download a fresh clang release and put it into out_folder.
-
-    Clang itself will be located in 'out_folder/bin/clang'.
-    We currently download one of the latest releases of clang by the
-    Chromium project (see
-    https://chromium.googlesource.com/chromium/src/+/master/docs/clang.md).
-
-    Args:
-      repo_ctx: An instance of repository_context object.
-      out_folder: A folder to extract the compiler into.
-    """
-    # TODO(ibiryukov): we currently download and extract some extra tools in the
-    # clang release (e.g., sanitizers). We should probably remove the ones
-    # we don't need and document the ones we want provide in addition to clang.
-
-    # Latest CLANG_REVISION and CLANG_SUB_REVISION of the Chromiums's release
-    # can be found in https://chromium.googlesource.com/chromium/src/tools/clang/+/master/scripts/update.py
-    CLANG_REVISION = "b4160cb94c54f0b31d0ce14694950dac7b6cd83f"
-    CLANG_SVN_REVISION = "371856"
-    CLANG_SUB_REVISION = 1
-    package_version = "%s-%s-%s" % (
-        CLANG_SVN_REVISION,
-        CLANG_REVISION[:8],
-        CLANG_SUB_REVISION,
-    )
-
-    checksums = {
-        "Linux_x64": "919c19df3ebd7db03b72575b2de5198404357659fc8c85c2d66e679ad4acbafe",
-        "Mac": "5632c516f3ac5fab3654d0a874688cad6c7f99b96845da27ab12336a14187aa2",
-        "Win": "235545b33f4d697190032cb538fdcaba227017c95b752ea8af8f29aab8da7479",
-    }
-
-    platform_folder = _get_platform_folder(repo_ctx.os.name)
-    _download_chromium_clang(
-        repo_ctx,
-        platform_folder,
-        package_version,
-        checksums[platform_folder],
-        out_folder,
-    )
diff --git a/third_party/xla/third_party/curl.BUILD b/third_party/xla/third_party/curl.BUILD
deleted file mode 100644
index 8dcd544..0000000
--- a/third_party/xla/third_party/curl.BUILD
+++ /dev/null
@@ -1,822 +0,0 @@
-# Description:
-#   curl is a tool for talking to web servers.
-
-licenses(["notice"])  # MIT/X derivative license
-
-exports_files(["COPYING"])
-
-CURL_WIN_COPTS = [
-    "/Iexternal/curl/lib",
-    "/DBUILDING_LIBCURL",
-    "/DHAVE_CONFIG_H",
-    "/DCURL_DISABLE_FTP",
-    "/DCURL_DISABLE_NTLM",
-    "/DCURL_DISABLE_PROXY",
-    "/DHAVE_LIBZ",
-    "/DHAVE_ZLIB_H",
-    # Defining _USING_V110_SDK71_ is hackery to defeat curl's incorrect
-    # detection of what OS releases we can build on with VC 2012. This
-    # may not be needed (or may have to change) if the WINVER setting
-    # changes in //third_party/msvc/vc_12_0/CROSSTOOL.
-    "/D_USING_V110_SDK71_",
-]
-
-CURL_WIN_SRCS = [
-    "lib/asyn-thread.c",
-    "lib/inet_ntop.c",
-    "lib/system_win32.c",
-    "lib/setup-win32.h",
-]
-
-cc_library(
-    name = "curl",
-    srcs = [
-        "include/curl_config.h",
-        "lib/altsvc.c",
-        "lib/altsvc.h",
-        "lib/amigaos.c",
-        "lib/amigaos.h",
-        "lib/arpa_telnet.h",
-        "lib/asyn.h",
-        "lib/asyn-ares.c",
-        "lib/base64.c",
-        "lib/bufq.c",
-        "lib/bufq.h",
-        "lib/bufref.c",
-        "lib/bufref.h",
-        "lib/c-hyper.c",
-        "lib/c-hyper.h",
-        "lib/cf-h1-proxy.c",
-        "lib/cf-h1-proxy.h",
-        "lib/cf-h2-proxy.c",
-        "lib/cf-h2-proxy.h",
-        "lib/cf-haproxy.c",
-        "lib/cf-haproxy.h",
-        "lib/cf-https-connect.c",
-        "lib/cf-https-connect.h",
-        "lib/cf-socket.c",
-        "lib/cf-socket.h",
-        "lib/cfilters.c",
-        "lib/cfilters.h",
-        "lib/config-amigaos.h",
-        "lib/config-dos.h",
-        "lib/config-mac.h",
-        "lib/config-os400.h",
-        "lib/config-plan9.h",
-        "lib/config-riscos.h",
-        "lib/config-win32.h",
-        "lib/config-win32ce.h",
-        "lib/conncache.c",
-        "lib/conncache.h",
-        "lib/connect.c",
-        "lib/connect.h",
-        "lib/content_encoding.c",
-        "lib/content_encoding.h",
-        "lib/cookie.c",
-        "lib/cookie.h",
-        "lib/curl_addrinfo.c",
-        "lib/curl_addrinfo.h",
-        "lib/curl_base64.h",
-        "lib/curl_ctype.h",
-        "lib/curl_des.c",
-        "lib/curl_des.h",
-        "lib/curl_endian.c",
-        "lib/curl_endian.h",
-        "lib/curl_fnmatch.c",
-        "lib/curl_fnmatch.h",
-        "lib/curl_get_line.c",
-        "lib/curl_get_line.h",
-        "lib/curl_gethostname.c",
-        "lib/curl_gethostname.h",
-        "lib/curl_gssapi.c",
-        "lib/curl_gssapi.h",
-        "lib/curl_hmac.h",
-        "lib/curl_krb5.h",
-        "lib/curl_ldap.h",
-        "lib/curl_md4.h",
-        "lib/curl_md5.h",
-        "lib/curl_memory.h",
-        "lib/curl_memrchr.c",
-        "lib/curl_memrchr.h",
-        "lib/curl_multibyte.c",
-        "lib/curl_multibyte.h",
-        "lib/curl_ntlm_core.c",
-        "lib/curl_ntlm_core.h",
-        "lib/curl_ntlm_wb.c",
-        "lib/curl_ntlm_wb.h",
-        "lib/curl_path.c",
-        "lib/curl_path.h",
-        "lib/curl_printf.h",
-        "lib/curl_range.c",
-        "lib/curl_range.h",
-        "lib/curl_rtmp.c",
-        "lib/curl_rtmp.h",
-        "lib/curl_sasl.c",
-        "lib/curl_sasl.h",
-        "lib/curl_setup.h",
-        "lib/curl_setup_once.h",
-        "lib/curl_sha256.h",
-        "lib/curl_sspi.c",
-        "lib/curl_sspi.h",
-        "lib/curl_threads.c",
-        "lib/curl_threads.h",
-        "lib/curl_trc.c",
-        "lib/curl_trc.h",
-        "lib/curlx.h",
-        "lib/dict.c",
-        "lib/dict.h",
-        "lib/doh.c",
-        "lib/doh.h",
-        "lib/dynbuf.c",
-        "lib/dynbuf.h",
-        "lib/dynhds.c",
-        "lib/dynhds.h",
-        "lib/easy.c",
-        "lib/easy_lock.h",
-        "lib/easygetopt.c",
-        "lib/easyif.h",
-        "lib/easyoptions.c",
-        "lib/easyoptions.h",
-        "lib/escape.c",
-        "lib/escape.h",
-        "lib/file.c",
-        "lib/file.h",
-        "lib/fileinfo.c",
-        "lib/fileinfo.h",
-        "lib/fopen.c",
-        "lib/fopen.h",
-        "lib/formdata.c",
-        "lib/formdata.h",
-        "lib/ftp.c",
-        "lib/ftp.h",
-        "lib/ftplistparser.c",
-        "lib/ftplistparser.h",
-        "lib/functypes.h",
-        "lib/getenv.c",
-        "lib/getinfo.c",
-        "lib/getinfo.h",
-        "lib/gopher.c",
-        "lib/gopher.h",
-        "lib/hash.c",
-        "lib/hash.h",
-        "lib/headers.c",
-        "lib/headers.h",
-        "lib/hmac.c",
-        "lib/hostasyn.c",
-        "lib/hostip.c",
-        "lib/hostip.h",
-        "lib/hostip4.c",
-        "lib/hostip6.c",
-        "lib/hostsyn.c",
-        "lib/hsts.c",
-        "lib/hsts.h",
-        "lib/http.c",
-        "lib/http.h",
-        "lib/http1.c",
-        "lib/http1.h",
-        "lib/http2.c",
-        "lib/http2.h",
-        "lib/http_aws_sigv4.c",
-        "lib/http_aws_sigv4.h",
-        "lib/http_chunks.c",
-        "lib/http_chunks.h",
-        "lib/http_digest.c",
-        "lib/http_digest.h",
-        "lib/http_negotiate.c",
-        "lib/http_negotiate.h",
-        "lib/http_ntlm.c",
-        "lib/http_ntlm.h",
-        "lib/http_proxy.c",
-        "lib/http_proxy.h",
-        "lib/idn.c",
-        "lib/idn.h",
-        "lib/if2ip.c",
-        "lib/if2ip.h",
-        "lib/imap.c",
-        "lib/imap.h",
-        "lib/inet_ntop.h",
-        "lib/inet_pton.c",
-        "lib/inet_pton.h",
-        "lib/krb5.c",
-        "lib/ldap.c",
-        "lib/llist.c",
-        "lib/llist.h",
-        "lib/macos.c",
-        "lib/macos.h",
-        "lib/md4.c",
-        "lib/md5.c",
-        "lib/memdebug.c",
-        "lib/memdebug.h",
-        "lib/mime.c",
-        "lib/mime.h",
-        "lib/mprintf.c",
-        "lib/mqtt.c",
-        "lib/mqtt.h",
-        "lib/multi.c",
-        "lib/multihandle.h",
-        "lib/multiif.h",
-        "lib/netrc.c",
-        "lib/netrc.h",
-        "lib/nonblock.c",
-        "lib/nonblock.h",
-        "lib/noproxy.c",
-        "lib/noproxy.h",
-        "lib/openldap.c",
-        "lib/parsedate.c",
-        "lib/parsedate.h",
-        "lib/pingpong.c",
-        "lib/pingpong.h",
-        "lib/pop3.c",
-        "lib/pop3.h",
-        "lib/progress.c",
-        "lib/progress.h",
-        "lib/psl.c",
-        "lib/psl.h",
-        "lib/rand.c",
-        "lib/rand.h",
-        "lib/rename.c",
-        "lib/rename.h",
-        "lib/rtsp.c",
-        "lib/rtsp.h",
-        "lib/select.c",
-        "lib/select.h",
-        "lib/sendf.c",
-        "lib/sendf.h",
-        "lib/setopt.c",
-        "lib/setopt.h",
-        "lib/setup-os400.h",
-        "lib/setup-vms.h",
-        "lib/sha256.c",
-        "lib/share.c",
-        "lib/share.h",
-        "lib/sigpipe.h",
-        "lib/slist.c",
-        "lib/slist.h",
-        "lib/smb.c",
-        "lib/smb.h",
-        "lib/smtp.c",
-        "lib/smtp.h",
-        "lib/sockaddr.h",
-        "lib/socketpair.c",
-        "lib/socketpair.h",
-        "lib/socks.c",
-        "lib/socks.h",
-        "lib/socks_gssapi.c",
-        "lib/socks_sspi.c",
-        "lib/speedcheck.c",
-        "lib/speedcheck.h",
-        "lib/splay.c",
-        "lib/splay.h",
-        "lib/strcase.c",
-        "lib/strcase.h",
-        "lib/strdup.c",
-        "lib/strdup.h",
-        "lib/strerror.c",
-        "lib/strerror.h",
-        "lib/strtok.c",
-        "lib/strtok.h",
-        "lib/strtoofft.c",
-        "lib/strtoofft.h",
-        "lib/system_win32.h",
-        "lib/telnet.c",
-        "lib/telnet.h",
-        "lib/tftp.c",
-        "lib/tftp.h",
-        "lib/timediff.c",
-        "lib/timediff.h",
-        "lib/timeval.c",
-        "lib/timeval.h",
-        "lib/transfer.c",
-        "lib/transfer.h",
-        "lib/url.c",
-        "lib/url.h",
-        "lib/urlapi.c",
-        "lib/urlapi-int.h",
-        "lib/urldata.h",
-        "lib/vauth/cleartext.c",
-        "lib/vauth/cram.c",
-        "lib/vauth/digest.c",
-        "lib/vauth/digest.h",
-        "lib/vauth/digest_sspi.c",
-        "lib/vauth/gsasl.c",
-        "lib/vauth/krb5_gssapi.c",
-        "lib/vauth/krb5_sspi.c",
-        "lib/vauth/ntlm.c",
-        "lib/vauth/ntlm.h",
-        "lib/vauth/ntlm_sspi.c",
-        "lib/vauth/oauth2.c",
-        "lib/vauth/spnego_gssapi.c",
-        "lib/vauth/spnego_sspi.c",
-        "lib/vauth/vauth.c",
-        "lib/vauth/vauth.h",
-        "lib/version.c",
-        "lib/version_win32.c",
-        "lib/version_win32.h",
-        "lib/vquic/curl_msh3.c",
-        "lib/vquic/curl_msh3.h",
-        "lib/vquic/curl_ngtcp2.c",
-        "lib/vquic/curl_ngtcp2.h",
-        "lib/vquic/curl_quiche.c",
-        "lib/vquic/curl_quiche.h",
-        "lib/vquic/vquic.c",
-        "lib/vquic/vquic.h",
-        "lib/vquic/vquic_int.h",
-        "lib/vssh/libssh.c",
-        "lib/vssh/libssh2.c",
-        "lib/vssh/ssh.h",
-        "lib/vssh/wolfssh.c",
-        "lib/vtls/bearssl.c",
-        "lib/vtls/bearssl.h",
-        "lib/vtls/gtls.c",
-        "lib/vtls/gtls.h",
-        "lib/vtls/hostcheck.c",
-        "lib/vtls/hostcheck.h",
-        "lib/vtls/keylog.c",
-        "lib/vtls/keylog.h",
-        "lib/vtls/mbedtls.c",
-        "lib/vtls/mbedtls.h",
-        "lib/vtls/mbedtls_threadlock.c",
-        "lib/vtls/mbedtls_threadlock.h",
-        "lib/vtls/openssl.c",
-        "lib/vtls/openssl.h",
-        "lib/vtls/rustls.c",
-        "lib/vtls/rustls.h",
-        "lib/vtls/schannel.c",
-        "lib/vtls/schannel.h",
-        "lib/vtls/schannel_int.h",
-        "lib/vtls/schannel_verify.c",
-        "lib/vtls/sectransp.h",
-        "lib/vtls/vtls.c",
-        "lib/vtls/vtls.h",
-        "lib/vtls/vtls_int.h",
-        "lib/vtls/wolfssl.c",
-        "lib/vtls/wolfssl.h",
-        "lib/vtls/x509asn1.c",
-        "lib/vtls/x509asn1.h",
-        "lib/warnless.c",
-        "lib/warnless.h",
-        "lib/ws.c",
-        "lib/ws.h",
-    ] + select({
-        "@local_tsl//tsl:macos": [
-            "lib/vtls/sectransp.c",
-        ],
-        "@local_tsl//tsl:ios": [
-            "lib/vtls/sectransp.c",
-        ],
-        "@local_tsl//tsl:windows": CURL_WIN_SRCS,
-        "//conditions:default": [
-        ],
-    }),
-    hdrs = [
-        "include/curl/curl.h",
-        "include/curl/curlver.h",
-        "include/curl/easy.h",
-        "include/curl/header.h",
-        "include/curl/mprintf.h",
-        "include/curl/multi.h",
-        "include/curl/options.h",
-        "include/curl/stdcheaders.h",
-        "include/curl/system.h",
-        "include/curl/typecheck-gcc.h",
-        "include/curl/urlapi.h",
-        "include/curl/websockets.h",
-    ],
-    copts = select({
-        "@local_tsl//tsl:windows": CURL_WIN_COPTS,
-        "//conditions:default": [
-            "-Iexternal/curl/lib",
-            "-D_GNU_SOURCE",
-            "-DBUILDING_LIBCURL",
-            "-DHAVE_CONFIG_H",
-            "-DCURL_DISABLE_FTP",
-            "-DCURL_DISABLE_NTLM",  # turning it off in configure is not enough
-            "-DHAVE_LIBZ",
-            "-DHAVE_ZLIB_H",
-            "-Wno-string-plus-int",
-        ],
-    }) + select({
-        "@local_tsl//tsl:macos": [
-            "-fno-constant-cfstrings",
-        ],
-        "@local_tsl//tsl:windows": [
-            # See curl.h for discussion of write size and Windows
-            "/DCURL_MAX_WRITE_SIZE=16384",
-        ],
-        "//conditions:default": [
-            "-DCURL_MAX_WRITE_SIZE=65536",
-        ],
-    }),
-    defines = ["CURL_STATICLIB"],
-    includes = ["include"],
-    linkopts = select({
-        "@local_tsl//tsl:android": [
-            "-pie",
-        ],
-        "@local_tsl//tsl:macos": [
-            "-Wl,-framework",
-            "-Wl,CoreFoundation",
-            "-Wl,-framework",
-            "-Wl,SystemConfiguration",
-            "-Wl,-framework",
-            "-Wl,Security",
-        ],
-        "@local_tsl//tsl:ios": [],
-        "@local_tsl//tsl:windows": [
-            "-DEFAULTLIB:ws2_32.lib",
-            "-DEFAULTLIB:advapi32.lib",
-            "-DEFAULTLIB:crypt32.lib",
-            "-DEFAULTLIB:Normaliz.lib",
-        ],
-        "//conditions:default": [
-            "-lrt",
-        ],
-    }),
-    visibility = ["//visibility:public"],
-    deps = [
-        "@zlib",
-    ] + select({
-        "@local_tsl//tsl:ios": [],
-        "@local_tsl//tsl:windows": [],
-        "//conditions:default": [
-            "@boringssl//:ssl",
-        ],
-    }),
-)
-
-CURL_BIN_WIN_COPTS = [
-    "/Iexternal/curl/lib",
-    "/DHAVE_CONFIG_H",
-    "/DCURL_DISABLE_LIBCURL_OPTION",
-]
-
-cc_binary(
-    name = "curl_bin",
-    srcs = [
-        "lib/config-win32.h",
-        "src/slist_wc.c",
-        "src/slist_wc.h",
-        "src/tool_binmode.c",
-        "src/tool_binmode.h",
-        "src/tool_bname.c",
-        "src/tool_bname.h",
-        "src/tool_cb_dbg.c",
-        "src/tool_cb_dbg.h",
-        "src/tool_cb_hdr.c",
-        "src/tool_cb_hdr.h",
-        "src/tool_cb_prg.c",
-        "src/tool_cb_prg.h",
-        "src/tool_cb_rea.c",
-        "src/tool_cb_rea.h",
-        "src/tool_cb_see.c",
-        "src/tool_cb_see.h",
-        "src/tool_cb_wrt.c",
-        "src/tool_cb_wrt.h",
-        "src/tool_cfgable.c",
-        "src/tool_cfgable.h",
-        "src/tool_dirhie.c",
-        "src/tool_dirhie.h",
-        "src/tool_doswin.c",
-        "src/tool_doswin.h",
-        "src/tool_easysrc.c",
-        "src/tool_easysrc.h",
-        "src/tool_filetime.c",
-        "src/tool_filetime.h",
-        "src/tool_formparse.c",
-        "src/tool_formparse.h",
-        "src/tool_getparam.c",
-        "src/tool_getparam.h",
-        "src/tool_getpass.c",
-        "src/tool_getpass.h",
-        "src/tool_help.c",
-        "src/tool_help.h",
-        "src/tool_helpers.c",
-        "src/tool_helpers.h",
-        "src/tool_homedir.c",
-        "src/tool_homedir.h",
-        "src/tool_hugehelp.c",
-        "src/tool_hugehelp.h",
-        "src/tool_libinfo.c",
-        "src/tool_libinfo.h",
-        "src/tool_main.c",
-        "src/tool_main.h",
-        "src/tool_metalink.c",
-        "src/tool_metalink.h",
-        "src/tool_mfiles.c",
-        "src/tool_mfiles.h",
-        "src/tool_msgs.c",
-        "src/tool_msgs.h",
-        "src/tool_operate.c",
-        "src/tool_operate.h",
-        "src/tool_operhlp.c",
-        "src/tool_operhlp.h",
-        "src/tool_panykey.c",
-        "src/tool_panykey.h",
-        "src/tool_paramhlp.c",
-        "src/tool_paramhlp.h",
-        "src/tool_parsecfg.c",
-        "src/tool_parsecfg.h",
-        "src/tool_progress.c",
-        "src/tool_progress.h",
-        "src/tool_sdecls.h",
-        "src/tool_setopt.c",
-        "src/tool_setopt.h",
-        "src/tool_setup.h",
-        "src/tool_sleep.c",
-        "src/tool_sleep.h",
-        "src/tool_strdup.c",
-        "src/tool_strdup.h",
-        "src/tool_urlglob.c",
-        "src/tool_urlglob.h",
-        "src/tool_util.c",
-        "src/tool_util.h",
-        "src/tool_version.h",
-        "src/tool_vms.c",
-        "src/tool_vms.h",
-        "src/tool_writeenv.c",
-        "src/tool_writeenv.h",
-        "src/tool_writeout.c",
-        "src/tool_writeout.h",
-        "src/tool_writeout_json.c",
-        "src/tool_writeout_json.h",
-        "src/tool_xattr.c",
-        "src/tool_xattr.h",
-    ],
-    copts = select({
-        "@local_tsl//tsl:windows": CURL_BIN_WIN_COPTS,
-        "//conditions:default": [
-            "-Iexternal/curl/lib",
-            "-D_GNU_SOURCE",
-            "-DHAVE_CONFIG_H",
-            "-DCURL_DISABLE_LIBCURL_OPTION",
-            "-Wno-string-plus-int",
-        ],
-    }),
-    deps = [":curl"],
-)
-
-genrule(
-    name = "configure",
-    outs = ["include/curl_config.h"],
-    cmd = "\n".join([
-        "cat <<'EOF' >$@",
-        "#ifndef EXTERNAL_CURL_INCLUDE_CURL_CONFIG_H_",
-        "#define EXTERNAL_CURL_INCLUDE_CURL_CONFIG_H_",
-        "",
-        "#if !defined(_WIN32) && !defined(__APPLE__)",
-        "#  include <openssl/opensslv.h>",
-        "#  if defined(OPENSSL_IS_BORINGSSL)",
-        "#    define HAVE_BORINGSSL 1",
-        "#  endif",
-        "#endif",
-        "",
-        "#if defined(_WIN32)",
-        "#  include \"lib/config-win32.h\"",
-        "#  define BUILDING_LIBCURL 1",
-        "#  define CURL_DISABLE_CRYPTO_AUTH 1",
-        "#  define CURL_DISABLE_DICT 1",
-        "#  define CURL_DISABLE_FILE 1",
-        "#  define CURL_DISABLE_GOPHER 1",
-        "#  define CURL_DISABLE_IMAP 1",
-        "#  define CURL_DISABLE_LDAP 1",
-        "#  define CURL_DISABLE_LDAPS 1",
-        "#  define CURL_DISABLE_POP3 1",
-        "#  define CURL_PULL_WS2TCPIP_H 1",
-        "#  define CURL_DISABLE_SMTP 1",
-        "#  define CURL_DISABLE_TELNET 1",
-        "#  define CURL_DISABLE_TFTP 1",
-        "#  define CURL_PULL_WS2TCPIP_H 1",
-        "#  define USE_WINDOWS_SSPI 1",
-        "#  define USE_WIN32_IDN 1",
-        "#  define USE_SCHANNEL 1",
-        "#  define WANT_IDN_PROTOTYPES 1",
-        "#elif defined(__APPLE__)",
-        "#  define HAVE_FSETXATTR_6 1",
-        "#  define HAVE_SETMODE 1",
-        "#  define HAVE_SYS_FILIO_H 1",
-        "#  define HAVE_SYS_SOCKIO_H 1",
-        "#  define OS \"x86_64-apple-darwin15.5.0\"",
-        "#  define USE_SECTRANSP 1",
-        "#else",
-        "#  define CURL_CA_BUNDLE \"/etc/ssl/certs/ca-certificates.crt\"",
-        "#  define GETSERVBYPORT_R_ARGS 6",
-        "#  define GETSERVBYPORT_R_BUFSIZE 4096",
-        "#  define HAVE_BORINGSSL 1",
-        "#  define HAVE_CLOCK_GETTIME_MONOTONIC 1",
-        "#  define HAVE_CRYPTO_CLEANUP_ALL_EX_DATA 1",
-        "#  define HAVE_FSETXATTR_5 1",
-        "#  define HAVE_GETHOSTBYADDR_R 1",
-        "#  define HAVE_GETHOSTBYADDR_R_8 1",
-        "#  define HAVE_GETHOSTBYNAME_R 1",
-        "#  define HAVE_GETHOSTBYNAME_R_6 1",
-        "#  define HAVE_GETSERVBYPORT_R 1",
-        "#  define HAVE_LIBSSL 1",
-        "#  define HAVE_MALLOC_H 1",
-        "#  define HAVE_MSG_NOSIGNAL 1",
-        "#  define HAVE_OPENSSL_CRYPTO_H 1",
-        "#  define HAVE_OPENSSL_ERR_H 1",
-        "#  define HAVE_OPENSSL_PEM_H 1",
-        "#  define HAVE_OPENSSL_PKCS12_H 1",
-        "#  define HAVE_OPENSSL_RSA_H 1",
-        "#  define HAVE_OPENSSL_SSL_H 1",
-        "#  define HAVE_OPENSSL_X509_H 1",
-        "#  define HAVE_RAND_EGD 1",
-        "#  define HAVE_RAND_STATUS 1",
-        "#  define HAVE_SSL_GET_SHUTDOWN 1",
-        "#  define HAVE_TERMIOS_H 1",
-        "#  define OS \"x86_64-pc-linux-gnu\"",
-        "#  define RANDOM_FILE \"/dev/urandom\"",
-        "#  define USE_OPENSSL 1",
-        "#endif",
-        "",
-        "#if !defined(_WIN32)",
-        "#  define CURL_DISABLE_DICT 1",
-        "#  define CURL_DISABLE_FILE 1",
-        "#  define CURL_DISABLE_GOPHER 1",
-        "#  define CURL_DISABLE_IMAP 1",
-        "#  define CURL_DISABLE_LDAP 1",
-        "#  define CURL_DISABLE_LDAPS 1",
-        "#  define CURL_DISABLE_POP3 1",
-        "#  define CURL_DISABLE_SMTP 1",
-        "#  define CURL_DISABLE_TELNET 1",
-        "#  define CURL_DISABLE_TFTP 1",
-        "#  define CURL_EXTERN_SYMBOL __attribute__ ((__visibility__ (\"default\")))",
-        "#  define ENABLE_IPV6 1",
-        "#  define GETHOSTNAME_TYPE_ARG2 size_t",
-        "#  define GETNAMEINFO_QUAL_ARG1 const",
-        "#  define GETNAMEINFO_TYPE_ARG1 struct sockaddr *",
-        "#  define GETNAMEINFO_TYPE_ARG2 socklen_t",
-        "#  define GETNAMEINFO_TYPE_ARG46 socklen_t",
-        "#  define GETNAMEINFO_TYPE_ARG7 int",
-        "#  define HAVE_ALARM 1",
-        "#  define HAVE_ALLOCA_H 1",
-        "#  define HAVE_ARPA_INET_H 1",
-        "#  define HAVE_ARPA_TFTP_H 1",
-        "#  define HAVE_ASSERT_H 1",
-        "#  define HAVE_BASENAME 1",
-        "#  define HAVE_BOOL_T 1",
-        "#  define HAVE_CONNECT 1",
-        "#  define HAVE_DLFCN_H 1",
-        "#  define HAVE_ERRNO_H 1",
-        "#  define HAVE_FCNTL 1",
-        "#  define HAVE_FCNTL_H 1",
-        "#  define HAVE_FCNTL_O_NONBLOCK 1",
-        "#  define HAVE_FDOPEN 1",
-        "#  define HAVE_FORK 1",
-        "#  define HAVE_FREEADDRINFO 1",
-        "#  define HAVE_FREEIFADDRS 1",
-        "#  if !defined(__ANDROID__)",
-        "#    define HAVE_FSETXATTR 1",
-        "#  endif",
-        "#  define HAVE_FTRUNCATE 1",
-        "#  define HAVE_GAI_STRERROR 1",
-        "#  define HAVE_GETADDRINFO 1",
-        "#  define HAVE_GETADDRINFO_THREADSAFE 1",
-        "#  define HAVE_GETEUID 1",
-        "#  define HAVE_GETHOSTBYADDR 1",
-        "#  define HAVE_GETHOSTBYNAME 1",
-        "#  define HAVE_GETHOSTNAME 1",
-        "#  if !defined(__ANDROID__)",
-        "#    define HAVE_GETIFADDRS 1",
-        "#  endif",
-        "#  define HAVE_GETNAMEINFO 1",
-        "#  define HAVE_GETPPID 1",
-        "#  define HAVE_GETPROTOBYNAME 1",
-        "#  define HAVE_GETPWUID 1",
-        "#  if !defined(__ANDROID__)",
-        "#    define HAVE_GETPWUID_R 1",
-        "#  endif",
-        "#  define HAVE_GETRLIMIT 1",
-        "#  define HAVE_GETTIMEOFDAY 1",
-        "#  define HAVE_GMTIME_R 1",
-        "#  if !defined(__ANDROID__)",
-        "#    define HAVE_IFADDRS_H 1",
-        "#  endif",
-        "#  define HAVE_IF_NAMETOINDEX 1",
-        "#  define HAVE_INET_ADDR 1",
-        "#  define HAVE_INET_NTOP 1",
-        "#  define HAVE_INET_PTON 1",
-        "#  define HAVE_INTTYPES_H 1",
-        "#  define HAVE_IOCTL 1",
-        "#  define HAVE_IOCTL_FIONBIO 1",
-        "#  define HAVE_IOCTL_SIOCGIFADDR 1",
-        "#  define HAVE_LIBGEN_H 1",
-        "#  define HAVE_LIBZ 1",
-        "#  define HAVE_LIMITS_H 1",
-        "#  define HAVE_LL 1",
-        "#  define HAVE_LOCALE_H 1",
-        "#  define HAVE_LOCALTIME_R 1",
-        "#  define HAVE_LONGLONG 1",
-        "#  define HAVE_MEMORY_H 1",
-        "#  define HAVE_NETDB_H 1",
-        "#  define HAVE_NETINET_IN_H 1",
-        "#  define HAVE_NETINET_TCP_H 1",
-        "#  define HAVE_NET_IF_H 1",
-        "#  define HAVE_PERROR 1",
-        "#  define HAVE_PIPE 1",
-        "#  define HAVE_POLL 1",
-        "#  define HAVE_POLL_FINE 1",
-        "#  define HAVE_POLL_H 1",
-        "#  define HAVE_POSIX_STRERROR_R 1",
-        "#  define HAVE_PWD_H 1",
-        "#  define HAVE_RECV 1",
-        "#  define HAVE_SELECT 1",
-        "#  define HAVE_SEND 1",
-        "#  define HAVE_SETJMP_H 1",
-        "#  define HAVE_SETLOCALE 1",
-        "#  define HAVE_SETRLIMIT 1",
-        "#  define HAVE_SETSOCKOPT 1",
-        "#  define HAVE_SGTTY_H 1",
-        "#  define HAVE_SIGACTION 1",
-        "#  define HAVE_SIGINTERRUPT 1",
-        "#  define HAVE_SIGNAL 1",
-        "#  define HAVE_SIGNAL_H 1",
-        "#  define HAVE_SIGSETJMP 1",
-        "#  define HAVE_SIG_ATOMIC_T 1",
-        "#  define HAVE_SOCKADDR_IN6_SIN6_SCOPE_ID 1",
-        "#  define HAVE_SOCKET 1",
-        "#  define HAVE_SOCKETPAIR 1",
-        "#  define HAVE_STDBOOL_H 1",
-        "#  define HAVE_STDINT_H 1",
-        "#  define HAVE_STDIO_H 1",
-        "#  define HAVE_STDLIB_H 1",
-        "#  define HAVE_STRCASECMP 1",
-        "#  define HAVE_STRDUP 1",
-        "#  define HAVE_STRERROR_R 1",
-        "#  define HAVE_STRINGS_H 1",
-        "#  define HAVE_STRING_H 1",
-        "#  define HAVE_STRNCASECMP 1",
-        "#  define HAVE_STRSTR 1",
-        "#  define HAVE_STRTOK_R 1",
-        "#  define HAVE_STRTOLL 1",
-        "#  define HAVE_STRUCT_SOCKADDR_STORAGE 1",
-        "#  define HAVE_STRUCT_TIMEVAL 1",
-        "#  define HAVE_SYS_IOCTL_H 1",
-        "#  define HAVE_SYS_PARAM_H 1",
-        "#  define HAVE_SYS_POLL_H 1",
-        "#  define HAVE_SYS_RESOURCE_H 1",
-        "#  define HAVE_SYS_SELECT_H 1",
-        "#  define HAVE_SYS_SOCKET_H 1",
-        "#  define HAVE_SYS_STAT_H 1",
-        "#  define HAVE_SYS_TIME_H 1",
-        "#  define HAVE_SYS_TYPES_H 1",
-        "#  define HAVE_SYS_UIO_H 1",
-        "#  define HAVE_SYS_UN_H 1",
-        "#  define HAVE_SYS_WAIT_H 1",
-        "#  define HAVE_SYS_XATTR_H 1",
-        "#  define HAVE_TIME_H 1",
-        "#  define HAVE_UNAME 1",
-        "#  define HAVE_UNISTD_H 1",
-        "#  define HAVE_UTIME 1",
-        "#  define HAVE_UTIME_H 1",
-        "#  define HAVE_VARIADIC_MACROS_C99 1",
-        "#  define HAVE_VARIADIC_MACROS_GCC 1",
-        "#  define HAVE_WRITABLE_ARGV 1",
-        "#  define HAVE_WRITEV 1",
-        "#  define HAVE_ZLIB_H 1",
-        "#  define LT_OBJDIR \".libs/\"",
-        "#  define PACKAGE \"curl\"",
-        "#  define PACKAGE_BUGREPORT \"a suitable curl mailing list: https://curl.haxx.se/mail/\"",
-        "#  define PACKAGE_NAME \"curl\"",
-        "#  define PACKAGE_STRING \"curl -\"",
-        "#  define PACKAGE_TARNAME \"curl\"",
-        "#  define PACKAGE_URL \"\"",
-        "#  define PACKAGE_VERSION \"-\"",
-        "#  define RECV_TYPE_ARG1 int",
-        "#  define RECV_TYPE_ARG2 void *",
-        "#  define RECV_TYPE_ARG3 size_t",
-        "#  define RECV_TYPE_ARG4 int",
-        "#  define RECV_TYPE_RETV ssize_t",
-        "#  define RETSIGTYPE void",
-        "#  define SELECT_QUAL_ARG5",
-        "#  define SELECT_TYPE_ARG1 int",
-        "#  define SELECT_TYPE_ARG234 fd_set *",
-        "#  define SELECT_TYPE_ARG5 struct timeval *",
-        "#  define SELECT_TYPE_RETV int",
-        "#  define SEND_QUAL_ARG2 const",
-        "#  define SEND_TYPE_ARG1 int",
-        "#  define SEND_TYPE_ARG2 void *",
-        "#  define SEND_TYPE_ARG3 size_t",
-        "#  define SEND_TYPE_ARG4 int",
-        "#  define SEND_TYPE_RETV ssize_t",
-        "#  define SIZEOF_INT 4",
-        "#  define SIZEOF_LONG 8",
-        "#  define SIZEOF_OFF_T 8",
-        "#  define SIZEOF_CURL_OFF_T 8",
-        "#  define SIZEOF_SHORT 2",
-        "#  define SIZEOF_SIZE_T 8",
-        "#  define SIZEOF_TIME_T 8",
-        "#  define SIZEOF_VOIDP 8",
-        "#  define STDC_HEADERS 1",
-        "#  define STRERROR_R_TYPE_ARG3 size_t",
-        "#  define TIME_WITH_SYS_TIME 1",
-        "#  define VERSION \"-\"",
-        "#  ifndef _DARWIN_USE_64_BIT_INODE",
-        "#    define _DARWIN_USE_64_BIT_INODE 1",
-        "#  endif",
-        "#endif",
-        "",
-        "#endif  // EXTERNAL_CURL_INCLUDE_CURL_CONFIG_H_",
-        "EOF",
-    ]),
-)
diff --git a/third_party/xla/third_party/cython.BUILD b/third_party/xla/third_party/cython.BUILD
deleted file mode 100644
index ac8c331..0000000
--- a/third_party/xla/third_party/cython.BUILD
+++ /dev/null
@@ -1,28 +0,0 @@
-# Modified version of @cython//:BUILD.bazel
-
-py_library(
-    name = "cython_lib",
-    srcs = glob(
-        ["Cython/**/*.py"],
-        exclude = [
-            "**/Tests/*.py",
-        ],
-    ) + ["cython.py"],
-    data = glob([
-        "Cython/**/*.pyx",
-        "Cython/Utility/*.*",
-        "Cython/Includes/**/*.pxd",
-    ]),
-    srcs_version = "PY3",
-    visibility = ["//visibility:public"],
-)
-
-# May not be named "cython", since that conflicts with Cython/ on OSX
-py_binary(
-    name = "cython_binary",
-    srcs = ["cython.py"],
-    main = "cython.py",
-    srcs_version = "PY3",
-    visibility = ["//visibility:public"],
-    deps = ["cython_lib"],
-)
diff --git a/third_party/xla/third_party/eigen3/BUILD b/third_party/xla/third_party/eigen3/BUILD
deleted file mode 100644
index 631cc89..0000000
--- a/third_party/xla/third_party/eigen3/BUILD
+++ /dev/null
@@ -1,71 +0,0 @@
-# Description:
-#   Eigen is a C++ template library for linear algebra: vectors,
-#   matrices, and related algorithms.
-# This is the BUILD file with extra code to patch into @eigen_archive.
-
-load("//third_party/mkl:build_defs.bzl", "if_mkl")
-
-licenses([
-    # Note: Eigen is an MPL2 library that includes GPL v3 and LGPL v2.1+ code.
-    #       We've taken special care to not reference any restricted code.
-    "reciprocal",  # MPL2
-    "notice",  # Portions BSD
-])
-
-exports_files(
-    ["LICENSE"],
-    visibility = ["//visibility:public"],
-)
-
-EIGEN3_THIRD_PARTY_HEADERS = [
-    "Eigen/Core",
-    "Eigen/LU",
-    "Eigen/Cholesky",
-    "Eigen/Eigenvalues",
-    "Eigen/OrderingMethods",
-    "Eigen/QR",
-    "Eigen/SparseCholesky",
-    "Eigen/SparseCore",
-    "Eigen/SVD",
-    "unsupported/Eigen/MatrixFunctions",
-    "unsupported/Eigen/SpecialFunctions",
-    "unsupported/Eigen/CXX11/ThreadPool",
-    "unsupported/Eigen/CXX11/Tensor",
-]
-
-cc_library(
-    name = "eigen3",
-    hdrs = EIGEN3_THIRD_PARTY_HEADERS,
-    includes = if_mkl(["./mkl_include"]),
-    visibility = ["//visibility:public"],
-    deps = [
-        "@eigen_archive//:eigen3_internal",
-    ],
-)
-
-filegroup(
-    name = "eigen_third_party_header_files",
-    srcs = EIGEN3_THIRD_PARTY_HEADERS,
-    visibility = ["//visibility:public"],
-)
-
-genrule(
-    name = "install_eigen_headers",
-    srcs = [
-        "@eigen_archive//:eigen_header_files",
-        "@eigen_archive//:eigen_source_files",
-        ":eigen_third_party_header_files",
-    ],
-    outs = ["include"],
-    cmd = """
-    mkdir $@
-    for f in $(SRCS); do
-      d="$${f%/*}"
-      d="$${d#*external/eigen_archive/}"
-
-      mkdir -p "$@/$${d}"
-      cp "$${f}" "$@/$${d}/"
-    done
-    """,
-    tags = ["manual"],
-)
diff --git a/third_party/xla/third_party/eigen3/Eigen/Cholesky b/third_party/xla/third_party/eigen3/Eigen/Cholesky
deleted file mode 100644
index c199a025..0000000
--- a/third_party/xla/third_party/eigen3/Eigen/Cholesky
+++ /dev/null
@@ -1 +0,0 @@
-#include "Eigen/Cholesky"
diff --git a/third_party/xla/third_party/eigen3/Eigen/Core b/third_party/xla/third_party/eigen3/Eigen/Core
deleted file mode 100644
index d4b0367..0000000
--- a/third_party/xla/third_party/eigen3/Eigen/Core
+++ /dev/null
@@ -1 +0,0 @@
-#include "Eigen/Core"
diff --git a/third_party/xla/third_party/eigen3/Eigen/Eigenvalues b/third_party/xla/third_party/eigen3/Eigen/Eigenvalues
deleted file mode 100644
index bf739b9..0000000
--- a/third_party/xla/third_party/eigen3/Eigen/Eigenvalues
+++ /dev/null
@@ -1 +0,0 @@
-#include "Eigen/Eigenvalues"
diff --git a/third_party/xla/third_party/eigen3/Eigen/LU b/third_party/xla/third_party/eigen3/Eigen/LU
deleted file mode 100644
index 536149c..0000000
--- a/third_party/xla/third_party/eigen3/Eigen/LU
+++ /dev/null
@@ -1 +0,0 @@
-#include "Eigen/LU"
diff --git a/third_party/xla/third_party/eigen3/Eigen/OrderingMethods b/third_party/xla/third_party/eigen3/Eigen/OrderingMethods
deleted file mode 100644
index 190fc22..0000000
--- a/third_party/xla/third_party/eigen3/Eigen/OrderingMethods
+++ /dev/null
@@ -1 +0,0 @@
-#include "Eigen/OrderingMethods"
\ No newline at end of file
diff --git a/third_party/xla/third_party/eigen3/Eigen/QR b/third_party/xla/third_party/eigen3/Eigen/QR
deleted file mode 100644
index be067d3..0000000
--- a/third_party/xla/third_party/eigen3/Eigen/QR
+++ /dev/null
@@ -1 +0,0 @@
-#include "Eigen/QR"
diff --git a/third_party/xla/third_party/eigen3/Eigen/SVD b/third_party/xla/third_party/eigen3/Eigen/SVD
deleted file mode 100644
index eecf47c..0000000
--- a/third_party/xla/third_party/eigen3/Eigen/SVD
+++ /dev/null
@@ -1 +0,0 @@
-#include "Eigen/SVD"
diff --git a/third_party/xla/third_party/eigen3/Eigen/SparseCholesky b/third_party/xla/third_party/eigen3/Eigen/SparseCholesky
deleted file mode 100644
index a6d362b..0000000
--- a/third_party/xla/third_party/eigen3/Eigen/SparseCholesky
+++ /dev/null
@@ -1 +0,0 @@
-#include "Eigen/SparseCholesky"
\ No newline at end of file
diff --git a/third_party/xla/third_party/eigen3/Eigen/SparseCore b/third_party/xla/third_party/eigen3/Eigen/SparseCore
deleted file mode 100644
index 3c60745..0000000
--- a/third_party/xla/third_party/eigen3/Eigen/SparseCore
+++ /dev/null
@@ -1 +0,0 @@
-#include "Eigen/SparseCore"
\ No newline at end of file
diff --git a/third_party/xla/third_party/eigen3/LICENSE b/third_party/xla/third_party/eigen3/LICENSE
deleted file mode 100644
index eff7afb..0000000
--- a/third_party/xla/third_party/eigen3/LICENSE
+++ /dev/null
@@ -1,1072 +0,0 @@
-Eigen is primarily MPL2 licensed. See COPYING.MPL2 and these links:
-  http://www.mozilla.org/MPL/2.0/
-  http://www.mozilla.org/MPL/2.0/FAQ.html
-
-Some files contain third-party code under BSD or LGPL licenses, whence
-the other COPYING.* files here.
-
-All the LGPL code is either LGPL 2.1-only, or LGPL 2.1-or-later.
-For this reason, the COPYING.LGPL file contains the LGPL 2.1 text.
-
-If you want to guarantee that the Eigen code that you are #including
-is licensed under the MPL2 and possibly more permissive licenses (like
-BSD), #define this preprocessor symbol: EIGEN_MPL2_ONLY 
-For example, with most compilers, you could add this to your project
-      CXXFLAGS: -DEIGEN_MPL2_ONLY 
-This will cause a compilation error to be generated if you #include
-any code that is LGPL licensed.
-
-----------------------------------------------------------------------
-Following applies to:
-./test/mapstaticmethods.cpp
-./test/schur_real.cpp
-./test/prec_inverse_4x4.cpp
-./test/smallvectors.cpp
-./test/redux.cpp
-./test/special_numbers.cpp
-./test/adjoint.cpp
-./test/resize.cpp
-./test/mixingtypes.cpp
-./test/product_trmv.cpp
-./test/sparse_solvers.cpp
-./test/cholesky.cpp
-./test/geo_quaternion.cpp
-./test/miscmatrices.cpp
-./test/stddeque.cpp
-./test/integer_types.cpp
-./test/product_large.cpp
-./test/eigensolver_generic.cpp
-./test/householder.cpp
-./test/geo_orthomethods.cpp
-./test/array_for_matrix.cpp
-./test/sparseLM.cpp
-./test/upperbidiagonalization.cpp
-./test/nomalloc.cpp
-./test/packetmath.cpp
-./test/jacobisvd.cpp
-./test/geo_transformations.cpp
-./test/swap.cpp
-./test/eigensolver_selfadjoint.cpp
-./test/inverse.cpp
-./test/product_selfadjoint.cpp
-./test/product_trsolve.cpp
-./test/product_extra.cpp
-./test/sparse_solver.h
-./test/mapstride.cpp
-./test/mapped_matrix.cpp
-./test/geo_eulerangles.cpp
-./test/eigen2support.cpp
-./test/denseLM.cpp
-./test/stdvector.cpp
-./test/nesting_ops.cpp
-./test/sparse_permutations.cpp
-./test/zerosized.cpp
-./test/exceptions.cpp
-./test/vectorwiseop.cpp
-./test/cwiseop.cpp
-./test/basicstuff.cpp
-./test/product_trmm.cpp
-./test/linearstructure.cpp
-./test/sparse_product.cpp
-./test/stdvector_overload.cpp
-./test/stable_norm.cpp
-./test/umeyama.cpp
-./test/unalignedcount.cpp
-./test/triangular.cpp
-./test/product_mmtr.cpp
-./test/sparse_basic.cpp
-./test/sparse_vector.cpp
-./test/meta.cpp
-./test/real_qz.cpp
-./test/ref.cpp
-./test/eigensolver_complex.cpp
-./test/cholmod_support.cpp
-./test/conjugate_gradient.cpp
-./test/sparse.h
-./test/simplicial_cholesky.cpp
-./test/bicgstab.cpp
-./test/dynalloc.cpp
-./test/product_notemporary.cpp
-./test/geo_hyperplane.cpp
-./test/lu.cpp
-./test/qr.cpp
-./test/hessenberg.cpp
-./test/sizeof.cpp
-./test/main.h
-./test/selfadjoint.cpp
-./test/permutationmatrices.cpp
-./test/superlu_support.cpp
-./test/qtvector.cpp
-./test/geo_homogeneous.cpp
-./test/determinant.cpp
-./test/array_reverse.cpp
-./test/unalignedassert.cpp
-./test/stdlist.cpp
-./test/product_symm.cpp
-./test/corners.cpp
-./test/dontalign.cpp
-./test/visitor.cpp
-./test/geo_alignedbox.cpp
-./test/diagonalmatrices.cpp
-./test/product_small.cpp
-./test/eigensolver_generalized_real.cpp
-./test/umfpack_support.cpp
-./test/first_aligned.cpp
-./test/qr_fullpivoting.cpp
-./test/array_replicate.cpp
-./test/geo_parametrizedline.cpp
-./test/eigen2/eigen2_unalignedassert.cpp
-./test/eigen2/eigen2_prec_inverse_4x4.cpp
-./test/eigen2/eigen2_alignedbox.cpp
-./test/eigen2/eigen2_sparse_product.cpp
-./test/eigen2/eigen2_meta.cpp
-./test/eigen2/eigen2_nomalloc.cpp
-./test/eigen2/eigen2_visitor.cpp
-./test/eigen2/eigen2_packetmath.cpp
-./test/eigen2/eigen2_svd.cpp
-./test/eigen2/eigen2_mixingtypes.cpp
-./test/eigen2/eigen2_qr.cpp
-./test/eigen2/eigen2_cwiseop.cpp
-./test/eigen2/eigen2_geometry_with_eigen2_prefix.cpp
-./test/eigen2/eigen2_smallvectors.cpp
-./test/eigen2/eigen2_commainitializer.cpp
-./test/eigen2/eigen2_sparse_solvers.cpp
-./test/eigen2/eigen2_hyperplane.cpp
-./test/eigen2/eigen2_eigensolver.cpp
-./test/eigen2/eigen2_linearstructure.cpp
-./test/eigen2/eigen2_sizeof.cpp
-./test/eigen2/eigen2_parametrizedline.cpp
-./test/eigen2/eigen2_lu.cpp
-./test/eigen2/eigen2_adjoint.cpp
-./test/eigen2/eigen2_geometry.cpp
-./test/eigen2/eigen2_stdvector.cpp
-./test/eigen2/eigen2_newstdvector.cpp
-./test/eigen2/eigen2_submatrices.cpp
-./test/eigen2/sparse.h
-./test/eigen2/eigen2_swap.cpp
-./test/eigen2/eigen2_triangular.cpp
-./test/eigen2/eigen2_basicstuff.cpp
-./test/eigen2/gsl_helper.h
-./test/eigen2/eigen2_dynalloc.cpp
-./test/eigen2/eigen2_array.cpp
-./test/eigen2/eigen2_map.cpp
-./test/eigen2/main.h
-./test/eigen2/eigen2_miscmatrices.cpp
-./test/eigen2/eigen2_product_large.cpp
-./test/eigen2/eigen2_first_aligned.cpp
-./test/eigen2/eigen2_cholesky.cpp
-./test/eigen2/eigen2_determinant.cpp
-./test/eigen2/eigen2_sum.cpp
-./test/eigen2/eigen2_inverse.cpp
-./test/eigen2/eigen2_regression.cpp
-./test/eigen2/eigen2_product_small.cpp
-./test/eigen2/eigen2_qtvector.cpp
-./test/eigen2/eigen2_sparse_vector.cpp
-./test/eigen2/product.h
-./test/eigen2/eigen2_sparse_basic.cpp
-./test/eigen2/eigen2_bug_132.cpp
-./test/array.cpp
-./test/product_syrk.cpp
-./test/commainitializer.cpp
-./test/conservative_resize.cpp
-./test/qr_colpivoting.cpp
-./test/nullary.cpp
-./test/bandmatrix.cpp
-./test/pastix_support.cpp
-./test/product.h
-./test/block.cpp
-./test/vectorization_logic.cpp
-./test/jacobi.cpp
-./test/diagonal.cpp
-./test/schur_complex.cpp
-./test/sizeoverflow.cpp
-./bench/BenchTimer.h
-./bench/benchFFT.cpp
-./bench/eig33.cpp
-./bench/spbench/spbenchsolver.h
-./bench/spbench/spbenchstyle.h
-./lapack/complex_double.cpp
-./lapack/cholesky.cpp
-./lapack/lapack_common.h
-./lapack/eigenvalues.cpp
-./lapack/single.cpp
-./lapack/lu.cpp
-./lapack/complex_single.cpp
-./lapack/double.cpp
-./demos/mix_eigen_and_c/binary_library.cpp
-./demos/mix_eigen_and_c/binary_library.h
-./demos/mix_eigen_and_c/example.c
-./demos/mandelbrot/mandelbrot.cpp
-./demos/mandelbrot/mandelbrot.h
-./demos/opengl/icosphere.cpp
-./demos/opengl/icosphere.h
-./demos/opengl/camera.cpp
-./demos/opengl/quaternion_demo.h
-./demos/opengl/camera.h
-./demos/opengl/trackball.h
-./demos/opengl/gpuhelper.h
-./demos/opengl/trackball.cpp
-./demos/opengl/gpuhelper.cpp
-./demos/opengl/quaternion_demo.cpp
-./debug/gdb/printers.py
-./unsupported/test/minres.cpp
-./unsupported/test/openglsupport.cpp
-./unsupported/test/jacobisvd.cpp
-./unsupported/test/dgmres.cpp
-./unsupported/test/matrix_square_root.cpp
-./unsupported/test/bdcsvd.cpp
-./unsupported/test/matrix_exponential.cpp
-./unsupported/test/forward_adolc.cpp
-./unsupported/test/polynomialsolver.cpp
-./unsupported/test/matrix_function.cpp
-./unsupported/test/sparse_extra.cpp
-./unsupported/test/matrix_functions.h
-./unsupported/test/svd_common.h
-./unsupported/test/FFTW.cpp
-./unsupported/test/alignedvector3.cpp
-./unsupported/test/autodiff.cpp
-./unsupported/test/gmres.cpp
-./unsupported/test/BVH.cpp
-./unsupported/test/levenberg_marquardt.cpp
-./unsupported/test/matrix_power.cpp
-./unsupported/test/kronecker_product.cpp
-./unsupported/test/splines.cpp
-./unsupported/test/polynomialutils.cpp
-./unsupported/bench/bench_svd.cpp
-./unsupported/Eigen/IterativeSolvers
-./unsupported/Eigen/src/IterativeSolvers/DGMRES.h
-./unsupported/Eigen/src/IterativeSolvers/IncompleteLU.h
-./unsupported/Eigen/src/IterativeSolvers/GMRES.h
-./unsupported/Eigen/src/IterativeSolvers/IncompleteCholesky.h
-./unsupported/Eigen/src/IterativeSolvers/Scaling.h
-./unsupported/Eigen/src/IterativeSolvers/MINRES.h
-./unsupported/Eigen/src/SparseExtra/RandomSetter.h
-./unsupported/Eigen/src/SparseExtra/MatrixMarketIterator.h
-./unsupported/Eigen/src/SparseExtra/DynamicSparseMatrix.h
-./unsupported/Eigen/src/SparseExtra/MarketIO.h
-./unsupported/Eigen/src/SparseExtra/BlockOfDynamicSparseMatrix.h
-./unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h
-./unsupported/Eigen/src/NonLinearOptimization/LevenbergMarquardt.h
-./unsupported/Eigen/src/NonLinearOptimization/HybridNonLinearSolver.h
-./unsupported/Eigen/src/BVH/BVAlgorithms.h
-./unsupported/Eigen/src/BVH/KdBVH.h
-./unsupported/Eigen/src/AutoDiff/AutoDiffScalar.h
-./unsupported/Eigen/src/AutoDiff/AutoDiffJacobian.h
-./unsupported/Eigen/src/AutoDiff/AutoDiffVector.h
-./unsupported/Eigen/src/Splines/Spline.h
-./unsupported/Eigen/src/Splines/SplineFitting.h
-./unsupported/Eigen/src/Splines/SplineFwd.h
-./unsupported/Eigen/src/SVD/JacobiSVD.h
-./unsupported/Eigen/src/SVD/BDCSVD.h
-./unsupported/Eigen/src/SVD/SVDBase.h
-./unsupported/Eigen/src/MatrixFunctions/MatrixFunction.h
-./unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h
-./unsupported/Eigen/src/MatrixFunctions/MatrixLogarithm.h
-./unsupported/Eigen/src/MatrixFunctions/StemFunction.h
-./unsupported/Eigen/src/MatrixFunctions/MatrixPower.h
-./unsupported/Eigen/src/MatrixFunctions/MatrixExponential.h
-./unsupported/Eigen/src/MatrixFunctions/MatrixFunctionAtomic.h
-./unsupported/Eigen/src/MoreVectorization/MathFunctions.h
-./unsupported/Eigen/src/LevenbergMarquardt/LevenbergMarquardt.h
-./unsupported/Eigen/src/FFT/ei_fftw_impl.h
-./unsupported/Eigen/src/FFT/ei_kissfft_impl.h
-./unsupported/Eigen/src/Polynomials/PolynomialSolver.h
-./unsupported/Eigen/src/Polynomials/Companion.h
-./unsupported/Eigen/src/Polynomials/PolynomialUtils.h
-./unsupported/Eigen/src/NumericalDiff/NumericalDiff.h
-./unsupported/Eigen/src/Skyline/SkylineProduct.h
-./unsupported/Eigen/src/Skyline/SkylineMatrixBase.h
-./unsupported/Eigen/src/Skyline/SkylineStorage.h
-./unsupported/Eigen/src/Skyline/SkylineUtil.h
-./unsupported/Eigen/src/Skyline/SkylineInplaceLU.h
-./unsupported/Eigen/src/Skyline/SkylineMatrix.h
-./unsupported/Eigen/SparseExtra
-./unsupported/Eigen/AdolcForward
-./unsupported/Eigen/KroneckerProduct
-./unsupported/Eigen/NonLinearOptimization
-./unsupported/Eigen/BVH
-./unsupported/Eigen/OpenGLSupport
-./unsupported/Eigen/ArpackSupport
-./unsupported/Eigen/AutoDiff
-./unsupported/Eigen/Splines
-./unsupported/Eigen/MPRealSupport
-./unsupported/Eigen/MatrixFunctions
-./unsupported/Eigen/MoreVectorization
-./unsupported/Eigen/LevenbergMarquardt
-./unsupported/Eigen/AlignedVector3
-./unsupported/Eigen/FFT
-./unsupported/Eigen/Polynomials
-./unsupported/Eigen/NumericalDiff
-./unsupported/Eigen/Skyline
-./COPYING.README
-./COPYING.README
-./LICENSE
-./LICENSE
-./LICENSE
-./Eigen/Eigen2Support
-./Eigen/src/Eigen2Support/VectorBlock.h
-./Eigen/src/Eigen2Support/Cwise.h
-./Eigen/src/Eigen2Support/Minor.h
-./Eigen/src/Eigen2Support/Lazy.h
-./Eigen/src/Eigen2Support/Memory.h
-./Eigen/src/Eigen2Support/MathFunctions.h
-./Eigen/src/Eigen2Support/Geometry/AlignedBox.h
-./Eigen/src/Eigen2Support/Geometry/Hyperplane.h
-./Eigen/src/Eigen2Support/Geometry/Quaternion.h
-./Eigen/src/Eigen2Support/Geometry/Rotation2D.h
-./Eigen/src/Eigen2Support/Geometry/ParametrizedLine.h
-./Eigen/src/Eigen2Support/Geometry/RotationBase.h
-./Eigen/src/Eigen2Support/Geometry/Translation.h
-./Eigen/src/Eigen2Support/Geometry/Scaling.h
-./Eigen/src/Eigen2Support/Geometry/AngleAxis.h
-./Eigen/src/Eigen2Support/Geometry/Transform.h
-./Eigen/src/Eigen2Support/TriangularSolver.h
-./Eigen/src/Eigen2Support/LU.h
-./Eigen/src/Eigen2Support/QR.h
-./Eigen/src/Eigen2Support/SVD.h
-./Eigen/src/Eigen2Support/Meta.h
-./Eigen/src/Eigen2Support/Block.h
-./Eigen/src/Eigen2Support/Macros.h
-./Eigen/src/Eigen2Support/LeastSquares.h
-./Eigen/src/Eigen2Support/CwiseOperators.h
-./Eigen/src/Jacobi/Jacobi.h
-./Eigen/src/misc/Kernel.h
-./Eigen/src/misc/SparseSolve.h
-./Eigen/src/misc/Solve.h
-./Eigen/src/misc/Image.h
-./Eigen/src/SparseCore/SparseColEtree.h
-./Eigen/src/SparseCore/SparseTranspose.h
-./Eigen/src/SparseCore/SparseUtil.h
-./Eigen/src/SparseCore/SparseCwiseBinaryOp.h
-./Eigen/src/SparseCore/SparseDiagonalProduct.h
-./Eigen/src/SparseCore/SparseProduct.h
-./Eigen/src/SparseCore/SparseDot.h
-./Eigen/src/SparseCore/SparseCwiseUnaryOp.h
-./Eigen/src/SparseCore/SparseSparseProductWithPruning.h
-./Eigen/src/SparseCore/SparseBlock.h
-./Eigen/src/SparseCore/SparseDenseProduct.h
-./Eigen/src/SparseCore/CompressedStorage.h
-./Eigen/src/SparseCore/SparseMatrixBase.h
-./Eigen/src/SparseCore/MappedSparseMatrix.h
-./Eigen/src/SparseCore/SparseTriangularView.h
-./Eigen/src/SparseCore/SparseView.h
-./Eigen/src/SparseCore/SparseFuzzy.h
-./Eigen/src/SparseCore/TriangularSolver.h
-./Eigen/src/SparseCore/SparseSelfAdjointView.h
-./Eigen/src/SparseCore/SparseMatrix.h
-./Eigen/src/SparseCore/SparseVector.h
-./Eigen/src/SparseCore/AmbiVector.h
-./Eigen/src/SparseCore/ConservativeSparseSparseProduct.h
-./Eigen/src/SparseCore/SparseRedux.h
-./Eigen/src/SparseCore/SparsePermutation.h
-./Eigen/src/Eigenvalues/RealSchur.h
-./Eigen/src/Eigenvalues/ComplexEigenSolver.h
-./Eigen/src/Eigenvalues/GeneralizedEigenSolver.h
-./Eigen/src/Eigenvalues/ComplexSchur.h
-./Eigen/src/Eigenvalues/RealQZ.h
-./Eigen/src/Eigenvalues/EigenSolver.h
-./Eigen/src/Eigenvalues/HessenbergDecomposition.h
-./Eigen/src/Eigenvalues/GeneralizedSelfAdjointEigenSolver.h
-./Eigen/src/Eigenvalues/Tridiagonalization.h
-./Eigen/src/Eigenvalues/SelfAdjointEigenSolver.h
-./Eigen/src/Eigenvalues/MatrixBaseEigenvalues.h
-./Eigen/src/SuperLUSupport/SuperLUSupport.h
-./Eigen/src/StlSupport/StdDeque.h
-./Eigen/src/StlSupport/StdVector.h
-./Eigen/src/StlSupport/StdList.h
-./Eigen/src/StlSupport/details.h
-./Eigen/src/SparseQR/SparseQR.h
-./Eigen/src/LU/Inverse.h
-./Eigen/src/LU/arch/Inverse_SSE.h
-./Eigen/src/LU/Determinant.h
-./Eigen/src/LU/PartialPivLU.h
-./Eigen/src/LU/FullPivLU.h
-./Eigen/src/UmfPackSupport/UmfPackSupport.h
-./Eigen/src/OrderingMethods/Ordering.h
-./Eigen/src/OrderingMethods/Eigen_Colamd.h
-./Eigen/src/QR/HouseholderQR.h
-./Eigen/src/QR/ColPivHouseholderQR.h
-./Eigen/src/QR/FullPivHouseholderQR.h
-./Eigen/src/SVD/JacobiSVD.h
-./Eigen/src/SVD/UpperBidiagonalization.h
-./Eigen/src/Geometry/OrthoMethods.h
-./Eigen/src/Geometry/AlignedBox.h
-./Eigen/src/Geometry/Hyperplane.h
-./Eigen/src/Geometry/Quaternion.h
-./Eigen/src/Geometry/EulerAngles.h
-./Eigen/src/Geometry/Rotation2D.h
-./Eigen/src/Geometry/ParametrizedLine.h
-./Eigen/src/Geometry/RotationBase.h
-./Eigen/src/Geometry/arch/Geometry_SSE.h
-./Eigen/src/Geometry/Umeyama.h
-./Eigen/src/Geometry/Homogeneous.h
-./Eigen/src/Geometry/Translation.h
-./Eigen/src/Geometry/Scaling.h
-./Eigen/src/Geometry/AngleAxis.h
-./Eigen/src/Geometry/Transform.h
-./Eigen/src/plugins/BlockMethods.h
-./Eigen/src/plugins/CommonCwiseUnaryOps.h
-./Eigen/src/plugins/CommonCwiseBinaryOps.h
-./Eigen/src/plugins/MatrixCwiseUnaryOps.h
-./Eigen/src/plugins/MatrixCwiseBinaryOps.h
-./Eigen/src/Householder/Householder.h
-./Eigen/src/Householder/HouseholderSequence.h
-./Eigen/src/Householder/BlockHouseholder.h
-./Eigen/src/Core/VectorBlock.h
-./Eigen/src/Core/Matrix.h
-./Eigen/src/Core/Ref.h
-./Eigen/src/Core/SelfAdjointView.h
-./Eigen/src/Core/MathFunctions.h
-./Eigen/src/Core/GlobalFunctions.h
-./Eigen/src/Core/MapBase.h
-./Eigen/src/Core/EigenBase.h
-./Eigen/src/Core/GenericPacketMath.h
-./Eigen/src/Core/NestByValue.h
-./Eigen/src/Core/CwiseUnaryOp.h
-./Eigen/src/Core/SolveTriangular.h
-./Eigen/src/Core/Fuzzy.h
-./Eigen/src/Core/Visitor.h
-./Eigen/src/Core/Map.h
-./Eigen/src/Core/NoAlias.h
-./Eigen/src/Core/Diagonal.h
-./Eigen/src/Core/StableNorm.h
-./Eigen/src/Core/CoreIterators.h
-./Eigen/src/Core/products/Parallelizer.h
-./Eigen/src/Core/products/SelfadjointMatrixVector.h
-./Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h
-./Eigen/src/Core/products/TriangularSolverMatrix.h
-./Eigen/src/Core/products/GeneralMatrixMatrix.h
-./Eigen/src/Core/products/SelfadjointProduct.h
-./Eigen/src/Core/products/CoeffBasedProduct.h
-./Eigen/src/Core/products/TriangularMatrixVector.h
-./Eigen/src/Core/products/SelfadjointMatrixMatrix.h
-./Eigen/src/Core/products/TriangularSolverVector.h
-./Eigen/src/Core/products/SelfadjointRank2Update.h
-./Eigen/src/Core/products/GeneralBlockPanelKernel.h
-./Eigen/src/Core/products/GeneralMatrixVector.h
-./Eigen/src/Core/products/TriangularMatrixMatrix.h
-./Eigen/src/Core/Reverse.h
-./Eigen/src/Core/BooleanRedux.h
-./Eigen/src/Core/Replicate.h
-./Eigen/src/Core/arch/AltiVec/PacketMath.h
-./Eigen/src/Core/arch/AltiVec/Complex.h
-./Eigen/src/Core/arch/SSE/PacketMath.h
-./Eigen/src/Core/arch/SSE/Complex.h
-./Eigen/src/Core/arch/SSE/MathFunctions.h
-./Eigen/src/Core/arch/NEON/PacketMath.h
-./Eigen/src/Core/arch/NEON/Complex.h
-./Eigen/src/Core/arch/Default/Settings.h
-./Eigen/src/Core/CwiseUnaryView.h
-./Eigen/src/Core/Array.h
-./Eigen/src/Core/ArrayWrapper.h
-./Eigen/src/Core/Swap.h
-./Eigen/src/Core/Transpositions.h
-./Eigen/src/Core/Random.h
-./Eigen/src/Core/IO.h
-./Eigen/src/Core/SelfCwiseBinaryOp.h
-./Eigen/src/Core/VectorwiseOp.h
-./Eigen/src/Core/Select.h
-./Eigen/src/Core/ArrayBase.h
-./Eigen/src/Core/DenseCoeffsBase.h
-./Eigen/src/Core/DiagonalProduct.h
-./Eigen/src/Core/Assign.h
-./Eigen/src/Core/Redux.h
-./Eigen/src/Core/ForceAlignedAccess.h
-./Eigen/src/Core/BandMatrix.h
-./Eigen/src/Core/PlainObjectBase.h
-./Eigen/src/Core/DenseBase.h
-./Eigen/src/Core/Flagged.h
-./Eigen/src/Core/CwiseBinaryOp.h
-./Eigen/src/Core/ProductBase.h
-./Eigen/src/Core/TriangularMatrix.h
-./Eigen/src/Core/Transpose.h
-./Eigen/src/Core/DiagonalMatrix.h
-./Eigen/src/Core/Dot.h
-./Eigen/src/Core/Functors.h
-./Eigen/src/Core/PermutationMatrix.h
-./Eigen/src/Core/NumTraits.h
-./Eigen/src/Core/MatrixBase.h
-./Eigen/src/Core/DenseStorage.h
-./Eigen/src/Core/util/Memory.h
-./Eigen/src/Core/util/StaticAssert.h
-./Eigen/src/Core/util/BlasUtil.h
-./Eigen/src/Core/util/MatrixMapper.h
-./Eigen/src/Core/util/XprHelper.h
-./Eigen/src/Core/util/ForwardDeclarations.h
-./Eigen/src/Core/util/Meta.h
-./Eigen/src/Core/util/Macros.h
-./Eigen/src/Core/util/Constants.h
-./Eigen/src/Core/CwiseNullaryOp.h
-./Eigen/src/Core/Block.h
-./Eigen/src/Core/GeneralProduct.h
-./Eigen/src/Core/CommaInitializer.h
-./Eigen/src/Core/ReturnByValue.h
-./Eigen/src/Core/Stride.h
-./Eigen/src/SPQRSupport/SuiteSparseQRSupport.h
-./Eigen/src/SparseLU/SparseLU_column_dfs.h
-./Eigen/src/SparseLU/SparseLU_panel_dfs.h
-./Eigen/src/SparseLU/SparseLU_relax_snode.h
-./Eigen/src/SparseLU/SparseLU_panel_bmod.h
-./Eigen/src/SparseLU/SparseLU_SupernodalMatrix.h
-./Eigen/src/SparseLU/SparseLU_Utils.h
-./Eigen/src/SparseLU/SparseLU_gemm_kernel.h
-./Eigen/src/SparseLU/SparseLU_kernel_bmod.h
-./Eigen/src/SparseLU/SparseLU_pivotL.h
-./Eigen/src/SparseLU/SparseLU_Memory.h
-./Eigen/src/SparseLU/SparseLU_heap_relax_snode.h
-./Eigen/src/SparseLU/SparseLUImpl.h
-./Eigen/src/SparseLU/SparseLU_copy_to_ucol.h
-./Eigen/src/SparseLU/SparseLU_Structs.h
-./Eigen/src/SparseLU/SparseLU.h
-./Eigen/src/SparseLU/SparseLU_column_bmod.h
-./Eigen/src/SparseLU/SparseLU_pruneL.h
-./Eigen/src/IterativeLinearSolvers/IncompleteLUT.h
-./Eigen/src/IterativeLinearSolvers/BasicPreconditioners.h
-./Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h
-./Eigen/src/IterativeLinearSolvers/ConjugateGradient.h
-./Eigen/src/IterativeLinearSolvers/BiCGSTAB.h
-./Eigen/src/SparseCholesky/SimplicialCholesky.h
-./Eigen/src/Cholesky/LDLT.h
-./Eigen/src/Cholesky/LLT.h
-./Eigen/src/CholmodSupport/CholmodSupport.h
-./Eigen/src/PaStiXSupport/PaStiXSupport.h
-./Eigen/src/MetisSupport/MetisSupport.h
-./Eigen/StdVector
-./Eigen/Core
-./Eigen/OrderingMethods
-./Eigen/SparseLU
-./Eigen/StdList
-./Eigen/StdDeque
-./Eigen/SparseCholesky
-./Eigen/SparseCore
-./scripts/relicense.py
-./scripts/relicense.py
-./blas/BandTriangularSolver.h
-./blas/PackedTriangularMatrixVector.h
-./blas/complex_double.cpp
-./blas/level2_real_impl.h
-./blas/level1_cplx_impl.h
-./blas/level1_impl.h
-./blas/level1_real_impl.h
-./blas/level3_impl.h
-./blas/single.cpp
-./blas/level2_cplx_impl.h
-./blas/PackedSelfadjointProduct.h
-./blas/Rank2Update.h
-./blas/complex_single.cpp
-./blas/PackedTriangularSolverVector.h
-./blas/double.cpp
-./blas/common.h
-./blas/level2_impl.h
-./blas/GeneralRank1Update.h
-
-Mozilla Public License Version 2.0
-==================================
-
-1. Definitions
---------------
-
-1.1. "Contributor"
-    means each individual or legal entity that creates, contributes to
-    the creation of, or owns Covered Software.
-
-1.2. "Contributor Version"
-    means the combination of the Contributions of others (if any) used
-    by a Contributor and that particular Contributor's Contribution.
-
-1.3. "Contribution"
-    means Covered Software of a particular Contributor.
-
-1.4. "Covered Software"
-    means Source Code Form to which the initial Contributor has attached
-    the notice in Exhibit A, the Executable Form of such Source Code
-    Form, and Modifications of such Source Code Form, in each case
-    including portions thereof.
-
-1.5. "Incompatible With Secondary Licenses"
-    means
-
-    (a) that the initial Contributor has attached the notice described
-        in Exhibit B to the Covered Software; or
-
-    (b) that the Covered Software was made available under the terms of
-        version 1.1 or earlier of the License, but not also under the
-        terms of a Secondary License.
-
-1.6. "Executable Form"
-    means any form of the work other than Source Code Form.
-
-1.7. "Larger Work"
-    means a work that combines Covered Software with other material, in 
-    a separate file or files, that is not Covered Software.
-
-1.8. "License"
-    means this document.
-
-1.9. "Licensable"
-    means having the right to grant, to the maximum extent possible,
-    whether at the time of the initial grant or subsequently, any and
-    all of the rights conveyed by this License.
-
-1.10. "Modifications"
-    means any of the following:
-
-    (a) any file in Source Code Form that results from an addition to,
-        deletion from, or modification of the contents of Covered
-        Software; or
-
-    (b) any new file in Source Code Form that contains any Covered
-        Software.
-
-1.11. "Patent Claims" of a Contributor
-    means any patent claim(s), including without limitation, method,
-    process, and apparatus claims, in any patent Licensable by such
-    Contributor that would be infringed, but for the grant of the
-    License, by the making, using, selling, offering for sale, having
-    made, import, or transfer of either its Contributions or its
-    Contributor Version.
-
-1.12. "Secondary License"
-    means either the GNU General Public License, Version 2.0, the GNU
-    Lesser General Public License, Version 2.1, the GNU Affero General
-    Public License, Version 3.0, or any later versions of those
-    licenses.
-
-1.13. "Source Code Form"
-    means the form of the work preferred for making modifications.
-
-1.14. "You" (or "Your")
-    means an individual or a legal entity exercising rights under this
-    License. For legal entities, "You" includes any entity that
-    controls, is controlled by, or is under common control with You. For
-    purposes of this definition, "control" means (a) the power, direct
-    or indirect, to cause the direction or management of such entity,
-    whether by contract or otherwise, or (b) ownership of more than
-    fifty percent (50%) of the outstanding shares or beneficial
-    ownership of such entity.
-
-2. License Grants and Conditions
---------------------------------
-
-2.1. Grants
-
-Each Contributor hereby grants You a world-wide, royalty-free,
-non-exclusive license:
-
-(a) under intellectual property rights (other than patent or trademark)
-    Licensable by such Contributor to use, reproduce, make available,
-    modify, display, perform, distribute, and otherwise exploit its
-    Contributions, either on an unmodified basis, with Modifications, or
-    as part of a Larger Work; and
-
-(b) under Patent Claims of such Contributor to make, use, sell, offer
-    for sale, have made, import, and otherwise transfer either its
-    Contributions or its Contributor Version.
-
-2.2. Effective Date
-
-The licenses granted in Section 2.1 with respect to any Contribution
-become effective for each Contribution on the date the Contributor first
-distributes such Contribution.
-
-2.3. Limitations on Grant Scope
-
-The licenses granted in this Section 2 are the only rights granted under
-this License. No additional rights or licenses will be implied from the
-distribution or licensing of Covered Software under this License.
-Notwithstanding Section 2.1(b) above, no patent license is granted by a
-Contributor:
-
-(a) for any code that a Contributor has removed from Covered Software;
-    or
-
-(b) for infringements caused by: (i) Your and any other third party's
-    modifications of Covered Software, or (ii) the combination of its
-    Contributions with other software (except as part of its Contributor
-    Version); or
-
-(c) under Patent Claims infringed by Covered Software in the absence of
-    its Contributions.
-
-This License does not grant any rights in the trademarks, service marks,
-or logos of any Contributor (except as may be necessary to comply with
-the notice requirements in Section 3.4).
-
-2.4. Subsequent Licenses
-
-No Contributor makes additional grants as a result of Your choice to
-distribute the Covered Software under a subsequent version of this
-License (see Section 10.2) or under the terms of a Secondary License (if
-permitted under the terms of Section 3.3).
-
-2.5. Representation
-
-Each Contributor represents that the Contributor believes its
-Contributions are its original creation(s) or it has sufficient rights
-to grant the rights to its Contributions conveyed by this License.
-
-2.6. Fair Use
-
-This License is not intended to limit any rights You have under
-applicable copyright doctrines of fair use, fair dealing, or other
-equivalents.
-
-2.7. Conditions
-
-Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted
-in Section 2.1.
-
-3. Responsibilities
--------------------
-
-3.1. Distribution of Source Form
-
-All distribution of Covered Software in Source Code Form, including any
-Modifications that You create or to which You contribute, must be under
-the terms of this License. You must inform recipients that the Source
-Code Form of the Covered Software is governed by the terms of this
-License, and how they can obtain a copy of this License. You may not
-attempt to alter or restrict the recipients' rights in the Source Code
-Form.
-
-3.2. Distribution of Executable Form
-
-If You distribute Covered Software in Executable Form then:
-
-(a) such Covered Software must also be made available in Source Code
-    Form, as described in Section 3.1, and You must inform recipients of
-    the Executable Form how they can obtain a copy of such Source Code
-    Form by reasonable means in a timely manner, at a charge no more
-    than the cost of distribution to the recipient; and
-
-(b) You may distribute such Executable Form under the terms of this
-    License, or sublicense it under different terms, provided that the
-    license for the Executable Form does not attempt to limit or alter
-    the recipients' rights in the Source Code Form under this License.
-
-3.3. Distribution of a Larger Work
-
-You may create and distribute a Larger Work under terms of Your choice,
-provided that You also comply with the requirements of this License for
-the Covered Software. If the Larger Work is a combination of Covered
-Software with a work governed by one or more Secondary Licenses, and the
-Covered Software is not Incompatible With Secondary Licenses, this
-License permits You to additionally distribute such Covered Software
-under the terms of such Secondary License(s), so that the recipient of
-the Larger Work may, at their option, further distribute the Covered
-Software under the terms of either this License or such Secondary
-License(s).
-
-3.4. Notices
-
-You may not remove or alter the substance of any license notices
-(including copyright notices, patent notices, disclaimers of warranty,
-or limitations of liability) contained within the Source Code Form of
-the Covered Software, except that You may alter any license notices to
-the extent required to remedy known factual inaccuracies.
-
-3.5. Application of Additional Terms
-
-You may choose to offer, and to charge a fee for, warranty, support,
-indemnity or liability obligations to one or more recipients of Covered
-Software. However, You may do so only on Your own behalf, and not on
-behalf of any Contributor. You must make it absolutely clear that any
-such warranty, support, indemnity, or liability obligation is offered by
-You alone, and You hereby agree to indemnify every Contributor for any
-liability incurred by such Contributor as a result of warranty, support,
-indemnity or liability terms You offer. You may include additional
-disclaimers of warranty and limitations of liability specific to any
-jurisdiction.
-
-4. Inability to Comply Due to Statute or Regulation
----------------------------------------------------
-
-If it is impossible for You to comply with any of the terms of this
-License with respect to some or all of the Covered Software due to
-statute, judicial order, or regulation then You must: (a) comply with
-the terms of this License to the maximum extent possible; and (b)
-describe the limitations and the code they affect. Such description must
-be placed in a text file included with all distributions of the Covered
-Software under this License. Except to the extent prohibited by statute
-or regulation, such description must be sufficiently detailed for a
-recipient of ordinary skill to be able to understand it.
-
-5. Termination
---------------
-
-5.1. The rights granted under this License will terminate automatically
-if You fail to comply with any of its terms. However, if You become
-compliant, then the rights granted under this License from a particular
-Contributor are reinstated (a) provisionally, unless and until such
-Contributor explicitly and finally terminates Your grants, and (b) on an
-ongoing basis, if such Contributor fails to notify You of the
-non-compliance by some reasonable means prior to 60 days after You have
-come back into compliance. Moreover, Your grants from a particular
-Contributor are reinstated on an ongoing basis if such Contributor
-notifies You of the non-compliance by some reasonable means, this is the
-first time You have received notice of non-compliance with this License
-from such Contributor, and You become compliant prior to 30 days after
-Your receipt of the notice.
-
-5.2. If You initiate litigation against any entity by asserting a patent
-infringement claim (excluding declaratory judgment actions,
-counter-claims, and cross-claims) alleging that a Contributor Version
-directly or indirectly infringes any patent, then the rights granted to
-You by any and all Contributors for the Covered Software under Section
-2.1 of this License shall terminate.
-
-5.3. In the event of termination under Sections 5.1 or 5.2 above, all
-end user license agreements (excluding distributors and resellers) which
-have been validly granted by You or Your distributors under this License
-prior to termination shall survive termination.
-
-************************************************************************
-*                                                                      *
-*  6. Disclaimer of Warranty                                           *
-*  -------------------------                                           *
-*                                                                      *
-*  Covered Software is provided under this License on an "as is"       *
-*  basis, without warranty of any kind, either expressed, implied, or  *
-*  statutory, including, without limitation, warranties that the       *
-*  Covered Software is free of defects, merchantable, fit for a        *
-*  particular purpose or non-infringing. The entire risk as to the     *
-*  quality and performance of the Covered Software is with You.        *
-*  Should any Covered Software prove defective in any respect, You     *
-*  (not any Contributor) assume the cost of any necessary servicing,   *
-*  repair, or correction. This disclaimer of warranty constitutes an   *
-*  essential part of this License. No use of any Covered Software is   *
-*  authorized under this License except under this disclaimer.         *
-*                                                                      *
-************************************************************************
-
-************************************************************************
-*                                                                      *
-*  7. Limitation of Liability                                          *
-*  --------------------------                                          *
-*                                                                      *
-*  Under no circumstances and under no legal theory, whether tort      *
-*  (including negligence), contract, or otherwise, shall any           *
-*  Contributor, or anyone who distributes Covered Software as          *
-*  permitted above, be liable to You for any direct, indirect,         *
-*  special, incidental, or consequential damages of any character      *
-*  including, without limitation, damages for lost profits, loss of    *
-*  goodwill, work stoppage, computer failure or malfunction, or any    *
-*  and all other commercial damages or losses, even if such party      *
-*  shall have been informed of the possibility of such damages. This   *
-*  limitation of liability shall not apply to liability for death or   *
-*  personal injury resulting from such party's negligence to the       *
-*  extent applicable law prohibits such limitation. Some               *
-*  jurisdictions do not allow the exclusion or limitation of           *
-*  incidental or consequential damages, so this exclusion and          *
-*  limitation may not apply to You.                                    *
-*                                                                      *
-************************************************************************
-
-8. Litigation
--------------
-
-Any litigation relating to this License may be brought only in the
-courts of a jurisdiction where the defendant maintains its principal
-place of business and such litigation shall be governed by laws of that
-jurisdiction, without reference to its conflict-of-law provisions.
-Nothing in this Section shall prevent a party's ability to bring
-cross-claims or counter-claims.
-
-9. Miscellaneous
-----------------
-
-This License represents the complete agreement concerning the subject
-matter hereof. If any provision of this License is held to be
-unenforceable, such provision shall be reformed only to the extent
-necessary to make it enforceable. Any law or regulation which provides
-that the language of a contract shall be construed against the drafter
-shall not be used to construe this License against a Contributor.
-
-10. Versions of the License
----------------------------
-
-10.1. New Versions
-
-Mozilla Foundation is the license steward. Except as provided in Section
-10.3, no one other than the license steward has the right to modify or
-publish new versions of this License. Each version will be given a
-distinguishing version number.
-
-10.2. Effect of New Versions
-
-You may distribute the Covered Software under the terms of the version
-of the License under which You originally received the Covered Software,
-or under the terms of any subsequent version published by the license
-steward.
-
-10.3. Modified Versions
-
-If you create software not governed by this License, and you want to
-create a new license for such software, you may create and use a
-modified version of this License if you rename the license and remove
-any references to the name of the license steward (except to note that
-such modified license differs from this License).
-
-10.4. Distributing Source Code Form that is Incompatible With Secondary
-Licenses
-
-If You choose to distribute Source Code Form that is Incompatible With
-Secondary Licenses under the terms of this version of the License, the
-notice described in Exhibit B of this License must be attached.
-
-Exhibit A - Source Code Form License Notice
--------------------------------------------
-
-  This Source Code Form is subject to the terms of the Mozilla Public
-  License, v. 2.0. If a copy of the MPL was not distributed with this
-  file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-If it is not possible or desirable to put the notice in a particular
-file, then You may include the notice in a location (such as a LICENSE
-file in a relevant directory) where a recipient would be likely to look
-for such a notice.
-
-You may add additional accurate notices of copyright ownership.
-
-Exhibit B - "Incompatible With Secondary Licenses" Notice
----------------------------------------------------------
-
-  This Source Code Form is "Incompatible With Secondary Licenses", as
-  defined by the Mozilla Public License, v. 2.0.
-
-----------------------------------------------------------------------
-Following applies to:
-./doc/UsingIntelMKL.dox
-./doc/UsingIntelMKL.dox
-./Eigen/src/Eigenvalues/ComplexSchur_MKL.h
-./Eigen/src/Eigenvalues/ComplexSchur_MKL.h
-./Eigen/src/Eigenvalues/SelfAdjointEigenSolver_MKL.h
-./Eigen/src/Eigenvalues/SelfAdjointEigenSolver_MKL.h
-./Eigen/src/Eigenvalues/RealSchur_MKL.h
-./Eigen/src/Eigenvalues/RealSchur_MKL.h
-./Eigen/src/LU/arch/Inverse_SSE.h
-./Eigen/src/LU/arch/Inverse_SSE.h
-./Eigen/src/LU/PartialPivLU_MKL.h
-./Eigen/src/LU/PartialPivLU_MKL.h
-./Eigen/src/QR/HouseholderQR_MKL.h
-./Eigen/src/QR/HouseholderQR_MKL.h
-./Eigen/src/QR/ColPivHouseholderQR_MKL.h
-./Eigen/src/QR/ColPivHouseholderQR_MKL.h
-./Eigen/src/SVD/JacobiSVD_MKL.h
-./Eigen/src/SVD/JacobiSVD_MKL.h
-./Eigen/src/PardisoSupport/PardisoSupport.h
-./Eigen/src/PardisoSupport/PardisoSupport.h
-./Eigen/src/Core/Assign_MKL.h
-./Eigen/src/Core/Assign_MKL.h
-./Eigen/src/Core/products/SelfadjointMatrixVector_MKL.h
-./Eigen/src/Core/products/SelfadjointMatrixVector_MKL.h
-./Eigen/src/Core/products/GeneralMatrixVector_MKL.h
-./Eigen/src/Core/products/GeneralMatrixVector_MKL.h
-./Eigen/src/Core/products/SelfadjointMatrixMatrix_MKL.h
-./Eigen/src/Core/products/SelfadjointMatrixMatrix_MKL.h
-./Eigen/src/Core/products/TriangularMatrixMatrix_MKL.h
-./Eigen/src/Core/products/TriangularMatrixMatrix_MKL.h
-./Eigen/src/Core/products/GeneralMatrixMatrix_MKL.h
-./Eigen/src/Core/products/GeneralMatrixMatrix_MKL.h
-./Eigen/src/Core/products/TriangularMatrixVector_MKL.h
-./Eigen/src/Core/products/TriangularMatrixVector_MKL.h
-./Eigen/src/Core/products/GeneralMatrixMatrixTriangular_MKL.h
-./Eigen/src/Core/products/GeneralMatrixMatrixTriangular_MKL.h
-./Eigen/src/Core/products/TriangularSolverMatrix_MKL.h
-./Eigen/src/Core/products/TriangularSolverMatrix_MKL.h
-./Eigen/src/Core/util/MKL_support.h
-./Eigen/src/Core/util/MKL_support.h
-./Eigen/src/Cholesky/LLT_MKL.h
-./Eigen/src/Cholesky/LLT_MKL.h
-
-/*
- Copyright (c) 2011, Intel Corporation. All rights reserved.
-
- Redistribution and use in source and binary forms, with or without
- modification, are permitted provided that the following conditions
- are met:
-
- * Redistributions of source code must retain the above copyright
-   notice, this list of conditions and the following disclaimer.  *
-   Redistributions in binary form must reproduce the above copyright
-   notice, this list of conditions and the following disclaimer in the
-   documentation and/or other materials provided with the
-   distribution.  * Neither the name of Intel Corporation nor the
-   names of its contributors may be used to endorse or promote
-   products derived from this software without specific prior written
-   permission.
-
- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
- "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
- LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
- A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
- OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
- SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
- LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
- DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
- THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
- (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
- OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
- */
-
-
-----------------------------------------------------------------------
-Following applies to:
-./unsupported/Eigen/src/LevenbergMarquardt/LevenbergMarquardt.h
-./unsupported/Eigen/src/LevenbergMarquardt/LMcovar.h
-./unsupported/Eigen/src/LevenbergMarquardt/LMonestep.h
-./unsupported/Eigen/src/LevenbergMarquardt/LMpar.h
-./unsupported/Eigen/src/LevenbergMarquardt/LMqrsolv.h
-
-Minpack Copyright Notice (1999) University of Chicago.  All rights
-reserved
-
-Redistribution and use in source and binary forms, with or
-without modification, are permitted provided that the
-following conditions are met:
-
-1. Redistributions of source code must retain the above
-copyright notice, this list of conditions and the following
-disclaimer.
-
-2. Redistributions in binary form must reproduce the above
-copyright notice, this list of conditions and the following
-disclaimer in the documentation and/or other materials
-provided with the distribution.
-
-3. The end-user documentation included with the
-redistribution, if any, must include the following
-acknowledgment:
-
-   "This product includes software developed by the
-   University of Chicago, as Operator of Argonne National
-   Laboratory.
-
-Alternately, this acknowledgment may appear in the software
-itself, if and wherever such third-party acknowledgments
-normally appear.
-
-4. WARRANTY DISCLAIMER. THE SOFTWARE IS SUPPLIED "AS IS"
-WITHOUT WARRANTY OF ANY KIND. THE COPYRIGHT HOLDER, THE
-UNITED STATES, THE UNITED STATES DEPARTMENT OF ENERGY, AND
-THEIR EMPLOYEES: (1) DISCLAIM ANY WARRANTIES, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO ANY IMPLIED WARRANTIES
-OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE
-OR NON-INFRINGEMENT, (2) DO NOT ASSUME ANY LEGAL LIABILITY
-OR RESPONSIBILITY FOR THE ACCURACY, COMPLETENESS, OR
-USEFULNESS OF THE SOFTWARE, (3) DO NOT REPRESENT THAT USE OF
-THE SOFTWARE WOULD NOT INFRINGE PRIVATELY OWNED RIGHTS, (4)
-DO NOT WARRANT THAT THE SOFTWARE WILL FUNCTION
-UNINTERRUPTED, THAT IT IS ERROR-FREE OR THAT ANY ERRORS WILL
-BE CORRECTED.
-
-5. LIMITATION OF LIABILITY. IN NO EVENT WILL THE COPYRIGHT
-HOLDER, THE UNITED STATES, THE UNITED STATES DEPARTMENT OF
-ENERGY, OR THEIR EMPLOYEES: BE LIABLE FOR ANY INDIRECT,
-INCIDENTAL, CONSEQUENTIAL, SPECIAL OR PUNITIVE DAMAGES OF
-ANY KIND OR NATURE, INCLUDING BUT NOT LIMITED TO LOSS OF
-PROFITS OR LOSS OF DATA, FOR ANY REASON WHATSOEVER, WHETHER
-SUCH LIABILITY IS ASSERTED ON THE BASIS OF CONTRACT, TORT
-(INCLUDING NEGLIGENCE OR STRICT LIABILITY), OR OTHERWISE,
-EVEN IF ANY OF SAID PARTIES HAS BEEN WARNED OF THE
-POSSIBILITY OF SUCH LOSS OR DAMAGES.
diff --git a/third_party/xla/third_party/eigen3/eigen_archive.BUILD b/third_party/xla/third_party/eigen3/eigen_archive.BUILD
deleted file mode 100644
index b8c0202..0000000
--- a/third_party/xla/third_party/eigen3/eigen_archive.BUILD
+++ /dev/null
@@ -1,74 +0,0 @@
-# Description:
-#   Eigen is a C++ template library for linear algebra: vectors,
-#   matrices, and related algorithms.
-# This is the BUILD file used for the @eigen_archive external repository.
-
-licenses([
-    # Note: Although Eigen also includes GPL V3 and LGPL v2.1+ code, TensorFlow
-    #       has taken special care to not reference any restricted code.
-    "reciprocal",  # MPL2
-    "notice",  # Portions BSD
-])
-
-exports_files(["COPYING.MPL2"])
-
-ALL_FILES_WITH_EXTENSIONS = glob(["**/*.*"])
-
-# Top-level headers, excluding anything in one of the  ../src/.. directories.
-EIGEN_HEADERS = glob(
-    [
-        "Eigen/*",
-        "unsupported/Eigen/*",
-        "unsupported/Eigen/CXX11/*",
-    ],
-    exclude = [
-        "**/src/**",
-    ] + ALL_FILES_WITH_EXTENSIONS,
-)
-
-# Internal eigen headers, known to be under an MPL2 license.
-EIGEN_MPL2_SOURCES = glob(
-    [
-        "Eigen/**/src/**/*.h",
-        "Eigen/**/src/**/*.inc",
-        "unsupported/Eigen/**/src/**/*.h",
-        "unsupported/Eigen/**/src/**/*.inc",
-    ],
-    exclude = [
-        # This guarantees that any file depending on non MPL2 licensed code
-        # will not compile.
-        "Eigen/src/Core/util/NonMPL2.h",
-    ],
-)
-
-alias(
-    name = "eigen3",
-    actual = "@local_xla//third_party/eigen3",
-    visibility = ["//visibility:public"],
-)
-
-cc_library(
-    name = "eigen3_internal",
-    srcs = EIGEN_MPL2_SOURCES,
-    hdrs = EIGEN_HEADERS,
-    defines = [
-        # This define (mostly) guarantees we don't link any problematic
-        # code. We use it, but we do not rely on it, as evidenced above.
-        "EIGEN_MPL2_ONLY",
-        "EIGEN_MAX_ALIGN_BYTES=64",
-    ],
-    includes = ["."],
-    visibility = ["//visibility:public"],
-)
-
-filegroup(
-    name = "eigen_header_files",
-    srcs = EIGEN_HEADERS,
-    visibility = ["//visibility:public"],
-)
-
-filegroup(
-    name = "eigen_source_files",
-    srcs = EIGEN_MPL2_SOURCES,
-    visibility = ["//visibility:public"],
-)
diff --git a/third_party/xla/third_party/eigen3/unsupported/Eigen/CXX11/Tensor b/third_party/xla/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
deleted file mode 100644
index 41db119..0000000
--- a/third_party/xla/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
+++ /dev/null
@@ -1 +0,0 @@
-#include "unsupported/Eigen/CXX11/Tensor"
diff --git a/third_party/xla/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool b/third_party/xla/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool
deleted file mode 100644
index d2639af..0000000
--- a/third_party/xla/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool
+++ /dev/null
@@ -1 +0,0 @@
-#include "unsupported/Eigen/CXX11/ThreadPool"
diff --git a/third_party/xla/third_party/eigen3/unsupported/Eigen/MatrixFunctions b/third_party/xla/third_party/eigen3/unsupported/Eigen/MatrixFunctions
deleted file mode 100644
index 314b325..0000000
--- a/third_party/xla/third_party/eigen3/unsupported/Eigen/MatrixFunctions
+++ /dev/null
@@ -1 +0,0 @@
-#include "unsupported/Eigen/MatrixFunctions"
diff --git a/third_party/xla/third_party/eigen3/unsupported/Eigen/SpecialFunctions b/third_party/xla/third_party/eigen3/unsupported/Eigen/SpecialFunctions
deleted file mode 100644
index ad13359..0000000
--- a/third_party/xla/third_party/eigen3/unsupported/Eigen/SpecialFunctions
+++ /dev/null
@@ -1 +0,0 @@
-#include "unsupported/Eigen/SpecialFunctions"
diff --git a/third_party/xla/third_party/eigen3/workspace.bzl b/third_party/xla/third_party/eigen3/workspace.bzl
deleted file mode 100644
index d1d8d4a..0000000
--- a/third_party/xla/third_party/eigen3/workspace.bzl
+++ /dev/null
@@ -1,20 +0,0 @@
-"""Provides the repository macro to import Eigen."""
-
-load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
-
-def repo():
-    """Imports Eigen."""
-
-    # Attention: tools parse and update these lines.
-    # LINT.IfChange
-    EIGEN_COMMIT = "66e8f38891841bf88ee976a316c0c78a52f0cee5"
-    EIGEN_SHA256 = "01fcd68409c038bbcfd16394274c2bf71e2bb6dda89a2319e23fc59a2da17210"
-    # LINT.ThenChange(//tensorflow/lite/tools/cmake/modules/eigen.cmake)
-
-    tf_http_archive(
-        name = "eigen_archive",
-        build_file = "//third_party/eigen3:eigen_archive.BUILD",
-        sha256 = EIGEN_SHA256,
-        strip_prefix = "eigen-{commit}".format(commit = EIGEN_COMMIT),
-        urls = tf_mirror_urls("https://gitlab.com/libeigen/eigen/-/archive/{commit}/eigen-{commit}.tar.gz".format(commit = EIGEN_COMMIT)),
-    )
diff --git a/third_party/xla/third_party/farmhash/BUILD b/third_party/xla/third_party/farmhash/BUILD
deleted file mode 100644
index 3c41380..0000000
--- a/third_party/xla/third_party/farmhash/BUILD
+++ /dev/null
@@ -1 +0,0 @@
-# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"])
diff --git a/third_party/xla/third_party/farmhash/farmhash.BUILD b/third_party/xla/third_party/farmhash/farmhash.BUILD
deleted file mode 100644
index 4b84646..0000000
--- a/third_party/xla/third_party/farmhash/farmhash.BUILD
+++ /dev/null
@@ -1,23 +0,0 @@
-licenses(["notice"])  # MIT
-
-exports_files(["COPYING"])
-
-config_setting(
-    name = "windows",
-    values = {
-        "cpu": "x64_windows",
-    },
-)
-
-cc_library(
-    name = "farmhash",
-    srcs = ["src/farmhash.cc"],
-    hdrs = ["src/farmhash.h"],
-    # Disable __builtin_expect support on Windows
-    copts = select({
-        ":windows": ["/DFARMHASH_OPTIONAL_BUILTIN_EXPECT"],
-        "//conditions:default": [],
-    }),
-    includes = ["src/."],
-    visibility = ["//visibility:public"],
-)
diff --git a/third_party/xla/third_party/farmhash/farmhash_gpu.BUILD b/third_party/xla/third_party/farmhash/farmhash_gpu.BUILD
deleted file mode 100644
index 78e551d..0000000
--- a/third_party/xla/third_party/farmhash/farmhash_gpu.BUILD
+++ /dev/null
@@ -1,14 +0,0 @@
-# Description:
-# This is a modified farmhash to only include GPU-related functions.
-
-package(
-    default_visibility = ["//visibility:public"],
-)
-
-licenses(["notice"])  # MIT
-
-cc_library(
-    name = "farmhash_gpu",
-    hdrs = ["src/farmhash_gpu.h"],
-    include_prefix = "third_party/farmhash_gpu",
-)
diff --git a/third_party/xla/third_party/farmhash/farmhash_support_cuda.patch b/third_party/xla/third_party/farmhash/farmhash_support_cuda.patch
deleted file mode 100644
index a5400a0..0000000
--- a/third_party/xla/third_party/farmhash/farmhash_support_cuda.patch
+++ /dev/null
@@ -1,289 +0,0 @@
-From eb130493c8042280a01e03c28bb89bd5ae0c5d18 Mon Sep 17 00:00:00 2001
-From: Kaixi Hou <kaixih@nvidia.com>
-Date: Tue, 23 Mar 2021 12:49:18 -0700
-Subject: [PATCH] Add device modifiers for GPUs
-
----
- src/{farmhash.cc => farmhash_gpu.h} | 95 +++++++++++++++++++++++------
- 1 file changed, 75 insertions(+), 20 deletions(-)
- rename src/{farmhash.cc => farmhash_gpu.h} (99%)
-
-diff --git a/src/farmhash.cc b/src/farmhash_gpu.h
-similarity index 99%
-rename from src/farmhash.cc
-rename to src/farmhash_gpu.h
-index cfd4a47..50994b6 100644
---- a/src/farmhash.cc
-+++ b/src/farmhash_gpu.h
-@@ -20,6 +20,17 @@
- //
- // FarmHash, by Geoff Pike
- 
-+#ifndef FARM_HASH_GPU_H_
-+#define FARM_HASH_GPU_H_
-+
-+#include <cstdint>
-+#include <string.h>   // for memcpy and memset
-+
-+#define NAMESPACE_FOR_HASH_FUNCTIONS_GPU util_gpu
-+#define DEVICE_MODIFIER __device__ __host__
-+
-+// We use DEVICE_MODIFIER to remove those code unused by GPUs.
-+#ifndef DEVICE_MODIFIER
- #include "farmhash.h"
- // FARMHASH ASSUMPTIONS: Modify as needed, or use -DFARMHASH_ASSUME_SSE42 etc.
- // Note that if you use -DFARMHASH_ASSUME_SSE42 you likely need -msse42
-@@ -187,7 +198,14 @@
- #define uint64_in_expected_order(x) (x)
- #endif
- 
--namespace NAMESPACE_FOR_HASH_FUNCTIONS {
-+#endif // DEVICE_MODIFIER
-+
-+#define uint32_in_expected_order(x) (x)
-+#define uint64_in_expected_order(x) (x)
-+
-+#define STATIC_INLINE DEVICE_MODIFIER inline
-+
-+namespace NAMESPACE_FOR_HASH_FUNCTIONS_GPU {
- 
- STATIC_INLINE uint64_t Fetch64(const char *p) {
-   uint64_t result;
-@@ -201,6 +219,7 @@ STATIC_INLINE uint32_t Fetch32(const char *p) {
-   return uint32_in_expected_order(result);
- }
- 
-+#ifndef DEVICE_MODIFIER
- STATIC_INLINE uint32_t Bswap32(uint32_t val) { return bswap_32(val); }
- STATIC_INLINE uint64_t Bswap64(uint64_t val) { return bswap_64(val); }
- 
-@@ -210,12 +229,14 @@ STATIC_INLINE uint32_t BasicRotate32(uint32_t val, int shift) {
-   // Avoid shifting by 32: doing so yields an undefined result.
-   return shift == 0 ? val : ((val >> shift) | (val << (32 - shift)));
- }
-+#endif // DEVICE_MODIFIER
- 
- STATIC_INLINE uint64_t BasicRotate64(uint64_t val, int shift) {
-   // Avoid shifting by 64: doing so yields an undefined result.
-   return shift == 0 ? val : ((val >> shift) | (val << (64 - shift)));
- }
- 
-+#ifndef DEVICE_MODIFIER
- #if defined(_WIN32) && defined(FARMHASH_ROTR)
- 
- STATIC_INLINE uint32_t Rotate32(uint32_t val, int shift) {
-@@ -240,12 +261,18 @@ STATIC_INLINE uint64_t Rotate64(uint64_t val, int shift) {
- }
- 
- #endif
-+#endif // DEVICE_MODIFIER
- 
--}  // namespace NAMESPACE_FOR_HASH_FUNCTIONS
-+STATIC_INLINE uint64_t Rotate64(uint64_t val, int shift) {
-+  return BasicRotate64(val, shift);
-+}
-+
-+}  // namespace NAMESPACE_FOR_HASH_FUNCTIONS_GPU
- 
- // FARMHASH PORTABILITY LAYER: debug mode or max speed?
- // One may use -DFARMHASH_DEBUG=1 or -DFARMHASH_DEBUG=0 to force the issue.
- 
-+#ifndef DEVICE_MODIFIER
- #if !defined(FARMHASH_DEBUG) && (!defined(NDEBUG) || defined(_DEBUG))
- #define FARMHASH_DEBUG 1
- #endif
-@@ -345,14 +372,21 @@ STATIC_INLINE __m128i Fetch128(const char* s) {
- 
- #undef PERMUTE3
- #define PERMUTE3(a, b, c) do { std::swap(a, b); std::swap(a, c); } while (0)
-+#endif // DEVICE_MODIFIER
-+
-+struct Pair {
-+  uint64_t first;
-+  uint64_t second;
-+};
- 
--namespace NAMESPACE_FOR_HASH_FUNCTIONS {
-+namespace NAMESPACE_FOR_HASH_FUNCTIONS_GPU {
- 
- // Some primes between 2^63 and 2^64 for various uses.
- static const uint64_t k0 = 0xc3a5c85c97cb3127ULL;
- static const uint64_t k1 = 0xb492b66fbe98f273ULL;
- static const uint64_t k2 = 0x9ae16a3b2f90404fULL;
- 
-+#ifndef DEVICE_MODIFIER
- // Magic numbers for 32-bit hashing.  Copied from Murmur3.
- static const uint32_t c1 = 0xcc9e2d51;
- static const uint32_t c2 = 0x1b873593;
-@@ -399,28 +433,34 @@ template <> uint128_t DebugTweak(uint128_t x) {
-   }
-   return x;
- }
-+#endif // DEVICE_MODIFIER
-+}  // namespace NAMESPACE_FOR_HASH_FUNCTIONS_GPU
- 
--}  // namespace NAMESPACE_FOR_HASH_FUNCTIONS
--
-+#ifndef DEVICE_MODIFIER
- using namespace std;
--using namespace NAMESPACE_FOR_HASH_FUNCTIONS;
--namespace farmhashna {
-+#endif // DEVICE_MODIFIER
-+using namespace NAMESPACE_FOR_HASH_FUNCTIONS_GPU;
-+namespace farmhashna_gpu {
- #undef Fetch
- #define Fetch Fetch64
- 
- #undef Rotate
- #define Rotate Rotate64
- 
-+#ifndef DEVICE_MODIFIER
- #undef Bswap
- #define Bswap Bswap64
-+#endif // DEVICE_MODIFIER
- 
- STATIC_INLINE uint64_t ShiftMix(uint64_t val) {
-   return val ^ (val >> 47);
- }
- 
-+#ifndef DEVICE_MODIFIER
- STATIC_INLINE uint64_t HashLen16(uint64_t u, uint64_t v) {
-   return Hash128to64(Uint128(u, v));
- }
-+#endif // DEVICE_MODIFIER
- 
- STATIC_INLINE uint64_t HashLen16(uint64_t u, uint64_t v, uint64_t mul) {
-   // Murmur-inspired hashing.
-@@ -471,7 +511,7 @@ STATIC_INLINE uint64_t HashLen17to32(const char *s, size_t len) {
- 
- // Return a 16-byte hash for 48 bytes.  Quick and dirty.
- // Callers do best to use "random-looking" values for a and b.
--STATIC_INLINE pair<uint64_t, uint64_t> WeakHashLen32WithSeeds(
-+STATIC_INLINE Pair WeakHashLen32WithSeeds(
-     uint64_t w, uint64_t x, uint64_t y, uint64_t z, uint64_t a, uint64_t b) {
-   a += w;
-   b = Rotate(b + a + z, 21);
-@@ -479,11 +519,11 @@ STATIC_INLINE pair<uint64_t, uint64_t> WeakHashLen32WithSeeds(
-   a += x;
-   a += y;
-   b += Rotate(a, 44);
--  return make_pair(a + z, b + c);
-+  return Pair{a + z, b + c};
- }
- 
- // Return a 16-byte hash for s[0] ... s[31], a, and b.  Quick and dirty.
--STATIC_INLINE pair<uint64_t, uint64_t> WeakHashLen32WithSeeds(
-+STATIC_INLINE Pair WeakHashLen32WithSeeds(
-     const char* s, uint64_t a, uint64_t b) {
-   return WeakHashLen32WithSeeds(Fetch(s),
-                                 Fetch(s + 8),
-@@ -510,7 +550,7 @@ STATIC_INLINE uint64_t HashLen33to64(const char *s, size_t len) {
-                    e + Rotate(f + a, 18) + g, mul);
- }
- 
--uint64_t Hash64(const char *s, size_t len) {
-+DEVICE_MODIFIER uint64_t Hash64(const char *s, size_t len) {
-   const uint64_t seed = 81;
-   if (len <= 32) {
-     if (len <= 16) {
-@@ -527,8 +567,8 @@ uint64_t Hash64(const char *s, size_t len) {
-   uint64_t x = seed;
-   uint64_t y = seed * k1 + 113;
-   uint64_t z = ShiftMix(y * k2 + 113) * k2;
--  pair<uint64_t, uint64_t> v = make_pair(0, 0);
--  pair<uint64_t, uint64_t> w = make_pair(0, 0);
-+  Pair v = {0, 0};
-+  Pair w = {0, 0};
-   x = x * k2 + Fetch(s);
- 
-   // Set end so that after the loop we have 1 to 64 bytes left to process.
-@@ -543,7 +583,9 @@ uint64_t Hash64(const char *s, size_t len) {
-     z = Rotate(z + w.first, 33) * k1;
-     v = WeakHashLen32WithSeeds(s, v.second * k1, x + w.first);
-     w = WeakHashLen32WithSeeds(s + 32, z + w.second, y + Fetch(s + 16));
--    std::swap(z, x);
-+    auto tmp = z;
-+    z = x;
-+    x = tmp;
-     s += 64;
-   } while (s != end);
-   uint64_t mul = k1 + ((z & 0xff) << 1);
-@@ -559,12 +601,15 @@ uint64_t Hash64(const char *s, size_t len) {
-   z = Rotate(z + w.first, 33) * mul;
-   v = WeakHashLen32WithSeeds(s, v.second * mul, x + w.first);
-   w = WeakHashLen32WithSeeds(s + 32, z + w.second, y + Fetch(s + 16));
--  std::swap(z, x);
-+  auto tmp = z;
-+  z = x;
-+  x = tmp;
-   return HashLen16(HashLen16(v.first, w.first, mul) + ShiftMix(y) * k0 + z,
-                    HashLen16(v.second, w.second, mul) + x,
-                    mul);
- }
- 
-+#ifndef DEVICE_MODIFIER
- uint64_t Hash64WithSeeds(const char *s, size_t len, uint64_t seed0, uint64_t seed1);
- 
- uint64_t Hash64WithSeed(const char *s, size_t len, uint64_t seed) {
-@@ -574,7 +619,9 @@ uint64_t Hash64WithSeed(const char *s, size_t len, uint64_t seed) {
- uint64_t Hash64WithSeeds(const char *s, size_t len, uint64_t seed0, uint64_t seed1) {
-   return HashLen16(Hash64(s, len) - seed0, seed1);
- }
--}  // namespace farmhashna
-+#endif // DEVICE_MODIFIER
-+}  // namespace farmhashna_gpu
-+#ifndef DEVICE_MODIFIER
- namespace farmhashuo {
- #undef Fetch
- #define Fetch Fetch64
-@@ -1864,8 +1911,10 @@ uint128_t Fingerprint128(const char* s, size_t len) {
-   return CityHash128(s, len);
- }
- }  // namespace farmhashcc
--namespace NAMESPACE_FOR_HASH_FUNCTIONS {
-+#endif // DEVICE_MODIFIER
-+namespace NAMESPACE_FOR_HASH_FUNCTIONS_GPU {
- 
-+#ifndef DEVICE_MODIFIER
- // BASIC STRING HASHING
- 
- // Hash function for a byte array.  See also Hash(), below.
-@@ -1948,12 +1997,14 @@ uint128_t Hash128WithSeed(const char* s, size_t len, uint128_t seed) {
- uint32_t Fingerprint32(const char* s, size_t len) {
-   return farmhashmk::Hash32(s, len);
- }
-+#endif // DEVICE_MODIFIER
- 
- // Fingerprint function for a byte array.
--uint64_t Fingerprint64(const char* s, size_t len) {
--  return farmhashna::Hash64(s, len);
-+DEVICE_MODIFIER uint64_t Fingerprint64(const char* s, size_t len) {
-+  return farmhashna_gpu::Hash64(s, len);
- }
- 
-+#ifndef DEVICE_MODIFIER
- // Fingerprint function for a byte array.
- uint128_t Fingerprint128(const char* s, size_t len) {
-   return farmhashcc::Fingerprint128(s, len);
-@@ -1961,9 +2012,11 @@ uint128_t Fingerprint128(const char* s, size_t len) {
- 
- // Older and still available but perhaps not as fast as the above:
- //   farmhashns::Hash32{,WithSeed}()
-+#endif // DEVICE_MODIFIER
- 
--}  // namespace NAMESPACE_FOR_HASH_FUNCTIONS
-+}  // namespace NAMESPACE_FOR_HASH_FUNCTIONS_GPU
- 
-+#ifndef DEVICE_MODIFIER
- #if FARMHASHSELFTEST
- 
- #ifndef FARMHASH_SELF_TEST_GUARD
-@@ -11829,3 +11882,5 @@ int main() {
- }
- 
- #endif  // FARMHASHSELFTEST
-+#endif // DEVICE_MODIFIER
-+#endif // FARM_HASH_GPU_H_
--- 
-2.17.1
-
diff --git a/third_party/xla/third_party/farmhash/workspace.bzl b/third_party/xla/third_party/farmhash/workspace.bzl
deleted file mode 100644
index f2733138..0000000
--- a/third_party/xla/third_party/farmhash/workspace.bzl
+++ /dev/null
@@ -1,29 +0,0 @@
-"""Provides the repository macro to import farmhash."""
-
-load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
-
-def repo():
-    """Imports farmhash."""
-
-    # Attention: tools parse and update these lines.
-    # LINT.IfChange
-    FARMHASH_COMMIT = "0d859a811870d10f53a594927d0d0b97573ad06d"
-    FARMHASH_SHA256 = "18392cf0736e1d62ecbb8d695c31496b6507859e8c75541d7ad0ba092dc52115"
-    # LINT.ThenChange(//tensorflow/lite/tools/cmake/modules/farmhash.cmake)
-
-    tf_http_archive(
-        name = "farmhash_archive",
-        build_file = "//third_party/farmhash:farmhash.BUILD",
-        sha256 = FARMHASH_SHA256,
-        strip_prefix = "farmhash-{commit}".format(commit = FARMHASH_COMMIT),
-        urls = tf_mirror_urls("https://github.com/google/farmhash/archive/{commit}.tar.gz".format(commit = FARMHASH_COMMIT)),
-    )
-
-    tf_http_archive(
-        name = "farmhash_gpu_archive",
-        build_file = "//third_party/farmhash:farmhash_gpu.BUILD",
-        patch_file = ["//third_party/farmhash:farmhash_support_cuda.patch"],
-        sha256 = FARMHASH_SHA256,
-        strip_prefix = "farmhash-{commit}".format(commit = FARMHASH_COMMIT),
-        urls = tf_mirror_urls("https://github.com/google/farmhash/archive/{commit}.tar.gz".format(commit = FARMHASH_COMMIT)),
-    )
diff --git a/third_party/xla/third_party/gemmlowp/BUILD b/third_party/xla/third_party/gemmlowp/BUILD
deleted file mode 100644
index 3c41380..0000000
--- a/third_party/xla/third_party/gemmlowp/BUILD
+++ /dev/null
@@ -1 +0,0 @@
-# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"])
diff --git a/third_party/xla/third_party/gemmlowp/workspace.bzl b/third_party/xla/third_party/gemmlowp/workspace.bzl
deleted file mode 100644
index b980355..0000000
--- a/third_party/xla/third_party/gemmlowp/workspace.bzl
+++ /dev/null
@@ -1,19 +0,0 @@
-"""Provides the repository macro to import gemmlowp."""
-
-load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
-
-def repo():
-    """Imports gemmlowp."""
-
-    # Attention: tools parse and update these lines.
-    # LINT.IfChange
-    GEMMLOWP_COMMIT = "e844ffd17118c1e17d94e1ba4354c075a4577b88"
-    GEMMLOWP_SHA256 = "522b7a82d920ebd0c4408a5365866a40b81d1c0d60b2369011d315cca03c6476"
-    # LINT.ThenChange(//tensorflow/lite/tools/cmake/modules/gemmlowp.cmake)
-
-    tf_http_archive(
-        name = "gemmlowp",
-        sha256 = GEMMLOWP_SHA256,
-        strip_prefix = "gemmlowp-{commit}".format(commit = GEMMLOWP_COMMIT),
-        urls = tf_mirror_urls("https://github.com/google/gemmlowp/archive/{commit}.zip".format(commit = GEMMLOWP_COMMIT)),
-    )
diff --git a/third_party/xla/third_party/gif.BUILD b/third_party/xla/third_party/gif.BUILD
deleted file mode 100644
index 51621ba..0000000
--- a/third_party/xla/third_party/gif.BUILD
+++ /dev/null
@@ -1,61 +0,0 @@
-# Description:
-#   A library for decoding and encoding GIF images
-
-licenses(["notice"])  # MIT
-
-exports_files(["COPYING"])
-
-cc_library(
-    name = "gif",
-    srcs = [
-        "dgif_lib.c",
-        "egif_lib.c",
-        "gif_err.c",
-        "gif_font.c",
-        "gif_hash.c",
-        "gif_hash.h",
-        "gif_lib_private.h",
-        "gifalloc.c",
-        "openbsd-reallocarray.c",
-        "quantize.c",
-    ],
-    hdrs = ["gif_lib.h"],
-    defines = select({
-        ":android": [
-            "S_IREAD=S_IRUSR",
-            "S_IWRITE=S_IWUSR",
-            "S_IEXEC=S_IXUSR",
-        ],
-        "//conditions:default": [],
-    }),
-    includes = ["."],
-    visibility = ["//visibility:public"],
-    deps = select({
-        ":windows": [":windows_polyfill"],
-        "//conditions:default": [],
-    }),
-)
-
-cc_library(
-    name = "windows_polyfill",
-    hdrs = ["windows/unistd.h"],
-    includes = ["windows"],
-)
-
-genrule(
-    name = "windows_unistd_h",
-    outs = ["windows/unistd.h"],
-    cmd = "touch $@",
-)
-
-config_setting(
-    name = "windows",
-    values = {
-        "cpu": "x64_windows",
-    },
-)
-
-config_setting(
-    name = "android",
-    values = {"crosstool_top": "//external:android/crosstool"},
-)
diff --git a/third_party/xla/third_party/gif_fix_strtok_r.patch b/third_party/xla/third_party/gif_fix_strtok_r.patch
deleted file mode 100644
index c9c9c30..0000000
--- a/third_party/xla/third_party/gif_fix_strtok_r.patch
+++ /dev/null
@@ -1,15 +0,0 @@
-diff -r -u ./fixed_gif_font.c ./gif_font.c
---- ./fixed_gif_font.c	2019-09-05 11:05:25.009598262 -0700
-+++ ./gif_font.c	2019-09-05 10:52:45.308389085 -0700
-@@ -11,6 +11,11 @@
-
- #include "gif_lib.h"
-
-+// Windows doesn't have strtok_r.
-+#if defined(WIN32) || defined(_WIN32) || defined(__WIN32) && !defined(__CYGWIN__)
-+#define strtok_r strtok_s
-+#endif
-+
- /*****************************************************************************
-  Ascii 8 by 8 regular font - only first 128 characters are supported.
- *****************************************************************************/
diff --git a/third_party/xla/third_party/git/BUILD b/third_party/xla/third_party/git/BUILD
deleted file mode 100644
index e69de29..0000000
--- a/third_party/xla/third_party/git/BUILD
+++ /dev/null
diff --git a/third_party/xla/third_party/git/BUILD.tpl b/third_party/xla/third_party/git/BUILD.tpl
deleted file mode 100644
index 7b031e7..0000000
--- a/third_party/xla/third_party/git/BUILD.tpl
+++ /dev/null
@@ -1,10 +0,0 @@
-# Description:
-# Exports generated files used to generate tensorflow/core/util/version_info.cc
-
-package(default_visibility = ["//visibility:public"])
-
-licenses(["notice"])
-
-exports_files(
-    glob(["gen/*"]),
-)
diff --git a/third_party/xla/third_party/git/git_configure.bzl b/third_party/xla/third_party/git/git_configure.bzl
deleted file mode 100644
index d74f300..0000000
--- a/third_party/xla/third_party/git/git_configure.bzl
+++ /dev/null
@@ -1,71 +0,0 @@
-"""Repository rule for Git autoconfiguration.
-
-`git_configure` depends on the following environment variables:
-
-  * `PYTHON_BIN_PATH`: location of python binary.
-"""
-
-_PYTHON_BIN_PATH = "PYTHON_BIN_PATH"
-
-def _fail(msg):
-    """Output failure message when auto configuration fails."""
-    red = "\033[0;31m"
-    no_color = "\033[0m"
-    fail("%sGit Configuration Error:%s %s\n" % (red, no_color, msg))
-
-def _get_python_bin(repository_ctx):
-    """Gets the python bin path."""
-    python_bin = repository_ctx.os.environ.get(_PYTHON_BIN_PATH)
-    if python_bin != None:
-        return python_bin
-    python_bin_path = repository_ctx.which("python3")
-    if python_bin_path != None:
-        return str(python_bin_path)
-    python_bin_path = repository_ctx.which("python")
-    if python_bin_path != None:
-        return str(python_bin_path)
-    _fail("Cannot find python in PATH, please make sure " +
-          "python is installed and add its directory in PATH, or --define " +
-          "%s='/something/else'.\nPATH=%s" % (
-              _PYTHON_BIN_PATH,
-              repository_ctx.os.environ.get("PATH", ""),
-          ))
-
-def _git_conf_impl(repository_ctx):
-    repository_ctx.template(
-        "BUILD",
-        Label("//third_party/git:BUILD.tpl"),
-    )
-
-    tensorflow_root_path = str(repository_ctx.path(
-        Label("@local_xla//:BUILD"),
-    ))[:-len("BUILD")]
-    python_script_path = repository_ctx.path(
-        Label("@local_xla//tensorflow/tools/git:gen_git_source.py"),
-    )
-    generated_files_path = repository_ctx.path("gen")
-
-    r = repository_ctx.execute(
-        ["test", "-f", "%s/.git/logs/HEAD" % tensorflow_root_path],
-    )
-    if r.return_code == 0:
-        unused_var = repository_ctx.path(Label("//:.git/HEAD"))  # pylint: disable=unused-variable
-
-    result = repository_ctx.execute([
-        _get_python_bin(repository_ctx),
-        python_script_path,
-        "--configure",
-        tensorflow_root_path,
-        "--gen_root_path",
-        generated_files_path,
-    ], quiet = False)
-
-    if not result.return_code == 0:
-        _fail(result.stderr)
-
-git_configure = repository_rule(
-    implementation = _git_conf_impl,
-    environ = [
-        _PYTHON_BIN_PATH,
-    ],
-)
diff --git a/third_party/xla/third_party/gpus/BUILD b/third_party/xla/third_party/gpus/BUILD
deleted file mode 100644
index e69de29..0000000
--- a/third_party/xla/third_party/gpus/BUILD
+++ /dev/null
diff --git a/third_party/xla/third_party/gpus/check_cuda_libs.py b/third_party/xla/third_party/gpus/check_cuda_libs.py
deleted file mode 100644
index b7d98ef..0000000
--- a/third_party/xla/third_party/gpus/check_cuda_libs.py
+++ /dev/null
@@ -1,83 +0,0 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Verifies that a list of libraries is installed on the system.
-
-Takes a list of arguments with every two subsequent arguments being a logical
-tuple of (path, check_soname). The path to the library and either True or False
-to indicate whether to check the soname field on the shared library.
-
-Example Usage:
-./check_cuda_libs.py /path/to/lib1.so True /path/to/lib2.so False
-"""
-import os
-import os.path
-import subprocess
-import sys
-
-# pylint: disable=g-import-not-at-top,g-importing-member
-try:
-  from shutil import which
-except ImportError:
-  from distutils.spawn import find_executable as which
-# pylint: enable=g-import-not-at-top,g-importing-member
-
-
-class ConfigError(Exception):
-  pass
-
-
-def check_cuda_lib(path, check_soname=True):
-  """Tests if a library exists on disk and whether its soname matches the filename.
-
-  Args:
-    path: the path to the library.
-    check_soname: whether to check the soname as well.
-
-  Raises:
-    ConfigError: If the library does not exist or if its soname does not match
-    the filename.
-  """
-  if not os.path.isfile(path):
-    raise ConfigError("No library found under: " + path)
-  objdump = which("objdump")
-  if check_soname and objdump is not None:
-    # Decode is necessary as in py3 the return type changed from str to bytes
-    output = subprocess.check_output([objdump, "-p", path]).decode("utf-8")
-    output = [line for line in output.splitlines() if "SONAME" in line]
-    sonames = [line.strip().split(" ")[-1] for line in output]
-    if not any(soname == os.path.basename(path) for soname in sonames):
-      raise ConfigError("None of the libraries match their SONAME: " + path)
-
-
-def main():
-  try:
-    args = [argv for argv in sys.argv[1:]]
-    if len(args) % 2 == 1:
-      raise ConfigError("Expected even number of arguments")
-    checked_paths = []
-    for i in range(0, len(args), 2):
-      path = args[i]
-      check_cuda_lib(path, check_soname=args[i + 1] == "True")
-      checked_paths.append(path)
-    # pylint: disable=superfluous-parens
-    print(os.linesep.join(checked_paths))
-    # pylint: enable=superfluous-parens
-  except ConfigError as e:
-    sys.stderr.write(str(e))
-    sys.exit(1)
-
-
-if __name__ == "__main__":
-  main()
diff --git a/third_party/xla/third_party/gpus/crosstool/BUILD b/third_party/xla/third_party/gpus/crosstool/BUILD
deleted file mode 100644
index e69de29..0000000
--- a/third_party/xla/third_party/gpus/crosstool/BUILD
+++ /dev/null
diff --git a/third_party/xla/third_party/gpus/crosstool/BUILD.rocm.tpl b/third_party/xla/third_party/gpus/crosstool/BUILD.rocm.tpl
deleted file mode 100644
index a742cfc..0000000
--- a/third_party/xla/third_party/gpus/crosstool/BUILD.rocm.tpl
+++ /dev/null
@@ -1,118 +0,0 @@
-# This file is expanded from a template by cuda_configure.bzl
-# Update cuda_configure.bzl#verify_build_defines when adding new variables.
-
-load(":cc_toolchain_config.bzl", "cc_toolchain_config")
-
-licenses(["restricted"])
-
-package(default_visibility = ["//visibility:public"])
-
-toolchain(
-    name = "toolchain-linux-x86_64",
-    exec_compatible_with = [
-        "@platforms//os:linux",
-        "@platforms//cpu:x86_64",
-    ],
-    target_compatible_with = [
-        "@platforms//os:linux",
-        "@platforms//cpu:x86_64",
-    ],
-    toolchain = ":cc-compiler-local",
-    toolchain_type = "@bazel_tools//tools/cpp:toolchain_type",
-)
-
-cc_toolchain_suite(
-    name = "toolchain",
-    toolchains = {
-        "local|compiler": ":cc-compiler-local",
-        "arm": ":cc-compiler-local",
-        "aarch64": ":cc-compiler-local",
-        "k8": ":cc-compiler-local",
-        "piii": ":cc-compiler-local",
-        "ppc": ":cc-compiler-local",
-    },
-)
-
-cc_toolchain(
-    name = "cc-compiler-local",
-    all_files = ":crosstool_wrapper_driver_is_not_gcc",
-    compiler_files = ":crosstool_wrapper_driver_is_not_gcc",
-    ar_files = ":crosstool_wrapper_driver_is_not_gcc",
-    as_files = ":crosstool_wrapper_driver_is_not_gcc",
-    dwp_files = ":empty",
-    linker_files = ":crosstool_wrapper_driver_is_not_gcc",
-    objcopy_files = ":empty",
-    strip_files = ":empty",
-    # To support linker flags that need to go to the start of command line
-    # we need the toolchain to support parameter files. Parameter files are
-    # last on the command line and contain all shared libraries to link, so all
-    # regular options will be left of them.
-    supports_param_files = 1,
-    toolchain_identifier = "local_linux",
-    toolchain_config = ":cc-compiler-local-config",
-)
-
-cc_toolchain_config(
-    name = "cc-compiler-local-config",
-    cpu = "local",
-    compiler = "compiler",
-    toolchain_identifier = "local_linux",
-    host_system_name = "local",
-    target_system_name = "local",
-    target_libc = "local",
-    abi_version = "local",
-    abi_libc_version = "local",
-    cxx_builtin_include_directories = [%{cxx_builtin_include_directories}],
-    host_compiler_path = "%{host_compiler_path}",
-    host_compiler_prefix = "%{host_compiler_prefix}",
-    compile_flags = [
-        "-U_FORTIFY_SOURCE",
-        "-fstack-protector",
-        "-Wall",
-        "-Wunused-but-set-parameter",
-        "-Wno-free-nonheap-object",
-        "-fno-omit-frame-pointer",
-    ],
-    opt_compile_flags = [
-        "-g0",
-        "-O2",
-        "-D_FORTIFY_SOURCE=1",
-        "-DNDEBUG",
-        "-ffunction-sections",
-        "-fdata-sections",
-    ],
-    dbg_compile_flags = ["-g"],
-    cxx_flags = ["-std=c++14"],
-    link_flags = [
-        "-fuse-ld=gold",
-        "-Wl,-no-as-needed",
-        "-Wl,-z,relro,-z,now",
-        "-pass-exit-codes",
-        "-lstdc++",
-        "-lm",
-    ],
-    link_libs = [],
-    opt_link_flags = [],
-    unfiltered_compile_flags = [
-        "-fno-canonical-system-headers",
-        "-Wno-builtin-macro-redefined",
-        "-D__DATE__=\"redacted\"",
-        "-D__TIMESTAMP__=\"redacted\"",
-        "-D__TIME__=\"redacted\"",
-    ] + [%{unfiltered_compile_flags}],
-    linker_bin_path = "%{linker_bin_path}",
-    coverage_compile_flags = ["--coverage"],
-    coverage_link_flags = ["--coverage"],
-    supports_start_end_lib = True,
-)
-
-filegroup(
-    name = "empty",
-    srcs = [],
-)
-
-filegroup(
-    name = "crosstool_wrapper_driver_is_not_gcc",
-    srcs = ["clang/bin/crosstool_wrapper_driver_is_not_gcc"],
-)
-
diff --git a/third_party/xla/third_party/gpus/crosstool/BUILD.tpl b/third_party/xla/third_party/gpus/crosstool/BUILD.tpl
deleted file mode 100644
index 8eda7a1..0000000
--- a/third_party/xla/third_party/gpus/crosstool/BUILD.tpl
+++ /dev/null
@@ -1,144 +0,0 @@
-# This file is expanded from a template by cuda_configure.bzl
-# Update cuda_configure.bzl#verify_build_defines when adding new variables.
-
-load(":cc_toolchain_config.bzl", "cc_toolchain_config")
-
-licenses(["restricted"])
-
-package(default_visibility = ["//visibility:public"])
-
-toolchain(
-    name = "toolchain-linux-x86_64",
-    exec_compatible_with = [
-        "@platforms//os:linux",
-        "@platforms//cpu:x86_64",
-    ],
-    target_compatible_with = [
-        "@platforms//os:linux",
-        "@platforms//cpu:x86_64",
-    ],
-    toolchain = ":cc-compiler-local",
-    toolchain_type = "@bazel_tools//tools/cpp:toolchain_type",
-)
-
-cc_toolchain_suite(
-    name = "toolchain",
-    toolchains = {
-        "local|compiler": ":cc-compiler-local",
-        "darwin|compiler": ":cc-compiler-darwin",
-        "x64_windows|msvc-cl": ":cc-compiler-windows",
-        "x64_windows": ":cc-compiler-windows",
-        "arm": ":cc-compiler-local",
-        "aarch64": ":cc-compiler-local",
-        "k8": ":cc-compiler-local",
-        "piii": ":cc-compiler-local",
-        "ppc": ":cc-compiler-local",
-        "darwin": ":cc-compiler-darwin",
-    },
-)
-
-cc_toolchain(
-    name = "cc-compiler-local",
-    all_files = "%{compiler_deps}",
-    compiler_files = "%{compiler_deps}",
-    ar_files = "%{compiler_deps}",
-    as_files = "%{compiler_deps}",
-    dwp_files = ":empty",
-    linker_files = "%{compiler_deps}",
-    objcopy_files = ":empty",
-    strip_files = ":empty",
-    # To support linker flags that need to go to the start of command line
-    # we need the toolchain to support parameter files. Parameter files are
-    # last on the command line and contain all shared libraries to link, so all
-    # regular options will be left of them.
-    supports_param_files = 1,
-    toolchain_identifier = "local_linux",
-    toolchain_config = ":cc-compiler-local-config",
-)
-
-cc_toolchain_config(
-    name = "cc-compiler-local-config",
-    cpu = "local",
-    builtin_include_directories = [%{cxx_builtin_include_directories}],
-    extra_no_canonical_prefixes_flags = [%{extra_no_canonical_prefixes_flags}],
-    host_compiler_path = "%{host_compiler_path}",
-    host_compiler_prefix = "%{host_compiler_prefix}",
-    host_compiler_warnings = [%{host_compiler_warnings}],
-    host_unfiltered_compile_flags = [%{unfiltered_compile_flags}],
-    linker_bin_path = "%{linker_bin_path}",
-    builtin_sysroot = "%{builtin_sysroot}",
-    cuda_path = "%{cuda_toolkit_path}",
-    compiler = "%{compiler}",
-)
-
-cc_toolchain(
-    name = "cc-compiler-darwin",
-    all_files = "%{compiler_deps}",
-    compiler_files = "%{compiler_deps}",
-    ar_files = "%{compiler_deps}",
-    as_files = "%{compiler_deps}",
-    dwp_files = ":empty",
-    linker_files = "%{compiler_deps}",
-    objcopy_files = ":empty",
-    strip_files = ":empty",
-    supports_param_files = 0,
-    toolchain_identifier = "local_darwin",
-    toolchain_config = ":cc-compiler-local-darwin",
-)
-
-cc_toolchain_config(
-    name = "cc-compiler-local-darwin",
-    cpu = "darwin",
-    builtin_include_directories = [%{cxx_builtin_include_directories}],
-    extra_no_canonical_prefixes_flags = [%{extra_no_canonical_prefixes_flags}],
-    host_compiler_path = "%{host_compiler_path}",
-    host_compiler_prefix = "%{host_compiler_prefix}",
-    host_compiler_warnings = [%{host_compiler_warnings}],
-    host_unfiltered_compile_flags = [%{unfiltered_compile_flags}],
-    linker_bin_path = "%{linker_bin_path}",
-)
-
-cc_toolchain(
-    name = "cc-compiler-windows",
-    all_files = "%{win_compiler_deps}",
-    compiler_files = "%{win_compiler_deps}",
-    ar_files = "%{win_compiler_deps}",
-    as_files = "%{win_compiler_deps}",
-    dwp_files = ":empty",
-    linker_files = "%{win_compiler_deps}",
-    objcopy_files = ":empty",
-    strip_files = ":empty",
-    supports_param_files = 1,
-    toolchain_identifier = "local_windows",
-    toolchain_config = ":cc-compiler-windows-config",
-)
-
-cc_toolchain_config(
-    name = "cc-compiler-windows-config",
-    cpu = "x64_windows",
-    builtin_include_directories = [%{cxx_builtin_include_directories}],
-    msvc_cl_path = "%{msvc_cl_path}",
-    msvc_env_include = "%{msvc_env_include}",
-    msvc_env_lib = "%{msvc_env_lib}",
-    msvc_env_path = "%{msvc_env_path}",
-    msvc_env_tmp = "%{msvc_env_tmp}",
-    msvc_lib_path = "%{msvc_lib_path}",
-    msvc_link_path = "%{msvc_link_path}",
-    msvc_ml_path = "%{msvc_ml_path}",
-    compiler = "msvc",
-)
-
-filegroup(
-    name = "empty",
-    srcs = [],
-)
-
-filegroup(
-    name = "crosstool_wrapper_driver_is_not_gcc",
-    srcs = ["clang/bin/crosstool_wrapper_driver_is_not_gcc"],
-)
-
-filegroup(
-    name = "windows_msvc_wrapper_files",
-    srcs = glob(["windows/msvc_*"]),
-)
diff --git a/third_party/xla/third_party/gpus/crosstool/LICENSE b/third_party/xla/third_party/gpus/crosstool/LICENSE
deleted file mode 100644
index d3da228..0000000
--- a/third_party/xla/third_party/gpus/crosstool/LICENSE
+++ /dev/null
@@ -1,203 +0,0 @@
-Copyright 2015 The TensorFlow Authors.  All rights reserved.
-
-                                 Apache License
-                           Version 2.0, January 2004
-                        http://www.apache.org/licenses/
-
-   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-
-   1. Definitions.
-
-      "License" shall mean the terms and conditions for use, reproduction,
-      and distribution as defined by Sections 1 through 9 of this document.
-
-      "Licensor" shall mean the copyright owner or entity authorized by
-      the copyright owner that is granting the License.
-
-      "Legal Entity" shall mean the union of the acting entity and all
-      other entities that control, are controlled by, or are under common
-      control with that entity. For the purposes of this definition,
-      "control" means (i) the power, direct or indirect, to cause the
-      direction or management of such entity, whether by contract or
-      otherwise, or (ii) ownership of fifty percent (50%) or more of the
-      outstanding shares, or (iii) beneficial ownership of such entity.
-
-      "You" (or "Your") shall mean an individual or Legal Entity
-      exercising permissions granted by this License.
-
-      "Source" form shall mean the preferred form for making modifications,
-      including but not limited to software source code, documentation
-      source, and configuration files.
-
-      "Object" form shall mean any form resulting from mechanical
-      transformation or translation of a Source form, including but
-      not limited to compiled object code, generated documentation,
-      and conversions to other media types.
-
-      "Work" shall mean the work of authorship, whether in Source or
-      Object form, made available under the License, as indicated by a
-      copyright notice that is included in or attached to the work
-      (an example is provided in the Appendix below).
-
-      "Derivative Works" shall mean any work, whether in Source or Object
-      form, that is based on (or derived from) the Work and for which the
-      editorial revisions, annotations, elaborations, or other modifications
-      represent, as a whole, an original work of authorship. For the purposes
-      of this License, Derivative Works shall not include works that remain
-      separable from, or merely link (or bind by name) to the interfaces of,
-      the Work and Derivative Works thereof.
-
-      "Contribution" shall mean any work of authorship, including
-      the original version of the Work and any modifications or additions
-      to that Work or Derivative Works thereof, that is intentionally
-      submitted to Licensor for inclusion in the Work by the copyright owner
-      or by an individual or Legal Entity authorized to submit on behalf of
-      the copyright owner. For the purposes of this definition, "submitted"
-      means any form of electronic, verbal, or written communication sent
-      to the Licensor or its representatives, including but not limited to
-      communication on electronic mailing lists, source code control systems,
-      and issue tracking systems that are managed by, or on behalf of, the
-      Licensor for the purpose of discussing and improving the Work, but
-      excluding communication that is conspicuously marked or otherwise
-      designated in writing by the copyright owner as "Not a Contribution."
-
-      "Contributor" shall mean Licensor and any individual or Legal Entity
-      on behalf of whom a Contribution has been received by Licensor and
-      subsequently incorporated within the Work.
-
-   2. Grant of Copyright License. Subject to the terms and conditions of
-      this License, each Contributor hereby grants to You a perpetual,
-      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
-      copyright license to reproduce, prepare Derivative Works of,
-      publicly display, publicly perform, sublicense, and distribute the
-      Work and such Derivative Works in Source or Object form.
-
-   3. Grant of Patent License. Subject to the terms and conditions of
-      this License, each Contributor hereby grants to You a perpetual,
-      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
-      (except as stated in this section) patent license to make, have made,
-      use, offer to sell, sell, import, and otherwise transfer the Work,
-      where such license applies only to those patent claims licensable
-      by such Contributor that are necessarily infringed by their
-      Contribution(s) alone or by combination of their Contribution(s)
-      with the Work to which such Contribution(s) was submitted. If You
-      institute patent litigation against any entity (including a
-      cross-claim or counterclaim in a lawsuit) alleging that the Work
-      or a Contribution incorporated within the Work constitutes direct
-      or contributory patent infringement, then any patent licenses
-      granted to You under this License for that Work shall terminate
-      as of the date such litigation is filed.
-
-   4. Redistribution. You may reproduce and distribute copies of the
-      Work or Derivative Works thereof in any medium, with or without
-      modifications, and in Source or Object form, provided that You
-      meet the following conditions:
-
-      (a) You must give any other recipients of the Work or
-          Derivative Works a copy of this License; and
-
-      (b) You must cause any modified files to carry prominent notices
-          stating that You changed the files; and
-
-      (c) You must retain, in the Source form of any Derivative Works
-          that You distribute, all copyright, patent, trademark, and
-          attribution notices from the Source form of the Work,
-          excluding those notices that do not pertain to any part of
-          the Derivative Works; and
-
-      (d) If the Work includes a "NOTICE" text file as part of its
-          distribution, then any Derivative Works that You distribute must
-          include a readable copy of the attribution notices contained
-          within such NOTICE file, excluding those notices that do not
-          pertain to any part of the Derivative Works, in at least one
-          of the following places: within a NOTICE text file distributed
-          as part of the Derivative Works; within the Source form or
-          documentation, if provided along with the Derivative Works; or,
-          within a display generated by the Derivative Works, if and
-          wherever such third-party notices normally appear. The contents
-          of the NOTICE file are for informational purposes only and
-          do not modify the License. You may add Your own attribution
-          notices within Derivative Works that You distribute, alongside
-          or as an addendum to the NOTICE text from the Work, provided
-          that such additional attribution notices cannot be construed
-          as modifying the License.
-
-      You may add Your own copyright statement to Your modifications and
-      may provide additional or different license terms and conditions
-      for use, reproduction, or distribution of Your modifications, or
-      for any such Derivative Works as a whole, provided Your use,
-      reproduction, and distribution of the Work otherwise complies with
-      the conditions stated in this License.
-
-   5. Submission of Contributions. Unless You explicitly state otherwise,
-      any Contribution intentionally submitted for inclusion in the Work
-      by You to the Licensor shall be under the terms and conditions of
-      this License, without any additional terms or conditions.
-      Notwithstanding the above, nothing herein shall supersede or modify
-      the terms of any separate license agreement you may have executed
-      with Licensor regarding such Contributions.
-
-   6. Trademarks. This License does not grant permission to use the trade
-      names, trademarks, service marks, or product names of the Licensor,
-      except as required for reasonable and customary use in describing the
-      origin of the Work and reproducing the content of the NOTICE file.
-
-   7. Disclaimer of Warranty. Unless required by applicable law or
-      agreed to in writing, Licensor provides the Work (and each
-      Contributor provides its Contributions) on an "AS IS" BASIS,
-      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
-      implied, including, without limitation, any warranties or conditions
-      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
-      PARTICULAR PURPOSE. You are solely responsible for determining the
-      appropriateness of using or redistributing the Work and assume any
-      risks associated with Your exercise of permissions under this License.
-
-   8. Limitation of Liability. In no event and under no legal theory,
-      whether in tort (including negligence), contract, or otherwise,
-      unless required by applicable law (such as deliberate and grossly
-      negligent acts) or agreed to in writing, shall any Contributor be
-      liable to You for damages, including any direct, indirect, special,
-      incidental, or consequential damages of any character arising as a
-      result of this License or out of the use or inability to use the
-      Work (including but not limited to damages for loss of goodwill,
-      work stoppage, computer failure or malfunction, or any and all
-      other commercial damages or losses), even if such Contributor
-      has been advised of the possibility of such damages.
-
-   9. Accepting Warranty or Additional Liability. While redistributing
-      the Work or Derivative Works thereof, You may choose to offer,
-      and charge a fee for, acceptance of support, warranty, indemnity,
-      or other liability obligations and/or rights consistent with this
-      License. However, in accepting such obligations, You may act only
-      on Your own behalf and on Your sole responsibility, not on behalf
-      of any other Contributor, and only if You agree to indemnify,
-      defend, and hold each Contributor harmless for any liability
-      incurred by, or claims asserted against, such Contributor by reason
-      of your accepting any such warranty or additional liability.
-
-   END OF TERMS AND CONDITIONS
-
-   APPENDIX: How to apply the Apache License to your work.
-
-      To apply the Apache License to your work, attach the following
-      boilerplate notice, with the fields enclosed by brackets "[]"
-      replaced with your own identifying information. (Don't include
-      the brackets!)  The text should be enclosed in the appropriate
-      comment syntax for the file format. We also recommend that a
-      file or class name and description of purpose be included on the
-      same "printed page" as the copyright notice for easier
-      identification within third-party archives.
-
-   Copyright 2015, The TensorFlow Authors.
-
-   Licensed under the Apache License, Version 2.0 (the "License");
-   you may not use this file except in compliance with the License.
-   You may obtain a copy of the License at
-
-       http://www.apache.org/licenses/LICENSE-2.0
-
-   Unless required by applicable law or agreed to in writing, software
-   distributed under the License is distributed on an "AS IS" BASIS,
-   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-   See the License for the specific language governing permissions and
-   limitations under the License.
diff --git a/third_party/xla/third_party/gpus/crosstool/cc_toolchain_config.bzl.tpl b/third_party/xla/third_party/gpus/crosstool/cc_toolchain_config.bzl.tpl
deleted file mode 100644
index ffa305c..0000000
--- a/third_party/xla/third_party/gpus/crosstool/cc_toolchain_config.bzl.tpl
+++ /dev/null
@@ -1,1085 +0,0 @@
-"""cc_toolchain_config rule for configuring CUDA toolchains on Linux, Mac, and Windows."""
-
-load(
-    "@bazel_tools//tools/cpp:cc_toolchain_config_lib.bzl",
-    "action_config",
-    "artifact_name_pattern",
-    "env_entry",
-    "env_set",
-    "feature",
-    "feature_set",
-    "flag_group",
-    "flag_set",
-    "tool",
-    "tool_path",
-    "variable_with_value",
-    "with_feature_set",
-)
-load("@bazel_tools//tools/build_defs/cc:action_names.bzl", "ACTION_NAMES")
-
-def all_assembly_actions():
-    return [
-        ACTION_NAMES.assemble,
-        ACTION_NAMES.preprocess_assemble,
-    ]
-
-def all_compile_actions():
-    return [
-        ACTION_NAMES.assemble,
-        ACTION_NAMES.c_compile,
-        ACTION_NAMES.cpp_compile,
-        ACTION_NAMES.cpp_header_parsing,
-        ACTION_NAMES.cpp_module_codegen,
-        ACTION_NAMES.cpp_module_compile,
-        ACTION_NAMES.linkstamp_compile,
-        ACTION_NAMES.preprocess_assemble,
-    ]
-
-def all_c_compile_actions():
-    return [
-        ACTION_NAMES.c_compile,
-    ]
-
-def all_cpp_compile_actions():
-    return [
-        ACTION_NAMES.cpp_compile,
-        ACTION_NAMES.cpp_header_parsing,
-        ACTION_NAMES.cpp_module_codegen,
-        ACTION_NAMES.cpp_module_compile,
-        ACTION_NAMES.linkstamp_compile,
-    ]
-
-def all_preprocessed_actions():
-    return [
-        ACTION_NAMES.c_compile,
-        ACTION_NAMES.cpp_compile,
-        ACTION_NAMES.cpp_header_parsing,
-        ACTION_NAMES.cpp_module_codegen,
-        ACTION_NAMES.cpp_module_compile,
-        ACTION_NAMES.linkstamp_compile,
-        ACTION_NAMES.preprocess_assemble,
-    ]
-
-def all_link_actions():
-    return [
-        ACTION_NAMES.cpp_link_executable,
-        ACTION_NAMES.cpp_link_dynamic_library,
-        ACTION_NAMES.cpp_link_nodeps_dynamic_library,
-    ]
-
-def all_executable_link_actions():
-    return [
-        ACTION_NAMES.cpp_link_executable,
-    ]
-
-def all_shared_library_link_actions():
-    return [
-        ACTION_NAMES.cpp_link_dynamic_library,
-        ACTION_NAMES.cpp_link_nodeps_dynamic_library,
-    ]
-
-def all_archive_actions():
-    return [ACTION_NAMES.cpp_link_static_library]
-
-def all_strip_actions():
-    return [ACTION_NAMES.strip]
-
-def _library_to_link(flag_prefix, value, iterate = None):
-    return flag_group(
-        flags = [
-            "{}%{{libraries_to_link.{}}}".format(
-                flag_prefix,
-                iterate if iterate else "name",
-            ),
-        ],
-        iterate_over = ("libraries_to_link." + iterate if iterate else None),
-        expand_if_equal = variable_with_value(
-            name = "libraries_to_link.type",
-            value = value,
-        ),
-    )
-
-def _surround_static_library(prefix, suffix):
-    return [
-        flag_group(
-            flags = [prefix, "%{libraries_to_link.name}", suffix],
-            expand_if_true = "libraries_to_link.is_whole_archive",
-        ),
-        flag_group(
-            flags = ["%{libraries_to_link.name}"],
-            expand_if_false = "libraries_to_link.is_whole_archive",
-        ),
-    ]
-
-def _prefix_static_library(prefix):
-    return [
-        flag_group(
-            flags = ["%{libraries_to_link.name}"],
-            expand_if_false = "libraries_to_link.is_whole_archive",
-        ),
-        flag_group(
-            flags = [prefix + "%{libraries_to_link.name}"],
-            expand_if_true = "libraries_to_link.is_whole_archive",
-        ),
-    ]
-
-def _static_library_to_link(alwayslink_prefix, alwayslink_suffix = None):
-    if alwayslink_suffix:
-        flag_groups = _surround_static_library(alwayslink_prefix, alwayslink_suffix)
-    else:
-        flag_groups = _prefix_static_library(alwayslink_prefix)
-    return flag_group(
-        flag_groups = flag_groups,
-        expand_if_equal = variable_with_value(
-            name = "libraries_to_link.type",
-            value = "static_library",
-        ),
-    )
-
-def _iterate_flag_group(iterate_over, flags = [], flag_groups = []):
-    return flag_group(
-        iterate_over = iterate_over,
-        expand_if_available = iterate_over,
-        flag_groups = flag_groups,
-        flags = flags,
-    )
-
-def _libraries_to_link_group(flavour):
-    if flavour == "linux":
-        return _iterate_flag_group(
-            iterate_over = "libraries_to_link",
-            flag_groups = [
-                flag_group(
-                    flags = ["-Wl,--start-lib"],
-                    expand_if_equal = variable_with_value(
-                        name = "libraries_to_link.type",
-                        value = "object_file_group",
-                    ),
-                ),
-                _library_to_link("", "object_file_group", "object_files"),
-                flag_group(
-                    flags = ["-Wl,--end-lib"],
-                    expand_if_equal = variable_with_value(
-                        name = "libraries_to_link.type",
-                        value = "object_file_group",
-                    ),
-                ),
-                _library_to_link("", "object_file"),
-                _library_to_link("", "interface_library"),
-                _static_library_to_link("-Wl,-whole-archive", "-Wl,-no-whole-archive"),
-                _library_to_link("-l", "dynamic_library"),
-                _library_to_link("-l:", "versioned_dynamic_library"),
-            ],
-        )
-    elif flavour == "darwin":
-        return _iterate_flag_group(
-            iterate_over = "libraries_to_link",
-            flag_groups = [
-                _library_to_link("", "object_file_group", "object_files"),
-                _library_to_link("", "object_file"),
-                _library_to_link("", "interface_library"),
-                _static_library_to_link("-Wl,-force_load,"),
-                _library_to_link("-l", "dynamic_library"),
-                _library_to_link("-l:", "versioned_dynamic_library"),
-            ],
-        )
-    elif flavour == "msvc":
-        return _iterate_flag_group(
-            iterate_over = "libraries_to_link",
-            flag_groups = [
-                _library_to_link("", "object_file_group", "object_files"),
-                _library_to_link("", "object_file"),
-                _library_to_link("", "interface_library"),
-                _static_library_to_link("/WHOLEARCHIVE:"),
-            ],
-        )
-
-def _action_configs_with_tool(path, actions):
-    return [
-        action_config(
-            action_name = name,
-            enabled = True,
-            tools = [tool(path = path)],
-        )
-        for name in actions
-    ]
-
-def _action_configs(assembly_path, c_compiler_path, cc_compiler_path, archiver_path, linker_path, strip_path):
-    return _action_configs_with_tool(
-        assembly_path,
-        all_assembly_actions(),
-    ) + _action_configs_with_tool(
-        c_compiler_path,
-        all_c_compile_actions(),
-    ) + _action_configs_with_tool(
-        cc_compiler_path,
-        all_cpp_compile_actions(),
-    ) + _action_configs_with_tool(
-        archiver_path,
-        all_archive_actions(),
-    ) + _action_configs_with_tool(
-        linker_path,
-        all_link_actions(),
-    ) + _action_configs_with_tool(
-        strip_path,
-        all_strip_actions(),
-    )
-
-def _tool_paths(cpu, ctx):
-    if cpu in ["local", "darwin"]:
-        return [
-            tool_path(name = "gcc", path = ctx.attr.host_compiler_path),
-            tool_path(name = "ar", path = ctx.attr.host_compiler_prefix + (
-                "/ar" if cpu == "local" else "/libtool"
-            )),
-            tool_path(name = "compat-ld", path = ctx.attr.host_compiler_prefix + "/ld"),
-            tool_path(name = "cpp", path = ctx.attr.host_compiler_prefix + "/cpp"),
-            tool_path(name = "dwp", path = ctx.attr.host_compiler_prefix + "/dwp"),
-            tool_path(name = "gcov", path = ctx.attr.host_compiler_prefix + "/gcov"),
-            tool_path(name = "ld", path = ctx.attr.host_compiler_prefix + "/ld"),
-            tool_path(name = "nm", path = ctx.attr.host_compiler_prefix + "/nm"),
-            tool_path(name = "objcopy", path = ctx.attr.host_compiler_prefix + "/objcopy"),
-            tool_path(name = "objdump", path = ctx.attr.host_compiler_prefix + "/objdump"),
-            tool_path(name = "strip", path = ctx.attr.host_compiler_prefix + "/strip"),
-        ]
-    elif cpu == "x64_windows":
-        return [
-            tool_path(name = "ar", path = ctx.attr.msvc_lib_path),
-            tool_path(name = "ml", path = ctx.attr.msvc_ml_path),
-            tool_path(name = "cpp", path = ctx.attr.msvc_cl_path),
-            tool_path(name = "gcc", path = ctx.attr.msvc_cl_path),
-            tool_path(name = "gcov", path = "wrapper/bin/msvc_nop.bat"),
-            tool_path(name = "ld", path = ctx.attr.msvc_link_path),
-            tool_path(name = "nm", path = "wrapper/bin/msvc_nop.bat"),
-            tool_path(
-                name = "objcopy",
-                path = "wrapper/bin/msvc_nop.bat",
-            ),
-            tool_path(
-                name = "objdump",
-                path = "wrapper/bin/msvc_nop.bat",
-            ),
-            tool_path(
-                name = "strip",
-                path = "wrapper/bin/msvc_nop.bat",
-            ),
-        ]
-    else:
-        fail("Unreachable")
-
-def _sysroot_group():
-    return flag_group(
-        flags = ["--sysroot=%{sysroot}"],
-        expand_if_available = "sysroot",
-    )
-
-def _no_canonical_prefixes_group(extra_flags):
-    return flag_group(
-        flags = [
-            "-no-canonical-prefixes",
-        ] + extra_flags,
-    )
-
-def _cuda_set(cuda_path, actions):
-    if cuda_path:
-        return [flag_set(
-            actions = actions,
-            flag_groups = [
-                flag_group(
-                    flags = ["--cuda-path=" + cuda_path],
-                ),
-            ],
-        )]
-    else:
-        return []
-
-def _nologo():
-    return flag_group(flags = ["/nologo"])
-
-def _features(cpu, compiler, ctx):
-    if cpu in ["local", "darwin"]:
-        return [
-            feature(name = "no_legacy_features"),
-            feature(
-                name = "all_compile_flags",
-                enabled = True,
-                flag_sets = [
-                    flag_set(
-                        actions = all_compile_actions(),
-                        flag_groups = [
-                            flag_group(
-                                flags = ["-MD", "-MF", "%{dependency_file}"],
-                                expand_if_available = "dependency_file",
-                            ),
-                            flag_group(
-                                flags = ["-gsplit-dwarf"],
-                                expand_if_available = "per_object_debug_info_file",
-                            ),
-                        ],
-                    ),
-                    flag_set(
-                        actions = all_preprocessed_actions(),
-                        flag_groups = [
-                            flag_group(
-                                flags = ["-frandom-seed=%{output_file}"],
-                                expand_if_available = "output_file",
-                            ),
-                            _iterate_flag_group(
-                                flags = ["-D%{preprocessor_defines}"],
-                                iterate_over = "preprocessor_defines",
-                            ),
-                            _iterate_flag_group(
-                                flags = ["-include", "%{includes}"],
-                                iterate_over = "includes",
-                            ),
-                            _iterate_flag_group(
-                                flags = ["-iquote", "%{quote_include_paths}"],
-                                iterate_over = "quote_include_paths",
-                            ),
-                            _iterate_flag_group(
-                                flags = ["-I%{include_paths}"],
-                                iterate_over = "include_paths",
-                            ),
-                            _iterate_flag_group(
-                                flags = ["-isystem", "%{system_include_paths}"],
-                                iterate_over = "system_include_paths",
-                            ),
-                            _iterate_flag_group(
-                                flags = ["-F", "%{framework_include_paths}"],
-                                iterate_over = "framework_include_paths",
-                            ),
-                        ],
-                    ),
-                    flag_set(
-                        actions = all_cpp_compile_actions(),
-                        flag_groups = [
-                            flag_group(flags = [
-                                "-fmerge-all-constants",
-                            ]),
-                        ] if compiler == "clang" else [],
-                    ),
-                    flag_set(
-                        actions = all_compile_actions(),
-                        flag_groups = [
-                            flag_group(
-                                flags = [
-                                    "-Wno-builtin-macro-redefined",
-                                    "-D__DATE__=\"redacted\"",
-                                    "-D__TIMESTAMP__=\"redacted\"",
-                                    "-D__TIME__=\"redacted\"",
-                                ],
-                            ),
-                            flag_group(
-                                flags = ["-fPIC"],
-                                expand_if_available = "pic",
-                            ),
-                            flag_group(
-                                flags = ["-fPIE"],
-                                expand_if_not_available = "pic",
-                            ),
-                            flag_group(
-                                flags = [
-                                    "-U_FORTIFY_SOURCE",
-                                    "-D_FORTIFY_SOURCE=1",
-                                    "-fstack-protector",
-                                    "-Wall",
-                                ] + ctx.attr.host_compiler_warnings + [
-                                    "-fno-omit-frame-pointer",
-                                ],
-                            ),
-                            _no_canonical_prefixes_group(
-                                ctx.attr.extra_no_canonical_prefixes_flags,
-                            ),
-                        ],
-                    ),
-                    flag_set(
-                        actions = all_compile_actions(),
-                        flag_groups = [flag_group(flags = ["-DNDEBUG"])],
-                        with_features = [with_feature_set(features = ["disable-assertions"])],
-                    ),
-                    flag_set(
-                        actions = all_compile_actions(),
-                        flag_groups = [
-                            flag_group(
-                                flags = [
-                                    "-g0",
-                                    "-O2",
-                                    "-ffunction-sections",
-                                    "-fdata-sections",
-                                ],
-                            ),
-                        ],
-                        with_features = [with_feature_set(features = ["opt"])],
-                    ),
-                    flag_set(
-                        actions = all_compile_actions(),
-                        flag_groups = [flag_group(flags = ["-g"])],
-                        with_features = [with_feature_set(features = ["dbg"])],
-                    ),
-                ] + _cuda_set(
-                    ctx.attr.cuda_path,
-                    all_compile_actions(),
-                ) + [
-                    flag_set(
-                        actions = all_compile_actions(),
-                        flag_groups = [
-                            _iterate_flag_group(
-                                flags = ["%{user_compile_flags}"],
-                                iterate_over = "user_compile_flags",
-                            ),
-                            _sysroot_group(),
-                            flag_group(
-                                expand_if_available = "source_file",
-                                flags = ["-c", "%{source_file}"],
-                            ),
-                            flag_group(
-                                expand_if_available = "output_assembly_file",
-                                flags = ["-S"],
-                            ),
-                            flag_group(
-                                expand_if_available = "output_preprocess_file",
-                                flags = ["-E"],
-                            ),
-                            flag_group(
-                                expand_if_available = "output_file",
-                                flags = ["-o", "%{output_file}"],
-                            ),
-                        ],
-                    ),
-                ],
-            ),
-            feature(
-                name = "all_archive_flags",
-                enabled = True,
-                flag_sets = [
-                    flag_set(
-                        actions = all_archive_actions(),
-                        flag_groups = [
-                            flag_group(
-                                expand_if_available = "linker_param_file",
-                                flags = ["@%{linker_param_file}"],
-                            ),
-                            flag_group(flags = ["rcsD"]),
-                            flag_group(
-                                flags = ["%{output_execpath}"],
-                                expand_if_available = "output_execpath",
-                            ),
-                            flag_group(
-                                iterate_over = "libraries_to_link",
-                                flag_groups = [
-                                    flag_group(
-                                        flags = ["%{libraries_to_link.name}"],
-                                        expand_if_equal = variable_with_value(
-                                            name = "libraries_to_link.type",
-                                            value = "object_file",
-                                        ),
-                                    ),
-                                    flag_group(
-                                        flags = ["%{libraries_to_link.object_files}"],
-                                        iterate_over = "libraries_to_link.object_files",
-                                        expand_if_equal = variable_with_value(
-                                            name = "libraries_to_link.type",
-                                            value = "object_file_group",
-                                        ),
-                                    ),
-                                ],
-                                expand_if_available = "libraries_to_link",
-                            ),
-                        ],
-                    ),
-                ],
-            ),
-            feature(
-                name = "all_link_flags",
-                enabled = True,
-                flag_sets = [
-                    flag_set(
-                        actions = all_shared_library_link_actions(),
-                        flag_groups = [flag_group(flags = ["-shared"])],
-                    ),
-                    flag_set(
-                        actions = all_link_actions(),
-                        flag_groups = ([
-                            flag_group(flags = ["-Wl,-no-as-needed"])
-                        ] if cpu == "local" else []) + ([
-                            flag_group(flags = ["-B" + ctx.attr.linker_bin_path])
-                        ] if ctx.attr.linker_bin_path else []) + [
-                            flag_group(
-                                flags = ["@%{linker_param_file}"],
-                                expand_if_available = "linker_param_file",
-                            ),
-                            _iterate_flag_group(
-                                flags = ["%{linkstamp_paths}"],
-                                iterate_over = "linkstamp_paths",
-                            ),
-                            flag_group(
-                                flags = ["-o", "%{output_execpath}"],
-                                expand_if_available = "output_execpath",
-                            ),
-                            _iterate_flag_group(
-                                flags = ["-L%{library_search_directories}"],
-                                iterate_over = "library_search_directories",
-                            ),
-                            _iterate_flag_group(
-                                iterate_over = "runtime_library_search_directories",
-                                flags = [
-                                    "-Wl,-rpath,$ORIGIN/%{runtime_library_search_directories}",
-                                ] if cpu == "local" else [
-                                    "-Wl,-rpath,@loader_path/%{runtime_library_search_directories}",
-                                ],
-                            ),
-                            _libraries_to_link_group("darwin" if cpu == "darwin" else "linux"),
-                            _iterate_flag_group(
-                                flags = ["%{user_link_flags}"],
-                                iterate_over = "user_link_flags",
-                            ),
-                            flag_group(
-                                flags = ["-Wl,--gdb-index"],
-                                expand_if_available = "is_using_fission",
-                            ),
-                            flag_group(
-                                flags = ["-Wl,-S"],
-                                expand_if_available = "strip_debug_symbols",
-                            ),
-                            flag_group(flags = ["-lc++" if cpu == "darwin" else "-lstdc++"]),
-                            _no_canonical_prefixes_group(
-                                ctx.attr.extra_no_canonical_prefixes_flags,
-                            ),
-                        ],
-                    ),
-                    flag_set(
-                        actions = all_executable_link_actions(),
-                        flag_groups = [flag_group(flags = ["-pie"])],
-                    ),
-                ] + ([
-                    flag_set(
-                        actions = all_link_actions(),
-                        flag_groups = [flag_group(flags = [
-                            "-Wl,-z,relro,-z,now",
-                        ])],
-                    ),
-                ] if cpu == "local" else []) + ([
-                    flag_set(
-                        actions = all_link_actions(),
-                        flag_groups = [
-                            flag_group(flags = ["-Wl,--gc-sections"]),
-                            flag_group(
-                                flags = ["-Wl,--build-id=md5", "-Wl,--hash-style=gnu"],
-                            ),
-                        ],
-                    ),
-                ] if cpu == "local" else []) + ([
-                    flag_set(
-                        actions = all_link_actions(),
-                        flag_groups = [flag_group(flags = ["-undefined", "dynamic_lookup"])],
-                    ),
-                ] if cpu == "darwin" else []) + _cuda_set(
-                    ctx.attr.cuda_path,
-                    all_link_actions(),
-                ) + [
-                    flag_set(
-                        actions = all_link_actions(),
-                        flag_groups = [
-                            _sysroot_group(),
-                        ],
-                    ),
-                ],
-            ),
-            feature(name = "disable-assertions"),
-            feature(
-                name = "opt",
-                implies = ["disable-assertions"],
-            ),
-            feature(name = "fastbuild"),
-            feature(name = "dbg"),
-            feature(name = "supports_dynamic_linker", enabled = True),
-            feature(name = "pic", enabled = True),
-            feature(name = "supports_pic", enabled = True),
-            feature(name = "has_configured_linker_path", enabled = True),
-        ]
-    elif cpu == "x64_windows":
-        return [
-            feature(name = "compiler_param_file"),
-            feature(name = "no_legacy_features"),
-            feature(
-                name = "common_flags",
-                enabled = True,
-                env_sets = [
-                    env_set(
-                        actions = all_compile_actions() + all_link_actions() + all_archive_actions(),
-                        env_entries = [
-                            env_entry(key = "PATH", value = ctx.attr.msvc_env_path),
-                            env_entry(key = "INCLUDE", value = ctx.attr.msvc_env_include),
-                            env_entry(key = "LIB", value = ctx.attr.msvc_env_lib),
-                            env_entry(key = "TMP", value = ctx.attr.msvc_env_tmp),
-                            env_entry(key = "TEMP", value = ctx.attr.msvc_env_tmp),
-                        ],
-                    ),
-                ],
-            ),
-            feature(
-                name = "all_compile_flags",
-                enabled = True,
-                flag_sets = [
-                    flag_set(
-                        actions = all_compile_actions(),
-                        flag_groups = [
-                            _nologo(),
-                            flag_group(
-                                flags = [
-                                    "/DCOMPILER_MSVC",
-                                    "/DNOMINMAX",
-                                    "/D_WIN32_WINNT=0x0600",
-                                    "/D_CRT_SECURE_NO_DEPRECATE",
-                                    "/D_CRT_SECURE_NO_WARNINGS",
-                                    "/D_SILENCE_STDEXT_HASH_DEPRECATION_WARNINGS",
-                                    "/bigobj",
-                                    "/Zm500",
-                                    "/J",
-                                    "/Gy",
-                                    "/GF",
-                                    "/EHsc",
-                                    "/wd4351",
-                                    "/wd4291",
-                                    "/wd4250",
-                                    "/wd4996",
-                                ],
-                            ),
-                            _iterate_flag_group(
-                                flags = ["/I%{quote_include_paths}"],
-                                iterate_over = "quote_include_paths",
-                            ),
-                            _iterate_flag_group(
-                                flags = ["/I%{include_paths}"],
-                                iterate_over = "include_paths",
-                            ),
-                            _iterate_flag_group(
-                                flags = ["/I%{system_include_paths}"],
-                                iterate_over = "system_include_paths",
-                            ),
-                            _iterate_flag_group(
-                                flags = ["/D%{preprocessor_defines}"],
-                                iterate_over = "preprocessor_defines",
-                            ),
-                        ],
-                    ),
-                    flag_set(
-                        actions = all_preprocessed_actions(),
-                        flag_groups = [flag_group(flags = ["/showIncludes"])],
-                    ),
-                    flag_set(
-                        actions = all_compile_actions(),
-                        flag_groups = [flag_group(flags = ["/MT"])],
-                        with_features = [with_feature_set(features = ["static_link_msvcrt_no_debug"])],
-                    ),
-                    flag_set(
-                        actions = all_compile_actions(),
-                        flag_groups = [flag_group(flags = ["/MD"])],
-                        with_features = [with_feature_set(features = ["dynamic_link_msvcrt_no_debug"])],
-                    ),
-                    flag_set(
-                        actions = all_compile_actions(),
-                        flag_groups = [flag_group(flags = ["/MTd"])],
-                        with_features = [with_feature_set(features = ["static_link_msvcrt_debug"])],
-                    ),
-                    flag_set(
-                        actions = all_compile_actions(),
-                        flag_groups = [flag_group(flags = ["/MDd"])],
-                        with_features = [with_feature_set(features = ["dynamic_link_msvcrt_debug"])],
-                    ),
-                    flag_set(
-                        actions = all_compile_actions(),
-                        flag_groups = [flag_group(flags = ["/Od", "/Z7", "/DDEBUG"])],
-                        with_features = [with_feature_set(features = ["dbg"])],
-                    ),
-                    flag_set(
-                        actions = all_compile_actions(),
-                        flag_groups = [flag_group(flags = ["/Od", "/Z7", "/DDEBUG"])],
-                        with_features = [with_feature_set(features = ["fastbuild"])],
-                    ),
-                    flag_set(
-                        actions = all_compile_actions(),
-                        flag_groups = [flag_group(flags = ["/O2", "/DNDEBUG"])],
-                        with_features = [with_feature_set(features = ["opt"])],
-                    ),
-                    flag_set(
-                        actions = all_preprocessed_actions(),
-                        flag_groups = [
-                            _iterate_flag_group(
-                                flags = ["%{user_compile_flags}"],
-                                iterate_over = "user_compile_flags",
-                            ),
-                        ] + ([
-                            flag_group(flags = ctx.attr.host_unfiltered_compile_flags),
-                        ] if ctx.attr.host_unfiltered_compile_flags else []),
-                    ),
-                    flag_set(
-                        actions = [ACTION_NAMES.assemble],
-                        flag_groups = [
-                            flag_group(
-                                flag_groups = [
-                                    flag_group(
-                                        flags = ["/Fo%{output_file}", "/Zi"],
-                                        expand_if_not_available = "output_preprocess_file",
-                                    ),
-                                ],
-                                expand_if_available = "output_file",
-                                expand_if_not_available = "output_assembly_file",
-                            ),
-                        ],
-                    ),
-                    flag_set(
-                        actions = all_preprocessed_actions(),
-                        flag_groups = [
-                            flag_group(
-                                flag_groups = [
-                                    flag_group(
-                                        flags = ["/Fo%{output_file}"],
-                                        expand_if_not_available = "output_preprocess_file",
-                                    ),
-                                ],
-                                expand_if_available = "output_file",
-                                expand_if_not_available = "output_assembly_file",
-                            ),
-                            flag_group(
-                                flag_groups = [
-                                    flag_group(
-                                        flags = ["/Fa%{output_file}"],
-                                        expand_if_available = "output_assembly_file",
-                                    ),
-                                ],
-                                expand_if_available = "output_file",
-                            ),
-                            flag_group(
-                                flag_groups = [
-                                    flag_group(
-                                        flags = ["/P", "/Fi%{output_file}"],
-                                        expand_if_available = "output_preprocess_file",
-                                    ),
-                                ],
-                                expand_if_available = "output_file",
-                            ),
-                        ],
-                    ),
-                    flag_set(
-                        actions = all_compile_actions(),
-                        flag_groups = [
-                            flag_group(
-                                flags = ["/c", "%{source_file}"],
-                                expand_if_available = "source_file",
-                            ),
-                        ],
-                    ),
-                ],
-            ),
-            feature(
-                name = "all_archive_flags",
-                enabled = True,
-                flag_sets = [
-                    flag_set(
-                        actions = all_archive_actions(),
-                        flag_groups = [
-                            _nologo(),
-                            flag_group(
-                                flags = ["/OUT:%{output_execpath}"],
-                                expand_if_available = "output_execpath",
-                            ),
-                        ],
-                    ),
-                ],
-            ),
-            feature(
-                name = "all_link_flags",
-                enabled = True,
-                flag_sets = [
-                    flag_set(
-                        actions = all_shared_library_link_actions(),
-                        flag_groups = [flag_group(flags = ["/DLL"])],
-                    ),
-                    flag_set(
-                        actions = all_link_actions(),
-                        flag_groups = [
-                            _nologo(),
-                            _iterate_flag_group(
-                                flags = ["%{linkstamp_paths}"],
-                                iterate_over = "linkstamp_paths",
-                            ),
-                            flag_group(
-                                flags = ["/OUT:%{output_execpath}"],
-                                expand_if_available = "output_execpath",
-                            ),
-                        ],
-                    ),
-                    flag_set(
-                        actions = all_shared_library_link_actions(),
-                        flag_groups = [
-                            flag_group(
-                                flags = ["/IMPLIB:%{interface_library_output_path}"],
-                                expand_if_available = "interface_library_output_path",
-                            ),
-                        ],
-                    ),
-                    flag_set(
-                        actions = all_link_actions() +
-                                  all_archive_actions(),
-                        flag_groups = [
-                            _libraries_to_link_group("msvc"),
-                        ],
-                    ),
-                    flag_set(
-                        actions = all_link_actions(),
-                        flag_groups = [
-                            flag_group(flags = ["/SUBSYSTEM:CONSOLE"]),
-                            _iterate_flag_group(
-                                flags = ["%{user_link_flags}"],
-                                iterate_over = "user_link_flags",
-                            ),
-                            flag_group(flags = ["/MACHINE:X64"]),
-                        ],
-                    ),
-                    flag_set(
-                        actions = all_link_actions() +
-                                  all_archive_actions(),
-                        flag_groups = [
-                            flag_group(
-                                flags = ["@%{linker_param_file}"],
-                                expand_if_available = "linker_param_file",
-                            ),
-                        ],
-                    ),
-                    flag_set(
-                        actions = all_link_actions(),
-                        flag_groups = [flag_group(flags = ["/DEFAULTLIB:libcmt.lib"])],
-                        with_features = [with_feature_set(features = ["static_link_msvcrt_no_debug"])],
-                    ),
-                    flag_set(
-                        actions = all_link_actions(),
-                        flag_groups = [flag_group(flags = ["/DEFAULTLIB:msvcrt.lib"])],
-                        with_features = [with_feature_set(features = ["dynamic_link_msvcrt_no_debug"])],
-                    ),
-                    flag_set(
-                        actions = all_link_actions(),
-                        flag_groups = [flag_group(flags = ["/DEFAULTLIB:libcmtd.lib"])],
-                        with_features = [with_feature_set(features = ["static_link_msvcrt_debug"])],
-                    ),
-                    flag_set(
-                        actions = all_link_actions(),
-                        flag_groups = [flag_group(flags = ["/DEFAULTLIB:msvcrtd.lib"])],
-                        with_features = [with_feature_set(features = ["dynamic_link_msvcrt_debug"])],
-                    ),
-                    flag_set(
-                        actions = all_link_actions(),
-                        flag_groups = [flag_group(flags = ["/DEBUG:FULL", "/INCREMENTAL:NO"])],
-                        with_features = [with_feature_set(features = ["dbg"])],
-                    ),
-                    flag_set(
-                        actions = all_link_actions(),
-                        flag_groups = [
-                            flag_group(flags = ["/DEBUG:FASTLINK", "/INCREMENTAL:NO"]),
-                        ],
-                        with_features = [with_feature_set(features = ["fastbuild"])],
-                    ),
-                    flag_set(
-                        actions = all_link_actions(),
-                        flag_groups = [
-                            flag_group(
-                                flags = ["/DEF:%{def_file_path}", "/ignore:4070"],
-                                expand_if_available = "def_file_path",
-                            ),
-                        ],
-                    ),
-                ],
-            ),
-            feature(name = "parse_showincludes", enabled = True),
-            feature(name = "no_stripping", enabled = True),
-            feature(
-                name = "targets_windows",
-                enabled = True,
-                implies = ["copy_dynamic_libraries_to_binary"],
-            ),
-            feature(name = "copy_dynamic_libraries_to_binary"),
-            feature(
-                name = "generate_pdb_file",
-                requires = [
-                    feature_set(features = ["dbg"]),
-                    feature_set(features = ["fastbuild"]),
-                ],
-            ),
-            feature(name = "static_link_msvcrt"),
-            feature(
-                name = "static_link_msvcrt_no_debug",
-                requires = [
-                    feature_set(features = ["fastbuild"]),
-                    feature_set(features = ["opt"]),
-                ],
-            ),
-            feature(
-                name = "dynamic_link_msvcrt_no_debug",
-                requires = [
-                    feature_set(features = ["fastbuild"]),
-                    feature_set(features = ["opt"]),
-                ],
-            ),
-            feature(
-                name = "static_link_msvcrt_debug",
-                requires = [feature_set(features = ["dbg"])],
-            ),
-            feature(
-                name = "dynamic_link_msvcrt_debug",
-                requires = [feature_set(features = ["dbg"])],
-            ),
-            feature(
-                name = "dbg",
-                implies = ["generate_pdb_file"],
-            ),
-            feature(
-                name = "fastbuild",
-                implies = ["generate_pdb_file"],
-            ),
-            feature(
-                name = "opt",
-            ),
-            feature(name = "windows_export_all_symbols"),
-            feature(name = "no_windows_export_all_symbols"),
-            feature(name = "supports_dynamic_linker", enabled = True),
-            feature(
-                name = "supports_interface_shared_libraries",
-                enabled = True,
-            ),
-            feature(name = "has_configured_linker_path", enabled = True),
-        ]
-    else:
-        fail("Unreachable")
-
-def _impl(ctx):
-    cpu = ctx.attr.cpu
-    compiler = ctx.attr.compiler
-
-    if (cpu == "darwin"):
-        toolchain_identifier = "local_darwin"
-        target_cpu = "darwin"
-        target_libc = "macosx"
-        compiler = "compiler"
-        action_configs = _action_configs(
-            assembly_path = ctx.attr.host_compiler_path,
-            c_compiler_path = ctx.attr.host_compiler_path,
-            cc_compiler_path = ctx.attr.host_compiler_path,
-            archiver_path = ctx.attr.host_compiler_prefix + "/libtool",
-            linker_path = ctx.attr.host_compiler_path,
-            strip_path = ctx.attr.host_compiler_prefix + "/strip",
-        )
-        artifact_name_patterns = []
-    elif (cpu == "local"):
-        toolchain_identifier = "local_linux"
-        target_cpu = "local"
-        target_libc = "local"
-        action_configs = _action_configs(
-            assembly_path = ctx.attr.host_compiler_path,
-            c_compiler_path = ctx.attr.host_compiler_path,
-            cc_compiler_path = ctx.attr.host_compiler_path,
-            archiver_path = ctx.attr.host_compiler_prefix + "/ar",
-            linker_path = ctx.attr.host_compiler_path,
-            strip_path = ctx.attr.host_compiler_prefix + "/strip",
-        )
-        artifact_name_patterns = []
-    elif (cpu == "x64_windows"):
-        toolchain_identifier = "local_windows"
-        target_cpu = "x64_windows"
-        target_libc = "msvcrt"
-        compiler = "msvc-cl"
-        action_configs = _action_configs(
-            assembly_path = ctx.attr.msvc_ml_path,
-            c_compiler_path = ctx.attr.msvc_cl_path,
-            cc_compiler_path = ctx.attr.msvc_cl_path,
-            archiver_path = ctx.attr.msvc_lib_path,
-            linker_path = ctx.attr.msvc_link_path,
-            strip_path = "fake_tool_strip_not_supported",
-        )
-        artifact_name_patterns = [
-            artifact_name_pattern(
-                category_name = "object_file",
-                prefix = "",
-                extension = ".obj",
-            ),
-            artifact_name_pattern(
-                category_name = "static_library",
-                prefix = "",
-                extension = ".lib",
-            ),
-            artifact_name_pattern(
-                category_name = "alwayslink_static_library",
-                prefix = "",
-                extension = ".lo.lib",
-            ),
-            artifact_name_pattern(
-                category_name = "executable",
-                prefix = "",
-                extension = ".exe",
-            ),
-            artifact_name_pattern(
-                category_name = "dynamic_library",
-                prefix = "",
-                extension = ".dll",
-            ),
-            artifact_name_pattern(
-                category_name = "interface_library",
-                prefix = "",
-                extension = ".if.lib",
-            ),
-        ]
-    else:
-        fail("Unreachable")
-
-    out = ctx.actions.declare_file(ctx.label.name)
-    ctx.actions.write(out, "Fake executable")
-    return [
-        cc_common.create_cc_toolchain_config_info(
-            ctx = ctx,
-            features = _features(cpu, compiler, ctx),
-            action_configs = action_configs,
-            artifact_name_patterns = artifact_name_patterns,
-            cxx_builtin_include_directories = ctx.attr.builtin_include_directories,
-            toolchain_identifier = toolchain_identifier,
-            host_system_name = "local",
-            target_system_name = "local",
-            target_cpu = target_cpu,
-            target_libc = target_libc,
-            compiler = compiler,
-            abi_version = "local",
-            abi_libc_version = "local",
-            tool_paths = _tool_paths(cpu, ctx),
-            make_variables = [],
-            builtin_sysroot = ctx.attr.builtin_sysroot,
-            cc_target_os = None,
-        ),
-        DefaultInfo(
-            executable = out,
-        ),
-    ]
-
-cc_toolchain_config = rule(
-    implementation = _impl,
-    attrs = {
-        "cpu": attr.string(mandatory = True, values = ["darwin", "local", "x64_windows"]),
-        "compiler": attr.string(values = ["clang", "msvc", "unknown"], default = "unknown"),
-        "builtin_include_directories": attr.string_list(),
-        "extra_no_canonical_prefixes_flags": attr.string_list(),
-        "host_compiler_path": attr.string(),
-        "host_compiler_prefix": attr.string(),
-        "host_compiler_warnings": attr.string_list(),
-        "host_unfiltered_compile_flags": attr.string_list(),
-        "linker_bin_path": attr.string(),
-        "builtin_sysroot": attr.string(),
-        "cuda_path": attr.string(),
-        "msvc_cl_path": attr.string(default = "msvc_not_used"),
-        "msvc_env_include": attr.string(default = "msvc_not_used"),
-        "msvc_env_lib": attr.string(default = "msvc_not_used"),
-        "msvc_env_path": attr.string(default = "msvc_not_used"),
-        "msvc_env_tmp": attr.string(default = "msvc_not_used"),
-        "msvc_lib_path": attr.string(default = "msvc_not_used"),
-        "msvc_link_path": attr.string(default = "msvc_not_used"),
-        "msvc_ml_path": attr.string(default = "msvc_not_used"),
-    },
-    provides = [CcToolchainConfigInfo],
-    executable = True,
-)
diff --git a/third_party/xla/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl b/third_party/xla/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
deleted file mode 100755
index 81e54ad..0000000
--- a/third_party/xla/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
+++ /dev/null
@@ -1,306 +0,0 @@
-#!/usr/bin/env python
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-"""Crosstool wrapper for compiling CUDA programs.
-
-SYNOPSIS:
-  crosstool_wrapper_is_not_gcc [options passed in by cc_library()
-                                or cc_binary() rule]
-
-DESCRIPTION:
-  This script is expected to be called by the cc_library() or cc_binary() bazel
-  rules. When the option "-x cuda" is present in the list of arguments passed
-  to this script, it invokes the nvcc CUDA compiler. Most arguments are passed
-  as is as a string to --compiler-options of nvcc. When "-x cuda" is not
-  present, this wrapper invokes hybrid_driver_is_not_gcc with the input
-  arguments as is.
-
-NOTES:
-  Changes to the contents of this file must be propagated from
-  //third_party/gpus/crosstool/crosstool_wrapper_is_not_gcc to
-  //third_party/gpus/crosstool/v*/*/clang/bin/crosstool_wrapper_is_not_gcc
-"""
-
-__author__ = 'keveman@google.com (Manjunath Kudlur)'
-
-from argparse import ArgumentParser
-import os
-import subprocess
-import re
-import sys
-import pipes
-
-# Template values set by cuda_autoconf.
-CPU_COMPILER = ('%{cpu_compiler}')
-GCC_HOST_COMPILER_PATH = ('%{gcc_host_compiler_path}')
-
-NVCC_PATH = '%{nvcc_path}'
-PREFIX_DIR = os.path.dirname(GCC_HOST_COMPILER_PATH)
-NVCC_VERSION = '%{cuda_version}'
-
-def Log(s):
-  print('gpus/crosstool: {0}'.format(s))
-
-
-def GetOptionValue(argv, option):
-  """Extract the list of values for option from the argv list.
-
-  Args:
-    argv: A list of strings, possibly the argv passed to main().
-    option: The option whose value to extract, with the leading '-'.
-
-  Returns:
-    A list of values, either directly following the option,
-    (eg., -opt val1 val2) or values collected from multiple occurrences of
-    the option (eg., -opt val1 -opt val2).
-  """
-
-  parser = ArgumentParser()
-  parser.add_argument(option, nargs='*', action='append')
-  option = option.lstrip('-').replace('-', '_')
-  args, _ = parser.parse_known_args(argv)
-  if not args or not vars(args)[option]:
-    return []
-  else:
-    return sum(vars(args)[option], [])
-
-
-def GetHostCompilerOptions(argv):
-  """Collect the -isystem, -iquote, and --sysroot option values from argv.
-
-  Args:
-    argv: A list of strings, possibly the argv passed to main().
-
-  Returns:
-    The string that can be used as the --compiler-options to nvcc.
-  """
-
-  parser = ArgumentParser()
-  parser.add_argument('-isystem', nargs='*', action='append')
-  parser.add_argument('-iquote', nargs='*', action='append')
-  parser.add_argument('--sysroot', nargs=1)
-  parser.add_argument('-g', nargs='*', action='append')
-  parser.add_argument('-fno-canonical-system-headers', action='store_true')
-  parser.add_argument('-no-canonical-prefixes', action='store_true')
-
-  args, _ = parser.parse_known_args(argv)
-
-  opts = ''
-
-  if args.isystem:
-    opts += ' -isystem ' + ' -isystem '.join(sum(args.isystem, []))
-  if args.iquote:
-    opts += ' -iquote ' + ' -iquote '.join(sum(args.iquote, []))
-  if args.g:
-    opts += ' -g' + ' -g'.join(sum(args.g, []))
-  if args.fno_canonical_system_headers:
-    opts += ' -fno-canonical-system-headers'
-  if args.no_canonical_prefixes:
-    opts += ' -no-canonical-prefixes'
-  if args.sysroot:
-    opts += ' --sysroot ' + args.sysroot[0]
-
-  return opts
-
-def _update_options(nvcc_options):
-  if NVCC_VERSION in ("7.0",):
-    return nvcc_options
-
-  update_options = { "relaxed-constexpr" : "expt-relaxed-constexpr" }
-  return [ update_options[opt] if opt in update_options else opt
-                    for opt in nvcc_options ]
-
-def GetNvccOptions(argv):
-  """Collect the -nvcc_options values from argv.
-
-  Args:
-    argv: A list of strings, possibly the argv passed to main().
-
-  Returns:
-    The string that can be passed directly to nvcc.
-  """
-
-  parser = ArgumentParser()
-  parser.add_argument('-nvcc_options', nargs='*', action='append')
-
-  args, _ = parser.parse_known_args(argv)
-
-  if args.nvcc_options:
-    options = _update_options(sum(args.nvcc_options, []))
-    return ' '.join(['--'+a for a in options])
-  return ''
-
-def system(cmd):
-  """Invokes cmd with os.system().
-
-  Args:
-    cmd: The command.
-
-  Returns:
-    The exit code if the process exited with exit() or -signal
-    if the process was terminated by a signal.
-  """
-  retv = os.system(cmd)
-  if os.WIFEXITED(retv):
-    return os.WEXITSTATUS(retv)
-  else:
-    return -os.WTERMSIG(retv)
-
-def InvokeNvcc(argv, log=False):
-  """Call nvcc with arguments assembled from argv.
-
-  Args:
-    argv: A list of strings, possibly the argv passed to main().
-    log: True if logging is requested.
-
-  Returns:
-    The return value of calling system('nvcc ' + args)
-  """
-
-  host_compiler_options = GetHostCompilerOptions(argv)
-  nvcc_compiler_options = GetNvccOptions(argv)
-  opt_option = GetOptionValue(argv, '-O')
-  m_options = GetOptionValue(argv, '-m')
-  m_options = ''.join([' -m' + m for m in m_options if m in ['32', '64']])
-  m_host_options = ''.join([' -m' + m for m in m_options if m not in ['32', '64']])
-  host_compiler_options = ' '.join([host_compiler_options, m_host_options])
-  include_options = GetOptionValue(argv, '-I')
-  out_file = GetOptionValue(argv, '-o')
-  depfiles = GetOptionValue(argv, '-MF')
-  defines = GetOptionValue(argv, '-D')
-  defines = ''.join([' -D' + define for define in defines])
-  undefines = GetOptionValue(argv, '-U')
-  undefines = ''.join([' -U' + define for define in undefines])
-  std_options = GetOptionValue(argv, '-std')
-  # Supported -std flags as of CUDA 9.0. Only keep last to mimic gcc/clang.
-  nvcc_allowed_std_options = ["c++03", "c++11", "c++14"]
-  nvcc_std_map = {}
-  if int(NVCC_VERSION.split('.')[0]) >= 11:
-      nvcc_std_map["c++1z"] = "c++17"
-      nvcc_allowed_std_options += ["c++17", "c++1z"]
-  std_options = ''.join([' -std=' +
-      (nvcc_std_map[define] if define in nvcc_std_map else define)
-      for define in std_options if define in nvcc_allowed_std_options][-1:])
-  fatbin_options = ''.join([' --fatbin-options=' + option
-      for option in GetOptionValue(argv, '-Xcuda-fatbinary')])
-
-  # The list of source files get passed after the -c option. I don't know of
-  # any other reliable way to just get the list of source files to be compiled.
-  src_files = GetOptionValue(argv, '-c')
-
-  # Pass -w through from host to nvcc, but don't do anything fancier with
-  # warnings-related flags, since they're not necessarily the same across
-  # compilers.
-  warning_options = ' -w' if '-w' in argv else ''
-
-  if len(src_files) == 0:
-    return 1
-  if len(out_file) != 1:
-    return 1
-
-  opt = (' -O2' if (len(opt_option) > 0 and int(opt_option[0]) > 0)
-         else ' -g')
-
-  includes = (' -I ' + ' -I '.join(include_options)
-              if len(include_options) > 0
-              else '')
-
-  # Unfortunately, there are other options that have -c prefix too.
-  # So allowing only those look like C/C++ files.
-  src_files = [f for f in src_files if
-               re.search('\.cpp$|\.cc$|\.c$|\.cxx$|\.C$', f)]
-  srcs = ' '.join(src_files)
-  out = ' -o ' + out_file[0]
-
-  nvccopts = '-D_FORCE_INLINES '
-  capabilities_sm = set(GetOptionValue(argv, "--cuda-gpu-arch"))
-  capabilities_compute = set(GetOptionValue(argv, '--cuda-include-ptx'))
-  # When both "code=sm_xy" and "code=compute_xy" are requested for a single
-  # arch, they can be combined using "code=xy,compute_xy" which avoids a
-  # redundant PTX generation during compilation.
-  capabilities_both = capabilities_sm.intersection(capabilities_compute)
-  for capability in capabilities_both:
-    capability = capability[len('sm_'):]
-    nvccopts += r'-gencode=arch=compute_%s,code=\"sm_%s,compute_%s\" ' % (
-        capability, capability, capability)
-  for capability in capabilities_sm - capabilities_both:
-    capability = capability[len('sm_'):]
-    nvccopts += r'-gencode=arch=compute_%s,\"code=sm_%s\" ' % (capability,
-                                                               capability)
-  for capability in capabilities_compute - capabilities_both:
-    capability = capability[len('sm_'):]
-    nvccopts += r'-gencode=arch=compute_%s,\"code=compute_%s\" ' % (capability,
-                                                                    capability)
-  nvccopts += nvcc_compiler_options
-  nvccopts += undefines
-  nvccopts += defines
-  nvccopts += std_options
-  nvccopts += m_options
-  nvccopts += warning_options
-  # Force C++17 dialect (note, everything in just one string!)
-  nvccopts += ' --std c++17 '
-  nvccopts += fatbin_options
-
-  if depfiles:
-    # Generate the dependency file
-    depfile = depfiles[0]
-    cmd = (NVCC_PATH + ' ' + nvccopts +
-           ' --compiler-options "' + host_compiler_options + '"' +
-           ' --compiler-bindir=' + GCC_HOST_COMPILER_PATH +
-           ' -I .' +
-           ' -x cu ' + opt + includes + ' ' + srcs + ' -M -o ' + depfile)
-    if log: Log(cmd)
-    exit_status = system(cmd)
-    if exit_status != 0:
-      return exit_status
-
-  cmd = (NVCC_PATH + ' ' + nvccopts +
-         ' --compiler-options "' + host_compiler_options + ' -fPIC"' +
-         ' --compiler-bindir=' + GCC_HOST_COMPILER_PATH +
-         ' -I .' +
-         ' -x cu ' + opt + includes + ' -c ' + srcs + out)
-
-  # TODO(zhengxq): for some reason, 'gcc' needs this help to find 'as'.
-  # Need to investigate and fix.
-  cmd = 'PATH=' + PREFIX_DIR + ':$PATH ' + cmd
-  if log: Log(cmd)
-  return system(cmd)
-
-
-def main():
-  parser = ArgumentParser()
-  parser.add_argument('-x', nargs=1)
-  parser.add_argument('--cuda_log', action='store_true')
-  args, leftover = parser.parse_known_args(sys.argv[1:])
-
-  if args.x and args.x[0] == 'cuda':
-    if args.cuda_log: Log('-x cuda')
-    leftover = [pipes.quote(s) for s in leftover]
-    if args.cuda_log: Log('using nvcc')
-    return InvokeNvcc(leftover, log=args.cuda_log)
-
-  # Strip our flags before passing through to the CPU compiler for files which
-  # are not -x cuda. We can't just pass 'leftover' because it also strips -x.
-  # We not only want to pass -x to the CPU compiler, but also keep it in its
-  # relative location in the argv list (the compiler is actually sensitive to
-  # this).
-  cpu_compiler_flags = [flag for flag in sys.argv[1:]
-                             if not flag.startswith(('--cuda_log'))]
-
-  return subprocess.call([CPU_COMPILER] + cpu_compiler_flags)
-
-if __name__ == '__main__':
-  sys.exit(main())
diff --git a/third_party/xla/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl b/third_party/xla/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl
deleted file mode 100755
index 8fb2231..0000000
--- a/third_party/xla/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl
+++ /dev/null
@@ -1,265 +0,0 @@
-#!/usr/bin/env python
-"""Crosstool wrapper for compiling ROCm programs.
-
-SYNOPSIS:
-  crosstool_wrapper_driver_rocm [options passed in by cc_library()
-                                or cc_binary() rule]
-
-DESCRIPTION:
-  This script is expected to be called by the cc_library() or cc_binary() bazel
-  rules. When the option "-x rocm" is present in the list of arguments passed
-  to this script, it invokes the hipcc compiler. Most arguments are passed
-  as is as a string to --compiler-options of hipcc. When "-x rocm" is not
-  present, this wrapper invokes gcc with the input arguments as is.
-"""
-
-__author__ = 'whchung@gmail.com (Wen-Heng (Jack) Chung)'
-
-from argparse import ArgumentParser
-import os
-import subprocess
-import re
-import sys
-import pipes
-
-# Template values set by rocm_configure.bzl.
-CPU_COMPILER = ('%{cpu_compiler}')
-
-HIPCC_PATH = '%{hipcc_path}'
-HIPCC_ENV = '%{hipcc_env}'
-HIP_RUNTIME_PATH = '%{hip_runtime_path}'
-HIP_RUNTIME_LIBRARY = '%{hip_runtime_library}'
-ROCR_RUNTIME_PATH = '%{rocr_runtime_path}'
-ROCR_RUNTIME_LIBRARY = '%{rocr_runtime_library}'
-VERBOSE = '%{crosstool_verbose}'=='1'
-
-def Log(s):
-  print('gpus/crosstool: {0}'.format(s))
-
-
-def GetOptionValue(argv, option):
-  """Extract the list of values for option from the argv list.
-
-  Args:
-    argv: A list of strings, possibly the argv passed to main().
-    option: The option whose value to extract, without the leading '-'.
-
-  Returns:
-    A list of values, either directly following the option,
-    (eg., -opt val1 val2) or values collected from multiple occurrences of
-    the option (eg., -opt val1 -opt val2).
-  """
-
-  parser = ArgumentParser()
-  parser.add_argument('-' + option, nargs='*', action='append')
-  args, _ = parser.parse_known_args(argv)
-  if not args or not vars(args)[option]:
-    return []
-  else:
-    return sum(vars(args)[option], [])
-
-
-def GetHostCompilerOptions(argv):
-  """Collect the -isystem, -iquote, and --sysroot option values from argv.
-
-  Args:
-    argv: A list of strings, possibly the argv passed to main().
-
-  Returns:
-    The string that can be used as the --compiler-options to hipcc.
-  """
-
-  parser = ArgumentParser()
-  parser.add_argument('-isystem', nargs='*', action='append')
-  parser.add_argument('-iquote', nargs='*', action='append')
-  parser.add_argument('--sysroot', nargs=1)
-  parser.add_argument('-g', nargs='*', action='append')
-  parser.add_argument('-fno-canonical-system-headers', action='store_true')
-
-  args, _ = parser.parse_known_args(argv)
-
-  opts = ''
-
-  if args.isystem:
-    opts += ' -isystem ' + ' -isystem '.join(sum(args.isystem, []))
-  if args.iquote:
-    opts += ' -iquote ' + ' -iquote '.join(sum(args.iquote, []))
-  if args.g:
-    opts += ' -g' + ' -g'.join(sum(args.g, []))
-  #if args.fno_canonical_system_headers:
-  #  opts += ' -fno-canonical-system-headers'
-  if args.sysroot:
-    opts += ' --sysroot ' + args.sysroot[0]
-
-  return opts
-
-def system(cmd):
-  """Invokes cmd with os.system().
-
-  Args:
-    cmd: The command.
-
-  Returns:
-    The exit code if the process exited with exit() or -signal
-    if the process was terminated by a signal.
-  """
-  retv = os.system(cmd)
-  if os.WIFEXITED(retv):
-    return os.WEXITSTATUS(retv)
-  else:
-    return -os.WTERMSIG(retv)
-
-
-def InvokeHipcc(argv, log=False):
-  """Call hipcc with arguments assembled from argv.
-
-  Args:
-    argv: A list of strings, possibly the argv passed to main().
-    log: True if logging is requested.
-
-  Returns:
-    The return value of calling os.system('hipcc ' + args)
-  """
-
-  host_compiler_options = GetHostCompilerOptions(argv)
-  opt_option = GetOptionValue(argv, 'O')
-  m_options = GetOptionValue(argv, 'm')
-  m_options = ''.join([' -m' + m for m in m_options if m in ['32', '64']])
-  include_options = GetOptionValue(argv, 'I')
-  out_file = GetOptionValue(argv, 'o')
-  depfiles = GetOptionValue(argv, 'MF')
-  defines = GetOptionValue(argv, 'D')
-  defines = ''.join([' -D' + define for define in defines])
-  undefines = GetOptionValue(argv, 'U')
-  undefines = ''.join([' -U' + define for define in undefines])
-  std_options = GetOptionValue(argv, 'std')
-  hipcc_allowed_std_options = ["c++11", "c++14", "c++17"]
-  std_options = ''.join([' -std=' + define
-      for define in std_options if define in hipcc_allowed_std_options])
-
-  # The list of source files get passed after the -c option. I don't know of
-  # any other reliable way to just get the list of source files to be compiled.
-  src_files = GetOptionValue(argv, 'c')
-
-  if len(src_files) == 0:
-    return 1
-  if len(out_file) != 1:
-    return 1
-
-  opt = (' -O2' if (len(opt_option) > 0 and int(opt_option[0]) > 0)
-         else ' -g')
-
-  includes = (' -I ' + ' -I '.join(include_options)
-              if len(include_options) > 0
-              else '')
-
-  # Unfortunately, there are other options that have -c prefix too.
-  # So allowing only those look like C/C++ files.
-  src_files = [f for f in src_files if
-               re.search('\.cpp$|\.cc$|\.c$|\.cxx$|\.C$', f)]
-  srcs = ' '.join(src_files)
-  out = ' -o ' + out_file[0]
-
-  hipccopts = ' '
-  # In hip-clang environment, we need to make sure that hip header is included
-  # before some standard math header like <complex> is included in any source.
-  # Otherwise, we get build error.
-  # Also we need to retain warning about uninitialised shared variable as
-  # warning only, even when -Werror option is specified.
-  hipccopts += ' --include=hip/hip_runtime.h '
-  # Force C++17 dialect (note, everything in just one string!)
-  hipccopts += ' --std=c++17 '
-  # Use -fno-gpu-rdc by default for early GPU kernel finalization
-  # This flag would trigger GPU kernels be generated at compile time, instead
-  # of link time. This allows the default host compiler (gcc) be used as the
-  # linker for TensorFlow on ROCm platform.
-  hipccopts += ' -fno-gpu-rdc '
-  hipccopts += ' -fcuda-flush-denormals-to-zero '
-  hipccopts += undefines
-  hipccopts += defines
-  hipccopts += std_options
-  hipccopts += m_options
-
-  if depfiles:
-    # Generate the dependency file
-    depfile = depfiles[0]
-    cmd = (HIPCC_PATH + ' ' + hipccopts +
-           host_compiler_options +
-           ' -I .' + includes + ' ' + srcs + ' -M -o ' + depfile)
-    cmd = HIPCC_ENV.replace(';', ' ') + ' ' + cmd
-    if log: Log(cmd)
-    if VERBOSE: print(cmd)
-    exit_status = os.system(cmd)
-    if exit_status != 0:
-      return exit_status
-
-  cmd = (HIPCC_PATH + ' ' + hipccopts +
-         host_compiler_options + ' -fPIC' +
-         ' -I .' + opt + includes + ' -c ' + srcs + out)
-
-  cmd = HIPCC_ENV.replace(';', ' ') + ' '\
-        + cmd
-  if log: Log(cmd)
-  if VERBOSE: print(cmd)
-  return system(cmd)
-
-
-def main():
-  # ignore PWD env var
-  os.environ['PWD']=''
-
-  parser = ArgumentParser(fromfile_prefix_chars='@')
-  parser.add_argument('-x', nargs=1)
-  parser.add_argument('--rocm_log', action='store_true')
-  parser.add_argument('-pass-exit-codes', action='store_true')
-  args, leftover = parser.parse_known_args(sys.argv[1:])
-
-  if VERBOSE: print('PWD=' + os.getcwd())
-  if VERBOSE: print('HIPCC_ENV=' + HIPCC_ENV)
-
-  if args.x and args.x[0] == 'rocm':
-    # compilation for GPU objects
-    if args.rocm_log: Log('-x rocm')
-    leftover = [pipes.quote(s) for s in leftover]
-    if args.rocm_log: Log('using hipcc')
-    return InvokeHipcc(leftover, log=args.rocm_log)
-
-  elif args.pass_exit_codes:
-    # link
-    # with hipcc compiler invoked with -fno-gpu-rdc by default now, it's ok to 
-    # use host compiler as linker, but we have to link with HCC/HIP runtime.
-    # Such restriction would be revised further as the bazel script get
-    # improved to fine tune dependencies to ROCm libraries.
-    gpu_linker_flags = [flag for flag in sys.argv[1:]
-                               if not flag.startswith(('--rocm_log'))]
-
-    gpu_linker_flags.append('-L' + ROCR_RUNTIME_PATH)
-    gpu_linker_flags.append('-Wl,-rpath=' + ROCR_RUNTIME_PATH)
-    gpu_linker_flags.append('-l' + ROCR_RUNTIME_LIBRARY)
-    gpu_linker_flags.append('-L' + HIP_RUNTIME_PATH)
-    gpu_linker_flags.append('-Wl,-rpath=' + HIP_RUNTIME_PATH)
-    gpu_linker_flags.append('-l' + HIP_RUNTIME_LIBRARY)
-    gpu_linker_flags.append("-lrt")
-    gpu_linker_flags.append("-lstdc++")
-
-    if VERBOSE: print(' '.join([CPU_COMPILER] + gpu_linker_flags))
-    return subprocess.call([CPU_COMPILER] + gpu_linker_flags)
-
-  else:
-    # compilation for host objects
-
-    # Strip our flags before passing through to the CPU compiler for files which
-    # are not -x rocm. We can't just pass 'leftover' because it also strips -x.
-    # We not only want to pass -x to the CPU compiler, but also keep it in its
-    # relative location in the argv list (the compiler is actually sensitive to
-    # this).
-    cpu_compiler_flags = [flag for flag in sys.argv[1:]
-                               if not flag.startswith(('--rocm_log'))]
-
-    # XXX: SE codes need to be built with gcc, but need this macro defined
-    cpu_compiler_flags.append("-D__HIP_PLATFORM_HCC__")
-    if VERBOSE: print(' '.join([CPU_COMPILER] + cpu_compiler_flags))
-    return subprocess.call([CPU_COMPILER] + cpu_compiler_flags)
-
-if __name__ == '__main__':
-  sys.exit(main())
diff --git a/third_party/xla/third_party/gpus/crosstool/hipcc_cc_toolchain_config.bzl.tpl b/third_party/xla/third_party/gpus/crosstool/hipcc_cc_toolchain_config.bzl.tpl
deleted file mode 100644
index e0541de..0000000
--- a/third_party/xla/third_party/gpus/crosstool/hipcc_cc_toolchain_config.bzl.tpl
+++ /dev/null
@@ -1,1162 +0,0 @@
-"""cc_toolchain_config rule for configuring ROCm toolchain on Linux."""
-
-load(
-    "@bazel_tools//tools/cpp:cc_toolchain_config_lib.bzl",
-    "feature",
-    "feature_set",
-    "flag_group",
-    "flag_set",
-    "tool_path",
-    "variable_with_value",
-    "with_feature_set",
-)
-load("@bazel_tools//tools/build_defs/cc:action_names.bzl", "ACTION_NAMES")
-
-all_compile_actions = [
-    ACTION_NAMES.c_compile,
-    ACTION_NAMES.cpp_compile,
-    ACTION_NAMES.linkstamp_compile,
-    ACTION_NAMES.assemble,
-    ACTION_NAMES.preprocess_assemble,
-    ACTION_NAMES.cpp_header_parsing,
-    ACTION_NAMES.cpp_module_compile,
-    ACTION_NAMES.cpp_module_codegen,
-    ACTION_NAMES.clif_match,
-    ACTION_NAMES.lto_backend,
-]
-
-all_cpp_compile_actions = [
-    ACTION_NAMES.cpp_compile,
-    ACTION_NAMES.linkstamp_compile,
-    ACTION_NAMES.cpp_header_parsing,
-    ACTION_NAMES.cpp_module_compile,
-    ACTION_NAMES.cpp_module_codegen,
-    ACTION_NAMES.clif_match,
-]
-
-preprocessor_compile_actions = [
-    ACTION_NAMES.c_compile,
-    ACTION_NAMES.cpp_compile,
-    ACTION_NAMES.linkstamp_compile,
-    ACTION_NAMES.preprocess_assemble,
-    ACTION_NAMES.cpp_header_parsing,
-    ACTION_NAMES.cpp_module_compile,
-    ACTION_NAMES.clif_match,
-]
-
-codegen_compile_actions = [
-    ACTION_NAMES.c_compile,
-    ACTION_NAMES.cpp_compile,
-    ACTION_NAMES.linkstamp_compile,
-    ACTION_NAMES.assemble,
-    ACTION_NAMES.preprocess_assemble,
-    ACTION_NAMES.cpp_module_codegen,
-    ACTION_NAMES.lto_backend,
-]
-
-all_link_actions = [
-    ACTION_NAMES.cpp_link_executable,
-    ACTION_NAMES.cpp_link_dynamic_library,
-    ACTION_NAMES.cpp_link_nodeps_dynamic_library,
-]
-
-lto_index_actions = [
-    ACTION_NAMES.lto_index_for_executable,
-    ACTION_NAMES.lto_index_for_dynamic_library,
-    ACTION_NAMES.lto_index_for_nodeps_dynamic_library,
-]
-
-def _impl(ctx):
-    tool_paths = [
-        tool_path(name = "gcc", path = ctx.attr.host_compiler_path),
-        tool_path(name = "ar", path = ctx.attr.host_compiler_prefix + "/ar"),
-        tool_path(name = "compat-ld", path = ctx.attr.host_compiler_prefix + "/ld"),
-        tool_path(name = "cpp", path = ctx.attr.host_compiler_prefix + "/cpp"),
-        tool_path(name = "dwp", path = ctx.attr.host_compiler_prefix + "/dwp"),
-        tool_path(name = "gcov", path = ctx.attr.host_compiler_prefix + "/gcov"),
-        tool_path(name = "ld", path = ctx.attr.host_compiler_prefix + "/ld"),
-        tool_path(name = "nm", path = ctx.attr.host_compiler_prefix + "/nm"),
-        tool_path(name = "objcopy", path = ctx.attr.host_compiler_prefix + "/objcopy"),
-        tool_path(name = "objdump", path = ctx.attr.host_compiler_prefix + "/objdump"),
-        tool_path(name = "strip", path = ctx.attr.host_compiler_prefix + "/strip"),
-    ]
-
-    action_configs = []
-
-    supports_pic_feature = feature(
-        name = "supports_pic",
-        enabled = True,
-    )
-    supports_start_end_lib_feature = feature(
-        name = "supports_start_end_lib",
-        enabled = True,
-    )
-
-    default_compile_flags_feature = feature(
-        name = "default_compile_flags",
-        enabled = True,
-        flag_sets = [
-            flag_set(
-                actions = all_compile_actions,
-                flag_groups = ([
-                    flag_group(
-                        flags = ctx.attr.compile_flags,
-                    ),
-                ] if ctx.attr.compile_flags else []),
-            ),
-            flag_set(
-                actions = all_compile_actions,
-                flag_groups = ([
-                    flag_group(
-                        flags = ctx.attr.dbg_compile_flags,
-                    ),
-                ] if ctx.attr.dbg_compile_flags else []),
-                with_features = [with_feature_set(features = ["dbg"])],
-            ),
-            flag_set(
-                actions = all_compile_actions,
-                flag_groups = ([
-                    flag_group(
-                        flags = ctx.attr.opt_compile_flags,
-                    ),
-                ] if ctx.attr.opt_compile_flags else []),
-                with_features = [with_feature_set(features = ["opt"])],
-            ),
-            flag_set(
-                actions = all_cpp_compile_actions + [ACTION_NAMES.lto_backend],
-                flag_groups = ([
-                    flag_group(
-                        flags = ctx.attr.cxx_flags,
-                    ),
-                ] if ctx.attr.cxx_flags else []),
-            ),
-        ],
-    )
-
-    default_link_flags_feature = feature(
-        name = "default_link_flags",
-        enabled = True,
-        flag_sets = [
-            flag_set(
-                actions = all_link_actions + lto_index_actions,
-                flag_groups = ([
-                    flag_group(
-                        flags = ctx.attr.link_flags,
-                    ),
-                ] if ctx.attr.link_flags else []),
-            ),
-            flag_set(
-                actions = all_link_actions + lto_index_actions,
-                flag_groups = ([
-                    flag_group(
-                        flags = ctx.attr.opt_link_flags,
-                    ),
-                ] if ctx.attr.opt_link_flags else []),
-                with_features = [with_feature_set(features = ["opt"])],
-            ),
-        ],
-    )
-
-    dbg_feature = feature(name = "dbg")
-
-    opt_feature = feature(name = "opt")
-
-    sysroot_feature = feature(
-        name = "sysroot",
-        enabled = True,
-        flag_sets = [
-            flag_set(
-                actions = [
-                    ACTION_NAMES.preprocess_assemble,
-                    ACTION_NAMES.linkstamp_compile,
-                    ACTION_NAMES.c_compile,
-                    ACTION_NAMES.cpp_compile,
-                    ACTION_NAMES.cpp_header_parsing,
-                    ACTION_NAMES.cpp_module_compile,
-                    ACTION_NAMES.cpp_module_codegen,
-                    ACTION_NAMES.lto_backend,
-                    ACTION_NAMES.clif_match,
-                ] + all_link_actions + lto_index_actions,
-                flag_groups = [
-                    flag_group(
-                        flags = ["--sysroot=%{sysroot}"],
-                        expand_if_available = "sysroot",
-                    ),
-                ],
-            ),
-        ],
-    )
-
-    fdo_optimize_feature = feature(
-        name = "fdo_optimize",
-        flag_sets = [
-            flag_set(
-                actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile],
-                flag_groups = [
-                    flag_group(
-                        flags = [
-                            "-fprofile-use=%{fdo_profile_path}",
-                            "-fprofile-correction",
-                        ],
-                        expand_if_available = "fdo_profile_path",
-                    ),
-                ],
-            ),
-        ],
-        provides = ["profile"],
-    )
-
-    supports_dynamic_linker_feature = feature(name = "supports_dynamic_linker", enabled = True)
-
-    user_compile_flags_feature = feature(
-        name = "user_compile_flags",
-        enabled = True,
-        flag_sets = [
-            flag_set(
-                actions = all_compile_actions,
-                flag_groups = [
-                    flag_group(
-                        flags = ["%{user_compile_flags}"],
-                        iterate_over = "user_compile_flags",
-                        expand_if_available = "user_compile_flags",
-                    ),
-                ],
-            ),
-        ],
-    )
-
-    unfiltered_compile_flags_feature = feature(
-        name = "unfiltered_compile_flags",
-        enabled = True,
-        flag_sets = [
-            flag_set(
-                actions = all_compile_actions,
-                flag_groups = ([
-                    flag_group(
-                        flags = ctx.attr.unfiltered_compile_flags,
-                    ),
-                ] if ctx.attr.unfiltered_compile_flags else []),
-            ),
-        ],
-    )
-
-    library_search_directories_feature = feature(
-        name = "library_search_directories",
-        flag_sets = [
-            flag_set(
-                actions = all_link_actions + lto_index_actions,
-                flag_groups = [
-                    flag_group(
-                        flags = ["-L%{library_search_directories}"],
-                        iterate_over = "library_search_directories",
-                        expand_if_available = "library_search_directories",
-                    ),
-                ],
-            ),
-        ],
-    )
-
-    static_libgcc_feature = feature(
-        name = "static_libgcc",
-        enabled = True,
-        flag_sets = [
-            flag_set(
-                actions = [
-                    ACTION_NAMES.cpp_link_executable,
-                    ACTION_NAMES.cpp_link_dynamic_library,
-                    ACTION_NAMES.lto_index_for_executable,
-                    ACTION_NAMES.lto_index_for_dynamic_library,
-                ],
-                flag_groups = [flag_group(flags = ["-static-libgcc"])],
-                with_features = [
-                    with_feature_set(features = ["static_link_cpp_runtimes"]),
-                ],
-            ),
-        ],
-    )
-
-    pic_feature = feature(
-        name = "pic",
-        enabled = True,
-        flag_sets = [
-            flag_set(
-                actions = [
-                    ACTION_NAMES.assemble,
-                    ACTION_NAMES.preprocess_assemble,
-                    ACTION_NAMES.linkstamp_compile,
-                    ACTION_NAMES.c_compile,
-                    ACTION_NAMES.cpp_compile,
-                    ACTION_NAMES.cpp_module_codegen,
-                    ACTION_NAMES.cpp_module_compile,
-                ],
-                flag_groups = [
-                    flag_group(flags = ["-fPIC"], expand_if_available = "pic"),
-                ],
-            ),
-        ],
-    )
-
-    per_object_debug_info_feature = feature(
-        name = "per_object_debug_info",
-        flag_sets = [
-            flag_set(
-                actions = [
-                    ACTION_NAMES.assemble,
-                    ACTION_NAMES.preprocess_assemble,
-                    ACTION_NAMES.c_compile,
-                    ACTION_NAMES.cpp_compile,
-                    ACTION_NAMES.cpp_module_codegen,
-                ],
-                flag_groups = [
-                    flag_group(
-                        flags = ["-gsplit-dwarf"],
-                        expand_if_available = "per_object_debug_info_file",
-                    ),
-                ],
-            ),
-        ],
-    )
-
-    preprocessor_defines_feature = feature(
-        name = "preprocessor_defines",
-        enabled = True,
-        flag_sets = [
-            flag_set(
-                actions = [
-                    ACTION_NAMES.preprocess_assemble,
-                    ACTION_NAMES.linkstamp_compile,
-                    ACTION_NAMES.c_compile,
-                    ACTION_NAMES.cpp_compile,
-                    ACTION_NAMES.cpp_header_parsing,
-                    ACTION_NAMES.cpp_module_compile,
-                    ACTION_NAMES.clif_match,
-                ],
-                flag_groups = [
-                    flag_group(
-                        flags = ["-D%{preprocessor_defines}"],
-                        iterate_over = "preprocessor_defines",
-                    ),
-                ],
-            ),
-        ],
-    )
-
-    cs_fdo_optimize_feature = feature(
-        name = "cs_fdo_optimize",
-        flag_sets = [
-            flag_set(
-                actions = [ACTION_NAMES.lto_backend],
-                flag_groups = [
-                    flag_group(
-                        flags = [
-                            "-fprofile-use=%{fdo_profile_path}",
-                            "-Xclang-only=-Wno-profile-instr-unprofiled",
-                            "-Xclang-only=-Wno-profile-instr-out-of-date",
-                            "-fprofile-correction",
-                        ],
-                        expand_if_available = "fdo_profile_path",
-                    ),
-                ],
-            ),
-        ],
-        provides = ["csprofile"],
-    )
-
-    autofdo_feature = feature(
-        name = "autofdo",
-        flag_sets = [
-            flag_set(
-                actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile],
-                flag_groups = [
-                    flag_group(
-                        flags = [
-                            "-fauto-profile=%{fdo_profile_path}",
-                            "-fprofile-correction",
-                        ],
-                        expand_if_available = "fdo_profile_path",
-                    ),
-                ],
-            ),
-        ],
-        provides = ["profile"],
-    )
-
-    runtime_library_search_directories_feature = feature(
-        name = "runtime_library_search_directories",
-        flag_sets = [
-            flag_set(
-                actions = all_link_actions + lto_index_actions,
-                flag_groups = [
-                    flag_group(
-                        iterate_over = "runtime_library_search_directories",
-                        flag_groups = [
-                            flag_group(
-                                flags = [
-                                    "-Wl,-rpath,$EXEC_ORIGIN/%{runtime_library_search_directories}",
-                                ],
-                                expand_if_true = "is_cc_test",
-                            ),
-                            flag_group(
-                                flags = [
-                                    "-Wl,-rpath,$ORIGIN/%{runtime_library_search_directories}",
-                                ],
-                                expand_if_false = "is_cc_test",
-                            ),
-                        ],
-                        expand_if_available =
-                            "runtime_library_search_directories",
-                    ),
-                ],
-                with_features = [
-                    with_feature_set(features = ["static_link_cpp_runtimes"]),
-                ],
-            ),
-            flag_set(
-                actions = all_link_actions + lto_index_actions,
-                flag_groups = [
-                    flag_group(
-                        iterate_over = "runtime_library_search_directories",
-                        flag_groups = [
-                            flag_group(
-                                flags = [
-                                    "-Wl,-rpath,$ORIGIN/%{runtime_library_search_directories}",
-                                ],
-                            ),
-                        ],
-                        expand_if_available =
-                            "runtime_library_search_directories",
-                    ),
-                ],
-                with_features = [
-                    with_feature_set(
-                        not_features = ["static_link_cpp_runtimes"],
-                    ),
-                ],
-            ),
-        ],
-    )
-
-    fission_support_feature = feature(
-        name = "fission_support",
-        flag_sets = [
-            flag_set(
-                actions = all_link_actions + lto_index_actions,
-                flag_groups = [
-                    flag_group(
-                        flags = ["-Wl,--gdb-index"],
-                        expand_if_available = "is_using_fission",
-                    ),
-                ],
-            ),
-        ],
-    )
-
-    shared_flag_feature = feature(
-        name = "shared_flag",
-        flag_sets = [
-            flag_set(
-                actions = [
-                    ACTION_NAMES.cpp_link_dynamic_library,
-                    ACTION_NAMES.cpp_link_nodeps_dynamic_library,
-                    ACTION_NAMES.lto_index_for_dynamic_library,
-                    ACTION_NAMES.lto_index_for_nodeps_dynamic_library,
-                ],
-                flag_groups = [flag_group(flags = ["-shared"])],
-            ),
-        ],
-    )
-
-    random_seed_feature = feature(
-        name = "random_seed",
-        enabled = True,
-        flag_sets = [
-            flag_set(
-                actions = [
-                    ACTION_NAMES.c_compile,
-                    ACTION_NAMES.cpp_compile,
-                    ACTION_NAMES.cpp_module_codegen,
-                    ACTION_NAMES.cpp_module_compile,
-                ],
-                flag_groups = [
-                    flag_group(
-                        flags = ["-frandom-seed=%{output_file}"],
-                        expand_if_available = "output_file",
-                    ),
-                ],
-            ),
-        ],
-    )
-
-    includes_feature = feature(
-        name = "includes",
-        enabled = True,
-        flag_sets = [
-            flag_set(
-                actions = [
-                    ACTION_NAMES.preprocess_assemble,
-                    ACTION_NAMES.linkstamp_compile,
-                    ACTION_NAMES.c_compile,
-                    ACTION_NAMES.cpp_compile,
-                    ACTION_NAMES.cpp_header_parsing,
-                    ACTION_NAMES.cpp_module_compile,
-                    ACTION_NAMES.clif_match,
-                    ACTION_NAMES.objc_compile,
-                    ACTION_NAMES.objcpp_compile,
-                ],
-                flag_groups = [
-                    flag_group(
-                        flags = ["-include", "%{includes}"],
-                        iterate_over = "includes",
-                        expand_if_available = "includes",
-                    ),
-                ],
-            ),
-        ],
-    )
-
-    fdo_instrument_feature = feature(
-        name = "fdo_instrument",
-        flag_sets = [
-            flag_set(
-                actions = [
-                    ACTION_NAMES.c_compile,
-                    ACTION_NAMES.cpp_compile,
-                ] + all_link_actions + lto_index_actions,
-                flag_groups = [
-                    flag_group(
-                        flags = [
-                            "-fprofile-generate=%{fdo_instrument_path}",
-                            "-fno-data-sections",
-                        ],
-                        expand_if_available = "fdo_instrument_path",
-                    ),
-                ],
-            ),
-        ],
-        provides = ["profile"],
-    )
-
-    cs_fdo_instrument_feature = feature(
-        name = "cs_fdo_instrument",
-        flag_sets = [
-            flag_set(
-                actions = [
-                    ACTION_NAMES.c_compile,
-                    ACTION_NAMES.cpp_compile,
-                    ACTION_NAMES.lto_backend,
-                ] + all_link_actions + lto_index_actions,
-                flag_groups = [
-                    flag_group(
-                        flags = [
-                            "-fcs-profile-generate=%{cs_fdo_instrument_path}",
-                        ],
-                        expand_if_available = "cs_fdo_instrument_path",
-                    ),
-                ],
-            ),
-        ],
-        provides = ["csprofile"],
-    )
-
-    include_paths_feature = feature(
-        name = "include_paths",
-        enabled = True,
-        flag_sets = [
-            flag_set(
-                actions = [
-                    ACTION_NAMES.preprocess_assemble,
-                    ACTION_NAMES.linkstamp_compile,
-                    ACTION_NAMES.c_compile,
-                    ACTION_NAMES.cpp_compile,
-                    ACTION_NAMES.cpp_header_parsing,
-                    ACTION_NAMES.cpp_module_compile,
-                    ACTION_NAMES.clif_match,
-                    ACTION_NAMES.objc_compile,
-                    ACTION_NAMES.objcpp_compile,
-                ],
-                flag_groups = [
-                    flag_group(
-                        flags = ["-iquote", "%{quote_include_paths}"],
-                        iterate_over = "quote_include_paths",
-                    ),
-                    flag_group(
-                        flags = ["-I%{include_paths}"],
-                        iterate_over = "include_paths",
-                    ),
-                    flag_group(
-                        flags = ["-isystem", "%{system_include_paths}"],
-                        iterate_over = "system_include_paths",
-                    ),
-                ],
-            ),
-        ],
-    )
-
-    symbol_counts_feature = feature(
-        name = "symbol_counts",
-        flag_sets = [
-            flag_set(
-                actions = all_link_actions + lto_index_actions,
-                flag_groups = [
-                    flag_group(
-                        flags = [
-                            "-Wl,--print-symbol-counts=%{symbol_counts_output}",
-                        ],
-                        expand_if_available = "symbol_counts_output",
-                    ),
-                ],
-            ),
-        ],
-    )
-
-    llvm_coverage_map_format_feature = feature(
-        name = "llvm_coverage_map_format",
-        flag_sets = [
-            flag_set(
-                actions = [
-                    ACTION_NAMES.preprocess_assemble,
-                    ACTION_NAMES.c_compile,
-                    ACTION_NAMES.cpp_compile,
-                    ACTION_NAMES.cpp_module_compile,
-                    ACTION_NAMES.objc_compile,
-                    ACTION_NAMES.objcpp_compile,
-                ],
-                flag_groups = [
-                    flag_group(
-                        flags = [
-                            "-fprofile-instr-generate",
-                            "-fcoverage-mapping",
-                        ],
-                    ),
-                ],
-            ),
-            flag_set(
-                actions = all_link_actions + lto_index_actions + [
-                    "objc-executable",
-                    "objc++-executable",
-                ],
-                flag_groups = [
-                    flag_group(flags = ["-fprofile-instr-generate"]),
-                ],
-            ),
-        ],
-        requires = [feature_set(features = ["coverage"])],
-        provides = ["profile"],
-    )
-
-    strip_debug_symbols_feature = feature(
-        name = "strip_debug_symbols",
-        flag_sets = [
-            flag_set(
-                actions = all_link_actions + lto_index_actions,
-                flag_groups = [
-                    flag_group(
-                        flags = ["-Wl,-S"],
-                        expand_if_available = "strip_debug_symbols",
-                    ),
-                ],
-            ),
-        ],
-    )
-
-    build_interface_libraries_feature = feature(
-        name = "build_interface_libraries",
-        flag_sets = [
-            flag_set(
-                actions = [
-                    ACTION_NAMES.cpp_link_dynamic_library,
-                    ACTION_NAMES.cpp_link_nodeps_dynamic_library,
-                    ACTION_NAMES.lto_index_for_dynamic_library,
-                    ACTION_NAMES.lto_index_for_nodeps_dynamic_library,
-                ],
-                flag_groups = [
-                    flag_group(
-                        flags = [
-                            "%{generate_interface_library}",
-                            "%{interface_library_builder_path}",
-                            "%{interface_library_input_path}",
-                            "%{interface_library_output_path}",
-                        ],
-                        expand_if_available = "generate_interface_library",
-                    ),
-                ],
-                with_features = [
-                    with_feature_set(
-                        features = ["supports_interface_shared_libraries"],
-                    ),
-                ],
-            ),
-        ],
-    )
-
-    libraries_to_link_feature = feature(
-        name = "libraries_to_link",
-        flag_sets = [
-            flag_set(
-                actions = all_link_actions + lto_index_actions,
-                flag_groups = [
-                    flag_group(
-                        iterate_over = "libraries_to_link",
-                        flag_groups = [
-                            flag_group(
-                                flags = ["-Wl,--start-lib"],
-                                expand_if_equal = variable_with_value(
-                                    name = "libraries_to_link.type",
-                                    value = "object_file_group",
-                                ),
-                            ),
-                            flag_group(
-                                flags = ["-Wl,-whole-archive"],
-                                expand_if_true =
-                                    "libraries_to_link.is_whole_archive",
-                            ),
-                            flag_group(
-                                flags = ["%{libraries_to_link.object_files}"],
-                                iterate_over = "libraries_to_link.object_files",
-                                expand_if_equal = variable_with_value(
-                                    name = "libraries_to_link.type",
-                                    value = "object_file_group",
-                                ),
-                            ),
-                            flag_group(
-                                flags = ["%{libraries_to_link.name}"],
-                                expand_if_equal = variable_with_value(
-                                    name = "libraries_to_link.type",
-                                    value = "object_file",
-                                ),
-                            ),
-                            flag_group(
-                                flags = ["%{libraries_to_link.name}"],
-                                expand_if_equal = variable_with_value(
-                                    name = "libraries_to_link.type",
-                                    value = "interface_library",
-                                ),
-                            ),
-                            flag_group(
-                                flags = ["%{libraries_to_link.name}"],
-                                expand_if_equal = variable_with_value(
-                                    name = "libraries_to_link.type",
-                                    value = "static_library",
-                                ),
-                            ),
-                            flag_group(
-                                flags = ["-l%{libraries_to_link.name}"],
-                                expand_if_equal = variable_with_value(
-                                    name = "libraries_to_link.type",
-                                    value = "dynamic_library",
-                                ),
-                            ),
-                            flag_group(
-                                flags = ["-l:%{libraries_to_link.name}"],
-                                expand_if_equal = variable_with_value(
-                                    name = "libraries_to_link.type",
-                                    value = "versioned_dynamic_library",
-                                ),
-                            ),
-                            flag_group(
-                                flags = ["-Wl,-no-whole-archive"],
-                                expand_if_true = "libraries_to_link.is_whole_archive",
-                            ),
-                            flag_group(
-                                flags = ["-Wl,--end-lib"],
-                                expand_if_equal = variable_with_value(
-                                    name = "libraries_to_link.type",
-                                    value = "object_file_group",
-                                ),
-                            ),
-                        ],
-                        expand_if_available = "libraries_to_link",
-                    ),
-                    flag_group(
-                        flags = ["-Wl,@%{thinlto_param_file}"],
-                        expand_if_true = "thinlto_param_file",
-                    ),
-                ],
-            ),
-        ],
-    )
-
-    user_link_flags_feature = feature(
-        name = "user_link_flags",
-        flag_sets = [
-            flag_set(
-                actions = all_link_actions + lto_index_actions,
-                flag_groups = [
-                    flag_group(
-                        flags = ["%{user_link_flags}"],
-                        iterate_over = "user_link_flags",
-                        expand_if_available = "user_link_flags",
-                    ),
-                ] + ([flag_group(flags = ctx.attr.link_libs)] if ctx.attr.link_libs else []),
-            ),
-        ],
-    )
-
-    fdo_prefetch_hints_feature = feature(
-        name = "fdo_prefetch_hints",
-        flag_sets = [
-            flag_set(
-                actions = [
-                    ACTION_NAMES.c_compile,
-                    ACTION_NAMES.cpp_compile,
-                    ACTION_NAMES.lto_backend,
-                ],
-                flag_groups = [
-                    flag_group(
-                        flags = [
-                            "-Xclang-only=-mllvm",
-                            "-Xclang-only=-prefetch-hints-file=%{fdo_prefetch_hints_path}",
-                        ],
-                        expand_if_available = "fdo_prefetch_hints_path",
-                    ),
-                ],
-            ),
-        ],
-    )
-
-    linkstamps_feature = feature(
-        name = "linkstamps",
-        flag_sets = [
-            flag_set(
-                actions = all_link_actions + lto_index_actions,
-                flag_groups = [
-                    flag_group(
-                        flags = ["%{linkstamp_paths}"],
-                        iterate_over = "linkstamp_paths",
-                        expand_if_available = "linkstamp_paths",
-                    ),
-                ],
-            ),
-        ],
-    )
-
-    gcc_coverage_map_format_feature = feature(
-        name = "gcc_coverage_map_format",
-        flag_sets = [
-            flag_set(
-                actions = [
-                    ACTION_NAMES.preprocess_assemble,
-                    ACTION_NAMES.c_compile,
-                    ACTION_NAMES.cpp_compile,
-                    ACTION_NAMES.cpp_module_compile,
-                    ACTION_NAMES.objc_compile,
-                    ACTION_NAMES.objcpp_compile,
-                    "objc-executable",
-                    "objc++-executable",
-                ],
-                flag_groups = [
-                    flag_group(
-                        flags = ["-fprofile-arcs", "-ftest-coverage"],
-                        expand_if_available = "gcov_gcno_file",
-                    ),
-                ],
-            ),
-            flag_set(
-                actions = all_link_actions + lto_index_actions,
-                flag_groups = [flag_group(flags = ["--coverage"])],
-            ),
-        ],
-        requires = [feature_set(features = ["coverage"])],
-        provides = ["profile"],
-    )
-
-    archiver_flags_feature = feature(
-        name = "archiver_flags",
-        flag_sets = [
-            flag_set(
-                actions = [ACTION_NAMES.cpp_link_static_library],
-                flag_groups = [
-                    flag_group(flags = ["rcsD"]),
-                    flag_group(
-                        flags = ["%{output_execpath}"],
-                        expand_if_available = "output_execpath",
-                    ),
-                ],
-            ),
-            flag_set(
-                actions = [ACTION_NAMES.cpp_link_static_library],
-                flag_groups = [
-                    flag_group(
-                        iterate_over = "libraries_to_link",
-                        flag_groups = [
-                            flag_group(
-                                flags = ["%{libraries_to_link.name}"],
-                                expand_if_equal = variable_with_value(
-                                    name = "libraries_to_link.type",
-                                    value = "object_file",
-                                ),
-                            ),
-                            flag_group(
-                                flags = ["%{libraries_to_link.object_files}"],
-                                iterate_over = "libraries_to_link.object_files",
-                                expand_if_equal = variable_with_value(
-                                    name = "libraries_to_link.type",
-                                    value = "object_file_group",
-                                ),
-                            ),
-                        ],
-                        expand_if_available = "libraries_to_link",
-                    ),
-                ],
-            ),
-        ],
-    )
-
-    force_pic_flags_feature = feature(
-        name = "force_pic_flags",
-        flag_sets = [
-            flag_set(
-                actions = [
-                    ACTION_NAMES.cpp_link_executable,
-                    ACTION_NAMES.lto_index_for_executable,
-                ],
-                flag_groups = [
-                    flag_group(
-                        flags = ["-pie"],
-                        expand_if_available = "force_pic",
-                    ),
-                ],
-            ),
-        ],
-    )
-
-    dependency_file_feature = feature(
-        name = "dependency_file",
-        enabled = True,
-        flag_sets = [
-            flag_set(
-                actions = [
-                    ACTION_NAMES.assemble,
-                    ACTION_NAMES.preprocess_assemble,
-                    ACTION_NAMES.c_compile,
-                    ACTION_NAMES.cpp_compile,
-                    ACTION_NAMES.cpp_module_compile,
-                    ACTION_NAMES.objc_compile,
-                    ACTION_NAMES.objcpp_compile,
-                    ACTION_NAMES.cpp_header_parsing,
-                    ACTION_NAMES.clif_match,
-                ],
-                flag_groups = [
-                    flag_group(
-                        flags = ["-MD", "-MF", "%{dependency_file}"],
-                        expand_if_available = "dependency_file",
-                    ),
-                ],
-            ),
-        ],
-    )
-
-    dynamic_library_linker_tool_path = tool_paths
-    dynamic_library_linker_tool_feature = feature(
-        name = "dynamic_library_linker_tool",
-        flag_sets = [
-            flag_set(
-                actions = [
-                    ACTION_NAMES.cpp_link_dynamic_library,
-                    ACTION_NAMES.cpp_link_nodeps_dynamic_library,
-                    ACTION_NAMES.lto_index_for_dynamic_library,
-                    ACTION_NAMES.lto_index_for_nodeps_dynamic_library,
-                ],
-                flag_groups = [
-                    flag_group(
-                        flags = [" + cppLinkDynamicLibraryToolPath + "],
-                        expand_if_available = "generate_interface_library",
-                    ),
-                ],
-                with_features = [
-                    with_feature_set(
-                        features = ["supports_interface_shared_libraries"],
-                    ),
-                ],
-            ),
-        ],
-    )
-
-    output_execpath_flags_feature = feature(
-        name = "output_execpath_flags",
-        flag_sets = [
-            flag_set(
-                actions = all_link_actions + lto_index_actions,
-                flag_groups = [
-                    flag_group(
-                        flags = ["-o", "%{output_execpath}"],
-                        expand_if_available = "output_execpath",
-                    ),
-                ],
-            ),
-        ],
-    )
-
-    # Note that we also set --coverage for c++-link-nodeps-dynamic-library. The
-    # generated code contains references to gcov symbols, and the dynamic linker
-    # can't resolve them unless the library is linked against gcov.
-    coverage_feature = feature(
-        name = "coverage",
-        provides = ["profile"],
-        flag_sets = [
-            flag_set(
-                actions = [
-                    ACTION_NAMES.preprocess_assemble,
-                    ACTION_NAMES.c_compile,
-                    ACTION_NAMES.cpp_compile,
-                    ACTION_NAMES.cpp_header_parsing,
-                    ACTION_NAMES.cpp_module_compile,
-                ],
-                flag_groups = ([
-                    flag_group(flags = ctx.attr.coverage_compile_flags),
-                ] if ctx.attr.coverage_compile_flags else []),
-            ),
-            flag_set(
-                actions = all_link_actions + lto_index_actions,
-                flag_groups = ([
-                    flag_group(flags = ctx.attr.coverage_link_flags),
-                ] if ctx.attr.coverage_link_flags else []),
-            ),
-        ],
-    )
-
-    build_id_feature = feature(
-        name = "build-id",
-        enabled = True,
-        flag_sets = [
-            flag_set(
-                actions = all_link_actions,
-                flag_groups = [
-                    flag_group(
-                        flags = ["-Wl,--build-id=md5", "-Wl,--hash-style=gnu"],
-                    ),
-                ],
-            ),
-        ],
-    )
-
-    no_canonical_prefixes_feature = feature(
-        name = "no-canonical-prefixes",
-        enabled = True,
-        flag_sets = [
-            flag_set(
-                actions = [
-                    ACTION_NAMES.c_compile,
-                    ACTION_NAMES.cpp_compile,
-                    ACTION_NAMES.cpp_link_executable,
-                    ACTION_NAMES.cpp_link_dynamic_library,
-                    ACTION_NAMES.cpp_link_nodeps_dynamic_library,
-                ],
-                flag_groups = [
-                    flag_group(
-                        flags = [
-                            "-no-canonical-prefixes",
-                            "-fno-canonical-system-headers",
-                        ]
-                    ),
-                ],
-            ),
-        ],
-    )
-
-    linker_bin_path_feature = feature(
-        name = "linker-bin-path",
-        enabled = True,
-        flag_sets = [
-            flag_set(
-                actions = all_link_actions,
-                flag_groups = [flag_group(flags = ["-B" + ctx.attr.linker_bin_path])],
-            ),
-        ],
-    )
-
-    features = [
-        dependency_file_feature,
-        random_seed_feature,
-        pic_feature,
-        per_object_debug_info_feature,
-        preprocessor_defines_feature,
-        includes_feature,
-        include_paths_feature,
-        fdo_instrument_feature,
-        cs_fdo_instrument_feature,
-        cs_fdo_optimize_feature,
-        fdo_prefetch_hints_feature,
-        autofdo_feature,
-        build_interface_libraries_feature,
-        dynamic_library_linker_tool_feature,
-        symbol_counts_feature,
-        shared_flag_feature,
-        linkstamps_feature,
-        output_execpath_flags_feature,
-        runtime_library_search_directories_feature,
-        library_search_directories_feature,
-        archiver_flags_feature,
-        force_pic_flags_feature,
-        fission_support_feature,
-        strip_debug_symbols_feature,
-        coverage_feature,
-        supports_pic_feature,
-    ] + (
-        [
-            supports_start_end_lib_feature,
-        ] if ctx.attr.supports_start_end_lib else []
-    ) + [
-        default_compile_flags_feature,
-        default_link_flags_feature,
-        libraries_to_link_feature,
-        user_link_flags_feature,
-        static_libgcc_feature,
-        fdo_optimize_feature,
-        supports_dynamic_linker_feature,
-        dbg_feature,
-        opt_feature,
-        user_compile_flags_feature,
-        sysroot_feature,
-        unfiltered_compile_flags_feature,
-        build_id_feature,
-        no_canonical_prefixes_feature,
-        linker_bin_path_feature,
-    ]
-
-    return cc_common.create_cc_toolchain_config_info(
-        ctx = ctx,
-        features = features,
-        action_configs = action_configs,
-        cxx_builtin_include_directories = ctx.attr.cxx_builtin_include_directories,
-        toolchain_identifier = ctx.attr.toolchain_identifier,
-        host_system_name = ctx.attr.host_system_name,
-        target_system_name = ctx.attr.target_system_name,
-        target_cpu = ctx.attr.cpu,
-        target_libc = ctx.attr.target_libc,
-        compiler = ctx.attr.compiler,
-        abi_version = ctx.attr.abi_version,
-        abi_libc_version = ctx.attr.abi_libc_version,
-        tool_paths = tool_paths,
-    )
-
-cc_toolchain_config = rule(
-    implementation = _impl,
-    attrs = {
-        "cpu": attr.string(mandatory = True),
-        "compiler": attr.string(mandatory = True),
-        "toolchain_identifier": attr.string(mandatory = True),
-        "host_system_name": attr.string(mandatory = True),
-        "target_system_name": attr.string(mandatory = True),
-        "target_libc": attr.string(mandatory = True),
-        "abi_version": attr.string(mandatory = True),
-        "abi_libc_version": attr.string(mandatory = True),
-        "cxx_builtin_include_directories": attr.string_list(),
-        "compile_flags": attr.string_list(),
-        "dbg_compile_flags": attr.string_list(),
-        "opt_compile_flags": attr.string_list(),
-        "cxx_flags": attr.string_list(),
-        "link_flags": attr.string_list(),
-        "link_libs": attr.string_list(),
-        "opt_link_flags": attr.string_list(),
-        "unfiltered_compile_flags": attr.string_list(),
-        "coverage_compile_flags": attr.string_list(),
-        "coverage_link_flags": attr.string_list(),
-        "supports_start_end_lib": attr.bool(),
-        "host_compiler_path": attr.string(),
-        "host_compiler_prefix": attr.string(),
-        "linker_bin_path": attr.string(),
-    },
-    provides = [CcToolchainConfigInfo],
-)
-
diff --git a/third_party/xla/third_party/gpus/cuda/BUILD b/third_party/xla/third_party/gpus/cuda/BUILD
deleted file mode 100644
index e69de29..0000000
--- a/third_party/xla/third_party/gpus/cuda/BUILD
+++ /dev/null
diff --git a/third_party/xla/third_party/gpus/cuda/BUILD.tpl b/third_party/xla/third_party/gpus/cuda/BUILD.tpl
deleted file mode 100644
index 700e040..0000000
--- a/third_party/xla/third_party/gpus/cuda/BUILD.tpl
+++ /dev/null
@@ -1,245 +0,0 @@
-load(":build_defs.bzl", "cuda_header_library")
-load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
-load("@bazel_skylib//lib:selects.bzl", "selects")
-
-licenses(["restricted"])  # MPL2, portions GPL v3, LGPL v3, BSD-like
-
-package(default_visibility = ["//visibility:public"])
-
-# Config setting whether TensorFlow is built with CUDA support using clang.
-#
-# TODO(b/174244321), DEPRECATED: this target will be removed when all users
-# have been converted to :is_cuda_enabled (most) or :is_cuda_compiler_clang.
-selects.config_setting_group(
-    name = "using_clang",
-    match_all = [
-        "@local_config_cuda//:is_cuda_enabled",
-        "@local_config_cuda//:is_cuda_compiler_clang",
-    ],
-)
-
-# Config setting whether TensorFlow is built with CUDA support using nvcc.
-#
-# TODO(b/174244321), DEPRECATED: this target will be removed when all users
-# have been converted to :is_cuda_enabled (most) or :is_cuda_compiler_nvcc.
-selects.config_setting_group(
-    name = "using_nvcc",
-    match_all = [
-        "@local_config_cuda//:is_cuda_enabled",
-        "@local_config_cuda//:is_cuda_compiler_nvcc",
-    ],
-)
-
-# Equivalent to using_clang && -c opt.
-selects.config_setting_group(
-    name = "using_clang_opt",
-    match_all = [
-        ":using_clang",
-        ":_opt",
-    ],
-)
-
-config_setting(
-    name = "_opt",
-    values = {"compilation_mode": "opt"},
-)
-
-# Provides CUDA headers for '#include "third_party/gpus/cuda/include/cuda.h"'
-# All clients including TensorFlow should use these directives.
-cuda_header_library(
-    name = "cuda_headers",
-    hdrs = [
-        "cuda/cuda_config.h",
-        ":cuda-include",
-    ],
-    include_prefix = "third_party/gpus",
-    includes = [
-        ".",  # required to include cuda/cuda/cuda_config.h as cuda/config.h
-        "cuda/include",
-    ],
-)
-
-cc_library(
-    name = "cudart_static",
-    srcs = ["cuda/lib/libcudart_static.a"],
-    linkopts = [
-        "-ldl",
-        "-lrt",
-        "-lpthread",
-    ],
-)
-
-cc_library(
-    name = "cuda_driver",
-    srcs = ["cuda/lib/libcuda.so"],
-)
-
-cc_library(
-    name = "cudart",
-    srcs = glob(["cuda/lib/libcudart.so.*"]),
-    data = glob(["cuda/lib/libcudart.so.*"]),
-    linkstatic = 1,
-)
-
-cuda_header_library(
-    name = "cublas_headers",
-    hdrs = [":cublas-include"],
-    include_prefix = "third_party/gpus/cuda/include",
-    includes = ["cublas/include"],
-    strip_include_prefix = "cublas/include",
-    deps = [":cuda_headers"],
-)
-
-cuda_header_library(
-    name = "cusolver_headers",
-    hdrs = [":cusolver-include"],
-    include_prefix = "third_party/gpus/cuda/include",
-    includes = ["cusolver/include"],
-    strip_include_prefix = "cusolver/include",
-    deps = [":cuda_headers"],
-)
-
-cuda_header_library(
-    name = "cufft_headers",
-    hdrs = [":cufft-include"],
-    include_prefix = "third_party/gpus/cuda/include",
-    includes = ["cufft/include"],
-    strip_include_prefix = "cufft/include",
-    deps = [":cuda_headers"],
-)
-
-cuda_header_library(
-    name = "cusparse_headers",
-    hdrs = [":cusparse-include"],
-    include_prefix = "third_party/gpus/cuda/include",
-    includes = ["cusparse/include"],
-    strip_include_prefix = "cusparse/include",
-    deps = [":cuda_headers"],
-)
-
-cuda_header_library(
-    name = "curand_headers",
-    hdrs = [":curand-include"],
-    include_prefix = "third_party/gpus/cuda/include",
-    includes = ["curand/include"],
-    strip_include_prefix = "curand/include",
-    deps = [":cuda_headers"],
-)
-
-cc_library(
-    name = "cublas",
-    srcs = glob(["cuda/lib/libcublas.so.*"]),
-    data = glob(["cuda/lib/libcublas.so.*"]),
-    linkstatic = 1,
-)
-
-cc_library(
-    name = "cublasLt",
-    srcs = glob(["cuda/lib/libcublasLt.so.*"]),
-    data = glob(["cuda/lib/libcublasLt.so.*"]),
-    linkstatic = 1,
-)
-
-cc_library(
-    name = "cusolver",
-    srcs = glob(["cuda/lib/libcusolver.so.*"]),
-    data = glob(["cuda/lib/libcusolver.so.*"]),
-    linkopts = ["-lgomp"],
-    linkstatic = 1,
-)
-
-cc_library(
-    name = "cudnn",
-    srcs = glob(["cuda/lib/libcudnn.so.*"]),
-    data = glob(["cuda/lib/libcudnn.so.*"]),
-    linkstatic = 1,
-)
-
-cc_library(
-    name = "cudnn_header",
-    hdrs = [":cudnn-include"],
-    include_prefix = "third_party/gpus/cudnn",
-    strip_include_prefix = "cudnn/include",
-    deps = [":cuda_headers"],
-)
-
-cc_library(
-    name = "cufft",
-    srcs = glob(["cuda/lib/libcufft.so.*"]),
-    data = glob(["cuda/lib/libcufft.so.*"]),
-    linkstatic = 1,
-)
-
-cc_library(
-    name = "curand",
-    srcs = glob(["cuda/lib/libcurand.so.*"]),
-    data = glob(["cuda/lib/libcurand.so.*"]),
-    linkstatic = 1,
-)
-
-cc_library(
-    name = "cuda",
-    deps = [
-        ":cublas",
-        ":cublasLt",
-        ":cuda_headers",
-        ":cudart",
-        ":cudnn",
-        ":cufft",
-        ":curand",
-    ],
-)
-
-alias(
-    name = "cub_headers",
-    actual = ":cuda_headers",
-)
-
-cuda_header_library(
-    name = "cupti_headers",
-    hdrs = [":cuda-extras"],
-    include_prefix = "third_party/gpus",
-    includes = ["cuda/extras/CUPTI/include/"],
-    deps = [":cuda_headers"],
-)
-
-cuda_header_library(
-    name = "nvml_headers",
-    hdrs = [":nvml"],
-    include_prefix = "third_party/gpus",
-    includes = ["cuda/nvml/include/"],
-    deps = [":cuda_headers"],
-)
-
-cc_library(
-    name = "cupti_dsos",
-    data = glob(["cuda/lib/libcupti.so.*"]),
-)
-
-cc_library(
-    name = "cusparse",
-    srcs = glob(["cuda/lib/libcusparse.so.*"]),
-    data = glob(["cuda/lib/libcusparse.so.*"]),
-    linkopts = ["-lgomp"],
-    linkstatic = 1,
-)
-
-cc_library(
-    name = "libdevice_root",
-    data = [":cuda-nvvm"],
-)
-
-bzl_library(
-    name = "build_defs_bzl",
-    srcs = ["build_defs.bzl"],
-    deps = [
-        "@bazel_skylib//lib:selects",
-    ],
-)
-
-py_library(
-    name = "cuda_config_py",
-    srcs = ["cuda/cuda_config.py"],
-)
-
-%{copy_rules}
diff --git a/third_party/xla/third_party/gpus/cuda/LICENSE b/third_party/xla/third_party/gpus/cuda/LICENSE
deleted file mode 100644
index d3da228..0000000
--- a/third_party/xla/third_party/gpus/cuda/LICENSE
+++ /dev/null
@@ -1,203 +0,0 @@
-Copyright 2015 The TensorFlow Authors.  All rights reserved.
-
-                                 Apache License
-                           Version 2.0, January 2004
-                        http://www.apache.org/licenses/
-
-   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-
-   1. Definitions.
-
-      "License" shall mean the terms and conditions for use, reproduction,
-      and distribution as defined by Sections 1 through 9 of this document.
-
-      "Licensor" shall mean the copyright owner or entity authorized by
-      the copyright owner that is granting the License.
-
-      "Legal Entity" shall mean the union of the acting entity and all
-      other entities that control, are controlled by, or are under common
-      control with that entity. For the purposes of this definition,
-      "control" means (i) the power, direct or indirect, to cause the
-      direction or management of such entity, whether by contract or
-      otherwise, or (ii) ownership of fifty percent (50%) or more of the
-      outstanding shares, or (iii) beneficial ownership of such entity.
-
-      "You" (or "Your") shall mean an individual or Legal Entity
-      exercising permissions granted by this License.
-
-      "Source" form shall mean the preferred form for making modifications,
-      including but not limited to software source code, documentation
-      source, and configuration files.
-
-      "Object" form shall mean any form resulting from mechanical
-      transformation or translation of a Source form, including but
-      not limited to compiled object code, generated documentation,
-      and conversions to other media types.
-
-      "Work" shall mean the work of authorship, whether in Source or
-      Object form, made available under the License, as indicated by a
-      copyright notice that is included in or attached to the work
-      (an example is provided in the Appendix below).
-
-      "Derivative Works" shall mean any work, whether in Source or Object
-      form, that is based on (or derived from) the Work and for which the
-      editorial revisions, annotations, elaborations, or other modifications
-      represent, as a whole, an original work of authorship. For the purposes
-      of this License, Derivative Works shall not include works that remain
-      separable from, or merely link (or bind by name) to the interfaces of,
-      the Work and Derivative Works thereof.
-
-      "Contribution" shall mean any work of authorship, including
-      the original version of the Work and any modifications or additions
-      to that Work or Derivative Works thereof, that is intentionally
-      submitted to Licensor for inclusion in the Work by the copyright owner
-      or by an individual or Legal Entity authorized to submit on behalf of
-      the copyright owner. For the purposes of this definition, "submitted"
-      means any form of electronic, verbal, or written communication sent
-      to the Licensor or its representatives, including but not limited to
-      communication on electronic mailing lists, source code control systems,
-      and issue tracking systems that are managed by, or on behalf of, the
-      Licensor for the purpose of discussing and improving the Work, but
-      excluding communication that is conspicuously marked or otherwise
-      designated in writing by the copyright owner as "Not a Contribution."
-
-      "Contributor" shall mean Licensor and any individual or Legal Entity
-      on behalf of whom a Contribution has been received by Licensor and
-      subsequently incorporated within the Work.
-
-   2. Grant of Copyright License. Subject to the terms and conditions of
-      this License, each Contributor hereby grants to You a perpetual,
-      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
-      copyright license to reproduce, prepare Derivative Works of,
-      publicly display, publicly perform, sublicense, and distribute the
-      Work and such Derivative Works in Source or Object form.
-
-   3. Grant of Patent License. Subject to the terms and conditions of
-      this License, each Contributor hereby grants to You a perpetual,
-      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
-      (except as stated in this section) patent license to make, have made,
-      use, offer to sell, sell, import, and otherwise transfer the Work,
-      where such license applies only to those patent claims licensable
-      by such Contributor that are necessarily infringed by their
-      Contribution(s) alone or by combination of their Contribution(s)
-      with the Work to which such Contribution(s) was submitted. If You
-      institute patent litigation against any entity (including a
-      cross-claim or counterclaim in a lawsuit) alleging that the Work
-      or a Contribution incorporated within the Work constitutes direct
-      or contributory patent infringement, then any patent licenses
-      granted to You under this License for that Work shall terminate
-      as of the date such litigation is filed.
-
-   4. Redistribution. You may reproduce and distribute copies of the
-      Work or Derivative Works thereof in any medium, with or without
-      modifications, and in Source or Object form, provided that You
-      meet the following conditions:
-
-      (a) You must give any other recipients of the Work or
-          Derivative Works a copy of this License; and
-
-      (b) You must cause any modified files to carry prominent notices
-          stating that You changed the files; and
-
-      (c) You must retain, in the Source form of any Derivative Works
-          that You distribute, all copyright, patent, trademark, and
-          attribution notices from the Source form of the Work,
-          excluding those notices that do not pertain to any part of
-          the Derivative Works; and
-
-      (d) If the Work includes a "NOTICE" text file as part of its
-          distribution, then any Derivative Works that You distribute must
-          include a readable copy of the attribution notices contained
-          within such NOTICE file, excluding those notices that do not
-          pertain to any part of the Derivative Works, in at least one
-          of the following places: within a NOTICE text file distributed
-          as part of the Derivative Works; within the Source form or
-          documentation, if provided along with the Derivative Works; or,
-          within a display generated by the Derivative Works, if and
-          wherever such third-party notices normally appear. The contents
-          of the NOTICE file are for informational purposes only and
-          do not modify the License. You may add Your own attribution
-          notices within Derivative Works that You distribute, alongside
-          or as an addendum to the NOTICE text from the Work, provided
-          that such additional attribution notices cannot be construed
-          as modifying the License.
-
-      You may add Your own copyright statement to Your modifications and
-      may provide additional or different license terms and conditions
-      for use, reproduction, or distribution of Your modifications, or
-      for any such Derivative Works as a whole, provided Your use,
-      reproduction, and distribution of the Work otherwise complies with
-      the conditions stated in this License.
-
-   5. Submission of Contributions. Unless You explicitly state otherwise,
-      any Contribution intentionally submitted for inclusion in the Work
-      by You to the Licensor shall be under the terms and conditions of
-      this License, without any additional terms or conditions.
-      Notwithstanding the above, nothing herein shall supersede or modify
-      the terms of any separate license agreement you may have executed
-      with Licensor regarding such Contributions.
-
-   6. Trademarks. This License does not grant permission to use the trade
-      names, trademarks, service marks, or product names of the Licensor,
-      except as required for reasonable and customary use in describing the
-      origin of the Work and reproducing the content of the NOTICE file.
-
-   7. Disclaimer of Warranty. Unless required by applicable law or
-      agreed to in writing, Licensor provides the Work (and each
-      Contributor provides its Contributions) on an "AS IS" BASIS,
-      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
-      implied, including, without limitation, any warranties or conditions
-      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
-      PARTICULAR PURPOSE. You are solely responsible for determining the
-      appropriateness of using or redistributing the Work and assume any
-      risks associated with Your exercise of permissions under this License.
-
-   8. Limitation of Liability. In no event and under no legal theory,
-      whether in tort (including negligence), contract, or otherwise,
-      unless required by applicable law (such as deliberate and grossly
-      negligent acts) or agreed to in writing, shall any Contributor be
-      liable to You for damages, including any direct, indirect, special,
-      incidental, or consequential damages of any character arising as a
-      result of this License or out of the use or inability to use the
-      Work (including but not limited to damages for loss of goodwill,
-      work stoppage, computer failure or malfunction, or any and all
-      other commercial damages or losses), even if such Contributor
-      has been advised of the possibility of such damages.
-
-   9. Accepting Warranty or Additional Liability. While redistributing
-      the Work or Derivative Works thereof, You may choose to offer,
-      and charge a fee for, acceptance of support, warranty, indemnity,
-      or other liability obligations and/or rights consistent with this
-      License. However, in accepting such obligations, You may act only
-      on Your own behalf and on Your sole responsibility, not on behalf
-      of any other Contributor, and only if You agree to indemnify,
-      defend, and hold each Contributor harmless for any liability
-      incurred by, or claims asserted against, such Contributor by reason
-      of your accepting any such warranty or additional liability.
-
-   END OF TERMS AND CONDITIONS
-
-   APPENDIX: How to apply the Apache License to your work.
-
-      To apply the Apache License to your work, attach the following
-      boilerplate notice, with the fields enclosed by brackets "[]"
-      replaced with your own identifying information. (Don't include
-      the brackets!)  The text should be enclosed in the appropriate
-      comment syntax for the file format. We also recommend that a
-      file or class name and description of purpose be included on the
-      same "printed page" as the copyright notice for easier
-      identification within third-party archives.
-
-   Copyright 2015, The TensorFlow Authors.
-
-   Licensed under the Apache License, Version 2.0 (the "License");
-   you may not use this file except in compliance with the License.
-   You may obtain a copy of the License at
-
-       http://www.apache.org/licenses/LICENSE-2.0
-
-   Unless required by applicable law or agreed to in writing, software
-   distributed under the License is distributed on an "AS IS" BASIS,
-   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-   See the License for the specific language governing permissions and
-   limitations under the License.
diff --git a/third_party/xla/third_party/gpus/cuda/build_defs.bzl.tpl b/third_party/xla/third_party/gpus/cuda/build_defs.bzl.tpl
deleted file mode 100644
index 189d3e3..0000000
--- a/third_party/xla/third_party/gpus/cuda/build_defs.bzl.tpl
+++ /dev/null
@@ -1,151 +0,0 @@
-# Macros for building CUDA code.
-def if_cuda(if_true, if_false = []):
-    """Shorthand for select()'ing on whether we're building with CUDA.
-
-    Returns a select statement which evaluates to if_true if we're building
-    with CUDA enabled.  Otherwise, the select statement evaluates to if_false.
-    """
-    return select({
-        "@local_config_cuda//:is_cuda_enabled": if_true,
-        "//conditions:default": if_false,
-    })
-
-def if_cuda_clang(if_true, if_false = []):
-   """Shorthand for select()'ing on wheteher we're building with cuda-clang.
-
-    Returns a select statement which evaluates to if_true if we're building
-    with cuda-clang.  Otherwise, the select statement evaluates to if_false.
-   """
-   return select({
-       "@local_config_cuda//cuda:using_clang": if_true,
-       "//conditions:default": if_false
-   })
-
-def if_cuda_exec(if_true, if_false = []):
-    """Synonym for if_cuda.
-
-    Selects if_true both in target and in exec configurations. In principle,
-    if_cuda would only need to select if_true in a target configuration, but
-    not in an exec configuration, but this is not currently implemented.
-    """
-    return if_cuda(if_true, if_false)
-
-def cuda_compiler(if_cuda_clang, if_nvcc, neither = []):
-    """Shorthand for select()'ing on wheteher we're building with cuda-clang or nvcc.
-
-     Returns a select statement which evaluates to if_cuda_clang if we're building
-     with cuda-clang, if_nvcc if we're building with NVCC.
-     Otherwise, the select statement evaluates to neither.
-
-    """
-    if %{cuda_is_configured}:
-        return select({
-            "@local_config_cuda//cuda:using_clang": if_cuda_clang,
-            "@local_config_cuda//:is_cuda_compiler_nvcc": if_nvcc,
-            "//conditions:default": neither
-        })
-    else:
-        return select({
-            "//conditions:default": neither
-        })
-
-def if_cuda_clang_opt(if_true, if_false = []):
-   """Shorthand for select()'ing on wheteher we're building with cuda-clang
-   in opt mode.
-
-    Returns a select statement which evaluates to if_true if we're building
-    with cuda-clang in opt mode. Otherwise, the select statement evaluates to
-    if_false.
-
-   """
-   return select({
-       "@local_config_cuda//cuda:using_clang_opt": if_true,
-       "//conditions:default": if_false
-   })
-
-def cuda_default_copts():
-    """Default options for all CUDA compilations."""
-    return if_cuda([
-        "-x", "cuda",
-        "-DGOOGLE_CUDA=1",
-    ] + %{cuda_extra_copts}) + if_cuda_clang_opt(
-        # Some important CUDA optimizations are only enabled at O3.
-        ["-O3"]
-    ) + cuda_compiler(
-        if_cuda_clang = [ "-Xcuda-fatbinary", "--compress-all"],
-        if_nvcc = [
-            "-Xcuda-fatbinary=--compress-all",
-            # Ensure that NVCC matches clang's constexpr behavior.
-            "-nvcc_options=expt-relaxed-constexpr"
-        ]
-    )
-
-def cuda_gpu_architectures():
-    """Returns a list of supported GPU architectures."""
-    return %{cuda_gpu_architectures}
-
-def if_cuda_is_configured(x, no_cuda = []):
-    """Tests if the CUDA was enabled during the configure process.
-
-    Unlike if_cuda(), this does not require that we are building with
-    --config=cuda. Used to allow non-CUDA code to depend on CUDA libraries.
-    """
-    if %{cuda_is_configured}:
-      return select({"//conditions:default": x})
-    return select({"//conditions:default": no_cuda})
-
-def cuda_header_library(
-        name,
-        hdrs,
-        include_prefix = None,
-        strip_include_prefix = None,
-        deps = [],
-        **kwargs):
-    """Generates a cc_library containing both virtual and system include paths.
-
-    Generates both a header-only target with virtual includes plus the full
-    target without virtual includes. This works around the fact that bazel can't
-    mix 'includes' and 'include_prefix' in the same target."""
-
-    native.cc_library(
-        name = name + "_virtual",
-        hdrs = hdrs,
-        include_prefix = include_prefix,
-        strip_include_prefix = strip_include_prefix,
-        deps = deps,
-        visibility = ["//visibility:private"],
-    )
-
-    native.cc_library(
-        name = name,
-        textual_hdrs = hdrs,
-        deps = deps + [":%s_virtual" % name],
-        **kwargs
-    )
-
-def cuda_library(copts = [], **kwargs):
-    """Wrapper over cc_library which adds default CUDA options."""
-    native.cc_library(copts = cuda_default_copts() + copts, **kwargs)
-
-def cuda_cc_test(copts = [], **kwargs):
-    """Wrapper over cc_test which adds default CUDA options."""
-    native.cc_test(copts = copts + if_cuda(["-DGOOGLE_CUDA=1"]), **kwargs)
-
-EnableCudaInfo = provider()
-
-def _enable_cuda_flag_impl(ctx):
-    value = ctx.build_setting_value
-    if ctx.attr.enable_override:
-        print(
-            "\n\033[1;33mWarning:\033[0m '--define=using_cuda_nvcc' will be " +
-            "unsupported soon. Use '--@local_config_cuda//:enable_cuda' " +
-            "instead."
-        )
-        value = True
-    return EnableCudaInfo(value = value)
-
-enable_cuda_flag = rule(
-    implementation = _enable_cuda_flag_impl,
-    build_setting = config.bool(flag = True),
-    attrs = {"enable_override": attr.bool()},
-)
diff --git a/third_party/xla/third_party/gpus/cuda/cuda_config.h.tpl b/third_party/xla/third_party/gpus/cuda/cuda_config.h.tpl
deleted file mode 100644
index 03ecd01..0000000
--- a/third_party/xla/third_party/gpus/cuda/cuda_config.h.tpl
+++ /dev/null
@@ -1,33 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef CUDA_CUDA_CONFIG_H_
-#define CUDA_CUDA_CONFIG_H_
-
-#define TF_CUDA_VERSION "%{cuda_version}"
-#define TF_CUDART_VERSION "%{cudart_version}"
-#define TF_CUPTI_VERSION "%{cupti_version}"
-#define TF_CUBLAS_VERSION "%{cublas_version}"
-#define TF_CUSOLVER_VERSION "%{cusolver_version}"
-#define TF_CURAND_VERSION "%{curand_version}"
-#define TF_CUFFT_VERSION "%{cufft_version}"
-#define TF_CUSPARSE_VERSION "%{cusparse_version}"
-#define TF_CUDNN_VERSION "%{cudnn_version}"
-
-#define TF_CUDA_TOOLKIT_PATH "%{cuda_toolkit_path}"
-
-#define TF_CUDA_COMPUTE_CAPABILITIES %{cuda_compute_capabilities}
-
-#endif  // CUDA_CUDA_CONFIG_H_
diff --git a/third_party/xla/third_party/gpus/cuda_configure.bzl b/third_party/xla/third_party/gpus/cuda_configure.bzl
deleted file mode 100644
index ff3e8d6..0000000
--- a/third_party/xla/third_party/gpus/cuda_configure.bzl
+++ /dev/null
@@ -1,1224 +0,0 @@
-"""Repository rule for CUDA autoconfiguration.
-
-`cuda_configure` depends on the following environment variables:
-
-  * `TF_NEED_CUDA`: Whether to enable building with CUDA.
-  * `GCC_HOST_COMPILER_PATH`: The GCC host compiler path
-  * `TF_CUDA_CLANG`: Whether to use clang as a cuda compiler.
-  * `CLANG_CUDA_COMPILER_PATH`: The clang compiler path that will be used for
-    both host and device code compilation if TF_CUDA_CLANG is 1.
-  * `TF_SYSROOT`: The sysroot to use when compiling.
-  * `TF_DOWNLOAD_CLANG`: Whether to download a recent release of clang
-    compiler and use it to build tensorflow. When this option is set
-    CLANG_CUDA_COMPILER_PATH is ignored.
-  * `TF_CUDA_PATHS`: The base paths to look for CUDA and cuDNN. Default is
-    `/usr/local/cuda,usr/`.
-  * `CUDA_TOOLKIT_PATH` (deprecated): The path to the CUDA toolkit. Default is
-    `/usr/local/cuda`.
-  * `TF_CUDA_VERSION`: The version of the CUDA toolkit. If this is blank, then
-    use the system default.
-  * `TF_CUDNN_VERSION`: The version of the cuDNN library.
-  * `CUDNN_INSTALL_PATH` (deprecated): The path to the cuDNN library. Default is
-    `/usr/local/cuda`.
-  * `TF_CUDA_COMPUTE_CAPABILITIES`: The CUDA compute capabilities. Default is
-    `3.5,5.2`.
-  * `PYTHON_BIN_PATH`: The python binary path
-"""
-
-load("//third_party/clang_toolchain:download_clang.bzl", "download_clang")
-load(
-    "//third_party/remote_config:common.bzl",
-    "config_repo_label",
-    "err_out",
-    "execute",
-    "get_bash_bin",
-    "get_host_environ",
-    "get_python_bin",
-    "raw_exec",
-    "read_dir",
-    "realpath",
-    "which",
-)
-
-_GCC_HOST_COMPILER_PATH = "GCC_HOST_COMPILER_PATH"
-_GCC_HOST_COMPILER_PREFIX = "GCC_HOST_COMPILER_PREFIX"
-_CLANG_CUDA_COMPILER_PATH = "CLANG_CUDA_COMPILER_PATH"
-_TF_SYSROOT = "TF_SYSROOT"
-_CUDA_TOOLKIT_PATH = "CUDA_TOOLKIT_PATH"
-_TF_CUDA_VERSION = "TF_CUDA_VERSION"
-_TF_CUDNN_VERSION = "TF_CUDNN_VERSION"
-_CUDNN_INSTALL_PATH = "CUDNN_INSTALL_PATH"
-_TF_CUDA_COMPUTE_CAPABILITIES = "TF_CUDA_COMPUTE_CAPABILITIES"
-_TF_CUDA_CONFIG_REPO = "TF_CUDA_CONFIG_REPO"
-_TF_DOWNLOAD_CLANG = "TF_DOWNLOAD_CLANG"
-_PYTHON_BIN_PATH = "PYTHON_BIN_PATH"
-
-def to_list_of_strings(elements):
-    """Convert the list of ["a", "b", "c"] into '"a", "b", "c"'.
-
-    This is to be used to put a list of strings into the bzl file templates
-    so it gets interpreted as list of strings in Starlark.
-
-    Args:
-      elements: list of string elements
-
-    Returns:
-      single string of elements wrapped in quotes separated by a comma."""
-    quoted_strings = ["\"" + element + "\"" for element in elements]
-    return ", ".join(quoted_strings)
-
-def verify_build_defines(params):
-    """Verify all variables that crosstool/BUILD.tpl expects are substituted.
-
-    Args:
-      params: dict of variables that will be passed to the BUILD.tpl template.
-    """
-    missing = []
-    for param in [
-        "cxx_builtin_include_directories",
-        "extra_no_canonical_prefixes_flags",
-        "host_compiler_path",
-        "host_compiler_prefix",
-        "host_compiler_warnings",
-        "linker_bin_path",
-        "compiler_deps",
-        "unfiltered_compile_flags",
-    ]:
-        if ("%{" + param + "}") not in params:
-            missing.append(param)
-
-    if missing:
-        auto_configure_fail(
-            "BUILD.tpl template is missing these variables: " +
-            str(missing) +
-            ".\nWe only got: " +
-            str(params) +
-            ".",
-        )
-
-# TODO(dzc): Once these functions have been factored out of Bazel's
-# cc_configure.bzl, load them from @bazel_tools instead.
-# BEGIN cc_configure common functions.
-def find_cc(repository_ctx):
-    """Find the C++ compiler."""
-
-    if _use_cuda_clang(repository_ctx):
-        target_cc_name = "clang"
-        cc_path_envvar = _CLANG_CUDA_COMPILER_PATH
-        if _flag_enabled(repository_ctx, _TF_DOWNLOAD_CLANG):
-            return "extra_tools/bin/clang"
-    else:
-        target_cc_name = "gcc"
-        cc_path_envvar = _GCC_HOST_COMPILER_PATH
-    cc_name = target_cc_name
-
-    cc_name_from_env = get_host_environ(repository_ctx, cc_path_envvar)
-    if cc_name_from_env:
-        cc_name = cc_name_from_env
-    if cc_name.startswith("/"):
-        # Absolute path, maybe we should make this supported by our which function.
-        return cc_name
-    cc = which(repository_ctx, cc_name)
-    if cc == None:
-        fail(("Cannot find {}, either correct your path or set the {}" +
-              " environment variable").format(target_cc_name, cc_path_envvar))
-    return cc
-
-_INC_DIR_MARKER_BEGIN = "#include <...>"
-
-# OSX add " (framework directory)" at the end of line, strip it.
-_OSX_FRAMEWORK_SUFFIX = " (framework directory)"
-_OSX_FRAMEWORK_SUFFIX_LEN = len(_OSX_FRAMEWORK_SUFFIX)
-
-def _cxx_inc_convert(path):
-    """Convert path returned by cc -E xc++ in a complete path."""
-    path = path.strip()
-    if path.endswith(_OSX_FRAMEWORK_SUFFIX):
-        path = path[:-_OSX_FRAMEWORK_SUFFIX_LEN].strip()
-    return path
-
-def _normalize_include_path(repository_ctx, path):
-    """Normalizes include paths before writing them to the crosstool.
-
-      If path points inside the 'crosstool' folder of the repository, a relative
-      path is returned.
-      If path points outside the 'crosstool' folder, an absolute path is returned.
-      """
-    path = str(repository_ctx.path(path))
-    crosstool_folder = str(repository_ctx.path(".").get_child("crosstool"))
-
-    if path.startswith(crosstool_folder):
-        # We drop the path to "$REPO/crosstool" and a trailing path separator.
-        return path[len(crosstool_folder) + 1:]
-    return path
-
-def _is_compiler_option_supported(repository_ctx, cc, option):
-    """Checks that `option` is supported by the C compiler. Doesn't %-escape the option."""
-    result = repository_ctx.execute([
-        cc,
-        option,
-        "-o",
-        "/dev/null",
-        "-c",
-        str(repository_ctx.path("tools/cpp/empty.cc")),
-    ])
-    return result.stderr.find(option) == -1
-
-def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp, tf_sysroot):
-    """Compute the list of default C or C++ include directories."""
-    if lang_is_cpp:
-        lang = "c++"
-    else:
-        lang = "c"
-    sysroot = []
-    if tf_sysroot:
-        sysroot += ["--sysroot", tf_sysroot]
-    result = raw_exec(repository_ctx, [cc, "-E", "-x" + lang, "-", "-v"] +
-                                      sysroot)
-    stderr = err_out(result)
-    index1 = stderr.find(_INC_DIR_MARKER_BEGIN)
-    if index1 == -1:
-        return []
-    index1 = stderr.find("\n", index1)
-    if index1 == -1:
-        return []
-    index2 = stderr.rfind("\n ")
-    if index2 == -1 or index2 < index1:
-        return []
-    index2 = stderr.find("\n", index2 + 1)
-    if index2 == -1:
-        inc_dirs = stderr[index1 + 1:]
-    else:
-        inc_dirs = stderr[index1 + 1:index2].strip()
-
-    print_resource_dir_supported = _is_compiler_option_supported(
-        repository_ctx,
-        cc,
-        "-print-resource-dir",
-    )
-
-    if print_resource_dir_supported:
-        resource_dir = repository_ctx.execute(
-            [cc, "-print-resource-dir"],
-        ).stdout.strip() + "/share"
-        inc_dirs += "\n" + resource_dir
-
-    return [
-        _normalize_include_path(repository_ctx, _cxx_inc_convert(p))
-        for p in inc_dirs.split("\n")
-    ]
-
-def get_cxx_inc_directories(repository_ctx, cc, tf_sysroot):
-    """Compute the list of default C and C++ include directories."""
-
-    # For some reason `clang -xc` sometimes returns include paths that are
-    # different from the ones from `clang -xc++`. (Symlink and a dir)
-    # So we run the compiler with both `-xc` and `-xc++` and merge resulting lists
-    includes_cpp = _get_cxx_inc_directories_impl(
-        repository_ctx,
-        cc,
-        True,
-        tf_sysroot,
-    )
-    includes_c = _get_cxx_inc_directories_impl(
-        repository_ctx,
-        cc,
-        False,
-        tf_sysroot,
-    )
-
-    return includes_cpp + [
-        inc
-        for inc in includes_c
-        if inc not in includes_cpp
-    ]
-
-def auto_configure_fail(msg):
-    """Output failure message when cuda configuration fails."""
-    red = "\033[0;31m"
-    no_color = "\033[0m"
-    fail("\n%sCuda Configuration Error:%s %s\n" % (red, no_color, msg))
-
-# END cc_configure common functions (see TODO above).
-
-def _cuda_include_path(repository_ctx, cuda_config):
-    """Generates the Starlark string with cuda include directories.
-
-      Args:
-        repository_ctx: The repository context.
-        cc: The path to the gcc host compiler.
-
-      Returns:
-        A list of the gcc host compiler include directories.
-      """
-    nvcc_path = repository_ctx.path(
-        "%s/bin/nvcc" % cuda_config.cuda_toolkit_path,
-    )
-
-    # The expected exit code of this command is non-zero. Bazel remote execution
-    # only caches commands with zero exit code. So force a zero exit code.
-    cmd = "%s -v /dev/null -o /dev/null ; [ $? -eq 1 ]" % str(nvcc_path)
-    result = raw_exec(repository_ctx, [get_bash_bin(repository_ctx), "-c", cmd])
-    target_dir = ""
-    for one_line in err_out(result).splitlines():
-        if one_line.startswith("#$ _TARGET_DIR_="):
-            target_dir = (
-                cuda_config.cuda_toolkit_path + "/" + one_line.replace(
-                    "#$ _TARGET_DIR_=",
-                    "",
-                ) + "/include"
-            )
-    inc_entries = []
-    if target_dir != "":
-        inc_entries.append(realpath(repository_ctx, target_dir))
-    inc_entries.append(realpath(repository_ctx, cuda_config.cuda_toolkit_path + "/include"))
-    return inc_entries
-
-def enable_cuda(repository_ctx):
-    """Returns whether to build with CUDA support."""
-    return int(get_host_environ(repository_ctx, "TF_NEED_CUDA", False))
-
-def matches_version(environ_version, detected_version):
-    """Checks whether the user-specified version matches the detected version.
-
-      This function performs a weak matching so that if the user specifies only
-      the
-      major or major and minor versions, the versions are still considered
-      matching
-      if the version parts match. To illustrate:
-
-          environ_version  detected_version  result
-          -----------------------------------------
-          5.1.3            5.1.3             True
-          5.1              5.1.3             True
-          5                5.1               True
-          5.1.3            5.1               False
-          5.2.3            5.1.3             False
-
-      Args:
-        environ_version: The version specified by the user via environment
-          variables.
-        detected_version: The version autodetected from the CUDA installation on
-          the system.
-      Returns: True if user-specified version matches detected version and False
-        otherwise.
-    """
-    environ_version_parts = environ_version.split(".")
-    detected_version_parts = detected_version.split(".")
-    if len(detected_version_parts) < len(environ_version_parts):
-        return False
-    for i, part in enumerate(detected_version_parts):
-        if i >= len(environ_version_parts):
-            break
-        if part != environ_version_parts[i]:
-            return False
-    return True
-
-def compute_capabilities(repository_ctx):
-    """Returns a list of strings representing cuda compute capabilities.
-
-    Args:
-      repository_ctx: the repo rule's context.
-    Returns: list of cuda architectures to compile for. 'compute_xy' refers to
-      both PTX and SASS, 'sm_xy' refers to SASS only.
-    """
-    capabilities = get_host_environ(
-        repository_ctx,
-        _TF_CUDA_COMPUTE_CAPABILITIES,
-        "compute_35,compute_52",
-    ).split(",")
-
-    # Map old 'x.y' capabilities to 'compute_xy'.
-    if len(capabilities) > 0 and all([len(x.split(".")) == 2 for x in capabilities]):
-        # If all capabilities are in 'x.y' format, only include PTX for the
-        # highest capability.
-        cc_list = sorted([x.replace(".", "") for x in capabilities])
-        capabilities = ["sm_%s" % x for x in cc_list[:-1]] + ["compute_%s" % cc_list[-1]]
-    for i, capability in enumerate(capabilities):
-        parts = capability.split(".")
-        if len(parts) != 2:
-            continue
-        capabilities[i] = "compute_%s%s" % (parts[0], parts[1])
-
-    # Make list unique
-    capabilities = dict(zip(capabilities, capabilities)).keys()
-
-    # Validate capabilities.
-    for capability in capabilities:
-        if not capability.startswith(("compute_", "sm_")):
-            auto_configure_fail("Invalid compute capability: %s" % capability)
-        for prefix in ["compute_", "sm_"]:
-            if not capability.startswith(prefix):
-                continue
-            if len(capability) == len(prefix) + 2 and capability[-2:].isdigit():
-                continue
-            auto_configure_fail("Invalid compute capability: %s" % capability)
-
-    return capabilities
-
-def lib_name(base_name, version = None, static = False):
-    """Constructs the platform-specific name of a library.
-
-      Args:
-        base_name: The name of the library, such as "cudart"
-        version: The version of the library.
-        static: True the library is static or False if it is a shared object.
-
-      Returns:
-        The platform-specific name of the library.
-      """
-    version = "" if not version else "." + version
-    if static:
-        return "lib%s.a" % base_name
-    return "lib%s.so%s" % (base_name, version)
-
-def _lib_path(lib, basedir, version, static):
-    file_name = lib_name(lib, version, static)
-    return "%s/%s" % (basedir, file_name)
-
-def _should_check_soname(version, static):
-    return version and not static
-
-def _check_cuda_lib_params(lib, basedir, version, static = False):
-    return (
-        _lib_path(lib, basedir, version, static),
-        _should_check_soname(version, static),
-    )
-
-def _check_cuda_libs(repository_ctx, script_path, libs):
-    python_bin = get_python_bin(repository_ctx)
-    contents = repository_ctx.read(script_path).splitlines()
-
-    cmd = "from os import linesep;"
-    cmd += "f = open('script.py', 'w');"
-    for line in contents:
-        cmd += "f.write('%s' + linesep);" % line
-    cmd += "f.close();"
-    cmd += "from os import system;"
-    args = " ".join(["\"" + path + "\" " + str(check) for path, check in libs])
-    cmd += "system('%s script.py %s');" % (python_bin, args)
-
-    all_paths = [path for path, _ in libs]
-    checked_paths = execute(repository_ctx, [python_bin, "-c", cmd]).stdout.splitlines()
-
-    if all_paths != checked_paths:
-        auto_configure_fail("Error with installed CUDA libs. Expected '%s'. Actual '%s'." % (all_paths, checked_paths))
-
-def _find_libs(repository_ctx, check_cuda_libs_script, cuda_config):
-    """Returns the CUDA and cuDNN libraries on the system.
-
-      Also, verifies that the script actually exist.
-
-      Args:
-        repository_ctx: The repository context.
-        check_cuda_libs_script: The path to a script verifying that the cuda
-          libraries exist on the system.
-        cuda_config: The CUDA config as returned by _get_cuda_config
-
-      Returns:
-        Map of library names to structs of filename and path.
-      """
-    check_cuda_libs_params = {
-        "cuda": _check_cuda_lib_params(
-            "cuda",
-            cuda_config.config["cuda_library_dir"] + "/stubs",
-            version = None,
-        ),
-        "cudart": _check_cuda_lib_params(
-            "cudart",
-            cuda_config.config["cuda_library_dir"],
-            cuda_config.cudart_version,
-        ),
-        "cudart_static": _check_cuda_lib_params(
-            "cudart_static",
-            cuda_config.config["cuda_library_dir"],
-            cuda_config.cudart_version,
-            static = True,
-        ),
-        "cublas": _check_cuda_lib_params(
-            "cublas",
-            cuda_config.config["cublas_library_dir"],
-            cuda_config.cublas_version,
-        ),
-        "cublasLt": _check_cuda_lib_params(
-            "cublasLt",
-            cuda_config.config["cublas_library_dir"],
-            cuda_config.cublas_version,
-        ),
-        "cusolver": _check_cuda_lib_params(
-            "cusolver",
-            cuda_config.config["cusolver_library_dir"],
-            cuda_config.cusolver_version,
-        ),
-        "curand": _check_cuda_lib_params(
-            "curand",
-            cuda_config.config["curand_library_dir"],
-            cuda_config.curand_version,
-        ),
-        "cufft": _check_cuda_lib_params(
-            "cufft",
-            cuda_config.config["cufft_library_dir"],
-            cuda_config.cufft_version,
-        ),
-        "cudnn": _check_cuda_lib_params(
-            "cudnn",
-            cuda_config.config["cudnn_library_dir"],
-            cuda_config.cudnn_version,
-        ),
-        "cupti": _check_cuda_lib_params(
-            "cupti",
-            cuda_config.config["cupti_library_dir"],
-            cuda_config.cupti_version,
-        ),
-        "cusparse": _check_cuda_lib_params(
-            "cusparse",
-            cuda_config.config["cusparse_library_dir"],
-            cuda_config.cusparse_version,
-        ),
-    }
-
-    # Verify that the libs actually exist at their locations.
-    _check_cuda_libs(repository_ctx, check_cuda_libs_script, check_cuda_libs_params.values())
-
-    paths = {filename: v[0] for (filename, v) in check_cuda_libs_params.items()}
-    return paths
-
-# TODO(csigg): Only call once instead of from here, tensorrt_configure.bzl,
-# and nccl_configure.bzl.
-def find_cuda_config(repository_ctx, cuda_libraries):
-    """Returns CUDA config dictionary from running find_cuda_config.py"""
-    python_bin = get_python_bin(repository_ctx)
-    exec_result = execute(repository_ctx, [python_bin, repository_ctx.attr._find_cuda_config] + cuda_libraries)
-    if exec_result.return_code:
-        auto_configure_fail("Failed to run find_cuda_config.py: %s" % err_out(exec_result))
-
-    # Parse the dict from stdout.
-    return dict([tuple(x.split(": ")) for x in exec_result.stdout.splitlines()])
-
-def _get_cuda_config(repository_ctx):
-    """Detects and returns information about the CUDA installation on the system.
-
-      Args:
-        repository_ctx: The repository context.
-
-      Returns:
-        A struct containing the following fields:
-          cuda_toolkit_path: The CUDA toolkit installation directory.
-          cudnn_install_basedir: The cuDNN installation directory.
-          cuda_version: The version of CUDA on the system.
-          cudart_version: The CUDA runtime version on the system.
-          cudnn_version: The version of cuDNN on the system.
-          compute_capabilities: A list of the system's CUDA compute capabilities.
-      """
-    config = find_cuda_config(repository_ctx, ["cuda", "cudnn"])
-    toolkit_path = config["cuda_toolkit_path"]
-
-    cuda_version = config["cuda_version"].split(".")
-    cuda_major = cuda_version[0]
-    cuda_minor = cuda_version[1]
-
-    cuda_version = "%s.%s" % (cuda_major, cuda_minor)
-    cudnn_version = "%s" % config["cudnn_version"]
-
-    if int(cuda_major) >= 11:
-        # The libcudart soname in CUDA 11.x is versioned as 11.0 for backward compatability.
-        if int(cuda_major) == 11:
-            cudart_version = "11.0"
-            cupti_version = cuda_version
-        else:
-            cudart_version = "%s" % cuda_major
-            cupti_version = cudart_version
-        cublas_version = "%s" % config["cublas_version"].split(".")[0]
-        cusolver_version = "%s" % config["cusolver_version"].split(".")[0]
-        curand_version = "%s" % config["curand_version"].split(".")[0]
-        cufft_version = "%s" % config["cufft_version"].split(".")[0]
-        cusparse_version = "%s" % config["cusparse_version"].split(".")[0]
-    elif (int(cuda_major), int(cuda_minor)) >= (10, 1):
-        # cuda_lib_version is for libraries like cuBLAS, cuFFT, cuSOLVER, etc.
-        # It changed from 'x.y' to just 'x' in CUDA 10.1.
-        cuda_lib_version = "%s" % cuda_major
-        cudart_version = cuda_version
-        cupti_version = cuda_version
-        cublas_version = cuda_lib_version
-        cusolver_version = cuda_lib_version
-        curand_version = cuda_lib_version
-        cufft_version = cuda_lib_version
-        cusparse_version = cuda_lib_version
-    else:
-        cudart_version = cuda_version
-        cupti_version = cuda_version
-        cublas_version = cuda_version
-        cusolver_version = cuda_version
-        curand_version = cuda_version
-        cufft_version = cuda_version
-        cusparse_version = cuda_version
-
-    return struct(
-        cuda_toolkit_path = toolkit_path,
-        cuda_version = cuda_version,
-        cupti_version = cupti_version,
-        cuda_version_major = cuda_major,
-        cudart_version = cudart_version,
-        cublas_version = cublas_version,
-        cusolver_version = cusolver_version,
-        curand_version = curand_version,
-        cufft_version = cufft_version,
-        cusparse_version = cusparse_version,
-        cudnn_version = cudnn_version,
-        compute_capabilities = compute_capabilities(repository_ctx),
-        config = config,
-    )
-
-def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
-    if not out:
-        out = tpl.replace(":", "/")
-    repository_ctx.template(
-        out,
-        Label("//third_party/gpus/%s.tpl" % tpl),
-        substitutions,
-    )
-
-def _file(repository_ctx, label):
-    repository_ctx.template(
-        label.replace(":", "/"),
-        Label("//third_party/gpus/%s.tpl" % label),
-        {},
-    )
-
-_DUMMY_CROSSTOOL_BZL_FILE = """
-def error_gpu_disabled():
-  fail("ERROR: Building with --config=cuda but TensorFlow is not configured " +
-       "to build with GPU support. Please re-run ./configure and enter 'Y' " +
-       "at the prompt to build with GPU support.")
-
-  native.genrule(
-      name = "error_gen_crosstool",
-      outs = ["CROSSTOOL"],
-      cmd = "echo 'Should not be run.' && exit 1",
-  )
-
-  native.filegroup(
-      name = "crosstool",
-      srcs = [":CROSSTOOL"],
-      output_licenses = ["unencumbered"],
-  )
-"""
-
-_DUMMY_CROSSTOOL_BUILD_FILE = """
-load("//crosstool:error_gpu_disabled.bzl", "error_gpu_disabled")
-
-error_gpu_disabled()
-"""
-
-def _create_dummy_repository(repository_ctx):
-    # Set up BUILD file for cuda/.
-    _tpl(
-        repository_ctx,
-        "cuda:build_defs.bzl",
-        {
-            "%{cuda_is_configured}": "False",
-            "%{cuda_extra_copts}": "[]",
-            "%{cuda_gpu_architectures}": "[]",
-        },
-    )
-    _tpl(
-        repository_ctx,
-        "cuda:BUILD",
-        {
-            "%{copy_rules}": """
-filegroup(name="cuda-include")
-filegroup(name="cublas-include")
-filegroup(name="cusolver-include")
-filegroup(name="cufft-include")
-filegroup(name="cusparse-include")
-filegroup(name="curand-include")
-filegroup(name="cudnn-include")
-""",
-        },
-    )
-
-    # Create dummy files for the CUDA toolkit since they are still required by
-    # tensorflow/tsl/platform/default/build_config:cuda.
-    repository_ctx.file("cuda/cuda/include/cuda.h")
-    repository_ctx.file("cuda/cuda/include/cublas.h")
-    repository_ctx.file("cuda/cuda/include/cudnn.h")
-    repository_ctx.file("cuda/cuda/extras/CUPTI/include/cupti.h")
-    repository_ctx.file("cuda/cuda/lib/libcuda.so")
-    repository_ctx.file("cuda/cuda/lib/libcudart_static.a")
-    repository_ctx.file("cuda/cuda/nvml/include/nvml.h")
-
-    # Set up cuda_config.h, which is used by
-    # tensorflow/compiler/xla/stream_executor/dso_loader.cc.
-    _tpl(
-        repository_ctx,
-        "cuda:cuda_config.h",
-        {
-            "%{cuda_version}": "",
-            "%{cudart_version}": "",
-            "%{cupti_version}": "",
-            "%{cublas_version}": "",
-            "%{cusolver_version}": "",
-            "%{curand_version}": "",
-            "%{cufft_version}": "",
-            "%{cusparse_version}": "",
-            "%{cudnn_version}": "",
-            "%{cuda_toolkit_path}": "",
-            "%{cuda_compute_capabilities}": "",
-        },
-        "cuda/cuda/cuda_config.h",
-    )
-
-    # Set up cuda_config.py, which is used by gen_build_info to provide
-    # static build environment info to the API
-    _tpl(
-        repository_ctx,
-        "cuda:cuda_config.py",
-        _py_tmpl_dict({}),
-        "cuda/cuda/cuda_config.py",
-    )
-
-    # If cuda_configure is not configured to build with GPU support, and the user
-    # attempts to build with --config=cuda, add a dummy build rule to intercept
-    # this and fail with an actionable error message.
-    repository_ctx.file(
-        "crosstool/error_gpu_disabled.bzl",
-        _DUMMY_CROSSTOOL_BZL_FILE,
-    )
-    repository_ctx.file("crosstool/BUILD", _DUMMY_CROSSTOOL_BUILD_FILE)
-
-def _norm_path(path):
-    """Returns a path with '/' and remove the trailing slash."""
-    path = path.replace("\\", "/")
-    if path[-1] == "/":
-        path = path[:-1]
-    return path
-
-def make_copy_files_rule(repository_ctx, name, srcs, outs):
-    """Returns a rule to copy a set of files."""
-    cmds = []
-
-    # Copy files.
-    for src, out in zip(srcs, outs):
-        cmds.append('cp -f "%s" "$(location %s)"' % (src, out))
-    outs = [('        "%s",' % out) for out in outs]
-    return """genrule(
-    name = "%s",
-    outs = [
-%s
-    ],
-    cmd = \"""%s \""",
-)""" % (name, "\n".join(outs), " && \\\n".join(cmds))
-
-def make_copy_dir_rule(repository_ctx, name, src_dir, out_dir):
-    """Returns a rule to recursively copy a directory.
-    If exceptions is not None, it must be a list of files or directories in
-    'src_dir'; these will be excluded from copying.
-    """
-    src_dir = _norm_path(src_dir)
-    out_dir = _norm_path(out_dir)
-    outs = read_dir(repository_ctx, src_dir)
-    outs = [('        "%s",' % out.replace(src_dir, out_dir)) for out in outs]
-
-    # '@D' already contains the relative path for a single file, see
-    # http://docs.bazel.build/versions/master/be/make-variables.html#predefined_genrule_variables
-    out_dir = "$(@D)/%s" % out_dir if len(outs) > 1 else "$(@D)"
-    return """genrule(
-    name = "%s",
-    outs = [
-%s
-    ],
-    cmd = \"""cp -rLf "%s/." "%s/" \""",
-)""" % (name, "\n".join(outs), src_dir, out_dir)
-
-def _flag_enabled(repository_ctx, flag_name):
-    return get_host_environ(repository_ctx, flag_name) == "1"
-
-def _use_cuda_clang(repository_ctx):
-    return _flag_enabled(repository_ctx, "TF_CUDA_CLANG")
-
-def _tf_sysroot(repository_ctx):
-    return get_host_environ(repository_ctx, _TF_SYSROOT, "")
-
-def _compute_cuda_extra_copts(repository_ctx, compute_capabilities):
-    copts = ["--no-cuda-include-ptx=all"] if _use_cuda_clang(repository_ctx) else []
-    for capability in compute_capabilities:
-        if capability.startswith("compute_"):
-            capability = capability.replace("compute_", "sm_")
-            copts.append("--cuda-include-ptx=%s" % capability)
-        copts.append("--cuda-gpu-arch=%s" % capability)
-
-    return str(copts)
-
-def _tpl_path(repository_ctx, filename):
-    return repository_ctx.path(Label("//third_party/gpus/%s.tpl" % filename))
-
-def _create_local_cuda_repository(repository_ctx):
-    """Creates the repository containing files set up to build with CUDA."""
-
-    # Resolve all labels before doing any real work. Resolving causes the
-    # function to be restarted with all previous state being lost. This
-    # can easily lead to a O(n^2) runtime in the number of labels.
-    # See https://github.com/tensorflow/tensorflow/commit/62bd3534525a036f07d9851b3199d68212904778
-    tpl_paths = {filename: _tpl_path(repository_ctx, filename) for filename in [
-        "cuda:BUILD",
-        "cuda:build_defs.bzl",
-        "crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc",
-        "crosstool:BUILD",
-        "crosstool:cc_toolchain_config.bzl",
-        "cuda:cuda_config.h",
-        "cuda:cuda_config.py",
-    ]}
-
-    cuda_config = _get_cuda_config(repository_ctx)
-
-    cuda_include_path = cuda_config.config["cuda_include_dir"]
-    cublas_include_path = cuda_config.config["cublas_include_dir"]
-    cudnn_header_dir = cuda_config.config["cudnn_include_dir"]
-    cupti_header_dir = cuda_config.config["cupti_include_dir"]
-    nvvm_libdevice_dir = cuda_config.config["nvvm_library_dir"]
-    nvml_header_dir = cuda_config.config["nvml_header_dir"]
-
-    # Create genrule to copy files from the installed CUDA toolkit into execroot.
-    copy_rules = [
-        make_copy_dir_rule(
-            repository_ctx,
-            name = "cuda-include",
-            src_dir = cuda_include_path,
-            out_dir = "cuda/include",
-        ),
-        make_copy_dir_rule(
-            repository_ctx,
-            name = "cuda-nvvm",
-            src_dir = nvvm_libdevice_dir,
-            out_dir = "cuda/nvvm/libdevice",
-        ),
-        make_copy_dir_rule(
-            repository_ctx,
-            name = "cuda-extras",
-            src_dir = cupti_header_dir,
-            out_dir = "cuda/extras/CUPTI/include",
-        ),
-        make_copy_dir_rule(
-            repository_ctx,
-            name = "nvml",
-            src_dir = nvml_header_dir,
-            out_dir = "cuda/nvml/include",
-        ),
-    ]
-
-    copy_rules.append(make_copy_files_rule(
-        repository_ctx,
-        name = "cublas-include",
-        srcs = [
-            cublas_include_path + "/cublas.h",
-            cublas_include_path + "/cublas_v2.h",
-            cublas_include_path + "/cublas_api.h",
-            cublas_include_path + "/cublasLt.h",
-        ],
-        outs = [
-            "cublas/include/cublas.h",
-            "cublas/include/cublas_v2.h",
-            "cublas/include/cublas_api.h",
-            "cublas/include/cublasLt.h",
-        ],
-    ))
-
-    cusolver_include_path = cuda_config.config["cusolver_include_dir"]
-    copy_rules.append(make_copy_files_rule(
-        repository_ctx,
-        name = "cusolver-include",
-        srcs = [
-            cusolver_include_path + "/cusolver_common.h",
-            cusolver_include_path + "/cusolverDn.h",
-        ],
-        outs = [
-            "cusolver/include/cusolver_common.h",
-            "cusolver/include/cusolverDn.h",
-        ],
-    ))
-
-    cufft_include_path = cuda_config.config["cufft_include_dir"]
-    copy_rules.append(make_copy_files_rule(
-        repository_ctx,
-        name = "cufft-include",
-        srcs = [
-            cufft_include_path + "/cufft.h",
-        ],
-        outs = [
-            "cufft/include/cufft.h",
-        ],
-    ))
-
-    cusparse_include_path = cuda_config.config["cusparse_include_dir"]
-    copy_rules.append(make_copy_files_rule(
-        repository_ctx,
-        name = "cusparse-include",
-        srcs = [
-            cusparse_include_path + "/cusparse.h",
-        ],
-        outs = [
-            "cusparse/include/cusparse.h",
-        ],
-    ))
-
-    curand_include_path = cuda_config.config["curand_include_dir"]
-    copy_rules.append(make_copy_files_rule(
-        repository_ctx,
-        name = "curand-include",
-        srcs = [
-            curand_include_path + "/curand.h",
-        ],
-        outs = [
-            "curand/include/curand.h",
-        ],
-    ))
-
-    check_cuda_libs_script = repository_ctx.path(Label("@local_xla//third_party/gpus:check_cuda_libs.py"))
-    cuda_libs = _find_libs(repository_ctx, check_cuda_libs_script, cuda_config)
-    cuda_lib_srcs = []
-    cuda_lib_outs = []
-    for path in cuda_libs.values():
-        cuda_lib_srcs.append(path)
-        cuda_lib_outs.append("cuda/lib/" + path.rpartition("/")[-1])
-    copy_rules.append(make_copy_files_rule(
-        repository_ctx,
-        name = "cuda-lib",
-        srcs = cuda_lib_srcs,
-        outs = cuda_lib_outs,
-    ))
-
-    # copy files mentioned in third_party/nccl/build_defs.bzl.tpl
-    bin_files = ["crt/link.stub", "bin2c", "fatbinary", "nvlink", "nvprune"]
-    copy_rules.append(make_copy_files_rule(
-        repository_ctx,
-        name = "cuda-bin",
-        srcs = [cuda_config.cuda_toolkit_path + "/bin/" + f for f in bin_files],
-        outs = ["cuda/bin/" + f for f in bin_files],
-    ))
-
-    # Select the headers based on the cuDNN version.
-    cudnn_headers = ["cudnn.h"]
-    if cuda_config.cudnn_version.rsplit("_", 1)[-1] >= "8":
-        cudnn_headers += [
-            "cudnn_backend.h",
-            "cudnn_adv_infer.h",
-            "cudnn_adv_train.h",
-            "cudnn_cnn_infer.h",
-            "cudnn_cnn_train.h",
-            "cudnn_ops_infer.h",
-            "cudnn_ops_train.h",
-            "cudnn_version.h",
-        ]
-
-    cudnn_srcs = []
-    cudnn_outs = []
-    for header in cudnn_headers:
-        cudnn_srcs.append(cudnn_header_dir + "/" + header)
-        cudnn_outs.append("cudnn/include/" + header)
-
-    copy_rules.append(make_copy_files_rule(
-        repository_ctx,
-        name = "cudnn-include",
-        srcs = cudnn_srcs,
-        outs = cudnn_outs,
-    ))
-
-    # Set up BUILD file for cuda/
-    repository_ctx.template(
-        "cuda/build_defs.bzl",
-        tpl_paths["cuda:build_defs.bzl"],
-        {
-            "%{cuda_is_configured}": "True",
-            "%{cuda_extra_copts}": _compute_cuda_extra_copts(
-                repository_ctx,
-                cuda_config.compute_capabilities,
-            ),
-            "%{cuda_gpu_architectures}": str(cuda_config.compute_capabilities),
-        },
-    )
-
-    repository_ctx.template(
-        "cuda/BUILD",
-        tpl_paths["cuda:BUILD"],
-        {
-            "%{copy_rules}": "\n".join(copy_rules),
-        },
-    )
-
-    is_cuda_clang = _use_cuda_clang(repository_ctx)
-    tf_sysroot = _tf_sysroot(repository_ctx)
-
-    should_download_clang = is_cuda_clang and _flag_enabled(
-        repository_ctx,
-        _TF_DOWNLOAD_CLANG,
-    )
-    if should_download_clang:
-        download_clang(repository_ctx, "crosstool/extra_tools")
-
-    # Set up crosstool/
-    cc = find_cc(repository_ctx)
-    cc_fullpath = cc if not should_download_clang else "crosstool/" + cc
-
-    host_compiler_includes = get_cxx_inc_directories(
-        repository_ctx,
-        cc_fullpath,
-        tf_sysroot,
-    )
-    cuda_defines = {}
-    cuda_defines["%{builtin_sysroot}"] = tf_sysroot
-    cuda_defines["%{cuda_toolkit_path}"] = ""
-    cuda_defines["%{compiler}"] = "unknown"
-    if is_cuda_clang:
-        cuda_defines["%{cuda_toolkit_path}"] = cuda_config.config["cuda_toolkit_path"]
-        cuda_defines["%{compiler}"] = "clang"
-
-    host_compiler_prefix = get_host_environ(repository_ctx, _GCC_HOST_COMPILER_PREFIX)
-    if not host_compiler_prefix:
-        host_compiler_prefix = "/usr/bin"
-
-    cuda_defines["%{host_compiler_prefix}"] = host_compiler_prefix
-
-    # Bazel sets '-B/usr/bin' flag to workaround build errors on RHEL (see
-    # https://github.com/bazelbuild/bazel/issues/760).
-    # However, this stops our custom clang toolchain from picking the provided
-    # LLD linker, so we're only adding '-B/usr/bin' when using non-downloaded
-    # toolchain.
-    # TODO: when bazel stops adding '-B/usr/bin' by default, remove this
-    #       flag from the CROSSTOOL completely (see
-    #       https://github.com/bazelbuild/bazel/issues/5634)
-    if should_download_clang:
-        cuda_defines["%{linker_bin_path}"] = ""
-    else:
-        cuda_defines["%{linker_bin_path}"] = host_compiler_prefix
-
-    cuda_defines["%{extra_no_canonical_prefixes_flags}"] = ""
-    cuda_defines["%{unfiltered_compile_flags}"] = ""
-    if is_cuda_clang:
-        cuda_defines["%{host_compiler_path}"] = str(cc)
-        cuda_defines["%{host_compiler_warnings}"] = """
-        # Some parts of the codebase set -Werror and hit this warning, so
-        # switch it off for now.
-        "-Wno-invalid-partial-specialization"
-    """
-        cuda_defines["%{cxx_builtin_include_directories}"] = to_list_of_strings(host_compiler_includes)
-        cuda_defines["%{compiler_deps}"] = ":empty"
-        repository_ctx.file(
-            "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc",
-            "",
-        )
-    else:
-        cuda_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc"
-        cuda_defines["%{host_compiler_warnings}"] = ""
-
-        # nvcc has the system include paths built in and will automatically
-        # search them; we cannot work around that, so we add the relevant cuda
-        # system paths to the allowed compiler specific include paths.
-        cuda_defines["%{cxx_builtin_include_directories}"] = to_list_of_strings(
-            host_compiler_includes + _cuda_include_path(
-                repository_ctx,
-                cuda_config,
-            ) + [cupti_header_dir, cudnn_header_dir, nvml_header_dir],
-        )
-
-        # For gcc, do not canonicalize system header paths; some versions of gcc
-        # pick the shortest possible path for system includes when creating the
-        # .d file - given that includes that are prefixed with "../" multiple
-        # time quickly grow longer than the root of the tree, this can lead to
-        # bazel's header check failing.
-        cuda_defines["%{extra_no_canonical_prefixes_flags}"] = "\"-fno-canonical-system-headers\""
-
-        nvcc_path = "%s/nvcc" % cuda_config.config["cuda_binary_dir"]
-        cuda_defines["%{compiler_deps}"] = ":crosstool_wrapper_driver_is_not_gcc"
-
-        wrapper_defines = {
-            "%{cpu_compiler}": str(cc),
-            "%{cuda_version}": cuda_config.cuda_version,
-            "%{nvcc_path}": nvcc_path,
-            "%{gcc_host_compiler_path}": str(cc),
-        }
-        repository_ctx.template(
-            "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc",
-            tpl_paths["crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc"],
-            wrapper_defines,
-        )
-
-    verify_build_defines(cuda_defines)
-
-    # Only expand template variables in the BUILD file
-    repository_ctx.template(
-        "crosstool/BUILD",
-        tpl_paths["crosstool:BUILD"],
-        cuda_defines,
-    )
-
-    # No templating of cc_toolchain_config - use attributes and templatize the
-    # BUILD file.
-    repository_ctx.template(
-        "crosstool/cc_toolchain_config.bzl",
-        tpl_paths["crosstool:cc_toolchain_config.bzl"],
-        {},
-    )
-
-    # Set up cuda_config.h, which is used by
-    # tensorflow/compiler/xla/stream_executor/dso_loader.cc.
-    repository_ctx.template(
-        "cuda/cuda/cuda_config.h",
-        tpl_paths["cuda:cuda_config.h"],
-        {
-            "%{cuda_version}": cuda_config.cuda_version,
-            "%{cudart_version}": cuda_config.cudart_version,
-            "%{cupti_version}": cuda_config.cupti_version,
-            "%{cublas_version}": cuda_config.cublas_version,
-            "%{cusolver_version}": cuda_config.cusolver_version,
-            "%{curand_version}": cuda_config.curand_version,
-            "%{cufft_version}": cuda_config.cufft_version,
-            "%{cusparse_version}": cuda_config.cusparse_version,
-            "%{cudnn_version}": cuda_config.cudnn_version,
-            "%{cuda_toolkit_path}": cuda_config.cuda_toolkit_path,
-            "%{cuda_compute_capabilities}": ", ".join([
-                cc.split("_")[1]
-                for cc in cuda_config.compute_capabilities
-            ]),
-        },
-    )
-
-    # Set up cuda_config.py, which is used by gen_build_info to provide
-    # static build environment info to the API
-    repository_ctx.template(
-        "cuda/cuda/cuda_config.py",
-        tpl_paths["cuda:cuda_config.py"],
-        _py_tmpl_dict({
-            "cuda_version": cuda_config.cuda_version,
-            "cudnn_version": cuda_config.cudnn_version,
-            "cuda_compute_capabilities": cuda_config.compute_capabilities,
-            "cpu_compiler": str(cc),
-        }),
-    )
-
-def _py_tmpl_dict(d):
-    return {"%{cuda_config}": str(d)}
-
-def _create_remote_cuda_repository(repository_ctx, remote_config_repo):
-    """Creates pointers to a remotely configured repo set up to build with CUDA."""
-    _tpl(
-        repository_ctx,
-        "cuda:build_defs.bzl",
-        {
-            "%{cuda_is_configured}": "True",
-            "%{cuda_extra_copts}": _compute_cuda_extra_copts(
-                repository_ctx,
-                compute_capabilities(repository_ctx),
-            ),
-        },
-    )
-    repository_ctx.template(
-        "cuda/BUILD",
-        config_repo_label(remote_config_repo, "cuda:BUILD"),
-        {},
-    )
-    repository_ctx.template(
-        "cuda/build_defs.bzl",
-        config_repo_label(remote_config_repo, "cuda:build_defs.bzl"),
-        {},
-    )
-    repository_ctx.template(
-        "cuda/cuda/cuda_config.h",
-        config_repo_label(remote_config_repo, "cuda:cuda/cuda_config.h"),
-        {},
-    )
-    repository_ctx.template(
-        "cuda/cuda/cuda_config.py",
-        config_repo_label(remote_config_repo, "cuda:cuda/cuda_config.py"),
-        _py_tmpl_dict({}),
-    )
-
-    repository_ctx.template(
-        "crosstool/BUILD",
-        config_repo_label(remote_config_repo, "crosstool:BUILD"),
-        {},
-    )
-
-    repository_ctx.template(
-        "crosstool/cc_toolchain_config.bzl",
-        config_repo_label(remote_config_repo, "crosstool:cc_toolchain_config.bzl"),
-        {},
-    )
-
-    repository_ctx.template(
-        "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc",
-        config_repo_label(remote_config_repo, "crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc"),
-        {},
-    )
-
-def _cuda_autoconf_impl(repository_ctx):
-    """Implementation of the cuda_autoconf repository rule."""
-    build_file = Label("//third_party/gpus:local_config_cuda.BUILD")
-
-    if not enable_cuda(repository_ctx):
-        _create_dummy_repository(repository_ctx)
-    elif get_host_environ(repository_ctx, _TF_CUDA_CONFIG_REPO) != None:
-        has_cuda_version = get_host_environ(repository_ctx, _TF_CUDA_VERSION) != None
-        has_cudnn_version = get_host_environ(repository_ctx, _TF_CUDNN_VERSION) != None
-        if not has_cuda_version or not has_cudnn_version:
-            auto_configure_fail("%s and %s must also be set if %s is specified" %
-                                (_TF_CUDA_VERSION, _TF_CUDNN_VERSION, _TF_CUDA_CONFIG_REPO))
-        _create_remote_cuda_repository(
-            repository_ctx,
-            get_host_environ(repository_ctx, _TF_CUDA_CONFIG_REPO),
-        )
-    else:
-        _create_local_cuda_repository(repository_ctx)
-
-    repository_ctx.symlink(build_file, "BUILD")
-
-_ENVIRONS = [
-    _GCC_HOST_COMPILER_PATH,
-    _GCC_HOST_COMPILER_PREFIX,
-    _CLANG_CUDA_COMPILER_PATH,
-    "TF_NEED_CUDA",
-    "TF_CUDA_CLANG",
-    _TF_DOWNLOAD_CLANG,
-    _CUDA_TOOLKIT_PATH,
-    _CUDNN_INSTALL_PATH,
-    _TF_CUDA_VERSION,
-    _TF_CUDNN_VERSION,
-    _TF_CUDA_COMPUTE_CAPABILITIES,
-    "NVVMIR_LIBRARY_DIR",
-    _PYTHON_BIN_PATH,
-    "TMP",
-    "TMPDIR",
-    "TF_CUDA_PATHS",
-]
-
-remote_cuda_configure = repository_rule(
-    implementation = _create_local_cuda_repository,
-    environ = _ENVIRONS,
-    remotable = True,
-    attrs = {
-        "environ": attr.string_dict(),
-        "_find_cuda_config": attr.label(
-            default = Label("@local_xla//third_party/gpus:find_cuda_config.py"),
-        ),
-    },
-)
-
-cuda_configure = repository_rule(
-    implementation = _cuda_autoconf_impl,
-    environ = _ENVIRONS + [_TF_CUDA_CONFIG_REPO],
-    attrs = {
-        "_find_cuda_config": attr.label(
-            default = Label("@local_xla//third_party/gpus:find_cuda_config.py"),
-        ),
-    },
-)
-"""Detects and configures the local CUDA toolchain.
-
-Add the following to your WORKSPACE FILE:
-
-```python
-cuda_configure(name = "local_config_cuda")
-```
-
-Args:
-  name: A unique name for this workspace rule.
-"""
diff --git a/third_party/xla/third_party/gpus/find_cuda_config.py b/third_party/xla/third_party/gpus/find_cuda_config.py
deleted file mode 100644
index 78292c7..0000000
--- a/third_party/xla/third_party/gpus/find_cuda_config.py
+++ /dev/null
@@ -1,614 +0,0 @@
-# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Prints CUDA library and header directories and versions found on the system.
-
-The script searches for CUDA library and header files on the system, inspects
-them to determine their version and prints the configuration to stdout.
-The paths to inspect and the required versions are specified through environment
-variables. If no valid configuration is found, the script prints to stderr and
-returns an error code.
-
-The list of libraries to find is specified as arguments. Supported libraries are
-CUDA (includes cuBLAS), cuDNN, NCCL, and TensorRT.
-
-The script takes a list of base directories specified by the TF_CUDA_PATHS
-environment variable as comma-separated glob list. The script looks for headers
-and library files in a hard-coded set of subdirectories from these base paths.
-If TF_CUDA_PATHS is not specified, a OS specific default is used:
-
-  Linux:   /usr/local/cuda, /usr, and paths from 'ldconfig -p'.
-
-For backwards compatibility, some libraries also use alternative base
-directories from other environment variables if they are specified. List of
-library-specific environment variables:
-
-  Library   Version env variable  Additional base directories
-  ----------------------------------------------------------------
-  CUDA      TF_CUDA_VERSION       CUDA_TOOLKIT_PATH
-  cuBLAS    TF_CUBLAS_VERSION     CUDA_TOOLKIT_PATH
-  cuDNN     TF_CUDNN_VERSION      CUDNN_INSTALL_PATH
-  NCCL      TF_NCCL_VERSION       NCCL_INSTALL_PATH, NCCL_HDR_PATH
-  TensorRT  TF_TENSORRT_VERSION   TENSORRT_INSTALL_PATH
-
-Versions environment variables can be of the form 'x' or 'x.y' to request a
-specific version, empty or unspecified to accept any version.
-
-The output of a found library is of the form:
-tf_<library>_version: x.y.z
-tf_<library>_header_dir: ...
-tf_<library>_library_dir: ...
-"""
-
-import io
-import os
-import glob
-import re
-import subprocess
-import sys
-
-# pylint: disable=g-import-not-at-top
-try:
-  from shutil import which
-except ImportError:
-  from distutils.spawn import find_executable as which
-# pylint: enable=g-import-not-at-top
-
-
-class ConfigError(Exception):
-  pass
-
-
-def _matches_version(actual_version, required_version):
-  """Checks whether some version meets the requirements.
-
-      All elements of the required_version need to be present in the
-      actual_version.
-
-          required_version  actual_version  result
-          -----------------------------------------
-          1                 1.1             True
-          1.2               1               False
-          1.2               1.3             False
-                            1               True
-
-      Args:
-        required_version: The version specified by the user.
-        actual_version: The version detected from the CUDA installation.
-      Returns: Whether the actual version matches the required one.
-  """
-  if actual_version is None:
-    return False
-
-  # Strip spaces from the versions.
-  actual_version = actual_version.strip()
-  required_version = required_version.strip()
-  return actual_version.startswith(required_version)
-
-
-def _at_least_version(actual_version, required_version):
-  actual = [int(v) for v in actual_version.split(".")]
-  required = [int(v) for v in required_version.split(".")]
-  return actual >= required
-
-
-def _get_header_version(path, name):
-  """Returns preprocessor defines in C header file."""
-  for line in io.open(path, "r", encoding="utf-8").readlines():
-    match = re.match(r"\s*#\s*define %s\s+(\d+)" % name, line)
-    if match:
-      return match.group(1)
-  return ""
-
-
-def _cartesian_product(first, second):
-  """Returns all path combinations of first and second."""
-  return [os.path.join(f, s) for f in first for s in second]
-
-
-def _get_ld_config_paths():
-  """Returns all directories from 'ldconfig -p'."""
-  ldconfig_path = which("ldconfig") or "/sbin/ldconfig"
-  output = subprocess.check_output([ldconfig_path, "-p"])
-  pattern = re.compile(".* => (.*)")
-  result = set()
-  for line in output.splitlines():
-    try:
-      match = pattern.match(line.decode("ascii"))
-    except UnicodeDecodeError:
-      match = False
-    if match:
-      result.add(os.path.dirname(match.group(1)))
-  return sorted(list(result))
-
-
-def _get_default_cuda_paths(cuda_version):
-  if not cuda_version:
-    cuda_version = "*"
-  elif not "." in cuda_version:
-    cuda_version = cuda_version + ".*"
-
-  return ["/usr/local/cuda-%s" % cuda_version, "/usr/local/cuda", "/usr",
-         "/usr/local/cudnn"] + _get_ld_config_paths()
-
-
-def _header_paths():
-  """Returns hard-coded set of relative paths to look for header files."""
-  return [
-      "",
-      "include",
-      "include/cuda",
-      "include/*-linux-gnu",
-      "extras/CUPTI/include",
-      "include/cuda/CUPTI",
-      "local/cuda/extras/CUPTI/include",
-      "targets/x86_64-linux/include",
-  ]
-
-
-def _library_paths():
-  """Returns hard-coded set of relative paths to look for library files."""
-  return [
-      "",
-      "lib64",
-      "lib",
-      "lib/*-linux-gnu",
-      "lib/x64",
-      "extras/CUPTI/*",
-      "local/cuda/lib64",
-      "local/cuda/extras/CUPTI/lib64",
-  ]
-
-
-def _not_found_error(base_paths, relative_paths, filepattern):
-  base_paths = "".join(["\n        '%s'" % path for path in sorted(base_paths)])
-  relative_paths = "".join(["\n        '%s'" % path for path in relative_paths])
-  return ConfigError(
-      "Could not find any %s in any subdirectory:%s\nof:%s\n" %
-      (filepattern, relative_paths, base_paths))
-
-
-def _find_file(base_paths, relative_paths, filepattern):
-  for path in _cartesian_product(base_paths, relative_paths):
-    for file in glob.glob(os.path.join(path, filepattern)):
-      return file
-  raise _not_found_error(base_paths, relative_paths, filepattern)
-
-
-def _find_library(base_paths, library_name, required_version):
-  """Returns first valid path to the requested library."""
-  filepattern = ".".join(["lib" + library_name, "so"] +
-                         required_version.split(".")[:1]) + "*"
-  return _find_file(base_paths, _library_paths(), filepattern)
-
-
-def _find_versioned_file(base_paths, relative_paths, filepatterns,
-                         required_version, get_version):
-  """Returns first valid path to a file that matches the requested version."""
-  if type(filepatterns) not in [list, tuple]:
-    filepatterns = [filepatterns]
-  for path in _cartesian_product(base_paths, relative_paths):
-    for filepattern in filepatterns:
-      for file in glob.glob(os.path.join(path, filepattern)):
-        actual_version = get_version(file)
-        if _matches_version(actual_version, required_version):
-          return file, actual_version
-  raise _not_found_error(
-      base_paths, relative_paths,
-      ", ".join(filepatterns) + " matching version '%s'" % required_version)
-
-
-def _find_header(base_paths, header_name, required_version, get_version):
-  """Returns first valid path to a header that matches the requested version."""
-  return _find_versioned_file(base_paths, _header_paths(), header_name,
-                              required_version, get_version)
-
-
-def _find_cuda_config(base_paths, required_version):
-
-  def get_header_version(path):
-    version = int(_get_header_version(path, "CUDA_VERSION"))
-    if not version:
-      return None
-    return "%d.%d" % (version // 1000, version % 1000 // 10)
-
-  cuda_header_path, header_version = _find_header(base_paths, "cuda.h",
-                                                  required_version,
-                                                  get_header_version)
-  cuda_version = header_version  # x.y, see above.
-
-  cuda_library_path = _find_library(base_paths, "cudart", cuda_version)
-
-  def get_nvcc_version(path):
-    pattern = r"Cuda compilation tools, release \d+\.\d+, V(\d+\.\d+\.\d+)"
-    for line in subprocess.check_output([path, "--version"]).splitlines():
-      match = re.match(pattern, line.decode("ascii"))
-      if match:
-        return match.group(1)
-    return None
-
-  nvcc_name = "nvcc"
-  nvcc_path, nvcc_version = _find_versioned_file(base_paths, [
-      "",
-      "bin",
-      "local/cuda/bin",
-  ], nvcc_name, cuda_version, get_nvcc_version)
-
-  nvvm_path = _find_file(base_paths, [
-      "nvvm/libdevice",
-      "share/cuda",
-      "lib/nvidia-cuda-toolkit/libdevice",
-      "local/cuda/nvvm/libdevice",
-  ], "libdevice*.10.bc")
-
-  cupti_header_path = _find_file(base_paths, _header_paths(), "cupti.h")
-  nvml_header_dir = _find_file(base_paths, _header_paths(), "nvml.h")
-  cupti_library_path = _find_library(base_paths, "cupti", required_version)
-
-  cuda_binary_dir = os.path.dirname(nvcc_path)
-  nvvm_library_dir = os.path.dirname(nvvm_path)
-
-  # XLA requires the toolkit path to find ptxas and libdevice.
-  # TODO(csigg): pass in both directories instead.
-  cuda_toolkit_paths = (
-      os.path.normpath(os.path.join(cuda_binary_dir, "..")),
-      os.path.normpath(os.path.join(nvvm_library_dir, "../..")),
-  )
-  if cuda_toolkit_paths[0] != cuda_toolkit_paths[1]:
-    raise ConfigError("Inconsistent CUDA toolkit path: %s vs %s" %
-                      cuda_toolkit_paths)
-
-  return {
-      "cuda_version": cuda_version,
-      "cuda_include_dir": os.path.dirname(cuda_header_path),
-      "cuda_library_dir": os.path.dirname(cuda_library_path),
-      "cuda_binary_dir": cuda_binary_dir,
-      "nvvm_library_dir": nvvm_library_dir,
-      "cupti_include_dir": os.path.dirname(cupti_header_path),
-      "cupti_library_dir": os.path.dirname(cupti_library_path),
-      "cuda_toolkit_path": cuda_toolkit_paths[0],
-      "nvml_header_dir": os.path.dirname(nvml_header_dir),
-  }
-
-
-def _find_cublas_config(base_paths, required_version, cuda_version):
-
-  if _at_least_version(cuda_version, "10.1"):
-
-    def get_header_version(path):
-      version = (
-          _get_header_version(path, name)
-          for name in ("CUBLAS_VER_MAJOR", "CUBLAS_VER_MINOR",
-                       "CUBLAS_VER_PATCH"))
-      return ".".join(version)
-
-    header_path, header_version = _find_header(base_paths, "cublas_api.h",
-                                               required_version,
-                                               get_header_version)
-    # cuBLAS uses the major version only.
-    cublas_version = header_version.split(".")[0]
-
-  else:
-    # There is no version info available before CUDA 10.1, just find the file.
-    header_version = cuda_version
-    header_path = _find_file(base_paths, _header_paths(), "cublas_api.h")
-    # cuBLAS version is the same as CUDA version (x.y).
-    cublas_version = required_version
-
-  library_path = _find_library(base_paths, "cublas", cublas_version)
-
-  return {
-      "cublas_version": header_version,
-      "cublas_include_dir": os.path.dirname(header_path),
-      "cublas_library_dir": os.path.dirname(library_path),
-  }
-
-
-def _find_cusolver_config(base_paths, required_version, cuda_version):
-
-  if _at_least_version(cuda_version, "11.0"):
-
-    def get_header_version(path):
-      version = (
-          _get_header_version(path, name)
-          for name in ("CUSOLVER_VER_MAJOR", "CUSOLVER_VER_MINOR",
-                       "CUSOLVER_VER_PATCH"))
-      return ".".join(version)
-
-    header_path, header_version = _find_header(base_paths, "cusolver_common.h",
-                                               required_version,
-                                               get_header_version)
-    cusolver_version = header_version.split(".")[0]
-
-  else:
-    header_version = cuda_version
-    header_path = _find_file(base_paths, _header_paths(), "cusolver_common.h")
-    cusolver_version = required_version
-
-  library_path = _find_library(base_paths, "cusolver", cusolver_version)
-
-  return {
-      "cusolver_version": header_version,
-      "cusolver_include_dir": os.path.dirname(header_path),
-      "cusolver_library_dir": os.path.dirname(library_path),
-  }
-
-
-def _find_curand_config(base_paths, required_version, cuda_version):
-
-  if _at_least_version(cuda_version, "11.0"):
-
-    def get_header_version(path):
-      version = (
-          _get_header_version(path, name)
-          for name in ("CURAND_VER_MAJOR", "CURAND_VER_MINOR",
-                       "CURAND_VER_PATCH"))
-      return ".".join(version)
-
-    header_path, header_version = _find_header(base_paths, "curand.h",
-                                               required_version,
-                                               get_header_version)
-    curand_version = header_version.split(".")[0]
-
-  else:
-    header_version = cuda_version
-    header_path = _find_file(base_paths, _header_paths(), "curand.h")
-    curand_version = required_version
-
-  library_path = _find_library(base_paths, "curand", curand_version)
-
-  return {
-      "curand_version": header_version,
-      "curand_include_dir": os.path.dirname(header_path),
-      "curand_library_dir": os.path.dirname(library_path),
-  }
-
-
-def _find_cufft_config(base_paths, required_version, cuda_version):
-
-  if _at_least_version(cuda_version, "11.0"):
-
-    def get_header_version(path):
-      version = (
-          _get_header_version(path, name)
-          for name in ("CUFFT_VER_MAJOR", "CUFFT_VER_MINOR", "CUFFT_VER_PATCH"))
-      return ".".join(version)
-
-    header_path, header_version = _find_header(base_paths, "cufft.h",
-                                               required_version,
-                                               get_header_version)
-    cufft_version = header_version.split(".")[0]
-
-  else:
-    header_version = cuda_version
-    header_path = _find_file(base_paths, _header_paths(), "cufft.h")
-    cufft_version = required_version
-
-  library_path = _find_library(base_paths, "cufft", cufft_version)
-
-  return {
-      "cufft_version": header_version,
-      "cufft_include_dir": os.path.dirname(header_path),
-      "cufft_library_dir": os.path.dirname(library_path),
-  }
-
-
-def _find_cudnn_config(base_paths, required_version):
-
-  def get_header_version(path):
-    version = [
-        _get_header_version(path, name)
-        for name in ("CUDNN_MAJOR", "CUDNN_MINOR", "CUDNN_PATCHLEVEL")]
-    return ".".join(version) if version[0] else None
-
-  header_path, header_version = _find_header(base_paths,
-                                             ("cudnn.h", "cudnn_version.h"),
-                                             required_version,
-                                             get_header_version)
-  cudnn_version = header_version.split(".")[0]
-
-  library_path = _find_library(base_paths, "cudnn", cudnn_version)
-
-  return {
-      "cudnn_version": cudnn_version,
-      "cudnn_include_dir": os.path.dirname(header_path),
-      "cudnn_library_dir": os.path.dirname(library_path),
-  }
-
-
-def _find_cusparse_config(base_paths, required_version, cuda_version):
-
-  if _at_least_version(cuda_version, "11.0"):
-
-    def get_header_version(path):
-      version = (
-          _get_header_version(path, name)
-          for name in ("CUSPARSE_VER_MAJOR", "CUSPARSE_VER_MINOR",
-                       "CUSPARSE_VER_PATCH"))
-      return ".".join(version)
-
-    header_path, header_version = _find_header(base_paths, "cusparse.h",
-                                               required_version,
-                                               get_header_version)
-    cusparse_version = header_version.split(".")[0]
-
-  else:
-    header_version = cuda_version
-    header_path = _find_file(base_paths, _header_paths(), "cusparse.h")
-    cusparse_version = required_version
-
-  library_path = _find_library(base_paths, "cusparse", cusparse_version)
-
-  return {
-      "cusparse_version": header_version,
-      "cusparse_include_dir": os.path.dirname(header_path),
-      "cusparse_library_dir": os.path.dirname(library_path),
-  }
-
-
-def _find_nccl_config(base_paths, required_version):
-
-  def get_header_version(path):
-    version = (
-        _get_header_version(path, name)
-        for name in ("NCCL_MAJOR", "NCCL_MINOR", "NCCL_PATCH"))
-    return ".".join(version)
-
-  header_path, header_version = _find_header(base_paths, "nccl.h",
-                                             required_version,
-                                             get_header_version)
-  nccl_version = header_version.split(".")[0]
-
-  library_path = _find_library(base_paths, "nccl", nccl_version)
-
-  return {
-      "nccl_version": nccl_version,
-      "nccl_include_dir": os.path.dirname(header_path),
-      "nccl_library_dir": os.path.dirname(library_path),
-  }
-
-
-def _find_tensorrt_config(base_paths, required_version):
-
-  def get_header_version(path):
-    version = (
-        _get_header_version(path, name)
-        for name in ("NV_TENSORRT_MAJOR", "NV_TENSORRT_MINOR",
-                     "NV_TENSORRT_PATCH"))
-    # `version` is a generator object, so we convert it to a list before using
-    # it (muitiple times below).
-    version = list(version)
-    if not all(version):
-      return None  # Versions not found, make _matches_version returns False.
-    return ".".join(version)
-
-  header_path, header_version = _find_header(base_paths, "NvInferVersion.h",
-                                             required_version,
-                                             get_header_version)
-
-  tensorrt_version = header_version.split(".")[0]
-  library_path = _find_library(base_paths, "nvinfer", tensorrt_version)
-
-  return {
-      "tensorrt_version": header_version,
-      "tensorrt_include_dir": os.path.dirname(header_path),
-      "tensorrt_library_dir": os.path.dirname(library_path),
-  }
-
-
-def _list_from_env(env_name, default=[]):
-  """Returns comma-separated list from environment variable."""
-  if env_name in os.environ:
-    return os.environ[env_name].split(",")
-  return default
-
-
-def _get_legacy_path(env_name, default=[]):
-  """Returns a path specified by a legacy environment variable.
-
-  CUDNN_INSTALL_PATH, NCCL_INSTALL_PATH, TENSORRT_INSTALL_PATH set to
-  '/usr/lib/x86_64-linux-gnu' would previously find both library and header
-  paths. Detect those and return '/usr', otherwise forward to _list_from_env().
-  """
-  if env_name in os.environ:
-    match = re.match(r"^(/[^/ ]*)+/lib/\w+-linux-gnu/?$", os.environ[env_name])
-    if match:
-      return [match.group(1)]
-  return _list_from_env(env_name, default)
-
-
-def find_cuda_config():
-  """Returns a dictionary of CUDA library and header file paths."""
-  libraries = [argv.lower() for argv in sys.argv[1:]]
-  cuda_version = os.environ.get("TF_CUDA_VERSION", "")
-  base_paths = _list_from_env("TF_CUDA_PATHS",
-                              _get_default_cuda_paths(cuda_version))
-  base_paths = [path for path in base_paths if os.path.exists(path)]
-
-  result = {}
-  if "cuda" in libraries:
-    cuda_paths = _list_from_env("CUDA_TOOLKIT_PATH", base_paths)
-    res = _find_cuda_config(cuda_paths, cuda_version)
-
-    result.update(res)
-
-    cuda_version = result["cuda_version"]
-    cublas_paths = base_paths
-    if tuple(int(v) for v in cuda_version.split(".")) < (10, 1):
-      # Before CUDA 10.1, cuBLAS was in the same directory as the toolkit.
-      cublas_paths = cuda_paths
-    cublas_version = os.environ.get("TF_CUBLAS_VERSION", "")
-    result.update(
-        _find_cublas_config(cublas_paths, cublas_version, cuda_version))
-
-    cusolver_paths = base_paths
-    if tuple(int(v) for v in cuda_version.split(".")) < (11, 0):
-      cusolver_paths = cuda_paths
-    cusolver_version = os.environ.get("TF_CUSOLVER_VERSION", "")
-    result.update(
-        _find_cusolver_config(cusolver_paths, cusolver_version, cuda_version))
-
-    curand_paths = base_paths
-    if tuple(int(v) for v in cuda_version.split(".")) < (11, 0):
-      curand_paths = cuda_paths
-    curand_version = os.environ.get("TF_CURAND_VERSION", "")
-    result.update(
-        _find_curand_config(curand_paths, curand_version, cuda_version))
-
-    cufft_paths = base_paths
-    if tuple(int(v) for v in cuda_version.split(".")) < (11, 0):
-      cufft_paths = cuda_paths
-    cufft_version = os.environ.get("TF_CUFFT_VERSION", "")
-    result.update(_find_cufft_config(cufft_paths, cufft_version, cuda_version))
-
-    cusparse_paths = base_paths
-    if tuple(int(v) for v in cuda_version.split(".")) < (11, 0):
-      cusparse_paths = cuda_paths
-    cusparse_version = os.environ.get("TF_CUSPARSE_VERSION", "")
-    result.update(
-        _find_cusparse_config(cusparse_paths, cusparse_version, cuda_version))
-
-  if "cudnn" in libraries:
-    cudnn_paths = _get_legacy_path("CUDNN_INSTALL_PATH", base_paths)
-    cudnn_version = os.environ.get("TF_CUDNN_VERSION", "")
-    result.update(_find_cudnn_config(cudnn_paths, cudnn_version))
-
-  if "nccl" in libraries:
-    nccl_paths = _get_legacy_path("NCCL_INSTALL_PATH", base_paths)
-    nccl_version = os.environ.get("TF_NCCL_VERSION", "")
-    result.update(_find_nccl_config(nccl_paths, nccl_version))
-
-  if "tensorrt" in libraries:
-    tensorrt_paths = _get_legacy_path("TENSORRT_INSTALL_PATH", base_paths)
-    tensorrt_version = os.environ.get("TF_TENSORRT_VERSION", "")
-    result.update(_find_tensorrt_config(tensorrt_paths, tensorrt_version))
-
-  for k, v in result.items():
-    if k.endswith("_dir") or k.endswith("_path"):
-      result[k] = os.path.realpath(v)
-
-  return result
-
-
-def main():
-  try:
-    for key, value in sorted(find_cuda_config().items()):
-      print("%s: %s" % (key, value))
-  except ConfigError as e:
-    sys.stderr.write(str(e) + '\n')
-    sys.exit(1)
-
-
-if __name__ == "__main__":
-  main()
diff --git a/third_party/xla/third_party/gpus/find_rocm_config.py b/third_party/xla/third_party/gpus/find_rocm_config.py
deleted file mode 100644
index cd64efe..0000000
--- a/third_party/xla/third_party/gpus/find_rocm_config.py
+++ /dev/null
@@ -1,446 +0,0 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Prints ROCm library and header directories and versions found on the system.
-
-The script searches for ROCm library and header files on the system, inspects
-them to determine their version and prints the configuration to stdout.
-The path to inspect is specified through an environment variable (ROCM_PATH).
-If no valid configuration is found, the script prints to stderr and
-returns an error code.
-
-The script takes the directory specified by the ROCM_PATH environment variable.
-The script looks for headers and library files in a hard-coded set of
-subdirectories from base path of the specified directory. If ROCM_PATH is not
-specified, then "/opt/rocm" is used as it default value
-
-"""
-
-import io
-import os
-import re
-import sys
-
-
-class ConfigError(Exception):
-  pass
-
-
-def _get_default_rocm_path():
-  return "/opt/rocm"
-
-
-def _get_rocm_install_path():
-  """Determines and returns the ROCm installation path."""
-  rocm_install_path = _get_default_rocm_path()
-  if "ROCM_PATH" in os.environ:
-    rocm_install_path = os.environ["ROCM_PATH"]
-  # rocm_install_path = os.path.realpath(rocm_install_path)
-  return rocm_install_path
-
-
-def _get_composite_version_number(major, minor, patch):
-  return 10000 * major + 100 * minor + patch
-
-
-def _get_header_version(path, name):
-  """Returns preprocessor defines in C header file."""
-  for line in io.open(path, "r", encoding="utf-8"):
-    match = re.match(r"#define %s +(\d+)" % name, line)
-    if match:
-      value = match.group(1)
-      return int(value)
-
-  raise ConfigError('#define "{}" is either\n'.format(name) +
-                    "  not present in file {} OR\n".format(path) +
-                    "  its value is not an integer literal")
-
-
-def _find_rocm_config(rocm_install_path):
-
-  def rocm_version_numbers(path):
-    possible_version_files = [
-        "include/rocm-core/rocm_version.h",  # ROCm 5.2
-        "include/rocm_version.h",  # ROCm 5.1 and prior
-    ]
-    version_file = None
-    for f in possible_version_files:
-      version_file_path = os.path.join(path, f)
-      if os.path.exists(version_file_path):
-        version_file = version_file_path
-        break
-    if not version_file:
-      raise ConfigError(
-          "ROCm version file not found in {}".format(possible_version_files))
-
-    major = _get_header_version(version_file, "ROCM_VERSION_MAJOR")
-    minor = _get_header_version(version_file, "ROCM_VERSION_MINOR")
-    patch = _get_header_version(version_file, "ROCM_VERSION_PATCH")
-    return major, minor, patch
-
-  major, minor, patch = rocm_version_numbers(rocm_install_path)
-
-  rocm_config = {
-      "rocm_version_number": _get_composite_version_number(major, minor, patch)
-  }
-
-  return rocm_config
-
-
-def _find_hipruntime_config(rocm_install_path):
-
-  def hipruntime_version_number(path):
-    possible_version_files = [
-        "include/hip/hip_version.h",  # ROCm 5.2
-        "hip/include/hip/hip_version.h",  # ROCm 5.1 and prior
-    ]
-    version_file = None
-    for f in possible_version_files:
-      version_file_path = os.path.join(path, f)
-      if os.path.exists(version_file_path):
-        version_file = version_file_path
-        break
-    if not version_file:
-      raise ConfigError("HIP Runtime version file not found in {}".format(
-          possible_version_files))
-
-    # This header file has an explicit #define for HIP_VERSION, whose value
-    # is (HIP_VERSION_MAJOR * 100 + HIP_VERSION_MINOR)
-    # Retreive the major + minor and re-calculate here, since we do not
-    # want get into the business of parsing arith exprs
-    major = _get_header_version(version_file, "HIP_VERSION_MAJOR")
-    minor = _get_header_version(version_file, "HIP_VERSION_MINOR")
-    return 100 * major + minor
-
-  hipruntime_config = {
-      "hipruntime_version_number": hipruntime_version_number(rocm_install_path)
-  }
-
-  return hipruntime_config
-
-
-def _find_miopen_config(rocm_install_path):
-
-  def miopen_version_numbers(path):
-    possible_version_files = [
-        "include/miopen/version.h",  # ROCm 5.2 and prior
-        "miopen/include/miopen/version.h",  # ROCm 5.1 and prior
-    ]
-    version_file = None
-    for f in possible_version_files:
-      version_file_path = os.path.join(path, f)
-      if os.path.exists(version_file_path):
-        version_file = version_file_path
-        break
-    if not version_file:
-      raise ConfigError(
-          'MIOpen version file "{}" not found'.format(version_file))
-    major = _get_header_version(version_file, "MIOPEN_VERSION_MAJOR")
-    minor = _get_header_version(version_file, "MIOPEN_VERSION_MINOR")
-    patch = _get_header_version(version_file, "MIOPEN_VERSION_PATCH")
-    return major, minor, patch
-
-  major, minor, patch = miopen_version_numbers(rocm_install_path)
-
-  miopen_config = {
-      "miopen_version_number":
-          _get_composite_version_number(major, minor, patch)
-  }
-
-  return miopen_config
-
-
-def _find_rocblas_config(rocm_install_path):
-
-  def rocblas_version_numbers(path):
-    possible_version_files = [
-        "include/rocblas/internal/rocblas-version.h",  # ROCm 5.2
-        "rocblas/include/internal/rocblas-version.h",  # ROCm 5.1 and prior
-    ]
-    version_file = None
-    for f in possible_version_files:
-      version_file_path = os.path.join(path, f)
-      if os.path.exists(version_file_path):
-        version_file = version_file_path
-        break
-    if not version_file:
-      raise ConfigError(
-          "rocblas version file not found in {}".format(
-              possible_version_files))
-    major = _get_header_version(version_file, "ROCBLAS_VERSION_MAJOR")
-    minor = _get_header_version(version_file, "ROCBLAS_VERSION_MINOR")
-    patch = _get_header_version(version_file, "ROCBLAS_VERSION_PATCH")
-    return major, minor, patch
-
-  major, minor, patch = rocblas_version_numbers(rocm_install_path)
-
-  rocblas_config = {
-      "rocblas_version_number":
-          _get_composite_version_number(major, minor, patch)
-  }
-
-  return rocblas_config
-
-
-def _find_rocrand_config(rocm_install_path):
-
-  def rocrand_version_number(path):
-    possible_version_files = [
-        "include/rocrand/rocrand_version.h",  # ROCm 5.1
-        "rocrand/include/rocrand_version.h",  # ROCm 5.0 and prior
-    ]
-    version_file = None
-    for f in possible_version_files:
-      version_file_path = os.path.join(path, f)
-      if os.path.exists(version_file_path):
-        version_file = version_file_path
-        break
-    if not version_file:
-      raise ConfigError(
-          "rocrand version file not found in {}".format(possible_version_files))
-    version_number = _get_header_version(version_file, "ROCRAND_VERSION")
-    return version_number
-
-  rocrand_config = {
-      "rocrand_version_number": rocrand_version_number(rocm_install_path)
-  }
-
-  return rocrand_config
-
-
-def _find_rocfft_config(rocm_install_path):
-
-  def rocfft_version_numbers(path):
-    possible_version_files = [
-        "include/rocfft/rocfft-version.h",  # ROCm 5.2
-        "rocfft/include/rocfft-version.h",  # ROCm 5.1 and prior
-    ]
-    version_file = None
-    for f in possible_version_files:
-      version_file_path = os.path.join(path, f)
-      if os.path.exists(version_file_path):
-        version_file = version_file_path
-        break
-    if not version_file:
-      raise ConfigError(
-          "rocfft version file not found in {}".format(possible_version_files))
-    major = _get_header_version(version_file, "rocfft_version_major")
-    minor = _get_header_version(version_file, "rocfft_version_minor")
-    patch = _get_header_version(version_file, "rocfft_version_patch")
-    return major, minor, patch
-
-  major, minor, patch = rocfft_version_numbers(rocm_install_path)
-
-  rocfft_config = {
-      "rocfft_version_number":
-          _get_composite_version_number(major, minor, patch)
-  }
-
-  return rocfft_config
-
-
-def _find_hipfft_config(rocm_install_path):
-
-  def hipfft_version_numbers(path):
-    possible_version_files = [
-        "include/hipfft/hipfft-version.h",  # ROCm 5.2
-        "hipfft/include/hipfft-version.h",  # ROCm 5.1 and prior
-    ]
-    version_file = None
-    for f in possible_version_files:
-      version_file_path = os.path.join(path, f)
-      if os.path.exists(version_file_path):
-        version_file = version_file_path
-        break
-    if not version_file:
-      raise ConfigError(
-          "hipfft version file not found in {}".format(possible_version_files))
-    major = _get_header_version(version_file, "hipfftVersionMajor")
-    minor = _get_header_version(version_file, "hipfftVersionMinor")
-    patch = _get_header_version(version_file, "hipfftVersionPatch")
-    return major, minor, patch
-
-  major, minor, patch = hipfft_version_numbers(rocm_install_path)
-
-  hipfft_config = {
-      "hipfft_version_number":
-          _get_composite_version_number(major, minor, patch)
-  }
-
-  return hipfft_config
-
-
-def _find_roctracer_config(rocm_install_path):
-
-  def roctracer_version_numbers(path):
-    possible_version_files = [
-        "include/roctracer/roctracer.h",  # ROCm 5.2
-        "roctracer/include/roctracer.h",  # ROCm 5.1 and prior
-    ]
-    version_file = None
-    for f in possible_version_files:
-      version_file_path = os.path.join(path, f)
-      if os.path.exists(version_file_path):
-        version_file = version_file_path
-        break
-    if not version_file:
-      raise ConfigError("roctracer version file not found in {}".format(
-          possible_version_files))
-    major = _get_header_version(version_file, "ROCTRACER_VERSION_MAJOR")
-    minor = _get_header_version(version_file, "ROCTRACER_VERSION_MINOR")
-    # roctracer header does not have a patch version number
-    patch = 0
-    return major, minor, patch
-
-  major, minor, patch = roctracer_version_numbers(rocm_install_path)
-
-  roctracer_config = {
-      "roctracer_version_number":
-          _get_composite_version_number(major, minor, patch)
-  }
-
-  return roctracer_config
-
-
-def _find_hipsparse_config(rocm_install_path):
-
-  def hipsparse_version_numbers(path):
-    possible_version_files = [
-        "include/hipsparse/hipsparse-version.h",  # ROCm 5.2
-        "hipsparse/include/hipsparse-version.h",  # ROCm 5.1 and prior
-    ]
-    version_file = None
-    for f in possible_version_files:
-      version_file_path = os.path.join(path, f)
-      if os.path.exists(version_file_path):
-        version_file = version_file_path
-        break
-    if not version_file:
-      raise ConfigError("hipsparse version file not found in {}".format(
-          possible_version_files))
-    major = _get_header_version(version_file, "hipsparseVersionMajor")
-    minor = _get_header_version(version_file, "hipsparseVersionMinor")
-    patch = _get_header_version(version_file, "hipsparseVersionPatch")
-    return major, minor, patch
-
-  major, minor, patch = hipsparse_version_numbers(rocm_install_path)
-
-  hipsparse_config = {
-      "hipsparse_version_number":
-          _get_composite_version_number(major, minor, patch)
-  }
-
-  return hipsparse_config
-
-def _find_hipsolver_config(rocm_install_path):
-
-  def hipsolver_version_numbers(path):
-    possible_version_files = [
-        "include/hipsolver/internal/hipsolver-version.h",  # ROCm 5.2
-        "hipsolver/include/internal/hipsolver-version.h",  # ROCm 5.1
-        "hipsolver/include/hipsolver-version.h",  # ROCm 5.0 and prior
-    ]
-    version_file = None
-    for f in possible_version_files:
-      version_file_path = os.path.join(path, f)
-      if os.path.exists(version_file_path):
-        version_file = version_file_path
-        break
-    if not version_file:
-      raise ConfigError("hipsolver version file not found in {}".format(
-          possible_version_files))
-    major = _get_header_version(version_file, "hipsolverVersionMajor")
-    minor = _get_header_version(version_file, "hipsolverVersionMinor")
-    patch = _get_header_version(version_file, "hipsolverVersionPatch")
-    return major, minor, patch
-
-  major, minor, patch = hipsolver_version_numbers(rocm_install_path)
-
-  hipsolver_config = {
-      "hipsolver_version_number":
-          _get_composite_version_number(major, minor, patch)
-  }
-
-  return hipsolver_config
-
-
-def _find_rocsolver_config(rocm_install_path):
-
-  def rocsolver_version_numbers(path):
-    possible_version_files = [
-        "include/rocsolver/rocsolver-version.h",  # ROCm 5.2
-        "rocsolver/include/rocsolver-version.h",  # ROCm 5.1 and prior
-    ]
-    version_file = None
-    for f in possible_version_files:
-      version_file_path = os.path.join(path, f)
-      if os.path.exists(version_file_path):
-        version_file = version_file_path
-        break
-    if not version_file:
-      raise ConfigError("rocsolver version file not found in {}".format(
-          possible_version_files))
-    major = _get_header_version(version_file, "ROCSOLVER_VERSION_MAJOR")
-    minor = _get_header_version(version_file, "ROCSOLVER_VERSION_MINOR")
-    patch = _get_header_version(version_file, "ROCSOLVER_VERSION_PATCH")
-    return major, minor, patch
-
-  major, minor, patch = rocsolver_version_numbers(rocm_install_path)
-
-  rocsolver_config = {
-      "rocsolver_version_number":
-          _get_composite_version_number(major, minor, patch)
-  }
-
-  return rocsolver_config
-
-
-def find_rocm_config():
-  """Returns a dictionary of ROCm components config info."""
-  rocm_install_path = _get_rocm_install_path()
-  if not os.path.exists(rocm_install_path):
-    raise ConfigError(
-        'Specified ROCM_PATH "{}" does not exist'.format(rocm_install_path))
-
-  result = {}
-
-  result["rocm_toolkit_path"] = rocm_install_path
-  result.update(_find_rocm_config(rocm_install_path))
-  result.update(_find_hipruntime_config(rocm_install_path))
-  result.update(_find_miopen_config(rocm_install_path))
-  result.update(_find_rocblas_config(rocm_install_path))
-  result.update(_find_rocrand_config(rocm_install_path))
-  result.update(_find_rocfft_config(rocm_install_path))
-  if result["rocm_version_number"] >= 40100:
-    result.update(_find_hipfft_config(rocm_install_path))
-  result.update(_find_roctracer_config(rocm_install_path))
-  result.update(_find_hipsparse_config(rocm_install_path))
-  if result["rocm_version_number"] >= 40500:
-    result.update(_find_hipsolver_config(rocm_install_path))
-  result.update(_find_rocsolver_config(rocm_install_path))
-
-  return result
-
-
-def main():
-  try:
-    for key, value in sorted(find_rocm_config().items()):
-      print("%s: %s" % (key, value))
-  except ConfigError as e:
-    sys.stderr.write("\nERROR: {}\n\n".format(str(e)))
-    sys.exit(1)
-
-
-if __name__ == "__main__":
-  main()
diff --git a/third_party/xla/third_party/gpus/local_config_cuda.BUILD b/third_party/xla/third_party/gpus/local_config_cuda.BUILD
deleted file mode 100644
index bed22cc..0000000
--- a/third_party/xla/third_party/gpus/local_config_cuda.BUILD
+++ /dev/null
@@ -1,60 +0,0 @@
-load("@bazel_skylib//rules:common_settings.bzl", "string_flag")
-load("@local_config_cuda//cuda:build_defs.bzl", "enable_cuda_flag")
-
-package(default_visibility = ["//visibility:public"])
-
-# Build flag to enable CUDA support.
-#
-# Enable with '--@local_config_cuda//:enable_cuda', or indirectly with
-# ./configure or '--config=cuda'.
-enable_cuda_flag(
-    name = "enable_cuda",
-    build_setting_default = False,
-    enable_override = select({
-        ":define_using_cuda_nvcc": True,
-        "//conditions:default": False,
-    }),
-)
-
-# Config setting whether CUDA support has been requested.
-#
-# Enable path: ./configure > --config=cuda (.tf_configure.bazelrc)
-#     > --//tensorflow:enable_cuda (.bazelrc) > :is_cuda_enabled
-config_setting(
-    name = "is_cuda_enabled",
-    flag_values = {":enable_cuda": "True"},
-)
-
-# Build flag to select CUDA compiler.
-#
-# Set with '--@local_config_cuda//:cuda_compiler=...', or indirectly with
-# ./configure, '--config=cuda' or '--config=cuda_clang'.
-string_flag(
-    name = "cuda_compiler",
-    build_setting_default = "nvcc",
-    values = [
-        "clang",
-        "nvcc",
-    ],
-)
-
-# Config setting whether CUDA device code should be compiled with clang.
-config_setting(
-    name = "is_cuda_compiler_clang",
-    flag_values = {":cuda_compiler": "clang"},
-)
-
-# Config setting whether CUDA device code should be compiled with nvcc.
-config_setting(
-    name = "is_cuda_compiler_nvcc",
-    flag_values = {":cuda_compiler": "nvcc"},
-)
-
-# Config setting to keep `--define=using_cuda_nvcc=true` working.
-# TODO(b/174244321): Remove when downstream projects have been fixed, along
-# with the enable_cuda_flag rule in cuda:build_defs.bzl.tpl.
-config_setting(
-    name = "define_using_cuda_nvcc",
-    define_values = {"using_cuda_nvcc": "true"},
-    visibility = ["//visibility:private"],
-)
diff --git a/third_party/xla/third_party/gpus/rocm/BUILD b/third_party/xla/third_party/gpus/rocm/BUILD
deleted file mode 100644
index e69de29..0000000
--- a/third_party/xla/third_party/gpus/rocm/BUILD
+++ /dev/null
diff --git a/third_party/xla/third_party/gpus/rocm/BUILD.tpl b/third_party/xla/third_party/gpus/rocm/BUILD.tpl
deleted file mode 100644
index aa3688e..0000000
--- a/third_party/xla/third_party/gpus/rocm/BUILD.tpl
+++ /dev/null
@@ -1,182 +0,0 @@
-load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
-
-licenses(["restricted"])  # MPL2, portions GPL v3, LGPL v3, BSD-like
-
-package(default_visibility = ["//visibility:public"])
-
-config_setting(
-    name = "using_hipcc",
-    values = {
-        "define": "using_rocm_hipcc=true",
-    },
-)
-
-cc_library(
-    name = "rocm_headers",
-    hdrs = [
-        "rocm/rocm_config.h",
-        %{rocm_headers}
-    ],
-    includes = [
-        ".",
-        "rocm/include",
-        "rocm/include/rocrand",
-        "rocm/include/roctracer",
-    ],
-    visibility = ["//visibility:public"],
-)
-
-cc_library(
-    name = "hip",
-    srcs = ["rocm/lib/%{hip_lib}"],
-    data = ["rocm/lib/%{hip_lib}"],
-    includes = [
-        ".",
-        "rocm/include",
-    ],
-    linkstatic = 1,
-    visibility = ["//visibility:public"],
-)
-
-cc_library(
-    name = "rocblas",
-    srcs = ["rocm/lib/%{rocblas_lib}"],
-    data = ["rocm/lib/%{rocblas_lib}"],
-    includes = [
-        ".",
-        "rocm/include",
-    ],
-    linkstatic = 1,
-    visibility = ["//visibility:public"],
-)
-
-cc_library(
-    name = "%{hipfft_or_rocfft}",
-    srcs = ["rocm/lib/%{hipfft_or_rocfft_lib}"],
-    data = ["rocm/lib/%{hipfft_or_rocfft_lib}"],
-    includes = [
-        ".",
-        "rocm/include",
-    ],
-    linkstatic = 1,
-    visibility = ["//visibility:public"],
-)
-
-cc_library(
-    name = "hiprand",
-    srcs = ["rocm/lib/%{hiprand_lib}"],
-    data = ["rocm/lib/%{hiprand_lib}"],
-    includes = [
-        ".",
-        "rocm/include",
-        "rocm/include/rocrand",
-    ],
-    linkstatic = 1,
-    visibility = ["//visibility:public"],
-)
-
-cc_library(
-    name = "miopen",
-    srcs = ["rocm/lib/%{miopen_lib}"],
-    data = ["rocm/lib/%{miopen_lib}"],
-    includes = [
-        ".",
-        "rocm/include",
-    ],
-    linkstatic = 1,
-    visibility = ["//visibility:public"],
-)
-
-cc_library(
-    name = "rccl",
-    srcs = ["rocm/lib/%{rccl_lib}"],
-    data = ["rocm/lib/%{rccl_lib}"],
-    includes = [
-        ".",
-        "rocm/include",
-    ],
-    linkstatic = 1,
-    visibility = ["//visibility:public"],
-)
-
-cc_library(
-    name = "rocm",
-    visibility = ["//visibility:public"],
-    deps = [
-        ":rocm_headers",
-        ":hip",
-        ":rocblas",
-        ":hipblas",
-        ":%{hipfft_or_rocfft}",
-        ":hiprand",
-        ":miopen",
-        ":hipsparse",
-        ":roctracer",
-        ":rocsolver",
-        ":hipsolver",
-    ],
-)
-
-bzl_library(
-    name = "build_defs_bzl",
-    srcs = ["build_defs.bzl"],
-)
-
-cc_library(
-    name = "rocprim",
-    srcs = [
-        "rocm/include/hipcub/hipcub_version.hpp",
-        "rocm/include/rocprim/rocprim_version.hpp",
-    ],
-    hdrs = glob([
-        "rocm/include/hipcub/**",
-        "rocm/include/rocprim/**",
-    ]),
-    includes = [
-        ".",
-        "rocm/include/hipcub",
-        "rocm/include/rocprim",
-    ],
-    visibility = ["//visibility:public"],
-    deps = [
-        "@local_config_rocm//rocm:rocm_headers",
-    ],
-)
-
-cc_library(
-    name = "hipsparse",
-    srcs = ["rocm/lib/%{hipsparse_lib}"],
-    data = ["rocm/lib/%{hipsparse_lib}"],
-)
-
-cc_library(
-    name = "roctracer",
-    data = ["rocm/lib/%{roctracer_lib}"],
-)
-
-cc_library(
-    name = "rocsolver",
-    srcs = ["rocm/lib/%{rocsolver_lib}"],
-    data = ["rocm/lib/%{rocsolver_lib}"],
-)
-
-cc_library(
-    name = "hipsolver",
-    srcs = ["rocm/lib/%{hipsolver_lib}"],
-    data = ["rocm/lib/%{hipsolver_lib}"],
-)
-
-cc_library(
-    name = "hipblas",
-    srcs = ["rocm/lib/%{hipblas_lib}"],
-    data = ["rocm/lib/%{hipblas_lib}"],
-)
-
-filegroup(
-    name = "rocm_root",
-    srcs = [
-        "rocm/bin/clang-offload-bundler",
-    ],
-)
-
-%{copy_rules}
diff --git a/third_party/xla/third_party/gpus/rocm/build_defs.bzl.tpl b/third_party/xla/third_party/gpus/rocm/build_defs.bzl.tpl
deleted file mode 100644
index 2b4595b..0000000
--- a/third_party/xla/third_party/gpus/rocm/build_defs.bzl.tpl
+++ /dev/null
@@ -1,61 +0,0 @@
-# Macros for building ROCm code.
-def if_rocm(if_true, if_false = []):
-    """Shorthand for select()'ing on whether we're building with ROCm.
-
-    Returns a select statement which evaluates to if_true if we're building
-    with ROCm enabled.  Otherwise, the select statement evaluates to if_false.
-
-    """
-    return select({
-        "@local_config_rocm//rocm:using_hipcc": if_true,
-        "//conditions:default": if_false
-    })
-
-
-def rocm_default_copts():
-    """Default options for all ROCm compilations."""
-    return if_rocm(["-x", "rocm"] + %{rocm_extra_copts})
-
-def rocm_copts(opts = []):
-    """Gets the appropriate set of copts for (maybe) ROCm compilation.
-
-      If we're doing ROCm compilation, returns copts for our particular ROCm
-      compiler.  If we're not doing ROCm compilation, returns an empty list.
-
-      """
-    return rocm_default_copts() + select({
-        "//conditions:default": [],
-        "@local_config_rocm//rocm:using_hipcc": ([
-            "",
-        ]),
-    }) + if_rocm_is_configured(opts)
-
-def rocm_gpu_architectures():
-    """Returns a list of supported GPU architectures."""
-    return %{rocm_gpu_architectures}
-
-def rocm_version_number():
-    """Returns a list of supported GPU architectures."""
-    return %{rocm_version_number}
-
-def if_rocm_is_configured(x):
-    """Tests if the ROCm was enabled during the configure process.
-
-    Unlike if_rocm(), this does not require that we are building with
-    --config=rocm. Used to allow non-ROCm code to depend on ROCm libraries.
-    """
-    if %{rocm_is_configured}:
-      return select({"//conditions:default": x})
-    return select({"//conditions:default": []})
-
-def rocm_hipblaslt():
-    return %{rocm_is_configured} and %{rocm_hipblaslt}
-
-def if_rocm_hipblaslt(x):
-    if %{rocm_is_configured} and (%{rocm_hipblaslt} == "True"):
-      return select({"//conditions:default": x})
-    return select({"//conditions:default": []})
-
-def rocm_library(copts = [], **kwargs):
-    """Wrapper over cc_library which adds default ROCm options."""
-    native.cc_library(copts = rocm_default_copts() + copts, **kwargs)
diff --git a/third_party/xla/third_party/gpus/rocm/rocm_config.h.tpl b/third_party/xla/third_party/gpus/rocm/rocm_config.h.tpl
deleted file mode 100644
index 20506f6..0000000
--- a/third_party/xla/third_party/gpus/rocm/rocm_config.h.tpl
+++ /dev/null
@@ -1,26 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef ROCM_ROCM_CONFIG_H_
-#define ROCM_ROCM_CONFIG_H_
-
-#define TF_ROCM_TOOLKIT_PATH "%{rocm_toolkit_path}"
-
-#define TF_ROCM_VERSION %{rocm_version_number}
-#define TF_MIOPEN_VERSION %{miopen_version_number}
-#define TF_HIPRUNTIME_VERSION %{hipruntime_version_number}
-#define TF_HIPBLASLT %{hipblaslt_flag}
-
-#endif  // ROCM_ROCM_CONFIG_H_
diff --git a/third_party/xla/third_party/gpus/rocm_configure.bzl b/third_party/xla/third_party/gpus/rocm_configure.bzl
deleted file mode 100644
index c461ccf..0000000
--- a/third_party/xla/third_party/gpus/rocm_configure.bzl
+++ /dev/null
@@ -1,851 +0,0 @@
-"""Repository rule for ROCm autoconfiguration.
-
-`rocm_configure` depends on the following environment variables:
-
-  * `TF_NEED_ROCM`: Whether to enable building with ROCm.
-  * `GCC_HOST_COMPILER_PATH`: The GCC host compiler path
-  * `ROCM_PATH`: The path to the ROCm toolkit. Default is `/opt/rocm`.
-  * `TF_ROCM_AMDGPU_TARGETS`: The AMDGPU targets.
-"""
-
-load(
-    ":cuda_configure.bzl",
-    "make_copy_dir_rule",
-    "make_copy_files_rule",
-    "to_list_of_strings",
-)
-load(
-    "//third_party/remote_config:common.bzl",
-    "config_repo_label",
-    "err_out",
-    "execute",
-    "files_exist",
-    "get_bash_bin",
-    "get_cpu_value",
-    "get_host_environ",
-    "get_python_bin",
-    "raw_exec",
-    "realpath",
-    "which",
-)
-
-_GCC_HOST_COMPILER_PATH = "GCC_HOST_COMPILER_PATH"
-_GCC_HOST_COMPILER_PREFIX = "GCC_HOST_COMPILER_PREFIX"
-_ROCM_TOOLKIT_PATH = "ROCM_PATH"
-_TF_ROCM_AMDGPU_TARGETS = "TF_ROCM_AMDGPU_TARGETS"
-_TF_ROCM_CONFIG_REPO = "TF_ROCM_CONFIG_REPO"
-
-_DEFAULT_ROCM_TOOLKIT_PATH = "/opt/rocm"
-
-def verify_build_defines(params):
-    """Verify all variables that crosstool/BUILD.rocm.tpl expects are substituted.
-
-    Args:
-      params: dict of variables that will be passed to the BUILD.tpl template.
-    """
-    missing = []
-    for param in [
-        "cxx_builtin_include_directories",
-        "extra_no_canonical_prefixes_flags",
-        "host_compiler_path",
-        "host_compiler_prefix",
-        "linker_bin_path",
-        "unfiltered_compile_flags",
-    ]:
-        if ("%{" + param + "}") not in params:
-            missing.append(param)
-
-    if missing:
-        auto_configure_fail(
-            "BUILD.rocm.tpl template is missing these variables: " +
-            str(missing) +
-            ".\nWe only got: " +
-            str(params) +
-            ".",
-        )
-
-def find_cc(repository_ctx):
-    """Find the C++ compiler."""
-
-    # Return a dummy value for GCC detection here to avoid error
-    target_cc_name = "gcc"
-    cc_path_envvar = _GCC_HOST_COMPILER_PATH
-    cc_name = target_cc_name
-
-    cc_name_from_env = get_host_environ(repository_ctx, cc_path_envvar)
-    if cc_name_from_env:
-        cc_name = cc_name_from_env
-    if cc_name.startswith("/"):
-        # Absolute path, maybe we should make this supported by our which function.
-        return cc_name
-    cc = which(repository_ctx, cc_name)
-    if cc == None:
-        fail(("Cannot find {}, either correct your path or set the {}" +
-              " environment variable").format(target_cc_name, cc_path_envvar))
-    return cc
-
-_INC_DIR_MARKER_BEGIN = "#include <...>"
-
-def _cxx_inc_convert(path):
-    """Convert path returned by cc -E xc++ in a complete path."""
-    path = path.strip()
-    return path
-
-def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp):
-    """Compute the list of default C or C++ include directories."""
-    if lang_is_cpp:
-        lang = "c++"
-    else:
-        lang = "c"
-
-    # TODO: We pass -no-canonical-prefixes here to match the compiler flags,
-    #       but in rocm_clang CROSSTOOL file that is a `feature` and we should
-    #       handle the case when it's disabled and no flag is passed
-    result = raw_exec(repository_ctx, [
-        cc,
-        "-no-canonical-prefixes",
-        "-E",
-        "-x" + lang,
-        "-",
-        "-v",
-    ])
-    stderr = err_out(result)
-    index1 = stderr.find(_INC_DIR_MARKER_BEGIN)
-    if index1 == -1:
-        return []
-    index1 = stderr.find("\n", index1)
-    if index1 == -1:
-        return []
-    index2 = stderr.rfind("\n ")
-    if index2 == -1 or index2 < index1:
-        return []
-    index2 = stderr.find("\n", index2 + 1)
-    if index2 == -1:
-        inc_dirs = stderr[index1 + 1:]
-    else:
-        inc_dirs = stderr[index1 + 1:index2].strip()
-
-    return [
-        str(repository_ctx.path(_cxx_inc_convert(p)))
-        for p in inc_dirs.split("\n")
-    ]
-
-def get_cxx_inc_directories(repository_ctx, cc):
-    """Compute the list of default C and C++ include directories."""
-
-    # For some reason `clang -xc` sometimes returns include paths that are
-    # different from the ones from `clang -xc++`. (Symlink and a dir)
-    # So we run the compiler with both `-xc` and `-xc++` and merge resulting lists
-    includes_cpp = _get_cxx_inc_directories_impl(repository_ctx, cc, True)
-    includes_c = _get_cxx_inc_directories_impl(repository_ctx, cc, False)
-
-    includes_cpp_set = depset(includes_cpp)
-    return includes_cpp + [
-        inc
-        for inc in includes_c
-        if inc not in includes_cpp_set.to_list()
-    ]
-
-def auto_configure_fail(msg):
-    """Output failure message when rocm configuration fails."""
-    red = "\033[0;31m"
-    no_color = "\033[0m"
-    fail("\n%sROCm Configuration Error:%s %s\n" % (red, no_color, msg))
-
-def auto_configure_warning(msg):
-    """Output warning message during auto configuration."""
-    yellow = "\033[1;33m"
-    no_color = "\033[0m"
-    print("\n%sAuto-Configuration Warning:%s %s\n" % (yellow, no_color, msg))
-
-# END cc_configure common functions (see TODO above).
-
-def _rocm_include_path(repository_ctx, rocm_config, bash_bin):
-    """Generates the cxx_builtin_include_directory entries for rocm inc dirs.
-
-    Args:
-      repository_ctx: The repository context.
-      rocm_config: The path to the gcc host compiler.
-
-    Returns:
-      A string containing the Starlark string for each of the gcc
-      host compiler include directories, which can be added to the CROSSTOOL
-      file.
-    """
-    inc_dirs = []
-
-    # Add HSA headers (needs to match $HSA_PATH)
-    inc_dirs.append(rocm_config.rocm_toolkit_path + "/hsa/include")
-
-    # Add HIP headers (needs to match $HIP_PATH)
-    inc_dirs.append(rocm_config.rocm_toolkit_path + "/hip/include")
-    if int(rocm_config.rocm_version_number) >= 50200:
-        inc_dirs.append(rocm_config.rocm_toolkit_path + "/include")
-        inc_dirs.append(rocm_config.rocm_toolkit_path + "/include/hip")
-        inc_dirs.append(rocm_config.rocm_toolkit_path + "/include/rocprim")
-        inc_dirs.append(rocm_config.rocm_toolkit_path + "/include/rocsolver")
-        inc_dirs.append(rocm_config.rocm_toolkit_path + "/include/rocblas")
-
-    # Add HIP-Clang headers (realpath relative to compiler binary)
-    rocm_toolkit_path = realpath(repository_ctx, rocm_config.rocm_toolkit_path, bash_bin)
-    inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/8.0/include")
-    inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/9.0.0/include")
-    inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/10.0.0/include")
-    inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/11.0.0/include")
-    inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/12.0.0/include")
-    inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/13.0.0/include")
-    inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/14.0.0/include")
-    inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/15.0.0/include")
-    inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/16.0.0/include")
-    inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/17.0.0/include")
-
-    # Support hcc based off clang 10.0.0 (for ROCm 3.3)
-    inc_dirs.append(rocm_toolkit_path + "/hcc/compiler/lib/clang/10.0.0/include/")
-    inc_dirs.append(rocm_toolkit_path + "/hcc/lib/clang/10.0.0/include")
-
-    # Add hcc headers
-    inc_dirs.append(rocm_toolkit_path + "/hcc/include")
-
-    return inc_dirs
-
-def _enable_rocm(repository_ctx):
-    enable_rocm = get_host_environ(repository_ctx, "TF_NEED_ROCM")
-    if enable_rocm == "1":
-        if get_cpu_value(repository_ctx) != "Linux":
-            auto_configure_warning("ROCm configure is only supported on Linux")
-            return False
-        return True
-    return False
-
-def _amdgpu_targets(repository_ctx, rocm_toolkit_path, bash_bin):
-    """Returns a list of strings representing AMDGPU targets."""
-    amdgpu_targets_str = get_host_environ(repository_ctx, _TF_ROCM_AMDGPU_TARGETS)
-    if not amdgpu_targets_str:
-        cmd = "%s/bin/rocm_agent_enumerator" % rocm_toolkit_path
-        result = execute(repository_ctx, [bash_bin, "-c", cmd])
-        targets = [target for target in result.stdout.strip().split("\n") if target != "gfx000"]
-        targets = {x: None for x in targets}
-        targets = list(targets.keys())
-        amdgpu_targets_str = ",".join(targets)
-    amdgpu_targets = amdgpu_targets_str.split(",")
-    for amdgpu_target in amdgpu_targets:
-        if amdgpu_target[:3] != "gfx":
-            auto_configure_fail("Invalid AMDGPU target: %s" % amdgpu_target)
-    return amdgpu_targets
-
-def _hipcc_env(repository_ctx):
-    """Returns the environment variable string for hipcc.
-
-    Args:
-        repository_ctx: The repository context.
-
-    Returns:
-        A string containing environment variables for hipcc.
-    """
-    hipcc_env = ""
-    for name in [
-        "HIP_CLANG_PATH",
-        "DEVICE_LIB_PATH",
-        "HIP_VDI_HOME",
-        "HIPCC_VERBOSE",
-        "HIPCC_COMPILE_FLAGS_APPEND",
-        "HIPPCC_LINK_FLAGS_APPEND",
-        "HCC_AMDGPU_TARGET",
-        "HIP_PLATFORM",
-    ]:
-        env_value = get_host_environ(repository_ctx, name)
-        if env_value:
-            hipcc_env = (hipcc_env + " " + name + "=\"" + env_value + "\";")
-    return hipcc_env.strip()
-
-def _crosstool_verbose(repository_ctx):
-    """Returns the environment variable value CROSSTOOL_VERBOSE.
-
-    Args:
-        repository_ctx: The repository context.
-
-    Returns:
-        A string containing value of environment variable CROSSTOOL_VERBOSE.
-    """
-    return get_host_environ(repository_ctx, "CROSSTOOL_VERBOSE", "0")
-
-def _lib_name(lib, version = "", static = False):
-    """Constructs the name of a library on Linux.
-
-    Args:
-      lib: The name of the library, such as "hip"
-      version: The version of the library.
-      static: True the library is static or False if it is a shared object.
-
-    Returns:
-      The platform-specific name of the library.
-    """
-    if static:
-        return "lib%s.a" % lib
-    else:
-        if version:
-            version = ".%s" % version
-        return "lib%s.so%s" % (lib, version)
-
-def _rocm_lib_paths(repository_ctx, lib, basedir):
-    file_name = _lib_name(lib, version = "", static = False)
-    return [
-        repository_ctx.path("%s/lib64/%s" % (basedir, file_name)),
-        repository_ctx.path("%s/lib64/stubs/%s" % (basedir, file_name)),
-        repository_ctx.path("%s/lib/x86_64-linux-gnu/%s" % (basedir, file_name)),
-        repository_ctx.path("%s/lib/%s" % (basedir, file_name)),
-        repository_ctx.path("%s/%s" % (basedir, file_name)),
-    ]
-
-def _batch_files_exist(repository_ctx, libs_paths, bash_bin):
-    all_paths = []
-    for row in libs_paths:
-        lib_paths = row[1]
-        for lib_path in lib_paths:
-            all_paths.append(lib_path)
-    return files_exist(repository_ctx, all_paths, bash_bin)
-
-def _select_rocm_lib_paths(repository_ctx, libs_paths, bash_bin):
-    test_results = _batch_files_exist(repository_ctx, libs_paths, bash_bin)
-
-    libs = {}
-    i = 0
-    for row in libs_paths:
-        name = row[0]
-        lib_paths = row[1]
-        optional = (len(row) > 2 and row[2] == True)
-        selected_path = None
-        for path in lib_paths:
-            if test_results[i] and selected_path == None:
-                # For each lib select the first path that exists.
-                selected_path = path
-            i = i + 1
-        if selected_path == None:
-            if optional:
-                libs[name] = None
-                continue
-            else:
-                auto_configure_fail("Cannot find rocm library %s" % name)
-
-        libs[name] = struct(file_name = selected_path.basename, path = realpath(repository_ctx, selected_path, bash_bin))
-
-    return libs
-
-def _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, miopen_path, rccl_path, bash_bin):
-    """Returns the ROCm libraries on the system.
-
-    Args:
-      repository_ctx: The repository context.
-      rocm_config: The ROCm config as returned by _get_rocm_config
-      bash_bin: the path to the bash interpreter
-
-    Returns:
-      Map of library names to structs of filename and path
-    """
-    libs_paths = [
-        (name, _rocm_lib_paths(repository_ctx, name, path))
-        for name, path in [
-            ("amdhip64", rocm_config.rocm_toolkit_path + "/hip"),
-            ("rocblas", rocm_config.rocm_toolkit_path),
-            (hipfft_or_rocfft, rocm_config.rocm_toolkit_path),
-            ("hiprand", rocm_config.rocm_toolkit_path),
-            ("MIOpen", miopen_path),
-            ("rccl", rccl_path),
-            ("hipsparse", rocm_config.rocm_toolkit_path),
-            ("roctracer64", rocm_config.rocm_toolkit_path + "/roctracer"),
-            ("rocsolver", rocm_config.rocm_toolkit_path),
-        ]
-    ]
-    if int(rocm_config.rocm_version_number) >= 40500:
-        libs_paths.append(("hipsolver", _rocm_lib_paths(repository_ctx, "hipsolver", rocm_config.rocm_toolkit_path)))
-        libs_paths.append(("hipblas", _rocm_lib_paths(repository_ctx, "hipblas", rocm_config.rocm_toolkit_path)))
-
-    # hipblaslt may be absent even in versions of ROCm where it exists
-    # (it is not installed by default in some containers). Autodetect.
-    libs_paths.append(("hipblaslt", _rocm_lib_paths(repository_ctx, "hipblaslt", rocm_config.rocm_toolkit_path), True))
-    return _select_rocm_lib_paths(repository_ctx, libs_paths, bash_bin)
-
-def find_rocm_config(repository_ctx):
-    """Returns ROCm config dictionary from running find_rocm_config.py"""
-    python_bin = get_python_bin(repository_ctx)
-    exec_result = execute(repository_ctx, [python_bin, repository_ctx.attr._find_rocm_config])
-    if exec_result.return_code:
-        auto_configure_fail("Failed to run find_rocm_config.py: %s" % err_out(exec_result))
-
-    # Parse the dict from stdout.
-    return dict([tuple(x.split(": ")) for x in exec_result.stdout.splitlines()])
-
-def _get_rocm_config(repository_ctx, bash_bin):
-    """Detects and returns information about the ROCm installation on the system.
-
-    Args:
-      repository_ctx: The repository context.
-      bash_bin: the path to the path interpreter
-
-    Returns:
-      A struct containing the following fields:
-        rocm_toolkit_path: The ROCm toolkit installation directory.
-        amdgpu_targets: A list of the system's AMDGPU targets.
-        rocm_version_number: The version of ROCm on the system.
-        miopen_version_number: The version of MIOpen on the system.
-        hipruntime_version_number: The version of HIP Runtime on the system.
-    """
-    config = find_rocm_config(repository_ctx)
-    rocm_toolkit_path = config["rocm_toolkit_path"]
-    rocm_version_number = config["rocm_version_number"]
-    miopen_version_number = config["miopen_version_number"]
-    hipruntime_version_number = config["hipruntime_version_number"]
-    return struct(
-        amdgpu_targets = _amdgpu_targets(repository_ctx, rocm_toolkit_path, bash_bin),
-        rocm_toolkit_path = rocm_toolkit_path,
-        rocm_version_number = rocm_version_number,
-        miopen_version_number = miopen_version_number,
-        hipruntime_version_number = hipruntime_version_number,
-    )
-
-def _tpl_path(repository_ctx, labelname):
-    return repository_ctx.path(Label("//third_party/gpus/%s.tpl" % labelname))
-
-def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
-    if not out:
-        out = tpl.replace(":", "/")
-    repository_ctx.template(
-        out,
-        _tpl_path(repository_ctx, tpl),
-        substitutions,
-    )
-
-_DUMMY_CROSSTOOL_BZL_FILE = """
-def error_gpu_disabled():
-  fail("ERROR: Building with --config=rocm but TensorFlow is not configured " +
-       "to build with GPU support. Please re-run ./configure and enter 'Y' " +
-       "at the prompt to build with GPU support.")
-
-  native.genrule(
-      name = "error_gen_crosstool",
-      outs = ["CROSSTOOL"],
-      cmd = "echo 'Should not be run.' && exit 1",
-  )
-
-  native.filegroup(
-      name = "crosstool",
-      srcs = [":CROSSTOOL"],
-      output_licenses = ["unencumbered"],
-  )
-"""
-
-_DUMMY_CROSSTOOL_BUILD_FILE = """
-load("//crosstool:error_gpu_disabled.bzl", "error_gpu_disabled")
-
-error_gpu_disabled()
-"""
-
-def _create_dummy_repository(repository_ctx):
-    # Set up BUILD file for rocm/.
-    _tpl(
-        repository_ctx,
-        "rocm:build_defs.bzl",
-        {
-            "%{rocm_is_configured}": "False",
-            "%{rocm_extra_copts}": "[]",
-            "%{rocm_gpu_architectures}": "[]",
-            "%{rocm_version_number}": "0",
-            "%{rocm_hipblaslt}": "False",
-        },
-    )
-    _tpl(
-        repository_ctx,
-        "rocm:BUILD",
-        {
-            "%{hip_lib}": _lib_name("hip"),
-            "%{rocblas_lib}": _lib_name("rocblas"),
-            "%{hipblas_lib}": _lib_name("hipblas"),
-            "%{miopen_lib}": _lib_name("miopen"),
-            "%{rccl_lib}": _lib_name("rccl"),
-            "%{hipfft_or_rocfft}": _lib_name("hipfft"),
-            "%{hipfft_or_rocfft_lib}": _lib_name("hipfft"),
-            "%{hiprand_lib}": _lib_name("hiprand"),
-            "%{hipsparse_lib}": _lib_name("hipsparse"),
-            "%{roctracer_lib}": _lib_name("roctracer64"),
-            "%{rocsolver_lib}": _lib_name("rocsolver"),
-            "%{hipsolver_lib}": _lib_name("hipsolver"),
-            "%{hipblaslt_lib}": _lib_name("hipblaslt"),
-            "%{copy_rules}": "",
-            "%{rocm_headers}": "",
-        },
-    )
-
-    # Create dummy files for the ROCm toolkit since they are still required by
-    # tensorflow/compiler/xla/stream_executor/rocm:rocm_rpath
-    repository_ctx.file("rocm/hip/include/hip/hip_runtime.h", "")
-
-    # Set up rocm_config.h, which is used by
-    # tensorflow/compiler/xla/stream_executor/dso_loader.cc.
-    _tpl(
-        repository_ctx,
-        "rocm:rocm_config.h",
-        {
-            "%{rocm_toolkit_path}": _DEFAULT_ROCM_TOOLKIT_PATH,
-            "%{hipblaslt_flag}": "0",
-        },
-        "rocm/rocm/rocm_config.h",
-    )
-
-    # If rocm_configure is not configured to build with GPU support, and the user
-    # attempts to build with --config=rocm, add a dummy build rule to intercept
-    # this and fail with an actionable error message.
-    repository_ctx.file(
-        "crosstool/error_gpu_disabled.bzl",
-        _DUMMY_CROSSTOOL_BZL_FILE,
-    )
-    repository_ctx.file("crosstool/BUILD", _DUMMY_CROSSTOOL_BUILD_FILE)
-
-def _norm_path(path):
-    """Returns a path with '/' and remove the trailing slash."""
-    path = path.replace("\\", "/")
-    if path[-1] == "/":
-        path = path[:-1]
-    return path
-
-def _genrule(src_dir, genrule_name, command, outs):
-    """Returns a string with a genrule.
-
-    Genrule executes the given command and produces the given outputs.
-    """
-    return (
-        "genrule(\n" +
-        '    name = "' +
-        genrule_name + '",\n' +
-        "    outs = [\n" +
-        outs +
-        "\n    ],\n" +
-        '    cmd = """\n' +
-        command +
-        '\n   """,\n' +
-        ")\n"
-    )
-
-def _compute_rocm_extra_copts(repository_ctx, amdgpu_targets):
-    amdgpu_target_flags = ["--amdgpu-target=" +
-                           amdgpu_target for amdgpu_target in amdgpu_targets]
-    return str(amdgpu_target_flags)
-
-def _create_local_rocm_repository(repository_ctx):
-    """Creates the repository containing files set up to build with ROCm."""
-
-    tpl_paths = {labelname: _tpl_path(repository_ctx, labelname) for labelname in [
-        "rocm:build_defs.bzl",
-        "rocm:BUILD",
-        "crosstool:BUILD.rocm",
-        "crosstool:hipcc_cc_toolchain_config.bzl",
-        "crosstool:clang/bin/crosstool_wrapper_driver_rocm",
-        "rocm:rocm_config.h",
-    ]}
-
-    bash_bin = get_bash_bin(repository_ctx)
-    rocm_config = _get_rocm_config(repository_ctx, bash_bin)
-
-    # For ROCm 4.1 and above use hipfft, older ROCm versions use rocfft
-    rocm_version_number = int(rocm_config.rocm_version_number)
-    hipfft_or_rocfft = "rocfft" if rocm_version_number < 40100 else "hipfft"
-
-    # For ROCm 5.2 and above, find MIOpen and RCCL in the main rocm lib path
-    miopen_path = rocm_config.rocm_toolkit_path + "/miopen" if rocm_version_number < 50200 else rocm_config.rocm_toolkit_path
-    rccl_path = rocm_config.rocm_toolkit_path + "/rccl" if rocm_version_number < 50200 else rocm_config.rocm_toolkit_path
-
-    # Copy header and library files to execroot.
-    # rocm_toolkit_path
-    rocm_toolkit_path = rocm_config.rocm_toolkit_path
-    copy_rules = [
-        make_copy_dir_rule(
-            repository_ctx,
-            name = "rocm-include",
-            src_dir = rocm_toolkit_path + "/include",
-            out_dir = "rocm/include",
-        ),
-    ]
-
-    # explicitly copy (into the local_config_rocm repo) the $ROCM_PATH/hiprand/include and
-    # $ROCM_PATH/rocrand/include dirs, only once the softlink to them in $ROCM_PATH/include
-    # dir has been removed. This removal will happen in a near-future ROCm release.
-    hiprand_include = ""
-    hiprand_include_softlink = rocm_config.rocm_toolkit_path + "/include/hiprand"
-    softlink_exists = files_exist(repository_ctx, [hiprand_include_softlink], bash_bin)
-    if not softlink_exists[0]:
-        hiprand_include = '":hiprand-include",\n'
-        copy_rules.append(
-            make_copy_dir_rule(
-                repository_ctx,
-                name = "hiprand-include",
-                src_dir = rocm_toolkit_path + "/hiprand/include",
-                out_dir = "rocm/include/hiprand",
-            ),
-        )
-
-    rocrand_include = ""
-    rocrand_include_softlink = rocm_config.rocm_toolkit_path + "/include/rocrand"
-    softlink_exists = files_exist(repository_ctx, [rocrand_include_softlink], bash_bin)
-    if not softlink_exists[0]:
-        rocrand_include = '":rocrand-include",\n'
-        copy_rules.append(
-            make_copy_dir_rule(
-                repository_ctx,
-                name = "rocrand-include",
-                src_dir = rocm_toolkit_path + "/rocrand/include",
-                out_dir = "rocm/include/rocrand",
-            ),
-        )
-
-    rocm_libs = _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, miopen_path, rccl_path, bash_bin)
-    rocm_lib_srcs = []
-    rocm_lib_outs = []
-    for lib in rocm_libs.values():
-        if lib:
-            rocm_lib_srcs.append(lib.path)
-            rocm_lib_outs.append("rocm/lib/" + lib.file_name)
-    copy_rules.append(make_copy_files_rule(
-        repository_ctx,
-        name = "rocm-lib",
-        srcs = rocm_lib_srcs,
-        outs = rocm_lib_outs,
-    ))
-
-    clang_offload_bundler_path = rocm_toolkit_path + "/llvm/bin/clang-offload-bundler"
-
-    # copy files mentioned in third_party/gpus/rocm/BUILD
-    copy_rules.append(make_copy_files_rule(
-        repository_ctx,
-        name = "rocm-bin",
-        srcs = [
-            clang_offload_bundler_path,
-        ],
-        outs = [
-            "rocm/bin/" + "clang-offload-bundler",
-        ],
-    ))
-
-    have_hipblaslt = "1" if rocm_libs["hipblaslt"] != None else "0"
-
-    # Set up BUILD file for rocm/
-    repository_ctx.template(
-        "rocm/build_defs.bzl",
-        tpl_paths["rocm:build_defs.bzl"],
-        {
-            "%{rocm_is_configured}": "True",
-            "%{rocm_extra_copts}": _compute_rocm_extra_copts(
-                repository_ctx,
-                rocm_config.amdgpu_targets,
-            ),
-            "%{rocm_gpu_architectures}": str(rocm_config.amdgpu_targets),
-            "%{rocm_version_number}": str(rocm_version_number),
-            "%{rocm_hipblaslt}": "True" if rocm_libs["hipblaslt"] != None else "False",
-        },
-    )
-
-    repository_dict = {
-        "%{hip_lib}": rocm_libs["amdhip64"].file_name,
-        "%{rocblas_lib}": rocm_libs["rocblas"].file_name,
-        "%{hipfft_or_rocfft}": hipfft_or_rocfft,
-        "%{hipfft_or_rocfft_lib}": rocm_libs[hipfft_or_rocfft].file_name,
-        "%{hiprand_lib}": rocm_libs["hiprand"].file_name,
-        "%{miopen_lib}": rocm_libs["MIOpen"].file_name,
-        "%{rccl_lib}": rocm_libs["rccl"].file_name,
-        "%{hipsparse_lib}": rocm_libs["hipsparse"].file_name,
-        "%{roctracer_lib}": rocm_libs["roctracer64"].file_name,
-        "%{rocsolver_lib}": rocm_libs["rocsolver"].file_name,
-        "%{copy_rules}": "\n".join(copy_rules),
-        "%{rocm_headers}": ('":rocm-include",\n' +
-                            hiprand_include +
-                            rocrand_include),
-    }
-    if rocm_libs["hipblaslt"] != None:
-        repository_dict["%{hipblaslt_lib}"] = rocm_libs["hipblaslt"].file_name
-
-    if rocm_version_number >= 40500:
-        repository_dict["%{hipsolver_lib}"] = rocm_libs["hipsolver"].file_name
-        repository_dict["%{hipblas_lib}"] = rocm_libs["hipblas"].file_name
-
-    repository_ctx.template(
-        "rocm/BUILD",
-        tpl_paths["rocm:BUILD"],
-        repository_dict,
-    )
-
-    # Set up crosstool/
-
-    cc = find_cc(repository_ctx)
-
-    host_compiler_includes = get_cxx_inc_directories(repository_ctx, cc)
-
-    host_compiler_prefix = get_host_environ(repository_ctx, _GCC_HOST_COMPILER_PREFIX, "/usr/bin")
-
-    rocm_defines = {}
-
-    rocm_defines["%{host_compiler_prefix}"] = host_compiler_prefix
-
-    rocm_defines["%{linker_bin_path}"] = rocm_config.rocm_toolkit_path + "/hcc/compiler/bin"
-
-    # For gcc, do not canonicalize system header paths; some versions of gcc
-    # pick the shortest possible path for system includes when creating the
-    # .d file - given that includes that are prefixed with "../" multiple
-    # time quickly grow longer than the root of the tree, this can lead to
-    # bazel's header check failing.
-    rocm_defines["%{extra_no_canonical_prefixes_flags}"] = "\"-fno-canonical-system-headers\""
-
-    rocm_defines["%{unfiltered_compile_flags}"] = to_list_of_strings([
-        "-DTENSORFLOW_USE_ROCM=1",
-        "-D__HIP_PLATFORM_HCC__",
-        "-DEIGEN_USE_HIP",
-    ])
-
-    rocm_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc"
-
-    rocm_defines["%{cxx_builtin_include_directories}"] = to_list_of_strings(
-        host_compiler_includes + _rocm_include_path(repository_ctx, rocm_config, bash_bin),
-    )
-
-    verify_build_defines(rocm_defines)
-
-    # Only expand template variables in the BUILD file
-    repository_ctx.template(
-        "crosstool/BUILD",
-        tpl_paths["crosstool:BUILD.rocm"],
-        rocm_defines,
-    )
-
-    # No templating of cc_toolchain_config - use attributes and templatize the
-    # BUILD file.
-    repository_ctx.template(
-        "crosstool/cc_toolchain_config.bzl",
-        tpl_paths["crosstool:hipcc_cc_toolchain_config.bzl"],
-    )
-
-    repository_ctx.template(
-        "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc",
-        tpl_paths["crosstool:clang/bin/crosstool_wrapper_driver_rocm"],
-        {
-            "%{cpu_compiler}": str(cc),
-            "%{hipcc_path}": rocm_config.rocm_toolkit_path + "/bin/hipcc",
-            "%{hipcc_env}": _hipcc_env(repository_ctx),
-            "%{rocr_runtime_path}": rocm_config.rocm_toolkit_path + "/lib",
-            "%{rocr_runtime_library}": "hsa-runtime64",
-            "%{hip_runtime_path}": rocm_config.rocm_toolkit_path + "/hip/lib",
-            "%{hip_runtime_library}": "amdhip64",
-            "%{crosstool_verbose}": _crosstool_verbose(repository_ctx),
-            "%{gcc_host_compiler_path}": str(cc),
-        },
-    )
-
-    # Set up rocm_config.h, which is used by
-    # tensorflow/compiler/xla/stream_executor/dso_loader.cc.
-    repository_ctx.template(
-        "rocm/rocm/rocm_config.h",
-        tpl_paths["rocm:rocm_config.h"],
-        {
-            "%{rocm_amdgpu_targets}": ",".join(
-                ["\"%s\"" % c for c in rocm_config.amdgpu_targets],
-            ),
-            "%{rocm_toolkit_path}": rocm_config.rocm_toolkit_path,
-            "%{rocm_version_number}": rocm_config.rocm_version_number,
-            "%{miopen_version_number}": rocm_config.miopen_version_number,
-            "%{hipruntime_version_number}": rocm_config.hipruntime_version_number,
-            "%{hipblaslt_flag}": have_hipblaslt,
-        },
-    )
-
-def _create_remote_rocm_repository(repository_ctx, remote_config_repo):
-    """Creates pointers to a remotely configured repo set up to build with ROCm."""
-    _tpl(
-        repository_ctx,
-        "rocm:build_defs.bzl",
-        {
-            "%{rocm_is_configured}": "True",
-            "%{rocm_extra_copts}": _compute_rocm_extra_copts(
-                repository_ctx,
-                [],  #_compute_capabilities(repository_ctx)
-            ),
-        },
-    )
-    repository_ctx.template(
-        "rocm/BUILD",
-        config_repo_label(remote_config_repo, "rocm:BUILD"),
-        {},
-    )
-    repository_ctx.template(
-        "rocm/build_defs.bzl",
-        config_repo_label(remote_config_repo, "rocm:build_defs.bzl"),
-        {},
-    )
-    repository_ctx.template(
-        "rocm/rocm/rocm_config.h",
-        config_repo_label(remote_config_repo, "rocm:rocm/rocm_config.h"),
-        {},
-    )
-    repository_ctx.template(
-        "crosstool/BUILD",
-        config_repo_label(remote_config_repo, "crosstool:BUILD"),
-        {},
-    )
-    repository_ctx.template(
-        "crosstool/cc_toolchain_config.bzl",
-        config_repo_label(remote_config_repo, "crosstool:cc_toolchain_config.bzl"),
-        {},
-    )
-    repository_ctx.template(
-        "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc",
-        config_repo_label(remote_config_repo, "crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc"),
-        {},
-    )
-
-def _rocm_autoconf_impl(repository_ctx):
-    """Implementation of the rocm_autoconf repository rule."""
-    if not _enable_rocm(repository_ctx):
-        _create_dummy_repository(repository_ctx)
-    elif get_host_environ(repository_ctx, _TF_ROCM_CONFIG_REPO) != None:
-        _create_remote_rocm_repository(
-            repository_ctx,
-            get_host_environ(repository_ctx, _TF_ROCM_CONFIG_REPO),
-        )
-    else:
-        _create_local_rocm_repository(repository_ctx)
-
-_ENVIRONS = [
-    _GCC_HOST_COMPILER_PATH,
-    _GCC_HOST_COMPILER_PREFIX,
-    "TF_NEED_ROCM",
-    _ROCM_TOOLKIT_PATH,
-    _TF_ROCM_AMDGPU_TARGETS,
-]
-
-remote_rocm_configure = repository_rule(
-    implementation = _create_local_rocm_repository,
-    environ = _ENVIRONS,
-    remotable = True,
-    attrs = {
-        "environ": attr.string_dict(),
-        "_find_rocm_config": attr.label(
-            default = Label("@local_xla//third_party/gpus:find_rocm_config.py"),
-        ),
-    },
-)
-
-rocm_configure = repository_rule(
-    implementation = _rocm_autoconf_impl,
-    environ = _ENVIRONS + [_TF_ROCM_CONFIG_REPO],
-    attrs = {
-        "_find_rocm_config": attr.label(
-            default = Label("@local_xla//third_party/gpus:find_rocm_config.py"),
-        ),
-    },
-)
-"""Detects and configures the local ROCm toolchain.
-
-Add the following to your WORKSPACE FILE:
-
-```python
-rocm_configure(name = "local_config_rocm")
-```
-
-Args:
-  name: A unique name for this workspace rule.
-"""
diff --git a/third_party/xla/third_party/grpc/BUILD b/third_party/xla/third_party/grpc/BUILD
deleted file mode 100644
index e69de29..0000000
--- a/third_party/xla/third_party/grpc/BUILD
+++ /dev/null
diff --git a/third_party/xla/third_party/grpc/generate_cc_env_fix.patch b/third_party/xla/third_party/grpc/generate_cc_env_fix.patch
deleted file mode 100644
index 51832fe..0000000
--- a/third_party/xla/third_party/grpc/generate_cc_env_fix.patch
+++ /dev/null
@@ -1,10 +0,0 @@
---- a/bazel/generate_cc.bzl
-+++ b/bazel/generate_cc.bzl
-@@ -141,6 +141,7 @@ def generate_cc_impl(ctx):
-         outputs = out_files,
-         executable = ctx.executable._protoc,
-         arguments = arguments,
-+        use_default_shell_env = True,
-     )
-
-     return struct(files = depset(out_files))
diff --git a/third_party/xla/third_party/grpc/register_go_toolchain.patch b/third_party/xla/third_party/grpc/register_go_toolchain.patch
deleted file mode 100644
index eabe6cc..0000000
--- a/third_party/xla/third_party/grpc/register_go_toolchain.patch
+++ /dev/null
@@ -1,13 +0,0 @@
-diff --git a/bazel/grpc_extra_deps.bzl b/bazel/grpc_extra_deps.bzl
-index 4c1dfad2e8..d3d9ce15ba 100644
---- a/bazel/grpc_extra_deps.bzl
-+++ b/bazel/grpc_extra_deps.bzl
-@@ -33,7 +33,7 @@ def grpc_extra_deps():
-     api_dependencies()
-
-     go_rules_dependencies()
--    go_register_toolchains()
-+    go_register_toolchains(version = "1.18.4")
-
-     apple_rules_dependencies()
-
diff --git a/third_party/xla/third_party/grpc/upb_platform_fix.patch b/third_party/xla/third_party/grpc/upb_platform_fix.patch
deleted file mode 100644
index 6edd660..0000000
--- a/third_party/xla/third_party/grpc/upb_platform_fix.patch
+++ /dev/null
@@ -1,13 +0,0 @@
-diff --git a/BUILD b/BUILD
-index ad85b202..2311b2e4 100644
---- a/BUILD
-+++ b/BUILD
-@@ -44,7 +44,7 @@ config_setting(
-
- config_setting(
-     name = "windows",
--    constraint_values = ["@bazel_tools//platforms:windows"],
-+    constraint_values = ["@platforms//os:windows"],
- )
-
- config_setting(
diff --git a/third_party/xla/third_party/hwloc/BUILD b/third_party/xla/third_party/hwloc/BUILD
deleted file mode 100644
index 3848c08..0000000
--- a/third_party/xla/third_party/hwloc/BUILD
+++ /dev/null
@@ -1,12 +0,0 @@
-# BUILD file to make this directory a package.
-
-package(
-    default_visibility = ["//visibility:public"],
-    # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
-    licenses = ["notice"],
-)
-
-exports_files(
-    ["static-components.h"],
-    visibility = ["//visibility:public"],
-)
diff --git a/third_party/xla/third_party/hwloc/BUILD.system b/third_party/xla/third_party/hwloc/BUILD.system
deleted file mode 100644
index 2989102..0000000
--- a/third_party/xla/third_party/hwloc/BUILD.system
+++ /dev/null
@@ -1,22 +0,0 @@
-# hwloc: Portable Hardware Locality Library
-
-licenses(["notice"])
-
-config_setting(
-    name = "with_numa_support",
-    define_values = {"with_numa_support": "true"},
-)
-
-filegroup(
-    name = "COPYING",
-    visibility = ["//visibility:public"],
-)
-
-cc_library(
-    name = "hwloc",
-    linkopts = select({
-        ":with_numa_support": ["-lhwloc"],
-        "//conditions:default": [],
-    }),
-    visibility = ["//visibility:public"],
-)
diff --git a/third_party/xla/third_party/hwloc/hwloc.BUILD b/third_party/xla/third_party/hwloc/hwloc.BUILD
deleted file mode 100644
index 57f32c3..0000000
--- a/third_party/xla/third_party/hwloc/hwloc.BUILD
+++ /dev/null
@@ -1,316 +0,0 @@
-# hwloc: Portable Hardware Locality Library
-
-package(
-    default_visibility = ["//visibility:public"],
-)
-
-licenses(["notice"])
-
-exports_files(["COPYING"])
-
-load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
-load("@bazel_skylib//rules:expand_template.bzl", "expand_template")
-
-COMMON_INCLUDE_COPTS = [
-    "-I.",
-    "-Ihwloc",
-    "-Iinclude",
-]
-
-DISABLE_WARNINGS_COPTS = [
-    "-Wno-vla",
-]
-
-VAR_SETTINGS_COPTS = [
-    "-DHWLOC_DUMPED_HWDATA_DIR=",
-    "-DRUNSTATEDIR=",
-]
-
-_INCLUDE_HWLOC_AUTOIGEN_CONFIG_H_COMMON_SUBS = {
-    "#undef HWLOC_VERSION_MAJOR": "#define HWLOC_VERSION_MAJOR 2",
-    "#undef HWLOC_VERSION_MINOR": "#define HWLOC_VERSION_MINOR 0",
-    "#undef HWLOC_VERSION_RELEASE": "#define HWLOC_VERSION_RELEASE 3",
-    "#undef HWLOC_VERSION_GREEK": "#define HWLOC_VERSION_GREEK \"\"",
-    "#undef HWLOC_VERSION": "#define HWLOC_VERSION \"2.0.3\"",
-    "#undef hwloc_pid_t": "#define hwloc_pid_t pid_t",
-    "#undef hwloc_thread_t": "#define hwloc_thread_t pthread_t",
-    "#  undef HWLOC_HAVE_STDINT_H": "#  define HWLOC_HAVE_STDINT_H 1",
-    "#undef HWLOC_SYM_TRANSFORM": "#define HWLOC_SYM_TRANSFORM 0",
-    "#undef HWLOC_SYM_PREFIX_CAPS": "#define HWLOC_SYM_PREFIX_CAPS HWLOC_",
-    "#undef HWLOC_SYM_PREFIX": "#define HWLOC_SYM_PREFIX hwloc_",
-}
-
-_INCLUDE_HWLOC_AUTOIGEN_CONFIG_H_LINUX_SUBS = dict(_INCLUDE_HWLOC_AUTOIGEN_CONFIG_H_COMMON_SUBS)
-
-_INCLUDE_HWLOC_AUTOIGEN_CONFIG_H_LINUX_SUBS.update({
-    "#undef HWLOC_LINUX_SYS": "#define HWLOC_LINUX_SYS 1",
-})
-
-expand_template(
-    name = "include_hwloc_autogen_config_h",
-    out = "include/hwloc/autogen/config.h",
-    substitutions = select({
-        "@local_tsl//tsl:linux_x86_64": _INCLUDE_HWLOC_AUTOIGEN_CONFIG_H_LINUX_SUBS,
-        "//conditions:default": _INCLUDE_HWLOC_AUTOIGEN_CONFIG_H_COMMON_SUBS,
-    }),
-    template = "include/hwloc/autogen/config.h.in",
-)
-
-_INCLUDE_PRIVATE_HWLOC_AUTOIGEN_CONFIG_H_COMMON_SUBS = {
-    "#undef HAVE_CLOCK_GETTIME": "#define HAVE_CLOCK_GETTIME 1",
-    "#undef HAVE_CTYPE_H": "#define HAVE_CTYPE_H 1",
-    "#undef HAVE_DECL_CTL_HW": "#define HAVE_DECL_CTL_HW 0",
-    "#undef HAVE_DECL_FABSF": "#define HAVE_DECL_FABSF 1",
-    "#undef HAVE_DECL_GETEXECNAME": "#define HAVE_DECL_GETEXECNAME 0",
-    "#undef HAVE_DECL_GETMODULEFILENAME": "#define HAVE_DECL_GETMODULEFILENAME 0",
-    "#undef HAVE_DECL_GETPROGNAME": "#define HAVE_DECL_GETPROGNAME 0",
-    "#undef HAVE_DECL_HW_NCPU": "#define HAVE_DECL_HW_NCPU 0",
-    "#undef HAVE_DECL_MODFF": "#define HAVE_DECL_MODFF 1",
-    "#undef HAVE_DECL_PTHREAD_GETAFFINITY_NP": "#define HAVE_DECL_PTHREAD_GETAFFINITY_NP 1",
-    "#undef HAVE_DECL_PTHREAD_SETAFFINITY_NP": "#define HAVE_DECL_PTHREAD_SETAFFINITY_NP 1",
-    "#undef HAVE_DECL_RUNNING_ON_VALGRIND": "#define HAVE_DECL_RUNNING_ON_VALGRIND 0",
-    "#undef HAVE_DECL_SCHED_GETCPU": "#define HAVE_DECL_SCHED_GETCPU 1",
-    "#undef HAVE_DECL_SNPRINTF": "#define HAVE_DECL_SNPRINTF 1",
-    "#undef HAVE_DECL_STRTOULL": "#define HAVE_DECL_STRTOULL 1",
-    "#undef HAVE_DECL__PUTENV": "#define HAVE_DECL__PUTENV 0",
-    "#undef HAVE_DECL__SC_LARGE_PAGESIZE": "#define HAVE_DECL__SC_LARGE_PAGESIZE 0",
-    "#undef HAVE_DECL__SC_NPROCESSORS_CONF": "#define HAVE_DECL__SC_NPROCESSORS_CONF 1",
-    "#undef HAVE_DECL__SC_NPROCESSORS_ONLN": "#define HAVE_DECL__SC_NPROCESSORS_ONLN 1",
-    "#undef HAVE_DECL__SC_NPROC_CONF": "#define HAVE_DECL__SC_NPROC_CONF 0",
-    "#undef HAVE_DECL__SC_NPROC_ONLN": "#define HAVE_DECL__SC_NPROC_ONLN 0",
-    "#undef HAVE_DECL__SC_PAGESIZE": "#define HAVE_DECL__SC_PAGESIZE 1",
-    "#undef HAVE_DECL__SC_PAGE_SIZE": "#define HAVE_DECL__SC_PAGE_SIZE 1",
-    "#undef HAVE_DECL__STRDUP": "#define HAVE_DECL__STRDUP 0",
-    "#undef HAVE_DIRENT_H": "#define HAVE_DIRENT_H 1",
-    "#undef HAVE_DLFCN_H": "#define HAVE_DLFCN_H 1",
-    "#undef HAVE_FFSL": "#define HAVE_FFSL 1",
-    "#undef HAVE_FFS": "#define HAVE_FFS 1",
-    "#undef HAVE_GETPAGESIZE": "#define HAVE_GETPAGESIZE 1",
-    "#undef HAVE_INTTYPES_H": "#define HAVE_INTTYPES_H 1",
-    "#undef HAVE_LANGINFO_H": "#define HAVE_LANGINFO_H 1",
-    "#undef HAVE_LOCALE_H": "#define HAVE_LOCALE_H 1",
-    "#undef HAVE_MALLOC_H": "#define HAVE_MALLOC_H 1",
-    "#undef HAVE_MEMALIGN": "#define HAVE_MEMALIGN 1",
-    "#undef HAVE_MEMORY_H": "#define HAVE_MEMORY_H 1",
-    "#undef HAVE_MKSTEMP": "#define HAVE_MKSTEMP 1",
-    "#undef HAVE_NL_LANGINFO": "#define HAVE_NL_LANGINFO 1",
-    "#undef HAVE_OPENAT": "#define HAVE_OPENAT 1",
-    "#undef HAVE_POSIX_MEMALIGN": "#define HAVE_POSIX_MEMALIGN 1",
-    "#undef HAVE_PTHREAD_T": "#define HAVE_PTHREAD_T 1",
-    "#undef HAVE_PUTWC": "#define HAVE_PUTWC 1",
-    "#undef HAVE_SETLOCALE": "#define HAVE_SETLOCALE 1",
-    "#undef HAVE_SSIZE_T": "#define HAVE_SSIZE_T 1",
-    "#undef HAVE_STDINT_H": "#define HAVE_STDINT_H 1",
-    "#undef HAVE_STDLIB_H": "#define HAVE_STDLIB_H 1",
-    "#undef HAVE_STRCASECMP": "#define HAVE_STRCASECMP 1",
-    "#undef HAVE_STRFTIME": "#define HAVE_STRFTIME 1",
-    "#undef HAVE_STRINGS_H": "#define HAVE_STRINGS_H 1",
-    "#undef HAVE_STRING_H": "#define HAVE_STRING_H 1",
-    "#undef HAVE_STRNCASECMP": "#define HAVE_STRNCASECMP 1",
-    "#undef HAVE_SYS_MMAN_H": "#define HAVE_SYS_MMAN_H 1",
-    "#undef HAVE_SYS_PARAM_H": "#define HAVE_SYS_PARAM_H 1",
-    "#undef HAVE_SYS_STAT_H": "#define HAVE_SYS_STAT_H 1",
-    "#undef HAVE_SYS_SYSCTL_H": "#define HAVE_SYS_SYSCTL_H 1",
-    "#undef HAVE_SYS_TYPES_H": "#define HAVE_SYS_TYPES_H 1",
-    "#undef HAVE_SYS_UTSNAME_H": "#define HAVE_SYS_UTSNAME_H 1",
-    "#undef HAVE_TIME_H": "#define HAVE_TIME_H 1",
-    "#undef HAVE_UNAME": "#define HAVE_UNAME 1",
-    "#undef HAVE_UNISTD_H": "#define HAVE_UNISTD_H 1",
-    "#undef HAVE_USELOCALE": "#define HAVE_USELOCALE 1",
-    "#undef HAVE_WCHAR_T": "#define HAVE_WCHAR_T 1",
-    "#undef HAVE_X11_KEYSYM_H": "#define HAVE_X11_KEYSYM_H 1",
-    "#undef HAVE_X11_XLIB_H": "#define HAVE_X11_XLIB_H 1",
-    "#undef HAVE_X11_XUTIL_H": "#define HAVE_X11_XUTIL_H 1",
-    "#undef HAVE___PROGNAME": "#define HAVE___PROGNAME 1",
-    "#undef HWLOC_C_HAVE_VISIBILITY": "#define HWLOC_C_HAVE_VISIBILITY 1",
-    "#undef HWLOC_HAVE_ATTRIBUTE_ALIGNED": "#define HWLOC_HAVE_ATTRIBUTE_ALIGNED 1",
-    "#undef HWLOC_HAVE_ATTRIBUTE_ALWAYS_INLINE": "#define HWLOC_HAVE_ATTRIBUTE_ALWAYS_INLINE 1",
-    "#undef HWLOC_HAVE_ATTRIBUTE_COLD": "#define HWLOC_HAVE_ATTRIBUTE_COLD 1",
-    "#undef HWLOC_HAVE_ATTRIBUTE_CONSTRUCTOR": "#define HWLOC_HAVE_ATTRIBUTE_CONSTRUCTOR 1",
-    "#undef HWLOC_HAVE_ATTRIBUTE_CONST": "#define HWLOC_HAVE_ATTRIBUTE_CONST 1",
-    "#undef HWLOC_HAVE_ATTRIBUTE_DEPRECATED": "#define HWLOC_HAVE_ATTRIBUTE_DEPRECATED 1",
-    "#undef HWLOC_HAVE_ATTRIBUTE_FORMAT": "#define HWLOC_HAVE_ATTRIBUTE_FORMAT 1",
-    "#undef HWLOC_HAVE_ATTRIBUTE_HOT": "#define HWLOC_HAVE_ATTRIBUTE_HOT 1",
-    "#undef HWLOC_HAVE_ATTRIBUTE_MALLOC": "#define HWLOC_HAVE_ATTRIBUTE_MALLOC 1",
-    "#undef HWLOC_HAVE_ATTRIBUTE_MAY_ALIAS": "#define HWLOC_HAVE_ATTRIBUTE_MAY_ALIAS 1",
-    "#undef HWLOC_HAVE_ATTRIBUTE_NONNULL": "#define HWLOC_HAVE_ATTRIBUTE_NONNULL 1",
-    "#undef HWLOC_HAVE_ATTRIBUTE_NORETURN": "#define HWLOC_HAVE_ATTRIBUTE_NORETURN 1",
-    "#undef HWLOC_HAVE_ATTRIBUTE_NO_INSTRUMENT_FUNCTION": "#define HWLOC_HAVE_ATTRIBUTE_NO_INSTRUMENT_FUNCTION 1",
-    "#undef HWLOC_HAVE_ATTRIBUTE_PACKED": "#define HWLOC_HAVE_ATTRIBUTE_PACKED 1",
-    "#undef HWLOC_HAVE_ATTRIBUTE_PURE": "#define HWLOC_HAVE_ATTRIBUTE_PURE 1",
-    "#undef HWLOC_HAVE_ATTRIBUTE_SENTINEL": "#define HWLOC_HAVE_ATTRIBUTE_SENTINEL 1",
-    "#undef HWLOC_HAVE_ATTRIBUTE_UNUSED": "#define HWLOC_HAVE_ATTRIBUTE_UNUSED 1",
-    "#undef HWLOC_HAVE_ATTRIBUTE_WARN_UNUSED_RESULT": "#define HWLOC_HAVE_ATTRIBUTE_WARN_UNUSED_RESULT 1",
-    "#undef HWLOC_HAVE_ATTRIBUTE_WEAK_ALIAS": "#define HWLOC_HAVE_ATTRIBUTE_WEAK_ALIAS 1",
-    "#undef HWLOC_HAVE_ATTRIBUTE": "#define HWLOC_HAVE_ATTRIBUTE 1",
-    "#undef HWLOC_HAVE_CPU_SET_S": "#define HWLOC_HAVE_CPU_SET_S 1",
-    "#undef HWLOC_HAVE_CPU_SET": "#define HWLOC_HAVE_CPU_SET 1",
-    "#undef HWLOC_HAVE_DECL_FFSL": "#define HWLOC_HAVE_DECL_FFSL 1",
-    "#undef HWLOC_HAVE_DECL_FFS": "#define HWLOC_HAVE_DECL_FFS 1",
-    "#undef HWLOC_HAVE_DECL_STRCASECMP": "#define HWLOC_HAVE_DECL_STRCASECMP 1",
-    "#undef HWLOC_HAVE_DECL_STRNCASECMP": "#define HWLOC_HAVE_DECL_STRNCASECMP 1",
-    "#undef HWLOC_HAVE_FFSL": "#define HWLOC_HAVE_FFSL 1",
-    "#undef HWLOC_HAVE_FFS": "#define HWLOC_HAVE_FFS 1",
-    "#undef HWLOC_HAVE_LIBTERMCAP": "#define HWLOC_HAVE_LIBTERMCAP 1",
-    "#undef HWLOC_HAVE_LINUXIO": "#define HWLOC_HAVE_LINUXIO 1",
-    "#undef HWLOC_HAVE_PTHREAD_MUTEX": "#define HWLOC_HAVE_PTHREAD_MUTEX 1",
-    "#undef HWLOC_HAVE_SCHED_SETAFFINITY": "#define HWLOC_HAVE_SCHED_SETAFFINITY 1",
-    "#undef HWLOC_HAVE_STDINT_H": "#define HWLOC_HAVE_STDINT_H 1",
-    "#undef HWLOC_HAVE_SYSCALL": "#define HWLOC_HAVE_SYSCALL 1",
-    "#undef HWLOC_HAVE_X11_KEYSYM": "#define HWLOC_HAVE_X11_KEYSYM 1",
-    "#undef HWLOC_HAVE_X86_CPUID": "#define HWLOC_HAVE_X86_CPUID 1",
-    "#undef HWLOC_SIZEOF_UNSIGNED_INT": "#define HWLOC_SIZEOF_UNSIGNED_INT 4",
-    "#undef HWLOC_SIZEOF_UNSIGNED_LONG": "#define HWLOC_SIZEOF_UNSIGNED_LONG 8",
-    "#undef HWLOC_SYM_PREFIX_CAPS": "#define HWLOC_SYM_PREFIX_CAPS HWLOC_",
-    "#undef HWLOC_SYM_PREFIX": "#define HWLOC_SYM_PREFIX hwloc_",
-    "#undef HWLOC_SYM_TRANSFORM": "#define HWLOC_SYM_TRANSFORM 0",
-    "#undef HWLOC_USE_NCURSES": "#define HWLOC_USE_NCURSES 1",
-    "#undef HWLOC_VERSION_GREEK": "#define HWLOC_VERSION_GREEK \"\"",
-    "#undef HWLOC_VERSION_MAJOR": "#define HWLOC_VERSION_MAJOR 2",
-    "#undef HWLOC_VERSION_MINOR": "#define HWLOC_VERSION_MINOR 0",
-    "#undef HWLOC_VERSION_RELEASE": "#define HWLOC_VERSION_RELEASE 3",
-    "#undef HWLOC_VERSION": "#define HWLOC_VERSION \"2.0.3\"",
-    "#undef HWLOC_X86_64_ARCH": "#define HWLOC_X86_64_ARCH 1",
-    "#undef LT_OBJDIR": "#define LT_OBJDIR \".libs/\"",
-    "#undef PACKAGE_BUGREPORT": "#define PACKAGE_BUGREPORT \"http://github.com/open-mpi/hwloc/issues",
-    "#undef PACKAGE_NAME": "#define PACKAGE_NAME \"hwloc\"",
-    "#undef PACKAGE_STRING": "#define PACKAGE_STRING \"hwloc 2.0.3\"",
-    "#undef PACKAGE_TARNAME": "#define PACKAGE_TARNAME \"hwloc\"",
-    "#undef PACKAGE_URL": "#define PACKAGE_URL \"\"",
-    "#undef PACKAGE_VERSION": "#define PACKAGE_VERSION \"2.0.3\"",
-    "#undef PACKAGE": "#define PACKAGE \"hwloc\"",
-    "#undef SIZEOF_UNSIGNED_INT": "#define SIZEOF_UNSIGNED_INT 4",
-    "#undef SIZEOF_UNSIGNED_LONG": "#define SIZEOF_UNSIGNED_LONG 8",
-    "#undef SIZEOF_VOID_P": "#define SIZEOF_VOID_P 8",
-    "#undef STDC_HEADERS": "#define STDC_HEADERS 1",
-    "# undef _HPUX_SOURCE": "# define _HPUX_SOURCE 1",
-    "# undef _ALL_SOURCE": "# define _ALL_SOURCE 1",
-    "# undef _GNU_SOURCE": "# define _GNU_SOURCE 1",
-    "# undef _POSIX_PTHREAD_SEMANTICS": "# define _POSIX_PTHREAD_SEMANTICS 1",
-    "# undef _TANDEM_SOURCE": "# define _TANDEM_SOURCE 1",
-    "# undef __EXTENSIONS__": "# define __EXTENSIONS__ 1",
-    "#undef VERSION": "#define VERSION \"2.0.3\"",
-    "#undef _HPUX_SOURCE": "#define _HPUX_SOURCE 1",
-    "#undef hwloc_pid_t": "#define hwloc_pid_t pid_t",
-    "#undef hwloc_thread_t": "#define hwloc_thread_t pthread_t",
-}
-
-_INCLUDE_PRIVATE_HWLOC_AUTOIGEN_CONFIG_H_CUDA_SUBS = {
-    "#undef HAVE_CUDA_RUNTIME_API_H": "#define HAVE_CUDA_RUNTIME_API_H 1",
-    "#undef HAVE_CUDA_H": "#define HAVE_CUDA_H 1",
-    "#undef HAVE_CUDA": "#define HAVE_CUDA 1",
-}
-
-_INCLUDE_PRIVATE_HWLOC_AUTOIGEN_CONFIG_H_LINUX_SUBS = {
-    "#undef HAVE_PROGRAM_INVOCATION_NAME": "#define HAVE_PROGRAM_INVOCATION_NAME 1",
-    "#undef HWLOC_LINUX_SYS": "#define HWLOC_LINUX_SYS 1",
-}
-
-_INCLUDE_PRIVATE_HWLOC_AUTOIGEN_CONFIG_H_LINUX_SUBS.update(_INCLUDE_PRIVATE_HWLOC_AUTOIGEN_CONFIG_H_COMMON_SUBS)
-
-_INCLUDE_PRIVATE_HWLOC_AUTOIGEN_CONFIG_H_CUDA_SUBS.update(_INCLUDE_PRIVATE_HWLOC_AUTOIGEN_CONFIG_H_LINUX_SUBS)
-
-expand_template(
-    name = "include_private_hwloc_autogen__config_h",
-    out = "include/private/autogen/config.h",
-    substitutions = if_cuda(
-        _INCLUDE_PRIVATE_HWLOC_AUTOIGEN_CONFIG_H_CUDA_SUBS,
-        if_false = _INCLUDE_PRIVATE_HWLOC_AUTOIGEN_CONFIG_H_LINUX_SUBS,
-    ),
-    template = "include/private/autogen/config.h.in",
-)
-
-expand_template(
-    name = "move_static_components_h",
-    out = "hwloc/static-components.h",
-    substitutions = {"&hwloc_linuxio_component": "//&hwloc_linuxio_component"},
-    template = "@local_xla//third_party/hwloc:static-components.h",
-)
-
-cc_library(
-    name = "hwloc",
-    srcs = [
-        "hwloc/base64.c",
-        "hwloc/bind.c",
-        "hwloc/bitmap.c",
-        "hwloc/components.c",
-        "hwloc/cpukinds.c",
-        "hwloc/diff.c",
-        "hwloc/distances.c",
-        "hwloc/memattrs.c",
-        "hwloc/misc.c",
-        "hwloc/pci-common.c",
-        "hwloc/shmem.c",
-        "hwloc/static-components.h",
-        "hwloc/topology.c",
-        "hwloc/topology-hardwired.c",
-        "hwloc/topology-noos.c",
-        "hwloc/topology-synthetic.c",
-        "hwloc/topology-xml.c",
-        "hwloc/topology-xml-nolibxml.c",
-        "hwloc/traversal.c",
-        "include/hwloc/plugins.h",
-        "include/hwloc/shmem.h",
-        "include/private/autogen/config.h",
-        "include/private/components.h",
-        "include/private/debug.h",
-        "include/private/internal-components.h",
-        "include/private/misc.h",
-        "include/private/private.h",
-        "include/private/xml.h",
-    ] + select({
-        "@local_tsl//tsl:linux_x86_64": [
-            "hwloc/topology-linux.c",
-            "hwloc/topology-x86.c",
-            "include/hwloc/linux.h",
-            "include/private/cpuid-x86.h",
-        ],
-        "@local_tsl//tsl:linux_aarch64": [
-            "hwloc/topology-linux.c",
-            "include/hwloc/linux.h",
-        ],
-        "@local_tsl//tsl:linux_ppc64le": [
-            "hwloc/topology-linux.c",
-            "include/hwloc/linux.h",
-        ],
-        "@local_tsl//tsl:freebsd": [
-            "hwloc/topology-freebsd.c",
-            "hwloc/topology-x86.c",
-            "include/private/cpuid-x86.h",
-        ],
-        "//conditions:default": [],
-    }),
-    hdrs = [
-        "include/hwloc.h",
-        "include/hwloc/autogen/config.h",
-        "include/hwloc/bitmap.h",
-        "include/hwloc/cpukinds.h",
-        "include/hwloc/deprecated.h",
-        "include/hwloc/diff.h",
-        "include/hwloc/distances.h",
-        "include/hwloc/export.h",
-        "include/hwloc/helper.h",
-        "include/hwloc/inlines.h",
-        "include/hwloc/memattrs.h",
-        "include/hwloc/rename.h",
-    ],
-    copts = COMMON_INCLUDE_COPTS + DISABLE_WARNINGS_COPTS + VAR_SETTINGS_COPTS,
-    features = [
-        "-parse_headers",
-        "-layering_check",
-    ],
-    includes = [
-        "hwloc",
-        "include",
-    ],
-    deps = [],
-)
-
-cc_binary(
-    name = "hwloc_print",
-    srcs = ["hwloc_print.cc"],
-    copts = COMMON_INCLUDE_COPTS,
-    deps = [
-        ":hwloc",
-    ],
-)
diff --git a/third_party/xla/third_party/hwloc/static-components.h b/third_party/xla/third_party/hwloc/static-components.h
deleted file mode 100644
index e83b311..0000000
--- a/third_party/xla/third_party/hwloc/static-components.h
+++ /dev/null
@@ -1,38 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_THIRD_PARTY_HWLOC_STATIC_COMPONENTS_H_
-#define TENSORFLOW_THIRD_PARTY_HWLOC_STATIC_COMPONENTS_H_
-
-#include <private/internal-components.h>
-static const struct hwloc_component* hwloc_static_components[] = {
-    &hwloc_noos_component,
-    &hwloc_xml_component,
-    &hwloc_synthetic_component,
-    &hwloc_xml_nolibxml_component,
-#ifdef __linux__
-    &hwloc_linux_component,
-    &hwloc_linuxio_component,
-#endif
-#ifdef __FreeBSD__
-    &hwloc_freebsd_component,
-#endif
-#if defined(__x86_64__) || defined(__amd64__) || defined(_M_IX86) || \
-    defined(_M_X64)
-    &hwloc_x86_component,
-#endif
-    NULL};
-
-#endif  // TENSORFLOW_THIRD_PARTY_HWLOC_STATIC_COMPONENTS_H_
diff --git a/third_party/xla/third_party/hwloc/workspace.bzl b/third_party/xla/third_party/hwloc/workspace.bzl
deleted file mode 100644
index ce8475c..0000000
--- a/third_party/xla/third_party/hwloc/workspace.bzl
+++ /dev/null
@@ -1,13 +0,0 @@
-"""loads the hwloc library, used by TF."""
-
-load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
-
-def repo():
-    tf_http_archive(
-        name = "hwloc",
-        urls = tf_mirror_urls("https://download.open-mpi.org/release/hwloc/v2.7/hwloc-2.7.1.tar.gz"),
-        sha256 = "4cb0a781ed980b03ad8c48beb57407aa67c4b908e45722954b9730379bc7f6d5",
-        strip_prefix = "hwloc-2.7.1",
-        build_file = "//third_party/hwloc:hwloc.BUILD",
-        system_build_file = "//third_party/hwloc:BUILD.system",
-    )
diff --git a/third_party/xla/third_party/implib_so/BUILD b/third_party/xla/third_party/implib_so/BUILD
deleted file mode 100644
index 8401d61..0000000
--- a/third_party/xla/third_party/implib_so/BUILD
+++ /dev/null
@@ -1,23 +0,0 @@
-package(default_visibility = ["//visibility:public"])
-
-licenses(["notice"])  # MIT
-
-py_binary(
-    name = "get_symbols",
-    srcs = ["get_symbols.py"],
-    visibility = ["//visibility:public"],
-    deps = [
-        "@bazel_tools//tools/python/runfiles",
-        "@implib_so//:implib_gen_lib",
-    ],
-)
-
-py_binary(
-    name = "make_stub",
-    srcs = ["make_stub.py"],
-    visibility = ["//visibility:public"],
-    deps = [
-        "@bazel_tools//tools/python/runfiles",
-        "@implib_so//:implib_gen_lib",
-    ],
-)
diff --git a/third_party/xla/third_party/implib_so/get_symbols.py b/third_party/xla/third_party/implib_so/get_symbols.py
deleted file mode 100644
index 9625052..0000000
--- a/third_party/xla/third_party/implib_so/get_symbols.py
+++ /dev/null
@@ -1,38 +0,0 @@
-"""Given a .so file, lists symbols that should be included in a stub.
-
-Example usage:
-$ bazel run -c opt @local_tsl//third_party/implib_so:get_symbols
-/usr/local/cuda/lib64/libcudart.so > third_party/tsl/tsl/cuda/cudart.symbols
-"""
-
-import argparse
-import importlib
-
-# We can't import implib-gen directly because it has a dash in its name.
-implib = importlib.import_module('implib-gen')
-
-
-def _is_exported_function(s):
-  return (
-      s['Bind'] != 'LOCAL'
-      and s['Type'] == 'FUNC'
-      and s['Ndx'] != 'UND'
-      and s['Name'] not in ['', '_init', '_fini']
-      and s['Default']
-  )
-
-
-def main():
-  parser = argparse.ArgumentParser(
-      description='Extracts a list of symbols from a shared library'
-  )
-  parser.add_argument('library', help='Path to the .so file.')
-  args = parser.parse_args()
-  syms = implib.collect_syms(args.library)
-  funs = [s['Name'] for s in syms if _is_exported_function(s)]
-  for f in sorted(funs):
-    print(f)
-
-
-if __name__ == '__main__':
-  main()
diff --git a/third_party/xla/third_party/implib_so/implib_so.BUILD b/third_party/xla/third_party/implib_so/implib_so.BUILD
deleted file mode 100644
index bbfb289..0000000
--- a/third_party/xla/third_party/implib_so/implib_so.BUILD
+++ /dev/null
@@ -1,20 +0,0 @@
-# Description:
-#   Implib.so is a simple equivalent of Windows DLL import libraries for POSIX
-#   shared libraries.
-
-package(default_visibility = ["//visibility:public"])
-
-licenses(["notice"])  # MIT
-
-exports_files([
-    "LICENSE.txt",
-])
-
-py_library(
-    name = "implib_gen_lib",
-    srcs = ["implib-gen.py"],
-    data = glob([
-        "arch/**/*.S.tpl",
-        "arch/**/*.ini",
-    ]),
-)
diff --git a/third_party/xla/third_party/implib_so/make_stub.py b/third_party/xla/third_party/implib_so/make_stub.py
deleted file mode 100644
index f0e1fe5..0000000
--- a/third_party/xla/third_party/implib_so/make_stub.py
+++ /dev/null
@@ -1,68 +0,0 @@
-"""Given a list of symbols, generates a stub."""
-
-import argparse
-import configparser
-import os
-import string
-
-from bazel_tools.tools.python.runfiles import runfiles
-
-r = runfiles.Create()
-
-
-def main():
-  parser = argparse.ArgumentParser(
-      description='Generates stubs for CUDA libraries.'
-  )
-  parser.add_argument('symbols', help='File containing a list of symbols.')
-  parser.add_argument(
-      '--outdir', '-o', help='Path to create wrapper at', default='.'
-  )
-  parser.add_argument(
-      '--target',
-      help='Target platform name, e.g. x86_64, aarch64.',
-      required=True,
-  )
-  args = parser.parse_args()
-
-  config_path = r.Rlocation(f'implib_so/arch/{args.target}/config.ini')
-  table_path = r.Rlocation(f'implib_so/arch/{args.target}/table.S.tpl')
-  trampoline_path = r.Rlocation(
-      f'implib_so/arch/{args.target}/trampoline.S.tpl'
-  )
-
-  cfg = configparser.ConfigParser(inline_comment_prefixes=';')
-  cfg.read(config_path)
-  ptr_size = int(cfg['Arch']['PointerSize'])
-
-  with open(args.symbols, 'r') as f:
-    funs = [s.strip() for s in f.readlines()]
-
-  # Generate assembly code, containing a table for the resolved symbols and the
-  # trampolines.
-  lib_name, _ = os.path.splitext(os.path.basename(args.symbols))
-
-  with open(os.path.join(args.outdir, f'{lib_name}.tramp.S'), 'w') as f:
-    with open(table_path, 'r') as t:
-      table_text = string.Template(t.read()).substitute(
-          lib_suffix=lib_name, table_size=ptr_size * (len(funs) + 1)
-      )
-    f.write(table_text)
-
-    with open(trampoline_path, 'r') as t:
-      tramp_tpl = string.Template(t.read())
-
-    for i, name in enumerate(funs):
-      tramp_text = tramp_tpl.substitute(
-          lib_suffix=lib_name, sym=name, offset=i * ptr_size, number=i
-      )
-      f.write(tramp_text)
-
-  # Generates a list of symbols, formatted as a list of C++ strings.
-  with open(os.path.join(args.outdir, f'{lib_name}.inc'), 'w') as f:
-    sym_names = ''.join(f'  "{name}",\n' for name in funs)
-    f.write(sym_names)
-
-
-if __name__ == '__main__':
-  main()
diff --git a/third_party/xla/third_party/implib_so/workspace.bzl b/third_party/xla/third_party/implib_so/workspace.bzl
deleted file mode 100644
index 01dad3b..0000000
--- a/third_party/xla/third_party/implib_so/workspace.bzl
+++ /dev/null
@@ -1,13 +0,0 @@
-"""Implib.so is a simple equivalent of Windows DLL import libraries for POSIX
-shared libraries."""
-
-load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
-
-def repo():
-    tf_http_archive(
-        name = "implib_so",
-        strip_prefix = "Implib.so-5fb84c2a750434b9df1da67d67b749eb929598f1",
-        sha256 = "10de0a616df24849f2a883747784c115f209708960e44556f5ce384de6f103e8",
-        urls = tf_mirror_urls("https://github.com/yugr/Implib.so/archive/5fb84c2a750434b9df1da67d67b749eb929598f1.tar.gz"),
-        build_file = "//third_party/implib_so:implib_so.BUILD",
-    )
diff --git a/third_party/xla/third_party/jpeg/BUILD b/third_party/xla/third_party/jpeg/BUILD
deleted file mode 100644
index ed1568c..0000000
--- a/third_party/xla/third_party/jpeg/BUILD
+++ /dev/null
@@ -1,3 +0,0 @@
-# Needed to make this a package.
-
-# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"])
diff --git a/third_party/xla/third_party/jpeg/BUILD.system b/third_party/xla/third_party/jpeg/BUILD.system
deleted file mode 100644
index f4f52da..0000000
--- a/third_party/xla/third_party/jpeg/BUILD.system
+++ /dev/null
@@ -1,12 +0,0 @@
-licenses(["notice"])  # custom notice-style license, see LICENSE.md
-
-filegroup(
-    name = "LICENSE.md",
-    visibility = ["//visibility:public"],
-)
-
-cc_library(
-    name = "jpeg",
-    linkopts = ["-ljpeg"],
-    visibility = ["//visibility:public"],
-)
diff --git a/third_party/xla/third_party/jpeg/jpeg.BUILD b/third_party/xla/third_party/jpeg/jpeg.BUILD
deleted file mode 100644
index 9f61f9e..0000000
--- a/third_party/xla/third_party/jpeg/jpeg.BUILD
+++ /dev/null
@@ -1,806 +0,0 @@
-# Description:
-#   libjpeg-turbo is a drop in replacement for jpeglib optimized with SIMD.
-
-load("@bazel_skylib//rules:expand_template.bzl", "expand_template")
-load("@bazel_skylib//rules:common_settings.bzl", "string_flag")
-
-licenses(["notice"])  # custom notice-style license, see LICENSE.md
-
-exports_files(["LICENSE.md"])
-
-WIN_COPTS = [
-    "/Ox",
-    "-DWITH_SIMD",
-    "-wd4996",
-]
-
-libjpegturbo_copts = select({
-    ":android": [
-        "-O3",
-        "-fPIC",
-        "-w",
-    ],
-    ":windows": WIN_COPTS,
-    "//conditions:default": [
-        "-O3",
-        "-w",
-    ],
-}) + select({
-    ":armeabi-v7a": [
-        "-D__ARM_NEON__",
-        "-DNEON_INTRINSICS",
-        "-march=armv7-a",
-        "-mfpu=neon",
-        "-mfloat-abi=softfp",
-        "-fprefetch-loop-arrays",
-    ],
-    ":arm64-v8a": [
-        "-DNEON_INTRINSICS",
-    ],
-    ":linux_ppc64le": [
-        "-mcpu=power8",
-        "-mtune=power8",
-    ],
-    "//conditions:default": [],
-})
-
-cc_library(
-    name = "jpeg",
-    srcs = [
-        "jaricom.c",
-        "jcapimin.c",
-        "jcapistd.c",
-        "jcarith.c",
-        "jccoefct.c",
-        "jccolor.c",
-        "jcdctmgr.c",
-        "jchuff.c",
-        "jchuff.h",
-        "jcinit.c",
-        "jcmainct.c",
-        "jcmarker.c",
-        "jcmaster.c",
-        "jcomapi.c",
-        "jconfig.h",
-        "jconfigint.h",
-        "jcparam.c",
-        "jcphuff.c",
-        "jcprepct.c",
-        "jcsample.c",
-        "jctrans.c",
-        "jdapimin.c",
-        "jdapistd.c",
-        "jdarith.c",
-        "jdatadst.c",
-        "jdatasrc.c",
-        "jdcoefct.c",
-        "jdcoefct.h",
-        "jdcolor.c",
-        "jdct.h",
-        "jddctmgr.c",
-        "jdhuff.c",
-        "jdhuff.h",
-        "jdinput.c",
-        "jdmainct.c",
-        "jdmainct.h",
-        "jdmarker.c",
-        "jdmaster.c",
-        "jdmaster.h",
-        "jdmerge.c",
-        "jdmerge.h",
-        "jdphuff.c",
-        "jdpostct.c",
-        "jdsample.c",
-        "jdsample.h",
-        "jdtrans.c",
-        "jerror.c",
-        "jfdctflt.c",
-        "jfdctfst.c",
-        "jfdctint.c",
-        "jidctflt.c",
-        "jidctfst.c",
-        "jidctint.c",
-        "jidctred.c",
-        "jinclude.h",
-        "jmemmgr.c",
-        "jmemnobs.c",
-        "jmemsys.h",
-        "jpeg_nbits_table.h",
-        "jpegcomp.h",
-        "jquant1.c",
-        "jquant2.c",
-        "jutils.c",
-        "jversion.h",
-    ],
-    hdrs = [
-        "jccolext.c",  # should have been named .inc
-        "jdcol565.c",  # should have been named .inc
-        "jdcolext.c",  # should have been named .inc
-        "jdmrg565.c",  # should have been named .inc
-        "jdmrgext.c",  # should have been named .inc
-        "jerror.h",
-        "jmorecfg.h",
-        "jpegint.h",
-        "jpeglib.h",
-        "jstdhuff.c",  # should have been named .inc
-    ],
-    copts = libjpegturbo_copts,
-    visibility = ["//visibility:public"],
-    deps = select({
-        ":nosimd": [":simd_none"],
-        ":k8": [":simd_x86_64"],
-        ":armeabi-v7a": [":simd_armv7a"],
-        ":arm64-v8a": [":simd_armv8a"],
-        ":linux_ppc64le": [":simd_altivec"],
-        ":windows": [":simd_win_x86_64"],
-        "//conditions:default": [":simd_none"],
-    }),
-)
-
-cc_library(
-    name = "simd_altivec",
-    srcs = [
-        "jchuff.h",
-        "jconfig.h",
-        "jconfigint.h",
-        "jdct.h",
-        "jerror.h",
-        "jinclude.h",
-        "jmorecfg.h",
-        "jpegint.h",
-        "jpeglib.h",
-        "jsimd.h",
-        "jsimddct.h",
-        "simd/jsimd.h",
-        "simd/powerpc/jccolor-altivec.c",
-        "simd/powerpc/jcgray-altivec.c",
-        "simd/powerpc/jcsample-altivec.c",
-        "simd/powerpc/jdcolor-altivec.c",
-        "simd/powerpc/jdmerge-altivec.c",
-        "simd/powerpc/jdsample-altivec.c",
-        "simd/powerpc/jfdctfst-altivec.c",
-        "simd/powerpc/jfdctint-altivec.c",
-        "simd/powerpc/jidctfst-altivec.c",
-        "simd/powerpc/jidctint-altivec.c",
-        "simd/powerpc/jquanti-altivec.c",
-        "simd/powerpc/jsimd.c",
-    ],
-    hdrs = [
-        "simd/powerpc/jccolext-altivec.c",
-        "simd/powerpc/jcgryext-altivec.c",
-        "simd/powerpc/jcsample.h",
-        "simd/powerpc/jdcolext-altivec.c",
-        "simd/powerpc/jdmrgext-altivec.c",
-        "simd/powerpc/jsimd_altivec.h",
-    ],
-    copts = libjpegturbo_copts,
-)
-
-SRCS_SIMD_COMMON = [
-    "jchuff.h",
-    "jconfig.h",
-    "jconfigint.h",
-    "jdct.h",
-    "jerror.h",
-    "jinclude.h",
-    "jmorecfg.h",
-    "jpegint.h",
-    "jpeglib.h",
-    "jsimddct.h",
-    "jsimd.h",
-    "simd/jsimd.h",
-]
-
-cc_library(
-    name = "simd_x86_64",
-    srcs = [
-        "simd/x86_64/jccolor-avx2.o",
-        "simd/x86_64/jccolor-sse2.o",
-        "simd/x86_64/jcgray-avx2.o",
-        "simd/x86_64/jcgray-sse2.o",
-        "simd/x86_64/jchuff-sse2.o",
-        "simd/x86_64/jcphuff-sse2.o",
-        "simd/x86_64/jcsample-avx2.o",
-        "simd/x86_64/jcsample-sse2.o",
-        "simd/x86_64/jdcolor-avx2.o",
-        "simd/x86_64/jdcolor-sse2.o",
-        "simd/x86_64/jdmerge-avx2.o",
-        "simd/x86_64/jdmerge-sse2.o",
-        "simd/x86_64/jdsample-avx2.o",
-        "simd/x86_64/jdsample-sse2.o",
-        "simd/x86_64/jfdctflt-sse.o",
-        "simd/x86_64/jfdctfst-sse2.o",
-        "simd/x86_64/jfdctint-avx2.o",
-        "simd/x86_64/jfdctint-sse2.o",
-        "simd/x86_64/jidctflt-sse2.o",
-        "simd/x86_64/jidctfst-sse2.o",
-        "simd/x86_64/jidctint-avx2.o",
-        "simd/x86_64/jidctint-sse2.o",
-        "simd/x86_64/jidctred-sse2.o",
-        "simd/x86_64/jquantf-sse2.o",
-        "simd/x86_64/jquanti-avx2.o",
-        "simd/x86_64/jquanti-sse2.o",
-        "simd/x86_64/jsimd.c",
-        "simd/x86_64/jsimdcpu.o",
-    ] + SRCS_SIMD_COMMON,
-    copts = libjpegturbo_copts,
-    linkstatic = 1,
-)
-
-genrule(
-    name = "simd_x86_64_assemblage23",
-    srcs = [
-        "jconfig.h",
-        "jconfigint.h",
-        "simd/x86_64/jccolext-avx2.asm",
-        "simd/x86_64/jccolext-sse2.asm",
-        "simd/x86_64/jccolor-avx2.asm",
-        "simd/x86_64/jccolor-sse2.asm",
-        "simd/x86_64/jcgray-avx2.asm",
-        "simd/x86_64/jcgray-sse2.asm",
-        "simd/x86_64/jcgryext-avx2.asm",
-        "simd/x86_64/jcgryext-sse2.asm",
-        "simd/x86_64/jchuff-sse2.asm",
-        "simd/x86_64/jcphuff-sse2.asm",
-        "simd/x86_64/jcsample-avx2.asm",
-        "simd/x86_64/jcsample-sse2.asm",
-        "simd/x86_64/jdcolext-avx2.asm",
-        "simd/x86_64/jdcolext-sse2.asm",
-        "simd/x86_64/jdcolor-avx2.asm",
-        "simd/x86_64/jdcolor-sse2.asm",
-        "simd/x86_64/jdmerge-avx2.asm",
-        "simd/x86_64/jdmerge-sse2.asm",
-        "simd/x86_64/jdmrgext-avx2.asm",
-        "simd/x86_64/jdmrgext-sse2.asm",
-        "simd/x86_64/jdsample-avx2.asm",
-        "simd/x86_64/jdsample-sse2.asm",
-        "simd/x86_64/jfdctflt-sse.asm",
-        "simd/x86_64/jfdctfst-sse2.asm",
-        "simd/x86_64/jfdctint-avx2.asm",
-        "simd/x86_64/jfdctint-sse2.asm",
-        "simd/x86_64/jidctflt-sse2.asm",
-        "simd/x86_64/jidctfst-sse2.asm",
-        "simd/x86_64/jidctint-avx2.asm",
-        "simd/x86_64/jidctint-sse2.asm",
-        "simd/x86_64/jidctred-sse2.asm",
-        "simd/x86_64/jquantf-sse2.asm",
-        "simd/x86_64/jquanti-avx2.asm",
-        "simd/x86_64/jquanti-sse2.asm",
-        "simd/x86_64/jsimdcpu.asm",
-        "simd/nasm/jcolsamp.inc",
-        "simd/nasm/jdct.inc",
-        "simd/nasm/jsimdcfg.inc",
-        "simd/nasm/jsimdcfg.inc.h",
-        "simd/nasm/jsimdext.inc",
-    ],
-    outs = [
-        "simd/x86_64/jccolor-avx2.o",
-        "simd/x86_64/jccolor-sse2.o",
-        "simd/x86_64/jcgray-avx2.o",
-        "simd/x86_64/jcgray-sse2.o",
-        "simd/x86_64/jchuff-sse2.o",
-        "simd/x86_64/jcphuff-sse2.o",
-        "simd/x86_64/jcsample-avx2.o",
-        "simd/x86_64/jcsample-sse2.o",
-        "simd/x86_64/jdcolor-avx2.o",
-        "simd/x86_64/jdcolor-sse2.o",
-        "simd/x86_64/jdmerge-avx2.o",
-        "simd/x86_64/jdmerge-sse2.o",
-        "simd/x86_64/jdsample-avx2.o",
-        "simd/x86_64/jdsample-sse2.o",
-        "simd/x86_64/jfdctflt-sse.o",
-        "simd/x86_64/jfdctfst-sse2.o",
-        "simd/x86_64/jfdctint-avx2.o",
-        "simd/x86_64/jfdctint-sse2.o",
-        "simd/x86_64/jidctflt-sse2.o",
-        "simd/x86_64/jidctfst-sse2.o",
-        "simd/x86_64/jidctint-avx2.o",
-        "simd/x86_64/jidctint-sse2.o",
-        "simd/x86_64/jidctred-sse2.o",
-        "simd/x86_64/jquantf-sse2.o",
-        "simd/x86_64/jquanti-avx2.o",
-        "simd/x86_64/jquanti-sse2.o",
-        "simd/x86_64/jsimdcpu.o",
-    ],
-    cmd = "for out in $(OUTS); do\n" +
-          "  $(location @nasm//:nasm) -f elf64" +
-          "    -DELF -DPIC -D__x86_64__" +
-          "    -I $$(dirname $(location jconfig.h))/" +
-          "    -I $$(dirname $(location jconfigint.h))/" +
-          "    -I $$(dirname $(location simd/nasm/jsimdcfg.inc.h))/" +
-          "    -I $$(dirname $(location simd/x86_64/jccolext-sse2.asm))/" +
-          "    -o $$out" +
-          "    $$(dirname $(location simd/x86_64/jccolext-sse2.asm))/$$(basename $${out%.o}.asm)\n" +
-          "done",
-    tools = ["@nasm"],
-)
-
-expand_template(
-    name = "neon-compat_gen",
-    out = "simd/arm/neon-compat.h",
-    substitutions = {
-        "#cmakedefine HAVE_VLD1_S16_X3": "#define HAVE_VLD1_S16_X3",
-        "#cmakedefine HAVE_VLD1_U16_X2": "#define HAVE_VLD1_U16_X2",
-        "#cmakedefine HAVE_VLD1Q_U8_X4": "#define HAVE_VLD1Q_U8_X4",
-    },
-    template = "simd/arm/neon-compat.h.in",
-)
-
-genrule(
-    name = "neon-compat_hdr_src",
-    srcs = ["simd/arm/neon-compat.h"],
-    outs = ["neon-compat.h"],
-    cmd = "cp $(location simd/arm/neon-compat.h) $@",
-)
-
-cc_library(
-    name = "neon-compat_hdr",
-    hdrs = ["neon-compat.h"],
-    copts = libjpegturbo_copts,
-)
-
-SRCS_SIMD_ARM = [
-    "simd/arm/jccolor-neon.c",
-    "simd/arm/jcgray-neon.c",
-    "simd/arm/jcphuff-neon.c",
-    "simd/arm/jcsample-neon.c",
-    "simd/arm/jdcolor-neon.c",
-    "simd/arm/jdmerge-neon.c",
-    "simd/arm/jdsample-neon.c",
-    "simd/arm/jfdctfst-neon.c",
-    "simd/arm/jfdctint-neon.c",
-    "simd/arm/jidctfst-neon.c",
-    "simd/arm/jidctint-neon.c",
-    "simd/arm/jidctred-neon.c",
-    "simd/arm/jquanti-neon.c",
-]
-
-# .c files in the following list are used like .h files in that they are
-# "#include"-ed in the actual .c files. So, treat them like normal headers, and
-# they *should not* be compiled into individual objects.
-HDRS_SIMD_ARM = [
-    "simd/arm/align.h",
-    "simd/arm/jchuff.h",
-    "simd/arm/jcgryext-neon.c",
-    "simd/arm/jdcolext-neon.c",
-    "simd/arm/jdmrgext-neon.c",
-]
-
-cc_library(
-    name = "simd_armv7a",
-    srcs = [
-        "simd/arm/aarch32/jchuff-neon.c",
-        "simd/arm/aarch32/jsimd.c",
-    ] + SRCS_SIMD_COMMON + SRCS_SIMD_ARM,
-    hdrs = [
-        "simd/arm/aarch32/jccolext-neon.c",
-    ] + HDRS_SIMD_ARM,
-    copts = libjpegturbo_copts,
-    visibility = ["//visibility:private"],
-    deps = [":neon-compat_hdr"],
-)
-
-cc_library(
-    name = "simd_armv8a",
-    srcs = [
-        "simd/arm/aarch64/jchuff-neon.c",
-        "simd/arm/aarch64/jsimd.c",
-    ] + SRCS_SIMD_COMMON + SRCS_SIMD_ARM,
-    hdrs = [
-        "simd/arm/aarch64/jccolext-neon.c",
-    ] + HDRS_SIMD_ARM,
-    copts = libjpegturbo_copts,
-    visibility = ["//visibility:private"],
-    deps = [":neon-compat_hdr"],
-)
-
-cc_library(
-    name = "simd_win_x86_64",
-    srcs = [
-        "simd/x86_64/jccolor-avx2.obj",
-        "simd/x86_64/jccolor-sse2.obj",
-        "simd/x86_64/jcgray-avx2.obj",
-        "simd/x86_64/jcgray-sse2.obj",
-        "simd/x86_64/jchuff-sse2.obj",
-        "simd/x86_64/jcphuff-sse2.obj",
-        "simd/x86_64/jcsample-avx2.obj",
-        "simd/x86_64/jcsample-sse2.obj",
-        "simd/x86_64/jdcolor-avx2.obj",
-        "simd/x86_64/jdcolor-sse2.obj",
-        "simd/x86_64/jdmerge-avx2.obj",
-        "simd/x86_64/jdmerge-sse2.obj",
-        "simd/x86_64/jdsample-avx2.obj",
-        "simd/x86_64/jdsample-sse2.obj",
-        "simd/x86_64/jfdctflt-sse.obj",
-        "simd/x86_64/jfdctfst-sse2.obj",
-        "simd/x86_64/jfdctint-avx2.obj",
-        "simd/x86_64/jfdctint-sse2.obj",
-        "simd/x86_64/jidctflt-sse2.obj",
-        "simd/x86_64/jidctfst-sse2.obj",
-        "simd/x86_64/jidctint-avx2.obj",
-        "simd/x86_64/jidctint-sse2.obj",
-        "simd/x86_64/jidctred-sse2.obj",
-        "simd/x86_64/jquantf-sse2.obj",
-        "simd/x86_64/jquanti-avx2.obj",
-        "simd/x86_64/jquanti-sse2.obj",
-        "simd/x86_64/jsimd.c",
-        "simd/x86_64/jsimdcpu.obj",
-    ] + SRCS_SIMD_COMMON,
-    copts = libjpegturbo_copts,
-)
-
-genrule(
-    name = "simd_win_x86_64_assemble",
-    srcs = [
-        "jconfig.h",
-        "jconfigint.h",
-        "simd/x86_64/jccolext-avx2.asm",
-        "simd/x86_64/jccolext-sse2.asm",
-        "simd/x86_64/jccolor-avx2.asm",
-        "simd/x86_64/jccolor-sse2.asm",
-        "simd/x86_64/jcgray-avx2.asm",
-        "simd/x86_64/jcgray-sse2.asm",
-        "simd/x86_64/jcgryext-avx2.asm",
-        "simd/x86_64/jcgryext-sse2.asm",
-        "simd/x86_64/jchuff-sse2.asm",
-        "simd/x86_64/jcphuff-sse2.asm",
-        "simd/x86_64/jcsample-avx2.asm",
-        "simd/x86_64/jcsample-sse2.asm",
-        "simd/x86_64/jdcolext-avx2.asm",
-        "simd/x86_64/jdcolext-sse2.asm",
-        "simd/x86_64/jdcolor-avx2.asm",
-        "simd/x86_64/jdcolor-sse2.asm",
-        "simd/x86_64/jdmerge-avx2.asm",
-        "simd/x86_64/jdmerge-sse2.asm",
-        "simd/x86_64/jdmrgext-avx2.asm",
-        "simd/x86_64/jdmrgext-sse2.asm",
-        "simd/x86_64/jdsample-avx2.asm",
-        "simd/x86_64/jdsample-sse2.asm",
-        "simd/x86_64/jfdctflt-sse.asm",
-        "simd/x86_64/jfdctfst-sse2.asm",
-        "simd/x86_64/jfdctint-avx2.asm",
-        "simd/x86_64/jfdctint-sse2.asm",
-        "simd/x86_64/jidctflt-sse2.asm",
-        "simd/x86_64/jidctfst-sse2.asm",
-        "simd/x86_64/jidctint-avx2.asm",
-        "simd/x86_64/jidctint-sse2.asm",
-        "simd/x86_64/jidctred-sse2.asm",
-        "simd/x86_64/jquantf-sse2.asm",
-        "simd/x86_64/jquanti-avx2.asm",
-        "simd/x86_64/jquanti-sse2.asm",
-        "simd/x86_64/jsimdcpu.asm",
-        "simd/nasm/jcolsamp.inc",
-        "simd/nasm/jdct.inc",
-        "simd/nasm/jsimdcfg.inc",
-        "simd/nasm/jsimdcfg.inc.h",
-        "simd/nasm/jsimdext.inc",
-    ],
-    outs = [
-        "simd/x86_64/jccolor-avx2.obj",
-        "simd/x86_64/jccolor-sse2.obj",
-        "simd/x86_64/jcgray-avx2.obj",
-        "simd/x86_64/jcgray-sse2.obj",
-        "simd/x86_64/jchuff-sse2.obj",
-        "simd/x86_64/jcphuff-sse2.obj",
-        "simd/x86_64/jcsample-avx2.obj",
-        "simd/x86_64/jcsample-sse2.obj",
-        "simd/x86_64/jdcolor-avx2.obj",
-        "simd/x86_64/jdcolor-sse2.obj",
-        "simd/x86_64/jdmerge-avx2.obj",
-        "simd/x86_64/jdmerge-sse2.obj",
-        "simd/x86_64/jdsample-avx2.obj",
-        "simd/x86_64/jdsample-sse2.obj",
-        "simd/x86_64/jfdctflt-sse.obj",
-        "simd/x86_64/jfdctfst-sse2.obj",
-        "simd/x86_64/jfdctint-avx2.obj",
-        "simd/x86_64/jfdctint-sse2.obj",
-        "simd/x86_64/jidctflt-sse2.obj",
-        "simd/x86_64/jidctfst-sse2.obj",
-        "simd/x86_64/jidctint-avx2.obj",
-        "simd/x86_64/jidctint-sse2.obj",
-        "simd/x86_64/jidctred-sse2.obj",
-        "simd/x86_64/jquantf-sse2.obj",
-        "simd/x86_64/jquanti-avx2.obj",
-        "simd/x86_64/jquanti-sse2.obj",
-        "simd/x86_64/jsimdcpu.obj",
-    ],
-    cmd = "for out in $(OUTS); do\n" +
-          "  $(location @nasm//:nasm) -fwin64 -DWIN64 -D__x86_64__" +
-          "    -I $$(dirname $(location simd/x86_64/jccolext-sse2.asm))/" +
-          "    -I $$(dirname $(location simd/nasm/jdct.inc))/" +
-          "    -I $$(dirname $(location simd/nasm/jdct.inc))/../../win/" +
-          "    -o $$out" +
-          "    $$(dirname $(location simd/x86_64/jccolext-sse2.asm))/$$(basename $${out%.obj}.asm)\n" +
-          "done",
-    tools = ["@nasm"],
-)
-
-cc_library(
-    name = "simd_none",
-    srcs = [
-        "jchuff.h",
-        "jconfig.h",
-        "jconfigint.h",
-        "jdct.h",
-        "jerror.h",
-        "jinclude.h",
-        "jmorecfg.h",
-        "jpegint.h",
-        "jpeglib.h",
-        "jsimd.h",
-        "jsimd_none.c",
-        "jsimddct.h",
-    ],
-    copts = libjpegturbo_copts,
-)
-
-expand_template(
-    name = "jversion",
-    out = "jversion.h",
-    substitutions = {
-        "@COPYRIGHT_YEAR@": "1991-2022",
-    },
-    template = "jversion.h.in",
-)
-
-expand_template(
-    name = "jconfig_win",
-    out = "jconfig_win.h",
-    substitutions = {
-        "@JPEG_LIB_VERSION@": "62",
-        "@VERSION@": "2.1.4",
-        "@LIBJPEG_TURBO_VERSION_NUMBER@": "2001004",
-        "@BITS_IN_JSAMPLE@": "8",
-        "#cmakedefine C_ARITH_CODING_SUPPORTED": "#define C_ARITH_CODING_SUPPORTED",
-        "#cmakedefine D_ARITH_CODING_SUPPORTED": "#define D_ARITH_CODING_SUPPORTED",
-        "#cmakedefine MEM_SRCDST_SUPPORTED": "#define MEM_SRCDST_SUPPORTED",
-        "#cmakedefine WITH_SIMD": "",
-    },
-    template = "win/jconfig.h.in",
-)
-
-JCONFIG_NOWIN_COMMON_SUBSTITUTIONS = {
-    "@JPEG_LIB_VERSION@": "62",
-    "@VERSION@": "2.1.4",
-    "@LIBJPEG_TURBO_VERSION_NUMBER@": "2001004",
-    "#cmakedefine C_ARITH_CODING_SUPPORTED 1": "#define C_ARITH_CODING_SUPPORTED 1",
-    "#cmakedefine D_ARITH_CODING_SUPPORTED 1": "#define D_ARITH_CODING_SUPPORTED 1",
-    "#cmakedefine MEM_SRCDST_SUPPORTED 1": "#define MEM_SRCDST_SUPPORTED 1",
-    "@BITS_IN_JSAMPLE@": "8",
-    "#cmakedefine HAVE_LOCALE_H 1": "#define HAVE_LOCALE_H 1",
-    "#cmakedefine HAVE_STDDEF_H 1": "#define HAVE_STDDEF_H 1",
-    "#cmakedefine HAVE_STDLIB_H 1": "#define HAVE_STDLIB_H 1",
-    "#cmakedefine NEED_SYS_TYPES_H 1": "#define NEED_SYS_TYPES_H 1",
-    "#cmakedefine NEED_BSD_STRINGS 1": "",
-    "#cmakedefine HAVE_UNSIGNED_CHAR 1": "#define HAVE_UNSIGNED_CHAR 1",
-    "#cmakedefine HAVE_UNSIGNED_SHORT 1": "#define HAVE_UNSIGNED_SHORT 1",
-    "#cmakedefine INCOMPLETE_TYPES_BROKEN 1": "",
-    "#cmakedefine RIGHT_SHIFT_IS_UNSIGNED 1": "",
-    "#cmakedefine __CHAR_UNSIGNED__ 1": "",
-    "#undef const": "",
-    "#undef size_t": "",
-}
-
-JCONFIG_NOWIN_SIMD_SUBSTITUTIONS = {
-    "#cmakedefine WITH_SIMD 1": "#define WITH_SIMD 1",
-}
-
-JCONFIG_NOWIN_NOSIMD_SUBSTITUTIONS = {
-    "#cmakedefine WITH_SIMD 1": "",
-}
-
-JCONFIG_NOWIN_SIMD_SUBSTITUTIONS.update(JCONFIG_NOWIN_COMMON_SUBSTITUTIONS)
-
-JCONFIG_NOWIN_NOSIMD_SUBSTITUTIONS.update(JCONFIG_NOWIN_COMMON_SUBSTITUTIONS)
-
-expand_template(
-    name = "jconfig_nowin_nosimd",
-    out = "jconfig_nowin_nosimd.h",
-    substitutions = JCONFIG_NOWIN_NOSIMD_SUBSTITUTIONS,
-    template = "jconfig.h.in",
-)
-
-expand_template(
-    name = "jconfig_nowin_simd",
-    out = "jconfig_nowin_simd.h",
-    substitutions = JCONFIG_NOWIN_SIMD_SUBSTITUTIONS,
-    template = "jconfig.h.in",
-)
-
-JCONFIGINT_COMMON_SUBSTITUTIONS = {
-    "@BUILD@": "20221022",
-    "@VERSION@": "2.1.4",
-    "@CMAKE_PROJECT_NAME@": "libjpeg-turbo",
-    "#undef inline": "",
-    "#cmakedefine HAVE_INTRIN_H": "",
-}
-
-JCONFIGINT_NOWIN_SUBSTITUTIONS = {
-    "#cmakedefine HAVE_BUILTIN_CTZL": "#define HAVE_BUILTIN_CTZL",
-    "@INLINE@": "inline __attribute__((always_inline))",
-    "#define SIZEOF_SIZE_T  @SIZE_T@": "#if (__WORDSIZE==64 && !defined(__native_client__))\n" +
-                                       "#define SIZEOF_SIZE_T 8\n" +
-                                       "#else\n" +
-                                       "#define SIZEOF_SIZE_T 4\n" +
-                                       "#endif\n",
-}
-
-JCONFIGINT_WIN_SUBSTITUTIONS = {
-    "#cmakedefine HAVE_BUILTIN_CTZL": "",
-    "#define INLINE  @INLINE@": "#if defined(__GNUC__)\n" +
-                                "#define INLINE inline __attribute__((always_inline))\n" +
-                                "#elif defined(_MSC_VER)\n" +
-                                "#define INLINE __forceinline\n" +
-                                "#else\n" +
-                                "#define INLINE\n" +
-                                "#endif\n",
-    "#define SIZEOF_SIZE_T  @SIZE_T@": "#if (__WORDSIZE==64)\n" +
-                                       "#define SIZEOF_SIZE_T 8\n" +
-                                       "#else\n" +
-                                       "#define SIZEOF_SIZE_T 4\n" +
-                                       "#endif\n",
-}
-
-JCONFIGINT_NOWIN_SUBSTITUTIONS.update(JCONFIGINT_COMMON_SUBSTITUTIONS)
-
-JCONFIGINT_WIN_SUBSTITUTIONS.update(JCONFIGINT_COMMON_SUBSTITUTIONS)
-
-expand_template(
-    name = "jconfigint_nowin",
-    out = "jconfigint_nowin.h",
-    substitutions = JCONFIGINT_NOWIN_SUBSTITUTIONS,
-    template = "jconfigint.h.in",
-)
-
-expand_template(
-    name = "jconfigint_win",
-    out = "jconfigint_win.h",
-    substitutions = JCONFIGINT_WIN_SUBSTITUTIONS,
-    template = "jconfigint.h.in",
-)
-
-genrule(
-    name = "configure",
-    srcs = [
-        "jconfig_win.h",
-        "jconfig_nowin_nosimd.h",
-        "jconfig_nowin_simd.h",
-    ],
-    outs = ["jconfig.h"],
-    cmd = select({
-        ":windows": "cp $(location jconfig_win.h) $@",
-        ":k8": "cp $(location jconfig_nowin_simd.h) $@",
-        ":armeabi-v7a": "cp $(location jconfig_nowin_simd.h) $@",
-        ":arm64-v8a": "cp $(location jconfig_nowin_simd.h) $@",
-        ":linux_ppc64le": "cp $(location jconfig_nowin_simd.h) $@",
-        "//conditions:default": "cp $(location jconfig_nowin_nosimd.h) $@",
-    }),
-)
-
-genrule(
-    name = "configure_internal",
-    srcs = [
-        "jconfigint_win.h",
-        "jconfigint_nowin.h",
-    ],
-    outs = ["jconfigint.h"],
-    cmd = select({
-        ":windows": "cp $(location jconfigint_win.h) $@",
-        "//conditions:default": "cp $(location jconfigint_nowin.h) $@",
-    }),
-)
-
-# jiminy cricket the way this file is generated is completely outrageous
-genrule(
-    name = "configure_simd",
-    outs = ["simd/jsimdcfg.inc"],
-    cmd = "cat <<'EOF' >$@\n" +
-          "%define DCTSIZE 8\n" +
-          "%define DCTSIZE2 64\n" +
-          "%define RGB_RED 0\n" +
-          "%define RGB_GREEN 1\n" +
-          "%define RGB_BLUE 2\n" +
-          "%define RGB_PIXELSIZE 3\n" +
-          "%define EXT_RGB_RED 0\n" +
-          "%define EXT_RGB_GREEN 1\n" +
-          "%define EXT_RGB_BLUE 2\n" +
-          "%define EXT_RGB_PIXELSIZE 3\n" +
-          "%define EXT_RGBX_RED 0\n" +
-          "%define EXT_RGBX_GREEN 1\n" +
-          "%define EXT_RGBX_BLUE 2\n" +
-          "%define EXT_RGBX_PIXELSIZE 4\n" +
-          "%define EXT_BGR_RED 2\n" +
-          "%define EXT_BGR_GREEN 1\n" +
-          "%define EXT_BGR_BLUE 0\n" +
-          "%define EXT_BGR_PIXELSIZE 3\n" +
-          "%define EXT_BGRX_RED 2\n" +
-          "%define EXT_BGRX_GREEN 1\n" +
-          "%define EXT_BGRX_BLUE 0\n" +
-          "%define EXT_BGRX_PIXELSIZE 4\n" +
-          "%define EXT_XBGR_RED 3\n" +
-          "%define EXT_XBGR_GREEN 2\n" +
-          "%define EXT_XBGR_BLUE 1\n" +
-          "%define EXT_XBGR_PIXELSIZE 4\n" +
-          "%define EXT_XRGB_RED 1\n" +
-          "%define EXT_XRGB_GREEN 2\n" +
-          "%define EXT_XRGB_BLUE 3\n" +
-          "%define EXT_XRGB_PIXELSIZE 4\n" +
-          "%define RGBX_FILLER_0XFF 1\n" +
-          "%define JSAMPLE byte ; unsigned char\n" +
-          "%define SIZEOF_JSAMPLE SIZEOF_BYTE ; sizeof(JSAMPLE)\n" +
-          "%define CENTERJSAMPLE 128\n" +
-          "%define JCOEF word ; short\n" +
-          "%define SIZEOF_JCOEF SIZEOF_WORD ; sizeof(JCOEF)\n" +
-          "%define JDIMENSION dword ; unsigned int\n" +
-          "%define SIZEOF_JDIMENSION SIZEOF_DWORD ; sizeof(JDIMENSION)\n" +
-          "%define JSAMPROW POINTER ; JSAMPLE * (jpeglib.h)\n" +
-          "%define JSAMPARRAY POINTER ; JSAMPROW * (jpeglib.h)\n" +
-          "%define JSAMPIMAGE POINTER ; JSAMPARRAY * (jpeglib.h)\n" +
-          "%define JCOEFPTR POINTER ; JCOEF * (jpeglib.h)\n" +
-          "%define SIZEOF_JSAMPROW SIZEOF_POINTER ; sizeof(JSAMPROW)\n" +
-          "%define SIZEOF_JSAMPARRAY SIZEOF_POINTER ; sizeof(JSAMPARRAY)\n" +
-          "%define SIZEOF_JSAMPIMAGE SIZEOF_POINTER ; sizeof(JSAMPIMAGE)\n" +
-          "%define SIZEOF_JCOEFPTR SIZEOF_POINTER ; sizeof(JCOEFPTR)\n" +
-          "%define DCTELEM word ; short\n" +
-          "%define SIZEOF_DCTELEM SIZEOF_WORD ; sizeof(DCTELEM)\n" +
-          "%define float FP32 ; float\n" +
-          "%define SIZEOF_FAST_FLOAT SIZEOF_FP32 ; sizeof(float)\n" +
-          "%define ISLOW_MULT_TYPE word ; must be short\n" +
-          "%define SIZEOF_ISLOW_MULT_TYPE SIZEOF_WORD ; sizeof(ISLOW_MULT_TYPE)\n" +
-          "%define IFAST_MULT_TYPE word ; must be short\n" +
-          "%define SIZEOF_IFAST_MULT_TYPE SIZEOF_WORD ; sizeof(IFAST_MULT_TYPE)\n" +
-          "%define IFAST_SCALE_BITS 2 ; fractional bits in scale factors\n" +
-          "%define FLOAT_MULT_TYPE FP32 ; must be float\n" +
-          "%define SIZEOF_FLOAT_MULT_TYPE SIZEOF_FP32 ; sizeof(FLOAT_MULT_TYPE)\n" +
-          "%define JSIMD_NONE 0x00\n" +
-          "%define JSIMD_MMX 0x01\n" +
-          "%define JSIMD_3DNOW 0x02\n" +
-          "%define JSIMD_SSE 0x04\n" +
-          "%define JSIMD_SSE2 0x08\n" +
-          "EOF",
-)
-
-string_flag(
-    name = "noasm",
-    build_setting_default = "no",
-)
-
-config_setting(
-    name = "nosimd",
-    flag_values = {":noasm": "yes"},
-)
-
-config_setting(
-    name = "k8",
-    flag_values = {":noasm": "no"},
-    values = {"cpu": "k8"},
-)
-
-config_setting(
-    name = "android",
-    values = {"crosstool_top": "//external:android/crosstool"},
-)
-
-config_setting(
-    name = "armeabi-v7a",
-    flag_values = {":noasm": "no"},
-    values = {"cpu": "armeabi-v7a"},
-)
-
-config_setting(
-    name = "arm64-v8a",
-    flag_values = {":noasm": "no"},
-    values = {"cpu": "arm64-v8a"},
-)
-
-config_setting(
-    name = "windows",
-    flag_values = {":noasm": "no"},
-    values = {"cpu": "x64_windows"},
-)
-
-config_setting(
-    name = "linux_ppc64le",
-    flag_values = {":noasm": "no"},
-    values = {"cpu": "ppc"},
-)
diff --git a/third_party/xla/third_party/jpeg/jpeg_helpers.BUILD.bazel b/third_party/xla/third_party/jpeg/jpeg_helpers.BUILD.bazel
deleted file mode 100644
index 5b01f6e..0000000
--- a/third_party/xla/third_party/jpeg/jpeg_helpers.BUILD.bazel
+++ /dev/null
@@ -1 +0,0 @@
-licenses(["notice"])
diff --git a/third_party/xla/third_party/jpeg/workspace.bzl b/third_party/xla/third_party/jpeg/workspace.bzl
deleted file mode 100644
index 631cc93..0000000
--- a/third_party/xla/third_party/jpeg/workspace.bzl
+++ /dev/null
@@ -1,13 +0,0 @@
-"""loads the jpeg library, used by TF."""
-
-load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
-
-def repo():
-    tf_http_archive(
-        name = "libjpeg_turbo",
-        urls = tf_mirror_urls("https://github.com/libjpeg-turbo/libjpeg-turbo/archive/refs/tags/2.1.4.tar.gz"),
-        sha256 = "a78b05c0d8427a90eb5b4eb08af25309770c8379592bb0b8a863373128e6143f",
-        strip_prefix = "libjpeg-turbo-2.1.4",
-        build_file = "//third_party/jpeg:jpeg.BUILD",
-        system_build_file = "//third_party/jpeg:BUILD.system",
-    )
diff --git a/third_party/xla/third_party/mkl/BUILD b/third_party/xla/third_party/mkl/BUILD
deleted file mode 100644
index fa6f51f..0000000
--- a/third_party/xla/third_party/mkl/BUILD
+++ /dev/null
@@ -1,77 +0,0 @@
-licenses(["notice"])  # 3-Clause BSD
-
-load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
-
-package(default_visibility = ["//visibility:public"])
-
-alias(
-    name = "build_with_mkl",
-    actual = "@local_tsl//tsl/mkl:build_with_mkl",
-    visibility = ["//visibility:public"],
-)
-
-alias(
-    name = "build_with_mkl_lnx_x64",
-    actual = "@local_tsl//tsl/mkl:build_with_mkl_lnx_x64",
-    visibility = ["//visibility:public"],
-)
-
-alias(
-    name = "build_with_mkl_lnx_openmp",
-    actual = "@local_tsl//tsl/mkl:build_with_mkl_lnx_openmp",
-    visibility = ["//visibility:public"],
-)
-
-alias(
-    name = "build_with_mkl_windows_openmp",
-    actual = "@local_tsl//tsl/mkl:build_with_mkl_windows_openmp",
-    visibility = ["//visibility:public"],
-)
-
-alias(
-    name = "build_with_mkl_aarch64",
-    actual = "@local_tsl//tsl/mkl:build_with_mkl_aarch64",
-    visibility = ["//visibility:public"],
-)
-
-alias(
-    name = "enable_mkl",
-    actual = "@local_tsl//tsl/mkl:enable_mkl",
-    visibility = ["//visibility:public"],
-)
-
-alias(
-    name = "intel_binary_blob",
-    actual = "@local_tsl//tsl/mkl:intel_binary_blob",
-    visibility = ["//visibility:public"],
-)
-
-alias(
-    name = "LICENSE",
-    actual = "@local_tsl//tsl/mkl:LICENSE",
-    visibility = ["//visibility:public"],
-)
-
-alias(
-    name = "mkl_libs_linux",
-    actual = "@local_tsl//tsl/mkl:mkl_libs_linux",
-    visibility = ["//visibility:public"],
-)
-
-alias(
-    name = "mkl_libs_darwin",
-    actual = "@local_tsl//tsl/mkl:mkl_libs_darwin",
-    visibility = ["//visibility:public"],
-)
-
-alias(
-    name = "mkl_libs_windows",
-    actual = "@local_tsl//tsl/mkl:mkl_libs_windows",
-    visibility = ["//visibility:public"],
-)
-
-bzl_library(
-    name = "build_defs_bzl",
-    srcs = ["build_defs.bzl"],
-    visibility = ["//visibility:public"],
-)
diff --git a/third_party/xla/third_party/mkl/build_defs.bzl b/third_party/xla/third_party/mkl/build_defs.bzl
deleted file mode 100644
index 76bea5d..0000000
--- a/third_party/xla/third_party/mkl/build_defs.bzl
+++ /dev/null
@@ -1,30 +0,0 @@
-"""Starlark macros for MKL.
-
-if_mkl is a conditional to check if we are building with MKL.
-if_mkl_ml is a conditional to check if we are building with MKL-ML.
-if_mkl_ml_only is a conditional to check for MKL-ML-only (no MKL-DNN) mode.
-if_mkl_lnx_x64 is a conditional to check for MKL
-if_enable_mkl is a conditional to check if building with MKL and MKL is enabled.
-
-mkl_repository is a repository rule for creating MKL repository rule that can
-be pointed to either a local folder, or downloaded from the internet.
-mkl_repository depends on the following environment variables:
-  * `TF_MKL_ROOT`: The root folder where a copy of libmkl is located.
-"""
-
-load(
-    "@local_tsl//tsl/mkl:build_defs.bzl",
-    _if_enable_mkl = "if_enable_mkl",
-    _if_mkl = "if_mkl",
-    _if_mkl_lnx_x64 = "if_mkl_lnx_x64",
-    _if_mkl_ml = "if_mkl_ml",
-    _mkl_deps = "mkl_deps",
-    _mkl_repository = "mkl_repository",
-)
-
-if_mkl = _if_mkl
-if_mkl_ml = _if_mkl_ml
-if_mkl_lnx_x64 = _if_mkl_lnx_x64
-if_enable_mkl = _if_enable_mkl
-mkl_deps = _mkl_deps
-mkl_repository = _mkl_repository
diff --git a/third_party/xla/third_party/mkl_dnn/BUILD b/third_party/xla/third_party/mkl_dnn/BUILD
deleted file mode 100644
index c536923..0000000
--- a/third_party/xla/third_party/mkl_dnn/BUILD
+++ /dev/null
@@ -1,52 +0,0 @@
-load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
-
-package(
-    default_visibility = ["//visibility:public"],
-    licenses = ["notice"],
-)
-
-exports_files(
-    ["LICENSE"],
-    visibility = ["//visibility:public"],
-)
-
-config_setting(
-    name = "build_with_mkl_opensource",
-    define_values = {
-        "build_with_mkl": "true",
-        "build_with_mkl_opensource": "true",
-    },
-    visibility = ["//visibility:public"],
-)
-
-config_setting(
-    name = "build_with_mkldnn_openmp",
-    define_values = {
-        "build_with_mkl": "true",
-        "build_with_openmp": "true",
-    },
-    visibility = ["//visibility:public"],
-)
-
-config_setting(
-    name = "build_with_mkl_aarch64_openmp",
-    define_values = {
-        "build_with_mkl_aarch64": "true",
-        "build_with_openmp": "true",
-    },
-    visibility = ["//visibility:public"],
-)
-
-config_setting(
-    name = "build_with_mkl_aarch64",
-    define_values = {
-        "build_with_mkl_aarch64": "true",
-    },
-    visibility = ["//visibility:public"],
-)
-
-bzl_library(
-    name = "build_defs_bzl",
-    srcs = ["build_defs.bzl"],
-    visibility = ["//visibility:public"],
-)
diff --git a/third_party/xla/third_party/mkl_dnn/LICENSE b/third_party/xla/third_party/mkl_dnn/LICENSE
deleted file mode 100644
index 8dada3e..0000000
--- a/third_party/xla/third_party/mkl_dnn/LICENSE
+++ /dev/null
@@ -1,201 +0,0 @@
-                                 Apache License
-                           Version 2.0, January 2004
-                        http://www.apache.org/licenses/
-
-   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-
-   1. Definitions.
-
-      "License" shall mean the terms and conditions for use, reproduction,
-      and distribution as defined by Sections 1 through 9 of this document.
-
-      "Licensor" shall mean the copyright owner or entity authorized by
-      the copyright owner that is granting the License.
-
-      "Legal Entity" shall mean the union of the acting entity and all
-      other entities that control, are controlled by, or are under common
-      control with that entity. For the purposes of this definition,
-      "control" means (i) the power, direct or indirect, to cause the
-      direction or management of such entity, whether by contract or
-      otherwise, or (ii) ownership of fifty percent (50%) or more of the
-      outstanding shares, or (iii) beneficial ownership of such entity.
-
-      "You" (or "Your") shall mean an individual or Legal Entity
-      exercising permissions granted by this License.
-
-      "Source" form shall mean the preferred form for making modifications,
-      including but not limited to software source code, documentation
-      source, and configuration files.
-
-      "Object" form shall mean any form resulting from mechanical
-      transformation or translation of a Source form, including but
-      not limited to compiled object code, generated documentation,
-      and conversions to other media types.
-
-      "Work" shall mean the work of authorship, whether in Source or
-      Object form, made available under the License, as indicated by a
-      copyright notice that is included in or attached to the work
-      (an example is provided in the Appendix below).
-
-      "Derivative Works" shall mean any work, whether in Source or Object
-      form, that is based on (or derived from) the Work and for which the
-      editorial revisions, annotations, elaborations, or other modifications
-      represent, as a whole, an original work of authorship. For the purposes
-      of this License, Derivative Works shall not include works that remain
-      separable from, or merely link (or bind by name) to the interfaces of,
-      the Work and Derivative Works thereof.
-
-      "Contribution" shall mean any work of authorship, including
-      the original version of the Work and any modifications or additions
-      to that Work or Derivative Works thereof, that is intentionally
-      submitted to Licensor for inclusion in the Work by the copyright owner
-      or by an individual or Legal Entity authorized to submit on behalf of
-      the copyright owner. For the purposes of this definition, "submitted"
-      means any form of electronic, verbal, or written communication sent
-      to the Licensor or its representatives, including but not limited to
-      communication on electronic mailing lists, source code control systems,
-      and issue tracking systems that are managed by, or on behalf of, the
-      Licensor for the purpose of discussing and improving the Work, but
-      excluding communication that is conspicuously marked or otherwise
-      designated in writing by the copyright owner as "Not a Contribution."
-
-      "Contributor" shall mean Licensor and any individual or Legal Entity
-      on behalf of whom a Contribution has been received by Licensor and
-      subsequently incorporated within the Work.
-
-   2. Grant of Copyright License. Subject to the terms and conditions of
-      this License, each Contributor hereby grants to You a perpetual,
-      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
-      copyright license to reproduce, prepare Derivative Works of,
-      publicly display, publicly perform, sublicense, and distribute the
-      Work and such Derivative Works in Source or Object form.
-
-   3. Grant of Patent License. Subject to the terms and conditions of
-      this License, each Contributor hereby grants to You a perpetual,
-      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
-      (except as stated in this section) patent license to make, have made,
-      use, offer to sell, sell, import, and otherwise transfer the Work,
-      where such license applies only to those patent claims licensable
-      by such Contributor that are necessarily infringed by their
-      Contribution(s) alone or by combination of their Contribution(s)
-      with the Work to which such Contribution(s) was submitted. If You
-      institute patent litigation against any entity (including a
-      cross-claim or counterclaim in a lawsuit) alleging that the Work
-      or a Contribution incorporated within the Work constitutes direct
-      or contributory patent infringement, then any patent licenses
-      granted to You under this License for that Work shall terminate
-      as of the date such litigation is filed.
-
-   4. Redistribution. You may reproduce and distribute copies of the
-      Work or Derivative Works thereof in any medium, with or without
-      modifications, and in Source or Object form, provided that You
-      meet the following conditions:
-
-      (a) You must give any other recipients of the Work or
-          Derivative Works a copy of this License; and
-
-      (b) You must cause any modified files to carry prominent notices
-          stating that You changed the files; and
-
-      (c) You must retain, in the Source form of any Derivative Works
-          that You distribute, all copyright, patent, trademark, and
-          attribution notices from the Source form of the Work,
-          excluding those notices that do not pertain to any part of
-          the Derivative Works; and
-
-      (d) If the Work includes a "NOTICE" text file as part of its
-          distribution, then any Derivative Works that You distribute must
-          include a readable copy of the attribution notices contained
-          within such NOTICE file, excluding those notices that do not
-          pertain to any part of the Derivative Works, in at least one
-          of the following places: within a NOTICE text file distributed
-          as part of the Derivative Works; within the Source form or
-          documentation, if provided along with the Derivative Works; or,
-          within a display generated by the Derivative Works, if and
-          wherever such third-party notices normally appear. The contents
-          of the NOTICE file are for informational purposes only and
-          do not modify the License. You may add Your own attribution
-          notices within Derivative Works that You distribute, alongside
-          or as an addendum to the NOTICE text from the Work, provided
-          that such additional attribution notices cannot be construed
-          as modifying the License.
-
-      You may add Your own copyright statement to Your modifications and
-      may provide additional or different license terms and conditions
-      for use, reproduction, or distribution of Your modifications, or
-      for any such Derivative Works as a whole, provided Your use,
-      reproduction, and distribution of the Work otherwise complies with
-      the conditions stated in this License.
-
-   5. Submission of Contributions. Unless You explicitly state otherwise,
-      any Contribution intentionally submitted for inclusion in the Work
-      by You to the Licensor shall be under the terms and conditions of
-      this License, without any additional terms or conditions.
-      Notwithstanding the above, nothing herein shall supersede or modify
-      the terms of any separate license agreement you may have executed
-      with Licensor regarding such Contributions.
-
-   6. Trademarks. This License does not grant permission to use the trade
-      names, trademarks, service marks, or product names of the Licensor,
-      except as required for reasonable and customary use in describing the
-      origin of the Work and reproducing the content of the NOTICE file.
-
-   7. Disclaimer of Warranty. Unless required by applicable law or
-      agreed to in writing, Licensor provides the Work (and each
-      Contributor provides its Contributions) on an "AS IS" BASIS,
-      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
-      implied, including, without limitation, any warranties or conditions
-      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
-      PARTICULAR PURPOSE. You are solely responsible for determining the
-      appropriateness of using or redistributing the Work and assume any
-      risks associated with Your exercise of permissions under this License.
-
-   8. Limitation of Liability. In no event and under no legal theory,
-      whether in tort (including negligence), contract, or otherwise,
-      unless required by applicable law (such as deliberate and grossly
-      negligent acts) or agreed to in writing, shall any Contributor be
-      liable to You for damages, including any direct, indirect, special,
-      incidental, or consequential damages of any character arising as a
-      result of this License or out of the use or inability to use the
-      Work (including but not limited to damages for loss of goodwill,
-      work stoppage, computer failure or malfunction, or any and all
-      other commercial damages or losses), even if such Contributor
-      has been advised of the possibility of such damages.
-
-   9. Accepting Warranty or Additional Liability. While redistributing
-      the Work or Derivative Works thereof, You may choose to offer,
-      and charge a fee for, acceptance of support, warranty, indemnity,
-      or other liability obligations and/or rights consistent with this
-      License. However, in accepting such obligations, You may act only
-      on Your own behalf and on Your sole responsibility, not on behalf
-      of any other Contributor, and only if You agree to indemnify,
-      defend, and hold each Contributor harmless for any liability
-      incurred by, or claims asserted against, such Contributor by reason
-      of your accepting any such warranty or additional liability.
-
-   END OF TERMS AND CONDITIONS
-
-   APPENDIX: How to apply the Apache License to your work.
-
-      To apply the Apache License to your work, attach the following
-      boilerplate notice, with the fields enclosed by brackets "{}"
-      replaced with your own identifying information. (Don't include
-      the brackets!)  The text should be enclosed in the appropriate
-      comment syntax for the file format. We also recommend that a
-      file or class name and description of purpose be included on the
-      same "printed page" as the copyright notice for easier
-      identification within third-party archives.
-
-   Copyright {yyyy} {name of copyright owner}
-
-   Licensed under the Apache License, Version 2.0 (the "License");
-   you may not use this file except in compliance with the License.
-   You may obtain a copy of the License at
-
-       http://www.apache.org/licenses/LICENSE-2.0
-
-   Unless required by applicable law or agreed to in writing, software
-   distributed under the License is distributed on an "AS IS" BASIS,
-   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-   See the License for the specific language governing permissions and
-   limitations under the License.
diff --git a/third_party/xla/third_party/mkl_dnn/build_defs.bzl b/third_party/xla/third_party/mkl_dnn/build_defs.bzl
deleted file mode 100644
index b0ed90f..0000000
--- a/third_party/xla/third_party/mkl_dnn/build_defs.bzl
+++ /dev/null
@@ -1,34 +0,0 @@
-"""Starlark macros for oneDNN.
-
-if_mkldnn_openmp checks if we are building x86 backend with OpenMP.
-if_mkldnn_aarch64_acl checks if we are building with Arm Compute Library.
-if_mkldnn_aarch64_acl_openmp checks if we are building ACL with OpenMP.
-"""
-
-def if_mkldnn_openmp(if_true, if_false = []):
-    """Returns `if_true` if OpenMP is used with oneDNN.
-
-    Shorthand for select()'ing on whether we're building with
-    oneDNN open source library only with openmp
-
-    Returns a select statement which evaluates to if_true if we're building
-    with oneDNN open source library only with OpenMP. Otherwise, the
-    select statement evaluates to if_false.
-
-    """
-    return select({
-        "@local_xla//third_party/mkl_dnn:build_with_mkldnn_openmp": if_true,
-        "//conditions:default": if_false,
-    })
-
-def if_mkldnn_aarch64_acl(if_true, if_false = []):
-    return select({
-        "@local_xla//third_party/mkl:build_with_mkl_aarch64": if_true,
-        "//conditions:default": if_false,
-    })
-
-def if_mkldnn_aarch64_acl_openmp(if_true, if_false = []):
-    return select({
-        "@local_xla//third_party/mkl_dnn:build_with_mkl_aarch64_openmp": if_true,
-        "//conditions:default": if_false,
-    })
diff --git a/third_party/xla/third_party/mkl_dnn/mkldnn_acl.BUILD b/third_party/xla/third_party/mkl_dnn/mkldnn_acl.BUILD
deleted file mode 100644
index d453ee8..0000000
--- a/third_party/xla/third_party/mkl_dnn/mkldnn_acl.BUILD
+++ /dev/null
@@ -1,181 +0,0 @@
-exports_files(["LICENSE"])
-
-load("@bazel_skylib//rules:expand_template.bzl", "expand_template")
-
-_DNNL_COPTS_THREADPOOL = [
-    "-fopenmp-simd",
-    "-fexceptions",
-    "-UUSE_MKL",
-    "-UUSE_CBLAS",
-]
-
-_DNNL_COPTS_OMP = [
-    "-fopenmp",
-    "-fexceptions",
-    "-UUSE_MKL",
-    "-UUSE_CBLAS",
-]
-
-_DNNL_RUNTIME_THREADPOOL = {
-    "#cmakedefine DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_${DNNL_CPU_THREADING_RUNTIME}": "#define DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_THREADPOOL",
-    "#cmakedefine DNNL_CPU_RUNTIME DNNL_RUNTIME_${DNNL_CPU_RUNTIME}": "#define DNNL_CPU_RUNTIME DNNL_RUNTIME_THREADPOOL",
-    "#cmakedefine DNNL_GPU_RUNTIME DNNL_RUNTIME_${DNNL_GPU_RUNTIME}": "#define DNNL_GPU_RUNTIME DNNL_RUNTIME_NONE",
-    "#cmakedefine DNNL_USE_RT_OBJECTS_IN_PRIMITIVE_CACHE": "#undef DNNL_USE_RT_OBJECTS_IN_PRIMITIVE_CACHE",
-    "#cmakedefine DNNL_WITH_SYCL": "#undef DNNL_WITH_SYCL",
-    "#cmakedefine DNNL_WITH_LEVEL_ZERO": "#undef DNNL_WITH_LEVEL_ZERO",
-    "#cmakedefine DNNL_SYCL_CUDA": "#undef DNNL_SYCL_CUDA",
-    "#cmakedefine DNNL_SYCL_HIP": "#undef DNNL_SYCL_HIP",
-    "#cmakedefine DNNL_ENABLE_STACK_CHECKER": "#undef DNNL_ENABLE_STACK_CHECKER",
-    "#cmakedefine DNNL_EXPERIMENTAL": "#undef DNNL_EXPERIMENTAL",
-    "#cmakedefine ONEDNN_BUILD_GRAPH": "#undef ONEDNN_BUILD_GRAPH",
-    "#cmakedefine01 BUILD_TRAINING": "#define BUILD_TRAINING 1",
-    "#cmakedefine01 BUILD_INFERENCE": "#define BUILD_INFERENCE 0",
-    "#cmakedefine01 BUILD_PRIMITIVE_ALL": "#define BUILD_PRIMITIVE_ALL 1",
-    "#cmakedefine01 BUILD_BATCH_NORMALIZATION": "#define BUILD_BATCH_NORMALIZATION 0",
-    "#cmakedefine01 BUILD_BINARY": "#define BUILD_BINARY 0",
-    "#cmakedefine01 BUILD_CONCAT": "#define BUILD_CONCAT 0",
-    "#cmakedefine01 BUILD_CONVOLUTION": "#define BUILD_CONVOLUTION 0",
-    "#cmakedefine01 BUILD_DECONVOLUTION": "#define BUILD_DECONVOLUTION 0",
-    "#cmakedefine01 BUILD_ELTWISE": "#define BUILD_ELTWISE 0",
-    "#cmakedefine01 BUILD_INNER_PRODUCT": "#define BUILD_INNER_PRODUCT 0",
-    "#cmakedefine01 BUILD_LAYER_NORMALIZATION": "#define BUILD_LAYER_NORMALIZATION 0",
-    "#cmakedefine01 BUILD_LRN": "#define BUILD_LRN 0",
-    "#cmakedefine01 BUILD_MATMUL": "#define BUILD_MATMUL 0",
-    "#cmakedefine01 BUILD_POOLING": "#define BUILD_POOLING 0",
-    "#cmakedefine01 BUILD_PRELU": "#define BUILD_PRELU 0",
-    "#cmakedefine01 BUILD_REDUCTION": "#define BUILD_REDUCTION 0",
-    "#cmakedefine01 BUILD_REORDER": "#define BUILD_REORDER 0",
-    "#cmakedefine01 BUILD_RESAMPLING": "#define BUILD_RESAMPLING 0",
-    "#cmakedefine01 BUILD_RNN": "#define BUILD_RNN 0",
-    "#cmakedefine01 BUILD_SHUFFLE": "#define BUILD_SHUFFLE 0",
-    "#cmakedefine01 BUILD_SOFTMAX": "#define BUILD_SOFTMAX 0",
-    "#cmakedefine01 BUILD_SUM": "#define BUILD_SUM 0",
-    "#cmakedefine01 BUILD_PRIMITIVE_CPU_ISA_ALL": "#define BUILD_PRIMITIVE_CPU_ISA_ALL 0",
-    "#cmakedefine01 BUILD_SSE41": "#define BUILD_SSE41 0",
-    "#cmakedefine01 BUILD_AVX2": "#define BUILD_AVX2 0",
-    "#cmakedefine01 BUILD_AVX512": "#define BUILD_AVX512 0",
-    "#cmakedefine01 BUILD_AMX": "#define BUILD_AMX 0",
-    "#cmakedefine01 BUILD_PRIMITIVE_GPU_ISA_ALL": "#define BUILD_PRIMITIVE_GPU_ISA_ALL 0",
-    "#cmakedefine01 BUILD_GEN9": "#define BUILD_GEN9 0",
-    "#cmakedefine01 BUILD_GEN11": "#define BUILD_GEN11 0",
-    "#cmakedefine01 BUILD_XELP": "#define BUILD_XELP 0",
-    "#cmakedefine01 BUILD_XEHPG": "#define BUILD_XEHPG 0",
-    "#cmakedefine01 BUILD_XEHPC": "#define BUILD_XEHPC 0",
-    "#cmakedefine01 BUILD_XEHP": "#define BUILD_XEHP 0",
-}
-
-_DNNL_RUNTIME_OMP = {
-    "#cmakedefine DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_${DNNL_CPU_THREADING_RUNTIME}": "#define DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_OMP",
-    "#cmakedefine DNNL_CPU_RUNTIME DNNL_RUNTIME_${DNNL_CPU_RUNTIME}": "#define DNNL_CPU_RUNTIME DNNL_RUNTIME_OMP",
-    "#cmakedefine DNNL_GPU_RUNTIME DNNL_RUNTIME_${DNNL_GPU_RUNTIME}": "#define DNNL_GPU_RUNTIME DNNL_RUNTIME_NONE",
-    "#cmakedefine DNNL_USE_RT_OBJECTS_IN_PRIMITIVE_CACHE": "#undef DNNL_USE_RT_OBJECTS_IN_PRIMITIVE_CACHE",
-    "#cmakedefine DNNL_WITH_SYCL": "#undef DNNL_WITH_SYCL",
-    "#cmakedefine DNNL_WITH_LEVEL_ZERO": "#undef DNNL_WITH_LEVEL_ZERO",
-    "#cmakedefine DNNL_SYCL_CUDA": "#undef DNNL_SYCL_CUDA",
-    "#cmakedefine DNNL_SYCL_HIP": "#undef DNNL_SYCL_HIP",
-    "#cmakedefine DNNL_ENABLE_STACK_CHECKER": "#undef DNNL_ENABLE_STACK_CHECKER",
-    "#cmakedefine DNNL_EXPERIMENTAL": "#undef DNNL_EXPERIMENTAL",
-    "#cmakedefine ONEDNN_BUILD_GRAPH": "#undef ONEDNN_BUILD_GRAPH",
-    "#cmakedefine01 BUILD_TRAINING": "#define BUILD_TRAINING 1",
-    "#cmakedefine01 BUILD_INFERENCE": "#define BUILD_INFERENCE 0",
-    "#cmakedefine01 BUILD_PRIMITIVE_ALL": "#define BUILD_PRIMITIVE_ALL 1",
-    "#cmakedefine01 BUILD_BATCH_NORMALIZATION": "#define BUILD_BATCH_NORMALIZATION 0",
-    "#cmakedefine01 BUILD_BINARY": "#define BUILD_BINARY 0",
-    "#cmakedefine01 BUILD_CONCAT": "#define BUILD_CONCAT 0",
-    "#cmakedefine01 BUILD_CONVOLUTION": "#define BUILD_CONVOLUTION 0",
-    "#cmakedefine01 BUILD_DECONVOLUTION": "#define BUILD_DECONVOLUTION 0",
-    "#cmakedefine01 BUILD_ELTWISE": "#define BUILD_ELTWISE 0",
-    "#cmakedefine01 BUILD_INNER_PRODUCT": "#define BUILD_INNER_PRODUCT 0",
-    "#cmakedefine01 BUILD_LAYER_NORMALIZATION": "#define BUILD_LAYER_NORMALIZATION 0",
-    "#cmakedefine01 BUILD_LRN": "#define BUILD_LRN 0",
-    "#cmakedefine01 BUILD_MATMUL": "#define BUILD_MATMUL 0",
-    "#cmakedefine01 BUILD_POOLING": "#define BUILD_POOLING 0",
-    "#cmakedefine01 BUILD_PRELU": "#define BUILD_PRELU 0",
-    "#cmakedefine01 BUILD_REDUCTION": "#define BUILD_REDUCTION 0",
-    "#cmakedefine01 BUILD_REORDER": "#define BUILD_REORDER 0",
-    "#cmakedefine01 BUILD_RESAMPLING": "#define BUILD_RESAMPLING 0",
-    "#cmakedefine01 BUILD_RNN": "#define BUILD_RNN 0",
-    "#cmakedefine01 BUILD_SHUFFLE": "#define BUILD_SHUFFLE 0",
-    "#cmakedefine01 BUILD_SOFTMAX": "#define BUILD_SOFTMAX 0",
-    "#cmakedefine01 BUILD_SUM": "#define BUILD_SUM 0",
-    "#cmakedefine01 BUILD_PRIMITIVE_CPU_ISA_ALL": "#define BUILD_PRIMITIVE_CPU_ISA_ALL 0",
-    "#cmakedefine01 BUILD_SSE41": "#define BUILD_SSE41 0",
-    "#cmakedefine01 BUILD_AVX2": "#define BUILD_AVX2 0",
-    "#cmakedefine01 BUILD_AVX512": "#define BUILD_AVX512 0",
-    "#cmakedefine01 BUILD_AMX": "#define BUILD_AMX 0",
-    "#cmakedefine01 BUILD_PRIMITIVE_GPU_ISA_ALL": "#define BUILD_PRIMITIVE_GPU_ISA_ALL 0",
-    "#cmakedefine01 BUILD_GEN9": "#define BUILD_GEN9 0",
-    "#cmakedefine01 BUILD_GEN11": "#define BUILD_GEN11 0",
-    "#cmakedefine01 BUILD_XELP": "#define BUILD_XELP 0",
-    "#cmakedefine01 BUILD_XEHPG": "#define BUILD_XEHPG 0",
-    "#cmakedefine01 BUILD_XEHPC": "#define BUILD_XEHPC 0",
-    "#cmakedefine01 BUILD_XEHP": "#define BUILD_XEHP 0",
-}
-
-expand_template(
-    name = "dnnl_config_h",
-    out = "include/oneapi/dnnl/dnnl_config.h",
-    substitutions = select({
-        "@local_xla//third_party/mkl_dnn:build_with_mkl_aarch64_openmp": _DNNL_RUNTIME_OMP,
-        "//conditions:default": _DNNL_RUNTIME_THREADPOOL,
-    }),
-    template = "include/oneapi/dnnl/dnnl_config.h.in",
-)
-
-expand_template(
-    name = "dnnl_version_h",
-    out = "include/oneapi/dnnl/dnnl_version.h",
-    substitutions = {
-        "@DNNL_VERSION_MAJOR@": "3",
-        "@DNNL_VERSION_MINOR@": "2",
-        "@DNNL_VERSION_PATCH@": "1",
-        "@DNNL_VERSION_HASH@": "N/A",
-    },
-    template = "include/oneapi/dnnl/dnnl_version.h.in",
-)
-
-cc_library(
-    name = "mkl_dnn_acl",
-    srcs = glob(
-        [
-            "src/common/*.cpp",
-            "src/cpu/**/*.cpp",
-            "src/cpu/*.cpp",
-        ],
-        exclude = [
-            "src/cpu/x64/**",
-            "src/cpu/rv64/**",
-        ],
-    ),
-    copts = select({
-        "@local_xla//third_party/mkl_dnn:build_with_mkl_aarch64_openmp": _DNNL_COPTS_OMP,
-        "//conditions:default": _DNNL_COPTS_THREADPOOL,
-    }),
-    defines = ["DNNL_AARCH64_USE_ACL=1"],
-    includes = [
-        "include",
-        "src",
-        "src/common",
-        "src/cpu",
-        "src/cpu/aarch64/xbyak_aarch64/src",
-        "src/cpu/aarch64/xbyak_aarch64/xbyak_aarch64",
-        "src/cpu/gemm",
-    ],
-    textual_hdrs = glob(
-        [
-            "include/**/*",
-            "include/*",
-            "src/common/*.hpp",
-            "src/cpu/**/*.hpp",
-            "src/cpu/*.hpp",
-            "src/cpu/aarch64/xbyak_aarch64/**/*.h",
-        ],
-    ) + [
-        ":dnnl_config_h",
-        ":dnnl_version_h",
-    ],
-    visibility = ["//visibility:public"],
-    deps = [
-        "@compute_library//:arm_compute",
-    ],
-)
diff --git a/third_party/xla/third_party/mkl_dnn/mkldnn_v1.BUILD b/third_party/xla/third_party/mkl_dnn/mkldnn_v1.BUILD
deleted file mode 100644
index b19a9a5..0000000
--- a/third_party/xla/third_party/mkl_dnn/mkldnn_v1.BUILD
+++ /dev/null
@@ -1,185 +0,0 @@
-load("@local_tsl//tsl:tsl.bzl", "tf_openmp_copts")
-load("@local_xla//third_party/mkl:build_defs.bzl", "if_mkl")
-load("@local_xla//third_party/mkl_dnn:build_defs.bzl", "if_mkldnn_openmp")
-load("@local_xla//third_party/mkl:build_defs.bzl", "if_mkl_ml")
-load("@bazel_skylib//rules:expand_template.bzl", "expand_template")
-
-exports_files(["LICENSE"])
-
-_CMAKE_COMMON_LIST = {
-    "#cmakedefine DNNL_GPU_RUNTIME DNNL_RUNTIME_${DNNL_GPU_RUNTIME}": "#define DNNL_GPU_RUNTIME DNNL_RUNTIME_NONE",
-    "#cmakedefine DNNL_USE_RT_OBJECTS_IN_PRIMITIVE_CACHE": "#undef DNNL_USE_RT_OBJECTS_IN_PRIMITIVE_CACHE",
-    "#cmakedefine DNNL_WITH_SYCL": "#undef DNNL_WITH_SYCL",
-    "#cmakedefine DNNL_WITH_LEVEL_ZERO": "#undef DNNL_WITH_LEVEL_ZERO",
-    "#cmakedefine DNNL_SYCL_CUDA": "#undef DNNL_SYCL_CUDA",
-    "#cmakedefine DNNL_SYCL_HIP": "#undef DNNL_SYCL_HIP",
-    "#cmakedefine DNNL_ENABLE_STACK_CHECKER": "#undef DNNL_ENABLE_STACK_CHECKER",
-    "#cmakedefine DNNL_EXPERIMENTAL": "#undef DNNL_EXPERIMENTAL",
-    "#cmakedefine ONEDNN_BUILD_GRAPH": "#undef ONEDNN_BUILD_GRAPH",
-    "#cmakedefine01 BUILD_TRAINING": "#define BUILD_TRAINING 1",
-    "#cmakedefine01 BUILD_INFERENCE": "#define BUILD_INFERENCE 0",
-    "#cmakedefine01 BUILD_PRIMITIVE_ALL": "#define BUILD_PRIMITIVE_ALL 1",
-    "#cmakedefine01 BUILD_BATCH_NORMALIZATION": "#define BUILD_BATCH_NORMALIZATION 0",
-    "#cmakedefine01 BUILD_BINARY": "#define BUILD_BINARY 0",
-    "#cmakedefine01 BUILD_CONCAT": "#define BUILD_CONCAT 0",
-    "#cmakedefine01 BUILD_CONVOLUTION": "#define BUILD_CONVOLUTION 0",
-    "#cmakedefine01 BUILD_DECONVOLUTION": "#define BUILD_DECONVOLUTION 0",
-    "#cmakedefine01 BUILD_ELTWISE": "#define BUILD_ELTWISE 0",
-    "#cmakedefine01 BUILD_GEMM_KERNELS_ALL": "#define BUILD_GEMM_KERNELS_ALL 1",
-    "#cmakedefine01 BUILD_GEMM_KERNELS_NONE": "#define BUILD_GEMM_KERNELS_NONE 0",
-    "#cmakedefine01 BUILD_GEMM_SSE41": "#define BUILD_GEMM_SSE41 1",
-    "#cmakedefine01 BUILD_GEMM_AVX2": "#define BUILD_GEMM_AVX2 1",
-    "#cmakedefine01 BUILD_GEMM_AVX512": "#define BUILD_GEMM_AVX512 1",
-    "#cmakedefine01 BUILD_GROUP_NORMALIZATION": "#define BUILD_GROUP_NORMALIZATION 1",
-    "#cmakedefine01 BUILD_INNER_PRODUCT": "#define BUILD_INNER_PRODUCT 0",
-    "#cmakedefine01 BUILD_LAYER_NORMALIZATION": "#define BUILD_LAYER_NORMALIZATION 0",
-    "#cmakedefine01 BUILD_LRN": "#define BUILD_LRN 0",
-    "#cmakedefine01 BUILD_MATMUL": "#define BUILD_MATMUL 0",
-    "#cmakedefine01 BUILD_POOLING": "#define BUILD_POOLING 0",
-    "#cmakedefine01 BUILD_PRELU": "#define BUILD_PRELU 0",
-    "#cmakedefine01 BUILD_REDUCTION": "#define BUILD_REDUCTION 0",
-    "#cmakedefine01 BUILD_REORDER": "#define BUILD_REORDER 0",
-    "#cmakedefine01 BUILD_RESAMPLING": "#define BUILD_RESAMPLING 0",
-    "#cmakedefine01 BUILD_RNN": "#define BUILD_RNN 0",
-    "#cmakedefine01 BUILD_SHUFFLE": "#define BUILD_SHUFFLE 0",
-    "#cmakedefine01 BUILD_SOFTMAX": "#define BUILD_SOFTMAX 0",
-    "#cmakedefine01 BUILD_SUM": "#define BUILD_SUM 0",
-    "#cmakedefine01 BUILD_PRIMITIVE_CPU_ISA_ALL": "#define BUILD_PRIMITIVE_CPU_ISA_ALL 1",
-    "#cmakedefine01 BUILD_SSE41": "#define BUILD_SSE41 0",
-    "#cmakedefine01 BUILD_AVX2": "#define BUILD_AVX2 0",
-    "#cmakedefine01 BUILD_AVX512": "#define BUILD_AVX512 0",
-    "#cmakedefine01 BUILD_AMX": "#define BUILD_AMX 0",
-    "#cmakedefine01 BUILD_PRIMITIVE_GPU_ISA_ALL": "#define BUILD_PRIMITIVE_GPU_ISA_ALL 0",
-    "#cmakedefine01 BUILD_GEN9": "#define BUILD_GEN9 0",
-    "#cmakedefine01 BUILD_GEN11": "#define BUILD_GEN11 0",
-    "#cmakedefine01 BUILD_XELP": "#define BUILD_XELP 0",
-    "#cmakedefine01 BUILD_XEHPG": "#define BUILD_XEHPG 0",
-    "#cmakedefine01 BUILD_XEHPC": "#define BUILD_XEHPC 0",
-    "#cmakedefine01 BUILD_XEHP": "#define BUILD_XEHP 0",
-}
-
-_DNNL_RUNTIME_OMP = {
-    "#cmakedefine DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_${DNNL_CPU_THREADING_RUNTIME}": "#define DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_OMP",
-    "#cmakedefine DNNL_CPU_RUNTIME DNNL_RUNTIME_${DNNL_CPU_RUNTIME}": "#define DNNL_CPU_RUNTIME DNNL_RUNTIME_OMP",
-}
-
-_DNNL_RUNTIME_OMP.update(_CMAKE_COMMON_LIST)
-
-_DNNL_RUNTIME_THREADPOOL = {
-    "#cmakedefine DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_${DNNL_CPU_THREADING_RUNTIME}": "#define DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_THREADPOOL",
-    "#cmakedefine DNNL_CPU_RUNTIME DNNL_RUNTIME_${DNNL_CPU_RUNTIME}": "#define DNNL_CPU_RUNTIME DNNL_RUNTIME_THREADPOOL",
-}
-
-_DNNL_RUNTIME_THREADPOOL.update(_CMAKE_COMMON_LIST)
-
-expand_template(
-    name = "dnnl_config_h",
-    out = "include/oneapi/dnnl/dnnl_config.h",
-    substitutions = select({
-        "@local_xla//third_party/mkl_dnn:build_with_mkldnn_openmp": _DNNL_RUNTIME_OMP,
-        "//conditions:default": _DNNL_RUNTIME_THREADPOOL,
-    }),
-    template = "include/oneapi/dnnl/dnnl_config.h.in",
-)
-
-# Create the file dnnl_version.h with DNNL version numbers.
-# Currently, the version numbers are hard coded here. If DNNL is upgraded then
-# the version numbers have to be updated manually. The version numbers can be
-# obtained from the PROJECT_VERSION settings in CMakeLists.txt. The variable is
-# set to "version_major.version_minor.version_patch". The git hash version can
-# be set to NA.
-# TODO(agramesh1): Automatically get the version numbers from CMakeLists.txt.
-expand_template(
-    name = "dnnl_version_h",
-    out = "include/oneapi/dnnl/dnnl_version.h",
-    substitutions = {
-        "@DNNL_VERSION_MAJOR@": "3",
-        "@DNNL_VERSION_MINOR@": "3",
-        "@DNNL_VERSION_PATCH@": "0",
-        "@DNNL_VERSION_HASH@": "N/A",
-    },
-    template = "include/oneapi/dnnl/dnnl_version.h.in",
-)
-
-_COPTS_LIST = select({
-    "@local_tsl//tsl:windows": [],
-    "//conditions:default": ["-fexceptions"],
-}) + [
-    "-UUSE_MKL",
-    "-UUSE_CBLAS",
-    "-DDNNL_ENABLE_MAX_CPU_ISA",
-    "-DDNNL_ENABLE_ITT_TASKS",
-] + tf_openmp_copts()
-
-_INCLUDES_LIST = [
-    "include",
-    "src",
-    "src/common",
-    "src/common/ittnotify",
-    "src/cpu",
-    "src/cpu/gemm",
-    "src/cpu/x64/xbyak",
-]
-
-_TEXTUAL_HDRS_LIST = glob([
-    "include/**/*",
-    "src/common/*.hpp",
-    "src/common/ittnotify/**/*.h",
-    "src/cpu/*.hpp",
-    "src/cpu/**/*.hpp",
-    "src/cpu/jit_utils/**/*.hpp",
-    "src/cpu/x64/xbyak/*.h",
-]) + [
-    ":dnnl_config_h",
-    ":dnnl_version_h",
-]
-
-# Large autogen files take too long time to compile with usual optimization
-# flags. These files just generate binary kernels and are not the hot spots,
-# so we factor them out to lower compiler optimizations in ":dnnl_autogen".
-# Using -O1 to enable optimizations to reduce stack consumption. (With -O0,
-# compiler doesn't clean up stack from temporary objects.)
-cc_library(
-    name = "onednn_autogen",
-    srcs = glob(["src/cpu/x64/gemm/**/*_kern_autogen*.cpp"]),
-    copts = [
-        "-O1",
-        "-U_FORTIFY_SOURCE",
-    ] + _COPTS_LIST,
-    includes = _INCLUDES_LIST,
-    textual_hdrs = _TEXTUAL_HDRS_LIST,
-    visibility = ["//visibility:public"],
-)
-
-cc_library(
-    name = "mkl_dnn",
-    srcs = glob(
-        [
-            "src/common/*.cpp",
-            "src/cpu/*.cpp",
-            "src/cpu/**/*.cpp",
-            "src/common/ittnotify/*.c",
-            "src/cpu/jit_utils/**/*.cpp",
-        ],
-        exclude = [
-            "src/cpu/aarch64/**",
-            "src/cpu/rv64/**",
-            "src/cpu/x64/gemm/**/*_kern_autogen.cpp",
-        ],
-    ),
-    copts = _COPTS_LIST,
-    includes = _INCLUDES_LIST,
-    # TODO(penpornk): Use lrt_if_needed from tensorflow.bzl instead.
-    linkopts = select({
-        "@local_tsl//tsl:linux_aarch64": ["-lrt"],
-        "@local_tsl//tsl:linux_x86_64": ["-lrt"],
-        "@local_tsl//tsl:linux_ppc64le": ["-lrt"],
-        "//conditions:default": [],
-    }),
-    textual_hdrs = _TEXTUAL_HDRS_LIST,
-    visibility = ["//visibility:public"],
-    deps = [":onednn_autogen"] + if_mkl_ml(
-        ["@local_xla//third_party/mkl:intel_binary_blob"],
-        [],
-    ),
-)
diff --git a/third_party/xla/third_party/mkl_dnn/onednn_acl_bf16_capability_detection_for_ubuntu20.04.patch b/third_party/xla/third_party/mkl_dnn/onednn_acl_bf16_capability_detection_for_ubuntu20.04.patch
deleted file mode 100644
index 6d6f0c0..0000000
--- a/third_party/xla/third_party/mkl_dnn/onednn_acl_bf16_capability_detection_for_ubuntu20.04.patch
+++ /dev/null
@@ -1,50 +0,0 @@
-From 9a9430c7db870b78c6402d786a67921af4a66334 Mon Sep 17 00:00:00 2001
-From: Kentaro Kawakami <kawakami.k@fujitsu.com>
-Date: Fri, 26 May 2023 10:58:36 +0900
-Subject: [PATCH] cpu: aarch64: xbyak_aarch64: BF16 capability detection for
- Ubuntu 20.04
-
----
- .../aarch64/xbyak_aarch64/src/util_impl_linux.h   | 15 ++++++++++++---
- 1 file changed, 12 insertions(+), 3 deletions(-)
-
-diff --git a/src/cpu/aarch64/xbyak_aarch64/src/util_impl_linux.h b/src/cpu/aarch64/xbyak_aarch64/src/util_impl_linux.h
-index 743843bae50..3db37e972d1 100644
---- a/src/cpu/aarch64/xbyak_aarch64/src/util_impl_linux.h
-+++ b/src/cpu/aarch64/xbyak_aarch64/src/util_impl_linux.h
-@@ -39,6 +39,13 @@
- #include <asm/hwcap.h>
- #endif
- 
-+/* Linux kernel used in Ubuntu 20.04 does not have HWCAP2_BF16 definition. */
-+#ifdef AT_HWCAP2
-+#ifndef HWCAP2_BF16
-+#define HWCAP2_BF16 (1UL << 14)
-+#endif
-+#endif
-+
- namespace Xbyak_aarch64 {
- namespace util {
- #define XBYAK_AARCH64_ERROR_ fprintf(stderr, "%s, %d, Error occurrs during read cache infomation.\n", __FILE__, __LINE__);
-@@ -383,7 +390,7 @@ class CpuInfoLinux : public CpuInfo {
-   }
- 
-   void setHwCap() {
--    unsigned long hwcap = getauxval(AT_HWCAP);
-+    const unsigned long hwcap = getauxval(AT_HWCAP);
-     if (hwcap & HWCAP_ATOMICS)
-       type_ |= (Type)XBYAK_AARCH64_HWCAP_ATOMIC;
- 
-@@ -391,8 +398,10 @@ class CpuInfoLinux : public CpuInfo {
-       type_ |= (Type)XBYAK_AARCH64_HWCAP_FP;
-     if (hwcap & HWCAP_ASIMD)
-       type_ |= (Type)XBYAK_AARCH64_HWCAP_ADVSIMD;
--#ifdef HWCAP2_BF16
--    if (hwcap & HWCAP2_BF16)
-+
-+#ifdef AT_HWCAP2
-+    const unsigned long hwcap2 = getauxval(AT_HWCAP2);
-+    if (hwcap2 & HWCAP2_BF16)
-       type_ |= (Type)XBYAK_AARCH64_HWCAP_BF16;
- #endif
- 
diff --git a/third_party/xla/third_party/mkl_dnn/onednn_acl_fp32_bf16_reorder.patch b/third_party/xla/third_party/mkl_dnn/onednn_acl_fp32_bf16_reorder.patch
deleted file mode 100644
index 202902a..0000000
--- a/third_party/xla/third_party/mkl_dnn/onednn_acl_fp32_bf16_reorder.patch
+++ /dev/null
@@ -1,111 +0,0 @@
- *******************************************************************************
- Copyright 2023 Arm Limited and affiliates.
- SPDX-License-Identifier: Apache-2.0
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
-     http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- *******************************************************************************
-diff --git a/src/cpu/aarch64/cpu_isa_traits.hpp b/src/cpu/aarch64/cpu_isa_traits.hpp
-index 4a43b24c5..1a5cfe590 100644
---- a/src/cpu/aarch64/cpu_isa_traits.hpp
-+++ b/src/cpu/aarch64/cpu_isa_traits.hpp
-@@ -1,6 +1,7 @@
- /*******************************************************************************
- * Copyright 2018-2023 Intel Corporation
- * Copyright 2020-2023 FUJITSU LIMITED
-+* Copyright 2023 Arm Ltd. and affiliates
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
-@@ -211,10 +212,10 @@ static inline bool mayiuse_atomic() {
-     return cpu().isAtomicSupported();
- }
- 
--inline bool isa_has_bf16(cpu_isa_t isa) {
--    return false;
-+static inline bool mayiuse_bf16() {
-+    using namespace Xbyak_aarch64::util;
-+    return cpu().isBf16Supported();
- }
--
- } // namespace
- 
- /* whatever is required to generate string literals... */
-diff --git a/src/cpu/aarch64/jit_uni_reorder.cpp b/src/cpu/aarch64/jit_uni_reorder.cpp
-index 6bd259ec2..5541bb702 100644
---- a/src/cpu/aarch64/jit_uni_reorder.cpp
-+++ b/src/cpu/aarch64/jit_uni_reorder.cpp
-@@ -1,7 +1,7 @@
- /*******************************************************************************
- * Copyright 2018-2023 Intel Corporation
- * Copyright 2020-2023 FUJITSU LIMITED
--* Copyright 2022 Arm Ltd. and affiliates
-+* Copyright 2022-2023 Arm Ltd. and affiliates
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
-@@ -163,11 +163,11 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
- 
-         bool ok = true && p.ndims > 0
-                 && utils::one_of(p.itype, f32, s32, data_type::s8, u8)
--                && utils::one_of(p.otype, f32, s32, data_type::s8, u8)
-+                && utils::one_of(p.otype, f32, bf16, s32, data_type::s8, u8)
-                 && utils::everyone_is(0, p.ioff, p.ooff) /* do we need this? */
-                 && utils::one_of(p.beta, 0.f, 1.f) /* anything else? */
--                && simple_impl_desc_init(p, nullptr)
--                && prb_has_small_strides(p);
-+                && simple_impl_desc_init(p, nullptr) && prb_has_small_strides(p)
-+                && ((p.otype != bf16) || (p.itype == f32 && mayiuse_bf16()));
- 
-         return ok;
-     }
-@@ -648,6 +648,9 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
-                         cvt_v_s32_u8(startIdx, regNum);
-                     if (idt == data_type::s8) cvt_v_s8_u8(startIdx, regNum);
-                     break;
-+                case bf16:
-+                    if (idt == f32) cvt_v_f32_bf16(startIdx, regNum);
-+                    break;
-                 default: assert(!"unreachable");
-             }
-         };
-@@ -1677,6 +1680,10 @@ struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
-         UNROLL_INST(fcvtzs, VReg4S, tmp, tmp);
-     }
- 
-+    void cvt_v_f32_bf16(const size_t startIdx, const size_t regNum) {
-+        UNROLL_INST2(bfcvtn, VReg4H(i), VReg4S(i));
-+    }
-+
-     void cvt_z_s8_s32(const size_t startIdx, const size_t regNum) {
-         cvt_z_b_s(startIdx, regNum);
-         UNROLL_INST(sxtb, ZRegS, tmp, P_ALL_ONE / T_m, tmp);
-diff --git a/src/cpu/reorder/cpu_reorder_regular_f32_bf16.cpp b/src/cpu/reorder/cpu_reorder_regular_f32_bf16.cpp
-index ba5499ba9..d4e21d316 100644
---- a/src/cpu/reorder/cpu_reorder_regular_f32_bf16.cpp
-+++ b/src/cpu/reorder/cpu_reorder_regular_f32_bf16.cpp
-@@ -1,5 +1,6 @@
- /*******************************************************************************
- * Copyright 2020-2022 Intel Corporation
-+* Copyright 2023 Arm Ltd. and affiliates
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
-@@ -34,6 +35,8 @@ const impl_list_map_t &regular_f32_bf16_impl_list_map() {
-             DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, bf16, nChw16c))
-             DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, bf16, nCdhw16c))
- 
-+            DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t))
-+
-             DNNL_NON_X64_ONLY(REG_SR(f32, oihw, bf16, OIhw8i16o2i, fmt_order::keep))
-             DNNL_NON_X64_ONLY(REG_SR(f32, goihw, bf16, gOIhw8i16o2i, fmt_order::keep))
-             DNNL_NON_X64_ONLY(REG_SR(f32, oihw, bf16, OIhw8o16i2o, fmt_order::keep))
diff --git a/third_party/xla/third_party/mkl_dnn/onednn_acl_reorder.patch b/third_party/xla/third_party/mkl_dnn/onednn_acl_reorder.patch
deleted file mode 100644
index 5da6756..0000000
--- a/third_party/xla/third_party/mkl_dnn/onednn_acl_reorder.patch
+++ /dev/null
@@ -1,371 +0,0 @@
- *******************************************************************************
- Copyright 2023 Arm Limited and affiliates.
- SPDX-License-Identifier: Apache-2.0
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
-     http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- *******************************************************************************
-diff --git a/src/cpu/aarch64/acl_reorder.cpp b/src/cpu/aarch64/acl_reorder.cpp
-new file mode 100644
-index 000000000..061751b55
---- /dev/null
-+++ b/src/cpu/aarch64/acl_reorder.cpp
-@@ -0,0 +1,52 @@
-+/*******************************************************************************
-+* Copyright 2023 Arm Ltd. and affiliates
-+*
-+* Licensed under the Apache License, Version 2.0 (the "License");
-+* you may not use this file except in compliance with the License.
-+* You may obtain a copy of the License at
-+*
-+*     http://www.apache.org/licenses/LICENSE-2.0
-+*
-+* Unless required by applicable law or agreed to in writing, software
-+* distributed under the License is distributed on an "AS IS" BASIS,
-+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-+* See the License for the specific language governing permissions and
-+* limitations under the License.
-+*******************************************************************************/
-+
-+#include "cpu/aarch64/acl_reorder.hpp"
-+
-+namespace dnnl {
-+namespace impl {
-+namespace cpu {
-+namespace aarch64 {
-+
-+status_t acl_reorder_fwd_t::execute_forward(const exec_ctx_t &ctx) const {
-+    // Lock here is needed because resource_mapper does not support
-+    // concurrent multithreaded access.
-+    std::lock_guard<std::mutex> _lock {this->mtx};
-+
-+    auto src = CTX_IN_MEM(const void *, DNNL_ARG_FROM);
-+    auto dst = CTX_OUT_MEM(void *, DNNL_ARG_TO);
-+
-+    // Retrieve primitive resource and configured Compute Library objects
-+    auto *acl_resource
-+            = ctx.get_resource_mapper()->get<acl_reorder_resource_t>(this);
-+
-+    acl_reorder_obj_t &acl_obj = acl_resource->get_acl_obj();
-+
-+    acl_obj.src_tensor.allocator()->import_memory(const_cast<void *>(src));
-+    acl_obj.dst_tensor.allocator()->import_memory(dst);
-+
-+    acl_obj.reorder.run();
-+
-+    acl_obj.src_tensor.allocator()->free();
-+    acl_obj.dst_tensor.allocator()->free();
-+
-+    return status::success;
-+}
-+
-+} // namespace aarch64
-+} // namespace cpu
-+} // namespace impl
-+} // namespace dnnl
-diff --git a/src/cpu/aarch64/acl_reorder.hpp b/src/cpu/aarch64/acl_reorder.hpp
-new file mode 100644
-index 0000000000..edbc38914d
---- /dev/null
-+++ b/src/cpu/aarch64/acl_reorder.hpp
-@@ -0,0 +1,262 @@
-+/*******************************************************************************
-+* Copyright 2023 Arm Ltd. and affiliates
-+*
-+* Licensed under the Apache License, Version 2.0 (the "License");
-+* you may not use this file except in compliance with the License.
-+* You may obtain a copy of the License at
-+*
-+*     http://www.apache.org/licenses/LICENSE-2.0
-+*
-+* Unless required by applicable law or agreed to in writing, software
-+* distributed under the License is distributed on an "AS IS" BASIS,
-+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-+* See the License for the specific language governing permissions and
-+* limitations under the License.
-+*******************************************************************************/
-+#ifndef CPU_AARCH64_ACL_REORDER_HPP
-+#define CPU_AARCH64_ACL_REORDER_HPP
-+
-+#include "cpu/aarch64/acl_utils.hpp"
-+#include "cpu/reorder/cpu_reorder_pd.hpp"
-+#include "arm_compute/core/Types.h"
-+#include "common/utils.hpp"
-+
-+namespace dnnl {
-+namespace impl {
-+namespace cpu {
-+namespace aarch64 {
-+
-+struct acl_reorder_obj_t {
-+    arm_compute::NEReorderLayer reorder;
-+    arm_compute::Tensor src_tensor;
-+    arm_compute::Tensor dst_tensor;
-+    arm_compute::WeightFormat src_wf;
-+    arm_compute::WeightFormat dst_wf;
-+};
-+
-+struct acl_reorder_conf_t {
-+    arm_compute::TensorInfo src_info;
-+    arm_compute::TensorInfo dst_info;
-+    arm_compute::WeightFormat src_wf;
-+    arm_compute::WeightFormat dst_wf;
-+};
-+
-+struct acl_reorder_resource_t : public resource_t {
-+    acl_reorder_resource_t() : acl_obj_(utils::make_unique<acl_reorder_obj_t>()) {}
-+
-+    status_t configure(const acl_reorder_conf_t &app) {
-+        if (!acl_obj_) return status::out_of_memory;
-+
-+        // Init Compute Library tensors based on info from descriptor
-+        acl_obj_->src_tensor.allocator()->init(app.src_info);
-+        acl_obj_->dst_tensor.allocator()->init(app.dst_info);
-+
-+        // clang-format off
-+        acl_obj_->reorder.configure(
-+            &acl_obj_->src_tensor,
-+            &acl_obj_->dst_tensor,
-+            app.src_wf,
-+            app.dst_wf
-+            );
-+        // clang-format on
-+
-+        return status::success;
-+    }
-+
-+    acl_reorder_obj_t &get_acl_obj() const { return *acl_obj_; }
-+    DNNL_DISALLOW_COPY_AND_ASSIGN(acl_reorder_resource_t);
-+
-+private:
-+    std::unique_ptr<acl_reorder_obj_t> acl_obj_;
-+}; // acl_reorder_resource_t
-+
-+struct acl_reorder_fwd_t : public primitive_t {
-+    using primitive_t::primitive_t;
-+    struct pd_t : public cpu_reorder_pd_t {
-+
-+        using cpu_reorder_pd_t::cpu_reorder_pd_t;
-+
-+        DECLARE_COMMON_PD_T("acl", acl_reorder_fwd_t);
-+
-+        static status_t create(reorder_pd_t **reorder_pd, engine_t *engine,
-+                const primitive_attr_t *attr, engine_t *src_engine,
-+                const memory_desc_t *src_md, engine_t *dst_engine,
-+                const memory_desc_t *dst_md) {
-+
-+            using namespace acl_utils;
-+            // using skip_mask_t = dnnl_primitive_attr::skip_mask_t;
-+
-+            bool ok = src_md->data_type
-+                            == dst_md->data_type // ACL only supports matching src/dst data types
-+                    && utils::one_of(src_md->data_type,
-+                            data_type::f32) // Only supports f32 for now
-+                    && attr->has_default_values();
-+            if (!ok) return status::unimplemented;
-+
-+            int mask = -1;
-+            bool is_set = false;
-+            // CHECK(attr->scales_.get(DNNL_ARG_DST, &mask, &is_set));
-+            const memory_desc_wrapper input_d(src_md);
-+            if (input_d.has_runtime_dims_or_strides() && is_set && mask > 0)
-+                return status::unimplemented;
-+
-+            // Create and check primitive descriptor
-+            auto _pd = new pd_t(attr, src_engine->kind(), src_md,
-+                    dst_engine->kind(), dst_md);
-+            if (_pd == nullptr) return status::out_of_memory;
-+            if (_pd->init(engine, src_engine, dst_engine) != status::success) {
-+                delete _pd;
-+                return status::unimplemented;
-+            }
-+
-+            const memory_desc_wrapper src_d(*src_md);
-+            const memory_desc_wrapper dst_d(*dst_md);
-+
-+            const int ndims = src_d.ndims();
-+
-+            auto src_tag = memory_desc_matches_one_of_tag(
-+                            *src_md, format_tag::ba, format_tag::cdba);
-+            ACL_CHECK_SUPPORT(
-+                            utils::one_of(format_tag::undef, src_tag),
-+                            "");
-+
-+            arm_compute::TensorShape acl_tensor_shape_in;
-+            arm_compute::TensorShape acl_tensor_shape_out;
-+            // Need even amount of dims in dim 0 for ACL kernel (eg mulitple of 8 rows when blocking by 8)
-+            int dim_0_rounded_up;
-+
-+            // Switch for 2 or 4 dim tensors
-+            switch(ndims)
-+            {
-+                // Currently for Ab4a and Ab8a
-+                // No format_tag for these, have to deduce from stride
-+                case 2:
-+                    {
-+                        if(dst_md->dims[0] == 1 || dst_md->dims[1] == 1){
-+                            return status::unimplemented;
-+                        }
-+                        int dst_dim_1 = dst_md->dims[1];
-+                        int dst_dim_0_stride = dst_md->format_desc.blocking.strides[0];
-+                        int dst_dim_1_stride = dst_md->format_desc.blocking.strides[1];
-+                        // Interleave of 4 or 8 that stride for dim 1
-+                        if (dst_dim_1_stride != 4 && dst_dim_1_stride != 8){
-+                            return status::unimplemented;
-+                        }
-+                        // Check to ensure it's a blocking transpose
-+                        if (dst_dim_1 * dst_dim_1_stride != dst_dim_0_stride){
-+                            return status::unimplemented;
-+                        }
-+                        if(dst_dim_1_stride == 4){
-+                            // Set Dest WeightFormat
-+                            _pd->app_.dst_wf = arm_compute::WeightFormat::OHWIo4;
-+                            dim_0_rounded_up
-+                                    = utils::rnd_up(src_md->dims[0], 4);
-+                        } else {
-+                            // Set Dest WeightFormat
-+                            _pd->app_.dst_wf = arm_compute::WeightFormat::OHWIo8;
-+                            dim_0_rounded_up
-+                                    = utils::rnd_up(src_md->dims[0], 8);
-+                        }
-+                        acl_tensor_shape_in = arm_compute::TensorShape(src_md->dims[1], src_md->dims[0]);
-+                        acl_tensor_shape_out = arm_compute::TensorShape(src_md->dims[1], dim_0_rounded_up);
-+
-+                        break;
-+                    }
-+                // Currently for Acdb4a and Acdb8a
-+                case 4:
-+                    { 
-+
-+                        auto dst_tag = memory_desc_matches_one_of_tag(
-+                            *dst_md, format_tag::Acdb4a, format_tag::Acdb8a);
-+                        ACL_CHECK_SUPPORT(
-+                            utils::one_of(format_tag::undef, dst_tag),
-+                            "");
-+                        if(dst_tag == format_tag::Acdb4a){
-+                            // Set Dest WeightFormat
-+                            _pd->app_.dst_wf = arm_compute::WeightFormat::OHWIo4;
-+                            dim_0_rounded_up
-+                                    = utils::rnd_up(src_md->dims[0], 4);
-+                        }
-+                        else{
-+                            // Set Dest WeightFormat
-+                            _pd->app_.dst_wf = arm_compute::WeightFormat::OHWIo8;
-+                            dim_0_rounded_up
-+                                    = utils::rnd_up(src_md->dims[0], 8);
-+                        }
-+                        // Currently only supporting AxBx1x1 cases
-+                        if(dst_md->dims[2] != 1 || dst_md->dims[3] != 1){
-+                            return status::unimplemented;
-+                        }
-+                        if(dst_md->dims[0] == 1 || dst_md->dims[1] == 1){
-+                            return status::unimplemented;
-+                        }
-+                        acl_tensor_shape_in = arm_compute::TensorShape(src_md->dims[3], src_md->dims[2], src_md->dims[1], src_md->dims[0]);
-+                        acl_tensor_shape_out = arm_compute::TensorShape(src_md->dims[3], src_md->dims[2], src_md->dims[1], dim_0_rounded_up);
-+                        break;
-+                    }
-+                default:
-+                    return status::unimplemented;
-+            }
-+
-+            // Choose the data layout
-+            // bool is_nspc = utils::one_of(src_tag, format_tag::nhwc);
-+            const auto acl_layout = arm_compute::DataLayout::NCHW;
-+
-+            // Set Source WeightFormat
-+            _pd->app_.src_wf = arm_compute::WeightFormat::OHWI;
-+
-+            // Create ACL tensor infos
-+            const data_type_t data_type = src_d.data_type();
-+            const arm_compute::DataType acl_data_t
-+                    = acl_utils::get_acl_data_t(data_type);
-+            _pd->app_.src_info = arm_compute::TensorInfo(
-+                        acl_tensor_shape_in, 1, acl_data_t, acl_layout);
-+            _pd->app_.dst_info = arm_compute::TensorInfo(
-+                        acl_tensor_shape_out, 1, acl_data_t, acl_layout);
-+
-+            // Init scratch memory, not used so 0 in this implementation
-+            _pd->init_scratchpad_md();
-+
-+            return safe_ptr_assign(*reorder_pd, _pd);
-+        } // create 
-+
-+        friend dnnl::impl::impl_list_item_t;
-+        acl_reorder_conf_t app_;
-+
-+    }; // pd_t
-+
-+    acl_reorder_fwd_t(const pd_t *apd) : primitive_t(apd) {}
-+
-+    status_t create_resource(
-+            engine_t *engine, resource_mapper_t &mapper) const override {
-+        if (mapper.has_resource(this)) return status::success;
-+
-+        auto r = utils::make_unique<acl_reorder_resource_t>();
-+        if (!r) return status::out_of_memory;
-+
-+        // Configure the resource based on information from primitive descriptor
-+        CHECK(r->configure(pd()->app_));
-+
-+        mapper.add(this, std::move(r));
-+        return status::success;
-+    }
-+
-+    status_t execute(const exec_ctx_t &ctx) const override {
-+        return execute_forward(ctx);
-+    }
-+
-+private:
-+    // To guard the const execute_forward, the mutex must be 'mutable'
-+    mutable std::mutex mtx;
-+    status_t execute_forward(const exec_ctx_t &ctx) const;
-+    const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
-+
-+
-+}; // acl_reorder_fwd_t
-+
-+} // namespace aarch64
-+} // namespace cpu
-+} // namespace impl
-+} // namespace dnnl
-+
-+#endif // CPU_AARCH64_ACL_REORDER_HPP
-diff --git a/src/cpu/reorder/cpu_reorder_regular_f32_f32.cpp b/src/cpu/reorder/cpu_reorder_regular_f32_f32.cpp
-index a4150b619..f4d6b4de3 100644
---- a/src/cpu/reorder/cpu_reorder_regular_f32_f32.cpp
-+++ b/src/cpu/reorder/cpu_reorder_regular_f32_f32.cpp
-@@ -16,6 +16,7 @@
- *******************************************************************************/
- 
- #include "cpu/reorder/cpu_reorder.hpp"
-+#include "cpu/aarch64/acl_reorder.hpp"
- 
- namespace dnnl {
- namespace impl {
-@@ -28,6 +29,7 @@ const impl_list_map_t &regular_f32_f32_impl_list_map() {
-         // f32 -> f32
-         {{f32, f32, 0}, {
-             REG_FAST_DIRECT_COPY_F32_F32
-+            DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::acl_reorder_fwd_t))
- 
-             DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::brgemm_matmul_matrix_B_reorder_t))
-             DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t))
-@@ -69,6 +71,8 @@ const impl_list_map_t &regular_f32_f32_impl_list_map() {
-             nullptr,
-         }},
-         {{f32, f32, 4}, {
-+
-+            DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::acl_reorder_fwd_t))
-             CPU_REORDER_INSTANCE(rnn_weights_reorder_t<f32, f32>)
- 
-             REG_FAST_DIRECT_COPY_F32_F32
diff --git a/third_party/xla/third_party/mkl_dnn/onednn_acl_thread_local_scheduler.patch b/third_party/xla/third_party/mkl_dnn/onednn_acl_thread_local_scheduler.patch
deleted file mode 100644
index 9583308..0000000
--- a/third_party/xla/third_party/mkl_dnn/onednn_acl_thread_local_scheduler.patch
+++ /dev/null
@@ -1,97 +0,0 @@
- *******************************************************************************
- Copyright 2023 Arm Limited and affiliates.
- SPDX-License-Identifier: Apache-2.0
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
-     http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- *******************************************************************************
-diff --git a/src/cpu/aarch64/acl_thread.cpp b/src/cpu/aarch64/acl_thread.cpp
-index fd2c76d01..bd7bed837 100644
---- a/src/cpu/aarch64/acl_thread.cpp
-+++ b/src/cpu/aarch64/acl_thread.cpp
-@@ -55,14 +55,17 @@ void acl_set_benchmark_scheduler_default() {
- #endif
- 
- #if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL
--void acl_set_tp_scheduler() {
--    static std::once_flag flag_once;
--    // Create threadpool scheduler
--    std::shared_ptr<arm_compute::IScheduler> threadpool_scheduler
--            = std::make_unique<ThreadpoolScheduler>();
-+void acl_set_tp_scheduler(int intra_threads = 0) {
-+    static thread_local std::once_flag flag_once;
-     // set CUSTOM scheduler in ACL
-     std::call_once(flag_once,
--            [&]() { arm_compute::Scheduler::set(threadpool_scheduler); });
-+            [&]() {
-+                    // Create threadpool scheduler
-+                    std::shared_ptr<arm_compute::IScheduler> threadpool_scheduler
-+                        = std::make_unique<ThreadpoolScheduler>();
-+                    threadpool_scheduler->set_num_threads(intra_threads);
-+
-+                    arm_compute::Scheduler::set(threadpool_scheduler); });
- }
- 
- void acl_set_threadpool_num_threads() {
-@@ -102,14 +105,6 @@ void set_acl_threading() {
-         acl_set_benchmark_scheduler_default();
-     }
- #endif
--#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL
--    if (verbose_has_profile_externals()) {
--        acl_set_tp_benchmark_scheduler();
--    } else {
--        acl_set_tp_scheduler();
--    }
--
--#endif
- }
- 
- } // namespace acl_thread_utils
-diff --git a/src/cpu/aarch64/acl_thread.hpp b/src/cpu/aarch64/acl_thread.hpp
-index f073376e6..654a2aa5d 100644
---- a/src/cpu/aarch64/acl_thread.hpp
-+++ b/src/cpu/aarch64/acl_thread.hpp
-@@ -40,7 +40,7 @@ void acl_set_benchmark_scheduler_default();
- 
- #if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL
- // Retrieve threadpool size during primitive execution and set ThreadpoolScheduler num_threads
--void acl_set_tp_scheduler();
-+void acl_set_tp_scheduler(int intra_threads);
- void acl_set_threadpool_num_threads();
- // Swap BenchmarkScheduler for custom scheduler builds (i.e. ThreadPoolScheduler) for DNNL_VERBOSE=profile,profile_externals
- void acl_set_tp_benchmark_scheduler();
-diff --git a/src/cpu/aarch64/acl_threadpool_scheduler.cpp b/src/cpu/aarch64/acl_threadpool_scheduler.cpp
-index 439ca862e..6656c37a5 100644
---- a/src/cpu/aarch64/acl_threadpool_scheduler.cpp
-+++ b/src/cpu/aarch64/acl_threadpool_scheduler.cpp
-@@ -102,8 +102,6 @@ void ThreadpoolScheduler::schedule_op(ICPPKernel *kernel, const Hints &hints,
- void ThreadpoolScheduler::run_workloads(
-         std::vector<arm_compute::IScheduler::Workload> &workloads) {
- 
--    arm_compute::lock_guard<std::mutex> lock(this->_run_workloads_mutex);
--
-     const unsigned int num_threads
-             = std::min(static_cast<unsigned int>(_num_threads),
-                     static_cast<unsigned int>(workloads.size()));
-diff --git a/src/cpu/cpu_engine.cpp b/src/cpu/cpu_engine.cpp
-index 0bfec3871..7207b2b60 100644
---- a/src/cpu/cpu_engine.cpp
-+++ b/src/cpu/cpu_engine.cpp
-@@ -47,6 +47,7 @@ status_t cpu_engine_t::create_stream(stream_t **stream, unsigned flags) {
- #if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL
- status_t cpu_engine_t::create_stream(stream_t **stream,
-         dnnl::threadpool_interop::threadpool_iface *threadpool) {
-+    dnnl::impl::cpu::aarch64::acl_thread_utils::acl_set_tp_scheduler(threadpool->get_num_threads());
-     return safe_ptr_assign<stream_t>(
-             *stream, new cpu_stream_t(this, threadpool));
- }
diff --git a/third_party/xla/third_party/mkl_dnn/onednn_acl_threadcap.patch b/third_party/xla/third_party/mkl_dnn/onednn_acl_threadcap.patch
deleted file mode 100644
index 3a33af1..0000000
--- a/third_party/xla/third_party/mkl_dnn/onednn_acl_threadcap.patch
+++ /dev/null
@@ -1,43 +0,0 @@
- *******************************************************************************
- Copyright 2023 Arm Limited and affiliates.
- SPDX-License-Identifier: Apache-2.0
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
-     http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- *******************************************************************************
-diff --git a/src/cpu/aarch64/acl_thread.cpp b/src/cpu/aarch64/acl_thread.cpp
-index fd2c76d01..2d7c76d48 100644
---- a/src/cpu/aarch64/acl_thread.cpp
-+++ b/src/cpu/aarch64/acl_thread.cpp
-@@ -17,6 +17,8 @@
- #include "cpu/aarch64/acl_thread.hpp"
- #if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL
- #include "cpu/aarch64/acl_threadpool_scheduler.hpp"
-+#elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_OMP
-+#include <thread>
- #endif
- #include "cpu/aarch64/acl_benchmark_scheduler.hpp"
- 
-@@ -30,9 +32,10 @@ namespace acl_thread_utils {
- #if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_OMP
- void acl_thread_bind() {
-     static std::once_flag flag_once;
--    // The threads in Compute Library are bound for the cores 0..max_threads-1
--    // dnnl_get_max_threads() returns OMP_NUM_THREADS
--    const int max_threads = dnnl_get_max_threads();
-+    // Cap the number of threads to 90% of the total core count
-+    // to ensure Compute Library doesn't use too much resource
-+    int capped_threads = (int)std::floor(0.9*std::thread::hardware_concurrency());
-+    const int max_threads = std::min(capped_threads, dnnl_get_max_threads());
-     // arm_compute::Scheduler does not support concurrent access thus a
-     // workaround here restricts it to only one call
-     std::call_once(flag_once, [&]() {
diff --git a/third_party/xla/third_party/nasm/BUILD b/third_party/xla/third_party/nasm/BUILD
deleted file mode 100644
index ed1568c..0000000
--- a/third_party/xla/third_party/nasm/BUILD
+++ /dev/null
@@ -1,3 +0,0 @@
-# Needed to make this a package.
-
-# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"])
diff --git a/third_party/xla/third_party/nasm/BUILD.system b/third_party/xla/third_party/nasm/BUILD.system
deleted file mode 100644
index 52f6081..0000000
--- a/third_party/xla/third_party/nasm/BUILD.system
+++ /dev/null
@@ -1,18 +0,0 @@
-licenses(["notice"])  # BSD 2-clause
-
-filegroup(
-    name = "LICENSE",
-    visibility = ["//visibility:public"],
-)
-
-genrule(
-    name = "lnnasmlink",
-    outs = ["nasmlink"],
-    cmd = "ln -s $$(which nasm) $@",
-)
-
-sh_binary(
-    name = "nasm",
-    srcs = ["nasmlink"],
-    visibility = ["@libjpeg_turbo//:__pkg__"],
-)
diff --git a/third_party/xla/third_party/nasm/config.h b/third_party/xla/third_party/nasm/config.h
deleted file mode 100644
index 3533280..0000000
--- a/third_party/xla/third_party/nasm/config.h
+++ /dev/null
@@ -1,543 +0,0 @@
-/* config/config.h.  Generated from config.h.in by configure.  */
-/* config/config.h.in.  Generated from configure.ac by autoheader.  */
-
-/* Define to 1 to call abort() on panics (internal errors), for debugging. */
-/* #undef ABORT_ON_PANIC */
-
-/* Define if building universal (internal helper macro) */
-/* #undef AC_APPLE_UNIVERSAL_BUILD */
-
-/* Define to 1 if compiled with the `-fdata-sections' compiler flag */
-/* #undef CFLAG_FDATA_SECTIONS */
-
-/* Define to 1 if compiled with the `-ffunction-sections' compiler flag */
-/* #undef CFLAG_FFUNCTION_SECTIONS */
-
-/* Define to 1 if compiled with the `-fgnu89-inline' compiler flag */
-/* #undef CFLAG_FGNU89_INLINE */
-
-/* Define to 1 if compiled with the `-flto' compiler flag */
-/* #undef CFLAG_FLTO */
-
-/* Define to 1 if compiled with the `-fno-common' compiler flag */
-#define CFLAG_FNO_COMMON 1
-
-/* Define to 1 if compiled with the `-fno-omit-frame-pointer' compiler flag */
-/* #undef CFLAG_FNO_OMIT_FRAME_POINTER */
-
-/* Define to 1 if compiled with the `-fsanitize=address' compiler flag */
-/* #undef CFLAG_FSANITIZE_ADDRESS */
-
-/* Define to 1 if compiled with the `-fsanitize=undefined' compiler flag */
-/* #undef CFLAG_FSANITIZE_UNDEFINED */
-
-/* Define to 1 if compiled with the `-fvisibility=hidden' compiler flag */
-#define CFLAG_FVISIBILITY_HIDDEN 1
-
-/* Define to 1 if compiled with the `-fwrapv' compiler flag */
-#define CFLAG_FWRAPV 1
-
-/* Define to 1 if compiled with the `-ggdb3' compiler flag */
-/* #undef CFLAG_GGDB3 */
-
-/* Define to 1 if compiled with the `-pedantic' compiler flag */
-#define CFLAG_PEDANTIC 1
-
-/* Define to 1 if compiled with the `-U__STRICT_ANSI__' compiler flag */
-#define CFLAG_U_STRICT_ANSI 1
-
-/* Define to 1 if compiled with the `-W' compiler flag */
-#define CFLAG_W 1
-
-/* Define to 1 if compiled with the `-Wall' compiler flag */
-#define CFLAG_WALL 1
-
-/* Define to 1 if compiled with the `-Wc90-c99-compat' compiler flag */
-/* #undef CFLAG_WC90_C99_COMPAT */
-
-/* Define to 1 if compiled with the `-Werror' compiler flag */
-/* #undef CFLAG_WERROR */
-
-/* Define to 1 if compiled with the `-Werror=attributes' compiler flag */
-#define CFLAG_WERROR_ATTRIBUTES 1
-
-/* Define to 1 if compiled with the `-Werror=comment' compiler flag */
-#define CFLAG_WERROR_COMMENT 1
-
-/* Define to 1 if compiled with the `-Werror=implicit' compiler flag */
-#define CFLAG_WERROR_IMPLICIT 1
-
-/* Define to 1 if compiled with the `-Werror=missing-braces' compiler flag */
-#define CFLAG_WERROR_MISSING_BRACES 1
-
-/* Define to 1 if compiled with the `-Werror=missing-declarations' compiler
-   flag */
-#define CFLAG_WERROR_MISSING_DECLARATIONS 1
-
-/* Define to 1 if compiled with the `-Werror=missing-prototypes' compiler flag
- */
-#define CFLAG_WERROR_MISSING_PROTOTYPES 1
-
-/* Define to 1 if compiled with the `-Werror=pointer-arith' compiler flag */
-#define CFLAG_WERROR_POINTER_ARITH 1
-
-/* Define to 1 if compiled with the `-Werror=return-type' compiler flag */
-#define CFLAG_WERROR_RETURN_TYPE 1
-
-/* Define to 1 if compiled with the `-Werror=strict-prototypes' compiler flag
- */
-/* #undef CFLAG_WERROR_STRICT_PROTOTYPES */
-
-/* Define to 1 if compiled with the `-Werror=trigraphs' compiler flag */
-#define CFLAG_WERROR_TRIGRAPHS 1
-
-/* Define to 1 if compiled with the `-Werror=unknown-warning-option' compiler
-   flag */
-/* #undef CFLAG_WERROR_UNKNOWN_WARNING_OPTION */
-
-/* Define to 1 if compiled with the `-Werror=vla' compiler flag */
-#define CFLAG_WERROR_VLA 1
-
-/* Define to 1 if compiled with the `-Wlong-long' compiler flag */
-#define CFLAG_WLONG_LONG 1
-
-/* Define to 1 if compiled with the `-Wl,--gc-sections' compiler flag */
-/* #undef CFLAG_WL_GC_SECTIONS */
-
-/* Define to 1 if compiled with the `-Wpedantic-ms-format' compiler flag */
-/* #undef CFLAG_WPEDANTIC_MS_FORMAT */
-
-/* Define to 1 if compiled with the `-Wshift-negative-value' compiler flag */
-#define CFLAG_WSHIFT_NEGATIVE_VALUE 1
-
-/* Define to 1 if compiled with the `-Wstringop-truncation' compiler flag */
-/* #undef CFLAG_WSTRINGOP_TRUNCATION */
-
-/* Define to 1 if you have the `access' function. */
-#define HAVE_ACCESS 1
-
-/* Define to 1 if you have the `canonicalize_file_name' function. */
-/* #undef HAVE_CANONICALIZE_FILE_NAME */
-
-/* Define to 1 if you have the `cpu_to_le16' intrinsic function. */
-/* #undef HAVE_CPU_TO_LE16 */
-
-/* Define to 1 if you have the `cpu_to_le32' intrinsic function. */
-/* #undef HAVE_CPU_TO_LE32 */
-
-/* Define to 1 if you have the `cpu_to_le64' intrinsic function. */
-/* #undef HAVE_CPU_TO_LE64 */
-
-/* Define to 1 if you have the declaration of `strcasecmp', and to 0 if you
-   don't. */
-#define HAVE_DECL_STRCASECMP 1
-
-/* Define to 1 if you have the declaration of `stricmp', and to 0 if you
-   don't. */
-#define HAVE_DECL_STRICMP 0
-
-/* Define to 1 if you have the declaration of `strlcpy', and to 0 if you
-   don't. */
-#define HAVE_DECL_STRLCPY 0
-
-/* Define to 1 if you have the declaration of `strncasecmp', and to 0 if you
-   don't. */
-#define HAVE_DECL_STRNCASECMP 1
-
-/* Define to 1 if you have the declaration of `strnicmp', and to 0 if you
-   don't. */
-#define HAVE_DECL_STRNICMP 0
-
-/* Define to 1 if you have the declaration of `strnlen', and to 0 if you
-   don't. */
-#define HAVE_DECL_STRNLEN 1
-
-/* Define to 1 if you have the declaration of `strrchrnul', and to 0 if you
-   don't. */
-#define HAVE_DECL_STRRCHRNUL 0
-
-/* Define to 1 if you have the declaration of `strsep', and to 0 if you don't.
- */
-#define HAVE_DECL_STRSEP 1
-
-/* Define to 1 if you have the <endian.h> header file. */
-/* #undef HAVE_ENDIAN_H */
-
-/* Define to 1 if you have the `faccessat' function. */
-#define HAVE_FACCESSAT 1
-
-/* Define to 1 if you have the <fcntl.h> header file. */
-#define HAVE_FCNTL_H 1
-
-/* Define to 1 if you have the `fileno' function. */
-#define HAVE_FILENO 1
-
-/* Define to 1 if fseeko (and presumably ftello) exists and is declared. */
-#define HAVE_FSEEKO 1
-
-/* Define to 1 if you have the `fstat' function. */
-#define HAVE_FSTAT 1
-
-/* Define to 1 if you have the `ftruncate' function. */
-#define HAVE_FTRUNCATE 1
-
-/* Define to 1 if your compiler supports __attribute__((alloc_size)) on
-   functions */
-#define HAVE_FUNC_ATTRIBUTE_ALLOC_SIZE 1
-
-/* Define to 1 if your compiler supports __attribute__((cold)) on functions */
-#define HAVE_FUNC_ATTRIBUTE_COLD 1
-
-/* Define to 1 if your compiler supports __attribute__((const)) on functions
- */
-#define HAVE_FUNC_ATTRIBUTE_CONST 1
-
-/* Define to 1 if your compiler supports __attribute__((error)) on functions
- */
-/* #undef HAVE_FUNC_ATTRIBUTE_ERROR */
-
-/* Define to 1 if your compiler supports __attribute__((format)) on functions
- */
-#define HAVE_FUNC_ATTRIBUTE_FORMAT 1
-
-/* Define to 1 if your compiler supports __attribute__((malloc)) on functions
- */
-#define HAVE_FUNC_ATTRIBUTE_MALLOC 1
-
-/* Define to 1 if your compiler supports __attribute__((noreturn)) on
-   functions */
-#define HAVE_FUNC_ATTRIBUTE_NORETURN 1
-
-/* Define to 1 if your compiler supports __attribute__((pure)) on functions */
-#define HAVE_FUNC_ATTRIBUTE_PURE 1
-
-/* Define to 1 if your compiler supports __attribute__((returns_nonnull)) on
-   functions */
-#define HAVE_FUNC_ATTRIBUTE_RETURNS_NONNULL 1
-
-/* Define to 1 if your compiler supports __attribute__((sentinel)) on
-   functions */
-#define HAVE_FUNC_ATTRIBUTE_SENTINEL 1
-
-/* Define to 1 if you have the `getgid' function. */
-#define HAVE_GETGID 1
-
-/* Define to 1 if you have the `getpagesize' function. */
-#define HAVE_GETPAGESIZE 1
-
-/* Define to 1 if you have the `getuid' function. */
-#define HAVE_GETUID 1
-
-/* Define to 1 if you have the `htole16' intrinsic function. */
-/* #undef HAVE_HTOLE16 */
-
-/* Define to 1 if you have the `htole32' intrinsic function. */
-/* #undef HAVE_HTOLE32 */
-
-/* Define to 1 if you have the `htole64' intrinsic function. */
-/* #undef HAVE_HTOLE64 */
-
-/* Define to 1 if you have the <intrin.h> header file. */
-/* #undef HAVE_INTRIN_H */
-
-/* Define to 1 if you have the <inttypes.h> header file. */
-#define HAVE_INTTYPES_H 1
-
-/* Define to 1 if you have the <io.h> header file. */
-/* #undef HAVE_IO_H */
-
-/* Define to 1 if you have the <machine/endian.h> header file. */
-/* #undef HAVE_MACHINE_ENDIAN_H */
-
-/* Define to 1 if you have the <memory.h> header file. */
-#define HAVE_MEMORY_H 1
-
-/* Define to 1 if you have a working `mmap' system call. */
-#define HAVE_MMAP 1
-
-/* Define to 1 if you have the `pathconf' function. */
-#define HAVE_PATHCONF 1
-
-/* Define to 1 if you have the `realpath' function. */
-#define HAVE_REALPATH 1
-
-/* Define to 1 if you have the `snprintf' function. */
-#define HAVE_SNPRINTF 1
-
-/* Define to 1 if you have the `stat' function. */
-#define HAVE_STAT 1
-
-/* Define to 1 if stdbool.h conforms to C99. */
-#define HAVE_STDBOOL_H 1
-
-/* Define to 1 if your compiler supports C99 extern inline */
-#define HAVE_STDC_INLINE 1
-
-/* Define to 1 if you have the <stdint.h> header file. */
-#define HAVE_STDINT_H 1
-
-/* Define to 1 if you have the <stdlib.h> header file. */
-#define HAVE_STDLIB_H 1
-
-/* Define to 1 if you have the <stdnoreturn.h> header file. */
-#define HAVE_STDNORETURN_H 1
-
-/* Define to 1 if you have the `strcasecmp' function. */
-#define HAVE_STRCASECMP 1
-
-/* Define to 1 if you have the `stricmp' function. */
-/* #undef HAVE_STRICMP */
-
-/* Define to 1 if you have the <strings.h> header file. */
-#define HAVE_STRINGS_H 1
-
-/* Define to 1 if you have the <string.h> header file. */
-#define HAVE_STRING_H 1
-
-/* Define to 1 if you have the `strlcpy' function. */
-/* #undef HAVE_STRLCPY */
-
-/* Define to 1 if you have the `strncasecmp' function. */
-#define HAVE_STRNCASECMP 1
-
-/* Define to 1 if you have the `strnicmp' function. */
-/* #undef HAVE_STRNICMP */
-
-/* Define to 1 if you have the `strnlen' function. */
-#define HAVE_STRNLEN 1
-
-/* Define to 1 if you have the `strrchrnul' function. */
-/* #undef HAVE_STRRCHRNUL */
-
-/* Define to 1 if you have the `strsep' function. */
-#define HAVE_STRSEP 1
-
-/* Define to 1 if the system has the type `struct stat'. */
-#define HAVE_STRUCT_STAT 1
-
-/* Define to 1 if the system has the type `struct _stati64'. */
-/* #undef HAVE_STRUCT__STATI64 */
-
-/* Define to 1 if you have the `sysconf' function. */
-#define HAVE_SYSCONF 1
-
-/* Define to 1 if you have the <sys/endian.h> header file. */
-/* #undef HAVE_SYS_ENDIAN_H */
-
-/* Define to 1 if you have the <sys/mman.h> header file. */
-#define HAVE_SYS_MMAN_H 1
-
-/* Define to 1 if you have the <sys/param.h> header file. */
-#define HAVE_SYS_PARAM_H 1
-
-/* Define to 1 if you have the <sys/stat.h> header file. */
-#define HAVE_SYS_STAT_H 1
-
-/* Define to 1 if you have the <sys/types.h> header file. */
-#define HAVE_SYS_TYPES_H 1
-
-/* Define to 1 if the system has the type `uintptr_t'. */
-#define HAVE_UINTPTR_T 1
-
-/* Define to 1 if you have the <unistd.h> header file. */
-#define HAVE_UNISTD_H 1
-
-/* Define to 1 if you have the `vsnprintf' function. */
-#define HAVE_VSNPRINTF 1
-
-/* Define to 1 if you have the `_access' function. */
-/* #undef HAVE__ACCESS */
-
-/* Define to 1 if you have the `_BitScanReverse' intrinsic function. */
-/* #undef HAVE__BITSCANREVERSE */
-
-/* Define to 1 if you have the `_BitScanReverse64' intrinsic function. */
-/* #undef HAVE__BITSCANREVERSE64 */
-
-/* Define to 1 if the system has the type `_Bool'. */
-#define HAVE__BOOL 1
-
-/* Define to 1 if you have the `_byteswap_uint64' intrinsic function. */
-/* #undef HAVE__BYTESWAP_UINT64 */
-
-/* Define to 1 if you have the `_byteswap_ulong' intrinsic function. */
-/* #undef HAVE__BYTESWAP_ULONG */
-
-/* Define to 1 if you have the `_byteswap_ushort' intrinsic function. */
-/* #undef HAVE__BYTESWAP_USHORT */
-
-/* Define to 1 if you have the `_chsize' function. */
-/* #undef HAVE__CHSIZE */
-
-/* Define to 1 if you have the `_chsize_s' function. */
-/* #undef HAVE__CHSIZE_S */
-
-/* Define to 1 if you have the `_filelengthi64' function. */
-/* #undef HAVE__FILELENGTHI64 */
-
-/* Define to 1 if you have the `_fileno' function. */
-/* #undef HAVE__FILENO */
-
-/* Define to 1 if you have the `_fseeki64' function. */
-/* #undef HAVE__FSEEKI64 */
-
-/* Define to 1 if you have the `_fstati64' function. */
-/* #undef HAVE__FSTATI64 */
-
-/* Define to 1 if you have the `_fullpath' function. */
-/* #undef HAVE__FULLPATH */
-
-/* Define to 1 if you have the `_snprintf' function. */
-/* #undef HAVE__SNPRINTF */
-
-/* Define to 1 if you have the `_stati64' function. */
-/* #undef HAVE__STATI64 */
-
-/* Define to 1 if you have the `_vsnprintf' function. */
-/* #undef HAVE__VSNPRINTF */
-
-/* Define to 1 if you have the `__bswap_16' intrinsic function. */
-/* #undef HAVE___BSWAP_16 */
-
-/* Define to 1 if you have the `__bswap_32' intrinsic function. */
-/* #undef HAVE___BSWAP_32 */
-
-/* Define to 1 if you have the `__bswap_64' intrinsic function. */
-/* #undef HAVE___BSWAP_64 */
-
-/* Define to 1 if you have the `__builtin_bswap16' intrinsic function. */
-#define HAVE___BUILTIN_BSWAP16 1
-
-/* Define to 1 if you have the `__builtin_bswap32' intrinsic function. */
-#define HAVE___BUILTIN_BSWAP32 1
-
-/* Define to 1 if you have the `__builtin_bswap64' intrinsic function. */
-#define HAVE___BUILTIN_BSWAP64 1
-
-/* Define to 1 if you have the `__builtin_clz' intrinsic function. */
-#define HAVE___BUILTIN_CLZ 1
-
-/* Define to 1 if you have the `__builtin_clzl' intrinsic function. */
-#define HAVE___BUILTIN_CLZL 1
-
-/* Define to 1 if you have the `__builtin_clzll' intrinsic function. */
-#define HAVE___BUILTIN_CLZLL 1
-
-/* Define to 1 if you have the `__builtin_constant_p' intrinsic function. */
-#define HAVE___BUILTIN_CONSTANT_P 1
-
-/* Define to 1 if you have the `__builtin_expect' intrinsic function. */
-#define HAVE___BUILTIN_EXPECT 1
-
-/* Define to 1 if you have the `__cpu_to_le16' intrinsic function. */
-/* #undef HAVE___CPU_TO_LE16 */
-
-/* Define to 1 if you have the `__cpu_to_le32' intrinsic function. */
-/* #undef HAVE___CPU_TO_LE32 */
-
-/* Define to 1 if you have the `__cpu_to_le64' intrinsic function. */
-/* #undef HAVE___CPU_TO_LE64 */
-
-/* Define to the address where bug reports for this package should be sent. */
-#define PACKAGE_BUGREPORT ""
-
-/* Define to the full name of this package. */
-#define PACKAGE_NAME ""
-
-/* Define to the full name and version of this package. */
-#define PACKAGE_STRING ""
-
-/* Define to the one symbol short name of this package. */
-#define PACKAGE_TARNAME ""
-
-/* Define to the home page for this package. */
-#define PACKAGE_URL ""
-
-/* Define to the version of this package. */
-#define PACKAGE_VERSION ""
-
-/* Define to 1 if you have the ANSI C header files. */
-#define STDC_HEADERS 1
-
-/* Enable extensions on AIX 3, Interix.  */
-#ifndef _ALL_SOURCE
-#define _ALL_SOURCE 1
-#endif
-/* Enable GNU extensions on systems that have them.  */
-#ifndef _GNU_SOURCE
-#define _GNU_SOURCE 1
-#endif
-/* Enable threading extensions on Solaris.  */
-#ifndef _POSIX_PTHREAD_SEMANTICS
-#define _POSIX_PTHREAD_SEMANTICS 1
-#endif
-/* Enable extensions on HP NonStop.  */
-#ifndef _TANDEM_SOURCE
-#define _TANDEM_SOURCE 1
-#endif
-/* Enable general extensions on Solaris.  */
-#ifndef __EXTENSIONS__
-#define __EXTENSIONS__ 1
-#endif
-
-/* Define to 1 if your processor stores words with the most significant byte
-   first (like Motorola and SPARC, unlike Intel and VAX). */
-/* #undef WORDS_BIGENDIAN */
-
-/* Define to 1 if your processor stores words with the least significant byte
-   first (like Intel and VAX, unlike Motorola and SPARC). */
-#define WORDS_LITTLEENDIAN 1
-
-/* Enable large inode numbers on Mac OS X 10.5.  */
-#ifndef _DARWIN_USE_64_BIT_INODE
-#define _DARWIN_USE_64_BIT_INODE 1
-#endif
-
-/* Number of bits in a file offset, on hosts where this is settable. */
-/* #undef _FILE_OFFSET_BITS */
-
-/* Define to 1 to make fseeko visible on some hosts (e.g. glibc 2.2). */
-/* #undef _LARGEFILE_SOURCE */
-
-/* Define for large files, on AIX-style hosts. */
-/* #undef _LARGE_FILES */
-
-/* Define to 1 if on MINIX. */
-/* #undef _MINIX */
-
-/* Define to 2 if the system does not provide POSIX.1 features except with
-   this defined. */
-/* #undef _POSIX_1_SOURCE */
-
-/* Define to 1 if you need to in order for `stat' and other things to work. */
-/* #undef _POSIX_SOURCE */
-
-/* Define to empty if `const' does not conform to ANSI C. */
-/* #undef const */
-
-/* Define to `__inline__' or `__inline' if that's what the C compiler
-   calls it, or to nothing if 'inline' is not supported under any name.  */
-#ifndef __cplusplus
-/* #undef inline */
-#endif
-
-/* Define to the equivalent of the C99 'restrict' keyword, or to
-   nothing if this is not supported.  Do not define if restrict is
-   supported directly.  */
-#define restrict __restrict
-/* Work around a bug in Sun C++: it does not support _Restrict or
-   __restrict__, even though the corresponding Sun C compiler ends up with
-   "#define restrict _Restrict" or "#define restrict __restrict__" in the
-   previous line.  Perhaps some future version of Sun C++ will work with
-   restrict; if so, hopefully it defines __RESTRICT like Sun C does.  */
-#if defined __SUNPRO_CC && !defined __RESTRICT
-#define _Restrict
-#define __restrict__
-#endif
-
-/* Define to `unsigned int' if <sys/types.h> does not define. */
-/* #undef size_t */
-
-/* Define to the type of an unsigned integer type wide enough to hold a
-   pointer, if such a type exists, and if the system does not define it. */
-/* #undef uintptr_t */
diff --git a/third_party/xla/third_party/nasm/nasm.BUILD b/third_party/xla/third_party/nasm/nasm.BUILD
deleted file mode 100644
index a328d07..0000000
--- a/third_party/xla/third_party/nasm/nasm.BUILD
+++ /dev/null
@@ -1,138 +0,0 @@
-licenses(["notice"])
-
-exports_files(["LICENSE"])
-
-INCLUDES = [
-    ".",
-    "include",
-    "x86",
-    "asm",
-    "disasm",
-    "output",
-]
-
-COPTS = select({
-    ":windows": [],
-    "//conditions:default": [
-        "-w",
-        "-DHAVE_CONFIG_H",
-    ],
-})
-
-cc_library(
-    name = "nasm_2_14_02",
-    srcs = [
-        "asm/assemble.c",
-        "asm/directbl.c",
-        "asm/directiv.c",
-        "asm/error.c",
-        "asm/eval.c",
-        "asm/exprdump.c",
-        "asm/exprlib.c",
-        "asm/float.c",
-        "asm/labels.c",
-        "asm/listing.c",
-        "asm/parser.c",
-        "asm/pptok.c",
-        "asm/pragma.c",
-        "asm/preproc.c",
-        "asm/preproc-nop.c",
-        "asm/quote.c",
-        "asm/rdstrnum.c",
-        "asm/segalloc.c",
-        "asm/stdscan.c",
-        "asm/strfunc.c",
-        "asm/tokhash.c",
-        "common/common.c",
-        "disasm/disasm.c",
-        "disasm/sync.c",
-        "macros/macros.c",
-        "nasmlib/badenum.c",
-        "nasmlib/bsi.c",
-        "nasmlib/crc64.c",
-        "nasmlib/errfile.c",
-        "nasmlib/file.c",
-        "nasmlib/filename.c",
-        "nasmlib/hashtbl.c",
-        "nasmlib/ilog2.c",
-        "nasmlib/malloc.c",
-        "nasmlib/md5c.c",
-        "nasmlib/mmap.c",
-        "nasmlib/path.c",
-        "nasmlib/perfhash.c",
-        "nasmlib/raa.c",
-        "nasmlib/rbtree.c",
-        "nasmlib/readnum.c",
-        "nasmlib/realpath.c",
-        "nasmlib/saa.c",
-        "nasmlib/srcfile.c",
-        "nasmlib/string.c",
-        "nasmlib/strlist.c",
-        "nasmlib/ver.c",
-        "output/codeview.c",
-        "output/legacy.c",
-        "output/nulldbg.c",
-        "output/nullout.c",
-        "output/outaout.c",
-        "output/outas86.c",
-        "output/outbin.c",
-        "output/outcoff.c",
-        "output/outdbg.c",
-        "output/outelf.c",
-        "output/outform.c",
-        "output/outieee.c",
-        "output/outlib.c",
-        "output/outmacho.c",
-        "output/outobj.c",
-        "output/outrdf2.c",
-        "output/strtbl.c",
-        "stdlib/snprintf.c",
-        "stdlib/strlcpy.c",
-        "stdlib/strnlen.c",
-        "stdlib/strrchrnul.c",
-        "stdlib/vsnprintf.c",
-        "x86/disp8.c",
-        "x86/iflag.c",
-        "x86/insnsa.c",
-        "x86/insnsb.c",
-        "x86/insnsd.c",
-        "x86/insnsn.c",
-        "x86/regdis.c",
-        "x86/regflags.c",
-        "x86/regs.c",
-        "x86/regvals.c",
-    ],
-    hdrs = glob([
-        "*.h",
-        "include/*.h",
-        "x86/*.h",
-        "disasm/*.h",
-        "config/*.h",
-        "asm/*.h",
-        "output/*.h",
-        "nasmlib/*.h",
-    ]),
-    copts = COPTS,
-    includes = INCLUDES,
-)
-
-cc_binary(
-    name = "nasm",
-    srcs = [
-        "asm/nasm.c",
-        "nasmlib/zerobuf.c",
-    ],
-    copts = COPTS,
-    includes = INCLUDES,
-    visibility = ["@libjpeg_turbo//:__pkg__"],
-    deps = [
-        ":nasm_2_14_02",
-    ],
-)
-
-config_setting(
-    name = "windows",
-    values = {
-        "cpu": "x64_windows",
-    },
-)
diff --git a/third_party/xla/third_party/nasm/workspace.bzl b/third_party/xla/third_party/nasm/workspace.bzl
deleted file mode 100644
index 5806cba..0000000
--- a/third_party/xla/third_party/nasm/workspace.bzl
+++ /dev/null
@@ -1,18 +0,0 @@
-"""loads the nasm library, used by TF."""
-
-load("//third_party:repo.bzl", "tf_http_archive")
-
-def repo():
-    tf_http_archive(
-        name = "nasm",
-        urls = [
-            "https://storage.googleapis.com/mirror.tensorflow.org/www.nasm.us/pub/nasm/releasebuilds/2.14.02/nasm-2.14.02.tar.bz2",
-            "http://pkgs.fedoraproject.org/repo/pkgs/nasm/nasm-2.14.02.tar.bz2/sha512/d7a6b4cee8dfd603d8d4c976e5287b5cc542fa0b466ff989b743276a6e28114e64289bf02a7819eca63142a5278aa6eed57773007e5f589e15768e6456a8919d/nasm-2.14.02.tar.bz2",
-            "http://www.nasm.us/pub/nasm/releasebuilds/2.14.02/nasm-2.14.02.tar.bz2",
-        ],
-        sha256 = "34fd26c70a277a9fdd54cb5ecf389badedaf48047b269d1008fbc819b24e80bc",
-        strip_prefix = "nasm-2.14.02",
-        build_file = "//third_party/nasm:nasm.BUILD",
-        system_build_file = "//third_party/nasm:BUILD.system",
-        link_files = {"//third_party/nasm:config.h": "config/config.h"},
-    )
diff --git a/third_party/xla/third_party/nccl/BUILD b/third_party/xla/third_party/nccl/BUILD
deleted file mode 100644
index e69de29..0000000
--- a/third_party/xla/third_party/nccl/BUILD
+++ /dev/null
diff --git a/third_party/xla/third_party/nccl/LICENSE b/third_party/xla/third_party/nccl/LICENSE
deleted file mode 100644
index b958518..0000000
--- a/third_party/xla/third_party/nccl/LICENSE
+++ /dev/null
@@ -1,30 +0,0 @@
-
- Copyright (c) 2015-2018, NVIDIA CORPORATION. All rights reserved.
-
- Redistribution and use in source and binary forms, with or without
- modification, are permitted provided that the following conditions
- are met:
-  * Redistributions of source code must retain the above copyright
-    notice, this list of conditions and the following disclaimer.
-  * Redistributions in binary form must reproduce the above copyright
-    notice, this list of conditions and the following disclaimer in the
-    documentation and/or other materials provided with the distribution.
-  * Neither the name of NVIDIA CORPORATION, Lawrence Berkeley National
-    Laboratory, the U.S. Department of Energy, nor the names of their
-    contributors may be used to endorse or promote products derived
-    from this software without specific prior written permission.
-
- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
- EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
- IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
- PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
- CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
- EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
- PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
- PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
- OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
- (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
- OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
- The U.S. Department of Energy funded the development of this software
- under subcontract 7078610 with Lawrence Berkeley National Laboratory.
diff --git a/third_party/xla/third_party/nccl/archive.BUILD b/third_party/xla/third_party/nccl/archive.BUILD
deleted file mode 100644
index 05293fd..0000000
--- a/third_party/xla/third_party/nccl/archive.BUILD
+++ /dev/null
@@ -1,245 +0,0 @@
-# NVIDIA NCCL 2
-# A package of optimized primitives for collective multi-GPU communication.
-
-licenses(["notice"])
-
-exports_files(["LICENSE.txt"])
-
-load("@bazel_skylib//rules:expand_template.bzl", "expand_template")
-load("@bazel_skylib//rules:write_file.bzl", "write_file")
-load(
-    "@local_config_cuda//cuda:build_defs.bzl",
-    "cuda_library",
-)
-load(
-    "@local_config_nccl//:build_defs.bzl",
-    "cuda_rdc_library",
-    "gen_device_srcs",
-)
-
-NCCL_MAJOR = 2
-
-NCCL_MINOR = 16
-
-NCCL_PATCH = 5
-
-NCCL_VERSION = NCCL_MAJOR * 10000 + NCCL_MINOR * 100 + NCCL_PATCH  # e.g., 21605
-
-expand_template(
-    name = "nccl_header_version",
-    out = "src/nccl.h",
-    substitutions = {
-        "${nccl:Major}": str(NCCL_MAJOR),
-        "${nccl:Minor}": str(NCCL_MINOR),
-        "${nccl:Patch}": str(NCCL_PATCH),
-        "${nccl:Suffix}": "\"\"",
-        "${nccl:Version}": str(NCCL_VERSION),
-    },
-    template = "src/nccl.h.in",
-)
-
-# This additional header allows us to determine the configured NCCL version
-# without including the rest of NCCL.
-write_file(
-    name = "nccl_config_header",
-    out = "nccl_config.h",
-    content = [
-        "#define TF_NCCL_VERSION \"{}\"".format(NCCL_MAJOR),
-    ],
-)
-
-cc_library(
-    name = "nccl_config",
-    hdrs = ["nccl_config.h"],
-    include_prefix = "third_party/nccl",
-    visibility = ["//visibility:public"],
-)
-
-cc_library(
-    name = "src_hdrs",
-    hdrs = [
-        "src/include/collectives.h",
-        "src/nccl.h",
-    ],
-    strip_include_prefix = "src",
-)
-
-cc_library(
-    name = "include_hdrs",
-    hdrs = glob(["src/include/**"]),
-    strip_include_prefix = "src/include",
-    deps = ["@local_config_cuda//cuda:cuda_headers"],
-)
-
-cc_library(
-    name = "device_hdrs",
-    hdrs = glob(["src/collectives/device/*.h"]),
-    strip_include_prefix = "src/collectives/device",
-)
-
-# NCCL compiles the same source files with different NCCL_OP/NCCL_TYPE defines.
-# RDC compilation requires that each compiled module has a unique ID. Clang
-# derives the module ID from the path only so we need to copy the files to get
-# different IDs for different parts of compilation. NVCC does not have that
-# problem because it generates IDs based on preprocessed content.
-gen_device_srcs(
-    name = "device_srcs",
-    srcs = [
-        "src/collectives/device/all_gather.cu.cc",
-        "src/collectives/device/all_reduce.cu.cc",
-        "src/collectives/device/broadcast.cu.cc",
-        "src/collectives/device/reduce.cu.cc",
-        "src/collectives/device/reduce_scatter.cu.cc",
-        "src/collectives/device/sendrecv.cu.cc",
-    ],
-)
-
-cuda_rdc_library(
-    name = "device",
-    srcs = [
-        "src/collectives/device/functions.cu.cc",
-        "src/collectives/device/onerank_reduce.cu.cc",
-        ":device_srcs",
-    ] + glob([
-        # Required for header inclusion checking, see below for details.
-        "src/collectives/device/*.h",
-        "src/nccl.h",
-    ]),
-    deps = [
-        ":device_hdrs",
-        ":include_hdrs",
-        ":src_hdrs",
-        "@local_config_cuda//cuda:cuda_headers",
-    ],
-)
-
-cc_library(
-    name = "net",
-    srcs = [
-        "src/transport/coll_net.cc",
-        "src/transport/net.cc",
-    ],
-    linkopts = ["-lrt"],
-    deps = [
-        ":include_hdrs",
-        ":src_hdrs",
-    ],
-)
-
-cc_library(
-    name = "nccl_via_stub",
-    hdrs = ["src/nccl.h"],
-    include_prefix = "third_party/nccl",
-    strip_include_prefix = "src",
-    visibility = ["//visibility:public"],
-    deps = [
-        "@local_config_cuda//cuda:cuda_headers",
-        "@local_tsl//tsl/cuda:nccl_stub",
-    ],
-)
-
-cc_library(
-    name = "nccl_headers",
-    hdrs = ["src/nccl.h"],
-    include_prefix = "third_party/nccl",
-    strip_include_prefix = "src",
-    visibility = ["//visibility:public"],
-    deps = [
-        "@local_config_cuda//cuda:cuda_headers",
-    ],
-)
-
-cc_library(
-    name = "nccl",
-    srcs = glob(
-        include = [
-            "src/**/*.cc",
-            # Required for header inclusion checking, see below for details.
-            "src/graph/*.h",
-        ],
-        # Exclude device-library code.
-        exclude = [
-            "src/collectives/device/**",
-            "src/transport/coll_net.cc",
-            "src/transport/net.cc",
-            "src/enqueue.cc",
-        ],
-    ) + [
-        # Required for header inclusion checking (see
-        # http://docs.bazel.build/versions/master/be/c-cpp.html#hdrs).
-        # Files in src/ which #include "nccl.h" load it from there rather than
-        # from the virtual includes directory.
-        "src/include/collectives.h",
-        "src/nccl.h",
-    ],
-    hdrs = ["src/nccl.h"],
-    include_prefix = "third_party/nccl",
-    linkopts = ["-lrt"],
-    strip_include_prefix = "src",
-    visibility = ["//visibility:public"],
-    deps = [
-        ":device",
-        ":enqueue",
-        ":include_hdrs",
-        ":net",
-        ":src_hdrs",
-    ],
-)
-
-alias(
-    name = "enqueue",
-    actual = select({
-        "@local_config_cuda//cuda:using_clang": ":enqueue_clang",
-        "@local_config_cuda//cuda:using_nvcc": ":enqueue_nvcc",
-    }),
-)
-
-# Kernels and their names have special treatment in CUDA compilation.
-# Specifically, the host-side kernel launch stub (host-side representation of
-# the kernel) ends up having the name which does not match the actual kernel
-# name. In order to correctly refer to the kernel the referring code must be
-# compiled as CUDA.
-cuda_library(
-    name = "enqueue_clang",
-    srcs = [
-        "src/enqueue.cc",
-    ],
-    hdrs = ["src/nccl.h"],
-    copts = [
-        "--cuda-host-only",
-    ],
-    include_prefix = "third_party/nccl",
-    linkopts = ["-lrt"],
-    strip_include_prefix = "src",
-    target_compatible_with = select({
-        "@local_config_cuda//cuda:using_clang": [],
-        "//conditions:default": ["@platforms//:incompatible"],
-    }),
-    visibility = ["//visibility:public"],
-    deps = [
-        ":device",
-        ":include_hdrs",
-        ":src_hdrs",
-    ],
-)
-
-cc_library(
-    name = "enqueue_nvcc",
-    srcs = [
-        "src/enqueue.cc",
-    ],
-    hdrs = ["src/nccl.h"],
-    include_prefix = "third_party/nccl",
-    linkopts = ["-lrt"],
-    strip_include_prefix = "src",
-    target_compatible_with = select({
-        "@local_config_cuda//cuda:using_nvcc": [],
-        "//conditions:default": ["@platforms//:incompatible"],
-    }),
-    visibility = ["//visibility:public"],
-    deps = [
-        ":device",
-        ":include_hdrs",
-        ":src_hdrs",
-    ],
-)
diff --git a/third_party/xla/third_party/nccl/archive.patch b/third_party/xla/third_party/nccl/archive.patch
deleted file mode 100644
index f951a6a..0000000
--- a/third_party/xla/third_party/nccl/archive.patch
+++ /dev/null
@@ -1,59 +0,0 @@
-diff --git a/src/collectives/device/all_gather.cu b/src/collectives/device/all_gather.cu.cc
-similarity index 100%
-rename from src/collectives/device/all_gather.cu
-rename to src/collectives/device/all_gather.cu.cc
-diff --git a/src/collectives/device/all_reduce.cu b/src/collectives/device/all_reduce.cu.cc
-similarity index 100%
-rename from src/collectives/device/all_reduce.cu
-rename to src/collectives/device/all_reduce.cu.cc
-diff --git a/src/collectives/device/broadcast.cu b/src/collectives/device/broadcast.cu.cc
-similarity index 100%
-rename from src/collectives/device/broadcast.cu
-rename to src/collectives/device/broadcast.cu.cc
-diff --git a/src/collectives/device/functions.cu b/src/collectives/device/functions.cu.cc
-similarity index 100%
-rename from src/collectives/device/functions.cu
-rename to src/collectives/device/functions.cu.cc
-diff --git a/src/collectives/device/onerank_reduce.cu b/src/collectives/device/onerank_reduce.cu.cc
-similarity index 100%
-rename from src/collectives/device/onerank_reduce.cu
-rename to src/collectives/device/onerank_reduce.cu.cc
-diff --git a/src/collectives/device/reduce.cu b/src/collectives/device/reduce.cu.cc
-similarity index 100%
-rename from src/collectives/device/reduce.cu
-rename to src/collectives/device/reduce.cu.cc
-diff --git a/src/collectives/device/reduce_scatter.cu b/src/collectives/device/reduce_scatter.cu.cc
-similarity index 100%
-rename from src/collectives/device/reduce_scatter.cu
-rename to src/collectives/device/reduce_scatter.cu.cc
-diff --git a/src/collectives/device/sendrecv.cu b/src/collectives/device/sendrecv.cu.cc
-similarity index 100%
-rename from src/collectives/device/sendrecv.cu
-rename to src/collectives/device/sendrecv.cu.cc
-diff --git a/src/include/nvtx.h b/src/include/nvtx.h
-index 2aeb932..cdc67d2 100644
---- a/src/include/nvtx.h
-+++ b/src/include/nvtx.h
-@@ -37,7 +37,7 @@ struct nccl_domain{static constexpr char const* name{"NCCL"};};
-
- class payload_schema {
-  public:
--  NVTX3_RELAXED_CONSTEXPR explicit payload_schema(const nvtxPayloadSchemaEntry_t entries[], size_t numEntries, const uint64_t schemaId, const char* schemaName = nullptr) noexcept
-+  explicit payload_schema(const nvtxPayloadSchemaEntry_t entries[], size_t numEntries, const uint64_t schemaId, const char* schemaName = nullptr) noexcept
-   {
-     schema_attr.name = schemaName;
-     schema_attr.entries = entries;
-diff --git a/src/collectives/device/common.h b/src/collectives/device/common.h
-index accf8371a..4ab1bfac6 100644
---- a/src/collectives/device/common.h
-+++ b/src/collectives/device/common.h
-@@ -166,7 +166,8 @@ __device__ void ncclKernel(
-       bytes = 0;
-       break;
-     }
--    copyToShmem16(tid%WARP_SIZE, dst, src, bytes);
-+    if (bytes)
-+      copyToShmem16(tid%WARP_SIZE, dst, src, bytes);
-   }
-   __syncthreads(); // publish ncclShmem
- 
\ No newline at end of file
diff --git a/third_party/xla/third_party/nccl/build_defs.bzl.tpl b/third_party/xla/third_party/nccl/build_defs.bzl.tpl
deleted file mode 100644
index 04749be..0000000
--- a/third_party/xla/third_party/nccl/build_defs.bzl.tpl
+++ /dev/null
@@ -1,412 +0,0 @@
-"""Repository rule for NCCL."""
-
-load("@local_config_cuda//cuda:build_defs.bzl", "cuda_default_copts", "cuda_gpu_architectures")
-load("@bazel_tools//tools/cpp:toolchain_utils.bzl", "find_cpp_toolchain")
-
-# CUDA toolkit version as tuple (e.g. '(11, 1)').
-_cuda_version = %{cuda_version}
-_cuda_clang = %{cuda_clang}
-
-def _gen_device_srcs_impl(ctx):
-    ops = ["sum", "prod", "min", "max", "premulsum", "sumpostdiv"]
-    # TF uses CUDA version > 11.0, so enable bf16 type unconditionally.
-    types = ["i8", "u8", "i32", "u32", "i64", "u64", "f16", "bf16", "f32", "f64"]
-    hdr_tail = "****************************************/"
-    defines = "\n\n#define NCCL_OP %d\n#define NCCL_TYPE %d"
-
-    files = []
-    for NCCL_OP, op in enumerate(ops):
-        for NCCL_TYPE, dt in enumerate(types):
-            substitutions = {
-                hdr_tail: hdr_tail + defines % (NCCL_OP, NCCL_TYPE),
-            }
-            for src in ctx.files.srcs:
-                name = "%s_%s_%s" % (op, dt, src.basename)
-                file = ctx.actions.declare_file(name, sibling = src)
-                ctx.actions.expand_template(
-                    output = file,
-                    template = src,
-                    substitutions = substitutions,
-                )
-                files.append(file)
-    return [DefaultInfo(files = depset(files))]
-
-gen_device_srcs = rule(
-    implementation = _gen_device_srcs_impl,
-    attrs = {
-        "srcs": attr.label_list(allow_files = True),
-    },
-)
-"""Adds prefix to each file name in srcs and adds #define NCCL_OP."""
-
-def _rdc_copts():
-    """Returns copts for compiling relocatable device code."""
-
-    # The global functions can not have a lower register count than the
-    # device functions. This is enforced by setting a fixed register count.
-    # https://github.com/NVIDIA/nccl/blob/f93fe9bfd94884cec2ba711897222e0df5569a53/makefiles/common.mk#L48
-    maxrregcount = "-maxrregcount=96"
-
-    return cuda_default_copts() + select({
-        "@local_config_cuda//:is_cuda_compiler_nvcc": [
-            "-nvcc_options",
-            "relocatable-device-code=true",
-            "-nvcc_options",
-            "ptxas-options=" + maxrregcount,
-            "-nvcc_options",
-            "extended-lambda",
-        ],
-        "@local_config_cuda//:is_cuda_compiler_clang": [
-            "-fcuda-rdc",
-            "-Xcuda-ptxas",
-            maxrregcount,
-        ],
-        "//conditions:default": [],
-    })
-
-def _lookup_file(filegroup, path):
-    """Extracts file at (relative) path in filegroup."""
-    for file in filegroup.files:
-        if file.path.endswith(path):
-            return file
-    return None
-
-def _pic_only(files):
-    """Returns the PIC files if there are any in 'files', otherwise 'files'."""
-    pic_only = [f for f in files if f.basename.find(".pic.") >= 0]
-    return pic_only if pic_only else files
-
-def _device_link_impl(ctx):
-    if not ctx.attr.gpu_archs:
-        fail("No GPU architecture specified. NCCL requires --config=cuda or similar.")
-
-    inputs = []
-    for dep in ctx.attr.deps:
-        inputs += dep.files.to_list()
-    inputs = _pic_only(inputs)
-
-    # Device-link to cubins for each architecture.
-    name = ctx.attr.name
-    register_h = None
-    cubins = []
-    images = []
-    for arch in ctx.attr.gpu_archs:
-        arch = arch.replace("compute_", "sm_")  # PTX is JIT-linked at runtime.
-        cubin = ctx.actions.declare_file("%s_%s.cubin" % (name, arch))
-        register_h = ctx.actions.declare_file("%s_register_%s.h" % (name, arch))
-        ctx.actions.run(
-            outputs = [register_h, cubin],
-            inputs = inputs,
-            executable = ctx.file._nvlink,
-            arguments = ctx.attr.nvlink_args + [
-                "--arch=%s" % arch,
-                "--register-link-binaries=%s" % register_h.path,
-                "--output-file=%s" % cubin.path,
-            ] + [file.path for file in inputs],
-            mnemonic = "nvlink",
-            use_default_shell_env = True,
-        )
-        cubins.append(cubin)
-        images.append("--image=profile=%s,file=%s" % (arch, cubin.path))
-
-    # Generate fatbin header from all cubins.
-    tmp_fatbin = ctx.actions.declare_file("%s.fatbin" % name)
-    fatbin_h = ctx.actions.declare_file("%s_fatbin.h" % name)
-    bin2c = ctx.file._bin2c
-    arguments_list = [
-        "-64",
-        "--cmdline=--compile-only",
-        "--link",
-        "--compress-all",
-        "--create=%s" % tmp_fatbin.path,
-        "--embedded-fatbin=%s" % fatbin_h.path,
-    ]
-    if _cuda_version <= (10, 1):
-        arguments_list.append("--bin2c-path=%s" % bin2c.dirname)
-    ctx.actions.run(
-        outputs = [tmp_fatbin, fatbin_h],
-        inputs = cubins,
-        executable = ctx.file._fatbinary,
-        arguments = arguments_list + images,
-        tools = [bin2c],
-        mnemonic = "fatbinary",
-        use_default_shell_env = True,
-    )
-
-    # Generate the source file #including the headers generated above.
-    ctx.actions.expand_template(
-        output = ctx.outputs.out,
-        template = ctx.file._link_stub,
-        substitutions = {
-            "REGISTERLINKBINARYFILE": '"%s"' % register_h.short_path,
-            "FATBINFILE": '"%s"' % fatbin_h.short_path,
-        },
-    )
-
-    return [DefaultInfo(files = depset([register_h, fatbin_h]))]
-
-_device_link = rule(
-    implementation = _device_link_impl,
-    attrs = {
-        "deps": attr.label_list(),
-        "out": attr.output(mandatory = True),
-        "gpu_archs": attr.string_list(),
-        "nvlink_args": attr.string_list(),
-        "_nvlink": attr.label(
-            default = Label("@local_config_cuda//cuda:cuda/bin/nvlink"),
-            allow_single_file = True,
-            executable = True,
-            cfg = "host",
-        ),
-        "_fatbinary": attr.label(
-            default = Label("@local_config_cuda//cuda:cuda/bin/fatbinary"),
-            allow_single_file = True,
-            executable = True,
-            cfg = "host",
-        ),
-        "_bin2c": attr.label(
-            default = Label("@local_config_cuda//cuda:cuda/bin/bin2c"),
-            allow_single_file = True,
-            executable = True,
-            cfg = "host",
-        ),
-        "_link_stub": attr.label(
-            default = Label("@local_config_cuda//cuda:cuda/bin/crt/link.stub"),
-            allow_single_file = True,
-        ),
-    },
-)
-"""Links device code and generates source code for kernel registration."""
-
-def _prune_relocatable_code_impl(ctx):
-    """Clears __nv_relfatbin section containing relocatable device code."""
-
-    if _cuda_version < (11, 3):
-        # -no-relocatable-elf not supported, return unpruned input.
-        return ctx.attr.input[DefaultInfo]
-
-    # nvcc --generate-code options for the active set of cuda architectures.
-    gencodes = []
-    for code in ctx.attr.gpu_archs:
-        arch = code.replace("compute_", "sm_")
-        if code != arch:
-            gencodes.append((arch, arch))
-        gencodes.append((arch, code))
-
-    outputs = []
-    for input in ctx.files.input:
-        output = ctx.actions.declare_file(
-            "pruned_" + input.basename,
-            sibling = input,
-        )
-        arguments = (
-            ["--generate-code=arch=%s,code=%s" % code for code in gencodes] +
-            ["-no-relocatable-elf", "--output-file=%s" % output.path, str(input.path)]
-        )
-        ctx.actions.run(
-            outputs = [output],
-            inputs = [input],
-            executable = ctx.file._nvprune,
-            arguments = arguments,
-            mnemonic = "nvprune",
-            use_default_shell_env = True,
-        )
-        outputs.append(output)
-
-    return DefaultInfo(files = depset(outputs))
-
-_prune_relocatable_code = rule(
-    implementation = _prune_relocatable_code_impl,
-    attrs = {
-        "input": attr.label(mandatory = True, allow_files = True),
-        "gpu_archs": attr.string_list(),
-        "_nvprune": attr.label(
-            default = Label("@local_config_cuda//cuda:cuda/bin/nvprune"),
-            allow_single_file = True,
-            executable = True,
-            cfg = "host",
-        ),
-    },
-)
-
-def _merge_archive_impl(ctx):
-    # Generate an mri script to the merge archives in srcs and pass it to 'ar'.
-    # See https://stackoverflow.com/a/23621751.
-    files = _pic_only(ctx.files.srcs)
-    mri_script = "create " + ctx.outputs.out.path
-    for f in files:
-        mri_script += r"\naddlib " + f.path
-    mri_script += r"\nsave\nend"
-
-    cc_toolchain = find_cpp_toolchain(ctx)
-    ctx.actions.run_shell(
-        inputs = ctx.files.srcs,  # + ctx.files._crosstool,
-        outputs = [ctx.outputs.out],
-        command = "echo -e \"%s\" | %s -M" % (mri_script, cc_toolchain.ar_executable),
-        use_default_shell_env = True,
-    )
-
-_merge_archive = rule(
-    implementation = _merge_archive_impl,
-    attrs = {
-        "srcs": attr.label_list(mandatory = True, allow_files = True),
-        "_cc_toolchain": attr.label(
-            default = "@bazel_tools//tools/cpp:current_cc_toolchain",
-        ),
-        # "_crosstool": attr.label_list(
-        #     cfg = "host",
-        #     default = ["@bazel_tools//tools/cpp:crosstool"]
-        # ),
-    },
-    outputs = {"out": "lib%{name}.a"},
-)
-"""Merges srcs into a single archive."""
-
-def cuda_rdc_library(name, hdrs = None, copts = None, linkstatic = True, **kwargs):
-    r"""Produces a cuda_library using separate compilation and linking.
-
-    CUDA separate compilation and linking allows device function calls across
-    translation units. This is different from the normal whole program
-    compilation where each translation unit contains all device code. For more
-    background, see
-    https://devblogs.nvidia.com/separate-compilation-linking-cuda-device-code/,
-    https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#nvcc-options-for-separate-compilation
-
-    During separate compilation, the different CUDA source files are compiled
-    to 'relocatable device code' (RDC) and embedded in the host object files.
-    When using nvcc, linking the device code for each supported GPU
-    architecture and generating kernel registration code for the CUDA runtime
-    is handled automatically. Clang supports generating relocatable device
-    code, but it can't link it. We therefore rely on tools provided by the CUDA
-    SDK to link the device code and generate the host code to register the
-    kernels.
-
-    The nvlink tool extracts the RDC code from the object files and links it
-    into cubin files, one per GPU architecture. It also produces a header file
-    with a list of kernel names to register. The cubins are merged into a
-    binary blob using the fatbinary tool, and converted to a C header file with
-    the help of the bin2c tool. The registration header file, the fatbinary
-    header file, and the link.stub file (shipped with the CUDA SDK) are
-    compiled as ordinary host code.
-
-    Here is a diagram of the CUDA separate compilation trajectory:
-
-     x.cu.cc    y.cu.cc
-           \    /            cc_library (compile RDC and archive)
-            xy.a
-           /    \            * nvlink
-    register.h  xy.cubin
-          :      |           * fatbinary and bin2c
-          :     xy.fatbin.h
-          :      :           * #include
-          dlink.cc           * Expanded from crt/dlink.stub template
-             |               cc_library (host compile and archive)
-          dlink.a
-
-    The steps marked with '*' are implemented in the _device_link rule.
-
-    The intermediate relocatable device code in xy.a is no longer needed at
-    this point and the corresponding section is replaced with an empty one using
-    objcopy. We do not remove the section completely because it is referenced by
-    relocations, and removing those as well breaks fatbin registration.
-
-    The object files in both xy.a and dlink.a reference symbols defined in the
-    other archive. The separate archives are a side effect of using two
-    cc_library targets to implement a single compilation trajectory. We could
-    fix this once bazel supports C++ sandwich. For now, we just merge the two
-    archives to avoid unresolved symbols:
-
-                    xy.a
-                     |         objcopy --update-section __nv_relfatbin=''
-    dlink.a     xy_pruned.a
-         \           /         merge archive
-          xy_merged.a
-              |                cc_library (or alternatively, cc_import)
-         final target
-
-    Another complication is that cc_library produces (depending on the
-    configuration) both PIC and non-PIC archives, but the distinction
-    is hidden from Starlark until C++ sandwich becomes available. We work
-    around this by dropping the non-PIC files if PIC files are available.
-
-    Args:
-      name: Target name.
-      hdrs: Header files.
-      copts: Compiler options.
-      linkstatic: Must be true.
-      **kwargs: Any other arguments.
-    """
-
-    if not hdrs:
-        hdrs = []
-    if not copts:
-        copts = []
-
-    # Compile host and device code into library.
-    lib = name + "_lib"
-    native.cc_library(
-        name = lib,
-        hdrs = hdrs,
-        copts = _rdc_copts() + copts,
-        linkstatic = linkstatic,
-        **kwargs
-    )
-
-    # Generate source file containing linked device code.
-    dlink_hdrs = name + "_dlink_hdrs"
-    dlink_cc = name + "_dlink.cc"
-    _device_link(
-        name = dlink_hdrs,
-        deps = [lib],
-        out = dlink_cc,
-        gpu_archs = cuda_gpu_architectures(),
-        nvlink_args = select({
-            "@local_tsl//tsl:linux_x86_64": ["--cpu-arch=X86_64"],
-            "@local_tsl//tsl:linux_ppc64le": ["--cpu-arch=PPC64LE"],
-            "//conditions:default": [],
-        }),
-    )
-
-    # Compile the source file into a library.
-    dlink = name + "_dlink"
-    native.cc_library(
-        name = dlink,
-        srcs = [dlink_cc],
-        textual_hdrs = [dlink_hdrs],
-        deps = [
-            "@local_config_cuda//cuda:cuda_headers",
-        ],
-        defines = [
-            # Silence warning about including internal header.
-            "__CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__",
-            # Macros that need to be defined starting with CUDA 10.
-            "__NV_EXTRA_INITIALIZATION=",
-            "__NV_EXTRA_FINALIZATION=",
-        ],
-        linkstatic = linkstatic,
-    )
-
-    # Remove intermediate relocatable device code.
-    pruned = name + "_pruned"
-    _prune_relocatable_code(
-        name = pruned,
-        input = lib,
-        gpu_archs = cuda_gpu_architectures(),
-    )
-
-    # Repackage the two libs into a single archive. This is required because
-    # both libs reference symbols defined in the other one. For details, see
-    # https://eli.thegreenplace.net/2013/07/09/library-order-in-static-linking
-    merged = name + "_merged"
-    _merge_archive(
-        name = merged,
-        srcs = [pruned, dlink],
-    )
-
-    # Create cc target from archive.
-    native.cc_library(
-        name = name,
-        srcs = [merged],
-        hdrs = hdrs,
-        linkstatic = linkstatic,
-    )
diff --git a/third_party/xla/third_party/nccl/nccl_configure.bzl b/third_party/xla/third_party/nccl/nccl_configure.bzl
deleted file mode 100644
index 3ae8aba..0000000
--- a/third_party/xla/third_party/nccl/nccl_configure.bzl
+++ /dev/null
@@ -1,212 +0,0 @@
-"""Repository rule for NCCL configuration.
-
-`nccl_configure` depends on the following environment variables:
-
-  * `TF_NCCL_VERSION`: Installed NCCL version or empty to build from source.
-  * `NCCL_INSTALL_PATH` (deprecated): The installation path of the NCCL library.
-  * `NCCL_HDR_PATH` (deprecated): The installation path of the NCCL header 
-    files.
-  * `TF_CUDA_PATHS`: The base paths to look for CUDA and cuDNN. Default is
-    `/usr/local/cuda,usr/`.
-  * `TF_CUDA_CLANG`: "1" if using Clang, "0" if using NVCC.
-  * `TF_NCCL_USE_STUB`: "1" if a NCCL stub that loads NCCL dynamically should
-    be used, "0" if NCCL should be linked in statically.
-
-"""
-
-load(
-    "//third_party/gpus:cuda_configure.bzl",
-    "enable_cuda",
-    "find_cuda_config",
-)
-load(
-    "//third_party/remote_config:common.bzl",
-    "config_repo_label",
-    "get_cpu_value",
-    "get_host_environ",
-)
-
-_CUDA_TOOLKIT_PATH = "CUDA_TOOLKIT_PATH"
-_NCCL_HDR_PATH = "NCCL_HDR_PATH"
-_NCCL_INSTALL_PATH = "NCCL_INSTALL_PATH"
-_TF_CUDA_COMPUTE_CAPABILITIES = "TF_CUDA_COMPUTE_CAPABILITIES"
-_TF_NCCL_VERSION = "TF_NCCL_VERSION"
-_TF_NEED_CUDA = "TF_NEED_CUDA"
-_TF_CUDA_PATHS = "TF_CUDA_PATHS"
-_TF_CUDA_CLANG = "TF_CUDA_CLANG"
-_TF_NCCL_USE_STUB = "TF_NCCL_USE_STUB"
-
-_DEFINE_NCCL_MAJOR = "#define NCCL_MAJOR"
-_DEFINE_NCCL_MINOR = "#define NCCL_MINOR"
-_DEFINE_NCCL_PATCH = "#define NCCL_PATCH"
-
-_NCCL_DUMMY_BUILD_CONTENT = """
-filegroup(
-  name = "LICENSE",
-  visibility = ["//visibility:public"],
-)
-
-cc_library(
-  name = "nccl",
-  visibility = ["//visibility:public"],
-)
-
-cc_library(
-  name = "nccl_config",
-  hdrs = ["nccl_config.h"],
-  include_prefix = "third_party/nccl",
-  visibility = ["//visibility:public"],
-)
-"""
-
-_NCCL_ARCHIVE_BUILD_CONTENT = """
-filegroup(
-  name = "LICENSE",
-  data = ["@nccl_archive//:LICENSE.txt"],
-  visibility = ["//visibility:public"],
-)
-
-alias(
-  name = "nccl",
-  actual = "@nccl_archive//:nccl",
-  visibility = ["//visibility:public"],
-)
-
-alias(
-  name = "nccl_config",
-  actual = "@nccl_archive//:nccl_config",
-  visibility = ["//visibility:public"],
-)
-"""
-
-_NCCL_ARCHIVE_STUB_BUILD_CONTENT = """
-filegroup(
-  name = "LICENSE",
-  data = ["@nccl_archive//:LICENSE.txt"],
-  visibility = ["//visibility:public"],
-)
-
-alias(
-  name = "nccl",
-  actual = "@nccl_archive//:nccl_via_stub",
-  visibility = ["//visibility:public"],
-)
-
-alias(
-  name = "nccl_headers",
-  actual = "@nccl_archive//:nccl_headers",
-  visibility = ["//visibility:public"],
-)
-
-alias(
-  name = "nccl_config",
-  actual = "@nccl_archive//:nccl_config",
-  visibility = ["//visibility:public"],
-)
-"""
-
-def _label(file):
-    return Label("//third_party/nccl:{}".format(file))
-
-def _create_local_nccl_repository(repository_ctx):
-    nccl_version = get_host_environ(repository_ctx, _TF_NCCL_VERSION, "")
-    if nccl_version:
-        nccl_version = nccl_version.split(".")[0]
-
-    cuda_config = find_cuda_config(repository_ctx, ["cuda"])
-    cuda_version = cuda_config["cuda_version"].split(".")
-
-    if nccl_version == "":
-        # Alias to open source build from @nccl_archive.
-        if get_host_environ(repository_ctx, _TF_NCCL_USE_STUB, "0") == "0":
-            repository_ctx.file("BUILD", _NCCL_ARCHIVE_BUILD_CONTENT)
-        else:
-            repository_ctx.file("BUILD", _NCCL_ARCHIVE_STUB_BUILD_CONTENT)
-
-        repository_ctx.template(
-            "build_defs.bzl",
-            _label("build_defs.bzl.tpl"),
-            {
-                "%{cuda_version}": "(%s, %s)" % tuple(cuda_version),
-                "%{cuda_clang}": repr(get_host_environ(repository_ctx, _TF_CUDA_CLANG)),
-            },
-        )
-    else:
-        # Create target for locally installed NCCL.
-        config = find_cuda_config(repository_ctx, ["nccl"])
-        config_wrap = {
-            "%{nccl_version}": config["nccl_version"],
-            "%{nccl_header_dir}": config["nccl_include_dir"],
-            "%{nccl_library_dir}": config["nccl_library_dir"],
-        }
-        repository_ctx.template("BUILD", _label("system.BUILD.tpl"), config_wrap)
-
-def _create_remote_nccl_repository(repository_ctx, remote_config_repo):
-    repository_ctx.template(
-        "BUILD",
-        config_repo_label(remote_config_repo, ":BUILD"),
-        {},
-    )
-
-    nccl_version = get_host_environ(repository_ctx, _TF_NCCL_VERSION, "")
-    if nccl_version == "":
-        repository_ctx.template(
-            "build_defs.bzl",
-            config_repo_label(remote_config_repo, ":build_defs.bzl"),
-            {},
-        )
-
-def _nccl_autoconf_impl(repository_ctx):
-    if (not enable_cuda(repository_ctx) or
-        get_cpu_value(repository_ctx) not in ("Linux", "FreeBSD")):
-        # Add a dummy build file to make bazel query happy.
-        repository_ctx.file("BUILD", _NCCL_DUMMY_BUILD_CONTENT)
-        repository_ctx.file("nccl_config.h", "#define TF_NCCL_VERSION \"\"")
-    elif get_host_environ(repository_ctx, "TF_NCCL_CONFIG_REPO") != None:
-        _create_remote_nccl_repository(repository_ctx, get_host_environ(repository_ctx, "TF_NCCL_CONFIG_REPO"))
-    else:
-        _create_local_nccl_repository(repository_ctx)
-
-_ENVIRONS = [
-    _CUDA_TOOLKIT_PATH,
-    _NCCL_HDR_PATH,
-    _NCCL_INSTALL_PATH,
-    _TF_NCCL_VERSION,
-    _TF_CUDA_COMPUTE_CAPABILITIES,
-    _TF_NEED_CUDA,
-    _TF_CUDA_PATHS,
-    _TF_CUDA_CLANG,
-]
-
-remote_nccl_configure = repository_rule(
-    implementation = _create_local_nccl_repository,
-    environ = _ENVIRONS,
-    remotable = True,
-    attrs = {
-        "environ": attr.string_dict(),
-        "_find_cuda_config": attr.label(
-            default = Label("@local_xla//third_party/gpus:find_cuda_config.py"),
-        ),
-    },
-)
-
-nccl_configure = repository_rule(
-    implementation = _nccl_autoconf_impl,
-    environ = _ENVIRONS,
-    attrs = {
-        "_find_cuda_config": attr.label(
-            default = Label("@local_xla//third_party/gpus:find_cuda_config.py"),
-        ),
-    },
-)
-"""Detects and configures the NCCL configuration.
-
-Add the following to your WORKSPACE FILE:
-
-```python
-nccl_configure(name = "local_config_nccl")
-```
-
-Args:
-  name: A unique name for this workspace rule.
-"""
diff --git a/third_party/xla/third_party/nccl/system.BUILD.tpl b/third_party/xla/third_party/nccl/system.BUILD.tpl
deleted file mode 100644
index 6e2a22a..0000000
--- a/third_party/xla/third_party/nccl/system.BUILD.tpl
+++ /dev/null
@@ -1,62 +0,0 @@
-load("@bazel_skylib//rules:write_file.bzl", "write_file")
-load(
-    "@local_tsl//tsl/platform/default:cuda_build_defs.bzl",
-    "cuda_rpath_flags"
-)
-
-filegroup(
-    name = "LICENSE",
-    visibility = ["//visibility:public"],
-)
-
-cc_library(
-    name = "nccl",
-    srcs = ["libnccl.so.%{nccl_version}"],
-    hdrs = ["nccl.h"],
-    include_prefix = "third_party/nccl",
-    visibility = ["//visibility:public"],
-    deps = [
-        "@local_config_cuda//cuda:cuda_headers",
-    ],
-    linkopts = cuda_rpath_flags("nvidia/nccl/lib"),
-)
-
-cc_library(
-    name = "nccl_headers",
-    hdrs = ["nccl.h"],
-    include_prefix = "third_party/nccl",
-    visibility = ["//visibility:public"],
-    deps = [
-        "@local_config_cuda//cuda:cuda_headers",
-    ],
-)
-
-genrule(
-    name = "nccl-files",
-    outs = [
-        "libnccl.so.%{nccl_version}",
-        "nccl.h",
-    ],
-    cmd = """
-cp "%{nccl_header_dir}/nccl.h" "$(@D)/nccl.h" &&
-cp "%{nccl_library_dir}/libnccl.so.%{nccl_version}" \
-  "$(@D)/libnccl.so.%{nccl_version}"
-""",
-)
-
-# This additional header allows us to determine the configured NCCL version
-# without including the rest of NCCL.
-write_file(
-    name = "nccl_config_header",
-    out = "nccl_config.h",
-    content = [
-        "#define TF_NCCL_VERSION \"%{nccl_version}\""
-    ]
-)
-
-cc_library(
-    name = "nccl_config",
-    hdrs = ["nccl_config.h"],
-    include_prefix = "third_party/nccl",
-    visibility = ["//visibility:public"],
-)
diff --git a/third_party/xla/third_party/nsync.patch b/third_party/xla/third_party/nsync.patch
deleted file mode 100644
index b766303..0000000
--- a/third_party/xla/third_party/nsync.patch
+++ /dev/null
@@ -1,9 +0,0 @@
-The "version" file at the root of the nsync source tree conflicts with the C++20
-"version" header in Windows builds.
-diff --git a/VERSION b/VERSION
-deleted file mode 100644
-index 53cc1a6..0000000
---- a/VERSION
-+++ /dev/null
-@@ -1 +0,0 @@
--1.25.0
diff --git a/third_party/xla/third_party/nvtx.BUILD b/third_party/xla/third_party/nvtx.BUILD
deleted file mode 100644
index d3d4258..0000000
--- a/third_party/xla/third_party/nvtx.BUILD
+++ /dev/null
@@ -1,20 +0,0 @@
-#Description : NVIDIA Tools Extension (NVTX) library for adding profiling annotations to applications.
-
-package(
-    default_visibility = ["//visibility:public"],
-)
-
-licenses(["restricted"])  # NVIDIA proprietary license
-
-filegroup(
-    name = "nvtx_header_files",
-    srcs = glob([
-        "nvtx3/**",
-    ]),
-)
-
-cc_library(
-    name = "nvtx",
-    hdrs = [":nvtx_header_files"],
-    include_prefix = "third_party",
-)
diff --git a/third_party/xla/third_party/png.BUILD b/third_party/xla/third_party/png.BUILD
deleted file mode 100644
index 0e2a5d0..0000000
--- a/third_party/xla/third_party/png.BUILD
+++ /dev/null
@@ -1,70 +0,0 @@
-# Description:
-#   libpng is the official PNG reference library.
-
-licenses(["notice"])  # BSD/MIT-like license
-
-exports_files(["LICENSE"])
-
-cc_library(
-    name = "png",
-    srcs = [
-        "png.c",
-        "pngdebug.h",
-        "pngerror.c",
-        "pngget.c",
-        "pnginfo.h",
-        "pnglibconf.h",
-        "pngmem.c",
-        "pngpread.c",
-        "pngpriv.h",
-        "pngread.c",
-        "pngrio.c",
-        "pngrtran.c",
-        "pngrutil.c",
-        "pngset.c",
-        "pngstruct.h",
-        "pngtrans.c",
-        "pngwio.c",
-        "pngwrite.c",
-        "pngwtran.c",
-        "pngwutil.c",
-    ] + select({
-        ":windows": [
-            "intel/filter_sse2_intrinsics.c",
-            "intel/intel_init.c",
-        ],
-        "@local_tsl//tsl:linux_ppc64le": [
-            "powerpc/filter_vsx_intrinsics.c",
-            "powerpc/powerpc_init.c",
-        ],
-        "//conditions:default": [
-        ],
-    }),
-    hdrs = [
-        "png.h",
-        "pngconf.h",
-    ],
-    copts = select({
-        ":windows": ["-DPNG_INTEL_SSE_OPT=1"],
-        "//conditions:default": [],
-    }),
-    includes = ["."],
-    linkopts = select({
-        ":windows": [],
-        "//conditions:default": ["-lm"],
-    }),
-    visibility = ["//visibility:public"],
-    deps = ["@zlib"],
-)
-
-genrule(
-    name = "snappy_stubs_public_h",
-    srcs = ["scripts/pnglibconf.h.prebuilt"],
-    outs = ["pnglibconf.h"],
-    cmd = "sed -e 's/PNG_ZLIB_VERNUM 0/PNG_ZLIB_VERNUM 0x12d0/' $< >$@",
-)
-
-config_setting(
-    name = "windows",
-    values = {"cpu": "x64_windows"},
-)
diff --git a/third_party/xla/third_party/png_fix_rpi.patch b/third_party/xla/third_party/png_fix_rpi.patch
deleted file mode 100644
index df6cfd7..0000000
--- a/third_party/xla/third_party/png_fix_rpi.patch
+++ /dev/null
@@ -1,16 +0,0 @@
-diff -r -u ./scripts/pnglibconf.h.prebuilt ./scripts/pnglibconf.h.prebuilt
---- ./scripts/pnglibconf.h.prebuilt
-+++ ./scripts/pnglibconf.h.prebuilt
-@@ -19,6 +19,12 @@
- #define PNG_ALIGNED_MEMORY_SUPPORTED
- /*#undef PNG_ARM_NEON_API_SUPPORTED*/
- /*#undef PNG_ARM_NEON_CHECK_SUPPORTED*/
-+
-+/* Workaround not having a great build file by forcing
-+ * png filter optimization to be disabled on arm */
-+#define PNG_ARM_NEON_OPT 0
-+
-+
- #define PNG_BENIGN_ERRORS_SUPPORTED
- #define PNG_BENIGN_READ_ERRORS_SUPPORTED
- /*#undef PNG_BENIGN_WRITE_ERRORS_SUPPORTED*/
diff --git a/third_party/xla/third_party/protobuf/BUILD b/third_party/xla/third_party/protobuf/BUILD
deleted file mode 100644
index e69de29..0000000
--- a/third_party/xla/third_party/protobuf/BUILD
+++ /dev/null
diff --git a/third_party/xla/third_party/protobuf/protobuf.patch b/third_party/xla/third_party/protobuf/protobuf.patch
deleted file mode 100644
index 9d928ba..0000000
--- a/third_party/xla/third_party/protobuf/protobuf.patch
+++ /dev/null
@@ -1,141 +0,0 @@
-diff --git a/BUILD.bazel b/BUILD.bazel
---- a/BUILD.bazel	(revision 90b73ac3f0b10320315c2ca0d03a5a9b095d2f66)
-+++ b/BUILD.bazel	(date 1670471682469)
-@@ -68,6 +68,7 @@
-     copts = COPTS,
-     includes = ["src/"],
-     linkopts = LINK_OPTS,
-+    alwayslink = 1,
-     visibility = ["//visibility:public"],
- )
-
-@@ -135,6 +136,7 @@
-     copts = COPTS,
-     includes = ["src/"],
-     linkopts = LINK_OPTS,
-+    alwayslink = 1,
-     visibility = ["//visibility:public"],
-     deps = [":protobuf_lite"] + select({
-         "//build_defs:config_msvc": [],
-diff --git a/python/google/protobuf/pyext/descriptor.cc b/python/google/protobuf/pyext/descriptor.cc
-index 162531226..e93ec4809 100644
---- a/python/google/protobuf/pyext/descriptor.cc
-+++ b/python/google/protobuf/pyext/descriptor.cc
-@@ -58,6 +58,37 @@
-               : 0)                                               \
-        : PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
- 
-+#if PY_VERSION_HEX < 0x030900B1 && !defined(PYPY_VERSION)
-+static PyCodeObject* PyFrame_GetCode(PyFrameObject *frame)
-+{
-+    Py_INCREF(frame->f_code);
-+    return frame->f_code;
-+}
-+
-+static PyFrameObject* PyFrame_GetBack(PyFrameObject *frame)
-+{
-+    Py_XINCREF(frame->f_back);
-+    return frame->f_back;
-+}
-+#endif
-+
-+#if PY_VERSION_HEX < 0x030B00A7 && !defined(PYPY_VERSION)
-+static PyObject* PyFrame_GetLocals(PyFrameObject *frame)
-+{
-+    if (PyFrame_FastToLocalsWithError(frame) < 0) {
-+        return NULL;
-+    }
-+    Py_INCREF(frame->f_locals);
-+    return frame->f_locals;
-+}
-+
-+static PyObject* PyFrame_GetGlobals(PyFrameObject *frame)
-+{
-+    Py_INCREF(frame->f_globals);
-+    return frame->f_globals;
-+}
-+#endif
-+
- namespace google {
- namespace protobuf {
- namespace python {
-@@ -96,48 +127,66 @@ bool _CalledFromGeneratedFile(int stacklevel) {
-   // This check is not critical and is somewhat difficult to implement correctly
-   // in PyPy.
-   PyFrameObject* frame = PyEval_GetFrame();
-+  PyCodeObject* frame_code = nullptr;
-+  PyObject* frame_globals = nullptr;
-+  PyObject* frame_locals = nullptr;
-+  bool result = false;
-+
-   if (frame == nullptr) {
--    return false;
-+    goto exit;
-   }
-+  Py_INCREF(frame);
-   while (stacklevel-- > 0) {
--    frame = frame->f_back;
-+    PyFrameObject* next_frame = PyFrame_GetBack(frame);
-+    Py_DECREF(frame);
-+    frame = next_frame;
-     if (frame == nullptr) {
--      return false;
-+      goto exit;
-     }
-   }
- 
--  if (frame->f_code->co_filename == nullptr) {
--    return false;
-+  frame_code = PyFrame_GetCode(frame);
-+  if (frame_code->co_filename == nullptr) {
-+    goto exit;
-   }
-   char* filename;
-   Py_ssize_t filename_size;
--  if (PyString_AsStringAndSize(frame->f_code->co_filename,
-+  if (PyString_AsStringAndSize(frame_code->co_filename,
-                                &filename, &filename_size) < 0) {
-     // filename is not a string.
-     PyErr_Clear();
--    return false;
-+    goto exit;
-   }
-   if ((filename_size < 3) ||
-       (strcmp(&filename[filename_size - 3], ".py") != 0)) {
-     // Cython's stack does not have .py file name and is not at global module
-     // scope.
--    return true;
-+    result = true;
-+    goto exit;
-   }
-   if (filename_size < 7) {
-     // filename is too short.
--    return false;
-+    goto exit;
-   }
-   if (strcmp(&filename[filename_size - 7], "_pb2.py") != 0) {
-     // Filename is not ending with _pb2.
--    return false;
-+    goto exit;
-   }
- 
--  if (frame->f_globals != frame->f_locals) {
-+  frame_globals = PyFrame_GetGlobals(frame);
-+  frame_locals = PyFrame_GetLocals(frame);
-+  if (frame_globals != frame_locals) {
-     // Not at global module scope
--    return false;
-+    goto exit;
-   }
- #endif
--  return true;
-+  result = true;
-+exit:
-+  Py_XDECREF(frame_globals);
-+  Py_XDECREF(frame_locals);
-+  Py_XDECREF(frame_code);
-+  Py_XDECREF(frame);
-+  return result;
- }
- 
- // If the calling code is not a _pb2.py file, raise AttributeError.
\ No newline at end of file
diff --git a/third_party/xla/third_party/py/ml_dtypes/ml_dtypes.BUILD b/third_party/xla/third_party/py/ml_dtypes/ml_dtypes.BUILD
index ccf607d..a85195e 100644
--- a/third_party/xla/third_party/py/ml_dtypes/ml_dtypes.BUILD
+++ b/third_party/xla/third_party/py/ml_dtypes/ml_dtypes.BUILD
@@ -17,7 +17,7 @@
         ".",
         "ml_dtypes",
     ],
-    deps = ["@org_tensorflow//third_party/eigen3"],
+    deps = ["@eigen_archive//:eigen3"],
 )
 
 cc_library(
@@ -48,7 +48,7 @@
     deps = [
         ":float8",
         ":int4",
-        "@org_tensorflow//third_party/eigen3",
+        "@eigen_archive//:eigen3",
         "@org_tensorflow//third_party/py/numpy:headers",
     ],
 )
diff --git a/third_party/xla/third_party/py/ml_dtypes/ml_dtypes.tests.BUILD b/third_party/xla/third_party/py/ml_dtypes/ml_dtypes.tests.BUILD
index 37cd52d..574659a 100644
--- a/third_party/xla/third_party/py/ml_dtypes/ml_dtypes.tests.BUILD
+++ b/third_party/xla/third_party/py/ml_dtypes/ml_dtypes.tests.BUILD
@@ -55,7 +55,7 @@
         "//:float8",
         "@com_google_absl//absl/strings",
         "@com_google_googletest//:gtest_main",
-        "@org_tensorflow//third_party/eigen3",
+        "@eigen_archive//:eigen3",
     ],
 )
 
@@ -66,6 +66,6 @@
     deps = [
         "//:int4",
         "@com_google_googletest//:gtest_main",
-        "@org_tensorflow//third_party/eigen3",
+        "@eigen_archive//:eigen3",
     ],
 )
diff --git a/third_party/xla/third_party/py/non_hermetic/ml_dtypes/ml_dtypes.BUILD b/third_party/xla/third_party/py/non_hermetic/ml_dtypes/ml_dtypes.BUILD
index ccf607d..a85195e 100644
--- a/third_party/xla/third_party/py/non_hermetic/ml_dtypes/ml_dtypes.BUILD
+++ b/third_party/xla/third_party/py/non_hermetic/ml_dtypes/ml_dtypes.BUILD
@@ -17,7 +17,7 @@
         ".",
         "ml_dtypes",
     ],
-    deps = ["@org_tensorflow//third_party/eigen3"],
+    deps = ["@eigen_archive//:eigen3"],
 )
 
 cc_library(
@@ -48,7 +48,7 @@
     deps = [
         ":float8",
         ":int4",
-        "@org_tensorflow//third_party/eigen3",
+        "@eigen_archive//:eigen3",
         "@org_tensorflow//third_party/py/numpy:headers",
     ],
 )
diff --git a/third_party/xla/third_party/py/non_hermetic/ml_dtypes/ml_dtypes.tests.BUILD b/third_party/xla/third_party/py/non_hermetic/ml_dtypes/ml_dtypes.tests.BUILD
index 37cd52d..574659a 100644
--- a/third_party/xla/third_party/py/non_hermetic/ml_dtypes/ml_dtypes.tests.BUILD
+++ b/third_party/xla/third_party/py/non_hermetic/ml_dtypes/ml_dtypes.tests.BUILD
@@ -55,7 +55,7 @@
         "//:float8",
         "@com_google_absl//absl/strings",
         "@com_google_googletest//:gtest_main",
-        "@org_tensorflow//third_party/eigen3",
+        "@eigen_archive//:eigen3",
     ],
 )
 
@@ -66,6 +66,6 @@
     deps = [
         "//:int4",
         "@com_google_googletest//:gtest_main",
-        "@org_tensorflow//third_party/eigen3",
+        "@eigen_archive//:eigen3",
     ],
 )
diff --git a/third_party/xla/third_party/pybind11.BUILD b/third_party/xla/third_party/pybind11.BUILD
deleted file mode 100644
index d2a5d10..0000000
--- a/third_party/xla/third_party/pybind11.BUILD
+++ /dev/null
@@ -1,31 +0,0 @@
-package(default_visibility = ["//visibility:public"])
-
-cc_library(
-    name = "pybind11",
-    hdrs = glob(
-        include = [
-            "include/pybind11/*.h",
-            "include/pybind11/detail/*.h",
-        ],
-        exclude = [
-            "include/pybind11/common.h",
-            "include/pybind11/eigen.h",
-        ],
-    ),
-    copts = [
-        "-fexceptions",
-        "-Wno-undefined-inline",
-        "-Wno-pragma-once-outside-header",
-    ],
-    includes = ["include"],
-    strip_include_prefix = "include",
-    deps = [
-        "@local_xla//third_party/python_runtime:headers",
-    ],
-)
-
-# Needed by pybind11_bazel.
-config_setting(
-    name = "osx",
-    constraint_values = ["@platforms//os:osx"],
-)
diff --git a/third_party/xla/third_party/pybind11_abseil/BUILD b/third_party/xla/third_party/pybind11_abseil/BUILD
deleted file mode 100644
index 3b946e5..0000000
--- a/third_party/xla/third_party/pybind11_abseil/BUILD
+++ /dev/null
@@ -1,3 +0,0 @@
-# Necessary for bazel to recognize this as a package.
-
-# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"])
diff --git a/third_party/xla/third_party/pybind11_abseil/remove_license.patch b/third_party/xla/third_party/pybind11_abseil/remove_license.patch
deleted file mode 100644
index 91852c2..0000000
--- a/third_party/xla/third_party/pybind11_abseil/remove_license.patch
+++ /dev/null
@@ -1,13 +0,0 @@
-diff --git a/pybind11_abseil/BUILD b/pybind11_abseil/BUILD
-index 41482e7..ed9e4af 100644
---- a/pybind11_abseil/BUILD
-+++ b/pybind11_abseil/BUILD
-@@ -6,8 +6,6 @@ package(default_visibility = ["//visibility:public"])
-
- licenses(["notice"])
-
--exports_files(["LICENSE"])
--
- pybind_library(
-     name = "absl_casters",
-     hdrs = ["absl_casters.h"],
\ No newline at end of file
diff --git a/third_party/xla/third_party/pybind11_abseil/workspace.bzl b/third_party/xla/third_party/pybind11_abseil/workspace.bzl
deleted file mode 100644
index 19c1111..0000000
--- a/third_party/xla/third_party/pybind11_abseil/workspace.bzl
+++ /dev/null
@@ -1,20 +0,0 @@
-"""Provides the repo macro to import pybind11_abseil.
-
-pybind11_abseil requires pybind11 (which is loaded in another rule) and pybind11_bazel.
-See https://github.com/pybind/pybind11_abseil#installation.
-"""
-
-load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
-
-def repo():
-    """Imports pybind11_abseil."""
-    PA_COMMIT = "2c4932ed6f6204f1656e245838f4f5eae69d2e29"
-    PA_SHA256 = "0223b647b8cc817336a51e787980ebc299c8d5e64c069829bf34b69d72337449"
-    tf_http_archive(
-        name = "pybind11_abseil",
-        sha256 = PA_SHA256,
-        strip_prefix = "pybind11_abseil-{commit}".format(commit = PA_COMMIT),
-        urls = tf_mirror_urls("https://github.com/pybind/pybind11_abseil/archive/{commit}.tar.gz".format(commit = PA_COMMIT)),
-        build_file = "//third_party/pybind11_abseil:BUILD",
-        patch_file = ["//third_party/pybind11_abseil:remove_license.patch"],
-    )
diff --git a/third_party/xla/third_party/pybind11_bazel/BUILD b/third_party/xla/third_party/pybind11_bazel/BUILD
deleted file mode 100644
index e69de29..0000000
--- a/third_party/xla/third_party/pybind11_bazel/BUILD
+++ /dev/null
diff --git a/third_party/xla/third_party/pybind11_bazel/pybind11_bazel.patch b/third_party/xla/third_party/pybind11_bazel/pybind11_bazel.patch
deleted file mode 100644
index 74e038d..0000000
--- a/third_party/xla/third_party/pybind11_bazel/pybind11_bazel.patch
+++ /dev/null
@@ -1,37 +0,0 @@
-diff --git a/build_defs.bzl b/build_defs.bzl
-index cde1e93..03f14a5 100644
---- a/build_defs.bzl
-+++ b/build_defs.bzl
-@@ -27,7 +27,9 @@ PYBIND_DEPS = [
- 
- # Builds a Python extension module using pybind11.
- # This can be directly used in python with the import statement.
--# This adds rules for a .so binary file.
-+# This adds rules for .so and .pyd binary files, as well as
-+# a base target that selects between them depending on the platform
-+# (.pyd for windows, .so otherwise).
- def pybind_extension(
-         name,
-         copts = [],
-@@ -59,6 +61,21 @@ def pybind_extension(
-         **kwargs
-     )
- 
-+    native.genrule(
-+        name = name + "_pyd",
-+        srcs = [name + ".so"],
-+        outs = [name + ".pyd"],
-+        cmd = "cp $< $@",
-+    )
-+
-+    native.py_library(
-+        name = name,
-+        data = select({
-+            "@platforms//os:windows": [":" + name + ".pyd"],
-+            "//conditions:default": [":" + name + ".so"],
-+        }),
-+    )
-+
- # Builds a pybind11 compatible library. This can be linked to a pybind_extension.
- def pybind_library(
-         name,
diff --git a/third_party/xla/third_party/pybind11_bazel/workspace.bzl b/third_party/xla/third_party/pybind11_bazel/workspace.bzl
deleted file mode 100644
index dcc71d3..0000000
--- a/third_party/xla/third_party/pybind11_bazel/workspace.bzl
+++ /dev/null
@@ -1,17 +0,0 @@
-"""Provides the repo macro to import pybind11_bazel.
-
-pybind11_bazel requires pybind11 (which is loaded in another rule).
-"""
-
-load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
-
-def repo():
-    PB_COMMIT = "72cbbf1fbc830e487e3012862b7b720001b70672"
-    PB_SHA256 = "516c1b3a10d87740d2b7de6f121f8e19dde2c372ecbfe59aef44cd1872c10395"
-    tf_http_archive(
-        name = "pybind11_bazel",
-        strip_prefix = "pybind11_bazel-{commit}".format(commit = PB_COMMIT),
-        sha256 = PB_SHA256,
-        patch_file = ["//third_party/pybind11_bazel:pybind11_bazel.patch"],
-        urls = tf_mirror_urls("https://github.com/pybind/pybind11_bazel/archive/{commit}.tar.gz".format(commit = PB_COMMIT)),
-    )
diff --git a/third_party/xla/third_party/remote_config/BUILD b/third_party/xla/third_party/remote_config/BUILD
deleted file mode 100644
index e69de29..0000000
--- a/third_party/xla/third_party/remote_config/BUILD
+++ /dev/null
diff --git a/third_party/xla/third_party/remote_config/BUILD.tpl b/third_party/xla/third_party/remote_config/BUILD.tpl
deleted file mode 100644
index d97eb97..0000000
--- a/third_party/xla/third_party/remote_config/BUILD.tpl
+++ /dev/null
@@ -1,26 +0,0 @@
-# Each platform creates a constraint @<platform>//:platform_constraint that
-# is listed in its constraint_values; rule that want to select a specific
-# platform to run on can put @<platform>//:platform_constraing into their
-# exec_compatible_with attribute.
-# Toolchains can similarly be marked with target_compatible_with or
-# exec_compatible_with to bind them to this platform.
-constraint_setting(
-    name = "platform_setting"
-)
-
-constraint_value(
-    name = "platform_constraint",
-    constraint_setting = ":platform_setting",
-    visibility = ["//visibility:public"],
-)
-
-platform(
-    name = "platform",
-    visibility = ["//visibility:public"],
-    constraint_values = [
-        "@platforms//cpu:%{cpu}",
-        "@platforms//os:%{platform}",
-        ":platform_constraint",
-    ],
-    exec_properties = %{exec_properties},
-)
diff --git a/third_party/xla/third_party/remote_config/common.bzl b/third_party/xla/third_party/remote_config/common.bzl
deleted file mode 100644
index 57fb6fc..0000000
--- a/third_party/xla/third_party/remote_config/common.bzl
+++ /dev/null
@@ -1,327 +0,0 @@
-"""Functions common across configure rules."""
-
-BAZEL_SH = "BAZEL_SH"
-PYTHON_BIN_PATH = "PYTHON_BIN_PATH"
-PYTHON_LIB_PATH = "PYTHON_LIB_PATH"
-TF_PYTHON_CONFIG_REPO = "TF_PYTHON_CONFIG_REPO"
-
-def auto_config_fail(msg):
-    """Output failure message when auto configuration fails."""
-    red = "\033[0;31m"
-    no_color = "\033[0m"
-    fail("%sConfiguration Error:%s %s\n" % (red, no_color, msg))
-
-def which(repository_ctx, program_name, allow_failure = False):
-    """Returns the full path to a program on the execution platform.
-
-    Args:
-      repository_ctx: the repository_ctx
-      program_name: name of the program on the PATH
-
-    Returns:
-      The full path to a program on the execution platform.
-    """
-    if is_windows(repository_ctx):
-        if not program_name.endswith(".exe"):
-            program_name = program_name + ".exe"
-        out = execute(
-            repository_ctx,
-            ["C:\\Windows\\System32\\where.exe", program_name],
-            allow_failure = allow_failure,
-        ).stdout
-        if out != None:
-            out = out.replace("\\", "\\\\").rstrip()
-        return out
-
-    out = execute(
-        repository_ctx,
-        ["which", program_name],
-        allow_failure = allow_failure,
-    ).stdout
-    if out != None:
-        out = out.replace("\\", "\\\\").rstrip()
-    return out
-
-def get_python_bin(repository_ctx):
-    """Gets the python bin path.
-
-    Args:
-      repository_ctx: the repository_ctx
-
-    Returns:
-      The python bin path.
-    """
-    python_bin = get_host_environ(repository_ctx, PYTHON_BIN_PATH)
-    if python_bin:
-        return python_bin
-
-    # First check for an explicit "python3"
-    python_bin = which(repository_ctx, "python3", True)
-    if python_bin:
-        return python_bin
-
-    # Some systems just call pythone3 "python"
-    python_bin = which(repository_ctx, "python", True)
-    if python_bin:
-        return python_bin
-
-    auto_config_fail("Cannot find python in PATH, please make sure " +
-                     "python is installed and add its directory in PATH, or --define " +
-                     "%s='/something/else'.\nPATH=%s" % (
-                         PYTHON_BIN_PATH,
-                         get_environ(repository_ctx, "PATH"),
-                     ))
-    return python_bin  # unreachable
-
-def get_bash_bin(repository_ctx):
-    """Gets the bash bin path.
-
-    Args:
-      repository_ctx: the repository_ctx
-
-    Returns:
-      The bash bin path.
-    """
-    bash_bin = get_host_environ(repository_ctx, BAZEL_SH)
-    if bash_bin != None:
-        return bash_bin
-    bash_bin_path = which(repository_ctx, "bash")
-    if bash_bin_path == None:
-        auto_config_fail("Cannot find bash in PATH, please make sure " +
-                         "bash is installed and add its directory in PATH, or --define " +
-                         "%s='/path/to/bash'.\nPATH=%s" % (
-                             BAZEL_SH,
-                             get_environ(repository_ctx, "PATH"),
-                         ))
-    return bash_bin_path
-
-def read_dir(repository_ctx, src_dir):
-    """Returns a sorted list with all files in a directory.
-
-    Finds all files inside a directory, traversing subfolders and following
-    symlinks.
-
-    Args:
-      repository_ctx: the repository_ctx
-      src_dir: the directory to traverse
-
-    Returns:
-      A sorted list with all files in a directory.
-    """
-    if is_windows(repository_ctx):
-        src_dir = src_dir.replace("/", "\\")
-        find_result = execute(
-            repository_ctx,
-            ["C:\\Windows\\System32\\cmd.exe", "/c", "dir", src_dir, "/b", "/s", "/a-d"],
-            allow_failure = True,
-        )
-
-        # src_files will be used in genrule.outs where the paths must
-        # use forward slashes.
-        result = find_result.stdout.replace("\\", "/")
-    else:
-        find_result = execute(
-            repository_ctx,
-            ["find", src_dir, "-follow", "-type", "f"],
-            allow_failure = True,
-        )
-        result = find_result.stdout
-    return sorted(result.splitlines())
-
-def get_environ(repository_ctx, name, default_value = None):
-    """Returns the value of an environment variable on the execution platform.
-
-    Args:
-      repository_ctx: the repository_ctx
-      name: the name of environment variable
-      default_value: the value to return if not set
-
-    Returns:
-      The value of the environment variable 'name' on the execution platform
-      or 'default_value' if it's not set.
-    """
-    if is_windows(repository_ctx):
-        result = execute(
-            repository_ctx,
-            ["C:\\Windows\\System32\\cmd.exe", "/c", "echo", "%" + name + "%"],
-            allow_failure = True,
-        )
-    else:
-        cmd = "echo -n \"$%s\"" % name
-        result = execute(
-            repository_ctx,
-            [get_bash_bin(repository_ctx), "-c", cmd],
-            allow_failure = True,
-        )
-    if len(result.stdout) == 0:
-        return default_value
-    return result.stdout
-
-def get_host_environ(repository_ctx, name, default_value = None):
-    """Returns the value of an environment variable on the host platform.
-
-    The host platform is the machine that Bazel runs on.
-
-    Args:
-      repository_ctx: the repository_ctx
-      name: the name of environment variable
-
-    Returns:
-      The value of the environment variable 'name' on the host platform.
-    """
-    if name in repository_ctx.os.environ:
-        return repository_ctx.os.environ.get(name).strip()
-
-    if hasattr(repository_ctx.attr, "environ") and name in repository_ctx.attr.environ:
-        return repository_ctx.attr.environ.get(name).strip()
-
-    return default_value
-
-def is_windows(repository_ctx):
-    """Returns true if the execution platform is Windows.
-
-    Args:
-      repository_ctx: the repository_ctx
-
-    Returns:
-      If the execution platform is Windows.
-    """
-    os_name = ""
-    if hasattr(repository_ctx.attr, "exec_properties") and "OSFamily" in repository_ctx.attr.exec_properties:
-        os_name = repository_ctx.attr.exec_properties["OSFamily"]
-    else:
-        os_name = repository_ctx.os.name
-
-    return os_name.lower().find("windows") != -1
-
-def get_cpu_value(repository_ctx):
-    """Returns the name of the host operating system.
-
-    Args:
-      repository_ctx: The repository context.
-    Returns:
-      A string containing the name of the host operating system.
-    """
-    if is_windows(repository_ctx):
-        return "Windows"
-    result = raw_exec(repository_ctx, ["uname", "-s"])
-    return result.stdout.strip()
-
-def execute(
-        repository_ctx,
-        cmdline,
-        error_msg = None,
-        error_details = None,
-        allow_failure = False):
-    """Executes an arbitrary shell command.
-
-    Args:
-      repository_ctx: the repository_ctx object
-      cmdline: list of strings, the command to execute
-      error_msg: string, a summary of the error if the command fails
-      error_details: string, details about the error or steps to fix it
-      allow_failure: bool, if True, an empty stdout result or output to stderr
-        is fine, otherwise either of these is an error
-    Returns:
-      The result of repository_ctx.execute(cmdline)
-    """
-    result = raw_exec(repository_ctx, cmdline)
-    if (result.stderr or not result.stdout) and not allow_failure:
-        fail(
-            "\n".join([
-                error_msg.strip() if error_msg else "Repository command failed",
-                result.stderr.strip(),
-                error_details if error_details else "",
-            ]),
-        )
-    return result
-
-def raw_exec(repository_ctx, cmdline):
-    """Executes a command via repository_ctx.execute() and returns the result.
-
-    This method is useful for debugging purposes. For example, to print all
-    commands executed as well as their return code.
-
-    Args:
-      repository_ctx: the repository_ctx
-      cmdline: the list of args
-
-    Returns:
-      The 'exec_result' of repository_ctx.execute().
-    """
-    return repository_ctx.execute(cmdline)
-
-def files_exist(repository_ctx, paths, bash_bin = None):
-    """Checks which files in paths exists.
-
-    Args:
-      repository_ctx: the repository_ctx
-      paths: a list of paths
-      bash_bin: path to the bash interpreter
-
-    Returns:
-      Returns a list of Bool. True means that the path at the
-      same position in the paths list exists.
-    """
-    if bash_bin == None:
-        bash_bin = get_bash_bin(repository_ctx)
-
-    cmd_tpl = "[ -e \"%s\" ] && echo True || echo False"
-    cmds = [cmd_tpl % path for path in paths]
-    cmd = " ; ".join(cmds)
-
-    stdout = execute(repository_ctx, [bash_bin, "-c", cmd]).stdout.strip()
-    return [val == "True" for val in stdout.splitlines()]
-
-def realpath(repository_ctx, path, bash_bin = None):
-    """Returns the result of "realpath path".
-
-    Args:
-      repository_ctx: the repository_ctx
-      path: a path on the file system
-      bash_bin: path to the bash interpreter
-
-    Returns:
-      Returns the result of "realpath path"
-    """
-    if bash_bin == None:
-        bash_bin = get_bash_bin(repository_ctx)
-
-    return execute(repository_ctx, [bash_bin, "-c", "realpath \"%s\"" % path]).stdout.strip()
-
-def err_out(result):
-    """Returns stderr if set, else stdout.
-
-    This function is a workaround for a bug in RBE where stderr is returned as stdout. Instead
-    of using result.stderr use err_out(result) instead.
-
-    Args:
-      result: the exec_result.
-
-    Returns:
-      The stderr if set, else stdout
-    """
-    if len(result.stderr) == 0:
-        return result.stdout
-    return result.stderr
-
-def config_repo_label(config_repo, target):
-    """Construct a label from config_repo and target.
-
-    This function exists to ease the migration from preconfig to remote config. In preconfig
-    the TF_*_CONFIG_REPO environ variables are set to packages in the main repo while in
-    remote config they will point to remote repositories.
-
-    Args:
-      config_repo: a remote repository or package.
-      target: a target
-    Returns:
-      A label constructed from config_repo and target.
-    """
-    if config_repo.startswith("@") and not config_repo.find("//") > 0:
-        # remote config is being used.
-        return Label(config_repo + "//" + target)
-    elif target.startswith(":"):
-        return Label(config_repo + target)
-    else:
-        return Label(config_repo + "/" + target)
diff --git a/third_party/xla/third_party/remote_config/remote_platform_configure.bzl b/third_party/xla/third_party/remote_config/remote_platform_configure.bzl
deleted file mode 100644
index 59bedfe..0000000
--- a/third_party/xla/third_party/remote_config/remote_platform_configure.bzl
+++ /dev/null
@@ -1,55 +0,0 @@
-"""Repository rule to create a platform for a docker image to be used with RBE."""
-
-def _remote_platform_configure_impl(repository_ctx):
-    platform = repository_ctx.attr.platform
-    if platform == "local":
-        os = repository_ctx.os.name.lower()
-        if os.startswith("windows"):
-            platform = "windows"
-        elif os.startswith("mac os"):
-            platform = "osx"
-        else:
-            platform = "linux"
-
-    cpu = "x86_64"
-    machine_type = repository_ctx.execute(["bash", "-c", "echo $MACHTYPE"]).stdout
-    if (machine_type.startswith("ppc") or
-        machine_type.startswith("powerpc")):
-        cpu = "ppc"
-    elif machine_type.startswith("s390x"):
-        cpu = "s390x"
-    elif machine_type.startswith("aarch64"):
-        cpu = "aarch64"
-    elif machine_type.startswith("arm64"):
-        cpu = "aarch64"
-    elif machine_type.startswith("arm"):
-        cpu = "arm"
-    elif machine_type.startswith("mips64"):
-        cpu = "mips64"
-    elif machine_type.startswith("riscv64"):
-        cpu = "riscv64"
-
-    exec_properties = repository_ctx.attr.platform_exec_properties
-
-    serialized_exec_properties = "{"
-    for k, v in exec_properties.items():
-        serialized_exec_properties += "\"%s\" : \"%s\"," % (k, v)
-    serialized_exec_properties += "}"
-
-    repository_ctx.template(
-        "BUILD",
-        Label("@local_xla//third_party/remote_config:BUILD.tpl"),
-        {
-            "%{platform}": platform,
-            "%{exec_properties}": serialized_exec_properties,
-            "%{cpu}": cpu,
-        },
-    )
-
-remote_platform_configure = repository_rule(
-    implementation = _remote_platform_configure_impl,
-    attrs = {
-        "platform_exec_properties": attr.string_dict(mandatory = True),
-        "platform": attr.string(default = "linux", values = ["linux", "windows", "local"]),
-    },
-)
diff --git a/third_party/xla/third_party/six.BUILD b/third_party/xla/third_party/six.BUILD
deleted file mode 100644
index d6ac142..0000000
--- a/third_party/xla/third_party/six.BUILD
+++ /dev/null
@@ -1,14 +0,0 @@
-# Description:
-#   Six provides simple utilities for wrapping over differences between Python 2
-#   and Python 3.
-
-licenses(["notice"])  # MIT
-
-exports_files(["LICENSE"])
-
-py_library(
-    name = "six",
-    srcs = ["six.py"],
-    srcs_version = "PY3",
-    visibility = ["//visibility:public"],
-)
diff --git a/third_party/xla/third_party/snappy.BUILD b/third_party/xla/third_party/snappy.BUILD
deleted file mode 100644
index fbf4e2e..0000000
--- a/third_party/xla/third_party/snappy.BUILD
+++ /dev/null
@@ -1,99 +0,0 @@
-package(default_visibility = ["//visibility:public"])
-
-licenses(["notice"])  # BSD 3-Clause
-
-exports_files(["COPYING"])
-
-cc_library(
-    name = "snappy",
-    srcs = [
-        "config.h",
-        "snappy.cc",
-        "snappy.h",
-        "snappy-internal.h",
-        "snappy-sinksource.cc",
-        "snappy-sinksource.h",
-        "snappy-stubs-internal.cc",
-        "snappy-stubs-internal.h",
-        "snappy-stubs-public.h",
-    ],
-    hdrs = ["snappy.h"],
-    copts = ["-DHAVE_CONFIG_H"] + select({
-        "@local_tsl//tsl:windows": [],
-        "//conditions:default": [
-            "-fno-exceptions",
-            "-Wno-sign-compare",
-            "-Wno-shift-negative-value",
-            "-Wno-implicit-function-declaration",
-        ],
-    }),
-    defines = select({
-        "@local_tsl//tsl:windows": [],
-        "//conditions:default": ["HAVE_SYS_UIO_H"],
-    }),
-)
-
-genrule(
-    name = "config_h",
-    outs = ["config.h"],
-    cmd = "\n".join([
-        "cat <<'EOF' >$@",
-        "#define HAVE_STDDEF_H 1",
-        "#define HAVE_STDINT_H 1",
-        "",
-        "#ifdef __has_builtin",
-        "#  if !defined(HAVE_BUILTIN_EXPECT) && __has_builtin(__builtin_expect)",
-        "#    define HAVE_BUILTIN_EXPECT 1",
-        "#  endif",
-        "#  if !defined(HAVE_BUILTIN_CTZ) && __has_builtin(__builtin_ctzll)",
-        "#    define HAVE_BUILTIN_CTZ 1",
-        "#  endif",
-        "#elif defined(__GNUC__) && (__GNUC__ > 3 || __GNUC__ == 3 && __GNUC_MINOR__ >= 4)",
-        "#  ifndef HAVE_BUILTIN_EXPECT",
-        "#    define HAVE_BUILTIN_EXPECT 1",
-        "#  endif",
-        "#  ifndef HAVE_BUILTIN_CTZ",
-        "#    define HAVE_BUILTIN_CTZ 1",
-        "#  endif",
-        "#endif",
-        "",
-        "#ifdef __has_include",
-        "#  if !defined(HAVE_BYTESWAP_H) && __has_include(<byteswap.h>)",
-        "#    define HAVE_BYTESWAP_H 1",
-        "#  endif",
-        "#  if !defined(HAVE_UNISTD_H) && __has_include(<unistd.h>)",
-        "#    define HAVE_UNISTD_H 1",
-        "#  endif",
-        "#  if !defined(HAVE_SYS_ENDIAN_H) && __has_include(<sys/endian.h>)",
-        "#    define HAVE_SYS_ENDIAN_H 1",
-        "#  endif",
-        "#  if !defined(HAVE_SYS_MMAN_H) && __has_include(<sys/mman.h>)",
-        "#    define HAVE_SYS_MMAN_H 1",
-        "#  endif",
-        "#  if !defined(HAVE_SYS_UIO_H) && __has_include(<sys/uio.h>)",
-        "#    define HAVE_SYS_UIO_H 1",
-        "#  endif",
-        "#endif",
-        "",
-        "#ifndef SNAPPY_IS_BIG_ENDIAN",
-        "#  ifdef __s390x__",
-        "#    define SNAPPY_IS_BIG_ENDIAN 1",
-        "#  elif defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__",
-        "#    define SNAPPY_IS_BIG_ENDIAN 1",
-        "#  endif",
-        "#endif",
-        "EOF",
-    ]),
-)
-
-genrule(
-    name = "snappy_stubs_public_h",
-    srcs = ["snappy-stubs-public.h.in"],
-    outs = ["snappy-stubs-public.h"],
-    cmd = ("sed " +
-           "-e 's/$${\\(.*\\)_01}/\\1/g' " +
-           "-e 's/$${SNAPPY_MAJOR}/1/g' " +
-           "-e 's/$${SNAPPY_MINOR}/1/g' " +
-           "-e 's/$${SNAPPY_PATCHLEVEL}/4/g' " +
-           "$< >$@"),
-)
diff --git a/third_party/xla/third_party/stablehlo/temporary.patch b/third_party/xla/third_party/stablehlo/temporary.patch
index 5161610..b3cb70f 100644
--- a/third_party/xla/third_party/stablehlo/temporary.patch
+++ b/third_party/xla/third_party/stablehlo/temporary.patch
@@ -181,6 +181,18 @@
  
  #-------------------------------------------------------------------------------
  # Directory setup
+diff --ruN a/stablehlo/stablehlo/conversions/tosa/tests/BUILD.bazel b/stablehlo/stablehlo/conversions/tosa/tests/BUILD.bazel
+--- stablehlo/stablehlo/conversions/tosa/tests/BUILD.bazel
++++ stablehlo/stablehlo/conversions/tosa/tests/BUILD.bazel
+@@ -29,7 +29,7 @@
+         "@LLVM_TOOLS_DIR@": package_path("@llvm-project//llvm:BUILD"),
+         "\"@STABLEHLO_TOOLS_DIR@\"": "os.path.join(os.environ['TEST_SRCDIR'], 'stablehlo')",
+         "\"@STABLEHLO_SOURCE_DIR@\"": "os.path.join(os.environ['TEST_SRCDIR'], 'stablehlo')",
+-    },
++     },
+     template = "lit.site.cfg.py.in",
+ )
+ 
 diff --ruN a/stablehlo/stablehlo/dialect/Base.cpp b/stablehlo/stablehlo/dialect/Base.cpp
 --- stablehlo/stablehlo/dialect/Base.cpp
 +++ stablehlo/stablehlo/dialect/Base.cpp
@@ -980,6 +992,44 @@
 +}  // namespace mlir
 +
 +#endif  // STABLEHLO_DIALECT_EXPERIMENTAL_OPS_H
+diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.cpp b/stablehlo/stablehlo/dialect/StablehloOps.cpp
+--- stablehlo/stablehlo/dialect/StablehloOps.cpp
++++ stablehlo/stablehlo/dialect/StablehloOps.cpp
+@@ -1543,6 +1543,7 @@
+     p << " across dimensions = [";
+     llvm::interleaveComma(getDimensions().getValues<int64_t>(), p);
+     p << "]";
++    p.printOptionalAttrDict(getOperation()->getAttrs(), {"dimensions"});
+     p << " : ";
+     p.printFunctionalType(*this);
+   } else {
+@@ -1705,6 +1706,7 @@
+   if (parser.parseKeyword("across") || parser.parseKeyword("dimensions") ||
+       parser.parseEqual() ||
+       parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, parseDim) ||
++      parser.parseOptionalAttrDict(result.attributes) ||
+       parser.parseColon() || parser.parseType(reduceOpFnType) ||
+       parser.parseOptionalLocationSpecifier(explicitLoc))
+     return failure();
+diff --ruN a/stablehlo/stablehlo/tests/print_reduce.mlir b/stablehlo/stablehlo/tests/print_reduce.mlir
+--- stablehlo/stablehlo/tests/print_reduce.mlir
++++ stablehlo/stablehlo/tests/print_reduce.mlir
+@@ -168,3 +168,15 @@
+ 
+   func.return %0: tensor<4xf32>
+ }
++
++// The test case makes sure any custom attrs set on the reduce-op are
++// printed/parsed when pretty-printed.
++
++// CHECK-LABEL:  func @pretty_print_with_custom_attr
++// CHECK:          applies stablehlo.add across dimensions = [1] {custom_user_attr = 1 : i64}
++
++func.func @pretty_print_with_custom_attr(%arg0: tensor<2x64x13xf32>) -> tensor<2x13xf32> {
++  %0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
++  %1 = stablehlo.reduce(%arg0 init: %0) applies stablehlo.add across dimensions = [1] {custom_user_attr = 1 : i64} : (tensor<2x64x13xf32>, tensor<f32>) -> tensor<2x13xf32>
++  return %1 : tensor<2x13xf32>
++}
 diff --ruN a/stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir b/stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir
 --- stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir
 +++ stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir
diff --git a/third_party/xla/third_party/stablehlo/workspace.bzl b/third_party/xla/third_party/stablehlo/workspace.bzl
index 4dc93b8..da650e4 100644
--- a/third_party/xla/third_party/stablehlo/workspace.bzl
+++ b/third_party/xla/third_party/stablehlo/workspace.bzl
@@ -4,8 +4,8 @@
 
 def repo():
     # LINT.IfChange
-    STABLEHLO_COMMIT = "03216ba4f6ead279db5912828f8c94634589007d"
-    STABLEHLO_SHA256 = "84e9624cc61e70586c2e4bb0356da8d7fdbe653d0a015fd67b5c9a56660ba258"
+    STABLEHLO_COMMIT = "5e41e674af78da676652459c2dcf6a0d76e59ddb"
+    STABLEHLO_SHA256 = "02f7db52b6dc6b14d3dcbe8e7982c8cebcc2b91b1843abbe292ab98eef1fc9f2"
     # LINT.ThenChange(Google-internal path)
 
     tf_http_archive(
diff --git a/third_party/xla/third_party/systemlibs/BUILD b/third_party/xla/third_party/systemlibs/BUILD
deleted file mode 100644
index e69de29..0000000
--- a/third_party/xla/third_party/systemlibs/BUILD
+++ /dev/null
diff --git a/third_party/xla/third_party/systemlibs/BUILD.tpl b/third_party/xla/third_party/systemlibs/BUILD.tpl
deleted file mode 100644
index e69de29..0000000
--- a/third_party/xla/third_party/systemlibs/BUILD.tpl
+++ /dev/null
diff --git a/third_party/xla/third_party/systemlibs/absl_py.BUILD b/third_party/xla/third_party/systemlibs/absl_py.BUILD
deleted file mode 100644
index cbe6e10..0000000
--- a/third_party/xla/third_party/systemlibs/absl_py.BUILD
+++ /dev/null
@@ -1,6 +0,0 @@
-licenses(["notice"])  # Apache 2.0
-
-filegroup(
-    name = "LICENSE",
-    visibility = ["//visibility:public"],
-)
diff --git a/third_party/xla/third_party/systemlibs/absl_py.absl.BUILD b/third_party/xla/third_party/systemlibs/absl_py.absl.BUILD
deleted file mode 100644
index c9a6790..0000000
--- a/third_party/xla/third_party/systemlibs/absl_py.absl.BUILD
+++ /dev/null
@@ -1,7 +0,0 @@
-licenses(["notice"])  # Apache 2.0
-
-package(default_visibility = ["//visibility:public"])
-
-py_library(
-    name = "app",
-)
diff --git a/third_party/xla/third_party/systemlibs/absl_py.absl.flags.BUILD b/third_party/xla/third_party/systemlibs/absl_py.absl.flags.BUILD
deleted file mode 100644
index d92f494..0000000
--- a/third_party/xla/third_party/systemlibs/absl_py.absl.flags.BUILD
+++ /dev/null
@@ -1,11 +0,0 @@
-licenses(["notice"])  # Apache 2.0
-
-package(default_visibility = ["//visibility:public"])
-
-py_library(
-    name = "flags",
-)
-
-py_library(
-    name = "argparse_flags",
-)
diff --git a/third_party/xla/third_party/systemlibs/absl_py.absl.logging.BUILD b/third_party/xla/third_party/systemlibs/absl_py.absl.logging.BUILD
deleted file mode 100644
index 98d136e..0000000
--- a/third_party/xla/third_party/systemlibs/absl_py.absl.logging.BUILD
+++ /dev/null
@@ -1,7 +0,0 @@
-licenses(["notice"])  # Apache 2.0
-
-package(default_visibility = ["//visibility:public"])
-
-py_library(
-    name = "logging",
-)
diff --git a/third_party/xla/third_party/systemlibs/absl_py.absl.testing.BUILD b/third_party/xla/third_party/systemlibs/absl_py.absl.testing.BUILD
deleted file mode 100644
index ee810f8..0000000
--- a/third_party/xla/third_party/systemlibs/absl_py.absl.testing.BUILD
+++ /dev/null
@@ -1,16 +0,0 @@
-licenses(["notice"])  # Apache 2.0
-
-py_library(
-    name = "parameterized",
-    visibility = ["//visibility:public"],
-)
-
-py_library(
-    name = "absltest",
-    visibility = ["//visibility:public"],
-)
-
-py_library(
-    name = "flagsaver",
-    visibility = ["//visibility:public"],
-)
diff --git a/third_party/xla/third_party/systemlibs/boringssl.BUILD b/third_party/xla/third_party/systemlibs/boringssl.BUILD
deleted file mode 100644
index bc4c533..0000000
--- a/third_party/xla/third_party/systemlibs/boringssl.BUILD
+++ /dev/null
@@ -1,21 +0,0 @@
-licenses(["notice"])
-
-filegroup(
-    name = "LICENSE",
-    visibility = ["//visibility:public"],
-)
-
-cc_library(
-    name = "crypto",
-    linkopts = ["-lcrypto"],
-    visibility = ["//visibility:public"],
-)
-
-cc_library(
-    name = "ssl",
-    linkopts = ["-lssl"],
-    visibility = ["//visibility:public"],
-    deps = [
-        ":crypto",
-    ],
-)
diff --git a/third_party/xla/third_party/systemlibs/build_defs.bzl.tpl b/third_party/xla/third_party/systemlibs/build_defs.bzl.tpl
deleted file mode 100644
index 3faa46c..0000000
--- a/third_party/xla/third_party/systemlibs/build_defs.bzl.tpl
+++ /dev/null
@@ -1,32 +0,0 @@
-# -*- Python -*-
-"""Skylark macros for system libraries.
-"""
-
-SYSTEM_LIBS_ENABLED = %{syslibs_enabled}
-
-SYSTEM_LIBS_LIST = [
-%{syslibs_list}
-]
-
-
-def if_any_system_libs(a, b=[]):
-  """Conditional which evaluates to 'a' if any system libraries are configured."""
-  if SYSTEM_LIBS_ENABLED:
-    return a
-  else:
-    return b
-
-
-def if_system_lib(lib, a, b=[]):
-  """Conditional which evaluates to 'a' if we're using the system version of lib"""
-
-  if SYSTEM_LIBS_ENABLED and lib in SYSTEM_LIBS_LIST:
-    return a
-  else:
-    return b
-
-
-def if_not_system_lib(lib, a, b=[]):
-  """Conditional which evaluates to 'a' if we're using the system version of lib"""
-
-  return if_system_lib(lib, b, a)
diff --git a/third_party/xla/third_party/systemlibs/curl.BUILD b/third_party/xla/third_party/systemlibs/curl.BUILD
deleted file mode 100644
index c5f125c..0000000
--- a/third_party/xla/third_party/systemlibs/curl.BUILD
+++ /dev/null
@@ -1,12 +0,0 @@
-licenses(["notice"])  # MIT/X derivative license
-
-filegroup(
-    name = "COPYING",
-    visibility = ["//visibility:public"],
-)
-
-cc_library(
-    name = "curl",
-    linkopts = ["-lcurl"],
-    visibility = ["//visibility:public"],
-)
diff --git a/third_party/xla/third_party/systemlibs/cython.BUILD b/third_party/xla/third_party/systemlibs/cython.BUILD
deleted file mode 100644
index 1d52587..0000000
--- a/third_party/xla/third_party/systemlibs/cython.BUILD
+++ /dev/null
@@ -1,13 +0,0 @@
-licenses(["notice"])  # Apache-2.0
-
-genrule(
-    name = "lncython",
-    outs = ["cython"],
-    cmd = "ln -s $$(which cython) $@",
-)
-
-sh_binary(
-    name = "cython_binary",
-    srcs = ["cython"],
-    visibility = ["//visibility:public"],
-)
diff --git a/third_party/xla/third_party/systemlibs/double_conversion.BUILD b/third_party/xla/third_party/systemlibs/double_conversion.BUILD
deleted file mode 100644
index 5684601..0000000
--- a/third_party/xla/third_party/systemlibs/double_conversion.BUILD
+++ /dev/null
@@ -1,12 +0,0 @@
-licenses(["notice"])
-
-filegroup(
-    name = "LICENSE",
-    visibility = ["//visibility:public"],
-)
-
-cc_library(
-    name = "double-conversion",
-    linkopts = ["-ldouble-conversion"],
-    visibility = ["//visibility:public"],
-)
diff --git a/third_party/xla/third_party/systemlibs/gif.BUILD b/third_party/xla/third_party/systemlibs/gif.BUILD
deleted file mode 100644
index 5eb2c91..0000000
--- a/third_party/xla/third_party/systemlibs/gif.BUILD
+++ /dev/null
@@ -1,12 +0,0 @@
-licenses(["notice"])  # MIT
-
-filegroup(
-    name = "COPYING",
-    visibility = ["//visibility:public"],
-)
-
-cc_library(
-    name = "gif",
-    linkopts = ["-lgif"],
-    visibility = ["//visibility:public"],
-)
diff --git a/third_party/xla/third_party/systemlibs/google_cloud_cpp.BUILD b/third_party/xla/third_party/systemlibs/google_cloud_cpp.BUILD
deleted file mode 100644
index cbe6e10..0000000
--- a/third_party/xla/third_party/systemlibs/google_cloud_cpp.BUILD
+++ /dev/null
@@ -1,6 +0,0 @@
-licenses(["notice"])  # Apache 2.0
-
-filegroup(
-    name = "LICENSE",
-    visibility = ["//visibility:public"],
-)
diff --git a/third_party/xla/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD b/third_party/xla/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD
deleted file mode 100644
index b59d565..0000000
--- a/third_party/xla/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD
+++ /dev/null
@@ -1,7 +0,0 @@
-licenses(["notice"])  # Apache 2.0
-
-cc_library(
-    name = "bigtable_client",
-    linkopts = ["-lbigtable_client"],
-    visibility = ["//visibility:public"],
-)
diff --git a/third_party/xla/third_party/systemlibs/grpc.BUILD b/third_party/xla/third_party/systemlibs/grpc.BUILD
deleted file mode 100644
index 8b703f1..0000000
--- a/third_party/xla/third_party/systemlibs/grpc.BUILD
+++ /dev/null
@@ -1,76 +0,0 @@
-licenses(["notice"])  # Apache v2
-
-filegroup(
-    name = "LICENSE",
-    visibility = ["//visibility:public"],
-)
-
-cc_library(
-    name = "grpc",
-    linkopts = [
-        "-lgrpc",
-        "-lgpr",
-    ],
-    visibility = ["//visibility:public"],
-)
-
-cc_library(
-    name = "grpc++",
-    linkopts = [
-        "-lgrpc++",
-        "-lgpr",
-    ],
-    visibility = ["//visibility:public"],
-)
-
-cc_library(
-    name = "grpc++_public_hdrs",
-    visibility = ["//visibility:public"],
-)
-
-cc_library(
-    name = "grpc++_codegen_proto",
-    visibility = ["//visibility:public"],
-)
-
-cc_library(
-    name = "grpc_unsecure",
-    linkopts = [
-        "-lgrpc_unsecure",
-        "-lgpr",
-    ],
-    visibility = ["//visibility:public"],
-)
-
-cc_library(
-    name = "grpc++_unsecure",
-    linkopts = [
-        "-lgrpc++_unsecure",
-        "-lgpr",
-    ],
-    visibility = ["//visibility:public"],
-)
-
-genrule(
-    name = "ln_grpc_cpp_plugin",
-    outs = ["grpc_cpp_plugin.bin"],
-    cmd = "ln -s $$(which grpc_cpp_plugin) $@",
-)
-
-sh_binary(
-    name = "grpc_cpp_plugin",
-    srcs = ["grpc_cpp_plugin.bin"],
-    visibility = ["//visibility:public"],
-)
-
-genrule(
-    name = "ln_grpc_python_plugin",
-    outs = ["grpc_python_plugin.bin"],
-    cmd = "ln -s $$(which grpc_python_plugin) $@",
-)
-
-sh_binary(
-    name = "grpc_python_plugin",
-    srcs = ["grpc_python_plugin.bin"],
-    visibility = ["//visibility:public"],
-)
diff --git a/third_party/xla/third_party/systemlibs/grpc.bazel.cc_grpc_library.bzl b/third_party/xla/third_party/systemlibs/grpc.bazel.cc_grpc_library.bzl
deleted file mode 100644
index e427328..0000000
--- a/third_party/xla/third_party/systemlibs/grpc.bazel.cc_grpc_library.bzl
+++ /dev/null
@@ -1,105 +0,0 @@
-"""Generates and compiles C++ grpc stubs from proto_library rules."""
-
-load("@com_github_grpc_grpc//bazel:generate_cc.bzl", "generate_cc")
-load("@com_github_grpc_grpc//bazel:protobuf.bzl", "well_known_proto_libs")
-
-def cc_grpc_library(
-        name,
-        srcs,
-        deps,
-        proto_only = False,
-        well_known_protos = False,
-        generate_mocks = False,
-        use_external = False,
-        grpc_only = False,
-        **kwargs):
-    """Generates C++ grpc classes for services defined in a proto file.
-
-    If grpc_only is True, this rule is compatible with proto_library and
-    cc_proto_library native rules such that it expects proto_library target
-    as srcs argument and generates only grpc library classes, expecting
-    protobuf messages classes library (cc_proto_library target) to be passed in
-    deps argument. By default grpc_only is False which makes this rule to behave
-    in a backwards-compatible mode (trying to generate both proto and grpc
-    classes).
-
-    Assumes the generated classes will be used in cc_api_version = 2.
-
-    Args:
-        name (str): Name of rule.
-        srcs (list): A single .proto file which contains services definitions,
-          or if grpc_only parameter is True, a single proto_library which
-          contains services descriptors.
-        deps (list): A list of C++ proto_library (or cc_proto_library) which
-          provides the compiled code of any message that the services depend on.
-        proto_only (bool): If True, create only C++ proto classes library,
-          avoid creating C++ grpc classes library (expect it in deps).
-          Deprecated, use native cc_proto_library instead. False by default.
-        well_known_protos (bool): Should this library additionally depend on
-          well known protos. Deprecated, the well known protos should be
-          specified as explicit dependencies of the proto_library target
-          (passed in srcs parameter) instead. False by default.
-        generate_mocks (bool): when True, Google Mock code for client stub is
-          generated. False by default.
-        use_external (bool): Not used.
-        grpc_only (bool): if True, generate only grpc library, expecting
-          protobuf messages library (cc_proto_library target) to be passed as
-          deps. False by default (will become True by default eventually).
-        **kwargs: rest of arguments, e.g., compatible_with and visibility
-    """
-    if len(srcs) > 1:
-        fail("Only one srcs value supported", "srcs")
-    if grpc_only and proto_only:
-        fail("A mutualy exclusive configuration is specified: grpc_only = True and proto_only = True")
-
-    extra_deps = []
-    proto_targets = []
-
-    if not grpc_only:
-        proto_target = "_" + name + "_only"
-        cc_proto_target = name if proto_only else "_" + name + "_cc_proto"
-
-        proto_deps = ["_" + dep + "_only" for dep in deps if dep.find(":") == -1]
-        proto_deps += [dep.split(":")[0] + ":" + "_" + dep.split(":")[1] + "_only" for dep in deps if dep.find(":") != -1]
-        if well_known_protos:
-            proto_deps += well_known_proto_libs()
-
-        native.proto_library(
-            name = proto_target,
-            srcs = srcs,
-            deps = proto_deps,
-            **kwargs
-        )
-
-        native.cc_proto_library(
-            name = cc_proto_target,
-            deps = [":" + proto_target],
-            **kwargs
-        )
-        extra_deps.append(":" + cc_proto_target)
-        proto_targets.append(proto_target)
-    else:
-        if not srcs:
-            fail("srcs cannot be empty", "srcs")
-        proto_targets += srcs
-
-    if not proto_only:
-        codegen_grpc_target = "_" + name + "_grpc_codegen"
-        generate_cc(
-            name = codegen_grpc_target,
-            srcs = proto_targets,
-            plugin = "@com_github_grpc_grpc//src/compiler:grpc_cpp_plugin",
-            well_known_protos = well_known_protos,
-            generate_mocks = generate_mocks,
-            **kwargs
-        )
-
-        native.cc_library(
-            name = name,
-            srcs = [":" + codegen_grpc_target],
-            hdrs = [":" + codegen_grpc_target],
-            deps = deps +
-                   extra_deps +
-                   ["@com_github_grpc_grpc//:grpc++_codegen_proto"],
-            **kwargs
-        )
diff --git a/third_party/xla/third_party/systemlibs/grpc.bazel.generate_cc.bzl b/third_party/xla/third_party/systemlibs/grpc.bazel.generate_cc.bzl
deleted file mode 100644
index 3c3f20c..0000000
--- a/third_party/xla/third_party/systemlibs/grpc.bazel.generate_cc.bzl
+++ /dev/null
@@ -1,187 +0,0 @@
-"""Generates C++ grpc stubs from proto_library rules.
-
-This is an internal rule used by cc_grpc_library, and shouldn't be used
-directly.
-"""
-
-load(
-    "@com_github_grpc_grpc//bazel:protobuf.bzl",
-    "get_include_directory",
-    "get_plugin_args",
-    "get_proto_root",
-    "proto_path_to_generated_filename",
-)
-
-_GRPC_PROTO_HEADER_FMT = "{}.grpc.pb.h"
-_GRPC_PROTO_SRC_FMT = "{}.grpc.pb.cc"
-_GRPC_PROTO_MOCK_HEADER_FMT = "{}_mock.grpc.pb.h"
-_PROTO_HEADER_FMT = "{}.pb.h"
-_PROTO_SRC_FMT = "{}.pb.cc"
-
-def _strip_package_from_path(label_package, file):
-    prefix_len = 0
-    if not file.is_source and file.path.startswith(file.root.path):
-        prefix_len = len(file.root.path) + 1
-
-    path = file.path
-    if len(label_package) == 0:
-        return path
-    if not path.startswith(label_package + "/", prefix_len):
-        fail("'{}' does not lie within '{}'.".format(path, label_package))
-    return path[prefix_len + len(label_package + "/"):]
-
-def _get_srcs_file_path(file):
-    if not file.is_source and file.path.startswith(file.root.path):
-        return file.path[len(file.root.path) + 1:]
-    return file.path
-
-def _join_directories(directories):
-    massaged_directories = [directory for directory in directories if len(directory) != 0]
-    return "/".join(massaged_directories)
-
-def generate_cc_impl(ctx):
-    """Implementation of the generate_cc rule."""
-    protos = [f for src in ctx.attr.srcs for f in src[ProtoInfo].check_deps_sources.to_list()]
-    includes = [
-        f
-        for src in ctx.attr.srcs
-        for f in src[ProtoInfo].transitive_imports.to_list()
-    ]
-    outs = []
-    proto_root = get_proto_root(
-        ctx.label.workspace_root,
-    )
-
-    label_package = _join_directories([ctx.label.workspace_root, ctx.label.package])
-    if ctx.executable.plugin:
-        outs += [
-            proto_path_to_generated_filename(
-                _strip_package_from_path(label_package, proto),
-                _GRPC_PROTO_HEADER_FMT,
-            )
-            for proto in protos
-        ]
-        outs += [
-            proto_path_to_generated_filename(
-                _strip_package_from_path(label_package, proto),
-                _GRPC_PROTO_SRC_FMT,
-            )
-            for proto in protos
-        ]
-        if ctx.attr.generate_mocks:
-            outs += [
-                proto_path_to_generated_filename(
-                    _strip_package_from_path(label_package, proto),
-                    _GRPC_PROTO_MOCK_HEADER_FMT,
-                )
-                for proto in protos
-            ]
-    else:
-        outs += [
-            proto_path_to_generated_filename(
-                _strip_package_from_path(label_package, proto),
-                _PROTO_HEADER_FMT,
-            )
-            for proto in protos
-        ]
-        outs += [
-            proto_path_to_generated_filename(
-                _strip_package_from_path(label_package, proto),
-                _PROTO_SRC_FMT,
-            )
-            for proto in protos
-        ]
-    out_files = [ctx.actions.declare_file(out) for out in outs]
-    dir_out = str(ctx.genfiles_dir.path + proto_root)
-
-    arguments = []
-    if ctx.executable.plugin:
-        arguments += get_plugin_args(
-            ctx.executable.plugin,
-            ctx.attr.flags,
-            dir_out,
-            ctx.attr.generate_mocks,
-        )
-        tools = [ctx.executable.plugin]
-    else:
-        arguments += ["--cpp_out=" + ",".join(ctx.attr.flags) + ":" + dir_out]
-        tools = []
-
-    arguments += [
-        "--proto_path={}".format(get_include_directory(i))
-        for i in includes
-    ]
-
-    # Include the output directory so that protoc puts the generated code in the
-    # right directory.
-    arguments += ["--proto_path={0}{1}".format(dir_out, proto_root)]
-    arguments += [_get_srcs_file_path(proto) for proto in protos]
-
-    # create a list of well known proto files if the argument is non-None
-    well_known_proto_files = []
-    if ctx.attr.well_known_protos:
-        f = ctx.attr.well_known_protos.files.to_list()[0].dirname
-        if f != "external/com_google_protobuf/src/google/protobuf":
-            print(
-                "Error: Only @com_google_protobuf//:well_known_protos is supported",
-            )
-        else:
-            # f points to "external/com_google_protobuf/src/google/protobuf"
-            # add -I argument to protoc so it knows where to look for the proto files.
-            arguments += ["-I{0}".format(f + "/../..")]
-            well_known_proto_files = [
-                f
-                for f in ctx.attr.well_known_protos.files.to_list()
-            ]
-
-    ctx.actions.run(
-        inputs = protos + includes + well_known_proto_files,
-        tools = tools,
-        outputs = out_files,
-        executable = ctx.executable._protoc,
-        arguments = arguments,
-        use_default_shell_env = True,
-    )
-
-    return struct(files = depset(out_files))
-
-_generate_cc = rule(
-    attrs = {
-        "srcs": attr.label_list(
-            mandatory = True,
-            allow_empty = False,
-            providers = [ProtoInfo],
-        ),
-        "plugin": attr.label(
-            executable = True,
-            providers = ["files_to_run"],
-            cfg = "exec",
-        ),
-        "flags": attr.string_list(
-            mandatory = False,
-            allow_empty = True,
-        ),
-        "well_known_protos": attr.label(mandatory = False),
-        "generate_mocks": attr.bool(
-            default = False,
-            mandatory = False,
-        ),
-        "_protoc": attr.label(
-            default = Label("@com_google_protobuf//:protoc"),
-            executable = True,
-            cfg = "exec",
-        ),
-    },
-    # We generate .h files, so we need to output to genfiles.
-    output_to_genfiles = True,
-    implementation = generate_cc_impl,
-)
-
-def generate_cc(well_known_protos, **kwargs):
-    if well_known_protos:
-        _generate_cc(
-            well_known_protos = "@com_google_protobuf//:well_known_protos",
-            **kwargs
-        )
-    else:
-        _generate_cc(**kwargs)
diff --git a/third_party/xla/third_party/systemlibs/grpc.bazel.grpc_deps.bzl b/third_party/xla/third_party/systemlibs/grpc.bazel.grpc_deps.bzl
deleted file mode 100644
index dd389c6..0000000
--- a/third_party/xla/third_party/systemlibs/grpc.bazel.grpc_deps.bzl
+++ /dev/null
@@ -1,6 +0,0 @@
-"""Load dependencies needed to compile and test the grpc library as a 3rd-party consumer."""
-
-def grpc_deps():
-    """Loads dependencies need to compile and test the grpc library."""
-
-    pass
diff --git a/third_party/xla/third_party/systemlibs/grpc.bazel.grpc_extra_deps.bzl b/third_party/xla/third_party/systemlibs/grpc.bazel.grpc_extra_deps.bzl
deleted file mode 100644
index 631c93a..0000000
--- a/third_party/xla/third_party/systemlibs/grpc.bazel.grpc_extra_deps.bzl
+++ /dev/null
@@ -1,4 +0,0 @@
-"""Stub version of @com_github_grpc_grpc//bazel:grpc_extra_deps.bzl necessary for TF system libs"""
-
-def grpc_extra_deps():
-    pass
diff --git a/third_party/xla/third_party/systemlibs/grpc.bazel.protobuf.bzl b/third_party/xla/third_party/systemlibs/grpc.bazel.protobuf.bzl
deleted file mode 100644
index 3eca97d..0000000
--- a/third_party/xla/third_party/systemlibs/grpc.bazel.protobuf.bzl
+++ /dev/null
@@ -1,244 +0,0 @@
-"""Utility functions for generating protobuf code."""
-
-_PROTO_EXTENSION = ".proto"
-_VIRTUAL_IMPORTS = "/_virtual_imports/"
-
-def well_known_proto_libs():
-    return [
-        "@com_google_protobuf//:any_proto",
-        "@com_google_protobuf//:api_proto",
-        "@com_google_protobuf//:compiler_plugin_proto",
-        "@com_google_protobuf//:descriptor_proto",
-        "@com_google_protobuf//:duration_proto",
-        "@com_google_protobuf//:empty_proto",
-        "@com_google_protobuf//:field_mask_proto",
-        "@com_google_protobuf//:source_context_proto",
-        "@com_google_protobuf//:struct_proto",
-        "@com_google_protobuf//:timestamp_proto",
-        "@com_google_protobuf//:type_proto",
-        "@com_google_protobuf//:wrappers_proto",
-    ]
-
-def get_proto_root(workspace_root):
-    """Gets the root protobuf directory.
-
-    Args:
-      workspace_root: context.label.workspace_root
-
-    Returns:
-      The directory relative to which generated include paths should be.
-    """
-    if workspace_root:
-        return "/{}".format(workspace_root)
-    else:
-        return ""
-
-def _strip_proto_extension(proto_filename):
-    if not proto_filename.endswith(_PROTO_EXTENSION):
-        fail('"{}" does not end with "{}"'.format(
-            proto_filename,
-            _PROTO_EXTENSION,
-        ))
-    return proto_filename[:-len(_PROTO_EXTENSION)]
-
-def proto_path_to_generated_filename(proto_path, fmt_str):
-    """Calculates the name of a generated file for a protobuf path.
-
-    For example, "examples/protos/helloworld.proto" might map to
-      "helloworld.pb.h".
-
-    Args:
-      proto_path: The path to the .proto file.
-      fmt_str: A format string used to calculate the generated filename. For
-        example, "{}.pb.h" might be used to calculate a C++ header filename.
-
-    Returns:
-      The generated filename.
-    """
-    return fmt_str.format(_strip_proto_extension(proto_path))
-
-def get_include_directory(source_file):
-    """Returns the include directory path for the source_file.
-
-    I.e. all of the include statements within the given source_file
-    are calculated relative to the directory returned by this method.
-
-    The returned directory path can be used as the "--proto_path=" argument
-    value.
-
-    Args:
-      source_file: A proto file.
-
-    Returns:
-      The include directory path for the source_file.
-    """
-    directory = source_file.path
-    prefix_len = 0
-
-    if is_in_virtual_imports(source_file):
-        root, relative = source_file.path.split(_VIRTUAL_IMPORTS, 2)
-        result = root + _VIRTUAL_IMPORTS + relative.split("/", 1)[0]
-        return result
-
-    if not source_file.is_source and directory.startswith(source_file.root.path):
-        prefix_len = len(source_file.root.path) + 1
-
-    if directory.startswith("external", prefix_len):
-        external_separator = directory.find("/", prefix_len)
-        repository_separator = directory.find("/", external_separator + 1)
-        return directory[:repository_separator]
-    else:
-        return source_file.root.path if source_file.root.path else "."
-
-def get_plugin_args(
-        plugin,
-        flags,
-        dir_out,
-        generate_mocks,
-        plugin_name = "PLUGIN"):
-    """Returns arguments configuring protoc to use a plugin for a language.
-
-    Args:
-      plugin: An executable file to run as the protoc plugin.
-      flags: The plugin flags to be passed to protoc.
-      dir_out: The output directory for the plugin.
-      generate_mocks: A bool indicating whether to generate mocks.
-      plugin_name: A name of the plugin, it is required to be unique when there
-      are more than one plugin used in a single protoc command.
-    Returns:
-      A list of protoc arguments configuring the plugin.
-    """
-    augmented_flags = list(flags)
-    if generate_mocks:
-        augmented_flags.append("generate_mock_code=true")
-
-    augmented_dir_out = dir_out
-    if augmented_flags:
-        augmented_dir_out = ",".join(augmented_flags) + ":" + dir_out
-
-    return [
-        "--plugin=protoc-gen-{plugin_name}={plugin_path}".format(
-            plugin_name = plugin_name,
-            plugin_path = plugin.path,
-        ),
-        "--{plugin_name}_out={dir_out}".format(
-            plugin_name = plugin_name,
-            dir_out = augmented_dir_out,
-        ),
-    ]
-
-def _get_staged_proto_file(context, source_file):
-    if (source_file.dirname == context.label.package or
-        is_in_virtual_imports(source_file)):
-        return source_file
-    else:
-        copied_proto = context.actions.declare_file(source_file.basename)
-        context.actions.run_shell(
-            inputs = [source_file],
-            outputs = [copied_proto],
-            command = "cp {} {}".format(source_file.path, copied_proto.path),
-            mnemonic = "CopySourceProto",
-        )
-        return copied_proto
-
-def protos_from_context(context):
-    """Copies proto files to the appropriate location.
-
-    Args:
-      context: The ctx object for the rule.
-
-    Returns:
-      A list of the protos.
-    """
-    protos = []
-    for src in context.attr.deps:
-        for file in src[ProtoInfo].direct_sources:
-            protos.append(_get_staged_proto_file(context, file))
-    return protos
-
-def includes_from_deps(deps):
-    """Get includes from rule dependencies."""
-    return [
-        file
-        for src in deps
-        for file in src[ProtoInfo].transitive_imports.to_list()
-    ]
-
-def get_proto_arguments(protos, genfiles_dir_path):
-    """Get the protoc arguments specifying which protos to compile."""
-    arguments = []
-    for proto in protos:
-        strip_prefix_len = 0
-        if is_in_virtual_imports(proto):
-            incl_directory = get_include_directory(proto)
-            if proto.path.startswith(incl_directory):
-                strip_prefix_len = len(incl_directory) + 1
-        elif proto.path.startswith(genfiles_dir_path):
-            strip_prefix_len = len(genfiles_dir_path) + 1
-
-        arguments.append(proto.path[strip_prefix_len:])
-
-    return arguments
-
-def declare_out_files(protos, context, generated_file_format):
-    """Declares and returns the files to be generated."""
-
-    out_file_paths = []
-    for proto in protos:
-        if not is_in_virtual_imports(proto):
-            out_file_paths.append(proto.basename)
-        else:
-            path = proto.path[proto.path.index(_VIRTUAL_IMPORTS) + 1:]
-            out_file_paths.append(path)
-
-    return [
-        context.actions.declare_file(
-            proto_path_to_generated_filename(
-                out_file_path,
-                generated_file_format,
-            ),
-        )
-        for out_file_path in out_file_paths
-    ]
-
-def get_out_dir(protos, context):
-    """ Returns the calculated value for --<lang>_out= protoc argument based on
-    the input source proto files and current context.
-
-    Args:
-        protos: A list of protos to be used as source files in protoc command
-        context: A ctx object for the rule.
-    Returns:
-        The value of --<lang>_out= argument.
-    """
-    at_least_one_virtual = 0
-    for proto in protos:
-        if is_in_virtual_imports(proto):
-            at_least_one_virtual = True
-        elif at_least_one_virtual:
-            fail("Proto sources must be either all virtual imports or all real")
-    if at_least_one_virtual:
-        out_dir = get_include_directory(protos[0])
-        ws_root = protos[0].owner.workspace_root
-        if ws_root and out_dir.find(ws_root) >= 0:
-            out_dir = "".join(out_dir.rsplit(ws_root, 1))
-        return struct(
-            path = out_dir,
-            import_path = out_dir[out_dir.find(_VIRTUAL_IMPORTS) + 1:],
-        )
-    return struct(path = context.genfiles_dir.path, import_path = None)
-
-def is_in_virtual_imports(source_file, virtual_folder = _VIRTUAL_IMPORTS):
-    """Determines if source_file is virtual (is placed in _virtual_imports
-    subdirectory). The output of all proto_library targets which use
-    import_prefix  and/or strip_import_prefix arguments is placed under
-    _virtual_imports directory.
-
-    Args:
-        source_file: A proto file.
-        virtual_folder: The virtual folder name (is set to "_virtual_imports"
-            by default)
-    Returns:
-        True if source_file is located under _virtual_imports, False otherwise.
-    """
-    return not source_file.is_source and virtual_folder in source_file.path
diff --git a/third_party/xla/third_party/systemlibs/jsoncpp.BUILD b/third_party/xla/third_party/systemlibs/jsoncpp.BUILD
deleted file mode 100644
index b5951e3..0000000
--- a/third_party/xla/third_party/systemlibs/jsoncpp.BUILD
+++ /dev/null
@@ -1,12 +0,0 @@
-licenses(["unencumbered"])  # Public Domain or MIT
-
-filegroup(
-    name = "LICENSE",
-    visibility = ["//visibility:public"],
-)
-
-cc_library(
-    name = "jsoncpp",
-    linkopts = ["-ljsoncpp"],
-    visibility = ["//visibility:public"],
-)
diff --git a/third_party/xla/third_party/systemlibs/lmdb.BUILD b/third_party/xla/third_party/systemlibs/lmdb.BUILD
deleted file mode 100644
index 6177b09..0000000
--- a/third_party/xla/third_party/systemlibs/lmdb.BUILD
+++ /dev/null
@@ -1,12 +0,0 @@
-licenses(["notice"])  # OpenLDAP Public License
-
-filegroup(
-    name = "LICENSE",
-    visibility = ["//visibility:public"],
-)
-
-cc_library(
-    name = "lmdb",
-    linkopts = ["-llmdb"],
-    visibility = ["//visibility:public"],
-)
diff --git a/third_party/xla/third_party/systemlibs/nsync.BUILD b/third_party/xla/third_party/systemlibs/nsync.BUILD
deleted file mode 100644
index c5d4ad0..0000000
--- a/third_party/xla/third_party/systemlibs/nsync.BUILD
+++ /dev/null
@@ -1,23 +0,0 @@
-licenses(["notice"])  # BSD 3-Clause
-
-filegroup(
-    name = "LICENSE",
-    visibility = ["//visibility:public"],
-)
-
-cc_library(
-    name = "nsync_headers",
-    visibility = ["//visibility:public"],
-)
-
-cc_library(
-    name = "nsync",
-    linkopts = ["-lnsync"],
-    visibility = ["//visibility:public"],
-)
-
-cc_library(
-    name = "nsync_cpp",
-    linkopts = ["-lnsync_cpp"],
-    visibility = ["//visibility:public"],
-)
diff --git a/third_party/xla/third_party/systemlibs/png.BUILD b/third_party/xla/third_party/systemlibs/png.BUILD
deleted file mode 100644
index fc6b6f2..0000000
--- a/third_party/xla/third_party/systemlibs/png.BUILD
+++ /dev/null
@@ -1,12 +0,0 @@
-licenses(["notice"])  # BSD/MIT-like license
-
-filegroup(
-    name = "LICENSE",
-    visibility = ["//visibility:public"],
-)
-
-cc_library(
-    name = "png",
-    linkopts = ["-lpng"],
-    visibility = ["//visibility:public"],
-)
diff --git a/third_party/xla/third_party/systemlibs/protobuf.BUILD b/third_party/xla/third_party/systemlibs/protobuf.BUILD
deleted file mode 100644
index 4d05ab2..0000000
--- a/third_party/xla/third_party/systemlibs/protobuf.BUILD
+++ /dev/null
@@ -1,113 +0,0 @@
-load("@rules_proto//proto:defs.bzl", "proto_library")
-load(
-    "@com_google_protobuf//:protobuf.bzl",
-    "cc_proto_library",
-    "proto_gen",
-    "py_proto_library",
-)
-
-licenses(["notice"])
-
-filegroup(
-    name = "LICENSE",
-    visibility = ["//visibility:public"],
-)
-
-# Map of all well known protos.
-# name => (include path, imports)
-WELL_KNOWN_PROTO_MAP = {
-    "any": ("google/protobuf/any.proto", []),
-    "api": (
-        "google/protobuf/api.proto",
-        [
-            "source_context",
-            "type",
-        ],
-    ),
-    "compiler_plugin": (
-        "google/protobuf/compiler/plugin.proto",
-        ["descriptor"],
-    ),
-    "descriptor": ("google/protobuf/descriptor.proto", []),
-    "duration": ("google/protobuf/duration.proto", []),
-    "empty": ("google/protobuf/empty.proto", []),
-    "field_mask": ("google/protobuf/field_mask.proto", []),
-    "source_context": ("google/protobuf/source_context.proto", []),
-    "struct": ("google/protobuf/struct.proto", []),
-    "timestamp": ("google/protobuf/timestamp.proto", []),
-    "type": (
-        "google/protobuf/type.proto",
-        [
-            "any",
-            "source_context",
-        ],
-    ),
-    "wrappers": ("google/protobuf/wrappers.proto", []),
-}
-
-RELATIVE_WELL_KNOWN_PROTOS = [proto[1][0] for proto in WELL_KNOWN_PROTO_MAP.items()]
-
-genrule(
-    name = "link_proto_files",
-    outs = RELATIVE_WELL_KNOWN_PROTOS,
-    cmd = """
-      for i in $(OUTS); do
-        f=$${i#$(@D)/}
-        mkdir -p $(@D)/$${f%/*}
-        ln -sf $(PROTOBUF_INCLUDE_PATH)/$$f $(@D)/$$f
-      done
-    """,
-)
-
-cc_library(
-    name = "protobuf",
-    linkopts = ["-lprotobuf"],
-    visibility = ["//visibility:public"],
-)
-
-cc_library(
-    name = "protobuf_headers",
-    linkopts = ["-lprotobuf"],
-    visibility = ["//visibility:public"],
-)
-
-cc_library(
-    name = "protoc_lib",
-    linkopts = ["-lprotoc"],
-    visibility = ["//visibility:public"],
-)
-
-genrule(
-    name = "protoc",
-    outs = ["protoc.bin"],
-    cmd = "ln -s $$(which protoc) $@",
-    executable = 1,
-    visibility = ["//visibility:public"],
-)
-
-cc_proto_library(
-    name = "cc_wkt_protos",
-    internal_bootstrap_hack = 1,
-    protoc = ":protoc",
-    visibility = ["//visibility:public"],
-)
-
-proto_gen(
-    name = "protobuf_python_genproto",
-    includes = ["."],
-    protoc = "@com_google_protobuf//:protoc",
-    visibility = ["//visibility:public"],
-)
-
-py_library(
-    name = "protobuf_python",
-    srcs_version = "PY3",
-    visibility = ["//visibility:public"],
-)
-
-[proto_library(
-    name = proto[0] + "_proto",
-    srcs = [proto[1][0]],
-    visibility = ["//visibility:public"],
-    deps = [dep + "_proto" for dep in proto[1][1]],
-) for proto in WELL_KNOWN_PROTO_MAP.items()]
diff --git a/third_party/xla/third_party/systemlibs/protobuf.bzl b/third_party/xla/third_party/systemlibs/protobuf.bzl
deleted file mode 100644
index 66a0630..0000000
--- a/third_party/xla/third_party/systemlibs/protobuf.bzl
+++ /dev/null
@@ -1,430 +0,0 @@
-def _GetPath(ctx, path):
-    if ctx.label.workspace_root:
-        return ctx.label.workspace_root + "/" + path
-    else:
-        return path
-
-def _IsNewExternal(ctx):
-    # Bazel 0.4.4 and older have genfiles paths that look like:
-    #   bazel-out/local-fastbuild/genfiles/external/repo/foo
-    # After the exec root rearrangement, they look like:
-    #   ../repo/bazel-out/local-fastbuild/genfiles/foo
-    return ctx.label.workspace_root.startswith("../")
-
-def _GenDir(ctx):
-    if _IsNewExternal(ctx):
-        # We are using the fact that Bazel 0.4.4+ provides repository-relative paths
-        # for ctx.genfiles_dir.
-        return ctx.genfiles_dir.path + (
-            "/" + ctx.attr.includes[0] if ctx.attr.includes and ctx.attr.includes[0] else ""
-        )
-
-    # This means that we're either in the old version OR the new version in the local repo.
-    # Either way, appending the source path to the genfiles dir works.
-    return ctx.var["GENDIR"] + "/" + _SourceDir(ctx)
-
-def _SourceDir(ctx):
-    if not ctx.attr.includes:
-        return ctx.label.workspace_root
-    if not ctx.attr.includes[0]:
-        return _GetPath(ctx, ctx.label.package)
-    if not ctx.label.package:
-        return _GetPath(ctx, ctx.attr.includes[0])
-    return _GetPath(ctx, ctx.label.package + "/" + ctx.attr.includes[0])
-
-def _CcHdrs(srcs, use_grpc_plugin = False):
-    ret = [s[:-len(".proto")] + ".pb.h" for s in srcs]
-    if use_grpc_plugin:
-        ret += [s[:-len(".proto")] + ".grpc.pb.h" for s in srcs]
-    return ret
-
-def _CcSrcs(srcs, use_grpc_plugin = False):
-    ret = [s[:-len(".proto")] + ".pb.cc" for s in srcs]
-    if use_grpc_plugin:
-        ret += [s[:-len(".proto")] + ".grpc.pb.cc" for s in srcs]
-    return ret
-
-def _CcOuts(srcs, use_grpc_plugin = False):
-    return _CcHdrs(srcs, use_grpc_plugin) + _CcSrcs(srcs, use_grpc_plugin)
-
-def _PyOuts(srcs, use_grpc_plugin = False):
-    ret = [s[:-len(".proto")] + "_pb2.py" for s in srcs]
-    if use_grpc_plugin:
-        ret += [s[:-len(".proto")] + "_pb2_grpc.py" for s in srcs]
-    return ret
-
-def _RelativeOutputPath(path, include, dest = ""):
-    if include == None:
-        return path
-
-    if not path.startswith(include):
-        fail("Include path %s isn't part of the path %s." % (include, path))
-
-    if include and include[-1] != "/":
-        include = include + "/"
-    if dest and dest[-1] != "/":
-        dest = dest + "/"
-
-    path = path[len(include):]
-    return dest + path
-
-def _proto_gen_impl(ctx):
-    """General implementation for generating protos"""
-    srcs = ctx.files.srcs
-    deps = []
-    deps += ctx.files.srcs
-    source_dir = _SourceDir(ctx)
-    gen_dir = _GenDir(ctx)
-    if source_dir:
-        import_flags = ["-I" + source_dir, "-I" + gen_dir]
-    else:
-        import_flags = ["-I."]
-
-    for dep in ctx.attr.deps:
-        import_flags += dep.proto.import_flags
-        deps += dep.proto.deps
-    import_flags = depset(import_flags).to_list()
-    deps = depset(deps).to_list()
-
-    args = []
-    if ctx.attr.gen_cc:
-        args += ["--cpp_out=" + gen_dir]
-    if ctx.attr.gen_py:
-        args += ["--python_out=" + gen_dir]
-
-    inputs = srcs + deps
-    tools = [ctx.executable.protoc]
-    if ctx.executable.plugin:
-        plugin = ctx.executable.plugin
-        lang = ctx.attr.plugin_language
-        if not lang and plugin.basename.startswith("protoc-gen-"):
-            lang = plugin.basename[len("protoc-gen-"):]
-        if not lang:
-            fail("cannot infer the target language of plugin", "plugin_language")
-
-        outdir = gen_dir
-        if ctx.attr.plugin_options:
-            outdir = ",".join(ctx.attr.plugin_options) + ":" + outdir
-        args += ["--plugin=protoc-gen-%s=%s" % (lang, plugin.path)]
-        args += ["--%s_out=%s" % (lang, outdir)]
-        tools.append(plugin)
-
-    if args:
-        ctx.actions.run(
-            inputs = inputs,
-            outputs = ctx.outputs.outs,
-            arguments = args + import_flags + [s.path for s in srcs],
-            executable = ctx.executable.protoc,
-            mnemonic = "ProtoCompile",
-            tools = tools,
-            use_default_shell_env = True,
-        )
-
-    return struct(
-        proto = struct(
-            srcs = srcs,
-            import_flags = import_flags,
-            deps = deps,
-        ),
-    )
-
-proto_gen = rule(
-    attrs = {
-        "srcs": attr.label_list(allow_files = True),
-        "deps": attr.label_list(providers = ["proto"]),
-        "includes": attr.string_list(),
-        "protoc": attr.label(
-            cfg = "host",
-            executable = True,
-            allow_single_file = True,
-            mandatory = True,
-        ),
-        "plugin": attr.label(
-            cfg = "host",
-            allow_files = True,
-            executable = True,
-        ),
-        "plugin_language": attr.string(),
-        "plugin_options": attr.string_list(),
-        "gen_cc": attr.bool(),
-        "gen_py": attr.bool(),
-        "outs": attr.output_list(),
-    },
-    output_to_genfiles = True,
-    implementation = _proto_gen_impl,
-)
-"""Generates codes from Protocol Buffers definitions.
-
-This rule helps you to implement Skylark macros specific to the target
-language. You should prefer more specific `cc_proto_library `,
-`py_proto_library` and others unless you are adding such wrapper macros.
-
-Args:
-  srcs: Protocol Buffers definition files (.proto) to run the protocol compiler
-    against.
-  deps: a list of dependency labels; must be other proto libraries.
-  includes: a list of include paths to .proto files.
-  protoc: the label of the protocol compiler to generate the sources.
-  plugin: the label of the protocol compiler plugin to be passed to the protocol
-    compiler.
-  plugin_language: the language of the generated sources
-  plugin_options: a list of options to be passed to the plugin
-  gen_cc: generates C++ sources in addition to the ones from the plugin.
-  gen_py: generates Python sources in addition to the ones from the plugin.
-  outs: a list of labels of the expected outputs from the protocol compiler.
-"""
-
-def cc_proto_library(
-        name,
-        srcs = [],
-        deps = [],
-        cc_libs = [],
-        include = None,
-        protoc = "@com_google_protobuf//:protoc",
-        internal_bootstrap_hack = False,
-        use_grpc_plugin = False,
-        default_runtime = "@com_google_protobuf//:protobuf",
-        **kwargs):
-    """Bazel rule to create a C++ protobuf library from proto source files
-
-    NOTE: the rule is only an internal workaround to generate protos. The
-    interface may change and the rule may be removed when bazel has introduced
-    the native rule.
-
-    Args:
-      name: the name of the cc_proto_library.
-      srcs: the .proto files of the cc_proto_library.
-      deps: a list of dependency labels; must be cc_proto_library.
-      cc_libs: a list of other cc_library targets depended by the generated
-          cc_library.
-      include: a string indicating the include path of the .proto files.
-      protoc: the label of the protocol compiler to generate the sources.
-      internal_bootstrap_hack: a flag indicating if the cc_proto_library is used only
-          for bootstrapping. When it is set to True, no files will be generated.
-          The rule will simply be a provider for .proto files, so that other
-          cc_proto_library can depend on it.
-      use_grpc_plugin: a flag to indicate whether to call the grpc C++ plugin
-          when processing the proto files.
-      default_runtime: the implicitly default runtime which will be depended on by
-          the generated cc_library target.
-      **kwargs: other keyword arguments that are passed to cc_library.
-
-    """
-
-    includes = []
-    if include != None:
-        includes = [include]
-
-    if internal_bootstrap_hack:
-        # For pre-checked-in generated files, we add the internal_bootstrap_hack
-        # which will skip the codegen action.
-        proto_gen(
-            name = name + "_genproto",
-            srcs = srcs,
-            deps = [s + "_genproto" for s in deps],
-            includes = includes,
-            protoc = protoc,
-            visibility = ["//visibility:public"],
-        )
-
-        # An empty cc_library to make rule dependency consistent.
-        native.cc_library(
-            name = name,
-            **kwargs
-        )
-        return
-
-    grpc_cpp_plugin = None
-    if use_grpc_plugin:
-        grpc_cpp_plugin = "//external:grpc_cpp_plugin"
-
-    gen_srcs = _CcSrcs(srcs, use_grpc_plugin)
-    gen_hdrs = _CcHdrs(srcs, use_grpc_plugin)
-    outs = gen_srcs + gen_hdrs
-
-    proto_gen(
-        name = name + "_genproto",
-        srcs = srcs,
-        deps = [s + "_genproto" for s in deps],
-        includes = includes,
-        protoc = protoc,
-        plugin = grpc_cpp_plugin,
-        plugin_language = "grpc",
-        gen_cc = 1,
-        outs = outs,
-        visibility = ["//visibility:public"],
-    )
-
-    if default_runtime and not default_runtime in cc_libs:
-        cc_libs = cc_libs + [default_runtime]
-    if use_grpc_plugin:
-        cc_libs = cc_libs + ["//external:grpc_lib"]
-
-    native.cc_library(
-        name = name,
-        srcs = gen_srcs,
-        hdrs = gen_hdrs,
-        deps = cc_libs + deps,
-        includes = includes,
-        alwayslink = 1,
-        **kwargs
-    )
-
-def internal_gen_well_known_protos_java(srcs):
-    """Bazel rule to generate the gen_well_known_protos_java genrule
-
-    Args:
-      srcs: the well known protos
-    """
-    root = Label("%s//protobuf_java" % (native.repository_name())).workspace_root
-    pkg = native.package_name() + "/" if native.package_name() else ""
-    if root == "":
-        include = " -I%ssrc " % pkg
-    else:
-        include = " -I%s/%ssrc " % (root, pkg)
-    native.genrule(
-        name = "gen_well_known_protos_java",
-        srcs = srcs,
-        outs = [
-            "wellknown.srcjar",
-        ],
-        cmd = "$(location :protoc) --java_out=$(@D)/wellknown.jar" +
-              " %s $(SRCS) " % include +
-              " && mv $(@D)/wellknown.jar $(@D)/wellknown.srcjar",
-        tools = [":protoc"],
-    )
-
-def internal_copied_filegroup(name, srcs, strip_prefix, dest, **kwargs):
-    """Macro to copy files to a different directory and then create a filegroup.
-
-    This is used by the //:protobuf_python py_proto_library target to work around
-    an issue caused by Python source files that are part of the same Python
-    package being in separate directories.
-
-    Args:
-      srcs: The source files to copy and add to the filegroup.
-      strip_prefix: Path to the root of the files to copy.
-      dest: The directory to copy the source files into.
-      **kwargs: extra arguments that will be passesd to the filegroup.
-    """
-    outs = [_RelativeOutputPath(s, strip_prefix, dest) for s in srcs]
-
-    native.genrule(
-        name = name + "_genrule",
-        srcs = srcs,
-        outs = outs,
-        cmd = " && ".join(
-            ["cp $(location %s) $(location %s)" %
-             (s, _RelativeOutputPath(s, strip_prefix, dest)) for s in srcs],
-        ),
-    )
-
-    native.filegroup(
-        name = name,
-        srcs = outs,
-        **kwargs
-    )
-
-def py_proto_library(
-        name,
-        srcs = [],
-        deps = [],
-        py_libs = [],
-        py_extra_srcs = [],
-        include = None,
-        default_runtime = "@com_google_protobuf//:protobuf_python",
-        protoc = "@com_google_protobuf//:protoc",
-        use_grpc_plugin = False,
-        **kwargs):
-    """Bazel rule to create a Python protobuf library from proto source files
-
-    NOTE: the rule is only an internal workaround to generate protos. The
-    interface may change and the rule may be removed when bazel has introduced
-    the native rule.
-
-    Args:
-      name: the name of the py_proto_library.
-      srcs: the .proto files of the py_proto_library.
-      deps: a list of dependency labels; must be py_proto_library.
-      py_libs: a list of other py_library targets depended by the generated
-          py_library.
-      py_extra_srcs: extra source files that will be added to the output
-          py_library. This attribute is used for internal bootstrapping.
-      include: a string indicating the include path of the .proto files.
-      default_runtime: the implicitly default runtime which will be depended on by
-          the generated py_library target.
-      protoc: the label of the protocol compiler to generate the sources.
-      use_grpc_plugin: a flag to indicate whether to call the Python C++ plugin
-          when processing the proto files.
-      **kwargs: other keyword arguments that are passed to py_library.
-
-    """
-    outs = _PyOuts(srcs, use_grpc_plugin)
-
-    includes = []
-    if include != None:
-        includes = [include]
-
-    grpc_python_plugin = None
-    if use_grpc_plugin:
-        grpc_python_plugin = "//external:grpc_python_plugin"
-        # Note: Generated grpc code depends on Python grpc module. This dependency
-        # is not explicitly listed in py_libs. Instead, host system is assumed to
-        # have grpc installed.
-
-    proto_gen(
-        name = name + "_genproto",
-        srcs = srcs,
-        deps = [s + "_genproto" for s in deps],
-        includes = includes,
-        protoc = protoc,
-        gen_py = 1,
-        outs = outs,
-        visibility = ["//visibility:public"],
-        plugin = grpc_python_plugin,
-        plugin_language = "grpc",
-    )
-
-    if default_runtime and not default_runtime in py_libs + deps:
-        py_libs = py_libs + [default_runtime]
-
-    native.py_library(
-        name = name,
-        srcs = outs + py_extra_srcs,
-        deps = py_libs + deps,
-        imports = includes,
-        **kwargs
-    )
-
-def internal_protobuf_py_tests(
-        name,
-        modules = [],
-        **kwargs):
-    """Bazel rules to create batch tests for protobuf internal.
-
-    Args:
-      name: the name of the rule.
-      modules: a list of modules for tests. The macro will create a py_test for
-          each of the parameter with the source "google/protobuf/%s.py"
-      **kwargs: extra parameters that will be passed into the py_test.
-
-    """
-    for m in modules:
-        s = "python/google/protobuf/internal/%s.py" % m
-        native.py_test(
-            name = "py_%s" % m,
-            srcs = [s],
-            main = s,
-            **kwargs
-        )
-
-def check_protobuf_required_bazel_version():
-    """For WORKSPACE files, to check the installed version of bazel.
-
-    This ensures bazel supports our approach to proto_library() depending on a
-    copied filegroup. (Fixed in bazel 0.5.4)
-    """
-    expected = apple_common.dotted_version("0.5.4")
-    current = apple_common.dotted_version(native.bazel_version)
-    if current.compare_to(expected) < 0:
-        fail("Bazel must be newer than 0.5.4")
diff --git a/third_party/xla/third_party/systemlibs/protobuf_deps.bzl b/third_party/xla/third_party/systemlibs/protobuf_deps.bzl
deleted file mode 100644
index aafd89b..0000000
--- a/third_party/xla/third_party/systemlibs/protobuf_deps.bzl
+++ /dev/null
@@ -1,2 +0,0 @@
-def protobuf_deps():
-    pass
diff --git a/third_party/xla/third_party/systemlibs/pybind11.BUILD b/third_party/xla/third_party/systemlibs/pybind11.BUILD
deleted file mode 100644
index 9ea6b41..0000000
--- a/third_party/xla/third_party/systemlibs/pybind11.BUILD
+++ /dev/null
@@ -1,8 +0,0 @@
-package(default_visibility = ["//visibility:public"])
-
-cc_library(
-    name = "pybind11",
-    deps = [
-        "@local_xla//third_party/python_runtime:headers",
-    ],
-)
diff --git a/third_party/xla/third_party/systemlibs/re2.BUILD b/third_party/xla/third_party/systemlibs/re2.BUILD
deleted file mode 100644
index c18e252..0000000
--- a/third_party/xla/third_party/systemlibs/re2.BUILD
+++ /dev/null
@@ -1,12 +0,0 @@
-licenses(["notice"])  # BSD/MIT-like license
-
-filegroup(
-    name = "LICENSE",
-    visibility = ["//visibility:public"],
-)
-
-cc_library(
-    name = "re2",
-    linkopts = ["-lre2"],
-    visibility = ["//visibility:public"],
-)
diff --git a/third_party/xla/third_party/systemlibs/six.BUILD b/third_party/xla/third_party/systemlibs/six.BUILD
deleted file mode 100644
index ff9b1a5..0000000
--- a/third_party/xla/third_party/systemlibs/six.BUILD
+++ /dev/null
@@ -1,11 +0,0 @@
-licenses(["notice"])  # MIT
-
-filegroup(
-    name = "LICENSE",
-    visibility = ["//visibility:public"],
-)
-
-py_library(
-    name = "six",
-    visibility = ["//visibility:public"],
-)
diff --git a/third_party/xla/third_party/systemlibs/snappy.BUILD b/third_party/xla/third_party/systemlibs/snappy.BUILD
deleted file mode 100644
index fd2db9e..0000000
--- a/third_party/xla/third_party/systemlibs/snappy.BUILD
+++ /dev/null
@@ -1,12 +0,0 @@
-licenses(["notice"])  # BSD 3-Clause
-
-filegroup(
-    name = "COPYING",
-    visibility = ["//visibility:public"],
-)
-
-cc_library(
-    name = "snappy",
-    linkopts = ["-lsnappy"],
-    visibility = ["//visibility:public"],
-)
diff --git a/third_party/xla/third_party/systemlibs/sqlite.BUILD b/third_party/xla/third_party/systemlibs/sqlite.BUILD
deleted file mode 100644
index 88a84a9..0000000
--- a/third_party/xla/third_party/systemlibs/sqlite.BUILD
+++ /dev/null
@@ -1,15 +0,0 @@
-licenses(["unencumbered"])  # Public Domain
-
-# Production build of SQLite library that's baked into TensorFlow.
-cc_library(
-    name = "org_sqlite",
-    linkopts = ["-lsqlite3"],
-    visibility = ["//visibility:public"],
-)
-
-# This is a Copybara sync helper for Google.
-py_library(
-    name = "python",
-    srcs_version = "PY3",
-    visibility = ["//visibility:public"],
-)
diff --git a/third_party/xla/third_party/systemlibs/syslibs_configure.bzl b/third_party/xla/third_party/systemlibs/syslibs_configure.bzl
deleted file mode 100644
index 822ecda..0000000
--- a/third_party/xla/third_party/systemlibs/syslibs_configure.bzl
+++ /dev/null
@@ -1,170 +0,0 @@
-"""Repository rule for system library autoconfiguration.
-
-`syslibs_configure` depends on the following environment variables:
-
-  * `TF_SYSTEM_LIBS`: list of third party dependencies that should use
-    the system version instead
-"""
-
-_TF_SYSTEM_LIBS = "TF_SYSTEM_LIBS"
-
-VALID_LIBS = [
-    "absl_py",
-    "astor_archive",
-    "astunparse_archive",
-    "boringssl",
-    "com_github_googlecloudplatform_google_cloud_cpp",
-    "com_github_grpc_grpc",
-    "com_google_absl",
-    "com_google_protobuf",
-    "com_googlesource_code_re2",
-    "curl",
-    "cython",
-    "dill_archive",
-    "double_conversion",
-    "flatbuffers",
-    "functools32_archive",
-    "gast_archive",
-    "gif",
-    "hwloc",
-    "icu",
-    "jsoncpp_git",
-    "libjpeg_turbo",
-    "nasm",
-    "nsync",
-    "org_sqlite",
-    "pasta",
-    "png",
-    "pybind11",
-    "six_archive",
-    "snappy",
-    "tblib_archive",
-    "termcolor_archive",
-    "typing_extensions_archive",
-    "wrapt",
-    "zlib",
-]
-
-def auto_configure_fail(msg):
-    """Output failure message when syslibs configuration fails."""
-    red = "\033[0;31m"
-    no_color = "\033[0m"
-    fail("\n%sSystem Library Configuration Error:%s %s\n" % (red, no_color, msg))
-
-def _is_windows(repository_ctx):
-    """Returns true if the host operating system is windows."""
-    os_name = repository_ctx.os.name.lower()
-    if os_name.find("windows") != -1:
-        return True
-    return False
-
-def _enable_syslibs(repository_ctx):
-    s = repository_ctx.os.environ.get(_TF_SYSTEM_LIBS, "").strip()
-    if not _is_windows(repository_ctx) and s != None and s != "":
-        return True
-    return False
-
-def _get_system_lib_list(repository_ctx):
-    """Gets the list of deps that should use the system lib.
-
-    Args:
-      repository_ctx: The repository context.
-
-    Returns:
-      A string version of a python list
-    """
-    if _TF_SYSTEM_LIBS not in repository_ctx.os.environ:
-        return []
-
-    libenv = repository_ctx.os.environ[_TF_SYSTEM_LIBS].strip()
-    libs = []
-
-    for lib in list(libenv.split(",")):
-        lib = lib.strip()
-        if lib == "":
-            continue
-        if lib not in VALID_LIBS:
-            auto_configure_fail("Invalid system lib set: %s" % lib)
-            return []
-        libs.append(lib)
-
-    return libs
-
-def _format_system_lib_list(repository_ctx):
-    """Formats the list of deps that should use the system lib.
-
-    Args:
-      repository_ctx: The repository context.
-
-    Returns:
-      A list of the names of deps that should use the system lib.
-    """
-    libs = _get_system_lib_list(repository_ctx)
-    ret = ""
-    for lib in libs:
-        ret += "'%s',\n" % lib
-
-    return ret
-
-def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
-    if not out:
-        out = tpl.replace(":", "")
-    repository_ctx.template(
-        out,
-        Label("//third_party/systemlibs%s.tpl" % tpl),
-        substitutions,
-        False,
-    )
-
-def _create_dummy_repository(repository_ctx):
-    """Creates the dummy repository to build with all bundled libraries."""
-
-    _tpl(repository_ctx, ":BUILD")
-    _tpl(
-        repository_ctx,
-        ":build_defs.bzl",
-        {
-            "%{syslibs_enabled}": "False",
-            "%{syslibs_list}": "",
-        },
-    )
-
-def _create_local_repository(repository_ctx):
-    """Creates the repository to build with system libraries."""
-
-    _tpl(repository_ctx, ":BUILD")
-    _tpl(
-        repository_ctx,
-        ":build_defs.bzl",
-        {
-            "%{syslibs_enabled}": "True",
-            "%{syslibs_list}": _format_system_lib_list(repository_ctx),
-        },
-    )
-
-def _syslibs_autoconf_impl(repository_ctx):
-    """Implementation of the syslibs_configure repository rule."""
-    if not _enable_syslibs(repository_ctx):
-        _create_dummy_repository(repository_ctx)
-    else:
-        _create_local_repository(repository_ctx)
-
-syslibs_configure = repository_rule(
-    implementation = _syslibs_autoconf_impl,
-    environ = [
-        _TF_SYSTEM_LIBS,
-    ],
-)
-
-"""Configures the build to link to system libraries
-instead of using bundled versions.
-
-Add the following to your WORKSPACE FILE:
-
-```python
-syslibs_configure(name = "local_config_syslibs")
-```
-
-Args:
-  name: A unique name for this workspace rule.
-"""
diff --git a/third_party/xla/third_party/systemlibs/tblib.BUILD b/third_party/xla/third_party/systemlibs/tblib.BUILD
deleted file mode 100644
index ac411ce..0000000
--- a/third_party/xla/third_party/systemlibs/tblib.BUILD
+++ /dev/null
@@ -1,12 +0,0 @@
-licenses(["notice"])  # BSD 3-clause
-
-filegroup(
-    name = "LICENSE",
-    visibility = ["//visibility:public"],
-)
-
-py_library(
-    name = "tblib",
-    srcs_version = "PY3",
-    visibility = ["//visibility:public"],
-)
diff --git a/third_party/xla/third_party/systemlibs/typing_extensions.BUILD b/third_party/xla/third_party/systemlibs/typing_extensions.BUILD
deleted file mode 100644
index dc5d58a..0000000
--- a/third_party/xla/third_party/systemlibs/typing_extensions.BUILD
+++ /dev/null
@@ -1,16 +0,0 @@
-# Description:
-#   Backports for the typing module to older Python versions. See
-#   https://github.com/python/typing/blob/master/typing_extensions/README.rst
-
-licenses(["notice"])  # PSF
-
-py_library(
-    name = "typing_extensions",
-    srcs_version = "PY3",
-    visibility = ["//visibility:public"],
-)
-
-filegroup(
-    name = "LICENSE",
-    visibility = ["//visibility:public"],
-)
diff --git a/third_party/xla/third_party/systemlibs/wrapt.BUILD b/third_party/xla/third_party/systemlibs/wrapt.BUILD
deleted file mode 100644
index 633feb2..0000000
--- a/third_party/xla/third_party/systemlibs/wrapt.BUILD
+++ /dev/null
@@ -1,4 +0,0 @@
-py_library(
-    name = "wrapt",
-    visibility = ["//visibility:public"],
-)
diff --git a/third_party/xla/third_party/systemlibs/zlib.BUILD b/third_party/xla/third_party/systemlibs/zlib.BUILD
deleted file mode 100644
index 69462ae..0000000
--- a/third_party/xla/third_party/systemlibs/zlib.BUILD
+++ /dev/null
@@ -1,12 +0,0 @@
-licenses(["notice"])  # BSD/MIT-like license (for zlib)
-
-filegroup(
-    name = "zlib.h",
-    visibility = ["//visibility:public"],
-)
-
-cc_library(
-    name = "zlib",
-    linkopts = ["-lz"],
-    visibility = ["//visibility:public"],
-)
diff --git a/third_party/xla/third_party/tensorrt/BUILD b/third_party/xla/third_party/tensorrt/BUILD
deleted file mode 100644
index e69de29..0000000
--- a/third_party/xla/third_party/tensorrt/BUILD
+++ /dev/null
diff --git a/third_party/xla/third_party/tensorrt/BUILD.tpl b/third_party/xla/third_party/tensorrt/BUILD.tpl
deleted file mode 100644
index 7fa5935..0000000
--- a/third_party/xla/third_party/tensorrt/BUILD.tpl
+++ /dev/null
@@ -1,60 +0,0 @@
-# NVIDIA TensorRT
-# A high-performance deep learning inference optimizer and runtime.
-
-licenses(["notice"])
-
-load("@local_config_cuda//cuda:build_defs.bzl", "cuda_default_copts")
-load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
-
-package(default_visibility = ["//visibility:public"])
-
-exports_files(["LICENSE"])
-
-config_setting(
-    name = "use_static_tensorrt",
-    define_values = {"TF_TENSORRT_STATIC":"1"},
-)
-
-cc_library(
-    name = "tensorrt_headers",
-    hdrs = [
-        "tensorrt/include/tensorrt_config.h",
-        ":tensorrt_include"
-    ],
-    include_prefix = "third_party/tensorrt",
-    strip_include_prefix = "tensorrt/include",
-)
-
-cc_library(
-    name = "tensorrt",
-    srcs = select({
-        ":use_static_tensorrt": [":tensorrt_static_lib"],
-        "//conditions:default": [":tensorrt_lib"],
-    }),
-    copts = cuda_default_copts(),
-    data = select({
-        ":use_static_tensorrt": [],
-        "//conditions:default": [":tensorrt_lib"],
-    }),
-    linkstatic = 1,
-    deps = [
-        ":tensorrt_headers",
-        # TODO(b/174608722): fix this line.
-        "@local_config_cuda//cuda",
-    ],
-)
-
-bzl_library(
-    name = "build_defs_bzl",
-    srcs = ["build_defs.bzl"],
-    deps = [
-        "@bazel_skylib//lib:selects",
-    ],
-)
-
-py_library(
-    name = "tensorrt_config_py",
-    srcs = ["tensorrt/tensorrt_config.py"]
-)
-
-%{copy_rules}
diff --git a/third_party/xla/third_party/tensorrt/LICENSE b/third_party/xla/third_party/tensorrt/LICENSE
deleted file mode 100644
index 146d9b7..0000000
--- a/third_party/xla/third_party/tensorrt/LICENSE
+++ /dev/null
@@ -1,203 +0,0 @@
-Copyright 2018 The TensorFlow Authors.  All rights reserved.
-
-                                 Apache License
-                           Version 2.0, January 2004
-                        http://www.apache.org/licenses/
-
-   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-
-   1. Definitions.
-
-      "License" shall mean the terms and conditions for use, reproduction,
-      and distribution as defined by Sections 1 through 9 of this document.
-
-      "Licensor" shall mean the copyright owner or entity authorized by
-      the copyright owner that is granting the License.
-
-      "Legal Entity" shall mean the union of the acting entity and all
-      other entities that control, are controlled by, or are under common
-      control with that entity. For the purposes of this definition,
-      "control" means (i) the power, direct or indirect, to cause the
-      direction or management of such entity, whether by contract or
-      otherwise, or (ii) ownership of fifty percent (50%) or more of the
-      outstanding shares, or (iii) beneficial ownership of such entity.
-
-      "You" (or "Your") shall mean an individual or Legal Entity
-      exercising permissions granted by this License.
-
-      "Source" form shall mean the preferred form for making modifications,
-      including but not limited to software source code, documentation
-      source, and configuration files.
-
-      "Object" form shall mean any form resulting from mechanical
-      transformation or translation of a Source form, including but
-      not limited to compiled object code, generated documentation,
-      and conversions to other media types.
-
-      "Work" shall mean the work of authorship, whether in Source or
-      Object form, made available under the License, as indicated by a
-      copyright notice that is included in or attached to the work
-      (an example is provided in the Appendix below).
-
-      "Derivative Works" shall mean any work, whether in Source or Object
-      form, that is based on (or derived from) the Work and for which the
-      editorial revisions, annotations, elaborations, or other modifications
-      represent, as a whole, an original work of authorship. For the purposes
-      of this License, Derivative Works shall not include works that remain
-      separable from, or merely link (or bind by name) to the interfaces of,
-      the Work and Derivative Works thereof.
-
-      "Contribution" shall mean any work of authorship, including
-      the original version of the Work and any modifications or additions
-      to that Work or Derivative Works thereof, that is intentionally
-      submitted to Licensor for inclusion in the Work by the copyright owner
-      or by an individual or Legal Entity authorized to submit on behalf of
-      the copyright owner. For the purposes of this definition, "submitted"
-      means any form of electronic, verbal, or written communication sent
-      to the Licensor or its representatives, including but not limited to
-      communication on electronic mailing lists, source code control systems,
-      and issue tracking systems that are managed by, or on behalf of, the
-      Licensor for the purpose of discussing and improving the Work, but
-      excluding communication that is conspicuously marked or otherwise
-      designated in writing by the copyright owner as "Not a Contribution."
-
-      "Contributor" shall mean Licensor and any individual or Legal Entity
-      on behalf of whom a Contribution has been received by Licensor and
-      subsequently incorporated within the Work.
-
-   2. Grant of Copyright License. Subject to the terms and conditions of
-      this License, each Contributor hereby grants to You a perpetual,
-      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
-      copyright license to reproduce, prepare Derivative Works of,
-      publicly display, publicly perform, sublicense, and distribute the
-      Work and such Derivative Works in Source or Object form.
-
-   3. Grant of Patent License. Subject to the terms and conditions of
-      this License, each Contributor hereby grants to You a perpetual,
-      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
-      (except as stated in this section) patent license to make, have made,
-      use, offer to sell, sell, import, and otherwise transfer the Work,
-      where such license applies only to those patent claims licensable
-      by such Contributor that are necessarily infringed by their
-      Contribution(s) alone or by combination of their Contribution(s)
-      with the Work to which such Contribution(s) was submitted. If You
-      institute patent litigation against any entity (including a
-      cross-claim or counterclaim in a lawsuit) alleging that the Work
-      or a Contribution incorporated within the Work constitutes direct
-      or contributory patent infringement, then any patent licenses
-      granted to You under this License for that Work shall terminate
-      as of the date such litigation is filed.
-
-   4. Redistribution. You may reproduce and distribute copies of the
-      Work or Derivative Works thereof in any medium, with or without
-      modifications, and in Source or Object form, provided that You
-      meet the following conditions:
-
-      (a) You must give any other recipients of the Work or
-          Derivative Works a copy of this License; and
-
-      (b) You must cause any modified files to carry prominent notices
-          stating that You changed the files; and
-
-      (c) You must retain, in the Source form of any Derivative Works
-          that You distribute, all copyright, patent, trademark, and
-          attribution notices from the Source form of the Work,
-          excluding those notices that do not pertain to any part of
-          the Derivative Works; and
-
-      (d) If the Work includes a "NOTICE" text file as part of its
-          distribution, then any Derivative Works that You distribute must
-          include a readable copy of the attribution notices contained
-          within such NOTICE file, excluding those notices that do not
-          pertain to any part of the Derivative Works, in at least one
-          of the following places: within a NOTICE text file distributed
-          as part of the Derivative Works; within the Source form or
-          documentation, if provided along with the Derivative Works; or,
-          within a display generated by the Derivative Works, if and
-          wherever such third-party notices normally appear. The contents
-          of the NOTICE file are for informational purposes only and
-          do not modify the License. You may add Your own attribution
-          notices within Derivative Works that You distribute, alongside
-          or as an addendum to the NOTICE text from the Work, provided
-          that such additional attribution notices cannot be construed
-          as modifying the License.
-
-      You may add Your own copyright statement to Your modifications and
-      may provide additional or different license terms and conditions
-      for use, reproduction, or distribution of Your modifications, or
-      for any such Derivative Works as a whole, provided Your use,
-      reproduction, and distribution of the Work otherwise complies with
-      the conditions stated in this License.
-
-   5. Submission of Contributions. Unless You explicitly state otherwise,
-      any Contribution intentionally submitted for inclusion in the Work
-      by You to the Licensor shall be under the terms and conditions of
-      this License, without any additional terms or conditions.
-      Notwithstanding the above, nothing herein shall supersede or modify
-      the terms of any separate license agreement you may have executed
-      with Licensor regarding such Contributions.
-
-   6. Trademarks. This License does not grant permission to use the trade
-      names, trademarks, service marks, or product names of the Licensor,
-      except as required for reasonable and customary use in describing the
-      origin of the Work and reproducing the content of the NOTICE file.
-
-   7. Disclaimer of Warranty. Unless required by applicable law or
-      agreed to in writing, Licensor provides the Work (and each
-      Contributor provides its Contributions) on an "AS IS" BASIS,
-      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
-      implied, including, without limitation, any warranties or conditions
-      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
-      PARTICULAR PURPOSE. You are solely responsible for determining the
-      appropriateness of using or redistributing the Work and assume any
-      risks associated with Your exercise of permissions under this License.
-
-   8. Limitation of Liability. In no event and under no legal theory,
-      whether in tort (including negligence), contract, or otherwise,
-      unless required by applicable law (such as deliberate and grossly
-      negligent acts) or agreed to in writing, shall any Contributor be
-      liable to You for damages, including any direct, indirect, special,
-      incidental, or consequential damages of any character arising as a
-      result of this License or out of the use or inability to use the
-      Work (including but not limited to damages for loss of goodwill,
-      work stoppage, computer failure or malfunction, or any and all
-      other commercial damages or losses), even if such Contributor
-      has been advised of the possibility of such damages.
-
-   9. Accepting Warranty or Additional Liability. While redistributing
-      the Work or Derivative Works thereof, You may choose to offer,
-      and charge a fee for, acceptance of support, warranty, indemnity,
-      or other liability obligations and/or rights consistent with this
-      License. However, in accepting such obligations, You may act only
-      on Your own behalf and on Your sole responsibility, not on behalf
-      of any other Contributor, and only if You agree to indemnify,
-      defend, and hold each Contributor harmless for any liability
-      incurred by, or claims asserted against, such Contributor by reason
-      of your accepting any such warranty or additional liability.
-
-   END OF TERMS AND CONDITIONS
-
-   APPENDIX: How to apply the Apache License to your work.
-
-      To apply the Apache License to your work, attach the following
-      boilerplate notice, with the fields enclosed by brackets "[]"
-      replaced with your own identifying information. (Don't include
-      the brackets!)  The text should be enclosed in the appropriate
-      comment syntax for the file format. We also recommend that a
-      file or class name and description of purpose be included on the
-      same "printed page" as the copyright notice for easier
-      identification within third-party archives.
-
-   Copyright 2018, The TensorFlow Authors.
-
-   Licensed under the Apache License, Version 2.0 (the "License");
-   you may not use this file except in compliance with the License.
-   You may obtain a copy of the License at
-
-       http://www.apache.org/licenses/LICENSE-2.0
-
-   Unless required by applicable law or agreed to in writing, software
-   distributed under the License is distributed on an "AS IS" BASIS,
-   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-   See the License for the specific language governing permissions and
-   limitations under the License.
diff --git a/third_party/xla/third_party/tensorrt/build_defs.bzl.tpl b/third_party/xla/third_party/tensorrt/build_defs.bzl.tpl
deleted file mode 100644
index 83fcc7d..0000000
--- a/third_party/xla/third_party/tensorrt/build_defs.bzl.tpl
+++ /dev/null
@@ -1,9 +0,0 @@
-# Build configurations for TensorRT.
-
-def if_tensorrt(if_true, if_false=[]):
-  """Tests whether TensorRT was enabled during the configure process."""
-  return %{if_tensorrt}
-
-def if_tensorrt_exec(if_true, if_false=[]):
-  """Synonym for if_tensorrt."""
-  return %{if_tensorrt}
diff --git a/third_party/xla/third_party/tensorrt/plugin.BUILD.tpl b/third_party/xla/third_party/tensorrt/plugin.BUILD.tpl
deleted file mode 100644
index 92028b5..0000000
--- a/third_party/xla/third_party/tensorrt/plugin.BUILD.tpl
+++ /dev/null
@@ -1,6 +0,0 @@
-load("@local_config_cuda//cuda:build_defs.bzl", "cuda_default_copts")
-load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
-
-package(default_visibility = ["//visibility:public"])
-
-%{oss_rules}
diff --git a/third_party/xla/third_party/tensorrt/plugin/BUILD b/third_party/xla/third_party/tensorrt/plugin/BUILD
deleted file mode 100644
index 2c76d3d..0000000
--- a/third_party/xla/third_party/tensorrt/plugin/BUILD
+++ /dev/null
@@ -1,64 +0,0 @@
-# NVIDIA TensorRT Open Source Plugins
-# This package contains build targets for select TensorRT plugins included in the
-# TensorRT open source repository.
-load("@local_config_cuda//cuda:build_defs.bzl", "cuda_default_copts", "cuda_library")
-
-exports_files(
-    ["LICENSE"],
-    visibility = ["//visibility:public"],
-)
-
-cuda_library(
-    name = "plugin_common",
-    srcs = [
-        "plugin/common/kernels/common.cu.cc",
-    ],
-    hdrs = [
-        "plugin/common/bboxUtils.h",
-        "plugin/common/checkMacrosPlugin.h",
-        "plugin/common/plugin.h",
-    ],
-    strip_include_prefix = "plugin/common",
-    visibility = ["//visibility:public"],
-    deps = [
-        "@local_config_tensorrt//:tensorrt",
-        "@local_config_tensorrt//:tensorrt_headers",
-    ],
-)
-
-cc_library(
-    name = "nms_plugin_hdrs",
-    hdrs = [
-        "plugin/efficientNMSPlugin/efficientNMSInference.h",
-        "plugin/efficientNMSPlugin/efficientNMSParameters.h",
-        "plugin/efficientNMSPlugin/efficientNMSPlugin.h",
-    ],
-    visibility = ["//visibility:public"],
-)
-
-cuda_library(
-    name = "nvinfer_plugin_nms",
-    srcs = [
-        "plugin/efficientNMSPlugin/efficientNMSInference.cu.cc",
-        "plugin/efficientNMSPlugin/efficientNMSInference.cu.h",
-        "plugin/efficientNMSPlugin/efficientNMSInference.h",
-        "plugin/efficientNMSPlugin/efficientNMSParameters.h",
-        "plugin/efficientNMSPlugin/efficientNMSPlugin.cpp",
-        "plugin/efficientNMSPlugin/efficientNMSPlugin.h",
-    ],
-    hdrs = [
-        "plugin/efficientNMSPlugin/efficientNMSInference.h",
-        "plugin/efficientNMSPlugin/efficientNMSParameters.h",
-        "plugin/efficientNMSPlugin/efficientNMSPlugin.h",
-    ],
-    copts = cuda_default_copts(),
-    include_prefix = "third_party/tensorrt/plugin/efficientNMSPlugin",
-    strip_include_prefix = "plugin/efficientNMSPlugin",
-    visibility = ["//visibility:public"],
-    deps = [
-        ":nms_plugin_hdrs",
-        ":plugin_common",
-        "@local_config_tensorrt//:tensorrt",
-        "@local_config_tensorrt//:tensorrt_headers",
-    ],
-)
diff --git a/third_party/xla/third_party/tensorrt/plugin/tensorrt_oss.patch b/third_party/xla/third_party/tensorrt/plugin/tensorrt_oss.patch
deleted file mode 100644
index 3380879..0000000
--- a/third_party/xla/third_party/tensorrt/plugin/tensorrt_oss.patch
+++ /dev/null
@@ -1,144 +0,0 @@
-diff --git a/plugin/common/checkMacrosPlugin.h b/plugin/common/checkMacrosPlugin.h
-index 2cff9f8..a803765 100644
---- a/plugin/common/checkMacrosPlugin.h
-+++ b/plugin/common/checkMacrosPlugin.h
-@@ -16,7 +16,7 @@
- #ifndef CHECK_MACROS_PLUGIN_H
- #define CHECK_MACROS_PLUGIN_H
- 
--#include "NvInfer.h"
-+#include "third_party/tensorrt/NvInfer.h"
- #include <sstream>
- 
- #ifndef TRT_CHECK_MACROS_H
-diff --git a/plugin/common/kernels/common.cu b/plugin/common/kernels/common.cu.cc
-similarity index 87%
-rename from plugin/common/kernels/common.cu
-rename to plugin/common/kernels/common.cu.cc
-index 7c8922a..9818a30 100755
---- a/plugin/common/kernels/common.cu
-+++ b/plugin/common/kernels/common.cu.cc
-@@ -18,7 +18,6 @@
- #include "cublas_v2.h"
- #include <cub/cub.cuh>
- #include <stdint.h>
--#include "kernel.h"
- #include "bboxUtils.h"
- 
- #define CUDA_MEM_ALIGN 256
-@@ -26,28 +25,7 @@
- // HASH
- unsigned int hash(const void* array_, size_t size)
- {
--    // Apply hashing only when debugging RPN codes.
--    if (DEBUG_ENABLE)
--    {
--        const char* array_const;
--        char* array;
--        cudaMallocHost((void**) &array, size);
--        cudaMemcpy(array, array_, size, cudaMemcpyDeviceToHost);
--        array_const = array;
--        unsigned int hash = 45599;
--        for (size_t i = 0; i < size; i++)
--        {
--            unsigned int value = array_const[i];
--            hash = hash * 1487 + value;
--            hash = hash * 317;
--            hash = hash % 105359;
--        }
--        return hash;
--    }
--    else
--    {
--        return 0;
--    }
-+    return 0;
- }
- 
- // ALIGNPTR
-diff --git a/plugin/common/plugin.h b/plugin/common/plugin.h
-index 27a1fb7..f056255 100644
---- a/plugin/common/plugin.h
-+++ b/plugin/common/plugin.h
-@@ -17,7 +17,7 @@
- #define TRT_PLUGIN_H
- #include "checkMacrosPlugin.h"
- 
--#include "NvInferPlugin.h"
-+#include "third_party/tensorrt/NvInferPlugin.h"
- #include <cstring>
- #include <cuda_runtime.h>
- #include <iostream>
-diff --git a/plugin/efficientNMSPlugin/efficientNMSInference.cu b/plugin/efficientNMSPlugin/efficientNMSInference.cu.cc
-similarity index 99%
-rename from plugin/efficientNMSPlugin/efficientNMSInference.cu
-rename to plugin/efficientNMSPlugin/efficientNMSInference.cu.cc
-index f02a2f8..44fa20b 100644
---- a/plugin/efficientNMSPlugin/efficientNMSInference.cu
-+++ b/plugin/efficientNMSPlugin/efficientNMSInference.cu.cc
-@@ -18,7 +18,7 @@
- #include "cub/cub.cuh"
- #include "cuda_runtime_api.h"
- 
--#include "efficientNMSInference.cuh"
-+#include "efficientNMSInference.cu.h"
- #include "efficientNMSInference.h"
- 
- using namespace nvinfer1;
-diff --git a/plugin/efficientNMSPlugin/efficientNMSInference.cuh b/plugin/efficientNMSPlugin/efficientNMSInference.cu.h
-similarity index 100%
-rename from plugin/efficientNMSPlugin/efficientNMSInference.cuh
-rename to plugin/efficientNMSPlugin/efficientNMSInference.cu.h
-diff --git a/plugin/efficientNMSPlugin/efficientNMSPlugin.cpp b/plugin/efficientNMSPlugin/efficientNMSPlugin.cpp
-index 2d05c5c..acda183 100644
---- a/plugin/efficientNMSPlugin/efficientNMSPlugin.cpp
-+++ b/plugin/efficientNMSPlugin/efficientNMSPlugin.cpp
-@@ -31,11 +31,6 @@ const char* EFFICIENT_NMS_ONNX_PLUGIN_VERSION{"1"};
- const char* EFFICIENT_NMS_ONNX_PLUGIN_NAME{"EfficientNMS_ONNX_TRT"};
- } // namespace
- 
--PluginFieldCollection EfficientNMSPluginCreator::mFC{};
--PluginFieldCollection EfficientNMSONNXPluginCreator::mFC{};
--std::vector<PluginField> EfficientNMSPluginCreator::mPluginAttributes;
--std::vector<PluginField> EfficientNMSONNXPluginCreator::mPluginAttributes;
--
- EfficientNMSPlugin::EfficientNMSPlugin(EfficientNMSParameters param)
-     : mParam(param)
- {
-@@ -386,7 +381,7 @@ EfficientNMSPluginCreator::EfficientNMSPluginCreator()
-     mPluginAttributes.emplace_back(PluginField("max_output_boxes", nullptr, PluginFieldType::kINT32, 1));
-     mPluginAttributes.emplace_back(PluginField("background_class", nullptr, PluginFieldType::kINT32, 1));
-     mPluginAttributes.emplace_back(PluginField("score_activation", nullptr, PluginFieldType::kINT32, 1));
--    mPluginAttributes.emplace_back(PluginField("box_coding", nullptr, PluginFieldType::kINT32, 1));
-+    mPluginAttributes.emplace_back(PluginField("box_coding", nullptr, PluginFieldType::kINT32, 1));    
-     mFC.nbFields = mPluginAttributes.size();
-     mFC.fields = mPluginAttributes.data();
- }
-diff --git a/plugin/efficientNMSPlugin/efficientNMSPlugin.h b/plugin/efficientNMSPlugin/efficientNMSPlugin.h
-index b342b09..84d5e69 100644
---- a/plugin/efficientNMSPlugin/efficientNMSPlugin.h
-+++ b/plugin/efficientNMSPlugin/efficientNMSPlugin.h
-@@ -85,9 +85,9 @@ public:
-         const char* name, const void* serialData, size_t serialLength) noexcept override;
- 
- protected:
--    static PluginFieldCollection mFC;
-+    PluginFieldCollection mFC;
-     EfficientNMSParameters mParam;
--    static std::vector<PluginField> mPluginAttributes;
-+    std::vector<PluginField> mPluginAttributes;
-     std::string mPluginName;
- };
- 
-@@ -107,9 +107,9 @@ public:
-         const char* name, const void* serialData, size_t serialLength) noexcept override;
- 
- protected:
--    static PluginFieldCollection mFC;
-+    PluginFieldCollection mFC;
-     EfficientNMSParameters mParam;
--    static std::vector<PluginField> mPluginAttributes;
-+    std::vector<PluginField> mPluginAttributes;
-     std::string mPluginName;
- };
- 
diff --git a/third_party/xla/third_party/tensorrt/tensorrt/include/tensorrt_config.h.tpl b/third_party/xla/third_party/tensorrt/tensorrt/include/tensorrt_config.h.tpl
deleted file mode 100644
index f9a09cf..0000000
--- a/third_party/xla/third_party/tensorrt/tensorrt/include/tensorrt_config.h.tpl
+++ /dev/null
@@ -1,21 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORRT_TENSORRT_INCLUDE_CONFIG_H_
-#define TENSORRT_TENSORRT_INCLUDE_CONFIG_H_
-
-#define TF_TENSORRT_VERSION "%{tensorrt_version}"
-
-#endif  // TENSORRT_TENSORRT_INCLUDE_CONFIG_H_
diff --git a/third_party/xla/third_party/tensorrt/tensorrt/tensorrt_config.py.tpl b/third_party/xla/third_party/tensorrt/tensorrt/tensorrt_config.py.tpl
deleted file mode 100644
index 88933dc..0000000
--- a/third_party/xla/third_party/tensorrt/tensorrt/tensorrt_config.py.tpl
+++ /dev/null
@@ -1,16 +0,0 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-config = %{tensorrt_config}
diff --git a/third_party/xla/third_party/tensorrt/tensorrt_configure.bzl b/third_party/xla/third_party/tensorrt/tensorrt_configure.bzl
deleted file mode 100644
index 3706b29..0000000
--- a/third_party/xla/third_party/tensorrt/tensorrt_configure.bzl
+++ /dev/null
@@ -1,337 +0,0 @@
-"""Repository rule for TensorRT configuration.
-
-`tensorrt_configure` depends on the following environment variables:
-
-  * `TF_TENSORRT_VERSION`: The TensorRT libnvinfer version.
-  * `TENSORRT_INSTALL_PATH`: The installation path of the TensorRT library.
-"""
-
-load(
-    "//third_party/gpus:cuda_configure.bzl",
-    "find_cuda_config",
-    "lib_name",
-    "make_copy_files_rule",
-)
-load(
-    "//third_party/remote_config:common.bzl",
-    "config_repo_label",
-    "get_cpu_value",
-    "get_host_environ",
-)
-
-_TENSORRT_INSTALL_PATH = "TENSORRT_INSTALL_PATH"
-_TF_TENSORRT_STATIC_PATH = "TF_TENSORRT_STATIC_PATH"
-_TF_TENSORRT_CONFIG_REPO = "TF_TENSORRT_CONFIG_REPO"
-_TF_TENSORRT_VERSION = "TF_TENSORRT_VERSION"
-_TF_NEED_TENSORRT = "TF_NEED_TENSORRT"
-
-_TF_TENSORRT_LIBS = ["nvinfer", "nvinfer_plugin"]
-_TF_TENSORRT_HEADERS = ["NvInfer.h", "NvUtils.h", "NvInferPlugin.h"]
-_TF_TENSORRT_HEADERS_V6 = [
-    "NvInfer.h",
-    "NvUtils.h",
-    "NvInferPlugin.h",
-    "NvInferVersion.h",
-    "NvInferRuntime.h",
-    "NvInferRuntimeCommon.h",
-    "NvInferPluginUtils.h",
-]
-_TF_TENSORRT_HEADERS_V8 = [
-    "NvInfer.h",
-    "NvInferLegacyDims.h",
-    "NvInferImpl.h",
-    "NvUtils.h",
-    "NvInferPlugin.h",
-    "NvInferVersion.h",
-    "NvInferRuntime.h",
-    "NvInferRuntimeCommon.h",
-    "NvInferPluginUtils.h",
-]
-_TF_TENSORRT_HEADERS_V8_6 = [
-    "NvInfer.h",
-    "NvInferConsistency.h",
-    "NvInferConsistencyImpl.h",
-    "NvInferImpl.h",
-    "NvInferLegacyDims.h",
-    "NvInferPlugin.h",
-    "NvInferPluginUtils.h",
-    "NvInferRuntime.h",
-    "NvInferRuntimeBase.h",
-    "NvInferRuntimeCommon.h",
-    "NvInferRuntimePlugin.h",
-    "NvInferSafeRuntime.h",
-    "NvInferVersion.h",
-    "NvUtils.h",
-]
-
-_DEFINE_TENSORRT_SONAME_MAJOR = "#define NV_TENSORRT_SONAME_MAJOR"
-_DEFINE_TENSORRT_SONAME_MINOR = "#define NV_TENSORRT_SONAME_MINOR"
-_DEFINE_TENSORRT_SONAME_PATCH = "#define NV_TENSORRT_SONAME_PATCH"
-
-_TENSORRT_OSS_DUMMY_BUILD_CONTENT = """
-cc_library(
-  name = "nvinfer_plugin_nms",
-  visibility = ["//visibility:public"],
-)
-"""
-
-_TENSORRT_OSS_ARCHIVE_BUILD_CONTENT = """
-alias(
-  name = "nvinfer_plugin_nms",
-  actual = "@tensorrt_oss_archive//:nvinfer_plugin_nms",
-  visibility = ["//visibility:public"],
-)
-"""
-
-def _at_least_version(actual_version, required_version):
-    actual = [int(v) for v in actual_version.split(".")]
-    required = [int(v) for v in required_version.split(".")]
-    return actual >= required
-
-def _get_tensorrt_headers(tensorrt_version):
-    if _at_least_version(tensorrt_version, "8.6"):
-        return _TF_TENSORRT_HEADERS_V8_6
-    if _at_least_version(tensorrt_version, "8"):
-        return _TF_TENSORRT_HEADERS_V8
-    if _at_least_version(tensorrt_version, "6"):
-        return _TF_TENSORRT_HEADERS_V6
-    return _TF_TENSORRT_HEADERS
-
-def _tpl_path(repository_ctx, filename):
-    return repository_ctx.path(Label("//third_party/tensorrt:%s.tpl" % filename))
-
-def _tpl(repository_ctx, tpl, substitutions):
-    repository_ctx.template(
-        tpl,
-        _tpl_path(repository_ctx, tpl),
-        substitutions,
-    )
-
-def _create_dummy_repository(repository_ctx):
-    """Create a dummy TensorRT repository."""
-    _tpl(repository_ctx, "build_defs.bzl", {"%{if_tensorrt}": "if_false"})
-    _tpl(repository_ctx, "BUILD", {
-        "%{copy_rules}": "",
-        "\":tensorrt_include\"": "",
-        "\":tensorrt_lib\"": "",
-        "%{oss_rules}": _TENSORRT_OSS_DUMMY_BUILD_CONTENT,
-    })
-    _tpl(repository_ctx, "tensorrt/include/tensorrt_config.h", {
-        "%{tensorrt_version}": "",
-    })
-
-    # Copy license file in non-remote build.
-    repository_ctx.template(
-        "LICENSE",
-        Label("//third_party/tensorrt:LICENSE"),
-        {},
-    )
-
-    # Set up tensorrt_config.py, which is used by gen_build_info to provide
-    # build environment info to the API
-    _tpl(
-        repository_ctx,
-        "tensorrt/tensorrt_config.py",
-        _py_tmpl_dict({}),
-    )
-
-def enable_tensorrt(repository_ctx):
-    """Returns whether to build with TensorRT support."""
-    return int(get_host_environ(repository_ctx, _TF_NEED_TENSORRT, False))
-
-def _get_tensorrt_static_path(repository_ctx):
-    """Returns the path for TensorRT static libraries."""
-    return get_host_environ(repository_ctx, _TF_TENSORRT_STATIC_PATH, None)
-
-def _get_tensorrt_full_version(repository_ctx):
-    """Returns the full version for TensorRT."""
-    return get_host_environ(repository_ctx, _TF_TENSORRT_VERSION, None)
-
-def _create_local_tensorrt_repository(repository_ctx):
-    tpl_paths = {
-        "build_defs.bzl": _tpl_path(repository_ctx, "build_defs.bzl"),
-        "BUILD": _tpl_path(repository_ctx, "BUILD"),
-        "tensorrt/include/tensorrt_config.h": _tpl_path(repository_ctx, "tensorrt/include/tensorrt_config.h"),
-        "tensorrt/tensorrt_config.py": _tpl_path(repository_ctx, "tensorrt/tensorrt_config.py"),
-        "plugin.BUILD": _tpl_path(repository_ctx, "plugin.BUILD"),
-    }
-
-    config = find_cuda_config(repository_ctx, ["cuda", "tensorrt"])
-    cuda_version = config["cuda_version"]
-    cuda_library_path = config["cuda_library_dir"] + "/"
-    trt_version = config["tensorrt_version"]
-    trt_full_version = _get_tensorrt_full_version(repository_ctx)
-    cpu_value = get_cpu_value(repository_ctx)
-
-    # Copy the library and header files.
-    libraries = [lib_name(lib, cpu_value, trt_version) for lib in _TF_TENSORRT_LIBS]
-
-    library_dir = config["tensorrt_library_dir"] + "/"
-    headers = _get_tensorrt_headers(trt_version)
-    include_dir = config["tensorrt_include_dir"] + "/"
-    copy_rules = [
-        make_copy_files_rule(
-            repository_ctx,
-            name = "tensorrt_lib",
-            srcs = [library_dir + library for library in libraries],
-            outs = ["tensorrt/lib/" + library for library in libraries],
-        ),
-        make_copy_files_rule(
-            repository_ctx,
-            name = "tensorrt_include",
-            srcs = [include_dir + header for header in headers],
-            outs = ["tensorrt/include/" + header for header in headers],
-        ),
-    ]
-
-    tensorrt_static_path = _get_tensorrt_static_path(repository_ctx)
-    if tensorrt_static_path:
-        tensorrt_static_path = tensorrt_static_path + "/"
-        if _at_least_version(trt_full_version, "8.4.1") and _at_least_version(cuda_version, "11.4"):
-            raw_static_library_names = _TF_TENSORRT_LIBS
-            nvrtc_ptxjit_static_raw_names = ["nvrtc", "nvrtc-builtins", "nvptxcompiler"]
-            nvrtc_ptxjit_static_names = ["%s_static" % name for name in nvrtc_ptxjit_static_raw_names]
-            nvrtc_ptxjit_static_libraries = [lib_name(lib, cpu_value, trt_version, static = True) for lib in nvrtc_ptxjit_static_names]
-        elif _at_least_version(trt_version, "8"):
-            raw_static_library_names = _TF_TENSORRT_LIBS
-            nvrtc_ptxjit_static_libraries = []
-        else:
-            raw_static_library_names = _TF_TENSORRT_LIBS + ["nvrtc", "myelin_compiler", "myelin_executor", "myelin_pattern_library", "myelin_pattern_runtime"]
-            nvrtc_ptxjit_static_libraries = []
-        static_library_names = ["%s_static" % name for name in raw_static_library_names]
-        static_libraries = [lib_name(lib, cpu_value, trt_version, static = True) for lib in static_library_names]
-        copy_rules = copy_rules + [
-            make_copy_files_rule(
-                repository_ctx,
-                name = "tensorrt_static_lib",
-                srcs = [tensorrt_static_path + library for library in static_libraries] +
-                       [cuda_library_path + library for library in nvrtc_ptxjit_static_libraries],
-                outs = ["tensorrt/lib/" + library for library in static_libraries] +
-                       ["tensorrt/lib/" + library for library in nvrtc_ptxjit_static_libraries],
-            ),
-        ]
-
-    # Set up config file.
-    repository_ctx.template(
-        "build_defs.bzl",
-        tpl_paths["build_defs.bzl"],
-        {"%{if_tensorrt}": "if_true"},
-    )
-
-    # Set up BUILD file.
-    repository_ctx.template(
-        "BUILD",
-        tpl_paths["BUILD"],
-        {
-            "%{copy_rules}": "\n".join(copy_rules),
-        },
-    )
-
-    # Set up the plugins folder BUILD file.
-    repository_ctx.template(
-        "plugin/BUILD",
-        tpl_paths["plugin.BUILD"],
-        {
-            "%{oss_rules}": _TENSORRT_OSS_ARCHIVE_BUILD_CONTENT,
-        },
-    )
-
-    # Copy license file in non-remote build.
-    repository_ctx.template(
-        "LICENSE",
-        Label("//third_party/tensorrt:LICENSE"),
-        {},
-    )
-
-    # Set up tensorrt_config.h, which is used by
-    # tensorflow/compiler/xla/stream_executor/dso_loader.cc.
-    repository_ctx.template(
-        "tensorrt/include/tensorrt_config.h",
-        tpl_paths["tensorrt/include/tensorrt_config.h"],
-        {"%{tensorrt_version}": trt_version},
-    )
-
-    # Set up tensorrt_config.py, which is used by gen_build_info to provide
-    # build environment info to the API
-    repository_ctx.template(
-        "tensorrt/tensorrt_config.py",
-        tpl_paths["tensorrt/tensorrt_config.py"],
-        _py_tmpl_dict({
-            "tensorrt_version": trt_version,
-        }),
-    )
-
-def _py_tmpl_dict(d):
-    return {"%{tensorrt_config}": str(d)}
-
-def _tensorrt_configure_impl(repository_ctx):
-    """Implementation of the tensorrt_configure repository rule."""
-
-    if get_host_environ(repository_ctx, _TF_TENSORRT_CONFIG_REPO) != None:
-        # Forward to the pre-configured remote repository.
-        remote_config_repo = repository_ctx.os.environ[_TF_TENSORRT_CONFIG_REPO]
-        repository_ctx.template("BUILD", config_repo_label(remote_config_repo, ":BUILD"), {})
-        repository_ctx.template(
-            "build_defs.bzl",
-            config_repo_label(remote_config_repo, ":build_defs.bzl"),
-            {},
-        )
-        repository_ctx.template(
-            "tensorrt/include/tensorrt_config.h",
-            config_repo_label(remote_config_repo, ":tensorrt/include/tensorrt_config.h"),
-            {},
-        )
-        repository_ctx.template(
-            "tensorrt/tensorrt_config.py",
-            config_repo_label(remote_config_repo, ":tensorrt/tensorrt_config.py"),
-            {},
-        )
-        repository_ctx.template(
-            "LICENSE",
-            config_repo_label(remote_config_repo, ":LICENSE"),
-            {},
-        )
-        return
-
-    if not enable_tensorrt(repository_ctx):
-        _create_dummy_repository(repository_ctx)
-        return
-
-    _create_local_tensorrt_repository(repository_ctx)
-
-_ENVIRONS = [
-    _TENSORRT_INSTALL_PATH,
-    _TF_TENSORRT_VERSION,
-    _TF_NEED_TENSORRT,
-    _TF_TENSORRT_STATIC_PATH,
-    "TF_CUDA_PATHS",
-]
-
-remote_tensorrt_configure = repository_rule(
-    implementation = _create_local_tensorrt_repository,
-    environ = _ENVIRONS,
-    remotable = True,
-    attrs = {
-        "environ": attr.string_dict(),
-        "_find_cuda_config": attr.label(default = "@local_xla//third_party/gpus:find_cuda_config.py"),
-    },
-)
-
-tensorrt_configure = repository_rule(
-    implementation = _tensorrt_configure_impl,
-    environ = _ENVIRONS + [_TF_TENSORRT_CONFIG_REPO],
-    attrs = {
-        "_find_cuda_config": attr.label(default = "@local_xla//third_party/gpus:find_cuda_config.py"),
-    },
-)
-"""Detects and configures the local CUDA toolchain.
-
-Add the following to your WORKSPACE FILE:
-
-```python
-tensorrt_configure(name = "local_config_tensorrt")
-```
-
-Args:
-  name: A unique name for this workspace rule.
-"""
diff --git a/third_party/xla/third_party/tensorrt/workspace.bzl b/third_party/xla/third_party/tensorrt/workspace.bzl
deleted file mode 100644
index be383ee..0000000
--- a/third_party/xla/third_party/tensorrt/workspace.bzl
+++ /dev/null
@@ -1,21 +0,0 @@
-"""Provides the repository macro to import TensorRT Open Source components."""
-
-load("//third_party:repo.bzl", "tf_http_archive")
-
-def repo(name = "tensorrt_oss_archive"):
-    """Imports TensorRT Open Source Components."""
-    TRT_OSS_COMMIT = "9ec6eb6db39188c9f3d25f49c8ee3a9721636b56"
-    TRT_OSS_SHA256 = "4fa2a712a5f2350b81df01d55c1dc17451e09efd4b2a53322b0433721009e1c7"
-
-    tf_http_archive(
-        name = name,
-        sha256 = TRT_OSS_SHA256,
-        strip_prefix = "TensorRT-{commit}".format(commit = TRT_OSS_COMMIT),
-        urls = [
-            # TODO: Google Mirror "https://storage.googleapis.com/...."
-            "https://storage.googleapis.com/mirror.tensorflow.org/github.com/NVIDIA/TensorRT/archive/{commit}.tar.gz".format(commit = TRT_OSS_COMMIT),
-            "https://github.com/NVIDIA/TensorRT/archive/{commit}.tar.gz".format(commit = TRT_OSS_COMMIT),
-        ],
-        build_file = "//third_party/tensorrt/plugin:BUILD",
-        patch_file = ["//third_party/tensorrt/plugin:tensorrt_oss.patch"],
-    )
diff --git a/third_party/xla/third_party/tf_runtime/BUILD b/third_party/xla/third_party/tf_runtime/BUILD
deleted file mode 100644
index e69de29..0000000
--- a/third_party/xla/third_party/tf_runtime/BUILD
+++ /dev/null
diff --git a/third_party/xla/third_party/tf_runtime/tf_runtime.patch b/third_party/xla/third_party/tf_runtime/tf_runtime.patch
deleted file mode 100644
index a9a9d4a..0000000
--- a/third_party/xla/third_party/tf_runtime/tf_runtime.patch
+++ /dev/null
@@ -1,84 +0,0 @@
-Intermittent patch to TFRT to submit a TF/TFRT cross-cutting change.
-This patch will be applied only until TF's TFRT commit is automatically bumped.
-
----
-
-diff --git a/backends/gpu/include/tfrt/gpu/gpu_types.h b/backends/gpu/include/tfrt/gpu/gpu_types.h
-index 3d311c3..a216716 100644
---- a/backends/gpu/include/tfrt/gpu/gpu_types.h
-+++ b/backends/gpu/include/tfrt/gpu/gpu_types.h
-@@ -295,11 +295,7 @@
-       wrapper::CurrentContext current, wrapper::Stream stream,
-       wrapper::CclComm comm)>;
- 
--  explicit GpuCclHandle(AsyncValueRef<GpuContext> context,
--                        wrapper::OwningCclComm comm, int num_ranks);
--  // TODO(hanbinyoon): Remove after transitioning to the above constructor.
--  explicit GpuCclHandle(AsyncValueRef<GpuContext> context,
--                        wrapper::OwningCclComm comm);
-+  GpuCclHandle(AsyncValueRef<GpuContext> context, wrapper::OwningCclComm comm);
-   ~GpuCclHandle();
- 
-   GpuCclHandle(GpuCclHandle&&) = default;
-@@ -311,8 +307,6 @@
-   llvm::Error ExecuteCallbacks(wrapper::CurrentContext current,
-                                wrapper::Stream stream);
- 
--  int num_ranks() const { return num_ranks_; }
--
-   const wrapper::OwningCclComm& operator->() const { return comm_; }
-   wrapper::CclComm get() const { return comm_.get(); }
-   wrapper::CclComm release();
-@@ -322,7 +316,6 @@
-  private:
-   AsyncValueRef<GpuContext> context_;
-   wrapper::OwningCclComm comm_;
--  int num_ranks_;
-   std::vector<Callback> callbacks_;
- };
- 
-diff --git a/backends/gpu/lib/gpu_types.cc b/backends/gpu/lib/gpu_types.cc
-index 38529bc..01e3dba 100644
---- a/backends/gpu/lib/gpu_types.cc
-+++ b/backends/gpu/lib/gpu_types.cc
-@@ -214,15 +214,8 @@
- GpuBlasHandle::~GpuBlasHandle() = default;
- 
- GpuCclHandle::GpuCclHandle(AsyncValueRef<GpuContext> context,
--                           wrapper::OwningCclComm comm, int num_ranks)
--    : context_(std::move(context)),
--      comm_(std::move(comm)),
--      num_ranks_(num_ranks) {}
--
--// TODO(hanbinyoon): Remove after transitioning to the above constructor.
--GpuCclHandle::GpuCclHandle(AsyncValueRef<GpuContext> context,
-                            wrapper::OwningCclComm comm)
--    : context_(std::move(context)), comm_(std::move(comm)), num_ranks_(0) {}
-+    : context_(std::move(context)), comm_(std::move(comm)) {}
- 
- GpuCclHandle::~GpuCclHandle() = default;
- 
-diff --git a/backends/gpu/lib/kernels/ccl_kernels.cc b/backends/gpu/lib/kernels/ccl_kernels.cc
-index 52ce820..9cfc1de 100644
---- a/backends/gpu/lib/kernels/ccl_kernels.cc
-+++ b/backends/gpu/lib/kernels/ccl_kernels.cc
-@@ -107,8 +107,6 @@
-   auto width = ToWidthInBytes(type);
-   if (!width) return width.takeError();
-   assert(*width != 0);
--  if (input->size() != output->size() * handle->num_ranks())
--    return MakeStringError("Input size must be output size times ranks.");
- 
-   handle->AddCallback([input = input.ValueRef(), output = output.ValueRef(),
-                        recvcount = output->size() / *width, type,
-@@ -116,6 +114,10 @@
-                           wrapper::CurrentContext current,
-                           wrapper::Stream stream,
-                           wrapper::CclComm comm) -> llvm::Error {
-+    auto count = wrapper::CclCommCount(comm);
-+    if (!count) return count.takeError();
-+    if (input->size() != output->size() * *count)
-+      return MakeStringError("Input size must be output size times ranks.");
-     return wrapper::CclReduceScatter(current, input->pointer(),
-                                      output->pointer(), recvcount, type, op,
-                                      comm, stream);
diff --git a/third_party/xla/third_party/tf_runtime/tf_runtime_clangcl.patch b/third_party/xla/third_party/tf_runtime/tf_runtime_clangcl.patch
deleted file mode 100644
index ce1859d..0000000
--- a/third_party/xla/third_party/tf_runtime/tf_runtime_clangcl.patch
+++ /dev/null
@@ -1,14 +0,0 @@
-diff --git a/include/tfrt/support/std_mutex.h b/include/tfrt/support/std_mutex.h
-index 6238d097..9fb24279 100644
---- a/include/tfrt/support/std_mutex.h
-+++ b/include/tfrt/support/std_mutex.h
-@@ -50,7 +50,7 @@ class TFRT_CAPABILITY("mutex") mutex {
- 
-  private:
-   friend class mutex_lock;
--  std::mutex mu_;
-+  std::mutex mu_{};
- };
-
- // Wrap std::unique_lock<std::mutex> with support for thread annotations.
- 
diff --git a/third_party/xla/third_party/tf_runtime/workspace.bzl b/third_party/xla/third_party/tf_runtime/workspace.bzl
deleted file mode 100644
index 44d692c..0000000
--- a/third_party/xla/third_party/tf_runtime/workspace.bzl
+++ /dev/null
@@ -1,20 +0,0 @@
-"""Provides the repository macro to import TFRT."""
-
-load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
-
-def repo():
-    """Imports TFRT."""
-
-    # Attention: tools parse and update these lines.
-    TFRT_COMMIT = "6d71fa4816fafb69ee1caac955b2b3844290d577"
-    TFRT_SHA256 = "37da0c18b558e85e8e9c9e482c217221762679265245893ab87c136df7265446"
-
-    tf_http_archive(
-        name = "tf_runtime",
-        sha256 = TFRT_SHA256,
-        strip_prefix = "runtime-{commit}".format(commit = TFRT_COMMIT),
-        urls = tf_mirror_urls("https://github.com/tensorflow/runtime/archive/{commit}.tar.gz".format(commit = TFRT_COMMIT)),
-        # A patch file can be provided for atomic commits to both TF and TFRT.
-        # The job that bumps the TFRT_COMMIT also resets patch_file to 'None'.
-        patch_file = None,
-    )
diff --git a/third_party/xla/third_party/triton/cl568176943.patch b/third_party/xla/third_party/triton/cl568176943.patch
index d91e505..c187e67 100644
--- a/third_party/xla/third_party/triton/cl568176943.patch
+++ b/third_party/xla/third_party/triton/cl568176943.patch
@@ -1,8 +1,16 @@
 diff --git a/lib/Target/LLVMIR/LLVMIRTranslation.cpp b/lib/Target/LLVMIR/LLVMIRTranslation.cpp
-index d2a3f7c74..cb668303a 100644
+index e78e7298c..a4685653c 100644
 --- a/lib/Target/LLVMIR/LLVMIRTranslation.cpp
 +++ b/lib/Target/LLVMIR/LLVMIRTranslation.cpp
-@@ -273,8 +273,10 @@ static std::map<std::string, std::string> getExternLibs(mlir::ModuleOp module) {
+@@ -40,7 +40,6 @@
+ #include "llvm/Support/SourceMgr.h"
+ #include "llvm/Target/TargetMachine.h"
+ #include "llvm/Transforms/InstCombine/InstCombine.h"
+-#include "third_party/py/triton/google/find_cuda.h"
+ #include <optional>
+ #ifdef _WIN32
+ #define WIN32_LEAN_AND_MEAN
+@@ -277,8 +276,10 @@ static std::map<std::string, std::string> getExternLibs(mlir::ModuleOp module) {
      // Search for libdevice relative to its library path if used from Python
      // Then native code is in `triton/_C/libtriton.so` and libdevice in
      // `triton/third_party/cuda/lib/libdevice.10.bc`
diff --git a/third_party/xla/third_party/triton/cl576548341.patch b/third_party/xla/third_party/triton/cl576548341.patch
new file mode 100644
index 0000000..3efc805
--- /dev/null
+++ b/third_party/xla/third_party/triton/cl576548341.patch
@@ -0,0 +1,16 @@
+diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp
+--- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp
++++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp
+@@ -362,8 +362,10 @@ Value addStringToModule(Location loc, Co
+   }
+ 
+   Value zero = i32_val(0);
+-  Value globalPtr =
+-      rewriter.create<LLVM::AddressOfOp>(UnknownLoc::get(ctx), global);
++  Type globalPtrType =
++      LLVM::LLVMPointerType::get(globalType, global.getAddrSpace());
++  Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
++      UnknownLoc::get(ctx), globalPtrType, global.getSymName());
+   Value stringStart =
+       rewriter.create<LLVM::GEPOp>(UnknownLoc::get(ctx), ptr_ty(i8_ty),
+                                    globalPtr, SmallVector<Value>({zero, zero}));
diff --git a/third_party/xla/third_party/triton/cl577369732.patch b/third_party/xla/third_party/triton/cl577369732.patch
new file mode 100644
index 0000000..e63b9f3
--- /dev/null
+++ b/third_party/xla/third_party/triton/cl577369732.patch
@@ -0,0 +1,116 @@
+==== triton/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp#19 - /google/src/cloud/springerm/mlir_3cd2a0bc1a2dcf851f1821765946b77d0e65bd2e_1698463035/triton/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp ====
+# action=edit type=text
+--- triton/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp	2023-10-19 14:55:11.000000000 -0700
++++ triton/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp	2023-10-27 20:17:46.000000000 -0700
+@@ -759,7 +759,7 @@
+   OpBuilder builder(forOp);
+   // Get init operands for loop carried values
+   for (BlockArgument &arg : forOp.getRegionIterArgs()) {
+-    OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg);
++    OpOperand &operand = *forOp.getTiedLoopInit(arg);
+     setValueMapping(arg, operand.get(), 0);
+   }
+ 
+==== triton/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp#10 - /google/src/cloud/springerm/mlir_3cd2a0bc1a2dcf851f1821765946b77d0e65bd2e_1698463035/triton/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp ====
+# action=edit type=text
+--- triton/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp	2023-10-19 14:55:11.000000000 -0700
++++ triton/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp	2023-10-27 20:17:46.000000000 -0700
+@@ -188,7 +188,7 @@
+   auto getIncomingOp = [this](Value v) -> Value {
+     if (auto arg = v.dyn_cast<BlockArgument>())
+       if (arg.getOwner()->getParentOp() == forOp.getOperation())
+-        return forOp.getOpOperandForRegionIterArg(arg).get();
++        return forOp.getTiedLoopInit(arg)->get();
+     return Value();
+   };
+ 
+@@ -298,10 +298,10 @@
+       Operation *firstDot = builder.clone(*dot, mapping);
+       if (Value a = operand2headPrefetch.lookup(dot.getA()))
+         firstDot->setOperand(
+-            0, newForOp.getRegionIterArgForOpOperand(*a.use_begin()));
++            0, newForOp.getTiedLoopRegionIterArg(&*a.use_begin()));
+       if (Value b = operand2headPrefetch.lookup(dot.getB()))
+         firstDot->setOperand(
+-            1, newForOp.getRegionIterArgForOpOperand(*b.use_begin()));
++            1, newForOp.getTiedLoopRegionIterArg(&*b.use_begin()));
+ 
+       // remaining part
+       int64_t kOff = prefetchWidth;
+==== triton/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp#18 - /google/src/cloud/springerm/mlir_3cd2a0bc1a2dcf851f1821765946b77d0e65bd2e_1698463035/triton/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp ====
+# action=edit type=text
+--- triton/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp	2023-10-24 18:31:01.000000000 -0700
++++ triton/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp	2023-10-27 20:17:46.000000000 -0700
+@@ -245,7 +245,7 @@
+   for (OpOperand &use : value.getUses()) {
+     Operation *user = use.getOwner();
+     if (auto forOp = dyn_cast<scf::ForOp>(user)) {
+-      Value arg = forOp.getRegionIterArgForOpOperand(use);
++      Value arg = forOp.getTiedLoopRegionIterArg(&use);
+       Value result = forOp.getResultForOpOperand(use);
+       setEncoding({arg, result}, info, changed, user);
+       continue;
+@@ -767,7 +767,7 @@
+       SmallVector<Value> newOperands;
+       for (auto arg : forOp.getRegionIterArgs()) {
+         if (slice.count(arg)) {
+-          OpOperand &initVal = forOp.getOpOperandForRegionIterArg(arg);
++          OpOperand &initVal = *forOp.getTiedLoopInit(arg);
+           argMapping.push_back(std::make_pair(
+               forOp.getResultForOpOperand(initVal).getResultNumber(),
+               forOp.getInitArgs().size() + newOperands.size()));
+==== triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp#16 - /google/src/cloud/springerm/mlir_3cd2a0bc1a2dcf851f1821765946b77d0e65bd2e_1698463035/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp ====
+# action=edit type=text
+--- triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp	2023-10-24 18:31:01.000000000 -0700
++++ triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp	2023-10-27 20:17:46.000000000 -0700
+@@ -430,10 +430,10 @@
+     Block *block = blockArg.getOwner();
+     Operation *parentOp = block->getParentOp();
+     if (auto forOp = dyn_cast<scf::ForOp>(parentOp)) {
+-      OpOperand &initOperand = forOp.getOpOperandForRegionIterArg(blockArg);
++      OpOperand *initOperand = forOp.getTiedLoopInit(blockArg);
+       Value yieldOperand = forOp.getBody()->getTerminator()->getOperand(
+           blockArg.getArgNumber() - forOp.getNumInductionVars());
+-      queue.push_back({initOperand.get(), encoding});
++      queue.push_back({initOperand->get(), encoding});
+       queue.push_back({yieldOperand, encoding});
+       continue;
+     }
+==== triton/lib/Dialect/TritonNvidiaGPU/Transforms/Utility.cpp#1 - /google/src/cloud/springerm/mlir_3cd2a0bc1a2dcf851f1821765946b77d0e65bd2e_1698463035/triton/lib/Dialect/TritonNvidiaGPU/Transforms/Utility.cpp ====
+# action=edit type=text
+--- triton/lib/Dialect/TritonNvidiaGPU/Transforms/Utility.cpp	2023-10-12 01:35:16.000000000 -0700
++++ triton/lib/Dialect/TritonNvidiaGPU/Transforms/Utility.cpp	2023-10-27 20:17:46.000000000 -0700
+@@ -88,9 +88,8 @@
+     auto parentOp = blockArg.getOwner()->getParentOp();
+     if (auto forOp = dyn_cast<scf::ForOp>(parentOp)) {
+       if (blockArg.getArgNumber() >= forOp.getNumInductionVars()) {
+-        if (failed(getDependentPointers(
+-                forOp.getOpOperandForRegionIterArg(blockArg).get(),
+-                dependentSet, processedSet)))
++        if (failed(getDependentPointers(forOp.getTiedLoopInit(blockArg)->get(),
++                                        dependentSet, processedSet)))
+           return failure();
+ 
+         unsigned operandIdx =
+@@ -383,7 +382,7 @@
+       if (failed(addControlOperandsForForOp(forOp)))
+         return failure();
+       if (blockArg.getArgNumber() >= forOp.getNumInductionVars()) {
+-        Value operand = forOp.getOpOperandForRegionIterArg(blockArg).get();
++        Value operand = forOp.getTiedLoopInit(blockArg)->get();
+         if (failed(tryInsertAndPropagate(operand)))
+           return failure();
+ 
+==== triton/test/lib/Analysis/TestAlias.cpp#5 - /google/src/cloud/springerm/mlir_3cd2a0bc1a2dcf851f1821765946b77d0e65bd2e_1698463035/triton/test/lib/Analysis/TestAlias.cpp ====
+# action=edit type=text
+--- triton/test/lib/Analysis/TestAlias.cpp	2023-10-19 14:55:11.000000000 -0700
++++ triton/test/lib/Analysis/TestAlias.cpp	2023-10-27 20:17:47.000000000 -0700
+@@ -87,7 +87,7 @@
+       }
+       if (auto forOp = dyn_cast<scf::ForOp>(op)) {
+         for (auto arg : llvm::enumerate(forOp.getRegionIterArgs())) {
+-          auto operand = forOp.getOpOperandForRegionIterArg(arg.value()).get();
++          auto operand = forOp.getTiedLoopInit(arg.value())->get();
+           auto opNames = getAllocOpNames(operand);
+           auto argName = getValueOperandName(arg.value(), state);
+           print(argName, opNames, os);
diff --git a/third_party/xla/third_party/triton/cl577379396.patch b/third_party/xla/third_party/triton/cl577379396.patch
new file mode 100644
index 0000000..ee569f9
--- /dev/null
+++ b/third_party/xla/third_party/triton/cl577379396.patch
@@ -0,0 +1,33 @@
+diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
+--- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
++++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
+@@ -246,7 +246,7 @@ SmallVector<Value> LayoutPropagation::pr
+     Operation *user = use.getOwner();
+     if (auto forOp = dyn_cast<scf::ForOp>(user)) {
+       Value arg = forOp.getTiedLoopRegionIterArg(&use);
+-      Value result = forOp.getResultForOpOperand(use);
++      Value result = forOp.getTiedLoopResult(&use);
+       setEncoding({arg, result}, info, changed, user);
+       continue;
+     }
+@@ -769,7 +769,7 @@ static void rewriteSlice(SetVector<Value
+         if (slice.count(arg)) {
+           OpOperand &initVal = *forOp.getTiedLoopInit(arg);
+           argMapping.push_back(std::make_pair(
+-              forOp.getResultForOpOperand(initVal).getResultNumber(),
++              forOp.getTiedLoopResult(&initVal).getResultNumber(),
+               forOp.getInitArgs().size() + newOperands.size()));
+           newOperands.push_back(mapping.lookup(initVal.get()));
+         }
+diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp
+--- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp
++++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp
+@@ -545,7 +545,7 @@ struct ForOpDeadArgElimination : public 
+       Value value = queue.pop_back_val();
+       if (auto nestedFor = value.getDefiningOp<scf::ForOp>()) {
+         auto result = value.cast<OpResult>();
+-        OpOperand &forOperand = nestedFor.getOpOperandForResult(result);
++        OpOperand &forOperand = *nestedFor.getTiedLoopInit(result);
+         markLive(forOperand.get());
+         auto nestedYieldOp =
+             cast<scf::YieldOp>(nestedFor.getBody()->getTerminator());
diff --git a/third_party/xla/third_party/triton/workspace.bzl b/third_party/xla/third_party/triton/workspace.bzl
index 7e15f39..9ca9639 100644
--- a/third_party/xla/third_party/triton/workspace.bzl
+++ b/third_party/xla/third_party/triton/workspace.bzl
@@ -5,8 +5,8 @@
 def repo():
     """Imports Triton."""
 
-    TRITON_COMMIT = "cl568176943"
-    TRITON_SHA256 = "5ffa5b538641fa306c8a24010438294ce7f43f80a462fe373a7cf747afde18b5"
+    TRITON_COMMIT = "cl575842988"
+    TRITON_SHA256 = "caa815ec863182eb3745fdc0884f521d622aa2b37be521b850f7ea330cadc923"
 
     tf_http_archive(
         name = "triton",
@@ -17,5 +17,8 @@
         patch_file = [
             "//third_party/triton:cl568176943.patch",
             "//third_party/triton:b304456327.patch",
+            "//third_party/triton:cl576548341.patch",
+            "//third_party/triton:cl577369732.patch",
+            "//third_party/triton:cl577379396.patch",
         ],
     )
diff --git a/third_party/xla/third_party/tsl/.bazelrc b/third_party/xla/third_party/tsl/.bazelrc
index 7035378..e9fc2d4 100644
--- a/third_party/xla/third_party/tsl/.bazelrc
+++ b/third_party/xla/third_party/tsl/.bazelrc
@@ -55,6 +55,7 @@
 #
 #     rbe_linux_cpu:                  RBE options to build with only CPU support.
 #     rbe_linux_cuda:                 RBE options to build with GPU support using clang.
+#     rbe_linux_cuda_nvcc:            RBE options to build with GPU support using nvcc.
 #
 #     rbe_win_py39: Windows Python 3.9 RBE config
 #
@@ -237,9 +238,12 @@
 # Select supported compute capabilities (supported graphics cards).
 # This is the same as the official TensorFlow builds.
 # See https://developer.nvidia.com/cuda-gpus#compute
-# TODO(angerson, perfinion): What does sm_ vs compute_ mean? How can users
-# select a good value for this? See go/tf-pip-cuda
-build:cuda_clang --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_75,compute_80"
+# `compute_XY` enables PTX embedding in addition to SASS. PTX
+# is forward compatible beyond the current compute capability major
+# release while SASS is only forward compatible inside the current
+# major release. Example: sm_80 kernels can run on sm_89 GPUs but
+# not on sm_90 GPUs. compute_80 kernels though can also run on sm_90 GPUs.
+build:cuda_clang --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90"
 
 # Set up compilation CUDA version and paths and use the CUDA Clang toolchain.
 build:cuda_clang_official --config=cuda_clang
@@ -249,7 +253,7 @@
 build:cuda_clang_official --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc"
 build:cuda_clang_official --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-17/bin/clang"
 build:cuda_clang_official --action_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64"
-build:cuda_clang_official --crosstool_top="@sigbuild-r2.14-clang_config_cuda//crosstool:toolchain"
+build:cuda_clang_official --crosstool_top="@sigbuild-r2.16-clang_config_cuda//crosstool:toolchain"
 
 # Debug config
 build:dbg -c dbg
@@ -482,12 +486,12 @@
 
 build:rbe_linux_cpu --config=rbe_linux
 # Linux cpu and cuda builds share the same toolchain now.
-build:rbe_linux_cpu --host_crosstool_top="@sigbuild-r2.14-clang_config_cuda//crosstool:toolchain"
-build:rbe_linux_cpu --crosstool_top="@sigbuild-r2.14-clang_config_cuda//crosstool:toolchain"
-build:rbe_linux_cpu --extra_toolchains="@sigbuild-r2.14-clang_config_cuda//crosstool:toolchain-linux-x86_64"
-build:rbe_linux_cpu --extra_execution_platforms="@sigbuild-r2.14-clang_config_platform//:platform"
-build:rbe_linux_cpu --host_platform="@sigbuild-r2.14-clang_config_platform//:platform"
-build:rbe_linux_cpu --platforms="@sigbuild-r2.14-clang_config_platform//:platform"
+build:rbe_linux_cpu --host_crosstool_top="@sigbuild-r2.16-clang_config_cuda//crosstool:toolchain"
+build:rbe_linux_cpu --crosstool_top="@sigbuild-r2.16-clang_config_cuda//crosstool:toolchain"
+build:rbe_linux_cpu --extra_toolchains="@sigbuild-r2.16-clang_config_cuda//crosstool:toolchain-linux-x86_64"
+build:rbe_linux_cpu --extra_execution_platforms="@sigbuild-r2.16-clang_config_platform//:platform"
+build:rbe_linux_cpu --host_platform="@sigbuild-r2.16-clang_config_platform//:platform"
+build:rbe_linux_cpu --platforms="@sigbuild-r2.16-clang_config_platform//:platform"
 # This is needed for all Clang17 builds but must not be present in GCC builds.
 build:rbe_linux_cpu --copt=-Wno-error=unused-command-line-argument
 # This was added in clang-16 by https://reviews.llvm.org/D133574.
@@ -496,7 +500,7 @@
 # See https://github.com/protocolbuffers/upb/blob/9effcbcb27f0a665f9f345030188c0b291e32482/upb/upb.c#L183.
 build:rbe_linux_cpu --copt=-Wno-gnu-offsetof-extensions
 # Python config is the same across all containers because the binary is the same
-build:rbe_linux_cpu --repo_env=TF_PYTHON_CONFIG_REPO="@sigbuild-r2.14-clang_config_python"
+build:rbe_linux_cpu --repo_env=TF_PYTHON_CONFIG_REPO="@sigbuild-r2.16-clang_config_python"
 build:rbe_linux_cpu --python_path="/usr/bin/python3"
 # These you may need to change for your own GCP project.
 common:rbe_linux_cpu --remote_instance_name=projects/tensorflow-testing/instances/default_instance
@@ -517,11 +521,40 @@
 build:rbe_linux_cuda --config=rbe_linux_cpu
 # For Remote build execution -- GPU configuration
 build:rbe_linux_cuda --repo_env=REMOTE_GPU_TESTING=1
-build:rbe_linux_cuda --repo_env=TF_CUDA_CONFIG_REPO="@sigbuild-r2.14-clang_config_cuda"
-build:rbe_linux_cuda --repo_env=TF_TENSORRT_CONFIG_REPO="@sigbuild-r2.14-clang_config_tensorrt"
-build:rbe_linux_cuda --repo_env=TF_NCCL_CONFIG_REPO="@sigbuild-r2.14-clang_config_nccl"
+build:rbe_linux_cuda --repo_env=TF_CUDA_CONFIG_REPO="@sigbuild-r2.16-clang_config_cuda"
+build:rbe_linux_cuda --repo_env=TF_TENSORRT_CONFIG_REPO="@sigbuild-r2.16-clang_config_tensorrt"
+build:rbe_linux_cuda --repo_env=TF_NCCL_CONFIG_REPO="@sigbuild-r2.16-clang_config_nccl"
 test:rbe_linux_cuda --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64"
 
+build:rbe_linux_cuda_nvcc --config=cuda
+build:rbe_linux_cuda_nvcc --repo_env TF_NCCL_USE_STUB=1
+build:rbe_linux_cuda_nvcc --@local_xla//xla/python:enable_gpu=true
+build:rbe_linux_cuda_nvcc --@local_xla//xla/python:jax_cuda_pip_rpaths=true
+build:rbe_linux_cuda_nvcc --define=xla_python_enable_gpu=true
+build:rbe_linux_cuda_nvcc --config=tensorrt
+build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_75,compute_80"
+build:rbe_linux_cuda_nvcc --action_env=TF_CUDA_VERSION="12"
+build:rbe_linux_cuda_nvcc --action_env=TF_CUDNN_VERSION="8"
+build:rbe_linux_cuda_nvcc --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-12.2"
+build:rbe_linux_cuda_nvcc --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc"
+build:rbe_linux_cuda_nvcc --action_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64"
+build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_cuda//crosstool:toolchain"
+build:rbe_linux_cuda_nvcc --config=rbe_linux
+build:rbe_linux_cuda_nvcc --host_crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_cuda//crosstool:toolchain"
+build:rbe_linux_cuda_nvcc --extra_toolchains="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_cuda//crosstool:toolchain-linux-x86_64"
+build:rbe_linux_cuda_nvcc --extra_execution_platforms="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_platform//:platform"
+build:rbe_linux_cuda_nvcc --host_platform="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_platform//:platform"
+build:rbe_linux_cuda_nvcc --platforms="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_platform//:platform"
+build:rbe_linux_cuda_nvcc --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_python3.9"
+build:rbe_linux_cuda_nvcc --python_path="/usr/bin/python3"
+# These you may need to change for your own GCP project.
+common:rbe_linux_cuda_nvcc --remote_instance_name=projects/tensorflow-testing/instances/default_instance
+build:rbe_linux_cuda_nvcc --repo_env=REMOTE_GPU_TESTING=1
+build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda12.2-cudnn8.9_config_cuda"
+build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda12.2-cudnn8.9_config_tensorrt"
+build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda12.2-cudnn8.9_config_nccl"
+test:rbe_linux_cuda_nvcc --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64"
+
 # TODO(kanglan): Remove rbe_win and rbe_win_py3* after b/289091160 is fixed
 build:rbe_win --config=rbe_base
 build:rbe_win --crosstool_top="//tensorflow/tools/toolchains/win/tf_win_05022023:toolchain"
@@ -576,8 +609,6 @@
 # Here are bazelrc configs for release builds
 # Build TensorFlow v2.
 test:release_base --test_size_filters=small,medium
-# TODO(b/294367488) disable after 2.15 brancut
-test:release_base --flaky_test_attempts=3
 
 # Target the AVX instruction set
 build:release_linux_base --config=avx_linux
@@ -615,7 +646,7 @@
 
 # Use the Clang toolchain to compile
 build:release_cpu_linux --config=release_linux_base
-build:release_cpu_linux --crosstool_top="@sigbuild-r2.14-clang_config_cuda//crosstool:toolchain"
+build:release_cpu_linux --crosstool_top="@sigbuild-r2.16-clang_config_cuda//crosstool:toolchain"
 
 build:release_gpu_linux --config=release_cpu_linux
 # Set up compilation CUDA version and paths and use the CUDA Clang toolchain.
@@ -684,7 +715,7 @@
 build:macos   --config=no_tfrt
 build:windows --config=no_tfrt
 build:rocm --config=no_tfrt
-build:no_tfrt --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compiler/mlir/tfrt/benchmarks,tensorflow/compiler/mlir/tfrt/ir,tensorflow/compiler/mlir/tfrt/ir/mlrt,tensorflow/compiler/mlir/tfrt/jit/python_binding,tensorflow/compiler/mlir/tfrt/jit/transforms,tensorflow/compiler/mlir/tfrt/python_tests,tensorflow/compiler/mlir/tfrt/tests,tensorflow/compiler/mlir/tfrt/tests/mlrt,tensorflow/compiler/mlir/tfrt/tests/ir,tensorflow/compiler/mlir/tfrt/tests/analysis,tensorflow/compiler/mlir/tfrt/tests/jit,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt,tensorflow/compiler/mlir/tfrt/tests/tf_to_corert,tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data,tensorflow/compiler/mlir/tfrt/tests/saved_model,tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu,tensorflow/compiler/mlir/tfrt/transforms/mlrt,tensorflow/core/runtime_fallback,tensorflow/core/runtime_fallback/conversion,tensorflow/core/runtime_fallback/kernel,tensorflow/core/runtime_fallback/opdefs,tensorflow/core/runtime_fallback/runtime,tensorflow/core/runtime_fallback/util,tensorflow/core/runtime_fallback/test,tensorflow/core/runtime_fallback/test/gpu,tensorflow/core/runtime_fallback/test/saved_model,tensorflow/core/runtime_fallback/test/testdata,tensorflow/core/tfrt/stubs,tensorflow/core/tfrt/tfrt_session,tensorflow/core/tfrt/mlrt,tensorflow/core/tfrt/mlrt/attribute,tensorflow/core/tfrt/mlrt/kernel,tensorflow/core/tfrt/mlrt/bytecode,tensorflow/core/tfrt/mlrt/interpreter,tensorflow/compiler/mlir/tfrt/translate/mlrt,tensorflow/compiler/mlir/tfrt/translate/mlrt/testdata,tensorflow/core/tfrt/gpu,tensorflow/core/tfrt/run_handler_thread_pool,tensorflow/core/tfrt/runtime,tensorflow/core/tfrt/saved_model,tensorflow/core/tfrt/graph_executor,tensorflow/core/tfrt/saved_model/tests,tensorflow/core/tfrt/tpu,tensorflow/core/tfrt/utils,tensorflow/core/tfrt/utils/debug,tensorflow/core/tfrt/saved_model/python,tensorflow/core/tfrt/graph_executor/python,tensorflow/core/tfrt/saved_model/utils
+build:no_tfrt --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compiler/mlir/tfrt/benchmarks,tensorflow/compiler/mlir/tfrt/ir,tensorflow/compiler/mlir/tfrt/ir/mlrt,tensorflow/compiler/mlir/tfrt/jit/python_binding,tensorflow/compiler/mlir/tfrt/jit/transforms,tensorflow/compiler/mlir/tfrt/python_tests,tensorflow/compiler/mlir/tfrt/tests,tensorflow/compiler/mlir/tfrt/tests/ifrt,tensorflow/compiler/mlir/tfrt/tests/mlrt,tensorflow/compiler/mlir/tfrt/tests/ir,tensorflow/compiler/mlir/tfrt/tests/analysis,tensorflow/compiler/mlir/tfrt/tests/jit,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt,tensorflow/compiler/mlir/tfrt/tests/tf_to_corert,tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data,tensorflow/compiler/mlir/tfrt/tests/saved_model,tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu,tensorflow/compiler/mlir/tfrt/transforms/mlrt,tensorflow/core/runtime_fallback,tensorflow/core/runtime_fallback/conversion,tensorflow/core/runtime_fallback/kernel,tensorflow/core/runtime_fallback/opdefs,tensorflow/core/runtime_fallback/runtime,tensorflow/core/runtime_fallback/util,tensorflow/core/runtime_fallback/test,tensorflow/core/runtime_fallback/test/gpu,tensorflow/core/runtime_fallback/test/saved_model,tensorflow/core/runtime_fallback/test/testdata,tensorflow/core/tfrt/stubs,tensorflow/core/tfrt/tfrt_session,tensorflow/core/tfrt/mlrt,tensorflow/core/tfrt/mlrt/attribute,tensorflow/core/tfrt/mlrt/kernel,tensorflow/core/tfrt/mlrt/bytecode,tensorflow/core/tfrt/mlrt/interpreter,tensorflow/compiler/mlir/tfrt/translate/mlrt,tensorflow/compiler/mlir/tfrt/translate/mlrt/testdata,tensorflow/core/tfrt/gpu,tensorflow/core/tfrt/run_handler_thread_pool,tensorflow/core/tfrt/runtime,tensorflow/core/tfrt/saved_model,tensorflow/core/tfrt/graph_executor,tensorflow/core/tfrt/saved_model/tests,tensorflow/core/tfrt/tpu,tensorflow/core/tfrt/utils,tensorflow/core/tfrt/utils/debug,tensorflow/core/tfrt/saved_model/python,tensorflow/core/tfrt/graph_executor/python,tensorflow/core/tfrt/saved_model/utils
 
 # BEGIN TF CACHE HELPER OPTIONS
 # Options when using remote execution
diff --git a/third_party/xla/third_party/tsl/opensource_only.files b/third_party/xla/third_party/tsl/opensource_only.files
index 37d7129..e4974e7 100644
--- a/third_party/xla/third_party/tsl/opensource_only.files
+++ b/third_party/xla/third_party/tsl/opensource_only.files
@@ -8,22 +8,15 @@
 third_party/compute_library/build_defs.bzl:
 third_party/curl.BUILD:
 third_party/cython.BUILD:
+third_party/ducc/BUILD:
+third_party/ducc/ducc0_custom_lowlevel_threading.h:
+third_party/ducc/fft.cc:
+third_party/ducc/fft.h:
+third_party/ducc/threading.cc:
+third_party/ducc/threading.h:
 third_party/eigen3/BUILD:
-third_party/eigen3/Eigen/Cholesky:
-third_party/eigen3/Eigen/Core:
-third_party/eigen3/Eigen/Eigenvalues:
-third_party/eigen3/Eigen/LU:
-third_party/eigen3/Eigen/OrderingMethods:
-third_party/eigen3/Eigen/QR:
-third_party/eigen3/Eigen/SVD:
-third_party/eigen3/Eigen/SparseCholesky:
-third_party/eigen3/Eigen/SparseCore:
 third_party/eigen3/LICENSE:
 third_party/eigen3/eigen_archive.BUILD:
-third_party/eigen3/unsupported/Eigen/CXX11/Tensor:
-third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool:
-third_party/eigen3/unsupported/Eigen/MatrixFunctions:
-third_party/eigen3/unsupported/Eigen/SpecialFunctions:
 third_party/gif.BUILD:
 third_party/gif_fix_strtok_r.patch:
 third_party/git/BUILD.tpl:
diff --git a/third_party/xla/third_party/tsl/third_party/absl/absl_designated_initializers.patch b/third_party/xla/third_party/tsl/third_party/absl/absl_designated_initializers.patch
deleted file mode 100644
index 6ee2322..0000000
--- a/third_party/xla/third_party/tsl/third_party/absl/absl_designated_initializers.patch
+++ /dev/null
@@ -1,65 +0,0 @@
-diff --git a/absl/crc/internal/crc_memcpy_x86_64.cc b/absl/crc/internal/crc_memcpy_x86_64.cc
-index 66f784de..ff424c54 100644
---- a/absl/crc/internal/crc_memcpy_x86_64.cc
-+++ b/absl/crc/internal/crc_memcpy_x86_64.cc
-@@ -359,18 +359,18 @@ CrcMemcpy::ArchSpecificEngines CrcMemcpy::GetArchSpecificEngines() {
-     case CpuType::kIntelHaswell:
-     case CpuType::kIntelIvybridge:
-       return {
--          .temporal = new FallbackCrcMemcpyEngine(),
--          .non_temporal = new CrcNonTemporalMemcpyAVXEngine(),
-+          /*.temporal=*/new FallbackCrcMemcpyEngine(),
-+          /*.non_temporal=*/new CrcNonTemporalMemcpyAVXEngine(),
-       };
-     // INTEL_SANDYBRIDGE performs better with SSE than AVX.
-     case CpuType::kIntelSandybridge:
-       return {
--          .temporal = new FallbackCrcMemcpyEngine(),
--          .non_temporal = new CrcNonTemporalMemcpyEngine(),
-+          /*.temporal=*/new FallbackCrcMemcpyEngine(),
-+          /*.non_temporal=*/new CrcNonTemporalMemcpyEngine(),
-       };
-     default:
--      return {.temporal = new FallbackCrcMemcpyEngine(),
--              .non_temporal = new FallbackCrcMemcpyEngine()};
-+      return {/*.temporal=*/new FallbackCrcMemcpyEngine(),
-+              /*.non_temporal=*/new FallbackCrcMemcpyEngine()};
-   }
- #else
-   // Get the underlying architecture.
-@@ -388,8 +388,8 @@ CrcMemcpy::ArchSpecificEngines CrcMemcpy::GetArchSpecificEngines() {
-     case CpuType::kAmdRome:
-     case CpuType::kAmdNaples:
-       return {
--          .temporal = new AcceleratedCrcMemcpyEngine<1, 2>(),
--          .non_temporal = new CrcNonTemporalMemcpyAVXEngine(),
-+          /*.temporal=*/new AcceleratedCrcMemcpyEngine<1, 2>(),
-+          /*.non_temporal=*/new CrcNonTemporalMemcpyAVXEngine(),
-       };
-     // PCLMULQDQ is slow and we don't have wide enough issue width to take
-     // advantage of it.  For an unknown architecture, don't risk using CLMULs.
-@@ -400,18 +400,18 @@ CrcMemcpy::ArchSpecificEngines CrcMemcpy::GetArchSpecificEngines() {
-     case CpuType::kIntelHaswell:
-     case CpuType::kIntelIvybridge:
-       return {
--          .temporal = new AcceleratedCrcMemcpyEngine<3, 0>(),
--          .non_temporal = new CrcNonTemporalMemcpyAVXEngine(),
-+          /*.temporal=*/new AcceleratedCrcMemcpyEngine<3, 0>(),
-+          /*.non_temporal=*/new CrcNonTemporalMemcpyAVXEngine(),
-       };
-     // INTEL_SANDYBRIDGE performs better with SSE than AVX.
-     case CpuType::kIntelSandybridge:
-       return {
--          .temporal = new AcceleratedCrcMemcpyEngine<3, 0>(),
--          .non_temporal = new CrcNonTemporalMemcpyEngine(),
-+          /*.temporal=*/new AcceleratedCrcMemcpyEngine<3, 0>(),
-+          /*.non_temporal=*/new CrcNonTemporalMemcpyEngine(),
-       };
-     default:
--      return {.temporal = new FallbackCrcMemcpyEngine(),
--              .non_temporal = new FallbackCrcMemcpyEngine()};
-+      return {/*.temporal=*/new FallbackCrcMemcpyEngine(),
-+              /*.non_temporal=*/new FallbackCrcMemcpyEngine()};
-   }
- #endif  // UNDEFINED_BEHAVIOR_SANITIZER
- }
diff --git a/third_party/xla/third_party/tsl/third_party/absl/workspace.bzl b/third_party/xla/third_party/tsl/third_party/absl/workspace.bzl
index 07f49ce..06f7516 100644
--- a/third_party/xla/third_party/tsl/third_party/absl/workspace.bzl
+++ b/third_party/xla/third_party/tsl/third_party/absl/workspace.bzl
@@ -7,8 +7,8 @@
 
     # Attention: tools parse and update these lines.
     # LINT.IfChange
-    ABSL_COMMIT = "b971ac5250ea8de900eae9f95e06548d14cd95fe"
-    ABSL_SHA256 = "8eeec9382fc0338ef5c60053f3a4b0e0708361375fe51c9e65d0ce46ccfe55a7"
+    ABSL_COMMIT = "fb3621f4f897824c0dbe0615fa94543df6192f30"
+    ABSL_SHA256 = "0320586856674d16b0b7a4d4afb22151bdc798490bb7f295eddd8f6a62b46fea"
     # LINT.ThenChange(//tensorflow/lite/tools/cmake/modules/abseil-cpp.cmake)
 
     SYS_DIRS = [
@@ -42,9 +42,6 @@
         build_file = "//third_party/absl:com_google_absl.BUILD",
         system_build_file = "//third_party/absl:system.BUILD",
         system_link_files = SYS_LINKS,
-        # This patch pulls in a fix for designated initializers that MSVC
-        # complains about. It shouldn't be necessary at the next LTS release.
-        patch_file = ["//third_party/absl:absl_designated_initializers.patch"],
         strip_prefix = "abseil-cpp-{commit}".format(commit = ABSL_COMMIT),
         urls = tf_mirror_urls("https://github.com/abseil/abseil-cpp/archive/{commit}.tar.gz".format(commit = ABSL_COMMIT)),
     )
diff --git a/third_party/xla/third_party/tsl/third_party/ducc/BUILD b/third_party/xla/third_party/tsl/third_party/ducc/BUILD
new file mode 100644
index 0000000..073696a
--- /dev/null
+++ b/third_party/xla/third_party/tsl/third_party/ducc/BUILD
@@ -0,0 +1 @@
+# DUCC FFT library (https://gitlab.mpcdf.mpg.de/mtr/ducc).
diff --git a/third_party/xla/third_party/tsl/third_party/ducc/ducc.BUILD b/third_party/xla/third_party/tsl/third_party/ducc/ducc.BUILD
new file mode 100644
index 0000000..8d71392
--- /dev/null
+++ b/third_party/xla/third_party/tsl/third_party/ducc/ducc.BUILD
@@ -0,0 +1,76 @@
+package(
+    default_visibility = ["//visibility:public"],
+    licenses = ["notice"],
+)
+
+exports_files(["LICENSE"])
+
+DUCC_COPTS = [
+    "-frtti",
+    "-fexceptions",
+    "-ffp-contract=fast",
+]
+
+# This library exposes the raw DUCC fft API.  It should be used
+# with caution, since inclusion of the headers will require any
+# dependent targets to be build with exceptions and RTTI enabled.
+# For a better-isolated target, use ":fft_wrapper".
+cc_library(
+    name = "fft",
+    srcs = [
+        "google/ducc0_custom_lowlevel_threading.h",
+        "google/threading.cc",
+        "src/ducc0/infra/aligned_array.h",
+        "src/ducc0/infra/error_handling.h",
+        "src/ducc0/infra/misc_utils.h",
+        "src/ducc0/infra/simd.h",
+        "src/ducc0/infra/threading.cc",
+        "src/ducc0/infra/useful_macros.h",
+        "src/ducc0/math/cmplx.h",
+        "src/ducc0/math/unity_roots.h",
+    ],
+    hdrs = [
+        "google/threading.h",
+        "src/ducc0/fft/fft.h",
+        "src/ducc0/fft/fft1d_impl.h",
+        "src/ducc0/fft/fftnd_impl.h",
+        "src/ducc0/infra/mav.h",
+        "src/ducc0/infra/threading.h",
+    ],
+    copts = DUCC_COPTS,
+    defines = [
+        # Use custom TSL/Eigen threading.
+        "DUCC0_CUSTOM_LOWLEVEL_THREADING=1",
+    ],
+    features = ["-use_header_modules"],
+    include_prefix = "ducc",
+    includes = [
+        ".",  # Needed for google/-relative paths.
+        "google",  # Needed for finding ducc0_custom_lowlevel_threading.h.
+        "src",  # Needed for internal headers.
+    ],
+    # The DUCC FFT source files are dual-licensed as BSD 3 clause and GPLv2.
+    # We choose BSD 3 clause.
+    licenses = ["notice"],
+    visibility = ["//visibility:private"],
+    deps = [
+        # Required for custom threadpool usage:
+        "@eigen_archive//:eigen3",
+        "@local_tsl//tsl/platform:mutex",
+    ],
+)
+
+cc_library(
+    name = "fft_wrapper",
+    srcs = ["google/fft.cc"],
+    hdrs = ["google/fft.h"],
+    copts = DUCC_COPTS,
+    features = ["-use_header_modules"],
+    include_prefix = "ducc",
+    licenses = ["notice"],
+    visibility = ["//visibility:public"],
+    deps = [
+        ":fft",
+        "@eigen_archive//:eigen3",
+    ],
+)
diff --git a/third_party/xla/third_party/tsl/third_party/ducc/ducc0_custom_lowlevel_threading.h b/third_party/xla/third_party/tsl/third_party/ducc/ducc0_custom_lowlevel_threading.h
new file mode 100644
index 0000000..688efe7
--- /dev/null
+++ b/third_party/xla/third_party/tsl/third_party/ducc/ducc0_custom_lowlevel_threading.h
@@ -0,0 +1,35 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_DUCC_GOOGLE_DUCC0_CUSTOM_LOWLEVEL_THREADING_H_
+#define THIRD_PARTY_DUCC_GOOGLE_DUCC0_CUSTOM_LOWLEVEL_THREADING_H_
+
+#include "tsl/platform/mutex.h"
+
+namespace ducc0 {
+namespace detail_threading {
+
+using Mutex = tsl::mutex;
+using UniqueLock = tsl::mutex_lock;
+using LockGuard = tsl::mutex_lock;
+using CondVar = tsl::condition_variable;
+
+// Missing variable used by DUCC threading.cc.
+extern thread_local bool in_parallel_region;
+
+}  // namespace detail_threading
+}  // namespace ducc0
+
+#endif  // THIRD_PARTY_DUCC_GOOGLE_DUCC0_CUSTOM_LOWLEVEL_THREADING_H_
diff --git a/third_party/xla/third_party/tsl/third_party/ducc/fft.cc b/third_party/xla/third_party/tsl/third_party/ducc/fft.cc
new file mode 100644
index 0000000..ec3c66f
--- /dev/null
+++ b/third_party/xla/third_party/tsl/third_party/ducc/fft.cc
@@ -0,0 +1,148 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "ducc/google/fft.h"
+
+#include <complex>
+#include <cstddef>
+#include <cstdlib>
+#include <exception>
+#include <iostream>
+#include <ostream>
+#include <vector>
+
+#include "ducc/google/threading.h"
+#include "ducc/src/ducc0/fft/fft.h"
+#include "ducc/src/ducc0/fft/fft1d_impl.h"  // IWYU pragma: keep, DUCC definitions.
+#include "ducc/src/ducc0/fft/fftnd_impl.h"  // IWYU pragma: keep, DUCC definitions.
+#include "ducc/src/ducc0/infra/mav.h"
+#include "ducc/src/ducc0/infra/threading.h"
+#include "unsupported/Eigen/CXX11/ThreadPool"
+
+namespace ducc0 {
+
+// Wrappers around DUCC calls.
+namespace google {
+
+using Shape = std::vector<std::size_t>;
+using Stride = std::vector<std::ptrdiff_t>;
+
+template <typename RealScalar>
+void c2c(const std::complex<RealScalar>* in, const Shape& in_shape,
+         const Stride& in_stride, std::complex<RealScalar>* out,
+         const Shape& out_shape, const Stride& out_stride, const Shape& axes,
+         bool forward, RealScalar scale,
+         Eigen::ThreadPoolInterface* thread_pool) {
+  ducc0::cfmav<std::complex<RealScalar>> m_in(in, in_shape, in_stride);
+  ducc0::vfmav<std::complex<RealScalar>> m_out(out, out_shape, out_stride);
+
+  try {
+    if (thread_pool == nullptr) {
+      // Use a fake threadpool.
+      ducc0::google::NoThreadPool no_thread_pool;
+      ducc0::detail_threading::ScopedUseThreadPool thread_pool_guard(
+          no_thread_pool);
+      ducc0::c2c(m_in, m_out, axes, forward, scale, 1);
+    } else {
+      EigenThreadPool eigen_thread_pool(*thread_pool);
+      ducc0::detail_threading::ScopedUseThreadPool thread_pool_guard(
+          eigen_thread_pool);
+      ducc0::c2c(m_in, m_out, axes, forward, scale,
+                 eigen_thread_pool.nthreads());
+    }
+  } catch (const std::exception& ex) {
+    std::cerr << "DUCC FFT c2c failed: " << ex.what() << std::endl;
+    std::abort();
+  }
+}
+
+template <typename RealScalar>
+void r2c(const RealScalar* in, const Shape& in_shape, const Stride& in_stride,
+         std::complex<RealScalar>* out, const Shape& out_shape,
+         const Stride& out_stride, const Shape& axes, bool forward,
+         RealScalar scale, Eigen::ThreadPoolInterface* thread_pool) {
+  ducc0::cfmav<RealScalar> m_in(in, in_shape, in_stride);
+  ducc0::vfmav<std::complex<RealScalar>> m_out(out, out_shape, out_stride);
+
+  try {
+    if (thread_pool == nullptr) {
+      // Use a fake threadpool.
+      ducc0::google::NoThreadPool no_thread_pool;
+      ducc0::detail_threading::ScopedUseThreadPool thread_pool_guard(
+          no_thread_pool);
+      ducc0::r2c(m_in, m_out, axes, forward, scale, 1);
+    } else {
+      EigenThreadPool eigen_thread_pool(*thread_pool);
+      ducc0::detail_threading::ScopedUseThreadPool thread_pool_guard(
+          eigen_thread_pool);
+      ducc0::r2c(m_in, m_out, axes, forward, scale,
+                 eigen_thread_pool.nthreads());
+    }
+  } catch (const std::exception& ex) {
+    std::cerr << "DUCC FFT r2c failed: " << ex.what() << std::endl;
+    std::abort();
+  }
+}
+
+template <typename RealScalar>
+void c2r(const std::complex<RealScalar>* in, const Shape& in_shape,
+         const Stride& in_stride, RealScalar* out, const Shape& out_shape,
+         const Stride& out_stride, const Shape& axes, bool forward,
+         RealScalar scale, Eigen::ThreadPoolInterface* thread_pool) {
+  ducc0::cfmav<std::complex<RealScalar>> m_in(in, in_shape, in_stride);
+  ducc0::vfmav<RealScalar> m_out(out, out_shape, out_stride);
+
+  try {
+    if (thread_pool == nullptr) {
+      // Use a fake threadpool.
+      ducc0::google::NoThreadPool no_thread_pool;
+      ducc0::detail_threading::ScopedUseThreadPool thread_pool_guard(
+          no_thread_pool);
+      ducc0::c2r(m_in, m_out, axes, forward, scale, 1);
+    } else {
+      EigenThreadPool eigen_thread_pool(*thread_pool);
+      ducc0::detail_threading::ScopedUseThreadPool thread_pool_guard(
+          eigen_thread_pool);
+      ducc0::c2r(m_in, m_out, axes, forward, scale,
+                 eigen_thread_pool.nthreads());
+    }
+  } catch (const std::exception& ex) {
+    std::cerr << "DUCC FFT c2r failed: " << ex.what() << std::endl;
+    std::abort();
+  }
+}
+
+#define FFT_DEFINITIONS(RealScalar)                                            \
+  template void c2c<RealScalar>(                                               \
+      const std::complex<RealScalar>* in, const Shape& in_shape,               \
+      const Stride& in_stride, std::complex<RealScalar>* out,                  \
+      const Shape& out_shape, const Stride& out_stride, const Shape& axes,     \
+      bool forward, RealScalar scale,                                          \
+      Eigen::ThreadPoolInterface* thread_pool);                                \
+  template void r2c<RealScalar>(                                               \
+      const RealScalar* in, const Shape& in_shape, const Stride& in_stride,    \
+      std::complex<RealScalar>* out, const Shape& out_shape,                   \
+      const Stride& out_stride, const Shape& axes, bool forward,               \
+      RealScalar scale, Eigen::ThreadPoolInterface* thread_pool);              \
+  template void c2r(const std::complex<RealScalar>* in, const Shape& in_shape, \
+                    const Stride& in_stride, RealScalar* out,                  \
+                    const Shape& out_shape, const Stride& out_stride,          \
+                    const Shape& axes, bool forward, RealScalar scale,         \
+                    Eigen::ThreadPoolInterface* thread_pool)
+FFT_DEFINITIONS(float);
+FFT_DEFINITIONS(double);
+#undef FFT_DEFINITIONS
+
+}  // namespace google
+}  // namespace ducc0
\ No newline at end of file
diff --git a/third_party/xla/third_party/tsl/third_party/ducc/fft.h b/third_party/xla/third_party/tsl/third_party/ducc/fft.h
new file mode 100644
index 0000000..8c1691d
--- /dev/null
+++ b/third_party/xla/third_party/tsl/third_party/ducc/fft.h
@@ -0,0 +1,77 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_DUCC_GOOGLE_FFT_H_
+#define THIRD_PARTY_DUCC_GOOGLE_FFT_H_
+
+// Wrapper around the DUCC FFT library to isolate usage of exceptions
+// and RTTI.  Eliminates all direct usage of DUCC headers.
+
+#include <complex>
+#include <cstddef>
+#include <vector>
+
+#include "unsupported/Eigen/CXX11/ThreadPool"
+
+namespace ducc0 {
+namespace google {
+
+using Shape = std::vector<std::size_t>;
+using Stride = std::vector<std::ptrdiff_t>;
+
+template <typename RealScalar>
+void c2c(const std::complex<RealScalar>* in, const Shape& in_shape,
+         const Stride& in_stride, std::complex<RealScalar>* out,
+         const Shape& out_shape, const Stride& out_stride, const Shape& axes,
+         bool forward, RealScalar scale,
+         Eigen::ThreadPoolInterface* thread_pool);
+
+template <typename RealScalar>
+void r2c(const RealScalar* in, const Shape& in_shape, const Stride& in_stride,
+         std::complex<RealScalar>* out, const Shape& out_shape,
+         const Stride& out_stride, const Shape& axes, bool forward,
+         RealScalar scale, Eigen::ThreadPoolInterface* thread_pool);
+
+template <typename RealScalar>
+void c2r(const std::complex<RealScalar>* in, const Shape& in_shape,
+         const Stride& in_stride, RealScalar* out, const Shape& out_shape,
+         const Stride& out_stride, const Shape& axes, bool forward,
+         RealScalar scale, Eigen::ThreadPoolInterface* thread_pool);
+
+#define FFT_DECLARATIONS(RealScalar)                                        \
+  extern template void c2c<RealScalar>(                                     \
+      const std::complex<RealScalar>* in, const Shape& in_shape,            \
+      const Stride& in_stride, std::complex<RealScalar>* out,               \
+      const Shape& out_shape, const Stride& out_stride, const Shape& axes,  \
+      bool forward, RealScalar scale,                                       \
+      Eigen::ThreadPoolInterface* thread_pool);                             \
+  extern template void r2c<RealScalar>(                                     \
+      const RealScalar* in, const Shape& in_shape, const Stride& in_stride, \
+      std::complex<RealScalar>* out, const Shape& out_shape,                \
+      const Stride& out_stride, const Shape& axes, bool forward,            \
+      RealScalar scale, Eigen::ThreadPoolInterface* thread_pool);           \
+  extern template void c2r(                                                 \
+      const std::complex<RealScalar>* in, const Shape& in_shape,            \
+      const Stride& in_stride, RealScalar* out, const Shape& out_shape,     \
+      const Stride& out_stride, const Shape& axes, bool forward,            \
+      RealScalar scale, Eigen::ThreadPoolInterface* thread_pool)
+FFT_DECLARATIONS(float);
+FFT_DECLARATIONS(double);
+#undef FFT_DECLARATIONS
+
+}  // namespace google
+}  // namespace ducc0
+
+#endif  // THIRD_PARTY_DUCC_GOOGLE_FFT_H_
\ No newline at end of file
diff --git a/third_party/xla/third_party/tsl/third_party/ducc/threading.cc b/third_party/xla/third_party/tsl/third_party/ducc/threading.cc
new file mode 100644
index 0000000..d079398
--- /dev/null
+++ b/third_party/xla/third_party/tsl/third_party/ducc/threading.cc
@@ -0,0 +1,68 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "ducc/google/threading.h"
+
+#include <thread>
+#include <utility>
+
+#include "ducc/src/ducc0/infra/threading.h"
+#include "unsupported/Eigen/CXX11/ThreadPool"
+
+namespace ducc0 {
+
+namespace google {
+
+namespace {
+
+// Default shared global pool.  It is created on first use.
+EigenThreadPool* GetGlobalThreadPoolSingleton() {
+  static Eigen::ThreadPool* eigen_pool =
+      new Eigen::ThreadPool(std::thread::hardware_concurrency());
+  static EigenThreadPool* pool = new EigenThreadPool(*eigen_pool);
+  return pool;
+}
+
+// Thread-local active pool for current execution.
+ducc0::detail_threading::thread_pool*& GetActiveThreadPoolSingleton() {
+  thread_local thread_pool* active_pool = nullptr;
+  return active_pool;
+}
+
+}  // namespace
+}  // namespace google
+
+// Implementations required by ducc0.
+namespace detail_threading {
+
+// Missing variable used by DUCC threading.cc.
+thread_local bool in_parallel_region = false;
+
+thread_pool* set_active_pool(thread_pool* new_pool) {
+  return std::exchange(ducc0::google::GetActiveThreadPoolSingleton(), new_pool);
+}
+
+thread_pool* get_active_pool() {
+  thread_pool* pool = google::GetActiveThreadPoolSingleton();
+  if (pool == nullptr) {
+    // Set to use a global pool.  This may trigger threadpool creation.
+    // Since the active pool is thread-local, this is thread-safe.
+    pool = google::GetGlobalThreadPoolSingleton();
+    set_active_pool(pool);
+  }
+  return pool;
+}
+
+}  // namespace detail_threading
+}  // namespace ducc0
\ No newline at end of file
diff --git a/third_party/xla/third_party/tsl/third_party/ducc/threading.h b/third_party/xla/third_party/tsl/third_party/ducc/threading.h
new file mode 100644
index 0000000..a374e3d
--- /dev/null
+++ b/third_party/xla/third_party/tsl/third_party/ducc/threading.h
@@ -0,0 +1,60 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_DUCC_GOOGLE_THREADING_H_
+#define THIRD_PARTY_DUCC_GOOGLE_THREADING_H_
+
+#include "ducc/src/ducc0/infra/threading.h"
+#include "unsupported/Eigen/CXX11/ThreadPool"
+
+namespace ducc0 {
+namespace google {
+
+using std::size_t;
+
+// Pseudo thread-pool for single-threaded execution.
+class NoThreadPool : public ducc0::detail_threading::thread_pool {
+ public:
+  size_t nthreads() const override { return 1; }
+  size_t adjust_nthreads(size_t nthreads_in) const override { return 1; };
+  void submit(std::function<void()> work) override { work(); }
+};
+
+// Thread-pool wrapper around Eigen's ThreadPool.
+class EigenThreadPool : public ducc0::detail_threading::thread_pool {
+ public:
+  EigenThreadPool(Eigen::ThreadPoolInterface& pool) : pool_{&pool} {}
+  size_t nthreads() const override { return pool_->NumThreads(); }
+  size_t adjust_nthreads(size_t nthreads_in) const override {
+    // If called by a thread in the pool, return 1
+    if (pool_->CurrentThreadId() >= 0) {
+      return 1;
+    } else if (nthreads_in == 0) {
+      return pool_->NumThreads();
+    }
+    return std::min<size_t>(nthreads_in, pool_->NumThreads());
+  };
+  void submit(std::function<void()> work) override {
+    pool_->Schedule(std::move(work));
+  }
+
+ private:
+  Eigen::ThreadPoolInterface* pool_;
+};
+
+}  // namespace google
+}  // namespace ducc0
+
+#endif  // THIRD_PARTY_DUCC_GOOGLE_THREADING_H_
diff --git a/third_party/xla/third_party/tsl/third_party/ducc/workspace.bzl b/third_party/xla/third_party/tsl/third_party/ducc/workspace.bzl
new file mode 100644
index 0000000..1475579
--- /dev/null
+++ b/third_party/xla/third_party/tsl/third_party/ducc/workspace.bzl
@@ -0,0 +1,21 @@
+"""Distinctly Useful Code Collection (DUCC) - CPU FFT Module"""
+
+load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
+
+def repo():
+    DUCC_COMMIT = "3d28aadfd8bb0219e3df188613dbbcdfffccc3cd"
+    DUCC_SHA256 = "eb044dd11374ed894d67081109d4aa7ed55c29fe3286b116f13db70da6af336c"
+    tf_http_archive(
+        name = "ducc",
+        strip_prefix = "ducc-{commit}".format(commit = DUCC_COMMIT),
+        sha256 = DUCC_SHA256,
+        urls = tf_mirror_urls("https://gitlab.mpcdf.mpg.de/mtr/ducc/-/archive/{commit}/ducc-{commit}.tar.gz".format(commit = DUCC_COMMIT)),
+        build_file = "//third_party/ducc:ducc.BUILD",
+        link_files = {
+            "//third_party/ducc:ducc0_custom_lowlevel_threading.h": "google/ducc0_custom_lowlevel_threading.h",
+            "//third_party/ducc:fft.h": "google/fft.h",
+            "//third_party/ducc:fft.cc": "google/fft.cc",
+            "//third_party/ducc:threading.cc": "google/threading.cc",
+            "//third_party/ducc:threading.h": "google/threading.h",
+        },
+    )
diff --git a/third_party/xla/third_party/tsl/third_party/eigen3/BUILD b/third_party/xla/third_party/tsl/third_party/eigen3/BUILD
index 631cc89..84a4205 100644
--- a/third_party/xla/third_party/tsl/third_party/eigen3/BUILD
+++ b/third_party/xla/third_party/tsl/third_party/eigen3/BUILD
@@ -1,71 +1,3 @@
 # Description:
 #   Eigen is a C++ template library for linear algebra: vectors,
 #   matrices, and related algorithms.
-# This is the BUILD file with extra code to patch into @eigen_archive.
-
-load("//third_party/mkl:build_defs.bzl", "if_mkl")
-
-licenses([
-    # Note: Eigen is an MPL2 library that includes GPL v3 and LGPL v2.1+ code.
-    #       We've taken special care to not reference any restricted code.
-    "reciprocal",  # MPL2
-    "notice",  # Portions BSD
-])
-
-exports_files(
-    ["LICENSE"],
-    visibility = ["//visibility:public"],
-)
-
-EIGEN3_THIRD_PARTY_HEADERS = [
-    "Eigen/Core",
-    "Eigen/LU",
-    "Eigen/Cholesky",
-    "Eigen/Eigenvalues",
-    "Eigen/OrderingMethods",
-    "Eigen/QR",
-    "Eigen/SparseCholesky",
-    "Eigen/SparseCore",
-    "Eigen/SVD",
-    "unsupported/Eigen/MatrixFunctions",
-    "unsupported/Eigen/SpecialFunctions",
-    "unsupported/Eigen/CXX11/ThreadPool",
-    "unsupported/Eigen/CXX11/Tensor",
-]
-
-cc_library(
-    name = "eigen3",
-    hdrs = EIGEN3_THIRD_PARTY_HEADERS,
-    includes = if_mkl(["./mkl_include"]),
-    visibility = ["//visibility:public"],
-    deps = [
-        "@eigen_archive//:eigen3_internal",
-    ],
-)
-
-filegroup(
-    name = "eigen_third_party_header_files",
-    srcs = EIGEN3_THIRD_PARTY_HEADERS,
-    visibility = ["//visibility:public"],
-)
-
-genrule(
-    name = "install_eigen_headers",
-    srcs = [
-        "@eigen_archive//:eigen_header_files",
-        "@eigen_archive//:eigen_source_files",
-        ":eigen_third_party_header_files",
-    ],
-    outs = ["include"],
-    cmd = """
-    mkdir $@
-    for f in $(SRCS); do
-      d="$${f%/*}"
-      d="$${d#*external/eigen_archive/}"
-
-      mkdir -p "$@/$${d}"
-      cp "$${f}" "$@/$${d}/"
-    done
-    """,
-    tags = ["manual"],
-)
diff --git a/third_party/xla/third_party/tsl/third_party/eigen3/Eigen/Cholesky b/third_party/xla/third_party/tsl/third_party/eigen3/Eigen/Cholesky
deleted file mode 100644
index c199a025..0000000
--- a/third_party/xla/third_party/tsl/third_party/eigen3/Eigen/Cholesky
+++ /dev/null
@@ -1 +0,0 @@
-#include "Eigen/Cholesky"
diff --git a/third_party/xla/third_party/tsl/third_party/eigen3/Eigen/Core b/third_party/xla/third_party/tsl/third_party/eigen3/Eigen/Core
deleted file mode 100644
index d4b0367..0000000
--- a/third_party/xla/third_party/tsl/third_party/eigen3/Eigen/Core
+++ /dev/null
@@ -1 +0,0 @@
-#include "Eigen/Core"
diff --git a/third_party/xla/third_party/tsl/third_party/eigen3/Eigen/Eigenvalues b/third_party/xla/third_party/tsl/third_party/eigen3/Eigen/Eigenvalues
deleted file mode 100644
index bf739b9..0000000
--- a/third_party/xla/third_party/tsl/third_party/eigen3/Eigen/Eigenvalues
+++ /dev/null
@@ -1 +0,0 @@
-#include "Eigen/Eigenvalues"
diff --git a/third_party/xla/third_party/tsl/third_party/eigen3/Eigen/LU b/third_party/xla/third_party/tsl/third_party/eigen3/Eigen/LU
deleted file mode 100644
index 536149c..0000000
--- a/third_party/xla/third_party/tsl/third_party/eigen3/Eigen/LU
+++ /dev/null
@@ -1 +0,0 @@
-#include "Eigen/LU"
diff --git a/third_party/xla/third_party/tsl/third_party/eigen3/Eigen/OrderingMethods b/third_party/xla/third_party/tsl/third_party/eigen3/Eigen/OrderingMethods
deleted file mode 100644
index 190fc22..0000000
--- a/third_party/xla/third_party/tsl/third_party/eigen3/Eigen/OrderingMethods
+++ /dev/null
@@ -1 +0,0 @@
-#include "Eigen/OrderingMethods"
\ No newline at end of file
diff --git a/third_party/xla/third_party/tsl/third_party/eigen3/Eigen/QR b/third_party/xla/third_party/tsl/third_party/eigen3/Eigen/QR
deleted file mode 100644
index be067d3..0000000
--- a/third_party/xla/third_party/tsl/third_party/eigen3/Eigen/QR
+++ /dev/null
@@ -1 +0,0 @@
-#include "Eigen/QR"
diff --git a/third_party/xla/third_party/tsl/third_party/eigen3/Eigen/SVD b/third_party/xla/third_party/tsl/third_party/eigen3/Eigen/SVD
deleted file mode 100644
index eecf47c..0000000
--- a/third_party/xla/third_party/tsl/third_party/eigen3/Eigen/SVD
+++ /dev/null
@@ -1 +0,0 @@
-#include "Eigen/SVD"
diff --git a/third_party/xla/third_party/tsl/third_party/eigen3/Eigen/SparseCholesky b/third_party/xla/third_party/tsl/third_party/eigen3/Eigen/SparseCholesky
deleted file mode 100644
index a6d362b..0000000
--- a/third_party/xla/third_party/tsl/third_party/eigen3/Eigen/SparseCholesky
+++ /dev/null
@@ -1 +0,0 @@
-#include "Eigen/SparseCholesky"
\ No newline at end of file
diff --git a/third_party/xla/third_party/tsl/third_party/eigen3/Eigen/SparseCore b/third_party/xla/third_party/tsl/third_party/eigen3/Eigen/SparseCore
deleted file mode 100644
index 3c60745..0000000
--- a/third_party/xla/third_party/tsl/third_party/eigen3/Eigen/SparseCore
+++ /dev/null
@@ -1 +0,0 @@
-#include "Eigen/SparseCore"
\ No newline at end of file
diff --git a/third_party/xla/third_party/tsl/third_party/eigen3/eigen_archive.BUILD b/third_party/xla/third_party/tsl/third_party/eigen3/eigen_archive.BUILD
index a2a0cc9..78b1fc8 100644
--- a/third_party/xla/third_party/tsl/third_party/eigen3/eigen_archive.BUILD
+++ b/third_party/xla/third_party/tsl/third_party/eigen3/eigen_archive.BUILD
@@ -4,8 +4,6 @@
 # This is the BUILD file used for the @eigen_archive external repository.
 
 licenses([
-    # Note: Although Eigen also includes GPL V3 and LGPL v2.1+ code, TensorFlow
-    #       has taken special care to not reference any restricted code.
     "reciprocal",  # MPL2
     "notice",  # Portions BSD
 ])
@@ -26,38 +24,29 @@
     ] + ALL_FILES_WITH_EXTENSIONS,
 )
 
-# Internal eigen headers, known to be under an MPL2 license.
-EIGEN_MPL2_SOURCES = glob(
+# Internal eigen headers.
+EIGEN_SOURCES = glob(
     [
         "Eigen/**/src/**/*.h",
         "Eigen/**/src/**/*.inc",
         "unsupported/Eigen/**/src/**/*.h",
         "unsupported/Eigen/**/src/**/*.inc",
     ],
-    exclude = [
-        # This guarantees that any file depending on non MPL2 licensed code
-        # will not compile.
-        "Eigen/src/Core/util/NonMPL2.h",
-    ],
-)
-
-alias(
-    name = "eigen3",
-    actual = "@local_tsl//third_party/eigen3",
-    visibility = ["//visibility:public"],
 )
 
 cc_library(
-    name = "eigen3_internal",
-    srcs = EIGEN_MPL2_SOURCES,
+    name = "eigen3",
+    srcs = EIGEN_SOURCES,
     hdrs = EIGEN_HEADERS,
     defines = [
-        # This define (mostly) guarantees we don't link any problematic
-        # code. We use it, but we do not rely on it, as evidenced above.
-        "EIGEN_MPL2_ONLY",
         "EIGEN_MAX_ALIGN_BYTES=64",
+        "EIGEN_ALLOW_UNALIGNED_SCALARS",  # TODO(b/296071640): Remove when underlying bugs are fixed.
+        "EIGEN_USE_AVX512_GEMM_KERNELS=0",  # TODO(b/238649163): Remove this once no longer necessary.
     ],
-    includes = ["."],
+    includes = [
+        ".",  # Third-party libraries include eigen relative to its root.
+        "./mkl_include",  # For using MKL backend for Eigen when available.
+    ],
     visibility = ["//visibility:public"],
 )
 
@@ -69,6 +58,6 @@
 
 filegroup(
     name = "eigen_source_files",
-    srcs = EIGEN_MPL2_SOURCES,
+    srcs = EIGEN_SOURCES,
     visibility = ["//visibility:public"],
 )
diff --git a/third_party/xla/third_party/tsl/third_party/eigen3/unsupported/Eigen/CXX11/Tensor b/third_party/xla/third_party/tsl/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
deleted file mode 100644
index 41db119..0000000
--- a/third_party/xla/third_party/tsl/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
+++ /dev/null
@@ -1 +0,0 @@
-#include "unsupported/Eigen/CXX11/Tensor"
diff --git a/third_party/xla/third_party/tsl/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool b/third_party/xla/third_party/tsl/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool
deleted file mode 100644
index d2639af..0000000
--- a/third_party/xla/third_party/tsl/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool
+++ /dev/null
@@ -1 +0,0 @@
-#include "unsupported/Eigen/CXX11/ThreadPool"
diff --git a/third_party/xla/third_party/tsl/third_party/eigen3/unsupported/Eigen/MatrixFunctions b/third_party/xla/third_party/tsl/third_party/eigen3/unsupported/Eigen/MatrixFunctions
deleted file mode 100644
index 314b325..0000000
--- a/third_party/xla/third_party/tsl/third_party/eigen3/unsupported/Eigen/MatrixFunctions
+++ /dev/null
@@ -1 +0,0 @@
-#include "unsupported/Eigen/MatrixFunctions"
diff --git a/third_party/xla/third_party/tsl/third_party/eigen3/unsupported/Eigen/SpecialFunctions b/third_party/xla/third_party/tsl/third_party/eigen3/unsupported/Eigen/SpecialFunctions
deleted file mode 100644
index ad13359..0000000
--- a/third_party/xla/third_party/tsl/third_party/eigen3/unsupported/Eigen/SpecialFunctions
+++ /dev/null
@@ -1 +0,0 @@
-#include "unsupported/Eigen/SpecialFunctions"
diff --git a/third_party/xla/third_party/tsl/third_party/eigen3/workspace.bzl b/third_party/xla/third_party/tsl/third_party/eigen3/workspace.bzl
index d1d8d4a..027454e 100644
--- a/third_party/xla/third_party/tsl/third_party/eigen3/workspace.bzl
+++ b/third_party/xla/third_party/tsl/third_party/eigen3/workspace.bzl
@@ -7,8 +7,8 @@
 
     # Attention: tools parse and update these lines.
     # LINT.IfChange
-    EIGEN_COMMIT = "66e8f38891841bf88ee976a316c0c78a52f0cee5"
-    EIGEN_SHA256 = "01fcd68409c038bbcfd16394274c2bf71e2bb6dda89a2319e23fc59a2da17210"
+    EIGEN_COMMIT = "aa6964bf3a34fd607837dd8123bc42465185c4f8"
+    EIGEN_SHA256 = "35ba771e30c735a4215ed784d7e032086cf89fe6622dce4d793c45dd74373362"
     # LINT.ThenChange(//tensorflow/lite/tools/cmake/modules/eigen.cmake)
 
     tf_http_archive(
diff --git a/third_party/xla/third_party/tsl/third_party/nccl/archive.BUILD b/third_party/xla/third_party/tsl/third_party/nccl/archive.BUILD
index 05293fd..1813bac 100644
--- a/third_party/xla/third_party/tsl/third_party/nccl/archive.BUILD
+++ b/third_party/xla/third_party/tsl/third_party/nccl/archive.BUILD
@@ -19,7 +19,7 @@
 
 NCCL_MAJOR = 2
 
-NCCL_MINOR = 16
+NCCL_MINOR = 18
 
 NCCL_PATCH = 5
 
@@ -210,6 +210,10 @@
     ],
     include_prefix = "third_party/nccl",
     linkopts = ["-lrt"],
+    # The following definition is needed to enable placeholder literals such as
+    # PRIx64 defined at the inttypes.h since Tensorflow docker image uses
+    # an old version of glibc.
+    local_defines = ["__STDC_FORMAT_MACROS"],
     strip_include_prefix = "src",
     target_compatible_with = select({
         "@local_config_cuda//cuda:using_clang": [],
diff --git a/third_party/xla/third_party/tsl/third_party/nccl/archive.patch b/third_party/xla/third_party/tsl/third_party/nccl/archive.patch
index f951a6a..8ef0af9 100644
--- a/third_party/xla/third_party/tsl/third_party/nccl/archive.patch
+++ b/third_party/xla/third_party/tsl/third_party/nccl/archive.patch
@@ -30,19 +30,6 @@
 similarity index 100%
 rename from src/collectives/device/sendrecv.cu
 rename to src/collectives/device/sendrecv.cu.cc
-diff --git a/src/include/nvtx.h b/src/include/nvtx.h
-index 2aeb932..cdc67d2 100644
---- a/src/include/nvtx.h
-+++ b/src/include/nvtx.h
-@@ -37,7 +37,7 @@ struct nccl_domain{static constexpr char const* name{"NCCL"};};
-
- class payload_schema {
-  public:
--  NVTX3_RELAXED_CONSTEXPR explicit payload_schema(const nvtxPayloadSchemaEntry_t entries[], size_t numEntries, const uint64_t schemaId, const char* schemaName = nullptr) noexcept
-+  explicit payload_schema(const nvtxPayloadSchemaEntry_t entries[], size_t numEntries, const uint64_t schemaId, const char* schemaName = nullptr) noexcept
-   {
-     schema_attr.name = schemaName;
-     schema_attr.entries = entries;
 diff --git a/src/collectives/device/common.h b/src/collectives/device/common.h
 index accf8371a..4ab1bfac6 100644
 --- a/src/collectives/device/common.h
diff --git a/third_party/xla/third_party/tsl/third_party/py/ml_dtypes/ml_dtypes.BUILD b/third_party/xla/third_party/tsl/third_party/py/ml_dtypes/ml_dtypes.BUILD
index ccf607d..a85195e 100644
--- a/third_party/xla/third_party/tsl/third_party/py/ml_dtypes/ml_dtypes.BUILD
+++ b/third_party/xla/third_party/tsl/third_party/py/ml_dtypes/ml_dtypes.BUILD
@@ -17,7 +17,7 @@
         ".",
         "ml_dtypes",
     ],
-    deps = ["@org_tensorflow//third_party/eigen3"],
+    deps = ["@eigen_archive//:eigen3"],
 )
 
 cc_library(
@@ -48,7 +48,7 @@
     deps = [
         ":float8",
         ":int4",
-        "@org_tensorflow//third_party/eigen3",
+        "@eigen_archive//:eigen3",
         "@org_tensorflow//third_party/py/numpy:headers",
     ],
 )
diff --git a/third_party/xla/third_party/tsl/third_party/py/ml_dtypes/ml_dtypes.tests.BUILD b/third_party/xla/third_party/tsl/third_party/py/ml_dtypes/ml_dtypes.tests.BUILD
index 37cd52d..574659a 100644
--- a/third_party/xla/third_party/tsl/third_party/py/ml_dtypes/ml_dtypes.tests.BUILD
+++ b/third_party/xla/third_party/tsl/third_party/py/ml_dtypes/ml_dtypes.tests.BUILD
@@ -55,7 +55,7 @@
         "//:float8",
         "@com_google_absl//absl/strings",
         "@com_google_googletest//:gtest_main",
-        "@org_tensorflow//third_party/eigen3",
+        "@eigen_archive//:eigen3",
     ],
 )
 
@@ -66,6 +66,6 @@
     deps = [
         "//:int4",
         "@com_google_googletest//:gtest_main",
-        "@org_tensorflow//third_party/eigen3",
+        "@eigen_archive//:eigen3",
     ],
 )
diff --git a/third_party/xla/third_party/tsl/third_party/py/non_hermetic/ml_dtypes/ml_dtypes.BUILD b/third_party/xla/third_party/tsl/third_party/py/non_hermetic/ml_dtypes/ml_dtypes.BUILD
index ccf607d..a85195e 100644
--- a/third_party/xla/third_party/tsl/third_party/py/non_hermetic/ml_dtypes/ml_dtypes.BUILD
+++ b/third_party/xla/third_party/tsl/third_party/py/non_hermetic/ml_dtypes/ml_dtypes.BUILD
@@ -17,7 +17,7 @@
         ".",
         "ml_dtypes",
     ],
-    deps = ["@org_tensorflow//third_party/eigen3"],
+    deps = ["@eigen_archive//:eigen3"],
 )
 
 cc_library(
@@ -48,7 +48,7 @@
     deps = [
         ":float8",
         ":int4",
-        "@org_tensorflow//third_party/eigen3",
+        "@eigen_archive//:eigen3",
         "@org_tensorflow//third_party/py/numpy:headers",
     ],
 )
diff --git a/third_party/xla/third_party/tsl/third_party/py/non_hermetic/ml_dtypes/ml_dtypes.tests.BUILD b/third_party/xla/third_party/tsl/third_party/py/non_hermetic/ml_dtypes/ml_dtypes.tests.BUILD
index 37cd52d..574659a 100644
--- a/third_party/xla/third_party/tsl/third_party/py/non_hermetic/ml_dtypes/ml_dtypes.tests.BUILD
+++ b/third_party/xla/third_party/tsl/third_party/py/non_hermetic/ml_dtypes/ml_dtypes.tests.BUILD
@@ -55,7 +55,7 @@
         "//:float8",
         "@com_google_absl//absl/strings",
         "@com_google_googletest//:gtest_main",
-        "@org_tensorflow//third_party/eigen3",
+        "@eigen_archive//:eigen3",
     ],
 )
 
@@ -66,6 +66,6 @@
     deps = [
         "//:int4",
         "@com_google_googletest//:gtest_main",
-        "@org_tensorflow//third_party/eigen3",
+        "@eigen_archive//:eigen3",
     ],
 )
diff --git a/third_party/xla/third_party/tsl/third_party/tf_runtime/tf_runtime.patch b/third_party/xla/third_party/tsl/third_party/tf_runtime/tf_runtime.patch
deleted file mode 100644
index a9a9d4a..0000000
--- a/third_party/xla/third_party/tsl/third_party/tf_runtime/tf_runtime.patch
+++ /dev/null
@@ -1,84 +0,0 @@
-Intermittent patch to TFRT to submit a TF/TFRT cross-cutting change.
-This patch will be applied only until TF's TFRT commit is automatically bumped.
-
----
-
-diff --git a/backends/gpu/include/tfrt/gpu/gpu_types.h b/backends/gpu/include/tfrt/gpu/gpu_types.h
-index 3d311c3..a216716 100644
---- a/backends/gpu/include/tfrt/gpu/gpu_types.h
-+++ b/backends/gpu/include/tfrt/gpu/gpu_types.h
-@@ -295,11 +295,7 @@
-       wrapper::CurrentContext current, wrapper::Stream stream,
-       wrapper::CclComm comm)>;
- 
--  explicit GpuCclHandle(AsyncValueRef<GpuContext> context,
--                        wrapper::OwningCclComm comm, int num_ranks);
--  // TODO(hanbinyoon): Remove after transitioning to the above constructor.
--  explicit GpuCclHandle(AsyncValueRef<GpuContext> context,
--                        wrapper::OwningCclComm comm);
-+  GpuCclHandle(AsyncValueRef<GpuContext> context, wrapper::OwningCclComm comm);
-   ~GpuCclHandle();
- 
-   GpuCclHandle(GpuCclHandle&&) = default;
-@@ -311,8 +307,6 @@
-   llvm::Error ExecuteCallbacks(wrapper::CurrentContext current,
-                                wrapper::Stream stream);
- 
--  int num_ranks() const { return num_ranks_; }
--
-   const wrapper::OwningCclComm& operator->() const { return comm_; }
-   wrapper::CclComm get() const { return comm_.get(); }
-   wrapper::CclComm release();
-@@ -322,7 +316,6 @@
-  private:
-   AsyncValueRef<GpuContext> context_;
-   wrapper::OwningCclComm comm_;
--  int num_ranks_;
-   std::vector<Callback> callbacks_;
- };
- 
-diff --git a/backends/gpu/lib/gpu_types.cc b/backends/gpu/lib/gpu_types.cc
-index 38529bc..01e3dba 100644
---- a/backends/gpu/lib/gpu_types.cc
-+++ b/backends/gpu/lib/gpu_types.cc
-@@ -214,15 +214,8 @@
- GpuBlasHandle::~GpuBlasHandle() = default;
- 
- GpuCclHandle::GpuCclHandle(AsyncValueRef<GpuContext> context,
--                           wrapper::OwningCclComm comm, int num_ranks)
--    : context_(std::move(context)),
--      comm_(std::move(comm)),
--      num_ranks_(num_ranks) {}
--
--// TODO(hanbinyoon): Remove after transitioning to the above constructor.
--GpuCclHandle::GpuCclHandle(AsyncValueRef<GpuContext> context,
-                            wrapper::OwningCclComm comm)
--    : context_(std::move(context)), comm_(std::move(comm)), num_ranks_(0) {}
-+    : context_(std::move(context)), comm_(std::move(comm)) {}
- 
- GpuCclHandle::~GpuCclHandle() = default;
- 
-diff --git a/backends/gpu/lib/kernels/ccl_kernels.cc b/backends/gpu/lib/kernels/ccl_kernels.cc
-index 52ce820..9cfc1de 100644
---- a/backends/gpu/lib/kernels/ccl_kernels.cc
-+++ b/backends/gpu/lib/kernels/ccl_kernels.cc
-@@ -107,8 +107,6 @@
-   auto width = ToWidthInBytes(type);
-   if (!width) return width.takeError();
-   assert(*width != 0);
--  if (input->size() != output->size() * handle->num_ranks())
--    return MakeStringError("Input size must be output size times ranks.");
- 
-   handle->AddCallback([input = input.ValueRef(), output = output.ValueRef(),
-                        recvcount = output->size() / *width, type,
-@@ -116,6 +114,10 @@
-                           wrapper::CurrentContext current,
-                           wrapper::Stream stream,
-                           wrapper::CclComm comm) -> llvm::Error {
-+    auto count = wrapper::CclCommCount(comm);
-+    if (!count) return count.takeError();
-+    if (input->size() != output->size() * *count)
-+      return MakeStringError("Input size must be output size times ranks.");
-     return wrapper::CclReduceScatter(current, input->pointer(),
-                                      output->pointer(), recvcount, type, op,
-                                      comm, stream);
diff --git a/third_party/xla/third_party/tsl/third_party/tf_runtime/tf_runtime_clangcl.patch b/third_party/xla/third_party/tsl/third_party/tf_runtime/tf_runtime_clangcl.patch
deleted file mode 100644
index ce1859d..0000000
--- a/third_party/xla/third_party/tsl/third_party/tf_runtime/tf_runtime_clangcl.patch
+++ /dev/null
@@ -1,14 +0,0 @@
-diff --git a/include/tfrt/support/std_mutex.h b/include/tfrt/support/std_mutex.h
-index 6238d097..9fb24279 100644
---- a/include/tfrt/support/std_mutex.h
-+++ b/include/tfrt/support/std_mutex.h
-@@ -50,7 +50,7 @@ class TFRT_CAPABILITY("mutex") mutex {
- 
-  private:
-   friend class mutex_lock;
--  std::mutex mu_;
-+  std::mutex mu_{};
- };
-
- // Wrap std::unique_lock<std::mutex> with support for thread annotations.
- 
diff --git a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl
index 44d692c..f4c03be 100644
--- a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl
+++ b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl
@@ -6,8 +6,8 @@
     """Imports TFRT."""
 
     # Attention: tools parse and update these lines.
-    TFRT_COMMIT = "6d71fa4816fafb69ee1caac955b2b3844290d577"
-    TFRT_SHA256 = "37da0c18b558e85e8e9c9e482c217221762679265245893ab87c136df7265446"
+    TFRT_COMMIT = "bc45a6d53a5554e3b12fd42d0e0b4862cf2cef92"
+    TFRT_SHA256 = "e77b6fd0de15ff2e3e0cccaa78f645aac6674d495b9f9a90fe348ce40a233c6b"
 
     tf_http_archive(
         name = "tf_runtime",
diff --git a/third_party/xla/third_party/tsl/tools/def_file_filter/def_file_filter.py.tpl b/third_party/xla/third_party/tsl/tools/def_file_filter/def_file_filter.py.tpl
index 18426a4..4091a57 100644
--- a/third_party/xla/third_party/tsl/tools/def_file_filter/def_file_filter.py.tpl
+++ b/third_party/xla/third_party/tsl/tools/def_file_filter/def_file_filter.py.tpl
@@ -300,8 +300,8 @@
     def_fp.write("\t ??_7ConfigProto@tensorflow@@6B@\n") # for _pywrap_tfe
     def_fp.write("\t ??_7CoordinatedTask@tensorflow@@6B@\n") # for _pywrap_tfe
     def_fp.write("\t ?InternalSwap@CoordinatedTask@tensorflow@@AEAAXPEAV12@@Z\n") # for _pywrap_tfe
-    def_fp.write("\t ?kSeed@MixingHashState@hash_internal@lts_20230125@absl@@0QEBXEB\n") # for _pywrap_tfcompile
-    def_fp.write("\t ?kEmptyGroup@container_internal@lts_20230125@absl@@3QBW4ctrl_t@123@B\n") # for _pywrap_tfcompile
+    def_fp.write("\t ?kSeed@MixingHashState@hash_internal@lts_20230802@absl@@0QEBXEB\n") # for _pywrap_tfcompile
+    def_fp.write("\t ?kEmptyGroup@container_internal@lts_20230802@absl@@3QBW4ctrl_t@123@B\n") # for _pywrap_tfcompile
     def_fp.write("\t ??_7GraphDef@tensorflow@@6B@\n")
     def_fp.write("\t ??_7DeviceProperties@tensorflow@@6B@\n")
     def_fp.write("\t ??_7MetaGraphDef@tensorflow@@6B@\n")
@@ -310,7 +310,7 @@
     def_fp.write("\t ??1CoordinatedTask@tensorflow@@UEAA@XZ\n") # for _pywrap_tfe
     def_fp.write("\t ?CopyFrom@CoordinatedTask@tensorflow@@QEAAXAEBV12@@Z\n") # for _pywrap_tfe
     def_fp.write("\t ??0CoordinatedTask@tensorflow@@IEAA@PEAVArena@protobuf@google@@_N@Z\n") # for _pywrap_tfe
-    def_fp.write("\t ?MaybeTrackCordImpl@CordzInfo@cord_internal@lts_20230125@absl@@CAXAEAVInlineData@234@AEBV5234@W4MethodIdentifier@CordzUpdateTracker@234@@Z\n") # for tensorflow::Status usage of absl::Cord
+    def_fp.write("\t ?MaybeTrackCordImpl@CordzInfo@cord_internal@lts_20230802@absl@@CAXAEAVInlineData@234@AEBV5234@W4MethodIdentifier@CordzUpdateTracker@234@@Z\n") # for tensorflow::Status usage of absl::Cord
 
 
     # Each symbols returned by undname matches the same position in candidates.
diff --git a/third_party/xla/third_party/tsl/tools/toolchains/python/python_repo.bzl b/third_party/xla/third_party/tsl/tools/toolchains/python/python_repo.bzl
index 59be9f6..77011b2 100644
--- a/third_party/xla/third_party/tsl/tools/toolchains/python/python_repo.bzl
+++ b/third_party/xla/third_party/tsl/tools/toolchains/python/python_repo.bzl
@@ -5,7 +5,7 @@
 """
 
 VERSIONS = ["3.9", "3.10", "3.11", "3.12"]
-DEFAULT_VERSION = "3.10"
+DEFAULT_VERSION = "3.11"
 WARNING = """
 TF_PYTHON_VERSION environment variable was not set correctly; using Python {}.
 
@@ -13,6 +13,11 @@
 export TF_PYTHON_VERSION=3.11
 """.format(DEFAULT_VERSION)
 
+content = """
+TF_PYTHON_VERSION = "{}"
+HERMETIC_PYTHON_VERSION = "{}"
+"""
+
 def _python_repository_impl(repository_ctx):
     repository_ctx.file("BUILD", "")
     version = repository_ctx.os.environ.get("TF_PYTHON_VERSION", "")
@@ -21,8 +26,7 @@
         version = DEFAULT_VERSION
     repository_ctx.file(
         "py_version.bzl",
-        "HERMETIC_PYTHON_VERSION = \"%s\"" %
-        version,
+        content.format(version, version),
     )
 
 python_repository = repository_rule(
diff --git a/third_party/xla/third_party/tsl/tools/toolchains/remote_config/configs.bzl b/third_party/xla/third_party/tsl/tools/toolchains/remote_config/configs.bzl
index 3b63e3a..4554463 100644
--- a/third_party/xla/third_party/tsl/tools/toolchains/remote_config/configs.bzl
+++ b/third_party/xla/third_party/tsl/tools/toolchains/remote_config/configs.bzl
@@ -602,3 +602,82 @@
             "TF_TENSORRT_VERSION": "8.6",
         },
     )
+
+    sigbuild_tf_configs(
+        name_container_map = {
+            "sigbuild-r2.16": "docker://gcr.io/tensorflow-sigs/build@sha256:c13559bbf5df818bb586ad0880b29c409398b56fd8cc122ab0b31dc2b2416505",
+            "sigbuild-r2.16-python3.9": "docker://gcr.io/tensorflow-sigs/build@sha256:c13559bbf5df818bb586ad0880b29c409398b56fd8cc122ab0b31dc2b2416505",
+            "sigbuild-r2.16-python3.10": "docker://gcr.io/tensorflow-sigs/build@sha256:93c234df4c781af6974d86e9d1dd2e19ce0845b1b662c38e9a30d1de64eab3b0",
+            "sigbuild-r2.16-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:d0a91705406aad65a79011683b8f7d4b8131625ea26a6d08aa7c6eb6955873a2",
+            "sigbuild-r2.16-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:ed7313f95bce391cbf3b498ff6c534d163cc2bb91ca1d6ef6363bde4fd9e0cfc",
+        },
+        # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12
+        # and manylinux2014 is 2.17.
+        env = {
+            "ABI_LIBC_VERSION": "glibc_2.19",
+            "ABI_VERSION": "gcc",
+            "BAZEL_COMPILER": "/dt9/usr/bin/gcc",
+            "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu",
+            "BAZEL_TARGET_CPU": "k8",
+            "BAZEL_TARGET_LIBC": "glibc_2.19",
+            "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu",
+            "CC": "/dt9/usr/bin/gcc",
+            "CC_TOOLCHAIN_NAME": "linux_gnu_x86",
+            "CLEAR_CACHE": "1",
+            "CUDNN_INSTALL_PATH": "/usr/lib/x86_64-linux-gnu",
+            "GCC_HOST_COMPILER_PATH": "/dt9/usr/bin/gcc",
+            "GCC_HOST_COMPILER_PREFIX": "/usr/bin",
+            "HOST_CXX_COMPILER": "/dt9/usr/bin/gcc",
+            "HOST_C_COMPILER": "/dt9/usr/bin/gcc",
+            "PYTHON_BIN_PATH": "/usr/bin/python3",
+            "TENSORRT_INSTALL_PATH": "/usr/lib/x86_64-linux-gnu",
+            "TF_CUDA_CLANG": "0",
+            "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0",
+            "TF_CUDA_VERSION": "12.2",
+            "TF_CUDNN_VERSION": "8.9",
+            "TF_ENABLE_XLA": "1",
+            "TF_NEED_CUDA": "1",
+            "TF_NEED_TENSORRT": "1",
+            "TF_SYSROOT": "/dt9",
+            "TF_TENSORRT_VERSION": "8.6",
+        },
+    )
+
+    sigbuild_tf_configs(
+        name_container_map = {
+            "sigbuild-r2.16-clang": "docker://gcr.io/tensorflow-sigs/build@sha256:c13559bbf5df818bb586ad0880b29c409398b56fd8cc122ab0b31dc2b2416505",
+            "sigbuild-r2.16-clang-python3.9": "docker://gcr.io/tensorflow-sigs/build@sha256:c13559bbf5df818bb586ad0880b29c409398b56fd8cc122ab0b31dc2b2416505",
+            "sigbuild-r2.16-clang-python3.10": "docker://gcr.io/tensorflow-sigs/build@sha256:93c234df4c781af6974d86e9d1dd2e19ce0845b1b662c38e9a30d1de64eab3b0",
+            "sigbuild-r2.16-clang-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:d0a91705406aad65a79011683b8f7d4b8131625ea26a6d08aa7c6eb6955873a2",
+            "sigbuild-r2.16-clang-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:ed7313f95bce391cbf3b498ff6c534d163cc2bb91ca1d6ef6363bde4fd9e0cfc",
+        },
+        # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12
+        # and manylinux2014 is 2.17.
+        env = {
+            "ABI_LIBC_VERSION": "glibc_2.19",
+            "ABI_VERSION": "gcc",
+            "BAZEL_COMPILER": "/usr/lib/llvm-17/bin/clang",
+            "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu",
+            "BAZEL_TARGET_CPU": "k8",
+            "BAZEL_TARGET_LIBC": "glibc_2.19",
+            "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu",
+            "CC": "/usr/lib/llvm-17/bin/clang",
+            "CC_TOOLCHAIN_NAME": "linux_gnu_x86",
+            "CLEAR_CACHE": "1",
+            "CUDNN_INSTALL_PATH": "/usr/lib/x86_64-linux-gnu",
+            "CLANG_CUDA_COMPILER_PATH": "/usr/lib/llvm-17/bin/clang",
+            "HOST_CXX_COMPILER": "/usr/lib/llvm-17/bin/clang",
+            "HOST_C_COMPILER": "/usr/lib/llvm-17/bin/clang",
+            "PYTHON_BIN_PATH": "/usr/bin/python3",
+            "TENSORRT_INSTALL_PATH": "/usr/lib/x86_64-linux-gnu",
+            "TF_CUDA_CLANG": "1",
+            "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0",
+            "TF_CUDA_VERSION": "12.2",
+            "TF_CUDNN_VERSION": "8.9",
+            "TF_ENABLE_XLA": "1",
+            "TF_NEED_CUDA": "1",
+            "TF_NEED_TENSORRT": "1",
+            "TF_SYSROOT": "/dt9",
+            "TF_TENSORRT_VERSION": "8.6",
+        },
+    )
diff --git a/third_party/xla/third_party/tsl/tools/toolchains/remote_config/containers.bzl b/third_party/xla/third_party/tsl/tools/toolchains/remote_config/containers.bzl
index 1b540ed..bfb4634 100644
--- a/third_party/xla/third_party/tsl/tools/toolchains/remote_config/containers.bzl
+++ b/third_party/xla/third_party/tsl/tools/toolchains/remote_config/containers.bzl
@@ -1,11 +1,11 @@
 """Docker images used with remote config and RBE."""
 
-"""SHA 256 values for each image."""
+# SHA 256 values for each image.
 container_digests = {
     # TF now uses only this container
     "cuda11.2-cudnn8.1-ubuntu20.04-manylinux2014-multipython": "sha256:48612bd85709cd014711d0b0f87e0806f3567d06d2e81c6e860516b87498b821",
     # JAX manylinux2014 configs.
-    "cuda11.8-cudnn8.6-ubuntu20.04-manylinux2014-multipython": "sha256:77234e5750afcf85c08e8980eff2e8c58ba207a0c32b06a372cafb687d144d2b",
+    "cuda11.8-cudnn8.6-ubuntu20.04-manylinux2014-multipython": "sha256:ab39410baf2fc1d31d50540acec7640d7f4814fa694e2421b696b6f0a058d645",
     "cuda12.2-cudnn8.9-ubuntu20.04-manylinux2014-multipython": "sha256:b699d6ae235ac601dc3e62391ac7c4606cb10331f8141983858c1580f5e74ddb",
     # ROCM, probably not all of them still in use
     "rocm-ubuntu18.04-manylinux2010-multipython": "sha256:6e953a09b145df338bcb03e9e36f99b291140c29b72d0a048fb6c5905ccad5eb",
diff --git a/third_party/xla/third_party/tsl/tsl/framework/contraction/BUILD b/third_party/xla/third_party/tsl/tsl/framework/contraction/BUILD
index 3ea59f3..cadbfbf 100644
--- a/third_party/xla/third_party/tsl/tsl/framework/contraction/BUILD
+++ b/third_party/xla/third_party/tsl/tsl/framework/contraction/BUILD
@@ -77,11 +77,6 @@
     name = "eigen_contraction_kernel",
     hdrs = ["eigen_contraction_kernel.h"],
     compatible_with = get_compatible_with_portable(),
-    # Hack to disable breaking AVX512 special GemmKernel. There is a conflicting
-    # specialization there causing build breakages.  This must be added here
-    # as "defines" so that the header is excluded in all dependent targets.
-    # TODO(b/238649163): remove this once no longer necessary.
-    defines = ["EIGEN_USE_AVX512_GEMM_KERNELS=0"],
     visibility = ["//visibility:public"],
     deps = select({
         ":no_mkldnn_contraction_kernel": [":eigen_contraction_kernel_no_mkl"],
diff --git a/third_party/xla/third_party/tsl/tsl/framework/convolution/BUILD b/third_party/xla/third_party/tsl/tsl/framework/convolution/BUILD
index b483d53..40be57d 100644
--- a/third_party/xla/third_party/tsl/tsl/framework/convolution/BUILD
+++ b/third_party/xla/third_party/tsl/tsl/framework/convolution/BUILD
@@ -19,11 +19,6 @@
         "eigen_spatial_convolutions-inl.h",
     ],
     compatible_with = get_compatible_with_portable(),
-    # Hack to disable breaking AVX512 special GemmKernel. There is a conflicting
-    # specialization there causing build breakages.  This must be added here
-    # as "defines" so that the header is excluded in all dependent targets.
-    # TODO(b/238649163): remove this once no longer necessary.
-    defines = ["EIGEN_USE_AVX512_GEMM_KERNELS=0"],
     visibility = ["//visibility:public"],
     deps = [
         "//tsl/framework/convolution:eigen_convolution_helpers",
diff --git a/third_party/xla/third_party/tsl/tsl/lib/core/BUILD b/third_party/xla/third_party/tsl/tsl/lib/core/BUILD
index d8c95c3..bae6b50 100644
--- a/third_party/xla/third_party/tsl/tsl/lib/core/BUILD
+++ b/third_party/xla/third_party/tsl/tsl/lib/core/BUILD
@@ -4,6 +4,7 @@
 #   The libraries in this package are not allowed to have ANY dependencies
 #   to other TF components outside of TSL.
 
+load("//tsl/platform:build_config.bzl", "tsl_cc_test")
 load("//tsl:tsl.bzl", "set_external_visibility")
 load("//tsl:tsl.default.bzl", "get_compatible_with_portable")
 load(
@@ -96,5 +97,18 @@
     deps = [
         "//tsl/platform:logging",
         "//tsl/platform:types",
+        "@com_google_absl//absl/numeric:bits",
+    ],
+)
+
+tsl_cc_test(
+    name = "bits_test",
+    size = "small",
+    srcs = ["bits_test.cc"],
+    visibility = ["//visibility:public"],
+    deps = [
+        ":bits",
+        "//tsl/platform:test",
+        "//tsl/platform:test_main",
     ],
 )
diff --git a/third_party/xla/third_party/tsl/tsl/lib/core/bits.h b/third_party/xla/third_party/tsl/tsl/lib/core/bits.h
index e4e2f70..9a31ae5 100644
--- a/third_party/xla/third_party/tsl/tsl/lib/core/bits.h
+++ b/third_party/xla/third_party/tsl/tsl/lib/core/bits.h
@@ -16,6 +16,9 @@
 #ifndef TENSORFLOW_TSL_LIB_CORE_BITS_H_
 #define TENSORFLOW_TSL_LIB_CORE_BITS_H_
 
+#include <cstdint>
+
+#include "absl/numeric/bits.h"
 #include "tsl/platform/logging.h"
 #include "tsl/platform/types.h"
 
@@ -104,6 +107,14 @@
   return 1LL << exponent;
 }
 
+inline int64_t NextPowerOfTwoS64(int64_t value) {
+  constexpr int64_t kMaxRepresentablePowerOfTwo =
+      static_cast<int64_t>(uint64_t{1} << 62);
+  DCHECK_GE(value, 0);
+  DCHECK_LE(value, kMaxRepresentablePowerOfTwo);
+  return static_cast<int64_t>(absl::bit_ceil(static_cast<uint64_t>(value)));
+}
+
 }  // namespace tsl
 
 #endif  // TENSORFLOW_TSL_LIB_CORE_BITS_H_
diff --git a/third_party/xla/third_party/tsl/tsl/lib/core/bits_test.cc b/third_party/xla/third_party/tsl/tsl/lib/core/bits_test.cc
new file mode 100644
index 0000000..2c38e1c
--- /dev/null
+++ b/third_party/xla/third_party/tsl/tsl/lib/core/bits_test.cc
@@ -0,0 +1,39 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tsl/lib/core/bits.h"
+
+#include <cstdint>
+
+#include "tsl/platform/test.h"
+
+namespace tsl {
+namespace {
+
+TEST(BitsTest, NextPowerOfTwoS64) {
+  constexpr int64_t kMaxRepresentablePowerOfTwo =
+      static_cast<int64_t>(uint64_t{1} << 62);
+  EXPECT_EQ(NextPowerOfTwoS64(0), 1);
+  EXPECT_EQ(NextPowerOfTwoS64(1), 1);
+  EXPECT_EQ(NextPowerOfTwoS64(2), 2);
+  EXPECT_EQ(NextPowerOfTwoS64(3), 4);
+  EXPECT_EQ(NextPowerOfTwoS64(kMaxRepresentablePowerOfTwo - 1),
+            kMaxRepresentablePowerOfTwo);
+  EXPECT_EQ(NextPowerOfTwoS64(kMaxRepresentablePowerOfTwo),
+            kMaxRepresentablePowerOfTwo);
+}
+
+}  // namespace
+}  // namespace tsl
diff --git a/third_party/xla/third_party/tsl/tsl/platform/BUILD b/third_party/xla/third_party/tsl/tsl/platform/BUILD
index 527bf3f..6b41e3b 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/BUILD
+++ b/third_party/xla/third_party/tsl/tsl/platform/BUILD
@@ -1149,7 +1149,7 @@
     compatible_with = get_compatible_with_portable(),
     visibility = ["//visibility:public"],
     deps = [
-        ":platform",
+        "@com_google_absl//absl/base:prefetch",
     ],
 )
 
diff --git a/third_party/xla/third_party/tsl/tsl/platform/prefetch.h b/third_party/xla/third_party/tsl/tsl/platform/prefetch.h
index 451c26d..d883529 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/prefetch.h
+++ b/third_party/xla/third_party/tsl/tsl/platform/prefetch.h
@@ -16,41 +16,27 @@
 #ifndef TENSORFLOW_TSL_PLATFORM_PREFETCH_H_
 #define TENSORFLOW_TSL_PLATFORM_PREFETCH_H_
 
-#include "tsl/platform/platform.h"
+#include "absl/base/prefetch.h"
 
 namespace tsl {
 namespace port {
 
-// Prefetching support
-//
-// Defined behavior on some of the uarchs:
-// PREFETCH_HINT_T0:
-//   prefetch to all levels of the hierarchy (except on p4: prefetch to L2)
-// PREFETCH_HINT_NTA:
-//   p4: fetch to L2, but limit to 1 way (out of the 8 ways)
-//   core: skip L2, go directly to L1
-//   k8 rev E and later: skip L2, can go to either of the 2-ways in L1
+// Prefetching support.
+// Deprecated. Prefer to call absl::Prefetch* directly.
+
 enum PrefetchHint {
-  PREFETCH_HINT_T0 = 3,  // More temporal locality
-  PREFETCH_HINT_T1 = 2,
-  PREFETCH_HINT_T2 = 1,  // Less temporal locality
+  PREFETCH_HINT_T0 = 3,  // Temporal locality
   PREFETCH_HINT_NTA = 0  // No temporal locality
 };
-template <PrefetchHint hint>
-void prefetch(const void* x);
 
-// ---------------------------------------------------------------------------
-// Inline implementation
-// ---------------------------------------------------------------------------
 template <PrefetchHint hint>
-inline void prefetch(const void* x) {
-// Check of COMPILER_GCC macro below is kept only for backward-compatibility
-// reasons. COMPILER_GCC3 is the macro that actually enables prefetch.
-#if defined(__llvm__) || defined(COMPILER_GCC) || defined(COMPILER_GCC3)
-  __builtin_prefetch(x, 0, hint);
-#else
-// You get no effect.  Feel free to add more sections above.
-#endif
+void prefetch(const void* x) {
+  absl::PrefetchToLocalCache(x);
+}
+
+template <>
+inline void prefetch<PREFETCH_HINT_NTA>(const void* x) {
+  absl::PrefetchToLocalCacheNta(x);
 }
 
 }  // namespace port
diff --git a/third_party/xla/third_party/tsl/tsl/platform/test.cc b/third_party/xla/third_party/tsl/tsl/platform/test.cc
index 129c7db..e8f4102 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/test.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/test.cc
@@ -62,11 +62,10 @@
 std::string XlaSrcRoot() {
   std::string workspace = GetEnvVarOrDie("TEST_WORKSPACE");
   std::string srcdir = GetEnvVarOrDie("TEST_SRCDIR");
-  const char* xla_path = "tensorflow/compiler/xla";
 
-  return kIsOpenSource
-             ? io::JoinPath(srcdir, workspace, xla_path)
-             : io::JoinPath(srcdir, workspace, "third_party", xla_path);
+  return kIsOpenSource ? io::JoinPath(srcdir, workspace, "xla")
+                       : io::JoinPath(srcdir, workspace,
+                                      "third_party/tensorflow/compiler/xla");
 }
 
 std::string TslSrcRoot() {
diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD b/third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD
index 0c3e736..99d4b2f 100644
--- a/third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD
+++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD
@@ -119,7 +119,6 @@
         "//tsl/lib/gtl:map_util",
         "//tsl/platform:logging",
         "//tsl/platform:macros",
-        "//tsl/platform:regexp",
         "//tsl/platform:types",
         "//tsl/profiler/lib:context_types_hdrs",
         "@com_google_absl//absl/container:flat_hash_map",
diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/preprocess_xplane.h b/third_party/xla/third_party/tsl/tsl/profiler/utils/preprocess_xplane.h
index 6ccfadb..2cbc96b 100644
--- a/third_party/xla/third_party/tsl/tsl/profiler/utils/preprocess_xplane.h
+++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/preprocess_xplane.h
@@ -27,6 +27,7 @@
 
 #include "absl/hash/hash.h"
 #include "absl/memory/memory.h"
+#include "absl/strings/match.h"
 #include "absl/strings/string_view.h"
 #include "absl/types/optional.h"
 #include "tsl/profiler/lib/context_types.h"
diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/trace_utils.h b/third_party/xla/third_party/tsl/tsl/profiler/utils/trace_utils.h
index d258098..90cee79 100644
--- a/third_party/xla/third_party/tsl/tsl/profiler/utils/trace_utils.h
+++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/trace_utils.h
@@ -29,8 +29,10 @@
 // Support up to 500 accelerator devices.
 constexpr uint32 kFirstDeviceId = 1;
 constexpr uint32 kLastDeviceId = 500;
+// Support Upto 200 custom planes.
+constexpr uint32 kCustomPlaneDeviceId = kLastDeviceId + 1;
 // Host threads are shown as a single fake device.
-constexpr uint32 kHostThreadsDeviceId = kLastDeviceId + 1;
+constexpr uint32 kHostThreadsDeviceId = kCustomPlaneDeviceId + 200;
 
 // Constants used as trace_viewer TID (resource_id in trace_events.proto).
 constexpr int kThreadIdDerivedMin = 0xdeadbeef;
diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc
index 7702cc9..2f7eb63 100644
--- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc
+++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc
@@ -23,9 +23,6 @@
 #include "absl/strings/string_view.h"
 #include "absl/types/optional.h"
 #include "tsl/lib/gtl/map_util.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/regexp.h"
-#include "tsl/platform/types.h"
 #include "tsl/profiler/utils/tf_op_utils.h"
 
 namespace tsl {
@@ -70,6 +67,10 @@
 constexpr int kNumStatTypes =
     StatType::kLastStatType - StatType::kFirstStatType + 1;
 
+constexpr int kNumMegaScaleStatTypes =
+    MegaScaleStatType::kLastMegaScaleStatType -
+    MegaScaleStatType::kFirstMegaScaleStatType + 1;
+
 constexpr int kNumLineIdTypes =
     LineIdType::kLastLineIdType - LineIdType::kFirstLineIdType + 1;
 
@@ -78,6 +79,10 @@
     absl::flat_hash_map<HostEventType, absl::string_view>;
 using StatTypeMap = absl::flat_hash_map<absl::string_view, StatType>;
 using StatTypeStrMap = absl::flat_hash_map<StatType, absl::string_view>;
+using MegaScaleStatTypeMap =
+    absl::flat_hash_map<absl::string_view, MegaScaleStatType>;
+using MegaScaleStatTypeStrMap =
+    absl::flat_hash_map<MegaScaleStatType, absl::string_view>;
 using LineIdTypeMap = absl::flat_hash_map<absl::string_view, LineIdType>;
 using LineIdTypeStrMap = absl::flat_hash_map<LineIdType, absl::string_view>;
 
@@ -276,6 +281,7 @@
       {"Hlo Proto", kHloProto},
       {"EdgeTPU Model information", kEdgeTpuModelInfo},
       {"EdgeTPU Model Profile information", kEdgeTpuModelProfileInfo},
+      {"EdgeTPU MLIR", kEdgeTpuMlir},
       // Device capability related.
       {"clock_rate", kDevCapClockRateKHz},
       {"core_count", kDevCapCoreCount},
@@ -327,6 +333,32 @@
   return *stat_type_map;
 }
 
+const MegaScaleStatTypeMap& GetMegaScaleStatTypeMap() {
+  static auto* stat_type_map = new MegaScaleStatTypeMap({
+      {"graph_key", kMegaScaleGraphKey},
+      {"local_device_id", kMegaScaleLocalDeviceId},
+      {"num_actions", kMegaScaleNumActions},
+      {"collective_type", kMegaScaleCollectiveType},
+      {"input_size", kMegaScaleInputSize},
+      {"slack_us", kMegaScaleSlackUs},
+      {"action_type", kMegaScaleActionType},
+      {"start_end_type", kMegaScaleStartEndType},
+      {"action_index", kMegaScaleActionIndex},
+      {"action_duration_ns", kMegaScaleActionDurationNs},
+      {"action_inputs", kMegaScaleActionInputs},
+      {"transfer_source", kMegaScaleTransferSource},
+      {"transfer_destinations", kMegaScaleTransferDestinations},
+      {"buffer_sizes", kMegaScaleBufferSizes},
+      {"compute_operation", kMegaScaleComputeOperation},
+      {"chunk", kMegaScaleChunk},
+      {"launch_id", kMegaScaleLaunchId},
+      {"loop_iteration", kMegaScaleLoopIteration},
+      {"graph_protos", kMegaScaleGraphProtos},
+  });
+  DCHECK_EQ(stat_type_map->size(), kNumMegaScaleStatTypes);
+  return *stat_type_map;
+}
+
 const LineIdTypeMap& GetLineIdTypeMap() {
   static auto* line_id_type_map = new LineIdTypeMap({
       {"UnknownLineIdType", kUnknownLineIdType},
@@ -349,6 +381,12 @@
   return *stat_type_str_map;
 }
 
+const MegaScaleStatTypeStrMap& GetMegaScaleStatTypeStrMap() {
+  static auto* stat_type_str_map = new MegaScaleStatTypeStrMap(
+      gtl::ReverseMap<MegaScaleStatTypeStrMap>(GetMegaScaleStatTypeMap()));
+  return *stat_type_str_map;
+}
+
 const LineIdTypeStrMap& GetLineIdTypeStrMap() {
   static auto* line_id_type_str_map = new LineIdTypeStrMap(
       gtl::ReverseMap<LineIdTypeStrMap>(GetLineIdTypeMap()));
@@ -392,6 +430,17 @@
   return std::nullopt;
 }
 
+absl::string_view GetMegaScaleStatTypeStr(MegaScaleStatType stat_type) {
+  return GetMegaScaleStatTypeStrMap().at(stat_type);
+}
+
+std::optional<int64_t> FindMegaScaleStatType(absl::string_view stat_name) {
+  if (auto stat_type = gtl::FindOrNull(GetMegaScaleStatTypeMap(), stat_name)) {
+    return *stat_type;
+  }
+  return std::nullopt;
+}
+
 absl::string_view GetLineIdTypeStr(LineIdType line_id_type) {
   return GetLineIdTypeStrMap().at(line_id_type);
 }
diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h
index 4642705..8fa3207 100644
--- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h
+++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h
@@ -22,7 +22,6 @@
 #include <string>
 
 #include "absl/hash/hash.h"
-#include "absl/strings/match.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/string_view.h"
 #include "absl/types/optional.h"
@@ -312,7 +311,32 @@
   kDcnLoopIndex,
   kEdgeTpuModelInfo,
   kEdgeTpuModelProfileInfo,
-  kLastStatType = kEdgeTpuModelProfileInfo,
+  kEdgeTpuMlir,
+  kLastStatType = kEdgeTpuMlir,
+};
+
+enum MegaScaleStatType : uint8_t {
+  kMegaScaleGraphKey,
+  kFirstMegaScaleStatType = kMegaScaleGraphKey,
+  kMegaScaleLocalDeviceId,
+  kMegaScaleNumActions,
+  kMegaScaleCollectiveType,
+  kMegaScaleInputSize,
+  kMegaScaleSlackUs,
+  kMegaScaleActionType,
+  kMegaScaleStartEndType,
+  kMegaScaleActionIndex,
+  kMegaScaleActionDurationNs,
+  kMegaScaleActionInputs,
+  kMegaScaleTransferSource,
+  kMegaScaleTransferDestinations,
+  kMegaScaleBufferSizes,
+  kMegaScaleComputeOperation,
+  kMegaScaleChunk,
+  kMegaScaleLaunchId,
+  kMegaScaleLoopIteration,
+  kMegaScaleGraphProtos,
+  kLastMegaScaleStatType = kMegaScaleGraphProtos,
 };
 
 static constexpr uint32_t kLineIdOffset = 10000;
@@ -360,6 +384,15 @@
 
 std::optional<int64_t> FindStatType(absl::string_view stat_name);
 
+absl::string_view GetMegaScaleStatTypeStr(MegaScaleStatType stat_type);
+
+inline bool IsMegaScaleStatType(MegaScaleStatType stat_type,
+                                absl::string_view stat_name) {
+  return GetMegaScaleStatTypeStr(stat_type) == stat_name;
+}
+
+std::optional<int64_t> FindMegaScaleStatType(absl::string_view stat_name);
+
 // Returns true if the given event shouldn't be shown in the trace viewer.
 bool IsInternalEvent(std::optional<int64_t> event_type);
 
diff --git a/third_party/xla/third_party/tsl/tsl/util/command_line_flags.cc b/third_party/xla/third_party/tsl/tsl/util/command_line_flags.cc
index ee812ef..520962f 100644
--- a/third_party/xla/third_party/tsl/tsl/util/command_line_flags.cc
+++ b/third_party/xla/third_party/tsl/tsl/util/command_line_flags.cc
@@ -92,11 +92,14 @@
       *value_parsing_ok = hook(true);
       return true;
     }
-
-    if (absl::EqualsIgnoreCase(arg, "=true")) {
+    // It's probably another argument name which begins with the name of this.
+    if (!absl::ConsumePrefix(&arg, "=")) {
+      return false;
+    }
+    if (absl::EqualsIgnoreCase(arg, "true")) {
       *value_parsing_ok = hook(true);
       return true;
-    } else if (absl::EqualsIgnoreCase(arg, "=false")) {
+    } else if (absl::EqualsIgnoreCase(arg, "false")) {
       *value_parsing_ok = hook(false);
       return true;
     } else {
diff --git a/third_party/xla/third_party/tsl/workspace1.bzl b/third_party/xla/third_party/tsl/workspace1.bzl
index 4cfb6da..2495080 100644
--- a/third_party/xla/third_party/tsl/workspace1.bzl
+++ b/third_party/xla/third_party/tsl/workspace1.bzl
@@ -3,7 +3,6 @@
 load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
 load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")
 load("@io_bazel_rules_closure//closure:defs.bzl", "closure_repositories")
-load("@rules_cuda//cuda:dependencies.bzl", "rules_cuda_dependencies")
 load("@rules_pkg//:deps.bzl", "rules_pkg_dependencies")
 
 # buildifier: disable=unnamed-macro
@@ -14,7 +13,6 @@
       with_rules_cc: whether to load and patch rules_cc repository.
     """
     native.register_toolchains("@local_config_python//:py_toolchain")
-    rules_cuda_dependencies(with_rules_cc)
     rules_pkg_dependencies()
 
     closure_repositories()
diff --git a/third_party/xla/third_party/tsl/workspace2.bzl b/third_party/xla/third_party/tsl/workspace2.bzl
index 895c1b1..1343f9e 100644
--- a/third_party/xla/third_party/tsl/workspace2.bzl
+++ b/third_party/xla/third_party/tsl/workspace2.bzl
@@ -13,6 +13,7 @@
 load("//third_party/absl:workspace.bzl", absl = "repo")
 load("//third_party/benchmark:workspace.bzl", benchmark = "repo")
 load("//third_party/clang_toolchain:cc_configure_clang.bzl", "cc_download_clang_toolchain")
+load("//third_party/ducc:workspace.bzl", ducc = "repo")
 load("//third_party/eigen3:workspace.bzl", eigen3 = "repo")
 load("//third_party/farmhash:workspace.bzl", farmhash = "repo")
 load("//third_party/gemmlowp:workspace.bzl", gemmlowp = "repo")
@@ -44,6 +45,7 @@
     """ Load third party repositories.  See above load() statements. """
     absl()
     benchmark()
+    ducc()
     eigen3()
     farmhash()
     gemmlowp()
@@ -424,9 +426,9 @@
         name = "nccl_archive",
         build_file = "//third_party:nccl/archive.BUILD",
         patch_file = ["//third_party/nccl:archive.patch"],
-        sha256 = "0e3d7b6295beed81dc15002e88abf7a3b45b5c686b13b779ceac056f5612087f",
-        strip_prefix = "nccl-2.16.5-1",
-        urls = tf_mirror_urls("https://github.com/nvidia/nccl/archive/v2.16.5-1.tar.gz"),
+        sha256 = "16ac98f3e926c024ce48e10ab220e19ce734adc48c423cfd55ad6f509bd1179f",
+        strip_prefix = "nccl-2.18.5-1",
+        urls = tf_mirror_urls("https://github.com/nvidia/nccl/archive/v2.18.5-1.tar.gz"),
     )
 
     java_import_external(
diff --git a/third_party/xla/third_party/zlib.BUILD b/third_party/xla/third_party/zlib.BUILD
deleted file mode 100644
index b8ca17d..0000000
--- a/third_party/xla/third_party/zlib.BUILD
+++ /dev/null
@@ -1,43 +0,0 @@
-package(default_visibility = ["//visibility:public"])
-
-licenses(["notice"])  # BSD/MIT-like license (for zlib)
-
-cc_library(
-    name = "zlib",
-    srcs = [
-        "adler32.c",
-        "compress.c",
-        "crc32.c",
-        "crc32.h",
-        "deflate.c",
-        "deflate.h",
-        "gzclose.c",
-        "gzguts.h",
-        "gzlib.c",
-        "gzread.c",
-        "gzwrite.c",
-        "infback.c",
-        "inffast.c",
-        "inffast.h",
-        "inffixed.h",
-        "inflate.c",
-        "inflate.h",
-        "inftrees.c",
-        "inftrees.h",
-        "trees.c",
-        "trees.h",
-        "uncompr.c",
-        "zconf.h",
-        "zutil.c",
-        "zutil.h",
-    ],
-    hdrs = ["zlib.h"],
-    copts = select({
-        "@local_tsl//tsl:windows": [],
-        "//conditions:default": [
-            "-Wno-shift-negative-value",
-            "-DZ_HAVE_UNISTD_H",
-        ],
-    }),
-    includes = ["."],
-)
diff --git a/third_party/xla/tools/toolchains/python/python_repo.bzl b/third_party/xla/tools/toolchains/python/python_repo.bzl
index 59be9f6..77011b2 100644
--- a/third_party/xla/tools/toolchains/python/python_repo.bzl
+++ b/third_party/xla/tools/toolchains/python/python_repo.bzl
@@ -5,7 +5,7 @@
 """
 
 VERSIONS = ["3.9", "3.10", "3.11", "3.12"]
-DEFAULT_VERSION = "3.10"
+DEFAULT_VERSION = "3.11"
 WARNING = """
 TF_PYTHON_VERSION environment variable was not set correctly; using Python {}.
 
@@ -13,6 +13,11 @@
 export TF_PYTHON_VERSION=3.11
 """.format(DEFAULT_VERSION)
 
+content = """
+TF_PYTHON_VERSION = "{}"
+HERMETIC_PYTHON_VERSION = "{}"
+"""
+
 def _python_repository_impl(repository_ctx):
     repository_ctx.file("BUILD", "")
     version = repository_ctx.os.environ.get("TF_PYTHON_VERSION", "")
@@ -21,8 +26,7 @@
         version = DEFAULT_VERSION
     repository_ctx.file(
         "py_version.bzl",
-        "HERMETIC_PYTHON_VERSION = \"%s\"" %
-        version,
+        content.format(version, version),
     )
 
 python_repository = repository_rule(
diff --git a/third_party/xla/tools/toolchains/remote_config/configs.bzl b/third_party/xla/tools/toolchains/remote_config/configs.bzl
index 3b63e3a..4554463 100644
--- a/third_party/xla/tools/toolchains/remote_config/configs.bzl
+++ b/third_party/xla/tools/toolchains/remote_config/configs.bzl
@@ -602,3 +602,82 @@
             "TF_TENSORRT_VERSION": "8.6",
         },
     )
+
+    sigbuild_tf_configs(
+        name_container_map = {
+            "sigbuild-r2.16": "docker://gcr.io/tensorflow-sigs/build@sha256:c13559bbf5df818bb586ad0880b29c409398b56fd8cc122ab0b31dc2b2416505",
+            "sigbuild-r2.16-python3.9": "docker://gcr.io/tensorflow-sigs/build@sha256:c13559bbf5df818bb586ad0880b29c409398b56fd8cc122ab0b31dc2b2416505",
+            "sigbuild-r2.16-python3.10": "docker://gcr.io/tensorflow-sigs/build@sha256:93c234df4c781af6974d86e9d1dd2e19ce0845b1b662c38e9a30d1de64eab3b0",
+            "sigbuild-r2.16-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:d0a91705406aad65a79011683b8f7d4b8131625ea26a6d08aa7c6eb6955873a2",
+            "sigbuild-r2.16-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:ed7313f95bce391cbf3b498ff6c534d163cc2bb91ca1d6ef6363bde4fd9e0cfc",
+        },
+        # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12
+        # and manylinux2014 is 2.17.
+        env = {
+            "ABI_LIBC_VERSION": "glibc_2.19",
+            "ABI_VERSION": "gcc",
+            "BAZEL_COMPILER": "/dt9/usr/bin/gcc",
+            "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu",
+            "BAZEL_TARGET_CPU": "k8",
+            "BAZEL_TARGET_LIBC": "glibc_2.19",
+            "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu",
+            "CC": "/dt9/usr/bin/gcc",
+            "CC_TOOLCHAIN_NAME": "linux_gnu_x86",
+            "CLEAR_CACHE": "1",
+            "CUDNN_INSTALL_PATH": "/usr/lib/x86_64-linux-gnu",
+            "GCC_HOST_COMPILER_PATH": "/dt9/usr/bin/gcc",
+            "GCC_HOST_COMPILER_PREFIX": "/usr/bin",
+            "HOST_CXX_COMPILER": "/dt9/usr/bin/gcc",
+            "HOST_C_COMPILER": "/dt9/usr/bin/gcc",
+            "PYTHON_BIN_PATH": "/usr/bin/python3",
+            "TENSORRT_INSTALL_PATH": "/usr/lib/x86_64-linux-gnu",
+            "TF_CUDA_CLANG": "0",
+            "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0",
+            "TF_CUDA_VERSION": "12.2",
+            "TF_CUDNN_VERSION": "8.9",
+            "TF_ENABLE_XLA": "1",
+            "TF_NEED_CUDA": "1",
+            "TF_NEED_TENSORRT": "1",
+            "TF_SYSROOT": "/dt9",
+            "TF_TENSORRT_VERSION": "8.6",
+        },
+    )
+
+    sigbuild_tf_configs(
+        name_container_map = {
+            "sigbuild-r2.16-clang": "docker://gcr.io/tensorflow-sigs/build@sha256:c13559bbf5df818bb586ad0880b29c409398b56fd8cc122ab0b31dc2b2416505",
+            "sigbuild-r2.16-clang-python3.9": "docker://gcr.io/tensorflow-sigs/build@sha256:c13559bbf5df818bb586ad0880b29c409398b56fd8cc122ab0b31dc2b2416505",
+            "sigbuild-r2.16-clang-python3.10": "docker://gcr.io/tensorflow-sigs/build@sha256:93c234df4c781af6974d86e9d1dd2e19ce0845b1b662c38e9a30d1de64eab3b0",
+            "sigbuild-r2.16-clang-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:d0a91705406aad65a79011683b8f7d4b8131625ea26a6d08aa7c6eb6955873a2",
+            "sigbuild-r2.16-clang-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:ed7313f95bce391cbf3b498ff6c534d163cc2bb91ca1d6ef6363bde4fd9e0cfc",
+        },
+        # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12
+        # and manylinux2014 is 2.17.
+        env = {
+            "ABI_LIBC_VERSION": "glibc_2.19",
+            "ABI_VERSION": "gcc",
+            "BAZEL_COMPILER": "/usr/lib/llvm-17/bin/clang",
+            "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu",
+            "BAZEL_TARGET_CPU": "k8",
+            "BAZEL_TARGET_LIBC": "glibc_2.19",
+            "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu",
+            "CC": "/usr/lib/llvm-17/bin/clang",
+            "CC_TOOLCHAIN_NAME": "linux_gnu_x86",
+            "CLEAR_CACHE": "1",
+            "CUDNN_INSTALL_PATH": "/usr/lib/x86_64-linux-gnu",
+            "CLANG_CUDA_COMPILER_PATH": "/usr/lib/llvm-17/bin/clang",
+            "HOST_CXX_COMPILER": "/usr/lib/llvm-17/bin/clang",
+            "HOST_C_COMPILER": "/usr/lib/llvm-17/bin/clang",
+            "PYTHON_BIN_PATH": "/usr/bin/python3",
+            "TENSORRT_INSTALL_PATH": "/usr/lib/x86_64-linux-gnu",
+            "TF_CUDA_CLANG": "1",
+            "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0",
+            "TF_CUDA_VERSION": "12.2",
+            "TF_CUDNN_VERSION": "8.9",
+            "TF_ENABLE_XLA": "1",
+            "TF_NEED_CUDA": "1",
+            "TF_NEED_TENSORRT": "1",
+            "TF_SYSROOT": "/dt9",
+            "TF_TENSORRT_VERSION": "8.6",
+        },
+    )
diff --git a/third_party/xla/tools/toolchains/remote_config/containers.bzl b/third_party/xla/tools/toolchains/remote_config/containers.bzl
index 1b540ed..bfb4634 100644
--- a/third_party/xla/tools/toolchains/remote_config/containers.bzl
+++ b/third_party/xla/tools/toolchains/remote_config/containers.bzl
@@ -1,11 +1,11 @@
 """Docker images used with remote config and RBE."""
 
-"""SHA 256 values for each image."""
+# SHA 256 values for each image.
 container_digests = {
     # TF now uses only this container
     "cuda11.2-cudnn8.1-ubuntu20.04-manylinux2014-multipython": "sha256:48612bd85709cd014711d0b0f87e0806f3567d06d2e81c6e860516b87498b821",
     # JAX manylinux2014 configs.
-    "cuda11.8-cudnn8.6-ubuntu20.04-manylinux2014-multipython": "sha256:77234e5750afcf85c08e8980eff2e8c58ba207a0c32b06a372cafb687d144d2b",
+    "cuda11.8-cudnn8.6-ubuntu20.04-manylinux2014-multipython": "sha256:ab39410baf2fc1d31d50540acec7640d7f4814fa694e2421b696b6f0a058d645",
     "cuda12.2-cudnn8.9-ubuntu20.04-manylinux2014-multipython": "sha256:b699d6ae235ac601dc3e62391ac7c4606cb10331f8141983858c1580f5e74ddb",
     # ROCM, probably not all of them still in use
     "rocm-ubuntu18.04-manylinux2010-multipython": "sha256:6e953a09b145df338bcb03e9e36f99b291140c29b72d0a048fb6c5905ccad5eb",
diff --git a/third_party/xla/xla/BUILD b/third_party/xla/xla/BUILD
index bc05eaa..24100c8 100644
--- a/third_party/xla/xla/BUILD
+++ b/third_party/xla/xla/BUILD
@@ -216,7 +216,9 @@
     name = "types_test",
     size = "small",
     srcs = ["types_test.cc"],
-    visibility = [":friends"],
+    visibility = [
+        "//visibility:private",  # Only private by automation, not intent. Owner may accept CLs adding visibility. See go/scheuklappen#explicit-private.
+    ],
     deps = [
         ":test",
         ":types",
@@ -431,6 +433,7 @@
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/container:inlined_vector",
         "@com_google_absl//absl/functional:function_ref",
+        "@com_google_absl//absl/log:check",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/strings:str_format",
         "@com_google_absl//absl/synchronization",
@@ -1075,6 +1078,7 @@
             "@com_google_absl//absl/strings",
             "@com_google_absl//absl/strings:str_format",
             "@local_tsl//tsl/platform:logging",
+            "@local_tsl//tsl/platform:protobuf",
             "@local_tsl//tsl/util:command_line_flags",
         ],
 )
diff --git a/third_party/xla/xla/array.h b/third_party/xla/xla/array.h
index 93c956b..eebc13b 100644
--- a/third_party/xla/xla/array.h
+++ b/third_party/xla/xla/array.h
@@ -438,6 +438,8 @@
 
     OwnedBuffer<int64_t> sizes(starts.size());
     for (int64_t i = 0; i < starts.size(); ++i) {
+      CHECK_GE(starts[i], 0);
+      CHECK_LE(limits[i], dim(i));
       sizes[i] = limits[i] - starts[i];
     }
     Array<T> result(sizes.span());
diff --git a/third_party/xla/xla/autotuning.proto b/third_party/xla/xla/autotuning.proto
index d3bad5a..b177fe3 100644
--- a/third_party/xla/xla/autotuning.proto
+++ b/third_party/xla/xla/autotuning.proto
@@ -73,6 +73,8 @@
     string exec_plan_id = 1;
   }
 
+  // If you don't need a proto in your code, please use TritonGemmConfig instead
+  // of using this proto directly.
   message TritonGemmKey {
     int64 block_m = 1;
     int64 block_n = 2;
diff --git a/third_party/xla/xla/backends/interpreter/BUILD b/third_party/xla/xla/backends/interpreter/BUILD
index 3c592c8..085ce03 100644
--- a/third_party/xla/xla/backends/interpreter/BUILD
+++ b/third_party/xla/xla/backends/interpreter/BUILD
@@ -31,38 +31,37 @@
     deps = [
         ":executable",
         ":platform_id",
+        "//xla:literal",
         "//xla:status",
         "//xla:status_macros",
         "//xla:statusor",
+        "//xla:util",
+        "//xla/hlo/evaluator:hlo_evaluator",
         "//xla/hlo/ir:hlo",
-        "//xla/service:algebraic_simplifier",
+        "//xla/hlo/ir:hlo_module_group",
         "//xla/service:batchnorm_expander",
         "//xla/service:cholesky_expander",
         "//xla/service:comparison_expander",
         "//xla/service:compiler",
         "//xla/service:computation_placer",
         "//xla/service:custom_call_target_registry",
+        "//xla/service:dynamic_dimension_inference",
         "//xla/service:dynamic_index_splitter",
         "//xla/service:eigh_expander",
         "//xla/service:executable",
-        "//xla/service:flatten_call_graph",
-        "//xla/service:hlo_constant_folding",
         "//xla/service:hlo_cost_analysis",
-        "//xla/service:hlo_cse",
-        "//xla/service:hlo_dce",
         "//xla/service:hlo_module_config",
-        "//xla/service:hlo_pass",
         "//xla/service:hlo_pass_pipeline",
         "//xla/service:layout_assignment",
-        "//xla/service:map_inliner",
         "//xla/service:qr_expander",
-        "//xla/service:reshape_mover",
         "//xla/service:topk_rewriter",
         "//xla/service:triangular_solve_expander",
-        "//xla/service:while_loop_simplifier",
         "//xla/stream_executor",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/types:span",
         "@local_tsl//tsl/platform:errors",
         "@local_tsl//tsl/platform:status",
+        "@local_tsl//tsl/platform:statusor",
     ],
     alwayslink = True,  # Contains compiler registration
 )
diff --git a/third_party/xla/xla/backends/interpreter/compiler.cc b/third_party/xla/xla/backends/interpreter/compiler.cc
index 930e2c5..864d98a 100644
--- a/third_party/xla/xla/backends/interpreter/compiler.cc
+++ b/third_party/xla/xla/backends/interpreter/compiler.cc
@@ -18,31 +18,41 @@
 #include <memory>
 #include <string>
 #include <utility>
+#include <vector>
 
+#include "absl/log/log.h"
+#include "absl/types/span.h"
 #include "xla/backends/interpreter/executable.h"
-#include "xla/service/algebraic_simplifier.h"
+#include "xla/backends/interpreter/platform_id.h"
+#include "xla/hlo/evaluator/hlo_evaluator.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_module_group.h"
+#include "xla/literal.h"
 #include "xla/service/batchnorm_expander.h"
 #include "xla/service/cholesky_expander.h"
 #include "xla/service/comparison_expander.h"
+#include "xla/service/compiler.h"
 #include "xla/service/computation_placer.h"
 #include "xla/service/custom_call_target_registry.h"
+#include "xla/service/dynamic_dimension_inference.h"
 #include "xla/service/dynamic_index_splitter.h"
 #include "xla/service/eigh_expander.h"
-#include "xla/service/flatten_call_graph.h"
-#include "xla/service/hlo_constant_folding.h"
-#include "xla/service/hlo_cse.h"
-#include "xla/service/hlo_dce.h"
-#include "xla/service/hlo_pass_fix.h"
+#include "xla/service/executable.h"
+#include "xla/service/hlo_cost_analysis.h"
 #include "xla/service/hlo_pass_pipeline.h"
 #include "xla/service/layout_assignment.h"
-#include "xla/service/map_inliner.h"
 #include "xla/service/qr_expander.h"
-#include "xla/service/reshape_mover.h"
 #include "xla/service/topk_rewriter.h"
 #include "xla/service/triangular_solve_expander.h"
-#include "xla/service/while_loop_simplifier.h"
+#include "xla/status.h"
 #include "xla/status_macros.h"
+#include "xla/statusor.h"
+#include "xla/stream_executor/platform.h"
+#include "xla/stream_executor/stream_executor_pimpl.h"
+#include "xla/util.h"
 #include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
 
 namespace xla {
 namespace interpreter {
@@ -114,8 +124,13 @@
 
   VLOG(1) << "Run backend " << hlo_module->name();
 
-  TF_ASSIGN_OR_RETURN(DynamicDimensionInference dynamic_dimension_inference,
-                      DynamicDimensionInference::Run(hlo_module.get()));
+  TF_ASSIGN_OR_RETURN(
+      DynamicDimensionInference dynamic_dimension_inference,
+      DynamicDimensionInference::Run(
+          hlo_module.get(),
+          /*op_supports_dynamism_handler=*/[&](HloInstruction* hlo) {
+            return OpDynamismSupport::kOptional;
+          }));
 
   auto evaluator = std::make_unique<HloEvaluator>();
   evaluator->set_use_fast_path(
diff --git a/third_party/xla/xla/backends/profiler/plugin/profiler_c_api.h b/third_party/xla/xla/backends/profiler/plugin/profiler_c_api.h
index 643f943..6f54825 100644
--- a/third_party/xla/xla/backends/profiler/plugin/profiler_c_api.h
+++ b/third_party/xla/xla/backends/profiler/plugin/profiler_c_api.h
@@ -16,21 +16,21 @@
 #ifndef XLA_BACKENDS_PROFILER_PLUGIN_PROFILER_C_API_H_
 #define XLA_BACKENDS_PROFILER_PLUGIN_PROFILER_C_API_H_
 
-#include <cstddef>
-#include <cstdint>
+#include <stddef.h>
+#include <stdint.h>
 
 #define PROFILER_STRUCT_SIZE(struct_type, last_field) \
   offsetof(struct_type, last_field) + sizeof(((struct_type*)0)->last_field)
 
 #define PROFILER_DEFINE_STRUCT_TRAITS(sname, last_field) \
   typedef struct sname sname;                            \
-  const size_t sname##_STRUCT_SIZE = PROFILER_STRUCT_SIZE(sname, last_field);
+  enum { sname##_STRUCT_SIZE = PROFILER_STRUCT_SIZE(sname, last_field) }
 
 #ifdef __cplusplus
 extern "C" {
 #endif
 
-#define PLUGIN_PROFILER_VERSION 0
+#define PLUGIN_PROFILER_VERSION 1
 
 typedef struct PLUGIN_Profiler PLUGIN_Profiler;
 typedef struct PLUGIN_Profiler_Error PLUGIN_Profiler_Error;
diff --git a/third_party/xla/xla/client/BUILD b/third_party/xla/xla/client/BUILD
index 047f259..ca9d8a9 100644
--- a/third_party/xla/xla/client/BUILD
+++ b/third_party/xla/xla/client/BUILD
@@ -1,9 +1,9 @@
 # Description:
 #   XLA client libraries.
 
+load("//xla:xla.bzl", "xla_cc_test")
 load("@local_tsl//tsl:tsl.default.bzl", "filegroup")
 load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library")
-load("//xla:xla.bzl", "xla_cc_test")
 
 package(
     default_visibility = ["//visibility:public"],
@@ -256,6 +256,10 @@
         ":padding",
         ":sharding_builder",
         ":xla_computation",
+        "//xla:array",
+        "//xla:array2d",
+        "//xla:array3d",
+        "//xla:array4d",
         "//xla:comparison_util",
         "//xla:literal",
         "//xla:literal_util",
@@ -274,12 +278,16 @@
         "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/container:inlined_vector",
         "@com_google_absl//absl/functional:function_ref",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:span",
         "@local_tsl//tsl/lib/core:bitmap",
         "@local_tsl//tsl/platform:errors",
         "@local_tsl//tsl/platform:stacktrace",
+        "@local_tsl//tsl/platform:statusor",
     ],
 )
 
diff --git a/third_party/xla/xla/client/lib/BUILD b/third_party/xla/xla/client/lib/BUILD
index d54549b..134c8c4 100644
--- a/third_party/xla/xla/client/lib/BUILD
+++ b/third_party/xla/xla/client/lib/BUILD
@@ -30,8 +30,7 @@
     deps = [
         ":constants",
         "//xla:shape_util",
-        "//xla:status_macros",
-        "//xla:types",
+        "//xla:statusor",
         "//xla:xla_data_proto_cc",
         "//xla/client:xla_builder",
         "//xla/client:xla_computation",
@@ -209,6 +208,10 @@
 xla_test(
     name = "math_test",
     srcs = ["math_test.cc"],
+    backend_tags = {
+        # Times out.
+        "ghostfish_iss": ["noasan"],
+    },
     deps = [
         ":constants",
         ":math",
@@ -248,6 +251,7 @@
         "@com_google_absl//absl/container:inlined_vector",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:span",
+        "@local_tsl//tsl/platform:errors",
     ],
 )
 
@@ -430,6 +434,9 @@
     deps = [
         ":arithmetic",
         ":constants",
+        "//xla:shape_util",
+        "//xla:status_macros",
+        "//xla:statusor",
         "//xla:types",
         "//xla:util",
         "//xla/client:xla_builder",
diff --git a/third_party/xla/xla/client/lib/arithmetic.cc b/third_party/xla/xla/client/lib/arithmetic.cc
index 6ca80dd..eda6145 100644
--- a/third_party/xla/xla/client/lib/arithmetic.cc
+++ b/third_party/xla/xla/client/lib/arithmetic.cc
@@ -15,6 +15,7 @@
 
 #include "xla/client/lib/arithmetic.h"
 
+#include <cstdint>
 #include <memory>
 #include <numeric>
 #include <string>
@@ -24,9 +25,10 @@
 #include "xla/client/lib/constants.h"
 #include "xla/client/xla_builder.h"
 #include "xla/client/xla_computation.h"
+#include "xla/primitive_util.h"
+#include "xla/shape.h"
 #include "xla/shape_util.h"
-#include "xla/status_macros.h"
-#include "xla/types.h"
+#include "xla/statusor.h"
 #include "xla/xla_data.pb.h"
 
 namespace xla {
@@ -151,8 +153,8 @@
     int64_t dimension_size = input_shape.dimensions(axis);
     auto index_type = dimension_size <= INT32_MAX ? S32 : output_type;
     XlaOp index_init_value = Zero(builder, index_type);
-    auto iota_shape = input_shape;
-    iota_shape.set_element_type(index_type);
+    auto iota_shape =
+        ShapeUtil::MakeShape(index_type, input_shape.dimensions());
     XlaOp iota = Iota(builder, iota_shape, axis);
 
     XlaComputation reducer = CreateMinMaxComputation(
diff --git a/third_party/xla/xla/client/lib/matrix.cc b/third_party/xla/xla/client/lib/matrix.cc
index ccc24b1..e5b060a 100644
--- a/third_party/xla/xla/client/lib/matrix.cc
+++ b/third_party/xla/xla/client/lib/matrix.cc
@@ -17,11 +17,12 @@
 
 #include <algorithm>
 #include <array>
-#include <limits>
+#include <cstdint>
 #include <map>
 #include <numeric>
 #include <optional>
 #include <string>
+#include <tuple>
 #include <utility>
 #include <vector>
 
@@ -29,7 +30,7 @@
 #include "absl/container/flat_hash_set.h"
 #include "absl/container/inlined_vector.h"
 #include "absl/strings/ascii.h"
-#include "absl/strings/str_join.h"
+#include "absl/strings/str_cat.h"
 #include "absl/strings/str_split.h"
 #include "absl/strings/string_view.h"
 #include "absl/types/span.h"
@@ -39,12 +40,14 @@
 #include "xla/client/xla_builder.h"
 #include "xla/literal.h"
 #include "xla/primitive_util.h"
+#include "xla/shape.h"
 #include "xla/shape_util.h"
 #include "xla/status.h"
 #include "xla/status_macros.h"
 #include "xla/statusor.h"
 #include "xla/util.h"
 #include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
 
 namespace xla {
 
@@ -296,8 +299,7 @@
   XlaBuilder* builder = x.builder();
   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
     TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
-    Shape iota_shape = x_shape;
-    iota_shape.set_element_type(S32);
+    Shape iota_shape = ShapeUtil::MakeShape(S32, x_shape.dimensions());
     XlaOp mask = ConstantR0(builder, true);
 
     for (auto label = config.begin(); label != config.end(); ++label) {
@@ -385,25 +387,27 @@
                   xla::XlaOp y, absl::Span<const int64_t> y_config,
                   absl::Span<const int64_t> output_config,
                   xla::PrecisionConfig::Precision precision,
-                  std::optional<PrimitiveType> preferred_element_type) {
+                  std::optional<PrimitiveType> preferred_element_type,
+                  bool grad_x, bool grad_y) {
   XlaBuilder* builder = x.builder();
   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
     auto x_diagonal_labels = EinsumDiagonalLabels(x_config);
     if (x_diagonal_labels) {
       return Einsum(EinsumDiagonal(x, x_config), x_diagonal_labels->at(0), y,
-                    y_config, output_config, precision, preferred_element_type);
+                    y_config, output_config, precision, preferred_element_type,
+                    grad_x, grad_y);
     }
     auto y_diagonal_labels = EinsumDiagonalLabels(y_config);
     if (y_diagonal_labels) {
       return Einsum(x, x_config, EinsumDiagonal(y, y_config),
                     y_diagonal_labels->at(0), output_config, precision,
-                    preferred_element_type);
+                    preferred_element_type, grad_x, grad_y);
     }
     auto output_diagonal_labels = EinsumDiagonalLabels(output_config);
     if (output_diagonal_labels) {
       return EinsumInverseDiagonal(
           Einsum(x, x_config, y, y_config, output_diagonal_labels->at(0),
-                 precision, preferred_element_type),
+                 precision, preferred_element_type, grad_x, grad_y),
           output_config);
     }
 
@@ -547,6 +551,11 @@
     precision_proto.add_operand_precision(precision);
     auto dot =
         DotGeneral(x, y, dnums, &precision_proto, preferred_element_type);
+
+    TF_RETURN_IF_ERROR(builder->SetInstructionFrontendAttribute(
+        dot, "grad_x", (grad_x ? "true" : "false")));
+    TF_RETURN_IF_ERROR(builder->SetInstructionFrontendAttribute(
+        dot, "grad_y", (grad_y ? "true" : "false")));
     dot = Transpose(dot, transpose_dims);
     if (transpose_rank == output_rank) {
       return dot;
@@ -578,7 +587,8 @@
 
 XlaOp BatchDot(XlaOp x, bool transpose_x, XlaOp y, bool transpose_y,
                PrecisionConfig::Precision precision,
-               std::optional<PrimitiveType> preferred_element_type) {
+               std::optional<PrimitiveType> preferred_element_type, bool grad_x,
+               bool grad_y) {
   XlaBuilder* builder = x.builder();
   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
     std::string string("...mk,...kn->...mn");
@@ -588,7 +598,8 @@
     if (transpose_y) {
       std::swap(string[6 + 3], string[6 + 4]);
     }
-    return Einsum(x, y, string, precision, preferred_element_type);
+    return Einsum(x, y, string, precision, preferred_element_type, grad_x,
+                  grad_y);
   });
 }
 
@@ -709,12 +720,14 @@
 
 XlaOp Einsum(XlaOp x, XlaOp y, absl::string_view einsum_config,
              PrecisionConfig::Precision precision,
-             std::optional<PrimitiveType> preferred_element_type) {
+             std::optional<PrimitiveType> preferred_element_type, bool grad_x,
+             bool grad_y) {
   XlaBuilder* builder = x.builder();
   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
     auto new_config = NormalizeEinsumString(einsum_config);
     if (!new_config.empty()) {
-      return Einsum(x, y, new_config, precision, preferred_element_type);
+      return Einsum(x, y, new_config, precision, preferred_element_type, grad_x,
+                    grad_y);
     }
     TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
     TF_ASSIGN_OR_RETURN(Shape y_shape, builder->GetShape(y));
@@ -722,7 +735,8 @@
         auto einsum_config_numeric,
         ParseEinsumString(einsum_config, x_shape.rank(), y_shape.rank()));
     return Einsum(x, einsum_config_numeric[0], y, einsum_config_numeric[1],
-                  einsum_config_numeric[2], precision, preferred_element_type);
+                  einsum_config_numeric[2], precision, preferred_element_type,
+                  grad_x, grad_y);
   });
 }
 
diff --git a/third_party/xla/xla/client/lib/matrix.h b/third_party/xla/xla/client/lib/matrix.h
index bead189..48f75b2 100644
--- a/third_party/xla/xla/client/lib/matrix.h
+++ b/third_party/xla/xla/client/lib/matrix.h
@@ -97,7 +97,8 @@
 xla::XlaOp BatchDot(
     xla::XlaOp x, bool transpose_x, xla::XlaOp y, bool transpose_y,
     xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT,
-    std::optional<PrimitiveType> preferred_element_type = std::nullopt);
+    std::optional<PrimitiveType> preferred_element_type = std::nullopt,
+    bool grad_x = false, bool grad_y = false);
 
 // Parse an einsum string into dimension numbers:
 //   "ab,cb->ac"
@@ -128,12 +129,12 @@
 xla::XlaOp Einsum(
     xla::XlaOp x, xla::XlaOp y, absl::string_view einsum_config,
     xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT,
-    std::optional<PrimitiveType> preferred_element_type = std::nullopt);
+    std::optional<PrimitiveType> preferred_element_type = std::nullopt,
+    bool grad_x = false, bool grad_y = false);
 xla::XlaOp Einsum(
     xla::XlaOp x, absl::string_view einsum_config,
     xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT);
 
-
 // Same as above but supporting numeric labels on dimensions. So "ab,cb->ac"
 // becomes:
 //   x_config = {0, 1}
@@ -143,7 +144,8 @@
     xla::XlaOp x, absl::Span<const int64_t> x_config, xla::XlaOp y,
     absl::Span<const int64_t> y_config, absl::Span<const int64_t> output_config,
     xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT,
-    std::optional<PrimitiveType> preferred_element_type = std::nullopt);
+    std::optional<PrimitiveType> preferred_element_type = std::nullopt,
+    bool grad_x = false, bool grad_y = false);
 
 // Transposes a stack of matrices `x` by swapping the last two dimensions.
 xla::XlaOp TransposeInMinorDims(xla::XlaOp x);
diff --git a/third_party/xla/xla/client/lib/slicing.cc b/third_party/xla/xla/client/lib/slicing.cc
index a9da8af..914e41d 100644
--- a/third_party/xla/xla/client/lib/slicing.cc
+++ b/third_party/xla/xla/client/lib/slicing.cc
@@ -16,13 +16,19 @@
 #include "xla/client/lib/slicing.h"
 
 #include <algorithm>
+#include <cstdint>
 #include <functional>
 #include <limits>
 #include <vector>
 
+#include "absl/types/span.h"
 #include "xla/client/lib/arithmetic.h"
 #include "xla/client/lib/constants.h"
 #include "xla/client/xla_builder.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/status_macros.h"
+#include "xla/statusor.h"
 #include "xla/util.h"
 
 namespace xla {
@@ -287,8 +293,9 @@
       ShapeUtil::AppendMajorDimension(1, &index_shape);
       std::vector<XlaOp> to_concat;
       to_concat.reserve(batch_dims + 1);
+      xla::Shape iota_shape = xla::ShapeUtil::MakeStaticShape(index_shape);
       for (int64_t batch_dim = 0; batch_dim < batch_dims; ++batch_dim) {
-        to_concat.push_back(Iota(builder, index_shape, batch_dim));
+        to_concat.push_back(Iota(builder, iota_shape, batch_dim));
       }
       to_concat.push_back(Reshape(index, index_shape.dimensions()));
       index = ConcatInDim(builder, to_concat, gather_dnums.index_vector_dim());
diff --git a/third_party/xla/xla/client/xla_builder.cc b/third_party/xla/xla/client/xla_builder.cc
index 6540d6d..172ded2 100644
--- a/third_party/xla/xla/client/xla_builder.cc
+++ b/third_party/xla/xla/client/xla_builder.cc
@@ -15,6 +15,8 @@
 
 #include "xla/client/xla_builder.h"
 
+#include <cstddef>
+#include <cstdint>
 #include <functional>
 #include <iterator>
 #include <memory>
@@ -29,18 +31,28 @@
 #include "absl/algorithm/container.h"
 #include "absl/container/flat_hash_map.h"
 #include "absl/container/flat_hash_set.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/functional/function_ref.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
 #include "absl/strings/match.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/str_join.h"
 #include "absl/strings/str_split.h"
+#include "absl/strings/string_view.h"
 #include "absl/types/span.h"
+#include "xla/array.h"
+#include "xla/client/padding.h"
 #include "xla/client/sharding_builder.h"
 #include "xla/client/xla_computation.h"
 #include "xla/comparison_util.h"
 #include "xla/hlo/ir/hlo_input_output_alias_config.h"
 #include "xla/hlo/ir/hlo_opcode.h"
 #include "xla/hlo/ir/hlo_sharding.h"
+#include "xla/layout.h"
 #include "xla/layout_util.h"
+#include "xla/literal.h"
+#include "xla/literal_util.h"
 #include "xla/permutation_util.h"
 #include "xla/primitive_util.h"
 #include "xla/service/hlo.pb.h"
@@ -50,10 +62,13 @@
 #include "xla/sharding_op_util.h"
 #include "xla/status.h"
 #include "xla/status_macros.h"
+#include "xla/statusor.h"
 #include "xla/util.h"
 #include "xla/window_util.h"
 #include "xla/xla_data.pb.h"
 #include "tsl/platform/errors.h"
+#include "tsl/platform/stacktrace.h"
+#include "tsl/platform/statusor.h"
 
 namespace xla {
 
@@ -860,6 +875,22 @@
     instr.add_dimensions(dim);
   }
 
+  TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
+  for (int64_t i = 0; i < shape.rank(); i++) {
+    if (auto it = absl::c_find(broadcast_dimensions, i);
+        it != broadcast_dimensions.end()) {
+      // Broadcast dimensions are permitted to be dynamic iff the operand
+      // dimension is dynamic.
+      TF_RET_CHECK(operand_shape->is_dynamic_dimension(
+                       it - broadcast_dimensions.begin()) ==
+                   shape.is_dynamic_dimension(i))
+          << " i: " << i << ", shape: " << shape.ToString()
+          << ", operand_shape: " << operand_shape->ToString();
+    } else {
+      // Non-broadcast dimensions must not be dynamic.
+      TF_RET_CHECK(!shape.is_dynamic_dimension(i));
+    }
+  }
   return AddInstruction(std::move(instr), HloOpcode::kBroadcast, {operand});
 }
 
@@ -876,35 +907,34 @@
 
   // Do explicit broadcast for scalar.
   if (ShapeUtil::IsScalar(*operand_shape)) {
-    return InDimBroadcast(broadcast_shape, operand, {});
+    return InDimBroadcast(ShapeUtil::MakeStaticShape(broadcast_shape), operand,
+                          {});
   }
 
   // Do explicit broadcast for degenerate broadcast.
   std::vector<int64_t> broadcast_dimensions;
   std::vector<int64_t> reshaped_dimensions;
+  std::vector<bool> reshaped_dynamic_dimensions;
   for (int i = 0; i < operand_shape->rank(); i++) {
     if (operand_shape->dimensions(i) == output_shape.dimensions(i)) {
       broadcast_dimensions.push_back(i);
       reshaped_dimensions.push_back(operand_shape->dimensions(i));
+      reshaped_dynamic_dimensions.push_back(
+          operand_shape->is_dynamic_dimension(i));
     } else {
-      TF_RET_CHECK(operand_shape->dimensions(i) == 1)
+      TF_RET_CHECK(operand_shape->dimensions(i) == 1 &&
+                   !operand_shape->is_dynamic_dimension(i))
           << "An explicit broadcast sequence requires the broadcasted "
              "dimensions to be trivial; operand shape: "
           << *operand_shape << "; output_shape: " << output_shape;
     }
+    broadcast_shape.set_dynamic_dimension(
+        i, operand_shape->is_dynamic_dimension(i));
   }
 
   Shape reshaped_shape =
-      ShapeUtil::MakeShape(operand_shape->element_type(), reshaped_dimensions);
-
-  std::vector<std::pair<int64_t, int64_t>> unmodified_dims =
-      ShapeUtil::DimensionsUnmodifiedByReshape(*operand_shape, reshaped_shape);
-
-  for (auto& unmodified : unmodified_dims) {
-    if (operand_shape->is_dynamic_dimension(unmodified.first)) {
-      reshaped_shape.set_dynamic_dimension(unmodified.second, true);
-    }
-  }
+      ShapeUtil::MakeShape(operand_shape->element_type(), reshaped_dimensions,
+                           reshaped_dynamic_dimensions);
 
   // Eliminate the size one dimensions.
   TF_ASSIGN_OR_RETURN(
@@ -952,7 +982,7 @@
       to_size_is_dynamic.reserve(rank);
       for (int i = 0; i < rank; i++) {
         to_size.push_back(shape.dimensions(i));
-        to_size_is_dynamic.push_back(shape.is_dynamic_dimension(i));
+        to_size_is_dynamic.push_back(false);
       }
       for (int64_t from_dim = 0; from_dim < from_shape.rank(); from_dim++) {
         int64_t to_dim = broadcast_dimensions[from_dim];
@@ -1055,7 +1085,7 @@
                          shape->dimensions())
                 << "Unimplemented implicit broadcast.";
           } else {
-            non_scalar_shape = *shape;
+            non_scalar_shape = ShapeUtil::MakeStaticShape(*shape);
           }
         }
       }
@@ -1114,6 +1144,11 @@
 
 XlaOp XlaBuilder::Iota(const Shape& shape, int64_t iota_dimension) {
   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+    if (!shape.is_static()) {
+      return InvalidArgument(
+          "The output of iota must not have dynamic dimensions: %s",
+          shape.ToString());
+    }
     HloInstructionProto instr;
     *instr.mutable_shape() = shape.ToProto();
     instr.add_dimensions(iota_dimension);
@@ -1223,11 +1258,14 @@
                            *operand_shape, output_shape, broadcast_dimensions)
                            .status());
     std::vector<int64_t> in_dim_size(out_dim_size.begin(), out_dim_size.end());
+    std::vector<bool> in_dim_dynamic(out_dim_size.size(), false);
     for (int i = 0; i < broadcast_rank; i++) {
       in_dim_size[broadcast_dimensions[i]] = operand_shape->dimensions(i);
+      in_dim_dynamic[broadcast_dimensions[i]] =
+          operand_shape->is_dynamic_dimension(i);
     }
-    const auto& in_dim_shape =
-        ShapeUtil::MakeShape(operand_shape->element_type(), in_dim_size);
+    const auto& in_dim_shape = ShapeUtil::MakeShape(
+        operand_shape->element_type(), in_dim_size, in_dim_dynamic);
     TF_ASSIGN_OR_RETURN(
         XlaOp in_dim_broadcast,
         InDimBroadcast(in_dim_shape, operand, broadcast_dimensions));
diff --git a/third_party/xla/xla/client/xla_builder.h b/third_party/xla/xla/client/xla_builder.h
index 00e288c..cbc0259 100644
--- a/third_party/xla/xla/client/xla_builder.h
+++ b/third_party/xla/xla/client/xla_builder.h
@@ -19,6 +19,7 @@
 #include <cstdint>
 #include <deque>
 #include <functional>
+#include <initializer_list>
 #include <map>
 #include <memory>
 #include <optional>
@@ -31,20 +32,31 @@
 #include "absl/container/flat_hash_map.h"
 #include "absl/container/flat_hash_set.h"
 #include "absl/functional/function_ref.h"
+#include "absl/log/check.h"
 #include "absl/strings/string_view.h"
 #include "absl/types/span.h"
+#include "xla/array.h"
+#include "xla/array2d.h"
+#include "xla/array3d.h"
+#include "xla/array4d.h"
 #include "xla/client/padding.h"
 #include "xla/client/xla_computation.h"
 #include "xla/comparison_util.h"
 #include "xla/hlo/ir/dynamic_parameter_binding.h"
 #include "xla/hlo/ir/hlo_input_output_alias_config.h"
 #include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/layout.h"
 #include "xla/literal.h"
 #include "xla/literal_util.h"
 #include "xla/service/hlo.pb.h"
+#include "xla/shape.h"
 #include "xla/shape_util.h"
+#include "xla/status.h"
 #include "xla/statusor.h"
+#include "xla/util.h"
 #include "xla/xla_data.pb.h"
+#include "tsl/lib/core/bitmap.h"
+#include "tsl/platform/errors.h"
 #include "tsl/platform/stacktrace.h"
 
 namespace xla {
@@ -1050,6 +1062,9 @@
 
   // Internal helper method that creates a sequence of instructions that
   // performs an explicit broadcast of the operand to the target shape.
+  // All dimensions of the operand must either be equal to the corresponding
+  // output shape dimension, or be exactly 1.  (Such dimensions are the
+  // degenerate dimensions.)
   StatusOr<XlaOp> AddBroadcastSequence(const Shape& output_shape,
                                        XlaOp operand);
 
diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc
index 4400144..008cfab 100644
--- a/third_party/xla/xla/debug_options_flags.cc
+++ b/third_party/xla/xla/debug_options_flags.cc
@@ -25,11 +25,16 @@
 #include "absl/base/call_once.h"
 #include "absl/container/flat_hash_map.h"
 #include "absl/container/node_hash_map.h"
+#include "absl/strings/ascii.h"
+#include "absl/strings/str_cat.h"
 #include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
 #include "absl/strings/str_split.h"
+#include "absl/strings/string_view.h"
 #include "xla/debug_options_parsers.h"
 #include "xla/parse_flags_from_env.h"
 #include "xla/xla.pb.h"
+#include "tsl/platform/protobuf.h"  // IWYU pragma: keep
 #include "tsl/util/command_line_flags.h"
 
 namespace xla {
@@ -93,9 +98,8 @@
   // flag.
   opts.set_xla_gpu_enable_cublaslt(false);
 
-  // TODO(b/258036887): Create separate flags for enabling cuBLAS, cuDNN, and
-  // NCCL in GPU graphs.
-  opts.set_xla_gpu_graph_level(1);
+  opts.add_xla_gpu_enable_command_buffer(DebugOptions::FUSION);
+  opts.add_xla_gpu_enable_command_buffer(DebugOptions::CUBLAS);
   opts.set_xla_gpu_graph_num_runs_to_instantiate(-1);
   opts.set_xla_gpu_enable_persistent_temp_buffers(false);
   opts.set_xla_gpu_graph_min_graph_size(5);
@@ -150,6 +154,7 @@
   opts.set_xla_gpu_lhs_enable_gpu_async_tracker(true);
   opts.set_xla_gpu_enable_analytical_latency_estimator(false);
   opts.set_xla_gpu_pgle_profile_file_or_directory_path("");
+  opts.set_xla_gpu_memory_limit_slop_factor(95);
   opts.set_xla_gpu_enable_highest_priority_async_stream(true);
 
   opts.set_xla_gpu_enable_pipelined_collectives(false);
@@ -203,6 +208,7 @@
   opts.set_xla_gpu_ensure_minor_dot_contraction_dims(false);
   opts.set_xla_gpu_filter_kernels_spilling_registers_on_autotuning(true);
   opts.set_xla_gpu_llvm_verification_level(0);
+  opts.set_xla_gpu_enable_cub_radix_sort(true);
 
   return opts;
 }
@@ -355,6 +361,46 @@
         return true;
       };
 
+  // Custom "sub-parser" lambda for xla_gpu_graph_level.
+  auto setter_for_xla_gpu_graph_level = [debug_options](const int32_t level) {
+    debug_options->clear_xla_gpu_enable_command_buffer();
+    if (level >= 1) {
+      debug_options->add_xla_gpu_enable_command_buffer(DebugOptions::FUSION);
+    }
+    if (level >= 2) {
+      debug_options->add_xla_gpu_enable_command_buffer(DebugOptions::CUBLAS);
+    }
+    if (level >= 3) {
+      debug_options->add_xla_gpu_enable_command_buffer(DebugOptions::CUDNN);
+    }
+    return true;
+  };
+
+  auto command_types_to_string =
+      [](tsl::protobuf::RepeatedField<int> command_types) -> std::string {
+    struct Formatter {
+      void operator()(std::string* out, int type) const {
+        absl::StrAppend(out, DebugOptions::CommandBufferCmdType_Name(type));
+      }
+    };
+    return absl::StrJoin(command_types, ", ", Formatter());
+  };
+
+  // Custom "sub-parser" lambda for xla_gpu_enable_command_buffer.
+  auto setter_for_xla_gpu_enable_command_buffer =
+      [debug_options](const std::string& values) {
+        debug_options->clear_xla_gpu_enable_command_buffer();
+        for (const absl::string_view value : absl::StrSplit(values, ',')) {
+          DebugOptions::CommandBufferCmdType cmd_type;
+          if (!DebugOptions::CommandBufferCmdType_Parse(
+                  absl::AsciiStrToUpper(value), &cmd_type)) {
+            return false;
+          }
+          debug_options->add_xla_gpu_enable_command_buffer(cmd_type);
+        }
+        return true;
+      };
+
   // Custom "sub-parser" for xla_fuel.  Note that ConsumeFuel does not do any
   // locking on the fuel global variables.  This means that it's
   // illegal/undefined behavior to modify this flag value while the compiler is
@@ -943,11 +989,14 @@
                 debug_options->xla_gpu_enable_cublaslt(),
                 "Use cuBLASLt for GEMMs when possible."));
   flag_list->push_back(tsl::Flag(
-      "xla_gpu_graph_level",
-      int32_setter_for(&DebugOptions::set_xla_gpu_graph_level),
-      debug_options->xla_gpu_graph_level(),
-      "Set GPU graph level. 0 = off; 1 = capture fusions and memcpys; 2 = "
-      "capture gemms; 3 = capture convolutions."));
+      "xla_gpu_graph_level", setter_for_xla_gpu_graph_level, 1,
+      "The legacy flag for setting GPU graph level. Use "
+      "xla_gpu_enable_command_buffer in new use cases. 0 = off; 1 = capture "
+      "fusions and memcpys; 2 = capture gemms; 3 = capture convolutions."));
+  flag_list->push_back(tsl::Flag(
+      "xla_gpu_enable_command_buffer", setter_for_xla_gpu_enable_command_buffer,
+      command_types_to_string(debug_options->xla_gpu_enable_command_buffer()),
+      "The types of the commands that are recorded into command buffers"));
   flag_list->push_back(tsl::Flag(
       "xla_gpu_graph_num_runs_to_instantiate",
       int32_setter_for(
@@ -1132,6 +1181,11 @@
       debug_options->xla_gpu_lhs_enable_gpu_async_tracker(),
       "Enable GPU async tracker for latency-hiding scheduler in XLA:GPU"));
   flag_list->push_back(tsl::Flag(
+      "xla_gpu_memory_limit_slop_factor",
+      int32_setter_for(&DebugOptions::set_xla_gpu_memory_limit_slop_factor),
+      debug_options->xla_gpu_memory_limit_slop_factor(),
+      "Slop factor for memory limits in XLA:GPU"));
+  flag_list->push_back(tsl::Flag(
       "xla_gpu_enable_highest_priority_async_stream",
       bool_setter_for(
           &DebugOptions::set_xla_gpu_enable_highest_priority_async_stream),
@@ -1342,6 +1396,11 @@
       debug_options->xla_gpu_llvm_verification_level(),
       "Sets how often we verify the generated llvm modules. Higher "
       "levels mean more frequent verification. Currently supported: 0, 1."));
+  flag_list->push_back(tsl::Flag(
+      "xla_gpu_enable_cub_radix_sort",
+      bool_setter_for(&DebugOptions::set_xla_gpu_enable_cub_radix_sort),
+      debug_options->xla_gpu_enable_cub_radix_sort(),
+      "Enable radix sort using CUB for simple shapes"));
 }  // NOLINT(readability/fn_size)
 
 // Allocates flag_values and flag_objects; this function must not be called more
diff --git a/third_party/xla/xla/ffi/BUILD b/third_party/xla/xla/ffi/BUILD
new file mode 100644
index 0000000..af72601
--- /dev/null
+++ b/third_party/xla/xla/ffi/BUILD
@@ -0,0 +1,75 @@
+load("//xla:xla.bzl", "xla_cc_test")
+load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library")
+
+package(
+    default_visibility = ["//visibility:public"],
+)
+
+cc_library(
+    name = "api",
+    hdrs = ["//xla/ffi/api:api_headers"],
+    visibility = ["//visibility:public"],
+    deps = ["//xla/ffi/api:c_api"],
+)
+
+cc_library(
+    name = "call_frame",
+    srcs = ["call_frame.cc"],
+    hdrs = ["call_frame.h"],
+    visibility = ["//visibility:public"],
+    deps = [
+        "//xla:types",
+        "//xla:xla_data_proto_cc",
+        "//xla/ffi/api:c_api",
+        "//xla/ffi/api:c_api_internal",
+        "//xla/stream_executor:device_memory",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/types:span",
+    ],
+)
+
+cc_library(
+    name = "ffi",
+    srcs = ["ffi.cc"],
+    hdrs = ["ffi.h"],
+    visibility = ["//visibility:public"],
+    deps = [
+        ":api",
+        ":call_frame",
+        "//xla:status",
+        "//xla:statusor",
+        "//xla:types",
+        "//xla:xla_data_proto_cc",
+        "//xla/ffi/api:c_api",
+        "//xla/ffi/api:c_api_internal",
+        "//xla/runtime:memref_view",
+        "//xla/service:executable",
+        "//xla/stream_executor:device_memory",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:span",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:logging",
+    ],
+)
+
+xla_cc_test(
+    name = "ffi_test",
+    srcs = ["ffi_test.cc"],
+    deps = [
+        ":api",
+        ":call_frame",
+        ":ffi",
+        "//xla:xla_data_proto_cc",
+        "//xla/ffi/api:c_api",
+        "//xla/service:executable",
+        "//xla/stream_executor:device_memory",
+        "@com_google_absl//absl/status",
+        "@local_tsl//tsl/lib/core:status_test_util",
+        "@local_tsl//tsl/platform:test",
+        "@local_tsl//tsl/platform:test_main",
+    ],
+)
diff --git a/third_party/xla/xla/ffi/README.md b/third_party/xla/xla/ffi/README.md
new file mode 100644
index 0000000..cea97f3
--- /dev/null
+++ b/third_party/xla/xla/ffi/README.md
@@ -0,0 +1,23 @@
+# XLA FFI
+
+This is the next generation of XLA custom calls with rich type-safe APIs.
+
+https://en.wikipedia.org/wiki/Foreign_function_interface
+
+```
+A foreign function interface (FFI) is a mechanism by which a program written in
+one programming language can call routines or make use of services written or
+compiled in another one. An FFI is often used in contexts where calls are made
+into binary dynamic-link library.
+```
+
+XLA FFI is a mechanism by which an XLA program can call functions compiled with
+another programming language using a stable C API (which guarantees ABI
+compatibility between XLA and external functions). XLA FFI also provides a C++
+header-only library that hides all the details of underlying C API from the
+user.
+
+**WARNING:** Under construction. We already have a rich type-safe custom call
+mechanism for XLA runtime. However, it doesn't provide a stable C API. XLA FFI
+aims to replicate the usability of XLA runtime's custom calls with a stable
+C API.
\ No newline at end of file
diff --git a/third_party/xla/xla/ffi/api/BUILD b/third_party/xla/xla/ffi/api/BUILD
new file mode 100644
index 0000000..d18f89e
--- /dev/null
+++ b/third_party/xla/xla/ffi/api/BUILD
@@ -0,0 +1,48 @@
+load("@local_tsl//tsl:tsl.default.bzl", "filegroup")
+load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library")
+
+package(
+    default_visibility = ["//visibility:public"],
+)
+
+filegroup(
+    name = "api_headers",
+    srcs = ["api.h"],
+    visibility = ["//visibility:public"],
+)
+
+filegroup(
+    name = "c_api_headers",
+    srcs = ["c_api.h"],
+    visibility = ["//visibility:public"],
+)
+
+cc_library(
+    name = "api",
+    hdrs = [":api_headers"],
+    visibility = ["//visibility:public"],
+    deps = [":c_api"],
+)
+
+cc_library(
+    name = "c_api",
+    hdrs = ["c_api.h"],
+    visibility = ["//visibility:public"],
+)
+
+cc_library(
+    name = "c_api_internal",
+    hdrs = ["c_api_internal.h"],
+    visibility = ["//visibility:public"],
+    deps = [":c_api"],
+)
+
+cc_library(
+    name = "ffi",
+    hdrs = ["ffi.h"],
+    visibility = ["//visibility:public"],
+    deps = [
+        ":api",
+        ":c_api",
+    ],
+)
diff --git a/third_party/xla/xla/ffi/api/api.h b/third_party/xla/xla/ffi/api/api.h
new file mode 100644
index 0000000..46105cf
--- /dev/null
+++ b/third_party/xla/xla/ffi/api/api.h
@@ -0,0 +1,610 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_FFI_API_API_H_
+#define XLA_FFI_API_API_H_
+
+#include <algorithm>
+#include <array>
+#include <cstddef>
+#include <cstdint>
+#include <functional>
+#include <iterator>
+#include <memory>
+#include <optional>
+#include <sstream>
+#include <string>
+#include <string_view>
+#include <tuple>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+// This is a header-only base C++ library that defines templates for decoding
+// XLA FFI call frames and invoking corresponding C++ functions. This must have
+// no dependencies outside of the C++ standard library.
+//
+// There are two extensions to this base library:
+//
+//   (1) xla/ffi/api/ffi.h for defining "external" FFI handlers loaded from
+//       dynamic libraries potentially built with different toolchains and/or
+//       a different XLA commit. It is a header-only library without any
+//       dependencies.
+//
+//   (2) xla/ffi/ffi.h for defining "internal" FFI handlers that must be
+//       statically linked into the binary and must be built from the same
+//       commit using the same toolchain, as it provides access to XLA
+//       implementation details (e.g. ServiceExecutableOptions) and C++ ABI
+//       across different libraries is hard.
+//
+// Extensions define template specializations for argument-decoding hooks
+// defined in this file.
+
+#include "xla/ffi/api/c_api.h"
+
+namespace xla::ffi {
+
+// Forward declare template defined below.
+template <typename... Ts>
+class Binding;
+
+// Forward declare template defined below.
+template <typename Fn, typename... Ts>
+class Handler;
+
+//===----------------------------------------------------------------------===//
+// XLA FFI virtual base for implementing FFI handlers
+//===----------------------------------------------------------------------===//
+
+class Ffi {
+ public:
+  static Binding<> Bind();
+
+  virtual ~Ffi() = default;
+  virtual XLA_FFI_Error* Call(const XLA_FFI_CallFrame* call_frame) const = 0;
+
+  // Registers handler with an XLA runtime under the given name.
+  static inline XLA_FFI_Error* RegisterStaticHandler(const XLA_FFI_Api* api,
+                                                     std::string_view name,
+                                                     XLA_FFI_Handler* handler);
+
+ protected:
+  template <typename... Args>
+  static std::string StrCat(Args... args);
+
+  static inline XLA_FFI_Error* MakeError(const XLA_FFI_Api* api,
+                                         XLA_FFI_Error_Code errc,
+                                         std::string message);
+
+  static inline XLA_FFI_Error* InvalidArgument(const XLA_FFI_Api* api,
+                                               std::string message);
+
+  static inline XLA_FFI_Error* CheckStructSize(const XLA_FFI_Api* api,
+                                               std::string_view struct_name,
+                                               size_t expected, size_t actual);
+};
+
+XLA_FFI_Error* Ffi::RegisterStaticHandler(const XLA_FFI_Api* api,
+                                          std::string_view name,
+                                          XLA_FFI_Handler* handler) {
+  std::string name_str(name);  // make a copy to guarantee it's null terminated
+
+  XLA_FFI_Handler_Register_Args args;
+  args.struct_size = XLA_FFI_Handler_Register_Args_STRUCT_SIZE;
+  args.priv = nullptr;
+  args.name = name_str.c_str();
+  args.handler = handler;
+  return api->XLA_FFI_Handler_Register(&args);
+}
+
+template <typename... Args>
+std::string Ffi::StrCat(Args... args) {
+  std::stringstream ss;
+  (ss << ... << args);
+  return ss.str();
+}
+
+XLA_FFI_Error* Ffi::MakeError(const XLA_FFI_Api* api, XLA_FFI_Error_Code errc,
+                              std::string message) {
+  XLA_FFI_Error_Create_Args args;
+  args.struct_size = XLA_FFI_Error_Create_Args_STRUCT_SIZE;
+  args.priv = nullptr;
+  args.errc = errc;
+  args.message = message.c_str();
+  return api->XLA_FFI_Error_Create(&args);
+}
+
+XLA_FFI_Error* Ffi::InvalidArgument(const XLA_FFI_Api* api,
+                                    std::string message) {
+  return MakeError(api, XLA_FFI_Error_Code_INVALID_ARGUMENT,
+                   std::move(message));
+}
+
+XLA_FFI_Error* Ffi::CheckStructSize(const XLA_FFI_Api* api,
+                                    std::string_view struct_name,
+                                    size_t expected, size_t actual) {
+  if (expected != actual) {
+    return InvalidArgument(
+        api, StrCat("Unexpected ", struct_name, " size: expected ", expected,
+                    " got ", actual, ". Check installed software versions."));
+  }
+  return nullptr;
+}
+
+//===----------------------------------------------------------------------===//
+// Type tags for distinguishing handler argument types
+//===----------------------------------------------------------------------===//
+
+namespace internal {
+
+// A type tag to distinguish arguments tied to the attributes in the
+// `Binding` variadic template argument.
+template <typename T>
+struct AttrTag {};
+
+// A type tag to distinguish arguments extracted from an execution context.
+template <typename T>
+struct CtxTag {};
+
+}  // namespace internal
+
+//===----------------------------------------------------------------------===//
+// Binding variadic template defines FFI handler signature
+//===----------------------------------------------------------------------===//
+
+template <typename... Ts>
+class Binding {
+ public:
+  template <typename T>
+  Binding<Ts..., T> Arg() && {
+    return {std::move(*this)};
+  }
+
+  template <typename T>
+  Binding<Ts..., internal::CtxTag<T>> Ctx() && {
+    return {std::move(*this)};
+  }
+
+  template <typename T>
+  Binding<Ts..., internal::AttrTag<T>> Attr(std::string attr) && {
+    attrs_.push_back(std::move(attr));
+    return {std::move(*this)};
+  }
+
+  template <typename Fn>
+  std::unique_ptr<Handler<Fn, Ts...>> To(Fn fn) {
+    return std::unique_ptr<Handler<Fn, Ts...>>(
+        new Handler<Fn, Ts...>(std::forward<Fn>(fn), std::move(attrs_)));
+  }
+
+ private:
+  template <typename...>
+  friend class Binding;
+  friend class Ffi;
+
+  explicit Binding() {
+    static_assert(sizeof...(Ts) == 0, "arguments must be empty");
+  }
+
+  template <typename... TTs>
+  Binding(Binding<TTs...>&& other)  // NOLINT
+      : attrs_(std::move(other.attrs_)) {}
+
+  Binding(Binding&) = delete;
+
+  std::vector<std::string> attrs_;  // names of bound attributes
+};
+
+inline Binding<> Ffi::Bind() { return xla::ffi::Binding<>(); }
+
+//===----------------------------------------------------------------------===//
+// Arguments decoding implementation
+//===----------------------------------------------------------------------===//
+
+// XLA FFI arguments decoding must be defined by specializing this template.
+//
+// Example: decoding for the `MyType` arguments
+//
+//   template <>
+//   struct ArgDecoding<MyType> {
+//     static std::optional<MyType> Decode(XLA_FFI_ArgType type, void* arg);
+//   };
+//
+// If argument can't be decoded it should return the empty optional.
+template <typename T>
+struct ArgDecoding;
+
+//===----------------------------------------------------------------------===//
+// Attributes decoding implementation
+//===----------------------------------------------------------------------===//
+
+// XLA FFI attribute decoding must be defined by specializing this template.
+//
+// Example: decoding for the `MyType` attributes
+//
+//   template <>
+//   struct AttrDecoding<MyType> {
+//    static std::optional<MyType> Decode(std::string_view name,
+//                                        XLA_FFI_AttrType type, void* attr);
+//   }
+//
+template <typename T>
+struct AttrDecoding;
+
+//===----------------------------------------------------------------------===//
+// Context decoding implementation
+//===----------------------------------------------------------------------===//
+
+// XLA FFI execution context decoding must be defined by specializing this
+// template.
+//
+// Example: decoding for the `MyType` context
+//
+//   template <>
+//   struct CtxDecoding<MyType> {
+//    using Type = <handler argument type for context type MyType>;
+//    static std::optional<Type> Decode(const XLA_FFI_Api* api,
+//                                      XLA_FFI_ExecutionContext* ctx);
+//   }
+//
+// TODO(ezhulenev): Add an example for decoding opaque data passed together with
+// a handler registration (not yet implemented). Today this is only used as
+// internal implementation detail of builtin FFI handlers.
+template <typename T>
+struct CtxDecoding;
+
+//===----------------------------------------------------------------------===//
+// Result encoding implementation
+//===----------------------------------------------------------------------===//
+
+// XLA FFI result encoding (conversion from a returned status-like type to FFI
+// error type) must be defined by specializing this template.
+//
+// Example: encoding `absl::Status` result
+//
+//   template<>
+//   struct ResultEncoding<absl::Status> {
+//     XLA_FFI_Error* Encode(const XLA_FFI_Api* api, absl::Status status) {...}
+//   }
+//
+template <typename T>
+struct ResultEncoding;
+
+//===----------------------------------------------------------------------===//
+// Decoding arguments and attributes
+//===----------------------------------------------------------------------===//
+
+namespace internal {
+
+// When decoding input data we need to keep track of how many arguments and
+// attributes we decoded so far to compute call frame offsets.
+struct DecodingOffsets {
+  int64_t args = 0;
+  int64_t attrs = 0;
+};
+
+struct DecodingContext {
+  const XLA_FFI_CallFrame* call_frame;
+
+  const std::string* attrs_names;  // not owned
+  const std::size_t* attrs_idx;    // not owned
+};
+
+template <typename T>
+struct Decode {
+  static std::optional<T> call(DecodingOffsets& offsets, DecodingContext& ctx) {
+    int64_t idx = offsets.args++;
+    return ArgDecoding<T>::Decode(ctx.call_frame->args.types[idx],
+                                  ctx.call_frame->args.args[idx]);
+  }
+};
+
+template <typename T>
+struct Decode<internal::AttrTag<T>> {
+  static std::optional<T> call(DecodingOffsets& offsets, DecodingContext& ctx) {
+    // Find decoded attribute corresponding to the given attribute index.
+    int64_t idx = offsets.attrs++;
+
+    // Get mapping from the attribute to its index in the sorted array.
+    size_t i = ctx.attrs_idx[idx];
+
+    // Load attribute from call frame using index into the sorted array.
+    XLA_FFI_AttrType type = ctx.call_frame->attrs.types[i];
+    XLA_FFI_ByteSpan* name = ctx.call_frame->attrs.names[i];
+    void* attr = ctx.call_frame->attrs.attrs[i];
+
+    // TODO(ezhulenev): Currently we require that attributes passed to the FFI
+    // handler must match attributes referenced in a binding, however
+    // we could safely ignore extra attributes. Relax this if needed.
+
+    // Attribute name does not match.
+    std::string_view name_view = {name->ptr, name->len};
+    if (name_view != ctx.attrs_names[idx]) return std::nullopt;
+
+    return AttrDecoding<T>::Decode(name_view, type, attr);
+  }
+};
+
+template <typename T>
+struct Decode<internal::CtxTag<T>> {
+  using R = typename CtxDecoding<T>::Type;
+
+  static std::optional<R> call(DecodingOffsets& offsets, DecodingContext& ctx) {
+    return CtxDecoding<T>::Decode(ctx.call_frame->api, ctx.call_frame->ctx);
+  }
+};
+
+}  // namespace internal
+
+//===----------------------------------------------------------------------===//
+// Template metaprogramming for decoding handler signature
+//===----------------------------------------------------------------------===//
+
+namespace internal {
+
+// A helper struct to extract the type of the handler argument.
+template <typename T>
+struct FnArgType {
+  using Type = T;
+};
+
+// Extracts the underlying type from the attribute type tag.
+template <typename T>
+struct FnArgType<internal::AttrTag<T>> {
+  using Type = T;
+};
+
+// Extracts the underlying type from the context type tag.
+template <typename T>
+struct FnArgType<internal::CtxTag<T>> {
+  using Type = typename CtxDecoding<T>::Type;
+};
+
+// A template for checking if type is a wrapped attribute or user data.
+template <typename>
+struct IsWrapped : std::false_type {};
+template <typename T>
+struct IsWrapped<AttrTag<T>> : std::true_type {};
+template <typename T>
+struct IsWrapped<CtxTag<T>> : std::true_type {};
+
+// A template for counting regular arguments in the Ts pack.
+template <typename... Ts>
+struct NumArgs;
+
+template <>
+struct NumArgs<> {
+  static constexpr int64_t value = 0;
+};
+
+template <typename T, typename... Ts>
+struct NumArgs<T, Ts...> {
+  static constexpr int64_t value = !IsWrapped<T>::value + NumArgs<Ts...>::value;
+};
+
+// A template for counting tagged arguments in the Ts pack (i.e. attributes).
+template <template <typename> class Tag, typename... Ts>
+struct NumTagged;
+
+template <template <typename> class Tag>
+struct NumTagged<Tag> {
+  static constexpr int64_t value = 0;
+};
+
+template <template <typename> class Tag, typename T, typename... Ts>
+struct NumTagged<Tag, Tag<T>, Ts...> {
+  static constexpr int64_t value = 1 + NumTagged<Tag, Ts...>::value;
+};
+
+template <template <typename> class Tag, typename T, typename... Ts>
+struct NumTagged<Tag, T, Ts...> {
+  static constexpr int64_t value = 0 + NumTagged<Tag, Ts...>::value;
+};
+
+}  // namespace internal
+
+//===----------------------------------------------------------------------===//
+// Handler decodes FFI call frame and invokes `Fn` with decoded arguments
+//===----------------------------------------------------------------------===//
+
+template <typename Fn, typename... Ts>
+class Handler : public Ffi {
+  static constexpr int64_t kSize = sizeof...(Ts);
+
+  static constexpr int64_t kNumArgs = internal::NumArgs<Ts...>::value;
+  static constexpr int64_t kNumAttrs =
+      internal::NumTagged<internal::AttrTag, Ts...>::value;
+
+  template <typename T>
+  using FnArgType = typename internal::FnArgType<T>::Type;
+
+  static_assert(std::is_invocable_v<Fn, FnArgType<Ts>...>,
+                "FFI binding signature is not compatible with a function type");
+
+  using ResultType = std::invoke_result_t<Fn, FnArgType<Ts>...>;
+
+ public:
+  XLA_FFI_Error* Call(const XLA_FFI_CallFrame* call_frame) const override {
+    // Sanity checking call frame struct size.
+    if (auto* err = CheckStructSize(call_frame->api, "XLA_FFI_CallFrame",
+                                    XLA_FFI_CallFrame_STRUCT_SIZE,
+                                    call_frame->struct_size))
+      return err;
+
+    // Check that the number of passed arguments matches the signature. Each
+    // individual argument decoding will check the actual type.
+    if (call_frame->args.num_args != kNumArgs) {
+      return InvalidArgument(
+          call_frame->api,
+          StrCat("Wrong number of arguments: expected ", kNumArgs, " but got ",
+                 call_frame->args.num_args));
+    }
+
+    // Check that the number of passed attributes matches the signature. Each
+    // individual attribute decoding will check the actual type.
+    if (call_frame->attrs.num_attrs != kNumAttrs) {
+      return InvalidArgument(
+          call_frame->api,
+          StrCat("Wrong number of attributes: expected ", kNumAttrs,
+                 " but got ", call_frame->attrs.num_attrs));
+    }
+
+    // Define index sequences to access custom call operands.
+    using Is = std::make_index_sequence<kSize>;
+
+    return Call(call_frame, Is{});
+  }
+
+ private:
+  template <size_t... Is>
+  XLA_FFI_Error* Call(const XLA_FFI_CallFrame* call_frame,
+                      std::index_sequence<Is...>) const {
+    // A helper structure to allow each decoder find the correct offset.
+    internal::DecodingOffsets offsets;
+
+    // Package all the data required for decoding ffi handler operands.
+    internal::DecodingContext ctx = {call_frame, attrs_.data(),
+                                     attrs_idx_.data()};
+
+    std::tuple<std::optional<FnArgType<Ts>>...> args = {
+        internal::Decode<Ts>::call(offsets, ctx)...};
+
+    bool all_decoded = (std::get<Is>(args).has_value() && ...);
+    if (!all_decoded) {
+      return FailedDecodeError(call_frame, {std::get<Is>(args).has_value()...});
+    }
+
+    auto result = fn_(std::move(*std::get<Is>(args))...);
+    return ResultEncoding<ResultType>::Encode(call_frame->api,
+                                              std::move(result));
+  }
+
+  XLA_FFI_Error* FailedDecodeError(const XLA_FFI_CallFrame* call_frame,
+                                   std::array<bool, kSize> decoded) const {
+    std::string message =
+        "Failed to decode all FFI handler operands (bad operands at: ";
+    for (size_t cnt = 0, idx = 0; idx < kSize; ++idx) {
+      if (!decoded[idx]) {
+        if (cnt++) message.append(", ");
+        message.append(std::to_string(idx));
+      }
+    }
+    message.append(")");
+    return InvalidArgument(call_frame->api, message);
+  }
+
+  template <typename...>
+  friend class Binding;
+
+  Handler(Fn fn, std::vector<std::string> attrs)
+      : fn_(std::move(fn)), attrs_(std::move(attrs)) {
+    // Sort attributes' names and remove duplicates. These unique attributes are
+    // what we'll be looking for in the call frame attributes.
+    std::vector<std::string> sorted = attrs_;
+    std::sort(sorted.begin(), sorted.end());
+    sorted.erase(
+        std::unique(sorted.begin(), sorted.end(), std::equal_to<std::string>()),
+        sorted.end());
+
+    // Find index of every attribute in the sorted attributes vector.
+    for (size_t i = 0; i < attrs_.size(); ++i) {
+      attrs_idx_.push_back(std::distance(
+          sorted.begin(), std::find(sorted.begin(), sorted.end(), attrs_[i])));
+    }
+  }
+
+  Fn fn_;
+
+  std::vector<std::string> attrs_;  // names of bound attributes
+
+  // A mapping from the attribute index (index into the `attrs_` member) to its
+  // index in the lexicographically sorted vector of attribute names. Call frame
+  // passes attributes sorted by name, and with this index we can find the
+  // attribute we are looking for using O(1) lookup, assuming if the call frame
+  // has exact same attributes as the binding. If not, this allows to do a more
+  // efficient binary search by skipping a part of the call frame attributes.
+  std::vector<size_t> attrs_idx_;
+};
+
+//===----------------------------------------------------------------------===//
+// Builtin attributes decoding
+//===----------------------------------------------------------------------===//
+
+#define XLA_FFI_REGISTER_SCALAR_ATTR_DECODING(T, TYPE)                  \
+  template <>                                                           \
+  struct AttrDecoding<T> {                                              \
+    static std::optional<T> Decode(std::string_view name,               \
+                                   XLA_FFI_AttrType type, void* attr) { \
+      if (type != TYPE) {                                               \
+        return std::nullopt;                                            \
+      }                                                                 \
+                                                                        \
+      return *reinterpret_cast<T*>(attr);                               \
+    }                                                                   \
+  }
+
+XLA_FFI_REGISTER_SCALAR_ATTR_DECODING(int32_t, XLA_FFI_AttrType_I32);
+XLA_FFI_REGISTER_SCALAR_ATTR_DECODING(float, XLA_FFI_AttrType_F32);
+
+#undef XLA_FFI_REGISTER_SCALAR_ATTR_DECODING
+
+template <>
+struct AttrDecoding<std::string_view> {
+  static std::optional<std::string_view> Decode(std::string_view name,
+                                                XLA_FFI_AttrType type,
+                                                void* attr) {
+    if (type != XLA_FFI_AttrType_STRING) {
+      return std::nullopt;
+    }
+
+    auto* span = reinterpret_cast<XLA_FFI_ByteSpan*>(attr);
+    return std::string_view(span->ptr, span->len);
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// Helper macro for registering FFI implementations
+//===----------------------------------------------------------------------===//
+
+#if (defined(__GNUC__) || defined(__APPLE__)) && !defined(SWIG)  // GCC-style
+#define XLA_FFI_ATTRIBUTE_UNUSED __attribute__((unused))
+#else  // Non-GCC equivalents
+#define XLA_FFI_ATTRIBUTE_UNUSED
+#endif
+
+// Use captureless lambda to function pointer conversion to create a static
+// XLA_FFI_Handler function pointer variable.
+#define XLA_FFI_DEFINE_HANDLER(fn, impl, binding)                             \
+  static constexpr XLA_FFI_Handler* fn = +[](XLA_FFI_CallFrame* call_frame) { \
+    static auto* handler = binding.To(impl).release();                        \
+    return handler->Call(call_frame);                                         \
+  }
+
+// TODO(ezhulenev): Add a callback so that end users can log registration error
+// to appropriate logging destination, e.g. LOG(FATAL) for duplicate internal
+// FFI handlers.
+#define XLA_FFI_REGISTER_HANDLER(API, NAME, FUNC) \
+  XLA_FFI_REGISTER_HANDLER_(API, NAME, FUNC, __COUNTER__)
+#define XLA_FFI_REGISTER_HANDLER_(API, NAME, FUNC, N) \
+  XLA_FFI_REGISTER_HANDLER__(API, NAME, FUNC, N)
+#define XLA_FFI_REGISTER_HANDLER__(API, NAME, FUNC, N)                  \
+  XLA_FFI_ATTRIBUTE_UNUSED static const XLA_FFI_Error*                  \
+      xla_ffi_static_handler_##N##_registered_ = [] {                   \
+        return ::xla::ffi::Ffi::RegisterStaticHandler(API, NAME, FUNC); \
+      }()
+
+}  // namespace xla::ffi
+
+#endif  // XLA_FFI_API_API_H_
diff --git a/third_party/xla/xla/ffi/api/c_api.h b/third_party/xla/xla/ffi/api/c_api.h
new file mode 100644
index 0000000..55d9282
--- /dev/null
+++ b/third_party/xla/xla/ffi/api/c_api.h
@@ -0,0 +1,255 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_FFI_API_C_API_H_
+#define XLA_FFI_API_C_API_H_
+
+#include <stddef.h>
+#include <stdint.h>
+
+// XLA FFI C API follows PJRT API style for consistency. See `pjrt_c_api.h`.
+
+// Every struct passed across the C API boundary has its size as a member, and
+// we use it as a sanity check for API compatibility.
+#define XLA_FFI_STRUCT_SIZE(struct_type, last_field) \
+  (offsetof(struct_type, last_field) + sizeof(((struct_type*)0)->last_field))
+
+#define XLA_FFI_DEFINE_STRUCT_TRAITS(sname, last_field) \
+  typedef struct sname sname;                           \
+  enum { sname##_STRUCT_SIZE = XLA_FFI_STRUCT_SIZE(sname, last_field) }
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+typedef struct XLA_FFI_Api XLA_FFI_Api;                  // Forward declare
+typedef struct XLA_FFI_InternalApi XLA_FFI_InternalApi;  // Forward declare
+
+//===----------------------------------------------------------------------===//
+// Version
+//===----------------------------------------------------------------------===//
+
+// Incremented when an ABI-incompatible change is made to the interface.
+//
+// Major changes include:
+// * Deleting a method or argument
+// * Changing the type of an argument
+// * Rearranging fields in the XLA_FFI_Api or argument structs
+#define XLA_FFI_API_MAJOR 0
+
+// Incremented when the interface is updated in a way that is potentially
+// ABI-compatible with older versions, if supported by the caller and/or
+// implementation.
+//
+// Callers can implement forwards compatibility by using XLA_FFI_Api_Version to
+// check if the implementation is aware of newer interface additions.
+//
+// Implementations can implement backwards compatibility by using the
+// `struct_size` fields to detect how many struct fields the caller is aware of.
+//
+// Minor changes include:
+// * Adding a new field to the XLA_FFI_Api or argument structs
+// * Renaming a method or argument (doesn't affect ABI)
+#define XLA_FFI_API_MINOR 0
+
+struct XLA_FFI_Api_Version {
+  size_t struct_size;
+  void* priv;
+  int major_version;  // out
+  int minor_version;  // out
+};
+
+XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Api_Version, minor_version);
+
+//===----------------------------------------------------------------------===//
+// Error codes
+//===----------------------------------------------------------------------===//
+
+// XLA FFI handler must return an XLA_FFI_Error*, which is NULL if there is no
+// error and set if there is. Caller allocates any returned XLA_FFI_Errors, and
+// the XLA FFI is responsible for freeing them.
+typedef struct XLA_FFI_Error XLA_FFI_Error;
+
+// Codes are based on https://abseil.io/docs/cpp/guides/status-codes
+typedef enum {
+  XLA_FFI_Error_Code_CANCELLED = 1,
+  XLA_FFI_Error_Code_UNKNOWN = 2,
+  XLA_FFI_Error_Code_INVALID_ARGUMENT = 3,
+  XLA_FFI_Error_Code_DEADLINE_EXCEEDED = 4,
+  XLA_FFI_Error_Code_NOT_FOUND = 5,
+  XLA_FFI_Error_Code_ALREADY_EXISTS = 6,
+  XLA_FFI_Error_Code_PERMISSION_DENIED = 7,
+  XLA_FFI_Error_Code_RESOURCE_EXHAUSTED = 8,
+  XLA_FFI_Error_Code_FAILED_PRECONDITION = 9,
+  XLA_FFI_Error_Code_ABORTED = 10,
+  XLA_FFI_Error_Code_OUT_OF_RANGE = 11,
+  XLA_FFI_Error_Code_UNIMPLEMENTED = 12,
+  XLA_FFI_Error_Code_INTERNAL = 13,
+  XLA_FFI_Error_Code_UNAVAILABLE = 14,
+  XLA_FFI_Error_Code_DATA_LOSS = 15,
+  XLA_FFI_Error_Code_UNAUTHENTICATED = 16
+} XLA_FFI_Error_Code;
+
+//===----------------------------------------------------------------------===//
+// Error reporting APIs
+//===----------------------------------------------------------------------===//
+
+struct XLA_FFI_Error_Create_Args {
+  size_t struct_size;
+  void* priv;
+  const char* message;
+  XLA_FFI_Error_Code errc;
+};
+
+XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Error_Create_Args, errc);
+
+typedef XLA_FFI_Error* XLA_FFI_Error_Create(XLA_FFI_Error_Create_Args* args);
+
+//===----------------------------------------------------------------------===//
+// Builtin argument types
+//===----------------------------------------------------------------------===//
+
+struct XLA_FFI_Buffer {
+  size_t struct_size;
+  void* priv;
+
+  void* data;
+  uint8_t primitive_type;
+  int64_t rank;
+  int64_t* dims;  // length == rank
+};
+
+XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Buffer, dims);
+
+typedef enum {
+  XLA_FFI_ArgType_BUFFER = 1,
+} XLA_FFI_ArgType;
+
+//===----------------------------------------------------------------------===//
+// Builtin attribute types
+//===----------------------------------------------------------------------===//
+
+typedef enum {
+  XLA_FFI_AttrType_I32 = 1,
+  XLA_FFI_AttrType_F32 = 2,
+  XLA_FFI_AttrType_STRING = 3,
+} XLA_FFI_AttrType;
+
+//===----------------------------------------------------------------------===//
+// Execution context
+//===----------------------------------------------------------------------===//
+
+// Execution context provides access to per-invocation state.
+typedef struct XLA_FFI_ExecutionContext XLA_FFI_ExecutionContext;
+
+//===----------------------------------------------------------------------===//
+// Call frame
+//===----------------------------------------------------------------------===//
+
+// We use byte spans to pass strings to handlers because strings might not be
+// null terminated, and even if they are, looking for a null terminator can
+// become very expensive in tight loops.
+struct XLA_FFI_ByteSpan {
+  size_t struct_size;
+  void* priv;
+
+  const char* ptr;
+  size_t len;
+};
+
+XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_ByteSpan, len);
+
+struct XLA_FFI_Args {
+  size_t struct_size;
+  void* priv;
+
+  int64_t num_args;
+  XLA_FFI_ArgType* types;  // length == num_args
+  void** args;             // length == num_args
+};
+
+XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Args, args);
+
+// FFI handler attributes are always sorted by name, so that the handler can
+// rely on binary search to look up attributes by name.
+struct XLA_FFI_Attrs {
+  size_t struct_size;
+  void* priv;
+
+  int64_t num_attrs;
+  XLA_FFI_AttrType* types;   // length == num_attrs
+  XLA_FFI_ByteSpan** names;  // length == num_attrs
+  void** attrs;              // length == num_attrs
+};
+
+XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Attrs, attrs);
+
+struct XLA_FFI_CallFrame {
+  size_t struct_size;
+  void* priv;
+
+  XLA_FFI_Api* api;
+  XLA_FFI_ExecutionContext* ctx;
+  XLA_FFI_Args args;
+  XLA_FFI_Attrs attrs;
+};
+
+XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_CallFrame, attrs);
+
+//===----------------------------------------------------------------------===//
+// FFI handler
+//===----------------------------------------------------------------------===//
+
+// External functions registered with XLA as FFI handlers.
+typedef XLA_FFI_Error* XLA_FFI_Handler(XLA_FFI_CallFrame* call_frame);
+
+struct XLA_FFI_Handler_Register_Args {
+  size_t struct_size;
+  void* priv;
+
+  const char* name;  // null terminated
+  XLA_FFI_Handler* handler;
+};
+
+XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Handler_Register_Args, handler);
+
+typedef XLA_FFI_Error* XLA_FFI_Handler_Register(
+    XLA_FFI_Handler_Register_Args* args);
+
+//===----------------------------------------------------------------------===//
+// API access
+//===----------------------------------------------------------------------===//
+
+#define _XLA_FFI_API_STRUCT_FIELD(fn_type) fn_type* fn_type
+
+struct XLA_FFI_Api {
+  size_t struct_size;
+  void* priv;
+
+  XLA_FFI_InternalApi* internal_api;
+
+  _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Error_Create);
+  _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Handler_Register);
+};
+
+#undef _XLA_FFI_API_STRUCT_FIELD
+
+XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Api, XLA_FFI_Handler_Register);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif  // XLA_FFI_API_C_API_H_
diff --git a/third_party/xla/xla/ffi/api/c_api_internal.h b/third_party/xla/xla/ffi/api/c_api_internal.h
new file mode 100644
index 0000000..c8c9fc7
--- /dev/null
+++ b/third_party/xla/xla/ffi/api/c_api_internal.h
@@ -0,0 +1,64 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_FFI_API_C_API_INTERNAL_H_
+#define XLA_FFI_API_C_API_INTERNAL_H_
+
+#include "xla/ffi/api/c_api.h"
+
+// Internal XLA FFI API that gives access to XLA implementation details that
+// should be used only for implementing FFI handlers statically linked into
+// the binary. This API should be used only by XLA itself (to implement builtin
+// custom calls), or libraries tightly coupled to XLA and built from exact same
+// commit and using the same toolchain (e.g. jaxlib). Trying to use this API
+// from a dynamically loaded shared library can lead to undefined behavior and
+// likely impossible to debug run time crashes.
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+// Because this is an internal XLA FFI API we use a slightly relaxed C API
+// style and do not track the struct size, as we expect this API to be used
+// only in statically linked binaries, and we do not need any backward or
+// forward compatibility.
+
+// Forwards `absl::Status` object pointed to by `status` to XLA FFI error
+// (status left in moved-from state). Pointer ownership stays with the
+// caller.
+typedef XLA_FFI_Error* XLA_FFI_Error_Forward(void* status);
+
+// Returns a pointer to `xla::ServiceExecutableRunOptions`.
+typedef void* XLA_FFI_ServiceExecutableRunOptions_Get(
+    XLA_FFI_ExecutionContext* ctx);
+
+//===----------------------------------------------------------------------===//
+// API access
+//===----------------------------------------------------------------------===//
+
+#define _XLA_FFI_INTERNAL_API_STRUCT_FIELD(fn_type) fn_type* fn_type
+
+struct XLA_FFI_InternalApi {
+  _XLA_FFI_INTERNAL_API_STRUCT_FIELD(XLA_FFI_Error_Forward);
+  _XLA_FFI_INTERNAL_API_STRUCT_FIELD(XLA_FFI_ServiceExecutableRunOptions_Get);
+};
+
+#undef _XLA_FFI_INTERNAL_API_STRUCT_FIELD
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif  // XLA_FFI_API_C_API_INTERNAL_H_
diff --git a/third_party/xla/xla/ffi/api/ffi.h b/third_party/xla/xla/ffi/api/ffi.h
new file mode 100644
index 0000000..1985864
--- /dev/null
+++ b/third_party/xla/xla/ffi/api/ffi.h
@@ -0,0 +1,30 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_FFI_API_FFI_H_
+#define XLA_FFI_API_FFI_H_
+
+#ifdef TENSORFLOW_COMPILER_XLA_FFI_FFI_H_
+#error Two different XLA FFI implementations cannot be included together
+#endif  // XLA_FFI_API_H_
+
+// IWYU pragma: begin_exports
+#include "xla/ffi/api/api.h"
+// IWYU pragma: end_exports
+
+// TODO(ezhulenev): Implement FFI arguments and attributes decoding for external
+// FFI users without any dependencies on absl or other libraries.
+
+#endif  // XLA_FFI_API_FFI_H_
diff --git a/third_party/xla/xla/ffi/call_frame.cc b/third_party/xla/xla/ffi/call_frame.cc
new file mode 100644
index 0000000..e828ab2
--- /dev/null
+++ b/third_party/xla/xla/ffi/call_frame.cc
@@ -0,0 +1,288 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/ffi/call_frame.h"
+
+#include <cstddef>
+#include <cstdint>
+#include <functional>
+#include <memory>
+#include <string>
+#include <utility>
+#include <variant>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/types/span.h"
+#include "xla/ffi/api/c_api.h"
+#include "xla/ffi/api/c_api_internal.h"  // IWYU pragma: keep
+#include "xla/stream_executor/device_memory.h"
+#include "xla/xla_data.pb.h"
+
+namespace xla::ffi {
+
+//===----------------------------------------------------------------------===//
+// CallFrameBuilder
+//===----------------------------------------------------------------------===//
+
+void CallFrameBuilder::AddBufferArg(se::DeviceMemoryBase memory,
+                                    PrimitiveType type,
+                                    absl::Span<const int64_t> dims) {
+  args_.push_back(Buffer{memory, type, {dims.begin(), dims.end()}});
+}
+
+void CallFrameBuilder::AddI32Attr(std::string name, int32_t value) {
+  attrs_.try_emplace(std::move(name), value);
+}
+
+void CallFrameBuilder::AddF32Attr(std::string name, float value) {
+  attrs_.try_emplace(std::move(name), value);
+}
+
+void CallFrameBuilder::AddStringAttr(std::string name, std::string value) {
+  attrs_.try_emplace(std::move(name), value);
+}
+
+void CallFrameBuilder::AddAttribute(std::string name, Attribute attr) {
+  attrs_.try_emplace(std::move(name), attr);
+}
+
+void CallFrameBuilder::AddAttributes(const AttributesMap& attrs) {
+  attrs_.insert(attrs.begin(), attrs.end());
+}
+
+CallFrame CallFrameBuilder::Build() { return CallFrame(args_, attrs_); }
+
+// ------------------------    !!! !!! !!!     ------------------------------ //
+
+// WARNING: In many structs defined below we use a pattern where we declare
+// a storage (e.g. an `std::string` member) and an XLA FFI reference type
+// pointing into that storage in the same struct (XLA_FFI_ByteSpan). Extra care
+// should be taken of keeping reference type up to date, e.g. if a parent
+// struct put into an `std::vector` container, every time vector will reallocate
+// storage all reference types will become invalid.
+
+// We intentionally do not use smart pointers that would guarantee pointer
+// stability for storage, as we are trying to minimize the number of heap
+// allocations required for building a call frame.
+
+// This is a low level internal implementation detail that should not leak via
+// public header files, and can be changed at any time in the future.
+
+//----------------------------------------------------------------------------//
+// Arguments storage + reference types
+//----------------------------------------------------------------------------//
+
+struct CallFrame::Buffer {
+  std::vector<int64_t> dims;  // XLA_FFI_Buffer::dims
+
+  XLA_FFI_Buffer buffer = {XLA_FFI_Buffer_STRUCT_SIZE, nullptr};
+};
+
+struct CallFrame::Arguments {
+  explicit Arguments(size_t size) {
+    arguments.reserve(size);
+    types.reserve(size);
+    args.reserve(size);
+  }
+
+  std::vector<Buffer> arguments;
+
+  std::vector<XLA_FFI_ArgType> types;  // XLA_FFI_Args::types
+  std::vector<void*> args;             // XLA_FFI_Args::args
+
+  XLA_FFI_Args ffi_args = {XLA_FFI_Args_STRUCT_SIZE, nullptr};
+};
+
+//----------------------------------------------------------------------------//
+// Attributes storage + reference types
+//----------------------------------------------------------------------------//
+
+struct CallFrame::String {
+  std::string value;  // XLA_FFI_ByteSpan::ptr
+
+  XLA_FFI_ByteSpan span = {XLA_FFI_ByteSpan_STRUCT_SIZE, nullptr};
+};
+
+struct CallFrame::NamedAttribute {
+  String name;
+  Attribute value;
+};
+
+struct CallFrame::Attributes {
+  explicit Attributes(size_t size) {
+    attributes.reserve(size);
+    names.reserve(size);
+    types.reserve(size);
+    attrs.reserve(size);
+  }
+
+  std::vector<NamedAttribute> attributes;
+
+  std::vector<XLA_FFI_ByteSpan*> names;  // XLA_FFI_Attributes::names
+  std::vector<XLA_FFI_AttrType> types;   // XLA_FFI_Attributes::types
+  std::vector<void*> attrs;              // XLA_FFI_Attributes::attrs
+
+  XLA_FFI_Attrs ffi_attrs = {XLA_FFI_Attrs_STRUCT_SIZE, nullptr};
+};
+
+//===----------------------------------------------------------------------===//
+// CallFrame
+//===----------------------------------------------------------------------===//
+
+CallFrame::CallFrame(absl::Span<const CallFrameBuilder::Buffer> args,
+                     const CallFrameBuilder::AttributesMap& attrs)
+    : arguments_(InitArgs(args)), attributes_(InitAttrs(attrs)) {}
+
+XLA_FFI_CallFrame CallFrame::Build(XLA_FFI_Api* api,
+                                   XLA_FFI_ExecutionContext* ctx) {
+  XLA_FFI_CallFrame call_frame = {XLA_FFI_CallFrame_STRUCT_SIZE, nullptr};
+  call_frame.api = api;
+  call_frame.ctx = ctx;
+  call_frame.args = arguments_->ffi_args;
+  call_frame.attrs = attributes_->ffi_attrs;
+  return call_frame;
+}
+
+CallFrame::~CallFrame() = default;
+
+//===----------------------------------------------------------------------===//
+// Call frame arguments
+//===----------------------------------------------------------------------===//
+
+/*static*/ std::unique_ptr<CallFrame::Arguments> CallFrame::InitArgs(
+    absl::Span<const CallFrameBuilder::Buffer> bargs) {
+  auto res = std::make_unique<Arguments>(bargs.size());
+
+  // Convert call frame builder arguments to call frame arguments.
+  for (const CallFrameBuilder::Buffer& barg : bargs) {
+    Buffer buffer;
+    buffer.dims = barg.dims;
+    buffer.buffer.data = const_cast<void*>(barg.memory.opaque());
+    buffer.buffer.primitive_type = static_cast<uint8_t>(barg.type);
+    buffer.buffer.rank = buffer.dims.size();
+    res->arguments.push_back(std::move(buffer));
+  }
+
+  // Fix up pointers in XLA FFI structs.
+  for (CallFrame::Buffer& arg : res->arguments) {
+    arg.buffer.dims = arg.dims.data();
+  }
+
+  // Initialize vectors required for building XLA_FFI_Args.
+  for (CallFrame::Buffer& arg : res->arguments) {
+    res->types.push_back(XLA_FFI_ArgType_BUFFER);
+    res->args.push_back(&arg.buffer);
+  }
+
+  // Finally initialize the XLA FFI struct. At this point all storage is
+  // allocated and it's safe to grab a pointer to it.
+  res->ffi_args.num_args = res->arguments.size();
+  res->ffi_args.types = res->types.data();
+  res->ffi_args.args = res->args.data();
+
+  return res;
+}
+
+//===----------------------------------------------------------------------===//
+// Call frame attributes
+//===----------------------------------------------------------------------===//
+
+// An std::visit overload set for converting CallFrameBuilder::Attribute to
+// CallFrame::Attribute.
+struct CallFrame::ConvertAttribute {
+  template <typename T>
+  CallFrame::Attribute operator()(const T& value) {
+    return value;
+  }
+
+  CallFrame::Attribute operator()(const std::string& str) {
+    return CallFrame::String{str};
+  }
+};
+
+// An std::visit overload set to fix up CallFrame::Attribute storage and
+// initialize XLA FFI structs with valid pointers into storage objects.
+struct CallFrame::FixupAttribute {
+  template <typename T>
+  void operator()(T& value) {}
+
+  void operator()(CallFrame::String& str) {
+    str.span.ptr = str.value.data();
+    str.span.len = str.value.size();
+  }
+};
+
+// An std::visit overload set to get CallFrame::Attribute XLA FFI type.
+struct CallFrame::AttributeType {
+  XLA_FFI_AttrType operator()(int32_t&) { return XLA_FFI_AttrType_I32; }
+
+  XLA_FFI_AttrType operator()(float&) { return XLA_FFI_AttrType_F32; }
+
+  XLA_FFI_AttrType operator()(CallFrame::String&) {
+    return XLA_FFI_AttrType_STRING;
+  }
+};
+
+// An std::visit overload set to get CallFrame::Attribute storage pointer.
+struct CallFrame::AttributeStorage {
+  template <typename T>
+  void* operator()(T& value) {
+    return &value;
+  }
+
+  void* operator()(CallFrame::String& str) { return &str.span; }
+};
+
+/*static*/ std::unique_ptr<CallFrame::Attributes> CallFrame::InitAttrs(
+    const CallFrameBuilder::AttributesMap& battrs) {
+  auto res = std::make_unique<Attributes>(battrs.size());
+
+  // Convert call frame builder attributes to a collection of named attributes.
+  for (auto& [name, battr] : battrs) {
+    NamedAttribute attr = {String{name}, std::visit(ConvertAttribute(), battr)};
+    res->attributes.push_back(std::move(attr));
+  }
+
+  // Sort attributes by name to enable binary search at run time.
+  absl::c_sort(res->attributes,
+               [](const NamedAttribute& a, const NamedAttribute& b) {
+                 return a.name.value < b.name.value;
+               });
+
+  // Fix up XLA FFI structs to point to correct storage.
+  for (NamedAttribute& attr : res->attributes) {
+    std::invoke(FixupAttribute{}, attr.name);
+    std::visit(FixupAttribute{}, attr.value);
+  }
+
+  // Initialize vectors required for building XLA_FFI_Attributes.
+  for (NamedAttribute& attr : res->attributes) {
+    res->names.push_back(&attr.name.span);
+    res->types.push_back(std::visit(AttributeType(), attr.value));
+    res->attrs.push_back(std::visit(AttributeStorage(), attr.value));
+  }
+
+  // Finally initialize XLA FFI struct. At this point all storage is allocated
+  // and it's safe to grab a pointer to it.
+  res->ffi_attrs.num_attrs = res->attributes.size();
+  res->ffi_attrs.names = res->names.data();
+  res->ffi_attrs.types = res->types.data();
+  res->ffi_attrs.attrs = res->attrs.data();
+
+  return res;
+}
+
+}  // namespace xla::ffi
diff --git a/third_party/xla/xla/ffi/call_frame.h b/third_party/xla/xla/ffi/call_frame.h
new file mode 100644
index 0000000..2f5327f
--- /dev/null
+++ b/third_party/xla/xla/ffi/call_frame.h
@@ -0,0 +1,117 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_FFI_CALL_FRAME_H_
+#define XLA_FFI_CALL_FRAME_H_
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <variant>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/types/span.h"
+#include "xla/ffi/api/c_api.h"
+#include "xla/stream_executor/device_memory.h"
+#include "xla/types.h"  // IWYU pragma: keep
+#include "xla/xla_data.pb.h"
+
+namespace xla::ffi {
+
+// CallFrame library encodes C++ arguments using XLA FFI C API structs in a form
+// compatible with the decoding defined in `ffi/api.h`.
+
+//===----------------------------------------------------------------------===//
+// CallFrameBuilder
+//===----------------------------------------------------------------------===//
+
+class CallFrame;  // forward declare
+
+class CallFrameBuilder {
+ public:
+  using Attribute = std::variant<int32_t, float, std::string>;
+  using AttributesMap = absl::flat_hash_map<std::string, Attribute>;
+
+  CallFrame Build();
+
+  void AddBufferArg(se::DeviceMemoryBase memory, PrimitiveType type,
+                    absl::Span<const int64_t> dims);
+
+  void AddI32Attr(std::string name, int32_t value);
+  void AddF32Attr(std::string name, float value);
+  void AddStringAttr(std::string name, std::string value);
+
+  void AddAttribute(std::string name, Attribute attr);
+  void AddAttributes(const AttributesMap& attrs);
+
+ private:
+  friend class CallFrame;
+
+  struct Buffer {
+    se::DeviceMemoryBase memory;
+    PrimitiveType type;
+    std::vector<int64_t> dims;
+  };
+
+  std::vector<Buffer> args_;
+  AttributesMap attrs_;
+};
+
+//===----------------------------------------------------------------------===//
+// CallFrame
+//===----------------------------------------------------------------------===//
+
+class CallFrame {
+ public:
+  ~CallFrame();
+
+  // Builds an XLA_FFI_CallFrame from owned arguments and attributes.
+  XLA_FFI_CallFrame Build(XLA_FFI_Api* api, XLA_FFI_ExecutionContext* ctx);
+
+ private:
+  friend class CallFrameBuilder;
+
+  // Declare implementation detail structs for call frame storage.
+  struct Arguments;
+  struct Attributes;
+  struct Buffer;
+  struct NamedAttribute;
+  struct String;
+
+  using Attribute = std::variant<int32_t, float, String>;
+
+  CallFrame(absl::Span<const CallFrameBuilder::Buffer> args,
+            const CallFrameBuilder::AttributesMap& attrs);
+
+  static std::unique_ptr<Arguments> InitArgs(
+      absl::Span<const CallFrameBuilder::Buffer> args);
+
+  static std::unique_ptr<Attributes> InitAttrs(
+      const CallFrameBuilder::AttributesMap& attrs);
+
+  std::unique_ptr<Arguments> arguments_;
+  std::unique_ptr<Attributes> attributes_;
+
+  // Declare implementation detail structs to grant access to private members.
+  struct ConvertAttribute;
+  struct FixupAttribute;
+  struct AttributeType;
+  struct AttributeStorage;
+};
+
+}  // namespace xla::ffi
+
+#endif  // XLA_FFI_CALL_FRAME_H_
diff --git a/third_party/xla/xla/ffi/ffi.cc b/third_party/xla/xla/ffi/ffi.cc
new file mode 100644
index 0000000..9899010
--- /dev/null
+++ b/third_party/xla/xla/ffi/ffi.cc
@@ -0,0 +1,222 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/ffi/ffi.h"
+
+#include <cstddef>
+#include <string>
+#include <string_view>
+#include <utility>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
+#include "xla/ffi/api/c_api.h"
+#include "xla/ffi/api/c_api_internal.h"  // IWYU pragma: keep
+#include "xla/ffi/call_frame.h"
+#include "xla/service/service_executable_run_options.h"
+#include "xla/status.h"
+#include "xla/statusor.h"
+#include "tsl/platform/logging.h"
+
+//===----------------------------------------------------------------------===//
+// XLA FFI C structs definition
+//===----------------------------------------------------------------------===//
+
+struct XLA_FFI_Error {
+  xla::Status status;
+};
+
+struct XLA_FFI_ExecutionContext {
+  const xla::ServiceExecutableRunOptions* run_options;
+};
+
+//===----------------------------------------------------------------------===//
+
+namespace xla::ffi {
+
+Status TakeStatus(XLA_FFI_Error* error) {
+  if (error == nullptr) return absl::OkStatus();
+  Status status = std::move(error->status);
+  delete error;
+  return status;
+}
+
+Status Call(Ffi& handler, CallFrame& call_frame, const CallOptions& options) {
+  XLA_FFI_ExecutionContext ctx = {options.run_options};
+  XLA_FFI_CallFrame ffi_call_frame = call_frame.Build(GetXlaFfiApi(), &ctx);
+  return TakeStatus(handler.Call(&ffi_call_frame));
+}
+
+Status Call(XLA_FFI_Handler* handler, CallFrame& call_frame,
+            const CallOptions& options) {
+  XLA_FFI_ExecutionContext ctx = {options.run_options};
+  XLA_FFI_CallFrame ffi_call_frame = call_frame.Build(GetXlaFfiApi(), &ctx);
+  return TakeStatus((*handler)(&ffi_call_frame));
+}
+
+//===----------------------------------------------------------------------===//
+// XLA FFI registry
+//===----------------------------------------------------------------------===//
+
+// TODO(ezhulenev): We have to support platform-specific handler registration.
+using HandlerRegistry = absl::flat_hash_map<std::string, XLA_FFI_Handler*>;
+
+static HandlerRegistry& GetHandlerRegistry() {
+  static auto* registry = new HandlerRegistry();
+  return *registry;
+}
+
+static Status RegisterHandler(std::string_view name, XLA_FFI_Handler* handler) {
+  auto emplaced = GetHandlerRegistry().try_emplace(std::string(name), handler);
+  if (!emplaced.second)
+    return absl::InvalidArgumentError(
+        absl::StrCat("Duplicate FFI handler registration for ", name));
+  return OkStatus();
+}
+
+StatusOr<XLA_FFI_Handler*> FindHandler(std::string_view name) {
+  auto it = GetHandlerRegistry().find(name);
+  if (it == GetHandlerRegistry().end())
+    return absl::NotFoundError(
+        absl::StrCat("No FFI handler registered for ", name));
+  return it->second;
+}
+
+//===----------------------------------------------------------------------===//
+// XLA FFI Api Implementation
+//===----------------------------------------------------------------------===//
+
+static std::string StructSizeErrorMsg(std::string_view struct_name,
+                                      size_t expected, size_t actual) {
+  return absl::StrCat("Unexpected ", struct_name, " size: expected ", expected,
+                      ", got ", actual, ". Check installed software versions. ",
+                      "The framework XLA FFI API version is ",
+                      XLA_FFI_API_MAJOR, ".", XLA_FFI_API_MINOR, ".");
+}
+
+static Status ActualStructSizeIsGreaterOrEqual(std::string_view struct_name,
+                                               size_t expected, size_t actual) {
+  if (actual < expected) {
+    return absl::InvalidArgumentError(
+        StructSizeErrorMsg(struct_name, expected, actual));
+  }
+  if (actual > expected) {
+    VLOG(2) << StructSizeErrorMsg(struct_name, expected, actual);
+  }
+  return absl::OkStatus();
+}
+
+static absl::StatusCode ToStatusCode(XLA_FFI_Error_Code errc) {
+  switch (errc) {
+    case XLA_FFI_Error_Code_CANCELLED:
+      return absl::StatusCode::kCancelled;
+    case XLA_FFI_Error_Code_UNKNOWN:
+      return absl::StatusCode::kUnknown;
+    case XLA_FFI_Error_Code_INVALID_ARGUMENT:
+      return absl::StatusCode::kInvalidArgument;
+    case XLA_FFI_Error_Code_DEADLINE_EXCEEDED:
+      return absl::StatusCode::kDeadlineExceeded;
+    case XLA_FFI_Error_Code_NOT_FOUND:
+      return absl::StatusCode::kNotFound;
+    case XLA_FFI_Error_Code_ALREADY_EXISTS:
+      return absl::StatusCode::kAlreadyExists;
+    case XLA_FFI_Error_Code_PERMISSION_DENIED:
+      return absl::StatusCode::kPermissionDenied;
+    case XLA_FFI_Error_Code_RESOURCE_EXHAUSTED:
+      return absl::StatusCode::kResourceExhausted;
+    case XLA_FFI_Error_Code_FAILED_PRECONDITION:
+      return absl::StatusCode::kFailedPrecondition;
+    case XLA_FFI_Error_Code_ABORTED:
+      return absl::StatusCode::kAborted;
+    case XLA_FFI_Error_Code_OUT_OF_RANGE:
+      return absl::StatusCode::kOutOfRange;
+    case XLA_FFI_Error_Code_UNIMPLEMENTED:
+      return absl::StatusCode::kUnimplemented;
+    case XLA_FFI_Error_Code_INTERNAL:
+      return absl::StatusCode::kInternal;
+    case XLA_FFI_Error_Code_UNAVAILABLE:
+      return absl::StatusCode::kUnavailable;
+    case XLA_FFI_Error_Code_DATA_LOSS:
+      return absl::StatusCode::kDataLoss;
+    case XLA_FFI_Error_Code_UNAUTHENTICATED:
+      return absl::StatusCode::kUnauthenticated;
+  }
+}
+
+#define XLA_FFI_RETURN_IF_ERROR(expr)                                   \
+  do {                                                                  \
+    Status _status = (expr);                                            \
+    if (!_status.ok()) {                                                \
+      XLA_FFI_Error* _c_status = new XLA_FFI_Error{std::move(_status)}; \
+      return _c_status;                                                 \
+    }                                                                   \
+  } while (false)
+
+static XLA_FFI_Error* XLA_FFI_Error_Create(XLA_FFI_Error_Create_Args* args) {
+  XLA_FFI_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual(
+      "XLA_FFI_Error_Create", XLA_FFI_Error_Create_Args_STRUCT_SIZE,
+      args->struct_size));
+
+  return new XLA_FFI_Error{Status(ToStatusCode(args->errc), args->message)};
+}
+
+static XLA_FFI_Error* XLA_FFI_Handler_Register(
+    XLA_FFI_Handler_Register_Args* args) {
+  XLA_FFI_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual(
+      "XLA_FFI_Handler_Register", XLA_FFI_Handler_Register_Args_STRUCT_SIZE,
+      args->struct_size));
+
+  if (auto status = RegisterHandler(args->name, args->handler); !status.ok()) {
+    return new XLA_FFI_Error{std::move(status)};
+  }
+  return nullptr;
+}
+
+//===----------------------------------------------------------------------===//
+// XLA FFI Internal Api Implementation
+//===----------------------------------------------------------------------===//
+
+static XLA_FFI_Error* XLA_FFI_Error_Forward(void* status) {
+  return new XLA_FFI_Error{std::move(*reinterpret_cast<Status*>(status))};
+}
+
+static void* XLA_FFI_ServiceExecutableRunOptions_Get(
+    XLA_FFI_ExecutionContext* ctx) {
+  return const_cast<ServiceExecutableRunOptions*>(ctx->run_options);
+}
+
+//===----------------------------------------------------------------------===//
+// XLA FFI Api access
+//===----------------------------------------------------------------------===//
+
+static XLA_FFI_InternalApi internal_api = {
+    XLA_FFI_Error_Forward,
+    XLA_FFI_ServiceExecutableRunOptions_Get,
+};
+
+static XLA_FFI_Api api = {
+    XLA_FFI_Api_STRUCT_SIZE,
+    /*priv=*/nullptr,
+
+    &internal_api,
+
+    XLA_FFI_Error_Create,      // creates error
+    XLA_FFI_Handler_Register,  // registers handler
+};
+
+XLA_FFI_Api* GetXlaFfiApi() { return &api; }
+
+}  // namespace xla::ffi
diff --git a/third_party/xla/xla/ffi/ffi.h b/third_party/xla/xla/ffi/ffi.h
new file mode 100644
index 0000000..18c9ca8
--- /dev/null
+++ b/third_party/xla/xla/ffi/ffi.h
@@ -0,0 +1,139 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_FFI_FFI_H_
+#define XLA_FFI_FFI_H_
+
+#ifdef TENSORFLOW_COMPILER_XLA_FFI_API_FFI_H_
+#error Two different XLA FFI implementations cannot be included together
+#endif  // XLA_FFI_API_FFI_H_
+
+#include <cstdint>
+#include <optional>
+#include <string_view>
+
+// IWYU pragma: begin_exports
+#include "xla/ffi/api/api.h"
+// IWYU pragma: end_exports
+
+#include "absl/types/span.h"
+#include "xla/ffi/api/c_api.h"
+#include "xla/ffi/api/c_api_internal.h"  // IWYU pragma: keep
+#include "xla/ffi/call_frame.h"
+#include "xla/runtime/memref_view.h"
+#include "xla/service/service_executable_run_options.h"
+#include "xla/status.h"
+#include "xla/statusor.h"
+#include "xla/stream_executor/device_memory.h"
+#include "xla/types.h"  // IWYU pragma: keep
+#include "xla/xla_data.pb.h"
+
+namespace xla::ffi {
+
+//===----------------------------------------------------------------------===//
+// Arguments
+//===----------------------------------------------------------------------===//
+
+struct Buffer {
+  PrimitiveType primitive_type;
+  se::DeviceMemoryBase data;
+  absl::Span<const int64_t> dimensions;
+
+  // TODO(ezhulenev): Remove this implicit conversion once we'll migrate to FFI
+  // handlers from runtime custom calls.
+  operator runtime::MemrefView() {  // NOLINT
+    return runtime::MemrefView{primitive_type, data.opaque(), dimensions};
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// Arguments decoding
+//===----------------------------------------------------------------------===//
+
+template <>
+struct ArgDecoding<Buffer> {
+  static std::optional<Buffer> Decode(XLA_FFI_ArgType type, void* arg) {
+    if (type != XLA_FFI_ArgType_BUFFER) return std::nullopt;
+    auto* buf = reinterpret_cast<XLA_FFI_Buffer*>(arg);
+
+    Buffer buffer;
+    buffer.primitive_type = PrimitiveType(buf->primitive_type);
+    buffer.data = se::DeviceMemoryBase(buf->data);
+    buffer.dimensions = absl::MakeConstSpan(buf->dims, buf->rank);
+    return buffer;
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// Context decoding
+//===----------------------------------------------------------------------===//
+
+template <>
+struct CtxDecoding<ServiceExecutableRunOptions> {
+  using Type = const ServiceExecutableRunOptions*;
+
+  static std::optional<Type> Decode(const XLA_FFI_Api* api,
+                                    XLA_FFI_ExecutionContext* ctx) {
+    void* ptr = api->internal_api->XLA_FFI_ServiceExecutableRunOptions_Get(ctx);
+    return reinterpret_cast<Type>(ptr);
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// Result encoding
+//===----------------------------------------------------------------------===//
+
+template <>
+struct ResultEncoding<Status> {
+  static XLA_FFI_Error* Encode(XLA_FFI_Api* api, Status status) {
+    return api->internal_api->XLA_FFI_Error_Forward(&status);
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// Result encoding
+//===----------------------------------------------------------------------===//
+
+// Takes ownership of the XLA FFI error and returns underlying status. Frees
+// `error` if it's not nullptr; returns OK status otherwise.
+Status TakeStatus(XLA_FFI_Error* error);
+
+struct CallOptions {
+  const ServiceExecutableRunOptions* run_options = nullptr;
+};
+
+Status Call(Ffi& handler, CallFrame& call_frame,
+            const CallOptions& options = {});
+
+Status Call(XLA_FFI_Handler* handler, CallFrame& call_frame,
+            const CallOptions& options = {});
+
+//===----------------------------------------------------------------------===//
+// XLA FFI registry
+//===----------------------------------------------------------------------===//
+
+// Returns registered FFI handler for a given name, or an error if it's not
+// found in the static registry.
+StatusOr<XLA_FFI_Handler*> FindHandler(std::string_view name);
+
+//===----------------------------------------------------------------------===//
+// XLA FFI Api Implementation
+//===----------------------------------------------------------------------===//
+
+XLA_FFI_Api* GetXlaFfiApi();
+
+}  // namespace xla::ffi
+
+#endif  // XLA_FFI_FFI_H_
diff --git a/third_party/xla/xla/ffi/ffi_test.cc b/third_party/xla/xla/ffi/ffi_test.cc
new file mode 100644
index 0000000..cd93c39
--- /dev/null
+++ b/third_party/xla/xla/ffi/ffi_test.cc
@@ -0,0 +1,161 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/ffi/ffi.h"
+
+#include <cstdint>
+#include <string_view>
+#include <vector>
+
+#include "absl/status/status.h"
+#include "xla/ffi/call_frame.h"
+#include "xla/service/service_executable_run_options.h"
+#include "xla/stream_executor/device_memory.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/lib/core/status_test_util.h"
+#include "tsl/platform/test.h"
+
+namespace xla::ffi {
+
+TEST(FfiTest, StaticRegistration) {
+  static constexpr auto* noop = +[] { return absl::OkStatus(); };
+
+  XLA_FFI_DEFINE_HANDLER(NoOp, noop, Ffi::Bind());
+  XLA_FFI_REGISTER_HANDLER(GetXlaFfiApi(), "no-op", NoOp);
+
+  auto handler = FindHandler("no-op");
+  TF_ASSERT_OK(handler.status());
+}
+
+TEST(FfiTest, ForwardError) {
+  auto call_frame = CallFrameBuilder().Build();
+  auto handler = Ffi::Bind().To([] { return absl::AbortedError("Ooops!"); });
+  auto status = Call(*handler, call_frame);
+  ASSERT_EQ(status.message(), "Ooops!");
+}
+
+TEST(FfiTest, WrongNumArgs) {
+  CallFrameBuilder builder;
+  builder.AddBufferArg(se::DeviceMemoryBase(nullptr), PrimitiveType::F32, {});
+  auto call_frame = builder.Build();
+
+  auto handler = Ffi::Bind().Arg<Buffer>().Arg<Buffer>().To(
+      [](Buffer, Buffer) { return absl::OkStatus(); });
+
+  auto status = Call(*handler, call_frame);
+
+  ASSERT_EQ(status.message(),
+            "Wrong number of arguments: expected 2 but got 1");
+}
+
+TEST(FfiTest, WrongNumAttrs) {
+  CallFrameBuilder builder;
+  builder.AddI32Attr("i32", 42);
+  builder.AddF32Attr("f32", 42.0f);
+  auto call_frame = builder.Build();
+
+  auto handler = Ffi::Bind().Attr<int32_t>("i32").To(
+      [](int32_t) { return absl::OkStatus(); });
+
+  auto status = Call(*handler, call_frame);
+
+  ASSERT_EQ(status.message(),
+            "Wrong number of attributes: expected 1 but got 2");
+}
+
+TEST(FfiTest, BuiltinAttributes) {
+  CallFrameBuilder builder;
+  builder.AddI32Attr("i32", 42);
+  builder.AddF32Attr("f32", 42.0f);
+  builder.AddStringAttr("str", "foo");
+  auto call_frame = builder.Build();
+
+  auto fn = [&](int32_t i32, float f32, std::string_view str) {
+    EXPECT_EQ(i32, 42);
+    EXPECT_EQ(f32, 42.0f);
+    EXPECT_EQ(str, "foo");
+    return absl::OkStatus();
+  };
+
+  auto handler = Ffi::Bind()
+                     .Attr<int32_t>("i32")
+                     .Attr<float>("f32")
+                     .Attr<std::string_view>("str")
+                     .To(fn);
+
+  auto status = Call(*handler, call_frame);
+
+  TF_ASSERT_OK(status);
+}
+
+TEST(FfiTest, DecodingErrors) {
+  CallFrameBuilder builder;
+  builder.AddI32Attr("i32", 42);
+  builder.AddF32Attr("f32", 42.0f);
+  builder.AddStringAttr("str", "foo");
+  auto call_frame = builder.Build();
+
+  auto fn = [](int32_t, float, std::string_view) { return absl::OkStatus(); };
+
+  auto handler = Ffi::Bind()
+                     .Attr<int32_t>("not_i32_should_fail")
+                     .Attr<float>("f32")
+                     .Attr<std::string_view>("not_str_should_fail")
+                     .To(fn);
+
+  auto status = Call(*handler, call_frame);
+
+  ASSERT_EQ(
+      status.message(),
+      "Failed to decode all FFI handler operands (bad operands at: 0, 2)");
+}
+
+TEST(FfiTest, BufferArgument) {
+  std::vector<float> storage(4, 0.0f);
+  se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float));
+
+  CallFrameBuilder builder;
+  builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2});
+  auto call_frame = builder.Build();
+
+  auto fn = [&](Buffer buffer) {
+    EXPECT_EQ(buffer.data.opaque(), storage.data());
+    EXPECT_EQ(buffer.primitive_type, PrimitiveType::F32);
+    EXPECT_EQ(buffer.dimensions.size(), 2);
+    return absl::OkStatus();
+  };
+
+  auto handler = Ffi::Bind().Arg<Buffer>().To(fn);
+  auto status = Call(*handler, call_frame);
+
+  TF_ASSERT_OK(status);
+}
+
+TEST(FfiTest, RunOptionsCtx) {
+  auto call_frame = CallFrameBuilder().Build();
+  auto* expected = reinterpret_cast<ServiceExecutableRunOptions*>(0x01234567);
+
+  auto fn = [&](const ServiceExecutableRunOptions* run_options) {
+    EXPECT_EQ(run_options, expected);
+    return absl::OkStatus();
+  };
+
+  auto handler = Ffi::Bind().Ctx<ServiceExecutableRunOptions>().To(fn);
+  auto status = Call(*handler, call_frame, {expected});
+
+  TF_ASSERT_OK(status);
+}
+
+}  // namespace xla::ffi
diff --git a/third_party/xla/xla/hlo/evaluator/BUILD b/third_party/xla/xla/hlo/evaluator/BUILD
index 7b3af5b..c39e07f 100644
--- a/third_party/xla/xla/hlo/evaluator/BUILD
+++ b/third_party/xla/xla/hlo/evaluator/BUILD
@@ -1,8 +1,8 @@
 # Description:
 #   XLA evaluator implementation.
 
-load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library")
 load("//xla:xla.bzl", "xla_cc_test")
+load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library")
 
 package(
     default_visibility = ["//visibility:public"],
@@ -94,14 +94,21 @@
 xla_cc_test(
     name = "hlo_evaluator_test",
     srcs = ["hlo_evaluator_test.cc"],
+    tags = [
+        "noasan",  # times out
+    ],
     deps = [
         ":hlo_evaluator",
+        "//xla:array2d",
+        "//xla:array3d",
+        "//xla:array4d",
+        "//xla:comparison_util",
+        "//xla:debug_options_flags",
+        "//xla:error_spec",
         "//xla:literal",
+        "//xla:literal_util",
         "//xla:permutation_util",
-        "//xla:reference_util",
         "//xla:shape_util",
-        "//xla:status",
-        "//xla:status_macros",
         "//xla:statusor",
         "//xla:test",
         "//xla:types",
@@ -109,14 +116,21 @@
         "//xla:xla_data_proto_cc",
         "//xla/client:xla_builder",
         "//xla/hlo/ir:hlo",
+        "//xla/service:dynamic_dimension_inference",
         "//xla/service:hlo_element_type_converter",
+        "//xla/service:hlo_module_config",
+        "//xla/service:shape_inference",
         "//xla/tests:hlo_test_base",
         "//xla/tests:literal_test_util",
         "//xla/tests:test_utils",
         "//xla/tests:xla_internal_test_main",  # fixdeps: keep
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/strings",
         "@com_google_absl//absl/strings:str_format",
-        "@local_tsl//tsl/lib/core:status_test_util",
-        "@local_tsl//tsl/platform:status",
+        "@com_google_absl//absl/types:span",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:statusor",
         "@local_tsl//tsl/platform:test",
         "@local_tsl//tsl/platform:test_benchmark",
     ],
diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_test.cc b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_test.cc
index aeb11a2..a831f51 100644
--- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_test.cc
+++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_test.cc
@@ -14,25 +14,44 @@
 ==============================================================================*/
 #include "xla/hlo/evaluator/hlo_evaluator.h"
 
+#include <array>
+#include <complex>
+#include <cstdint>
 #include <initializer_list>
+#include <limits>
 #include <memory>
+#include <numeric>
 #include <optional>
 #include <string>
-#include <tuple>
 #include <utility>
 #include <vector>
 
+#include "absl/algorithm/container.h"
+#include "absl/log/check.h"
 #include "absl/strings/str_format.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/array2d.h"
+#include "xla/array3d.h"
+#include "xla/array4d.h"
 #include "xla/client/xla_builder.h"
+#include "xla/comparison_util.h"
+#include "xla/debug_options_flags.h"
+#include "xla/error_spec.h"
 #include "xla/hlo/ir/hlo_computation.h"
 #include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/layout_util.h"
 #include "xla/literal.h"
+#include "xla/literal_util.h"
 #include "xla/permutation_util.h"
-#include "xla/reference_util.h"
+#include "xla/primitive_util.h"
+#include "xla/service/dynamic_dimension_inference.h"
 #include "xla/service/hlo_element_type_converter.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/service/shape_inference.h"
+#include "xla/shape.h"
 #include "xla/shape_util.h"
-#include "xla/status.h"
-#include "xla/status_macros.h"
 #include "xla/statusor.h"
 #include "xla/test.h"
 #include "xla/tests/hlo_test_base.h"
@@ -41,8 +60,8 @@
 #include "xla/types.h"
 #include "xla/util.h"
 #include "xla/xla_data.pb.h"
-#include "tsl/lib/core/status_test_util.h"
-#include "tsl/platform/status.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
 #include "tsl/platform/test.h"
 #include "tsl/platform/test_benchmark.h"
 
@@ -4340,14 +4359,19 @@
   c2 = s32[] constant(-2147483648)  // -2^31
   sub = s32[] subtract(c2, c1)  // -2^31 - 2^30, underflows
 
+  c3 = u32[] constant(4294967295)
+  c4 = u32[] constant(33)
+
   mul = s32[] multiply(c1, c1)
-  ROOT tuple = (s32[], s32[], s32[]) tuple(sum, sub, mul)
+
+  pow = u32[] power(c3, c4)
+  ROOT tuple = (s32[], s32[], s32[], u32[]) tuple(sum, sub, mul, pow)
 }
 )";
   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
   TF_ASSERT_OK_AND_ASSIGN(auto literal, Evaluate({}));
   std::vector<Literal> actual = literal.DecomposeTuple();
-  ASSERT_EQ(actual.size(), 3);
+  ASSERT_EQ(actual.size(), 4);
 
   uint32_t pow30 = uint32_t{1} << 30;
   uint32_t pow31 = uint32_t{1} << 31;
@@ -4356,6 +4380,7 @@
             static_cast<int32_t>(-(pow31 + pow30)));
   EXPECT_EQ(actual[2].GetFirstElement<int32_t>(),
             static_cast<int32_t>(pow31 * pow31));
+  EXPECT_EQ(actual[3].GetFirstElement<uint32_t>(), uint32_t{4294967295});
 }
 
 TEST_F(HloEvaluatorTest, GetDimensionSize) {
@@ -4369,7 +4394,7 @@
   
   data_dynamic = s32[<=4] set-dimension-size(data, size), dimensions={0}
 
-  sum = s32[4] add(data_dynamic, data)
+  sum = s32[<=4] add(data_dynamic, data)
 
   ROOT dynamic_size = s32[] get-dimension-size(sum), dimensions={0}
 }
diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h
index 3ce52b0..63b4802 100644
--- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h
+++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h
@@ -43,6 +43,7 @@
 #include "xla/literal_util.h"
 #include "xla/primitive_util.h"
 #include "xla/service/shape_inference.h"
+#include "xla/types.h"
 #include "xla/util.h"
 #include "xla/xla_data.pb.h"
 #include "tsl/platform/statusor.h"
@@ -522,7 +523,7 @@
           // 1. inf^(a + 0i) = inf, if a > 0.
           // 2. inf^(a + 0i) = 0, if a < 0.
           if constexpr (is_complex_v<ElementwiseT>) {
-            auto is_positive_infinity = [](auto c) {
+            auto is_positive_infinity = [](ElementwiseT c) {
               return c.imag() == 0 && c.real() > 0 && std::isinf(c.real());
             };
             auto is_positive_real = [](ElementwiseT c) {
@@ -531,7 +532,7 @@
             auto is_negative_real = [](ElementwiseT c) {
               return c.real() < 0 && c.imag() == 0;
             };
-            if (is_positive_infinity(lhs_el) && is_positive_real(rhs_el) > 0) {
+            if (is_positive_infinity(lhs_el) && is_positive_real(rhs_el)) {
               return static_cast<ElementwiseT>(lhs_el);
             }
             if (is_positive_infinity(lhs_el) && is_negative_real(rhs_el)) {
@@ -539,8 +540,21 @@
             }
           }
           // Case 3:
-          // Fallback to std::pow.
-          return static_cast<ElementwiseT>(std::pow(lhs_el, rhs_el));
+          // Fallback to pow.
+          if constexpr (std::is_same_v<ElementwiseT, bool>) {
+            return lhs_el || !rhs_el;
+          } else if constexpr (std::is_integral_v<ElementwiseT>) {
+            if constexpr (std::is_signed_v<ElementwiseT>) {
+              if (rhs_el < static_cast<ElementwiseT>(0)) {
+                return static_cast<ElementwiseT>(
+                    lhs_el == static_cast<ElementwiseT>(1) ? 1 : 0);
+              }
+            }
+            return static_cast<ElementwiseT>(
+                IPow<std::make_unsigned_t<ElementwiseT>>(lhs_el, rhs_el));
+          } else {
+            return static_cast<ElementwiseT>(std::pow(lhs_el, rhs_el));
+          }
         }));
     return OkStatus();
   }
diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD
index 5baace7..4ff47c6 100644
--- a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD
+++ b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD
@@ -138,6 +138,7 @@
 
 cc_library(
     name = "auto_sharding_option",
+    srcs = ["auto_sharding_option.cc"],
     hdrs = ["auto_sharding_option.h"],
     visibility = ["//visibility:public"],
     deps = [
@@ -194,7 +195,7 @@
     hdrs = ["cluster_environment.h"],
     visibility = ["//visibility:public"],
     deps = [
-        ":auto_sharding_solver_option",
+        ":auto_sharding_option",
         ":auto_sharding_strategy",
         ":auto_sharding_util",
         ":profiling_result",
diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc
index 76553da..695b024 100644
--- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc
+++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc
@@ -1102,7 +1102,7 @@
 void DisableIncompatibleMixedMeshShapeAndForceBatchDim(
     const InstructionBatchDimMap& batch_dim_map,
     const std::vector<HloInstruction*>& instructions, int num_devices,
-    AutoShardingSolverOption& solver_option) {
+    AutoShardingOption& option) {
   int64_t batch_size = INT_MAX;
   for (const auto& iter : batch_dim_map) {
     batch_size = std::min(batch_size, FindInstruction(instructions, iter.first)
@@ -1111,36 +1111,36 @@
   }
 
   if (IsDivisible(batch_size, num_devices)) {
-    if (solver_option.allow_mixed_mesh_shape) {
-      solver_option.allow_mixed_mesh_shape = false;
+    if (option.allow_mixed_mesh_shape) {
+      option.allow_mixed_mesh_shape = false;
       LOG(WARNING)
           << "Mixed mesh shape is disabled due to indivisible batch size.";
     }
   }
 
   if (batch_size == 1) {
-    solver_option.force_batch_dim_to_mesh_dim = -1;
+    option.force_batch_dim_to_mesh_dim = -1;
   }
 }
 
 StatusOr<std::unique_ptr<StrategyVector>> CreateAllStrategiesVector(
     const HloInstruction* ins, const Shape& shape, size_t instruction_id,
     LeafStrategies& leaf_strategies, const ClusterEnvironment& cluster_env,
-    const StrategyMap& strategy_map,
-    const AutoShardingSolverOption& solver_option, double replicated_penalty,
-    const InstructionBatchDimMap& batch_dim_map, const CallGraph& call_graph,
-    bool only_allow_divisible, bool create_replicated_strategies) {
+    const StrategyMap& strategy_map, const AutoShardingOption& option,
+    double replicated_penalty, const InstructionBatchDimMap& batch_dim_map,
+    const CallGraph& call_graph, bool only_allow_divisible,
+    bool create_replicated_strategies) {
   std::unique_ptr<StrategyVector> strategies;
   if (shape.IsTuple()) {
     strategies = CreateTupleStrategyVector(instruction_id);
     strategies->childs.reserve(shape.tuple_shapes_size());
     for (size_t i = 0; i < shape.tuple_shapes_size(); ++i) {
       auto child_strategies =
-          CreateAllStrategiesVector(
-              ins, shape.tuple_shapes(i), instruction_id, leaf_strategies,
-              cluster_env, strategy_map, solver_option, replicated_penalty,
-              batch_dim_map, call_graph, only_allow_divisible,
-              create_replicated_strategies)
+          CreateAllStrategiesVector(ins, shape.tuple_shapes(i), instruction_id,
+                                    leaf_strategies, cluster_env, strategy_map,
+                                    option, replicated_penalty, batch_dim_map,
+                                    call_graph, only_allow_divisible,
+                                    create_replicated_strategies)
               .value();
       child_strategies->tuple_element_idx = i;
       strategies->childs.push_back(std::move(child_strategies));
@@ -1164,7 +1164,7 @@
                             only_allow_divisible, call_graph, /*partitions*/ 3);
     }
 
-    if (solver_option.allow_mixed_mesh_shape && cluster_env.IsDeviceMesh2D()) {
+    if (option.allow_mixed_mesh_shape && cluster_env.IsDeviceMesh2D()) {
       // Set penalty for 1d partial tiled layout
       for (size_t i = 0; i < strategies->leaf_vector.size(); ++i) {
         strategies->leaf_vector[i].compute_cost += replicated_penalty * 0.8;
@@ -1182,10 +1182,10 @@
 
     // If force_batch_dim_to_mesh_dim is set, filter out invalid strategies
     // and only keep the data parallel strategies.
-    if (solver_option.force_batch_dim_to_mesh_dim >= 0 &&
+    if (option.force_batch_dim_to_mesh_dim >= 0 &&
         batch_dim_map.contains(GetBatchDimMapKey(ins))) {
       TF_RETURN_IF_ERROR(FilterStrategy(ins, shape, strategies, cluster_env,
-                                        batch_dim_map, solver_option));
+                                        batch_dim_map, option));
     }
   } else if (shape.IsToken()) {
     strategies = CreateLeafStrategyVector(instruction_id, ins, strategy_map,
@@ -1201,14 +1201,13 @@
 StatusOr<std::unique_ptr<StrategyVector>> CreateParameterStrategyVector(
     const HloInstruction* ins, const Shape& shape, size_t instruction_id,
     LeafStrategies& leaf_strategies, const ClusterEnvironment& cluster_env,
-    const StrategyMap& strategy_map,
-    const AutoShardingSolverOption& solver_option, double replicated_penalty,
-    const InstructionBatchDimMap& batch_dim_map, const CallGraph& call_graph,
-    bool only_allow_divisible) {
+    const StrategyMap& strategy_map, const AutoShardingOption& option,
+    double replicated_penalty, const InstructionBatchDimMap& batch_dim_map,
+    const CallGraph& call_graph, bool only_allow_divisible) {
   return CreateAllStrategiesVector(
       ins, shape, instruction_id, leaf_strategies, cluster_env, strategy_map,
-      solver_option, replicated_penalty, batch_dim_map, call_graph,
-      only_allow_divisible, solver_option.allow_replicated_parameters);
+      option, replicated_penalty, batch_dim_map, call_graph,
+      only_allow_divisible, option.allow_replicated_parameters);
 }
 
 // The sharding is replicated or the total number of tiles is over or equal to
@@ -1515,6 +1514,183 @@
   }
 }
 
+// Enumerates sharding strategies for elementwise operators by following
+// strategies of an operand of the elementwise op.
+std::unique_ptr<StrategyVector> CreateElementwiseOperatorStrategies(
+    size_t instruction_id, const HloInstruction* ins,
+    const StrategyMap& strategy_map, const ClusterEnvironment& cluster_env,
+    const InstructionDepthMap& depth_map, const AliasMap& alias_map,
+    const StableHashMap<int64_t, std::vector<ShardingStrategy>>&
+        pretrimmed_strategy_map,
+    int64_t max_depth, LeafStrategies& leaf_strategies,
+    AssociativeDotPairs& associative_dot_pairs) {
+  std::unique_ptr<StrategyVector> strategies = CreateLeafStrategyVector(
+      instruction_id, ins, strategy_map, leaf_strategies);
+
+  // Choose an operand to follow
+  int64_t follow_idx;
+  bool tie;
+  std::tie(follow_idx, tie) =
+      ChooseOperandToFollow(strategy_map, depth_map, alias_map, max_depth, ins);
+
+  if (!tie || AllowTieFollowing(ins)) {
+    strategies->following = strategy_map.at(ins->operand(follow_idx)).get();
+  } else {
+    strategies->following = nullptr;
+  }
+
+  // Get all possible sharding specs from operands
+  for (int64_t i = 0; i < ins->operand_count(); ++i) {
+    if (strategies->following != nullptr && i != follow_idx) {
+      // If ins follows one operand, do not consider sharding specs from
+      // other operands.
+      continue;
+    }
+
+    auto process_src_strategies =
+        [&](const std::vector<ShardingStrategy>& src_strategies_leaf_vector) {
+          for (int64_t sid = 0; sid < src_strategies_leaf_vector.size();
+               ++sid) {
+            HloSharding output_spec =
+                src_strategies_leaf_vector[sid].output_sharding;
+            std::string name = ToStringSimple(output_spec);
+            double compute_cost = 0, communication_cost = 0;
+            double memory_cost =
+                GetBytes(ins->shape()) / output_spec.NumTiles();
+            std::vector<std::vector<double>> resharding_costs;
+            std::vector<std::optional<HloSharding>> input_shardings;
+            for (int64_t k = 0; k < ins->operand_count(); ++k) {
+              resharding_costs.push_back(ReshardingCostVector(
+                  strategy_map.at(ins->operand(k)).get(),
+                  ins->operand(k)->shape(), output_spec, cluster_env));
+              input_shardings.push_back(output_spec);
+            }
+
+            strategies->leaf_vector.push_back(ShardingStrategy(
+                {name, output_spec, compute_cost, communication_cost,
+                 memory_cost, std::move(resharding_costs), input_shardings}));
+          }
+        };
+    StrategyVector* src_strategies = strategy_map.at(ins->operand(i)).get();
+    CHECK(!src_strategies->is_tuple);
+
+    process_src_strategies(src_strategies->leaf_vector);
+    if (pretrimmed_strategy_map.contains(src_strategies->node_idx)) {
+      process_src_strategies(
+          pretrimmed_strategy_map.at(src_strategies->node_idx));
+    }
+  }
+  if (ins->opcode() == HloOpcode::kAdd) {
+    // Adjust the resharding costs for AllReduceReassociate pass.
+    // The AllReduceReassociate pass can simplify
+    // allreduce(x) + allreduce(y) to allreduce(x + y),
+    // so we adjust the resharding costs to reflect this optimization.
+
+    // TODO(zhuohan): The current implementation only works for
+    // x = a + b. We also need to cover cases where there are
+    // more than two operands (i.e., x = a + b + c).
+    if (ins->operand(0)->opcode() == HloOpcode::kDot &&
+        ins->operand(1)->opcode() == HloOpcode::kDot) {
+      associative_dot_pairs.push_back({strategy_map.at(ins->operand(0)).get(),
+                                       strategy_map.at(ins->operand(1)).get()});
+    }
+  }
+  return strategies;
+}
+
+// Enumerates sharding strategies for reshape operators. The function does so by
+// essentially reshaping the sharding of the operand in a manner similar to the
+// tensor reshape itself.
+std::unique_ptr<StrategyVector> CreateReshapeStrategies(
+    size_t instruction_id, const HloInstruction* ins,
+    const StrategyMap& strategy_map, const ClusterEnvironment& cluster_env,
+    bool only_allow_divisible, double replicated_penalty,
+    const InstructionBatchDimMap& batch_dim_map,
+    const AutoShardingOption& option, LeafStrategies& leaf_strategies) {
+  std::unique_ptr<StrategyVector> strategies = CreateLeafStrategyVector(
+      instruction_id, ins, strategy_map, leaf_strategies);
+  const HloInstruction* operand = ins->operand(0);
+  const Array<int64_t>& device_mesh = cluster_env.device_mesh_;
+  const Array<int64_t>& device_mesh_1d = cluster_env.device_mesh_1d_;
+
+  int mesh_nn_dims = VectorGreaterThanOneElementCount(device_mesh.dimensions());
+  if (mesh_nn_dims < 2 || !option.allow_mixed_mesh_shape) {
+    // Create follow strategies
+    const StrategyVector* src_strategies = strategy_map.at(operand).get();
+    CHECK(!src_strategies->is_tuple);
+    strategies->following = src_strategies;
+
+    for (int64_t sid = 0; sid < src_strategies->leaf_vector.size(); ++sid) {
+      std::optional<HloSharding> output_spec =
+          hlo_sharding_util::ReshapeSharding(
+              operand->shape(), ins->shape(),
+              src_strategies->leaf_vector[sid].output_sharding);
+
+      if (!output_spec.has_value()) {
+        continue;
+      }
+
+      if (!IsValidTileAssignment(*output_spec)) {
+        continue;
+      }
+
+      if (!TileAssignmentMatchesMesh(*output_spec, device_mesh)) {
+        continue;
+      }
+      std::string name = ToStringSimple(*output_spec);
+      double compute_cost = 0, communication_cost = 0;
+      double memory_cost = GetBytes(ins->shape()) / output_spec->NumTiles();
+      std::vector<double> resharding_costs = ReshardingCostVector(
+          src_strategies, operand->shape(),
+          src_strategies->leaf_vector[sid].output_sharding, cluster_env);
+      strategies->leaf_vector.push_back(ShardingStrategy(
+          {name,
+           *output_spec,
+           compute_cost,
+           communication_cost,
+           memory_cost,
+           {resharding_costs},
+           {src_strategies->leaf_vector[sid].output_sharding}}));
+    }
+  }
+
+  // Fail to create follow strategies, enumerate all possible cases
+  if (strategies->leaf_vector.empty()) {
+    strategies->leaf_vector.clear();
+    strategies->following = nullptr;
+
+    // Split 1 dim
+    if (cluster_env.IsDeviceMesh1D()) {
+      EnumerateAll1DPartitionReshape(ins, device_mesh, cluster_env,
+                                     strategy_map, strategies,
+                                     only_allow_divisible, "");
+    }
+    if (option.allow_mixed_mesh_shape && cluster_env.IsDeviceMesh2D()) {
+      // Split 1 dim, but for 1d mesh
+      EnumerateAll1DPartitionReshape(ins, device_mesh_1d, cluster_env,
+                                     strategy_map, strategies,
+                                     only_allow_divisible, " 1d");
+    }
+    if (cluster_env.IsDeviceMesh2D()) {
+      // Split 2 dim, one is always the batch dim
+      EnumeratePartitionReshape(ins, device_mesh, cluster_env, strategy_map,
+                                batch_dim_map, strategies, only_allow_divisible,
+                                /*partitions*/ 2);
+    }
+    if (cluster_env.IsDeviceMesh3D()) {
+      // Split 3 dim, one is always the batch dim
+      EnumeratePartitionReshape(ins, device_mesh, cluster_env, strategy_map,
+                                batch_dim_map, strategies, only_allow_divisible,
+                                /*partitions*/ 3);
+    }
+
+    // Replicate
+    AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map,
+                          strategies, replicated_penalty);
+  }
+  return strategies;
+}
+
 // NOLINTBEGIN(readability/fn_size)
 // TODO(zhuohan): Decompose this function into smaller pieces
 // Build possible sharding strategies and their costs for all instructions.
@@ -1527,8 +1703,7 @@
                      const InstructionBatchDimMap& batch_dim_map,
                      const AliasMap& alias_map,
                      const ClusterEnvironment& cluster_env,
-                     AutoShardingSolverOption& solver_option,
-                     const CallGraph& call_graph,
+                     AutoShardingOption& option, const CallGraph& call_graph,
                      bool trying_multiple_mesh_shapes) {
   const Array<int64_t>& device_mesh = cluster_env.device_mesh_;
   const Array<int64_t>& device_mesh_1d = cluster_env.device_mesh_1d_;
@@ -1544,9 +1719,6 @@
 
   const std::vector<HloInstruction*>& instructions = sequence.instructions();
 
-  // Count the non-one mesh dimension.
-  int mesh_nn_dims = VectorGreaterThanOneElementCount(device_mesh.dimensions());
-
   // Add penalty for replicated tensors
   double replicated_penalty = std::round(cluster_env.AllReduceCost(1, 0) +
                                          cluster_env.AllReduceCost(1, 1));
@@ -1576,20 +1748,19 @@
       // strategies, outputs would be constrained as welll, but if outputs are
       // still unevely sharded in some cases, we need to fix the implementation
       // in auto sharding.
-      only_allow_divisible = solver_option.only_allow_divisible_input_output;
+      only_allow_divisible = option.only_allow_divisible_input_output;
     } else {
-      only_allow_divisible = solver_option.only_allow_divisible_intermediate;
+      only_allow_divisible = option.only_allow_divisible_intermediate;
     }
     switch (opcode) {
       case HloOpcode::kParameter:
       case HloOpcode::kRngBitGenerator:
       case HloOpcode::kRng: {
-        strategies =
-            CreateParameterStrategyVector(
-                ins, ins->shape(), instruction_id, leaf_strategies, cluster_env,
-                strategy_map, solver_option, replicated_penalty, batch_dim_map,
-                call_graph, only_allow_divisible)
-                .value();
+        strategies = CreateParameterStrategyVector(
+                         ins, ins->shape(), instruction_id, leaf_strategies,
+                         cluster_env, strategy_map, option, replicated_penalty,
+                         batch_dim_map, call_graph, only_allow_divisible)
+                         .value();
         break;
       }
       case HloOpcode::kConstant: {
@@ -1732,7 +1903,7 @@
                                 cluster_env, strategy_map, strategies,
                                 batch_dim_map, only_allow_divisible, call_graph,
                                 /*partitions*/ 2);
-          if (solver_option.allow_mixed_mesh_shape) {
+          if (option.allow_mixed_mesh_shape) {
             EnumerateAll1DPartition(ins, ins->shape(),
                                     cluster_env.device_mesh_1d_, cluster_env,
                                     strategy_map, strategies,
@@ -1745,86 +1916,10 @@
         break;
       }
       case HloOpcode::kReshape: {
-        strategies = CreateLeafStrategyVector(instruction_id, ins, strategy_map,
-                                              leaf_strategies);
-        const HloInstruction* operand = ins->operand(0);
-        if (!(mesh_nn_dims >= 2 && solver_option.allow_mixed_mesh_shape)) {
-          // Create follow strategies
-          const StrategyVector* src_strategies = strategy_map.at(operand).get();
-          CHECK(!src_strategies->is_tuple);
-          strategies->following = src_strategies;
-
-          for (int64_t sid = 0; sid < src_strategies->leaf_vector.size();
-               ++sid) {
-            std::optional<HloSharding> output_spec =
-                hlo_sharding_util::ReshapeSharding(
-                    operand->shape(), ins->shape(),
-                    src_strategies->leaf_vector[sid].output_sharding);
-
-            if (!output_spec.has_value()) {
-              continue;
-            }
-
-            if (!IsValidTileAssignment(*output_spec)) {
-              continue;
-            }
-
-            if (!TileAssignmentMatchesMesh(*output_spec, device_mesh)) {
-              continue;
-            }
-            std::string name = ToStringSimple(*output_spec);
-            double compute_cost = 0, communication_cost = 0;
-            double memory_cost =
-                GetBytes(ins->shape()) / output_spec->NumTiles();
-            auto resharding_costs = ReshardingCostVector(
-                src_strategies, operand->shape(),
-                src_strategies->leaf_vector[sid].output_sharding, cluster_env);
-            strategies->leaf_vector.push_back(ShardingStrategy(
-                {name,
-                 *output_spec,
-                 compute_cost,
-                 communication_cost,
-                 memory_cost,
-                 {resharding_costs},
-                 {src_strategies->leaf_vector[sid].output_sharding}}));
-          }
-        }
-
-        // Fail to create follow strategies, enumerate all possible cases
-        if (strategies->leaf_vector.empty()) {
-          strategies->leaf_vector.clear();
-          strategies->following = nullptr;
-
-          // Split 1 dim
-          if (cluster_env.IsDeviceMesh1D()) {
-            EnumerateAll1DPartitionReshape(ins, device_mesh, cluster_env,
-                                           strategy_map, strategies,
-                                           only_allow_divisible, "");
-          }
-          if (solver_option.allow_mixed_mesh_shape &&
-              cluster_env.IsDeviceMesh2D()) {
-            // Split 1 dim, but for 1d mesh
-            EnumerateAll1DPartitionReshape(ins, device_mesh_1d, cluster_env,
-                                           strategy_map, strategies,
-                                           only_allow_divisible, " 1d");
-          }
-          if (cluster_env.IsDeviceMesh2D()) {
-            // Split 2 dim, one is always the batch dim
-            EnumeratePartitionReshape(ins, device_mesh, cluster_env,
-                                      strategy_map, batch_dim_map, strategies,
-                                      only_allow_divisible, /*partitions*/ 2);
-          }
-          if (cluster_env.IsDeviceMesh3D()) {
-            // Split 3 dim, one is always the batch dim
-            EnumeratePartitionReshape(ins, device_mesh, cluster_env,
-                                      strategy_map, batch_dim_map, strategies,
-                                      only_allow_divisible, /*partitions*/ 3);
-          }
-
-          // Replicate
-          AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map,
-                                strategies, replicated_penalty);
-        }
+        strategies = CreateReshapeStrategies(instruction_id, ins, strategy_map,
+                                             cluster_env, only_allow_divisible,
+                                             replicated_penalty, batch_dim_map,
+                                             option, leaf_strategies);
         break;
       }
       case HloOpcode::kTranspose:
@@ -1978,6 +2073,20 @@
             pretrimmed_strategy_map);
         break;
       }
+      case HloOpcode::kBitcast: {
+        if (ins->shape() == ins->operand(0)->shape()) {
+          strategies = CreateElementwiseOperatorStrategies(
+              instruction_id, ins, strategy_map, cluster_env, depth_map,
+              alias_map, pretrimmed_strategy_map, max_depth, leaf_strategies,
+              associative_dot_pairs);
+        } else {
+          strategies = CreateReshapeStrategies(
+              instruction_id, ins, strategy_map, cluster_env,
+              only_allow_divisible, replicated_penalty, batch_dim_map, option,
+              leaf_strategies);
+        }
+        break;
+      }
       // Unary elementwise operations.
       case HloOpcode::kAbs:
       case HloOpcode::kRoundNearestAfz:
@@ -1985,7 +2094,6 @@
       case HloOpcode::kCeil:
       case HloOpcode::kClz:
       case HloOpcode::kConvert:
-      case HloOpcode::kBitcast:
       case HloOpcode::kBitcastConvert:
       case HloOpcode::kCopy:
       case HloOpcode::kCos:
@@ -2031,86 +2139,17 @@
       // Ternary elementwise operations.
       case HloOpcode::kSelect:
       case HloOpcode::kClamp: {
-        strategies = CreateLeafStrategyVector(instruction_id, ins, strategy_map,
-                                              leaf_strategies);
-
-        // Choose an operand to follow
-        int64_t follow_idx;
-        bool tie;
-        std::tie(follow_idx, tie) = ChooseOperandToFollow(
-            strategy_map, depth_map, alias_map, max_depth, ins);
-
-        if (!tie || AllowTieFollowing(ins)) {
-          strategies->following =
-              strategy_map.at(ins->operand(follow_idx)).get();
-        } else {
-          strategies->following = nullptr;
-        }
-
-        // Get all possible sharding specs from operands
-        for (int64_t i = 0; i < ins->operand_count(); ++i) {
-          if (strategies->following != nullptr && i != follow_idx) {
-            // If ins follows one operand, do not consider sharding specs from
-            // other operands.
-            continue;
-          }
-
-          auto process_src_strategies = [&](const std::vector<ShardingStrategy>
-                                                src_strategies_leaf_vector) {
-            for (int64_t sid = 0; sid < src_strategies_leaf_vector.size();
-                 ++sid) {
-              HloSharding output_spec =
-                  src_strategies_leaf_vector[sid].output_sharding;
-              std::string name = ToStringSimple(output_spec);
-              double compute_cost = 0, communication_cost = 0;
-              double memory_cost =
-                  GetBytes(ins->shape()) / output_spec.NumTiles();
-              std::vector<std::vector<double>> resharding_costs;
-              std::vector<std::optional<HloSharding>> input_shardings;
-              for (int64_t k = 0; k < ins->operand_count(); ++k) {
-                resharding_costs.push_back(ReshardingCostVector(
-                    strategy_map.at(ins->operand(k)).get(),
-                    ins->operand(k)->shape(), output_spec, cluster_env));
-                input_shardings.push_back(output_spec);
-              }
-
-              strategies->leaf_vector.push_back(ShardingStrategy(
-                  {name, output_spec, compute_cost, communication_cost,
-                   memory_cost, std::move(resharding_costs), input_shardings}));
-            }
-          };
-          auto src_strategies = strategy_map.at(ins->operand(i)).get();
-          CHECK(!src_strategies->is_tuple);
-
-          process_src_strategies(src_strategies->leaf_vector);
-          if (pretrimmed_strategy_map.contains(src_strategies->node_idx)) {
-            process_src_strategies(
-                pretrimmed_strategy_map.at(src_strategies->node_idx));
-          }
-        }
-        if (ins->opcode() == HloOpcode::kAdd) {
-          // Adjust the resharding costs for AllReduceReassociate pass.
-          // The AllReduceReassociate pass can simplify
-          // allreduce(x) + allreduce(y) to allreduce(x + y),
-          // so we adjust the resharidng costs to reflect this optimization.
-
-          // TODO(zhuohan): The current implementation only works for
-          // x = a + b. We also need to cover cases where there are
-          // more than two operands (i.e., x = a + b + c).
-          if (ins->operand(0)->opcode() == HloOpcode::kDot &&
-              ins->operand(1)->opcode() == HloOpcode::kDot) {
-            associative_dot_pairs.push_back(
-                {strategy_map.at(ins->operand(0)).get(),
-                 strategy_map.at(ins->operand(1)).get()});
-          }
-        }
+        strategies = CreateElementwiseOperatorStrategies(
+            instruction_id, ins, strategy_map, cluster_env, depth_map,
+            alias_map, pretrimmed_strategy_map, max_depth, leaf_strategies,
+            associative_dot_pairs);
         break;
       }
       case HloOpcode::kReduce: {
         auto strategies_status = FollowReduceStrategy(
             ins, ins->shape(), ins->operand(0), ins->operand(1), instruction_id,
             strategy_map, leaf_strategies, cluster_env,
-            solver_option.allow_mixed_mesh_shape, !trying_multiple_mesh_shapes);
+            option.allow_mixed_mesh_shape, !trying_multiple_mesh_shapes);
         if (strategies_status.ok()) {
           strategies = std::move(strategies_status.value());
         } else {
@@ -2121,18 +2160,18 @@
       case HloOpcode::kDot: {
         TF_RETURN_IF_ERROR(HandleDot(strategies, leaf_strategies, strategy_map,
                                      ins, instruction_id, cluster_env,
-                                     batch_dim_map, solver_option, call_graph));
-        if (solver_option.allow_replicated_strategy_for_dot_and_conv) {
+                                     batch_dim_map, option, call_graph));
+        if (option.allow_replicated_strategy_for_dot_and_conv) {
           AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map,
                                 strategies, 0);
         }
         break;
       }
       case HloOpcode::kConvolution: {
-        TF_RETURN_IF_ERROR(HandleConv(
-            strategies, leaf_strategies, strategy_map, ins, instruction_id,
-            cluster_env, batch_dim_map, solver_option, call_graph));
-        if (solver_option.allow_replicated_strategy_for_dot_and_conv) {
+        TF_RETURN_IF_ERROR(HandleConv(strategies, leaf_strategies, strategy_map,
+                                      ins, instruction_id, cluster_env,
+                                      batch_dim_map, option, call_graph));
+        if (option.allow_replicated_strategy_for_dot_and_conv) {
           AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map,
                                 strategies, 0);
         }
@@ -2165,8 +2204,7 @@
                                 strategy_map, strategies, batch_dim_map,
                                 only_allow_divisible, call_graph, /*parts*/ 3);
         }
-        if (cluster_env.IsDeviceMesh2D() &&
-            solver_option.allow_mixed_mesh_shape) {
+        if (cluster_env.IsDeviceMesh2D() && option.allow_mixed_mesh_shape) {
           // Split 1 dim, but for 1d flattened version of the 2d mesh
           // For example, when the mesh shape is (2, 4), we add strategies for
           // mesh shape (1, 8) here in addition.
@@ -2230,9 +2268,8 @@
                   strategies =
                       CreateAllStrategiesVector(
                           ins, ins->shape(), instruction_id, leaf_strategies,
-                          cluster_env, strategy_map, solver_option,
-                          replicated_penalty, batch_dim_map, call_graph,
-                          only_allow_divisible, true)
+                          cluster_env, strategy_map, option, replicated_penalty,
+                          batch_dim_map, call_graph, only_allow_divisible, true)
                           .value();
                 }
               } else {
@@ -2246,9 +2283,8 @@
                   strategies =
                       CreateAllStrategiesVector(
                           ins, ins->shape(), instruction_id, leaf_strategies,
-                          cluster_env, strategy_map, solver_option,
-                          replicated_penalty, batch_dim_map, call_graph,
-                          only_allow_divisible, true)
+                          cluster_env, strategy_map, option, replicated_penalty,
+                          batch_dim_map, call_graph, only_allow_divisible, true)
                           .value();
                 }
               }
@@ -2306,13 +2342,12 @@
       case HloOpcode::kConditional:
       case HloOpcode::kInfeed:
       case HloOpcode::kSort: {
-        strategies =
-            CreateAllStrategiesVector(
-                ins, ins->shape(), instruction_id, leaf_strategies, cluster_env,
-                strategy_map, solver_option, replicated_penalty, batch_dim_map,
-                call_graph, only_allow_divisible,
-                /*create_replicated_strategies*/ true)
-                .value();
+        strategies = CreateAllStrategiesVector(
+                         ins, ins->shape(), instruction_id, leaf_strategies,
+                         cluster_env, strategy_map, option, replicated_penalty,
+                         batch_dim_map, call_graph, only_allow_divisible,
+                         /*create_replicated_strategies*/ true)
+                         .value();
         break;
       }
       case HloOpcode::kOutfeed: {
@@ -2339,7 +2374,7 @@
       TrimOrGenerateStrategiesBasedOnExistingSharding(
           ins->shape(), strategies.get(), strategy_map, instructions,
           ins->sharding(), cluster_env, pretrimmed_strategy_map, call_graph,
-          solver_option.nd_sharding_iteratively_strict_search_space);
+          option.nd_sharding_iteratively_strict_search_space);
     }
     if (!strategies->is_tuple && strategies->following) {
       if (!LeafVectorsAreConsistent(
@@ -2376,11 +2411,9 @@
     XLA_VLOG_LINES(2, absl::StrCat("strategies:\n", strategies->ToString()));
 
     // Debug options: forcibly set the strategy of some instructions.
-    if (solver_option.force_strategy) {
-      std::vector<int64_t> inst_indices =
-          solver_option.force_strategy_inst_indices;
-      std::vector<std::string> stra_names =
-          solver_option.force_strategy_stra_names;
+    if (option.force_strategy) {
+      std::vector<int64_t> inst_indices = option.force_strategy_inst_indices;
+      std::vector<std::string> stra_names = option.force_strategy_stra_names;
       CHECK_EQ(inst_indices.size(), stra_names.size());
       auto it = absl::c_find(inst_indices, strategies->node_idx);
       if (it != inst_indices.end()) {
@@ -2419,7 +2452,7 @@
 
   // If gradient accumulation is used, adjust the cost of all-reduce for
   // gradient synchronization.
-  if (solver_option.grad_acc_num_micro_batches > 1) {
+  if (option.grad_acc_num_micro_batches > 1) {
     // find gradient-computation instructions
     std::vector<const HloInstruction*> grad_insts =
         GetGradientComputationInstructions(instructions);
@@ -2429,7 +2462,7 @@
 
       for (auto& stra : stra_vector->leaf_vector) {
         if (absl::StrContains(stra.name, "allreduce")) {
-          stra.communication_cost /= solver_option.grad_acc_num_micro_batches;
+          stra.communication_cost /= option.grad_acc_num_micro_batches;
         }
       }
     }
@@ -2478,19 +2511,20 @@
   const std::vector<HloInstruction*>& instructions = sequence.instructions();
 
   // Serialize node costs
+  int num_nodes_without_default = 0;
   for (NodeIdx node_idx = 0; node_idx < request.num_nodes; ++node_idx) {
     const StrategyVector* strategies = leaf_strategies[node_idx];
     auto instruction_name = instructions.at(strategies->instruction_id)->name();
     request.instruction_names.push_back(
         absl::StrCat(instruction_name, " (id: ", node_idx, ")"));
     std::vector<double> ci, di, mi, pi;
-    auto default_strategy = HloSharding::Replicate();
+    std::optional<HloSharding> default_strategy;
     auto iter = sharding_propagation_solution.find(instruction_name);
     if (iter != sharding_propagation_solution.end()) {
       CHECK(iter->second->has_sharding()) << iter->second->ToString();
       default_strategy = iter->second->sharding();
       if (strategies->tuple_element_idx) {
-        const auto& tuple_elements = default_strategy.tuple_elements();
+        const auto& tuple_elements = default_strategy->tuple_elements();
         CHECK_LT(*strategies->tuple_element_idx, tuple_elements.size());
         default_strategy = tuple_elements.at(*strategies->tuple_element_idx);
       }
@@ -2502,15 +2536,20 @@
       di.push_back(strategy.communication_cost +
                    cost_graph.extra_node_costs_[node_idx][j]);
       mi.push_back(strategy.memory_cost);
-      // TODO(moffitt): Revisit the default strategy below, which is currently
-      // defined as the "trivial sharding" in hlo_sharding.h
-      pi.push_back(sharding == default_strategy ? 0.0 : 1.0);
+      pi.push_back(default_strategy && sharding == *default_strategy ? 0 : 1);
+    }
+    if (*std::min_element(pi.begin(), pi.end()) > 0) {
+      LOG(WARNING) << "No default strategy for {node_idx " << node_idx
+                   << ", instruction ID " << strategies->instruction_id
+                   << ", instruction name " << instruction_name << "}";
+      ++num_nodes_without_default;
     }
     request.c.push_back(ci);
     request.d.push_back(di);
     request.m.push_back(mi);
     request.p.push_back(pi);
   }
+  LOG(INFO) << "Total nodes without default: " << num_nodes_without_default;
 
   // Serialize special edges that forces a alias pair have the same sharding
   // spec
@@ -3837,14 +3876,14 @@
   output->set_sharding(HloSharding::Tuple(tuple_sharding));
 }
 
-// Filter strategies according to the solver_option.force_batch_dim_to_mesh_dim.
+// Filter strategies according to the option.force_batch_dim_to_mesh_dim.
 // This can be used to forcibly generate data-parallel strategies.
 Status FilterStrategy(const HloInstruction* ins, const Shape& shape,
                       std::unique_ptr<StrategyVector>& strategies,
                       const ClusterEnvironment& cluster_env,
                       const InstructionBatchDimMap& batch_map,
-                      const AutoShardingSolverOption& solver_option) {
-  int mesh_dim = solver_option.force_batch_dim_to_mesh_dim;
+                      const AutoShardingOption& option) {
+  int mesh_dim = option.force_batch_dim_to_mesh_dim;
   int batch_dim = batch_map.at(GetBatchDimMapKey(ins));
   const Array<int64_t>& device_mesh = cluster_env.device_mesh_;
 
@@ -4141,11 +4180,11 @@
   solver_option.override_reduce_scatter_cost = false;
   solver_option.override_all_to_all_cost = false;
 
-  if (option_.force_all_gather_cost) {
+  if (option_.force_override_all_gather_cost) {
     solver_option.override_all_gather_cost = true;
     solver_option.all_gather_cost = option_.all_gather_cost;
   }
-  if (option_.force_all_to_all_cost) {
+  if (option_.force_override_all_to_all_cost) {
     solver_option.override_all_to_all_cost = true;
     solver_option.all_to_all_cost = option_.all_to_all_cost;
   }
@@ -4306,7 +4345,7 @@
     spmd::ProfilingResult prof_result;
     spmd::ClusterEnvironment cluster_env(
         original_device_mesh, device_mesh, option_.device_mesh_alpha,
-        option_.device_mesh_beta, prof_result, solver_option);
+        option_.device_mesh_beta, prof_result, option_);
 
     XLA_VLOG_LINES(3, module->ToString());
     int64_t memory_lower_bound = spmd::MemoryBudgetLowerBound(
@@ -4344,10 +4383,10 @@
       return AutoShardingResult::kModuleChangedShardingPerformed;
     }
 
-    if (solver_option.force_batch_dim_to_mesh_dim >= 0) {
-      DisableIncompatibleMixedMeshShapeAndForceBatchDim(
+    if (option_.force_batch_dim_to_mesh_dim >= 0) {
+      spmd::DisableIncompatibleMixedMeshShapeAndForceBatchDim(
           batch_dim_map, sequence.instructions(), device_mesh.num_elements(),
-          solver_option);
+          option_);
     }
 
     // ----- Analyze depth -----
@@ -4362,7 +4401,7 @@
         std::tie(strategy_map, leaf_strategies, associative_dot_pairs),
         BuildStrategyAndCost(sequence, module, instruction_execution_counts,
                              ins_depth_map, batch_dim_map, alias_map,
-                             cluster_env, solver_option, *call_graph,
+                             cluster_env, option_, *call_graph,
                              option_.try_multiple_mesh_shapes));
     spmd::AliasSet alias_set = spmd::BuildAliasSet(module, strategy_map);
     CheckAliasSetCompatibility(alias_set, leaf_strategies, sequence);
diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h
index 23a8679..037ffe8 100644
--- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h
+++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h
@@ -157,22 +157,21 @@
                       std::unique_ptr<StrategyVector>& strategies,
                       const ClusterEnvironment& cluster_env,
                       const InstructionBatchDimMap& batch_map,
-                      const AutoShardingSolverOption& solver_option);
+                      const AutoShardingOption& option);
 
 Status HandleDot(std::unique_ptr<StrategyVector>& strategies,
                  LeafStrategies& leaf_strategies, StrategyMap& strategy_map,
                  const HloInstruction* ins, size_t instruction_id,
                  const ClusterEnvironment& cluster_env,
                  const InstructionBatchDimMap& batch_map,
-                 const AutoShardingSolverOption& solver_option,
-                 const CallGraph& call_graph);
+                 const AutoShardingOption& option, const CallGraph& call_graph);
 
 Status HandleConv(std::unique_ptr<StrategyVector>& strategies,
                   LeafStrategies& leaf_strategies, StrategyMap& strategy_map,
                   const HloInstruction* ins, size_t instruction_id,
                   const ClusterEnvironment& cluster_env,
                   const InstructionBatchDimMap& batch_map,
-                  const AutoShardingSolverOption& solver_option,
+                  const AutoShardingOption& option,
                   const CallGraph& call_graph);
 
 void AnnotateShardingWithSimpleHeuristic(HloModule* module,
diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc
index ac862a0..fb082f1 100644
--- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc
+++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc
@@ -1,6 +1,6 @@
-/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+/*Copyright 2022 The TensorFlow Authors.All Rights Reserved.
 
-Licensed under the Apache License, Version 2.0 (the "License");
+Licensed under the Apache License, Version 2.0(the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at
 
@@ -29,7 +29,7 @@
 #include "absl/types/span.h"
 #include "xla/array.h"
 #include "xla/hlo/experimental/auto_sharding/auto_sharding.h"
-#include "xla/hlo/experimental/auto_sharding/auto_sharding_solver_option.h"
+#include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h"
 #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h"
 #include "xla/hlo/experimental/auto_sharding/auto_sharding_util.h"
 #include "xla/hlo/experimental/auto_sharding/cluster_environment.h"
@@ -58,14 +58,13 @@
               StrategyMap& strategy_map, const HloInstruction* ins,
               const ClusterEnvironment& cluster_env,
               const InstructionBatchDimMap& batch_map,
-              const AutoShardingSolverOption& solver_option,
-              const CallGraph& call_graph)
+              const AutoShardingOption& option, const CallGraph& call_graph)
       : strategies_(strategies),
         strategy_map_(strategy_map),
         ins_(ins),
         cluster_env_(cluster_env),
         batch_map_(batch_map),
-        solver_option_(solver_option),
+        option_(option),
         call_graph_(call_graph),
         device_mesh_(cluster_env.device_mesh_),
         device_mesh_1d_(cluster_env.device_mesh_1d_),
@@ -101,7 +100,7 @@
       auto shape_dim = ins->shape().dimensions().at(tensor_dim);
       auto device_mesh_dim = device_mesh_.dim(mesh_dim);
       if (shape_dim < device_mesh_dim) return false;
-      if (solver_option_.only_allow_divisible_intermediate &&
+      if (option_.only_allow_divisible_intermediate &&
           !IsDivisible(shape_dim, device_mesh_dim))
         return false;
     }
@@ -174,7 +173,7 @@
   const HloInstruction* ins_;
   const ClusterEnvironment& cluster_env_;
   const InstructionBatchDimMap& batch_map_;
-  const AutoShardingSolverOption& solver_option_;
+  const AutoShardingOption& option_;
   const CallGraph& call_graph_;
 
   const Array<int64_t>& device_mesh_;
@@ -189,20 +188,19 @@
              StrategyMap& strategy_map, const HloInstruction* ins,
              const ClusterEnvironment& cluster_env,
              const InstructionBatchDimMap& batch_map,
-             const AutoShardingSolverOption& solver_option,
-             const CallGraph& call_graph)
+             const AutoShardingOption& option, const CallGraph& call_graph)
       : HandlerBase(strategies, strategy_map, ins, cluster_env, batch_map,
-                    solver_option, call_graph),
-        dot_dnums_(ins->dot_dimension_numbers()),
-        space_base_dim_(dot_dnums_.lhs_batch_dimensions_size()),
+                    option, call_graph),
+        space_base_dim_(
+            ins->dot_dimension_numbers().lhs_batch_dimensions_size()),
         lhs_con_dims_(
             ins->dot_dimension_numbers().lhs_contracting_dimensions()),
         rhs_con_dims_(
             ins->dot_dimension_numbers().rhs_contracting_dimensions()),
         lhs_batch_dims_(ins->dot_dimension_numbers().lhs_batch_dimensions()),
         rhs_batch_dims_(ins->dot_dimension_numbers().rhs_batch_dimensions()) {
-    std::tie(lhs_space_dims_, rhs_space_dims_) =
-        GetSpaceDims(lhs_->shape(), rhs_->shape(), dot_dnums_);
+    std::tie(lhs_space_dims_, rhs_space_dims_) = GetSpaceDims(
+        lhs_->shape(), rhs_->shape(), ins->dot_dimension_numbers());
     CHECK_EQ(lhs_con_dims_.size(), rhs_con_dims_.size());
     CHECK_EQ(lhs_batch_dims_.size(), rhs_batch_dims_.size());
   }
@@ -432,8 +430,7 @@
       const DimMap rhs_dim_map = {{rhs_con_dims_[e.i], e.mesh_dims[0]}};
       HloSharding output_spec = HloSharding::Replicate();
       double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles();
-      double compute_cost =
-          cluster_env_.DotCost(lhs_->shape(), rhs_->shape(), dot_dnums_);
+      double compute_cost = cluster_env_.DotCost(lhs_->shape(), rhs_->shape());
       double communication_cost =
           cluster_env_.AllReduceCost(memory_cost, e.mesh_dims[0]);
       MaybeAppend(name, output_spec, lhs_dim_map, rhs_dim_map, device_mesh_,
@@ -456,7 +453,7 @@
         if (lhs_->shape().dimensions(lhs_space_dims_[i]) < num_devices) {
           continue;
         }
-        if (solver_option_.only_allow_divisible_intermediate &&
+        if (option_.only_allow_divisible_intermediate &&
             !IsDivisible(lhs_->shape().dimensions(lhs_space_dims_[i]),
                          num_devices)) {
           continue;
@@ -474,7 +471,7 @@
         if (lhs_->shape().dimensions(lhs_con_dims_[i]) < num_devices) {
           continue;
         }
-        if (solver_option_.only_allow_divisible_intermediate &&
+        if (option_.only_allow_divisible_intermediate &&
             !IsDivisible(lhs_->shape().dimensions(lhs_con_dims_[i]),
                          num_devices)) {
           continue;
@@ -545,12 +542,11 @@
     RecomputeSplitBothContract();
 
     // Add 1d data parallel in multi-dimensional mesh
-    if (solver_option_.allow_mixed_mesh_shape) {
+    if (option_.allow_mixed_mesh_shape) {
       Add1DDataParallel();
     }
 
-    if (solver_option_.batch_matmul_always_split_batch &&
-        !lhs_batch_dims_.empty() &&
+    if (option_.batch_matmul_always_split_batch && !lhs_batch_dims_.empty() &&
         cluster_env_.non_zero_mesh_dims_.size() > 1) {
       // If there is a batch dim and the device mesh is multi-dimensional,
       // always split on batch dim. Clear all old strategies.
@@ -573,7 +569,7 @@
     // Split batch dim and contracting dim
     SplitBatchDimBothContract();
 
-    if (solver_option_.batch_matmul_always_split_batch &&
+    if (option_.batch_matmul_always_split_batch &&
         lhs_batch_dims_.size() == 2 &&
         absl::c_count_if(device_mesh_.dimensions(),
                          [](int64_t size) { return size > 1; }) > 1) {
@@ -586,24 +582,22 @@
     // Split batch dims.
     SplitTwoBatchDims();
 
-    if (solver_option_.allow_mixed_mesh_shape) {
+    if (option_.allow_mixed_mesh_shape) {
       Add1DBatchSplit();
     }
 
     // If force_batch_dim_to_mesh_dim is set, filter out invalid strategies
     // and only keep the data parallel strategies.
-    if (solver_option_.force_batch_dim_to_mesh_dim >= 0 &&
+    if (option_.force_batch_dim_to_mesh_dim >= 0 &&
         batch_map_.contains(GetBatchDimMapKey(ins_))) {
       TF_RETURN_IF_ERROR(FilterStrategy(ins_, ins_->shape(), strategies_,
-                                        cluster_env_, batch_map_,
-                                        solver_option_));
+                                        cluster_env_, batch_map_, option_));
     }
 
     return OkStatus();
   }
 
   // Dimension information
-  const DotDimensionNumbers& dot_dnums_;
   int64_t space_base_dim_;
   tsl::protobuf::RepeatedField<int64_t> lhs_space_dims_, rhs_space_dims_;
   const tsl::protobuf::RepeatedField<int64_t>& lhs_con_dims_;
@@ -618,13 +612,13 @@
                  const HloInstruction* ins, size_t instruction_id,
                  const ClusterEnvironment& cluster_env,
                  const InstructionBatchDimMap& batch_map,
-                 const AutoShardingSolverOption& solver_option,
+                 const AutoShardingOption& option,
                  const CallGraph& call_graph) {
   strategies = CreateLeafStrategyVector(instruction_id, ins, strategy_map,
                                         leaf_strategies);
 
   DotHandler handler(strategies, strategy_map, ins, cluster_env, batch_map,
-                     solver_option, call_graph);
+                     option, call_graph);
   TF_RETURN_IF_ERROR(handler.RegisterStrategies());
   return OkStatus();
 }
@@ -635,10 +629,9 @@
               StrategyMap& strategy_map, const HloInstruction* ins,
               const ClusterEnvironment& cluster_env,
               const InstructionBatchDimMap& batch_map,
-              const AutoShardingSolverOption& solver_option,
-              const CallGraph& call_graph)
+              const AutoShardingOption& option, const CallGraph& call_graph)
       : HandlerBase(strategies, strategy_map, ins, cluster_env, batch_map,
-                    solver_option, call_graph),
+                    option, call_graph),
         conv_dnums_(ins->convolution_dimension_numbers()) {
     lhs_batch_dim_ = conv_dnums_.input_batch_dimension();
     lhs_in_channel_dim_ = conv_dnums_.input_feature_dimension();
@@ -790,17 +783,16 @@
     SplitRhsOutchannelBothInchannel();
 
     // Add 1d data parallel in multi-dimensional mesh
-    if (solver_option_.allow_mixed_mesh_shape) {
+    if (option_.allow_mixed_mesh_shape) {
       Add1DDataParallel();
     }
 
     // If force_batch_dim_to_mesh_dim is set, filter out invalid strategies
     // and only keep the data parallel strategies.
-    if (solver_option_.force_batch_dim_to_mesh_dim >= 0 &&
+    if (option_.force_batch_dim_to_mesh_dim >= 0 &&
         batch_map_.contains(GetBatchDimMapKey(ins_))) {
       TF_RETURN_IF_ERROR(FilterStrategy(ins_, ins_->shape(), strategies_,
-                                        cluster_env_, batch_map_,
-                                        solver_option_));
+                                        cluster_env_, batch_map_, option_));
     }
 
     return OkStatus();
@@ -819,13 +811,13 @@
                   const HloInstruction* ins, size_t instruction_id,
                   const ClusterEnvironment& cluster_env,
                   const InstructionBatchDimMap& batch_map,
-                  const AutoShardingSolverOption& solver_option,
+                  const AutoShardingOption& option,
                   const CallGraph& call_graph) {
   strategies = CreateLeafStrategyVector(instruction_id, ins, strategy_map,
                                         leaf_strategies);
 
   ConvHandler handler(strategies, strategy_map, ins, cluster_env, batch_map,
-                      solver_option, call_graph);
+                      option, call_graph);
   TF_RETURN_IF_ERROR(handler.RegisterStrategies());
   return OkStatus();
 }
diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.cc
new file mode 100644
index 0000000..661535c
--- /dev/null
+++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.cc
@@ -0,0 +1,242 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h"
+
+#include <cstddef>
+#include <cstdint>
+#include <numeric>
+#include <string>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/log/log.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "xla/hlo/experimental/auto_sharding/auto_sharding_util.h"
+
+namespace xla {
+std::string AutoShardingOption::ToString() const {
+  std::vector<std::string> lines;
+  lines.push_back(absl::StrCat("preserve_shardings: ", preserve_shardings));
+  lines.push_back(absl::StrCat("simplify_graph: ", simplify_graph));
+  if (memory_budget_per_device == -1) {
+    lines.push_back("memory_budget_per_device: -1");
+  } else {
+    lines.push_back(
+        absl::StrCat("memory_budget_per_device: ",
+                     memory_budget_per_device / (1024 * 1024 * 1024), " GB"));
+  }
+  lines.push_back(
+      absl::StrCat("try_multiple_mesh_shapes: ", try_multiple_mesh_shapes));
+
+  lines.push_back(absl::StrCat("force_override_all_gather_cost: ",
+                               force_override_all_gather_cost));
+  if (force_override_all_gather_cost) {
+    lines.push_back(absl::StrCat("all_gather_cost: ", all_gather_cost));
+  }
+
+  lines.push_back(absl::StrCat("force_override_all_to_all_cost: ",
+                               force_override_all_to_all_cost));
+  if (force_override_all_to_all_cost) {
+    lines.push_back(absl::StrCat("all_to_all_cost: ", all_to_all_cost));
+  }
+
+  lines.push_back(absl::StrCat("force_override_all_reduce_cost: ",
+                               force_override_all_reduce_cost));
+  if (force_override_all_reduce_cost) {
+    lines.push_back(absl::StrCat("all_reduce_cost: ", all_reduce_cost));
+  }
+
+  lines.push_back(absl::StrCat("force_override_reduce_scatter_cost: ",
+                               force_override_reduce_scatter_cost));
+  if (force_override_reduce_scatter_cost) {
+    lines.push_back(absl::StrCat("reduce_scatter_cost: ", reduce_scatter_cost));
+  }
+
+  lines.push_back(absl::StrCat("force_batch_dim_to_mesh_dim: ",
+                               force_batch_dim_to_mesh_dim));
+  lines.push_back(absl::StrCat("allow_replicated_parameters: ",
+                               allow_replicated_parameters));
+  lines.push_back(
+      absl::StrCat("prefer_reduce_scatter: ", prefer_reduce_scatter));
+  lines.push_back(absl::StrCat("reduce_scatter_grad_acc_friendly: ",
+                               reduce_scatter_grad_acc_friendly));
+  lines.push_back(absl::StrCat("reduce_scatter_aggressive_partition: ",
+                               reduce_scatter_aggressive_partition));
+  lines.push_back(absl::StrCat("batch_matmul_always_split_batch: ",
+                               batch_matmul_always_split_batch));
+  lines.push_back(
+      absl::StrCat("allow_recompute_heavy_op: ", allow_recompute_heavy_op));
+  lines.push_back(
+      absl::StrCat("allow_mixed_mesh_shape: ", allow_mixed_mesh_shape));
+  lines.push_back(
+      absl::StrCat("grad_acc_num_micro_batches: ", grad_acc_num_micro_batches));
+  lines.push_back(absl::StrCat("load_solution_vector: ", load_solution_vector));
+  lines.push_back(
+      absl::StrCat("force_simple_heuristic: ", force_simple_heuristic));
+  lines.push_back(absl::StrCat("force_strategy: ", force_strategy));
+
+  if (force_strategy) {
+    lines.push_back(
+        absl::StrCat("force_strategy_inst_indices: [",
+                     absl::StrJoin(force_strategy_inst_indices, ","), "]"));
+    lines.push_back(absl::StrCat("force_strategy_stra_names: [",
+                                 absl::StrJoin(force_strategy_stra_names, ","),
+                                 "]"));
+  }
+
+  lines.push_back(absl::StrCat("only_allow_divisible_input_output: ",
+                               only_allow_divisible_input_output));
+
+  lines.push_back(absl::StrCat("only_allow_divisible_intermediate: ",
+                               only_allow_divisible_intermediate));
+
+  lines.push_back(absl::StrCat("nd_sharding_iteratively_strict_search_space: ",
+                               nd_sharding_iteratively_strict_search_space));
+
+  lines.push_back(absl::StrCat("allow_replicated_strategy_for_dot_and_conv: ",
+                               allow_replicated_strategy_for_dot_and_conv));
+
+  lines.push_back(absl::StrCat("device_mesh_shape: [",
+                               absl::StrJoin(device_mesh_shape, ","), "]"));
+  lines.push_back(absl::StrCat("device_mesh_alpha: [",
+                               absl::StrJoin(device_mesh_alpha, ","), "]"));
+  lines.push_back(absl::StrCat("device_mesh_beta: [",
+                               absl::StrJoin(device_mesh_beta, ","), "]"));
+
+  lines.push_back(absl::StrCat("load_strategy: ", load_strategy));
+  if (load_strategy) {
+    lines.push_back(absl::StrCat("strategy_vector: [",
+                                 absl::StrJoin(strategy_vector, ","), "]"));
+  }
+
+  return absl::StrJoin(lines, "\n");
+}
+
+absl::Status AutoShardingOption::CheckAndSetup() {
+  only_allow_divisible_input_output = true;
+  only_allow_divisible_intermediate = false;
+
+  if (device_mesh_shape.empty()) {
+    return absl::OutOfRangeError(
+        "device_mesh_shape is empty and it needs to be specified.");
+  }
+  std::vector<int64_t> mesh_dims_greater_than_one_indices =
+      spmd::VectorGreaterThanOneElementIndices(device_mesh_shape);
+
+  // TODO(pratikf) The device mesh shape handling in this function currently
+  // does not work when try_multiple_mesh_shapes is true. Fix it.
+  if (mesh_dims_greater_than_one_indices.size() > 3 ||
+      (device_mesh_shape.size() > 3 && try_multiple_mesh_shapes)) {
+    return absl::OutOfRangeError(
+        absl::StrCat("Not supported: only device_mesh_shapes with 3 or less "
+                     "dimensions larger than 1 are supported. Instead we have ",
+                     mesh_dims_greater_than_one_indices.size(),
+                     " dimensions greater than 1."));
+  }
+  // All values in device_mesh_shape must be greater than 0.
+  if (absl::c_any_of(device_mesh_shape,
+                     [](const int64_t i) { return i <= 0; })) {
+    return absl::OutOfRangeError(
+        absl::StrCat("device_mesh_shape values need to be larger than 0: "
+                     "device_mesh_shape=",
+                     absl::StrJoin(device_mesh_shape, ",")));
+  }
+  if (spmd::VectorGreaterThanOneElementCount(device_mesh_shape) > 3) {
+    return absl::OutOfRangeError(
+        absl::StrCat("the auto-sharding pass currently does not support ",
+                     "more than three shardable dims: device_mesh_shape=",
+                     absl::StrJoin(device_mesh_shape, ",")));
+  }
+
+  if (device_mesh_alpha.empty()) {
+    // Generates simple device_mesh_alpha based on the size of
+    // device_mesh_shape.
+    device_mesh_alpha = std::vector(device_mesh_shape.size(), kDeviceMeshAlpha);
+    VLOG(0) << "Using default values for device_mesh_alpha: "
+            << absl::StrJoin(device_mesh_alpha, ",");
+  }
+  if (device_mesh_beta.empty()) {
+    // Generates simple device_mesh_beta based on the size of
+    // device_mesh_shape.
+    device_mesh_beta = std::vector(device_mesh_shape.size(), kDeviceMeshBeta);
+    VLOG(0) << "Using default values for device_mesh_beta: "
+            << absl::StrJoin(device_mesh_beta, ",");
+  }
+
+  if (device_mesh_shape.size() != device_mesh_alpha.size() ||
+      device_mesh_shape.size() != device_mesh_beta.size()) {
+    return absl::OutOfRangeError(absl::StrCat(
+        "Sizes do not match: length of device_mesh_shape is ",
+        device_mesh_shape.size(), ", length of device_mesh_alpha is ",
+        device_mesh_alpha.size(), ", length of device_mesh_beta is ",
+        device_mesh_beta.size(),
+        ". If not sure how to set device_mesh_alpha and "
+        "device_mesh_beta, "
+        "please leave them empty and default values will be used."));
+  }
+
+  if (!try_multiple_mesh_shapes) {
+    std::vector<int64_t> compressed_device_mesh_shape;
+    std::vector<double> compressed_device_mesh_alpha;
+    std::vector<double> compressed_device_mesh_beta;
+    int non_zero_counter = 0;
+    for (size_t i = 0; i < device_mesh_shape.size(); ++i) {
+      if (non_zero_counter < mesh_dims_greater_than_one_indices.size() &&
+          i == mesh_dims_greater_than_one_indices[non_zero_counter]) {
+        non_zero_counter++;
+        compressed_device_mesh_shape.push_back(device_mesh_shape[i]);
+        compressed_device_mesh_alpha.push_back(device_mesh_alpha[i]);
+        compressed_device_mesh_beta.push_back(device_mesh_beta[i]);
+      }
+    }
+    this->device_mesh_shape = compressed_device_mesh_shape;
+    this->device_mesh_alpha = compressed_device_mesh_alpha;
+    this->device_mesh_beta = compressed_device_mesh_beta;
+  }
+
+  // If device_mesh_shape has only one value, append 1 to it
+  if (device_mesh_shape.size() == 1) {
+    device_mesh_shape.push_back(1);
+    device_mesh_alpha.push_back(1.0);
+    device_mesh_beta.push_back(1.0);
+  }
+
+  int64_t total_devices = 1;
+  for (auto i : device_mesh_shape) {
+    total_devices *= i;
+  }
+  // Set up device_mesh_ids based on device_mesh_shape
+  if (device_mesh_ids.empty()) {
+    device_mesh_ids = std::vector<int64_t>(total_devices);
+    std::iota(device_mesh_ids.begin(), device_mesh_ids.end(), 0);
+    VLOG(0) << "Using default values for device_mesh_ids: "
+            << absl::StrJoin(device_mesh_ids, ",");
+  } else {
+    // Checks whether device_mesh_shape and device_mesh_ids are compatible.
+    if (total_devices != device_mesh_ids.size()) {
+      return absl::OutOfRangeError(absl::StrCat(
+          "Expect the product of device_mesh_shape to be the same as the "
+          "size of device_mesh_ids, but we have total devices = ",
+          total_devices,
+          " and device_mesh_ids.size()=", device_mesh_ids.size()));
+    }
+  }
+  return absl::OkStatus();
+}
+
+}  // namespace xla
diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h
index 4bf1892..c2898ce 100644
--- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h
+++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h
@@ -16,20 +16,11 @@
 #ifndef XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_OPTION_H_
 #define XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_OPTION_H_
 
-#include <cstddef>
 #include <cstdint>
-#include <numeric>
 #include <string>
 #include <vector>
 
-#include "absl/algorithm/container.h"
-#include "absl/log/log.h"
 #include "absl/status/status.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/str_join.h"
-#include "absl/strings/string_view.h"
-#include "absl/types/span.h"
-#include "xla/hlo/experimental/auto_sharding/auto_sharding_util.h"
 
 namespace xla {
 
@@ -73,12 +64,20 @@
   float memory_budget_ratio = 1.1;
 
   // Overwrite the all gather cost with the input all reduce cost.
-  bool force_all_gather_cost = false;
-  double all_gather_cost;
+  bool force_override_all_gather_cost = false;
+  double all_gather_cost = 0;
 
   // Overwrite the all gather cost with the input all reduce cost.
-  bool force_all_to_all_cost = false;
-  double all_to_all_cost;
+  bool force_override_all_to_all_cost = false;
+  double all_to_all_cost = 0;
+
+  // Overwrite the all gather cost with the input all reduce cost.
+  bool force_override_all_reduce_cost = false;
+  double all_reduce_cost = 0;
+
+  // Overwrite the all gather cost with the input all reduce cost.
+  bool force_override_reduce_scatter_cost = false;
+  double reduce_scatter_cost = 0;
 
   // Forcibly split the batch dimension and map it to a mesh dimension.
   // This can force the auto-sharding pass to generate the data parallel
@@ -89,7 +88,7 @@
   bool allow_replicated_parameters = true;
 
   // If true, prefer reduce-scatter + all-gather over all-reduce.
-  // A post process will be applied to replace all-reduce with reduce-scater +
+  // A post process will be applied to replace all-reduce with reduce-scatter +
   // all-gather if no communication overhead is introduced.
   bool prefer_reduce_scatter = false;
 
@@ -135,6 +134,23 @@
   std::vector<int64_t> force_strategy_inst_indices;
   std::vector<std::string> force_strategy_stra_names;
 
+  // Whether or not we allow sharding strategies where the tensor dim is
+  // indivisible by the #tiles in that dimension.
+  bool only_allow_divisible_input_output = true;
+  bool only_allow_divisible_intermediate = false;
+
+  // If true, strictly limit the following iterations to use the same number of
+  // shards for sharded tensor dimensions; if false, the following iterations
+  // can choose different number of shards for sharded tensor dimensions.
+  // Enabling it can hurt the performance of dot ops, but can make the search
+  // space more scalable. Therefore leaving it as an option.
+  bool nd_sharding_iteratively_strict_search_space = false;
+
+  // Whether or not to generate replicated strategies for dot/conv
+  // ops. Generating these seems to be beneficial for LLM serving models, but
+  // can increase the search space, so this feature is exposed as an option.
+  bool allow_replicated_strategy_for_dot_and_conv = true;
+
   // Device mesh shape.
   std::vector<int64_t> device_mesh_shape;
   // Device IDs in the mesh.
@@ -159,11 +175,6 @@
   // Static estimate for iteration count of a while loop, used in the cost model
   int64_t loop_iteration_count_estimate = 100;
 
-  // Whether or not to generate replicated strategies for dot/conv
-  // ops. Generating these seems to be beneficial for LLM serving models, but
-  // can increase the search space, so this feature is exposed as an option.
-  bool allow_replicated_strategy_for_dot_and_conv = true;
-
   // Allows the conversion of aliases to followers if their pairwise strategy
   // compatibilities are embodied by the identity matrix (which makes for a
   // smaller Mixed ILP).
@@ -180,188 +191,12 @@
   // a simple replicated default.
   bool use_sharding_propagation_for_default_shardings = true;
 
-  std::string ToString() {
-    std::vector<std::string> lines;
-    lines.push_back(absl::StrCat("preserve_shardings: ", preserve_shardings));
-    lines.push_back(absl::StrCat("simplify_graph: ", simplify_graph));
-    if (memory_budget_per_device == -1) {
-      lines.push_back("memory_budget_per_device: -1");
-    } else {
-      lines.push_back(
-          absl::StrCat("memory_budget_per_device: ",
-                       memory_budget_per_device / (1024 * 1024 * 1024), " GB"));
-    }
-    lines.push_back(
-        absl::StrCat("try_multiple_mesh_shapes: ", try_multiple_mesh_shapes));
-    lines.push_back(
-        absl::StrCat("force_all_gather_cost: ", force_all_gather_cost));
+  // Prints a debug string.
+  std::string ToString() const;
 
-    if (force_all_gather_cost) {
-      lines.push_back(absl::StrCat("all_gather_cost: ", all_gather_cost));
-    }
-    lines.push_back(
-        absl::StrCat("force_all_to_all_cost: ", force_all_to_all_cost));
-    if (force_all_to_all_cost) {
-      lines.push_back(absl::StrCat("all_to_all_cost: ", all_to_all_cost));
-    }
-    lines.push_back(absl::StrCat("force_batch_dim_to_mesh_dim: ",
-                                 force_batch_dim_to_mesh_dim));
-    lines.push_back(absl::StrCat("allow_replicated_parameters: ",
-                                 allow_replicated_parameters));
-    lines.push_back(
-        absl::StrCat("prefer_reduce_scatter: ", prefer_reduce_scatter));
-    lines.push_back(absl::StrCat("reduce_scatter_grad_acc_friendly: ",
-                                 reduce_scatter_grad_acc_friendly));
-    lines.push_back(absl::StrCat("reduce_scatter_aggressive_partition: ",
-                                 reduce_scatter_aggressive_partition));
-    lines.push_back(absl::StrCat("batch_matmul_always_split_batch: ",
-                                 batch_matmul_always_split_batch));
-    lines.push_back(
-        absl::StrCat("allow_recompute_heavy_op: ", allow_recompute_heavy_op));
-    lines.push_back(
-        absl::StrCat("allow_mixed_mesh_shape: ", allow_mixed_mesh_shape));
-    lines.push_back(absl::StrCat("grad_acc_num_micro_batches: ",
-                                 grad_acc_num_micro_batches));
-    lines.push_back(
-        absl::StrCat("load_solution_vector: ", load_solution_vector));
-    lines.push_back(
-        absl::StrCat("force_simple_heuristic: ", force_simple_heuristic));
-    lines.push_back(absl::StrCat("force_strategy: ", force_strategy));
-
-    if (force_strategy) {
-      lines.push_back(
-          absl::StrCat("force_strategy_inst_indices: [",
-                       absl::StrJoin(force_strategy_inst_indices, ","), "]"));
-      lines.push_back(
-          absl::StrCat("force_strategy_stra_names: [",
-                       absl::StrJoin(force_strategy_stra_names, ","), "]"));
-    }
-
-    lines.push_back(absl::StrCat("device_mesh_shape: [",
-                                 absl::StrJoin(device_mesh_shape, ","), "]"));
-    lines.push_back(absl::StrCat("device_mesh_alpha: [",
-                                 absl::StrJoin(device_mesh_alpha, ","), "]"));
-    lines.push_back(absl::StrCat("device_mesh_beta: [",
-                                 absl::StrJoin(device_mesh_beta, ","), "]"));
-
-    lines.push_back(absl::StrCat("load_strategy: ", load_strategy));
-    if (load_strategy) {
-      lines.push_back(absl::StrCat("strategy_vector: [",
-                                   absl::StrJoin(strategy_vector, ","), "]"));
-    }
-
-    return absl::StrJoin(lines, "\n");
-  }
-
-  Status CheckAndSetup() {
-    if (device_mesh_shape.empty()) {
-      return absl::OutOfRangeError(
-          "device_mesh_shape is empty and it needs to be specified.");
-    }
-    std::vector<int64_t> mesh_dims_greater_than_one_indices =
-        spmd::VectorGreaterThanOneElementIndices(device_mesh_shape);
-
-    // TODO(pratikf) The device mesh shape handling in this function currently
-    // does not work when try_multiple_mesh_shapes is true. Fix it.
-    if (mesh_dims_greater_than_one_indices.size() > 3 ||
-        (device_mesh_shape.size() > 3 && try_multiple_mesh_shapes)) {
-      return absl::OutOfRangeError(absl::StrCat(
-          "Not supported: only device_mesh_shapes with 3 or less "
-          "dimensions larger than 1 are supported. Instead we have ",
-          mesh_dims_greater_than_one_indices.size(),
-          " dimensions greater than 1."));
-    }
-    // All values in device_mesh_shape must be greater than 0.
-    if (absl::c_any_of(device_mesh_shape,
-                       [](const int64_t i) { return i <= 0; })) {
-      return absl::OutOfRangeError(
-          absl::StrCat("device_mesh_shape values need to be larger than 0: "
-                       "device_mesh_shape=",
-                       absl::StrJoin(device_mesh_shape, ",")));
-    }
-    if (spmd::VectorGreaterThanOneElementCount(device_mesh_shape) > 3) {
-      return absl::OutOfRangeError(
-          absl::StrCat("the auto-sharding pass currently does not support ",
-                       "more than three shardable dims: device_mesh_shape=",
-                       absl::StrJoin(device_mesh_shape, ",")));
-    }
-
-    if (device_mesh_alpha.empty()) {
-      // Generates simple device_mesh_alpha based on the size of
-      // device_mesh_shape.
-      device_mesh_alpha =
-          std::vector(device_mesh_shape.size(), kDeviceMeshAlpha);
-      VLOG(0) << "Using default values for device_mesh_alpha: "
-              << absl::StrJoin(device_mesh_alpha, ",");
-    }
-    if (device_mesh_beta.empty()) {
-      // Generates simple device_mesh_beta based on the size of
-      // device_mesh_shape.
-      device_mesh_beta = std::vector(device_mesh_shape.size(), kDeviceMeshBeta);
-      VLOG(0) << "Using default values for device_mesh_beta: "
-              << absl::StrJoin(device_mesh_beta, ",");
-    }
-
-    if (device_mesh_shape.size() != device_mesh_alpha.size() ||
-        device_mesh_shape.size() != device_mesh_beta.size()) {
-      return absl::OutOfRangeError(absl::StrCat(
-          "Sizes do not match: length of device_mesh_shape is ",
-          device_mesh_shape.size(), ", length of device_mesh_alpha is ",
-          device_mesh_alpha.size(), ", length of device_mesh_beta is ",
-          device_mesh_beta.size(),
-          ". If not sure how to set device_mesh_alpha and "
-          "device_mesh_beta, "
-          "please leave them empty and default values will be used."));
-    }
-
-    if (!try_multiple_mesh_shapes) {
-      std::vector<int64_t> compressed_device_mesh_shape;
-      std::vector<double> compressed_device_mesh_alpha;
-      std::vector<double> compressed_device_mesh_beta;
-      int non_zero_counter = 0;
-      for (size_t i = 0; i < device_mesh_shape.size(); ++i) {
-        if (non_zero_counter < mesh_dims_greater_than_one_indices.size() &&
-            i == mesh_dims_greater_than_one_indices[non_zero_counter]) {
-          non_zero_counter++;
-          compressed_device_mesh_shape.push_back(device_mesh_shape[i]);
-          compressed_device_mesh_alpha.push_back(device_mesh_alpha[i]);
-          compressed_device_mesh_beta.push_back(device_mesh_beta[i]);
-        }
-      }
-      this->device_mesh_shape = compressed_device_mesh_shape;
-      this->device_mesh_alpha = compressed_device_mesh_alpha;
-      this->device_mesh_beta = compressed_device_mesh_beta;
-    }
-
-    // If device_mesh_shape has only one value, append 1 to it
-    if (device_mesh_shape.size() == 1) {
-      device_mesh_shape.push_back(1);
-      device_mesh_alpha.push_back(1.0);
-      device_mesh_beta.push_back(1.0);
-    }
-
-    int64_t total_devices = 1;
-    for (auto i : device_mesh_shape) {
-      total_devices *= i;
-    }
-    // Set up device_mesh_ids based on device_mesh_shape
-    if (device_mesh_ids.empty()) {
-      device_mesh_ids = std::vector<int64_t>(total_devices);
-      std::iota(device_mesh_ids.begin(), device_mesh_ids.end(), 0);
-      VLOG(0) << "Using default values for device_mesh_ids: "
-              << absl::StrJoin(device_mesh_ids, ",");
-    } else {
-      // Checks whether device_mesh_shape and device_mesh_ids are compatible.
-      if (total_devices != device_mesh_ids.size()) {
-        return absl::OutOfRangeError(absl::StrCat(
-            "Expect the product of device_mesh_shape to be the same as the "
-            "size of device_mesh_ids, but we have total devices = ",
-            total_devices,
-            " and device_mesh_ids.size()=", device_mesh_ids.size()));
-      }
-    }
-    return OkStatus();
-  }
+  // Initializes uninitialized fields with default values, as well as checks the
+  // consistency of different options.
+  absl::Status CheckAndSetup();
 };
 
 }  // namespace xla
diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.cc
index 84d96bc..fc4fe74 100644
--- a/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.cc
+++ b/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.cc
@@ -17,14 +17,8 @@
 
 #include <algorithm>
 #include <cstdint>
-#include <functional>
-#include <iterator>
-#include <memory>
-#include <numeric>
 #include <optional>
-#include <ostream>
 #include <string>
-#include <tuple>
 #include <utility>
 #include <vector>
 
@@ -36,8 +30,8 @@
 namespace spmd {
 
 double ClusterEnvironment::AllGatherCost(double num_bytes, int mesh_dim) const {
-  if (solver_option_.override_all_gather_cost) {
-    return solver_option_.all_gather_cost;
+  if (auto_sharding_option_.force_override_all_gather_cost) {
+    return auto_sharding_option_.all_gather_cost;
   }
 
   if (prof_result_.Enabled()) {
@@ -45,7 +39,7 @@
                                               num_bytes / 4, "float32");
   }
 
-  if (solver_option_.force_batch_dim_to_mesh_dim == mesh_dim) {
+  if (auto_sharding_option_.force_batch_dim_to_mesh_dim == mesh_dim) {
     // if data-parallel is forced on this dim, we only allow all-reduce
     // in this dimension.
     return kInfinityCost;
@@ -61,8 +55,8 @@
 // TODO(zhuohan): distinguish dtype and reduce_op.
 double ClusterEnvironment::AllReduceCost(double num_bytes, int32_t mesh_dim,
                                          int32_t mesh_dim_another) const {
-  if (solver_option_.override_all_reduce_cost) {
-    return solver_option_.all_reduce_cost;
+  if (auto_sharding_option_.force_override_all_reduce_cost) {
+    return auto_sharding_option_.all_reduce_cost;
   }
 
   if (prof_result_.Enabled()) {
@@ -89,8 +83,8 @@
 
 double ClusterEnvironment::ReduceScatterCost(double num_bytes,
                                              int mesh_dim) const {
-  if (solver_option_.override_reduce_scatter_cost) {
-    return solver_option_.reduce_scatter_cost;
+  if (auto_sharding_option_.force_override_reduce_scatter_cost) {
+    return auto_sharding_option_.reduce_scatter_cost;
   }
 
   if (prof_result_.Enabled()) {
@@ -106,8 +100,8 @@
 }
 
 double ClusterEnvironment::AllToAllCost(double num_bytes, int mesh_dim) const {
-  if (solver_option_.override_all_to_all_cost) {
-    return solver_option_.all_to_all_cost;
+  if (auto_sharding_option_.force_override_all_to_all_cost) {
+    return auto_sharding_option_.all_to_all_cost;
   }
 
   if (prof_result_.Enabled()) {
@@ -115,7 +109,7 @@
                                              num_bytes / 4, "float32");
   }
 
-  if (solver_option_.force_batch_dim_to_mesh_dim == mesh_dim) {
+  if (auto_sharding_option_.force_batch_dim_to_mesh_dim == mesh_dim) {
     // if data-parallel is forced on this dim, we only allow all-reduce
     // in this dimension.
     return kInfinityCost;
@@ -127,9 +121,8 @@
 }
 
 double ClusterEnvironment::DotCost(const Shape& lhs_shape,
-                                   const Shape& rhs_shape,
-                                   const DotDimensionNumbers& dot_dnums) const {
-  if (!solver_option_.allow_recompute_heavy_op) {
+                                   const Shape& rhs_shape) const {
+  if (!auto_sharding_option_.allow_recompute_heavy_op) {
     return kInfinityCost;
   }
 
diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.h b/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.h
index ef00dc5..25c3e6b 100644
--- a/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.h
+++ b/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.h
@@ -19,15 +19,11 @@
 #include <algorithm>
 #include <cstdint>
 #include <iterator>
-#include <memory>
-#include <optional>
-#include <ostream>
-#include <sstream>
 #include <string>
 #include <utility>
 #include <vector>
 
-#include "xla/hlo/experimental/auto_sharding/auto_sharding_solver_option.h"
+#include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h"
 #include "xla/hlo/experimental/auto_sharding/auto_sharding_util.h"
 #include "xla/hlo/experimental/auto_sharding/profiling_result.h"
 #include "xla/hlo/ir/hlo_sharding.h"
@@ -47,7 +43,7 @@
                      absl::Span<const double> mesh_alpha,
                      absl::Span<const double> mesh_beta,
                      const ProfilingResult& prof_result,
-                     const AutoShardingSolverOption& solver_option)
+                     const AutoShardingOption& auto_sharding_option)
       : original_device_mesh_(original_device_mesh),
         device_mesh_(device_mesh),
         mesh_alpha_(mesh_alpha.begin(), mesh_alpha.end()),
@@ -55,7 +51,7 @@
         prof_result_(prof_result),
         total_devices_(device_mesh.num_elements()),
         device_mesh_1d_(original_device_mesh),
-        solver_option_(solver_option) {
+        auto_sharding_option_(auto_sharding_option) {
     // Build replica group for each dimension.
     non_zero_mesh_dims_ =
         VectorGreaterThanOneElementIndices(device_mesh.dimensions());
@@ -137,8 +133,7 @@
                                            const HloSharding& src_spec,
                                            const HloSharding& dst_spec) const;
 
-  double DotCost(const Shape& lhs_shape, const Shape& rhs_shape,
-                 const DotDimensionNumbers& dot_dnums) const;
+  double DotCost(const Shape& lhs_shape, const Shape& rhs_shape) const;
 
   // This function attempts to overestimate the cost of replicating a tensor of
   // shape `shape` sharded according to `src_spec`.
@@ -175,8 +170,8 @@
   // Used for mixed mesh shape strategies.
   Array<int64_t> device_mesh_1d_;
 
-  // The solver option may override the cost of communication primitives
-  const AutoShardingSolverOption& solver_option_;
+  // The option may override the cost of communication primitives
+  const AutoShardingOption& auto_sharding_option_;
 
   // Cached replica groups. Shape: [mesh_dim, group_id, ids in this group].
   std::vector<std::vector<std::vector<int64_t>>> cached_replica_groups_;
diff --git a/third_party/xla/xla/hlo/ir/dynamic_parameter_binding.cc b/third_party/xla/xla/hlo/ir/dynamic_parameter_binding.cc
index a051018..049676a 100644
--- a/third_party/xla/xla/hlo/ir/dynamic_parameter_binding.cc
+++ b/third_party/xla/xla/hlo/ir/dynamic_parameter_binding.cc
@@ -20,21 +20,26 @@
 #include <string>
 #include <vector>
 
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
 #include "xla/hlo/ir/hlo_computation.h"
 #include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_module.h"
+#include "xla/shape_util.h"
+#include "xla/status.h"
+#include "xla/status_macros.h"
+#include "tsl/platform/errors.h"
 
 namespace xla {
 
 Status DynamicParameterBinding::Bind(
-    const DynamicParameter& dynamic_parameter,
+    const DynamicSizeParameter& dynamic_parameter,
     const DynamicDimension& dynamic_dimension) {
   auto result = bindings_.emplace(dynamic_dimension, dynamic_parameter);
   TF_RET_CHECK(result.second);
   return OkStatus();
 }
 
-std::optional<DynamicParameterBinding::DynamicParameter>
+std::optional<DynamicParameterBinding::DynamicSizeParameter>
 DynamicParameterBinding::GetBinding(
     const DynamicDimension& dynamic_dimension) const {
   auto param_iter = bindings_.find(dynamic_dimension);
@@ -49,7 +54,7 @@
   pieces.push_back("DynamicParameterBinding: ");
   for (const auto& binding : bindings_) {
     const DynamicDimension& dynamic_dimension = binding.first;
-    const DynamicParameter& dynamic_param = binding.second;
+    const DynamicSizeParameter& dynamic_param = binding.second;
     pieces.push_back(absl::StrFormat(
         " -- Input param number %lld at %s has dim %lld as dynamic"
         " dimension, which is represented by param number %lld at "
@@ -69,24 +74,28 @@
   return OkStatus();
 }
 
-Status DynamicParameterBinding::Verify(const HloModule& module) const {
-  const HloComputation* entry = module.entry_computation();
-  return ForEachBinding([&](const DynamicParameter& dynamic_parameter,
+Status DynamicParameterBinding::Verify(
+    const HloComputation& computation) const {
+  return ForEachBinding([&](const DynamicSizeParameter& dynamic_parameter,
                             const DynamicDimension& dynamic_dimension)
                             -> Status {
     TF_RET_CHECK(dynamic_parameter.parameter_num >= 0 &&
-                 dynamic_parameter.parameter_num < entry->num_parameters());
-    TF_RET_CHECK(dynamic_dimension.parameter_num < entry->num_parameters());
+                 dynamic_parameter.parameter_num <
+                     computation.num_parameters());
+    TF_RET_CHECK(dynamic_dimension.parameter_num <
+                 computation.num_parameters());
     TF_RET_CHECK(ShapeUtil::IndexIsValid(
-        entry->parameter_instruction(dynamic_parameter.parameter_num)->shape(),
+        computation.parameter_instruction(dynamic_parameter.parameter_num)
+            ->shape(),
         dynamic_parameter.parameter_index));
     TF_RET_CHECK(ShapeUtil::IndexIsValid(
-        entry->parameter_instruction(dynamic_dimension.parameter_num)->shape(),
+        computation.parameter_instruction(dynamic_dimension.parameter_num)
+            ->shape(),
         dynamic_dimension.parameter_index));
     TF_RET_CHECK(
         dynamic_dimension.dimension <
         ShapeUtil::GetSubshape(
-            entry->parameter_instruction(dynamic_dimension.parameter_num)
+            computation.parameter_instruction(dynamic_dimension.parameter_num)
                 ->shape(),
             dynamic_dimension.parameter_index)
             .rank());
diff --git a/third_party/xla/xla/hlo/ir/dynamic_parameter_binding.h b/third_party/xla/xla/hlo/ir/dynamic_parameter_binding.h
index b40569d..77edb06 100644
--- a/third_party/xla/xla/hlo/ir/dynamic_parameter_binding.h
+++ b/third_party/xla/xla/hlo/ir/dynamic_parameter_binding.h
@@ -16,6 +16,7 @@
 #ifndef XLA_HLO_IR_DYNAMIC_PARAMETER_BINDING_H_
 #define XLA_HLO_IR_DYNAMIC_PARAMETER_BINDING_H_
 
+#include <cstdint>
 #include <functional>
 #include <optional>
 #include <ostream>
@@ -23,9 +24,9 @@
 #include <utility>
 
 #include "absl/container/flat_hash_map.h"
-#include "xla/service/hlo.pb.h"
-#include "xla/shape_tree.h"
+#include "xla/hlo/ir/hlo_computation.h"
 #include "xla/shape_util.h"
+#include "xla/status.h"
 
 namespace xla {
 
@@ -41,10 +42,10 @@
 // ready.
 class DynamicParameterBinding {
  public:
-  // DynamicParameter represents a special parameter that is used to represent
-  // the runtime size of a dimension of another parameter. A dynamic parameter
-  // has to be a scalar value.
-  struct DynamicParameter {
+  // DynamicSizeParameter represents a special parameter that is used to
+  // represent the runtime size of a dimension of another parameter. A dynamic
+  // size parameter has to be a scalar value.
+  struct DynamicSizeParameter {
     // The parameter number of dynamic parameter.
     int64_t parameter_num;
     // The index of the parameter.
@@ -52,8 +53,8 @@
   };
 
   // DynamicDimension represents a dimension whose size is determined at
-  // runtime. A DynamicDimension's runtime size is determined by the binded
-  // DynamicParameter using `DynamicParameterBinding::Bind` method.
+  // runtime. A DynamicDimension's runtime size is determined by the bound
+  // DynamicSizeParameter using `DynamicParameterBinding::Bind` method.
   struct DynamicDimension {
     // The parameter number of dynamic dimension.
     int64_t parameter_num;
@@ -77,25 +78,21 @@
     }
   };
 
-  DynamicParameterBinding() = default;
-
-  virtual ~DynamicParameterBinding() = default;
-
   // Adds binding which indicates that the dimension indicated by
   // `dynamic_dimension` is dynamic, and its runtime size is represented by
   // `dynamic_parameter`.
-  Status Bind(const DynamicParameter& dynamic_parameter,
+  Status Bind(const DynamicSizeParameter& dynamic_parameter,
               const DynamicDimension& dynamic_dimension);
 
   // Returns the parameter and the index representing the runtime size of
   // dimension `dim_num` of parameter `param_num` at `param_index`.
   //
   // Returns nullopt if the binding is not set.
-  std::optional<DynamicParameter> GetBinding(
+  std::optional<DynamicSizeParameter> GetBinding(
       const DynamicDimension& dynamic_dimension) const;
 
   using BindingFn =
-      std::function<Status(const DynamicParameter& dynamic_parameter,
+      std::function<Status(const DynamicSizeParameter& dynamic_parameter,
                            const DynamicDimension& dynamic_dimension)>;
 
   // Iterate through each binding.
@@ -103,16 +100,19 @@
 
   std::string ToString() const;
 
-  // Verifies that the given binding is valid for the given module.
+  // Verifies that the given binding is valid for the given computation.
   // Specifically, the binding's parameter and parameter size should be valid.
-  Status Verify(const HloModule& module) const;
+  Status Verify(const HloComputation& computation) const;
+
+  // Returns true iff there are no bindings.
+  bool empty() const { return bindings_.empty(); }
 
  private:
   // Keeps track of mappings from DynamicDimension to DynamicParameter. The
   // direction of is chosen so that we can easily query if a dimension is
   // dynamic and which dynamic parameter represents the real size of that
   // dimension.
-  absl::flat_hash_map<DynamicDimension, DynamicParameter> bindings_;
+  absl::flat_hash_map<DynamicDimension, DynamicSizeParameter> bindings_;
 };
 
 std::ostream& operator<<(std::ostream& out,
diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.cc b/third_party/xla/xla/hlo/ir/hlo_instruction.cc
index cf193ba..11631c8 100644
--- a/third_party/xla/xla/hlo/ir/hlo_instruction.cc
+++ b/third_party/xla/xla/hlo/ir/hlo_instruction.cc
@@ -3203,11 +3203,29 @@
 }
 
 bool HloInstruction::IsCrossModuleAllReduce() const {
-  return opcode() == HloOpcode::kAllReduce && channel_id();
+  if (opcode() == HloOpcode::kAllReduce ||
+      opcode() == HloOpcode::kAllReduceStart) {
+    return channel_id() != std::nullopt;
+  } else if (opcode() == HloOpcode::kAllReduceDone) {
+    CHECK_EQ(operand_count(), 1);
+    const HloInstruction* operand = this->operand(0);
+    CHECK_EQ(operand->opcode(), HloOpcode::kAllReduceStart);
+    return operand->channel_id() != std::nullopt;
+  }
+  return false;
 }
 
 bool HloInstruction::IsCrossReplicaAllReduce() const {
-  return opcode() == HloOpcode::kAllReduce && !channel_id();
+  if (opcode() == HloOpcode::kAllReduce ||
+      opcode() == HloOpcode::kAllReduceStart) {
+    return channel_id() == std::nullopt;
+  } else if (opcode() == HloOpcode::kAllReduceDone) {
+    CHECK_EQ(operand_count(), 1);
+    const HloInstruction* operand = this->operand(0);
+    CHECK_EQ(operand->opcode(), HloOpcode::kAllReduceStart);
+    return operand->channel_id() == std::nullopt;
+  }
+  return false;
 }
 
 void HloInstruction::PrintWithCanonicalNameMap(
diff --git a/third_party/xla/xla/hlo/utils/BUILD b/third_party/xla/xla/hlo/utils/BUILD
index 3d46d69..90ab60d 100644
--- a/third_party/xla/xla/hlo/utils/BUILD
+++ b/third_party/xla/xla/hlo/utils/BUILD
@@ -101,12 +101,14 @@
         "//xla:array",
         "//xla:literal_util",
         "//xla:protobuf_util",
+        "//xla:shape_util",
         "//xla:util",
         "//xla:xla_data_proto_cc",
         "//xla/hlo/ir:hlo",
         "//xla/hlo/ir:tile_assignment",
         "//xla/service:call_graph",
         "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/container:flat_hash_set",
         "@com_google_absl//absl/container:inlined_vector",
         "@com_google_absl//absl/log:check",
diff --git a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc
index 652ece8..7cdbf2f 100644
--- a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc
+++ b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc
@@ -18,17 +18,18 @@
 #include <algorithm>
 #include <cmath>
 #include <cstdint>
-#include <iostream>
 #include <iterator>
 #include <map>
 #include <memory>
 #include <optional>
 #include <set>
 #include <string>
+#include <tuple>
 #include <utility>
 #include <vector>
 
 #include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
 #include "absl/container/flat_hash_set.h"
 #include "absl/container/inlined_vector.h"
 #include "absl/log/check.h"
@@ -44,6 +45,7 @@
 #include "xla/literal_util.h"
 #include "xla/protobuf_util.h"
 #include "xla/service/call_graph.h"
+#include "xla/shape.h"
 #include "xla/util.h"
 #include "xla/xla_data.pb.h"
 
@@ -68,9 +70,10 @@
   if (potential_subsharding.IsTileMaximal()) {
     return false;
   }
+  const int32_t tiled_data_rank = potential_subsharding.TiledDataRank();
   // Different tiled ranks can't be compared (something is wrong, are the
   // shardings for different shapes?)
-  if (potential_subsharding.TiledDataRank() != sharding.TiledDataRank()) {
+  if (tiled_data_rank != sharding.TiledDataRank()) {
     return false;
   }
   // Helper to construct the base tile bounds based on a shape and a sharding.
@@ -95,56 +98,61 @@
       return false;
     }
   }
-  const int32_t num_devices =
-      potential_subsharding.tile_assignment().num_elements();
+  // Use one contiguous storage to reduce allocation overhead.
+  auto storage = std::make_unique<int32_t[]>(
+      sharding.tile_assignment().num_elements() * tiled_data_rank);
+  int32_t* storage_cursor = storage.get();
   // Need a map here, because the MPMD partitioner sharding annotations can have
   // non contiguous partition numbers.
-  absl::flat_hash_map<int32_t, std::vector<int32_t>> subsharding_offsets;
-  absl::flat_hash_map<int32_t, std::vector<int32_t>> sharding_offsets;
-  const int32_t indices_count = potential_subsharding.TiledDataRank();
-  // Collect the start offsets for each tile for the subsharding we are
-  // evaluating.
-  potential_subsharding.tile_assignment().Each(
-      [&](absl::Span<const int64_t> indices, int64_t device) {
-        auto& indices_per_device = subsharding_offsets[device];
-        for (int64_t i = 0; i < indices_count; ++i) {
-          indices_per_device.push_back(potential_base_tile[i] * indices[i]);
-        }
-      });
+  absl::flat_hash_map<int32_t, int32_t*> sharding_offsets;
+  sharding_offsets.reserve(sharding.tile_assignment().num_elements());
+  auto get_sharding_offsets = [&](int64_t device) -> absl::Span<int32_t> {
+    auto it = sharding_offsets.find(device);
+    if (it == sharding_offsets.end()) {
+      bool emplaced;
+      std::tie(it, emplaced) = sharding_offsets.emplace(device, storage_cursor);
+      DCHECK(emplaced);
+      storage_cursor += tiled_data_rank;
+    }
+    return absl::MakeSpan(it->second, tiled_data_rank);
+  };
   // Collect the start offsets for each tile for the sharding we are evaluating
   // against.
   sharding.tile_assignment().Each(
       [&](absl::Span<const int64_t> indices, int64_t device) {
-        auto& indices_per_device = sharding_offsets[device];
-        for (int64_t i = 0; i < indices_count; ++i) {
-          indices_per_device.push_back(base_tile[i] * indices[i]);
+        auto indices_per_device = get_sharding_offsets(device);
+        for (int64_t i = 0; i < tiled_data_rank; ++i) {
+          indices_per_device[i] = base_tile[i] * indices[i];
         }
       });
   // Compare the start offsets and the end offset of the tiles for each device.
-  for (int i = 0; i < num_devices; ++i) {
-    const int32_t device_id =
-        potential_subsharding.tile_assignment().array().data()[i];
-    auto& subsharding_offset = subsharding_offsets[device_id];
-    auto& sharding_offset = sharding_offsets[device_id];
-    for (int j = 0; j < indices_count; ++j) {
-      // The subsharding contains data outside of the tile we are comparing
-      // against.
-      if (subsharding_offset[j] < sharding_offset[j]) {
-        return false;
-      }
-      // Skip last tile. It can never go beyond the limit as the shape is the
-      // same for both shardings and sometimes there's padding making one of the
-      // two limits bigger than the other, but it shouldn't be counted.
-      const bool is_last_tile =
-          subsharding_offset[j] + potential_base_tile[j] >=
-          potential_sharded_shape.dimensions(j);
-      if (!is_last_tile && subsharding_offset[j] + potential_base_tile[j] >
-                               sharding_offset[j] + base_tile[j]) {
-        return false;
-      }
-    }
-  }
-  return true;
+  auto& potential_ta = potential_subsharding.tile_assignment().array();
+  absl::Status ok_if_no_vialation = potential_ta.EachStatus(
+      [&](absl::Span<const int64_t> indices, int64_t device) {
+        auto sharding_offset = get_sharding_offsets(device);
+        for (int j = 0; j < tiled_data_rank; ++j) {
+          const int32_t subsharding_offset_j =
+              potential_base_tile[j] * indices[j];
+          // The subsharding contains data outside of the tile we are comparing
+          // against.
+          if (subsharding_offset_j < sharding_offset[j]) {
+            return InternalError("");
+          }
+          // Skip last tile. It can never go beyond the limit as the shape is
+          // the same for both shardings and sometimes there's padding making
+          // one of the two limits bigger than the other, but it shouldn't be
+          // counted.
+          const bool is_last_tile =
+              subsharding_offset_j + potential_base_tile[j] >=
+              potential_sharded_shape.dimensions(j);
+          if (!is_last_tile && subsharding_offset_j + potential_base_tile[j] >
+                                   sharding_offset[j] + base_tile[j]) {
+            return InternalError("");
+          }
+        }
+        return absl::OkStatus();
+      });
+  return ok_if_no_vialation.ok();
 }
 
 bool IsShardingMoreSpecific(const HloSharding& lhs, const HloSharding& rhs) {
diff --git a/third_party/xla/xla/literal_util.cc b/third_party/xla/xla/literal_util.cc
index d6f82e2..5724836 100644
--- a/third_party/xla/xla/literal_util.cc
+++ b/third_party/xla/xla/literal_util.cc
@@ -121,28 +121,13 @@
 };
 
 template <typename T>
-struct Is16BitFloat {
-  static constexpr bool value =
-      std::is_same<bfloat16, T>::value || std::is_same<half, T>::value;
-};
-
-template <typename T>
 struct IsReal {
-  static constexpr bool value =
-      std::is_integral<T>::value || std::is_floating_point<T>::value ||
-      std::is_same<bfloat16, T>::value || std::is_same<half, T>::value ||
-      std::is_same<tsl::float8_e5m2, T>::value ||
-      std::is_same<tsl::float8_e5m2fnuz, T>::value ||
-      std::is_same<tsl::float8_e4m3fn, T>::value ||
-      std::is_same<tsl::float8_e4m3b11, T>::value ||
-      std::is_same<tsl::float8_e4m3fnuz, T>::value;
+  static constexpr bool value = std::numeric_limits<T>::is_specialized;
 };
 
 template <typename T>
 struct IsValidScalarType {
-  static constexpr bool value = IsReal<T>::value ||
-                                std::is_same<complex64, T>::value ||
-                                std::is_same<complex128, T>::value;
+  static constexpr bool value = IsReal<T>::value || is_complex_v<T>;
 };
 
 template <typename NativeT>
diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/BUILD b/third_party/xla/xla/mlir/backends/gpu/transforms/BUILD
index 758b54e..fd1c250 100644
--- a/third_party/xla/xla/mlir/backends/gpu/transforms/BUILD
+++ b/third_party/xla/xla/mlir/backends/gpu/transforms/BUILD
@@ -77,6 +77,7 @@
         "//xla/stream_executor:blas",
         "//xla/stream_executor:device_description",
         "//xla/translate/mhlo_to_hlo:location_exporter",
+        "@com_google_absl//absl/container:flat_hash_set",
         "@com_google_absl//absl/log",
         "@com_google_absl//absl/status:statusor",
         "@com_google_absl//absl/strings",
diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/outline_cuda_graphs.cc b/third_party/xla/xla/mlir/backends/gpu/transforms/outline_cuda_graphs.cc
index a32511c..63a2e72 100644
--- a/third_party/xla/xla/mlir/backends/gpu/transforms/outline_cuda_graphs.cc
+++ b/third_party/xla/xla/mlir/backends/gpu/transforms/outline_cuda_graphs.cc
@@ -19,6 +19,7 @@
 #include <utility>
 #include <vector>
 
+#include "absl/container/flat_hash_set.h"
 #include "llvm/ADT/STLExtras.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"  // from @llvm-project
 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
@@ -57,8 +58,10 @@
     : public impl::OutlineGpuGraphsPassBase<OutlineGpuGraphsPass> {
  public:
   OutlineGpuGraphsPass() = default;
-  explicit OutlineGpuGraphsPass(int gpu_graph_level, int min_graph_size)
-      : gpu_graph_level_(gpu_graph_level) {
+  explicit OutlineGpuGraphsPass(
+      absl::flat_hash_set<DebugOptions::CommandBufferCmdType> command_types,
+      int min_graph_size)
+      : command_types_(std::move(command_types)) {
     this->min_graph_size_ = min_graph_size;
   }
 
@@ -69,6 +72,8 @@
   }
 
  private:
+  absl::flat_hash_set<DebugOptions::CommandBufferCmdType> command_types_ = {
+      DebugOptions::FUSION, DebugOptions::CUBLAS, DebugOptions::CUDNN};
   int gpu_graph_level_ = 3;
 };
 
@@ -360,6 +365,10 @@
   // If an argument to parent_func has the "lmhlo.constant_name" attribute and
   // is passed to the graph capture function, we propagate the attribute the
   // graph capture function.
+  //
+  // We also annotate all arguments with "rt.allocation_index" attribute that
+  // allows us to forward correct arguments to graph capture function during
+  // Gpu executable initialization (see `InstantiateAllGraphs` implementation).
   for (unsigned i = 0; i < args.size(); ++i) {
     Value arg = args[i];
 
@@ -371,6 +380,12 @@
     Block* parent_block = block_arg.getParentBlock();
     if (!parent_block->isEntryBlock()) continue;
 
+    // If this is an argument to the entry block of the parent function, it
+    // means that it's the XLA allocation, and we forward index to the capture
+    // function.
+    func.setArgAttr(i, "rt.allocation_index",
+                    b.getIndexAttr(block_arg.getArgNumber()));
+
     // Check that the parent_block is in the SSACFG region of parent_func.
     Region& parent_func_region = parent_func.getRegion();
     if (parent_block->getParent() != &parent_func_region) continue;
@@ -458,7 +473,7 @@
 
   OpCapturePatternSet patterns;
 
-  if (gpu_graph_level_ >= 1) {
+  if (command_types_.contains(DebugOptions::FUSION)) {
     // Enable capturing fusions and memcpies.
     patterns.emplace_back(new LaunchFuncOpCapture());
     patterns.emplace_back(new ConstantOpCapture());
@@ -467,12 +482,12 @@
     patterns.emplace_back(new ReinterpretCastOpCapture());
   }
 
-  if (gpu_graph_level_ >= 2) {
+  if (command_types_.contains(DebugOptions::CUBLAS)) {
     // Enable capturing gemms.
     patterns.emplace_back(new GemmOpCapture());
   }
 
-  if (gpu_graph_level_ >= 3) {
+  if (command_types_.contains(DebugOptions::CUDNN)) {
     // Enable capturing convolutions.
     patterns.emplace_back(new ConvForwardOpCapture());
     patterns.emplace_back(new ConvBackwardInputOpCapture());
@@ -494,9 +509,9 @@
 }
 
 std::unique_ptr<OperationPass<ModuleOp>> createOutlineGpuGraphsPass(
-    int gpu_graph_level, int min_graph_size) {
-  return std::make_unique<OutlineGpuGraphsPass>(gpu_graph_level,
-                                                min_graph_size);
+    absl::flat_hash_set<DebugOptions::CommandBufferCmdType> command_types,
+    int min_graph_size) {
+  return std::make_unique<OutlineGpuGraphsPass>(command_types, min_graph_size);
 }
 
 }  // namespace gpu
diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/passes.cc b/third_party/xla/xla/mlir/backends/gpu/transforms/passes.cc
index 9f6ea04..8fcb6d1 100644
--- a/third_party/xla/xla/mlir/backends/gpu/transforms/passes.cc
+++ b/third_party/xla/xla/mlir/backends/gpu/transforms/passes.cc
@@ -15,15 +15,43 @@
 
 #include "xla/mlir/backends/gpu/transforms/passes.h"
 
+#include <cstdint>
+#include <vector>
+
 #include "absl/log/log.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
+#include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/SymbolTable.h"  // from @llvm-project
 #include "mlir/Pass/PassManager.h"  // from @llvm-project
 #include "mlir/Transforms/Passes.h"  // from @llvm-project
+#include "xla/mlir/runtime/ir/rt_ops.h"
 
 namespace xla {
 namespace gpu {
 
 using namespace mlir;  // NOLINT
 
+std::vector<std::vector<int64_t>> GetAllocationIndices(mlir::ModuleOp module) {
+  std::vector<std::vector<int64_t>> res;
+
+  SymbolTable sym_table(module);
+  for (auto op : module.getOps<runtime::ExportOp>()) {
+    unsigned ordinal = *op.ordinal();
+    if (ordinal >= res.size()) res.resize(ordinal + 1);
+
+    auto func = sym_table.lookup<func::FuncOp>(op.getFunctionRef());
+    res[ordinal].resize(func.getNumArguments(), -1);
+
+    for (unsigned i = 0; i < func.getNumArguments(); ++i) {
+      auto idx = func.getArgAttrOfType<IntegerAttr>(i, "rt.allocation_index");
+      if (idx) res[ordinal][i] = idx.getInt();
+    }
+  }
+
+  return res;
+}
+
 void populateXlaGpuRuntimePasses(mlir::OpPassManager& pm,
                                  ThunkSequence* thunk_sequence,
                                  const GpuPipelineOpts& opts) {
@@ -39,7 +67,7 @@
 
   // Outline CUDA-Graph-compatible operations into graph capture functions.
   pm.addPass(
-      createOutlineGpuGraphsPass(opts.gpu_graph_level, opts.min_graph_size));
+      createOutlineGpuGraphsPass(opts.command_types, opts.min_graph_size));
   if (opts.enable_concurrent_region) {
     // Concurrent regions create repeated-fork-join topology inside CUDA graphs,
     // which is not optimized by architectures prior to Ampere and may cause
diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/passes.h b/third_party/xla/xla/mlir/backends/gpu/transforms/passes.h
index 4c3fc98..253eda6 100644
--- a/third_party/xla/xla/mlir/backends/gpu/transforms/passes.h
+++ b/third_party/xla/xla/mlir/backends/gpu/transforms/passes.h
@@ -18,11 +18,14 @@
 
 #include <cstdint>
 #include <memory>
+#include <vector>
 
+#include "absl/container/flat_hash_set.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "xla/stream_executor/device_description.h"
+#include "xla/xla.pb.h"
 
 namespace xla {
 namespace gpu {
@@ -40,11 +43,24 @@
 
 class ThunkSequence;  // forward declare
 
+// Collects `rt.allocation_index` attributes from all exported functions.
+//
+//   auto result = GetAllocationIndices();
+//   result[ordinal][argument_index] == allocation_index;
+//
+// Returns `-1` for all arguments that do not have `rt.allocation_index`
+// attribute.
+//
+// TODO(ezhulenev): This is a very ugly hack for graph capture integration, but
+// given that we are moving towards a new runtime and command buffers, it's
+// supposed to be a very short lived hack.
+std::vector<std::vector<int64_t>> GetAllocationIndices(mlir::ModuleOp module);
+
 struct GpuPipelineOpts {
   // Enable experimental pass that outlines parts of the XLA computation into
   // CUDA Graphs, which allows us to amortize the cost of launching multiple
   // device kernels.
-  int32_t gpu_graph_level = 0;
+  absl::flat_hash_set<DebugOptions::CommandBufferCmdType> command_types;
   int32_t min_graph_size = 0;
   bool enable_concurrent_region = false;
   stream_executor::GpuComputeCapability compute_capability;
@@ -106,7 +122,8 @@
 createOutlineGpuGraphsPass();
 
 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> createOutlineGpuGraphsPass(
-    int32_t gpu_graph_level, int32_t min_graph_size);
+    absl::flat_hash_set<DebugOptions::CommandBufferCmdType> command_types,
+    int32_t min_graph_size);
 
 //===----------------------------------------------------------------------===//
 // Passes for marking concurrent region in CUDA graph capture function.
diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/outline_cuda_graphs.mlir b/third_party/xla/xla/mlir/backends/gpu/transforms/tests/outline_cuda_graphs.mlir
index 9a25e7b..cd28641 100644
--- a/third_party/xla/xla/mlir/backends/gpu/transforms/tests/outline_cuda_graphs.mlir
+++ b/third_party/xla/xla/mlir/backends/gpu/transforms/tests/outline_cuda_graphs.mlir
@@ -146,13 +146,17 @@
 func.func private @external()
 
 // CHECK: rt.export @[[CAPTURE]]
-// CHECK: func.func @[[CAPTURE]](%arg0: memref<?xf32>)
+// CHECK: func.func @[[CAPTURE]](
+// CHECK:   %arg0: memref<?xf32>
+// CHECK: )
 // CHECK-NEXT: arith.constant 1
 // CHECK-NEXT: gpu.launch_func @gpu_module::@fn0
 // CHECK-NEXT: gpu.launch_func @gpu_module::@fn1
 
 // CHECK: rt.export @[[CAPTURE_0]]
-// CHECK: func.func @[[CAPTURE_0]](%arg0: memref<?xf32>)
+// CHECK: func.func @[[CAPTURE_0]](
+// CHECK:   %arg0: memref<?xf32>
+// CHECK: )
 // CHECK-NEXT: arith.constant 2
 // CHECK-NEXT: gpu.launch_func @gpu_module::@fn1
 // CHECK-NEXT: gpu.launch_func @gpu_module::@fn0
@@ -665,8 +669,9 @@
 }
 
 // CHECK: func @local_xla.gpu.graph.capture(
-// CHECK-SAME:  %[[ARG0]]: memref<?xf32> {lmhlo.constant_name = "cst0"},
-// CHECK-SAME:  %[[ARG1]]: memref<?xf32> {lmhlo.constant_name = "cst1"})
+// CHECK-SAME:  %[[ARG0]]: memref<?xf32> {lmhlo.constant_name = "cst0",
+// CHECK-SAME:  %[[ARG1]]: memref<?xf32> {lmhlo.constant_name = "cst1",
+// CHECK-SAME: )
 // CHECK-NEXT:  %[[C1:.*]] = arith.constant 1
 // CHECK-NEXT:  gpu.launch_func @gpu_module::@fn0
 // CHECK-SAME:    blocks in (%[[C1]], %[[C1]], %[[C1]])
diff --git a/third_party/xla/xla/mlir/runtime/transforms/BUILD b/third_party/xla/xla/mlir/runtime/transforms/BUILD
index a1497f5..d9151d9 100644
--- a/third_party/xla/xla/mlir/runtime/transforms/BUILD
+++ b/third_party/xla/xla/mlir/runtime/transforms/BUILD
@@ -33,6 +33,7 @@
         "convert_asserts.cc",
         "convert_custom_calls.cc",
         "export_functions.cc",
+        "move_allocas_to_entry_block.cc",
         "ordinal_assignment.cc",
         "rt_to_llvm.cc",
     ],
@@ -47,20 +48,20 @@
         "//xla/runtime:custom_call",
         "//xla/runtime:tracing",
         "//xla/runtime:type_id",
+        "@com_google_absl//absl/log:check",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:ArithDialect",
         "@llvm-project//mlir:AsyncDialect",
         "@llvm-project//mlir:ControlFlowDialect",
         "@llvm-project//mlir:FuncDialect",
-        "@llvm-project//mlir:FuncToLLVM",
         "@llvm-project//mlir:FuncTransforms",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:LLVMCommonConversion",
         "@llvm-project//mlir:LLVMDialect",
+        "@llvm-project//mlir:MemRefDialect",
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:Support",
         "@llvm-project//mlir:TransformUtils",
-        "@llvm-project//mlir:Transforms",
     ],
 )
 
@@ -103,11 +104,11 @@
     deps = [
         ":compilation_pipeline_options",
         ":compiler",
-        ":custom_call_encoding",
         ":passes",
         "//xla/mlir/backends/cpu/transforms:passes",
         "//xla/mlir/math/transforms:passes",
         "//xla/mlir/memref/transforms:passes",
+        "//xla/mlir/runtime/ir:rt",
         "//xla/mlir_hlo:transforms_passes",
         "//xla/runtime:compiler",
         "@llvm-project//mlir:AMXToLLVMIRTranslation",
@@ -125,24 +126,21 @@
         "@llvm-project//mlir:ControlFlowDialect",
         "@llvm-project//mlir:FuncDialect",
         "@llvm-project//mlir:FuncExtensions",
-        "@llvm-project//mlir:FuncToLLVM",
         "@llvm-project//mlir:GPUToGPURuntimeTransforms",
         "@llvm-project//mlir:GPUTransforms",
         "@llvm-project//mlir:LLVMToLLVMIRTranslation",
+        "@llvm-project//mlir:LinalgDialect",
         "@llvm-project//mlir:LinalgTransforms",
         "@llvm-project//mlir:MathDialect",
         "@llvm-project//mlir:MathToLLVM",
-        "@llvm-project//mlir:MathToLibm",
         "@llvm-project//mlir:MemRefDialect",
         "@llvm-project//mlir:MemRefToLLVM",
         "@llvm-project//mlir:MemRefTransforms",
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:ReconcileUnrealizedCasts",
         "@llvm-project//mlir:SCFDialect",
-        "@llvm-project//mlir:SCFToControlFlow",
         "@llvm-project//mlir:SparseTensorDialect",
         "@llvm-project//mlir:Transforms",
-        "@llvm-project//mlir:VectorToLLVM",
         "@llvm-project//mlir:X86VectorToLLVMIRTranslation",
     ],
     alwayslink = 1,  # has pipeline registration
diff --git a/third_party/xla/xla/mlir/runtime/transforms/compilation_pipeline_cpu.cc b/third_party/xla/xla/mlir/runtime/transforms/compilation_pipeline_cpu.cc
index 995a5f4..50488ca 100644
--- a/third_party/xla/xla/mlir/runtime/transforms/compilation_pipeline_cpu.cc
+++ b/third_party/xla/xla/mlir/runtime/transforms/compilation_pipeline_cpu.cc
@@ -20,14 +20,10 @@
 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"  // from @llvm-project
 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"  // from @llvm-project
 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"  // from @llvm-project
-#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"  // from @llvm-project
 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"  // from @llvm-project
 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"  // from @llvm-project
-#include "mlir/Conversion/MathToLibm/MathToLibm.h"  // from @llvm-project
 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"  // from @llvm-project
 #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"  // from @llvm-project
-#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"  // from @llvm-project
-#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"  // from @llvm-project
 #include "mlir/Dialect/Affine/IR/AffineOps.h"  // from @llvm-project
 #include "mlir/Dialect/Arith/IR/Arith.h"  // from @llvm-project
 #include "mlir/Dialect/Arith/Transforms/Passes.h"  // from @llvm-project
@@ -37,13 +33,15 @@
 #include "mlir/Dialect/Func/Extensions/AllExtensions.h"  // from @llvm-project
 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
 #include "mlir/Dialect/GPU/Transforms/Passes.h"  // from @llvm-project
+#include "mlir/Dialect/Linalg/IR/Linalg.h"  // from @llvm-project
 #include "mlir/Dialect/Linalg/Passes.h"  // from @llvm-project
 #include "mlir/Dialect/Math/IR/Math.h"  // from @llvm-project
 #include "mlir/Dialect/MemRef/IR/MemRef.h"  // from @llvm-project
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"  // from @llvm-project
 #include "mlir/Dialect/SCF/IR/SCF.h"  // from @llvm-project
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"  // from @llvm-project
-#include "mlir/Pass/Pass.h"  // from @llvm-project
+#include "mlir/Pass/PassManager.h"  // from @llvm-project
+#include "mlir/Pass/PassRegistry.h"  // from @llvm-project
 #include "mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h"  // from @llvm-project
 #include "mlir/Target/LLVMIR/Dialect/ArmNeon/ArmNeonToLLVMIRTranslation.h"  // from @llvm-project
 #include "mlir/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.h"  // from @llvm-project
@@ -54,8 +52,9 @@
 #include "xla/mlir/backends/cpu/transforms/passes.h"
 #include "xla/mlir/math/transforms/passes.h"
 #include "xla/mlir/memref/transforms/passes.h"
+#include "xla/mlir/runtime/ir/rt_dialect.h"
+#include "xla/mlir/runtime/transforms/compilation_pipeline_options.h"
 #include "xla/mlir/runtime/transforms/compiler.h"
-#include "xla/mlir/runtime/transforms/custom_call_encoding.h"
 #include "xla/mlir/runtime/transforms/passes.h"
 #include "xla/mlir_hlo/transforms/passes.h"
 
@@ -113,6 +112,9 @@
   // Lower from high level async operations to async runtime.
   pm.addPass(mlir::createAsyncToAsyncRuntimePass());
 
+  // Move all memref.alloca to entry block for all functions.
+  pm.addPass(CreateMoveAllocasToEntryBlockPass());
+
   // Add async.runtime reference counting operations.
   pm.addPass(mlir::createAsyncRuntimePolicyBasedRefCountingPass());
 
diff --git a/third_party/xla/xla/mlir/runtime/transforms/move_allocas_to_entry_block.cc b/third_party/xla/xla/mlir/runtime/transforms/move_allocas_to_entry_block.cc
new file mode 100644
index 0000000..7a9a892
--- /dev/null
+++ b/third_party/xla/xla/mlir/runtime/transforms/move_allocas_to_entry_block.cc
@@ -0,0 +1,69 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+
+#include "absl/log/check.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
+#include "mlir/Dialect/MemRef/IR/MemRef.h"  // from @llvm-project
+#include "mlir/IR/Block.h"  // from @llvm-project
+#include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/ImplicitLocOpBuilder.h"  // from @llvm-project
+#include "mlir/Pass/Pass.h"  // from @llvm-project
+#include "xla/mlir/runtime/transforms/passes.h"
+
+namespace xla {
+namespace runtime {
+
+using namespace mlir;  // NOLINT
+
+#define GEN_PASS_DEF_MOVEALLOCASTOENTRYBLOCK
+#include "xla/mlir/runtime/transforms/passes.h.inc"
+
+class MoveAllocasToEntryBlockPass
+    : public impl::MoveAllocasToEntryBlockBase<MoveAllocasToEntryBlockPass> {
+  void runOnOperation() override;
+};
+
+//===----------------------------------------------------------------------====/
+
+void MoveAllocasToEntryBlockPass::runOnOperation() {
+  ModuleOp module = getOperation();
+  module.walk([](mlir::func::FuncOp func) {
+    CHECK(!func.getBlocks().empty());
+    Block* entryBlock = &func.getBlocks().front();
+    llvm::SmallVector<memref::AllocaOp> allocas;
+    for (auto op : func.getOps<memref::AllocaOp>()) {
+      if (op->getBlock() != entryBlock) {
+        allocas.push_back(op);
+      }
+    }
+
+    auto builder =
+        ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock);
+    builder.setInsertionPointToStart(entryBlock);
+    for (auto op : allocas) {
+      op->moveBefore(builder.getInsertionBlock(), builder.getInsertionPoint());
+    }
+  });
+}
+
+std::unique_ptr<OperationPass<ModuleOp>> CreateMoveAllocasToEntryBlockPass() {
+  return std::make_unique<MoveAllocasToEntryBlockPass>();
+}
+
+}  // namespace runtime
+}  // namespace xla
diff --git a/third_party/xla/xla/mlir/runtime/transforms/passes.h b/third_party/xla/xla/mlir/runtime/transforms/passes.h
index 8fa0918..9f18660 100644
--- a/third_party/xla/xla/mlir/runtime/transforms/passes.h
+++ b/third_party/xla/xla/mlir/runtime/transforms/passes.h
@@ -28,6 +28,7 @@
 namespace runtime {
 
 #define GEN_PASS_DECL_ORDINALASSIGNMENT
+#define GEN_PASS_DECL_MOVEALLOCASTOENTRYBLOCK
 #define GEN_PASS_DECL_EXPORTFUNCTIONS
 #define GEN_PASS_DECL_CONVERTCUSTOMCALLS
 #define GEN_PASS_DECL_CONVERTASSERTS
@@ -43,6 +44,9 @@
 CreateOrdinalAssignmentPass();
 
 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
+CreateMoveAllocasToEntryBlockPass();
+
+std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
 CreateExportRuntimeFunctionsPass();
 
 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
diff --git a/third_party/xla/xla/mlir/runtime/transforms/passes.td b/third_party/xla/xla/mlir/runtime/transforms/passes.td
index 4b2a736..aadf975 100644
--- a/third_party/xla/xla/mlir/runtime/transforms/passes.td
+++ b/third_party/xla/xla/mlir/runtime/transforms/passes.td
@@ -148,6 +148,17 @@
   let dependentDialects = ["xla::runtime::RuntimeDialect"];
 }
 
+def MoveAllocasToEntryBlock : Pass<"xla-rt-move-allocas-to-entry-block", "ModuleOp"> {
+  let summary = "Move all `memref.alloca` to entry block for all functions";
+
+  let description = [{
+    For all functions, move `memref.alloca` to entry block for coroutine stack safety.
+  }];
+
+  let constructor = "xla::runtime::CreateMoveAllocasToEntryBlockPass()";
+  let dependentDialects = ["xla::runtime::RuntimeDialect"];
+}
+
 def ConvertAsserts : Pass<"xla-rt-convert-asserts", "ModuleOp"> {
   let summary = "Converts asserts in exported functions to run-time errors";
 
diff --git a/third_party/xla/xla/mlir/runtime/transforms/tests/move_allocas_to_entry_block.mlir b/third_party/xla/xla/mlir/runtime/transforms/tests/move_allocas_to_entry_block.mlir
new file mode 100644
index 0000000..dff76dfc
--- /dev/null
+++ b/third_party/xla/xla/mlir/runtime/transforms/tests/move_allocas_to_entry_block.mlir
@@ -0,0 +1,55 @@
+// RUN: xla-runtime-opt %s --xla-rt-move-allocas-to-entry-block | FileCheck %s
+
+func.func @compute(
+  %arg0: !rt.execution_context,
+  %arg1: !async.value<memref<f32>>
+) -> !async.token attributes {passthrough = ["presplitcoroutine"]} {
+  // CHECK:   %alloca = memref.alloca() {alignment = 64 : i64} : memref<f32>
+  // CHECK:   %0 = async.runtime.create : !async.token
+  // CHECK:   %1 = async.coro.id
+  // CHECK:   %2 = async.coro.begin %1
+  // CHECK:   %3 = async.coro.save %2
+  // CHECK:   async.runtime.resume %2
+  // CHECK:   async.coro.suspend %3, ^bb9, ^bb1, ^bb8
+  // CHECK: ^bb1:  // pred: ^bb0
+  // CHECK:   %status = rt.call %arg0["test.producer"] (%alloca)
+  // CHECK:     : (memref<f32>) -> ()
+  %0 = async.runtime.create : !async.token
+  %1 = async.coro.id
+  %2 = async.coro.begin %1
+  %3 = async.coro.save %2
+  async.runtime.resume %2
+  async.coro.suspend %3, ^bb9, ^bb1, ^bb8
+^bb1:  // pred: ^bb0
+  %alloca = memref.alloca() {alignment = 64 : i64} : memref<f32>
+  %status = rt.call %arg0["test.producer"] (%alloca) : (memref<f32>) -> ()
+  %4 = rt.is_ok %status
+  cf.cond_br %4, ^bb2, ^bb6
+^bb2:  // pred: ^bb1
+  %5 = async.coro.save %2
+  async.runtime.await_and_resume %arg1, %2 : !async.value<memref<f32>>
+  async.coro.suspend %5, ^bb9, ^bb3, ^bb8
+^bb3:  // pred: ^bb2
+  %6 = async.runtime.is_error %arg1 : !async.value<memref<f32>>
+  cf.cond_br %6, ^bb6, ^bb4
+^bb4:  // pred: ^bb3
+  %7 = async.runtime.load %arg1 : <memref<f32>>
+  %status_0 = rt.call %arg0["test.consumer"] (%alloca) : (memref<f32>) -> ()
+  %8 = rt.is_ok %status_0
+  cf.cond_br %8, ^bb5, ^bb6
+^bb5:  // pred: ^bb4
+  async.runtime.set_available %0 : !async.token
+  cf.br ^bb7
+^bb6:  // 3 preds: ^bb1, ^bb3, ^bb4
+  async.runtime.set_error %0 : !async.token
+  cf.br ^bb7
+^bb7:  // 2 preds: ^bb5, ^bb6
+  async.coro.free %1, %2
+  cf.br ^bb9
+^bb8:  // 2 preds: ^bb0, ^bb2
+  async.coro.free %1, %2
+  cf.br ^bb9
+^bb9:  // 4 preds: ^bb0, ^bb2, ^bb7, ^bb8
+  async.coro.end %2
+  return %0 : !async.token
+}
diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/scalarization/scalarization.cc b/third_party/xla/xla/mlir_hlo/gml_st/transforms/scalarization/scalarization.cc
index 8bc17d1..df12ce7 100644
--- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/scalarization/scalarization.cc
+++ b/third_party/xla/xla/mlir_hlo/gml_st/transforms/scalarization/scalarization.cc
@@ -243,7 +243,7 @@
       dyn_cast<RankedTensorType>(iterOperand.get().getType());
   if (!iterArgTensorTy || !hasSingleElement(iterArgTensorTy)) return failure();
 
-  Value bbArg = forOp.getRegionIterArgForOpOperand(iterOperand);
+  Value bbArg = forOp.getTiedLoopRegionIterArg(&iterOperand);
 
   if (!bbArg.hasOneUse()) return failure();
 
diff --git a/third_party/xla/xla/mlir_hlo/lhlo/IR/lhlo_dialect.td b/third_party/xla/xla/mlir_hlo/lhlo/IR/lhlo_dialect.td
index b68c45f..7cddf4e 100644
--- a/third_party/xla/xla/mlir_hlo/lhlo/IR/lhlo_dialect.td
+++ b/third_party/xla/xla/mlir_hlo/lhlo/IR/lhlo_dialect.td
@@ -22,7 +22,6 @@
 def LHLO_Dialect : Dialect {
   let name = "lmhlo";
   let cppNamespace = "::mlir::lmhlo";
-  let usePropertiesForAttributes = 0;
 }
 
 #endif  // LHLO_DIALECT
diff --git a/third_party/xla/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td b/third_party/xla/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td
index 1c742e5..dd828c3 100644
--- a/third_party/xla/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td
+++ b/third_party/xla/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td
@@ -165,12 +165,15 @@
     Arg<LHLO_Buffer, "", [MemRead]>:$a,
     Arg<LHLO_Buffer, "", [MemRead]>:$b,
     Arg<LHLO_Buffer, "", [MemRead, MemWrite]>:$c,
+    Arg<Optional<LHLO_Buffer>, "", [MemRead, MemWrite]>:$workspace,
     MHLO_DotDimensionNumbers:$dot_dimension_numbers,
     MHLO_PrecisionConfigAttr:$precision_config,
     F64Attr:$alpha_real,
     F64Attr:$alpha_imag,
     F64Attr:$beta,
-    OptionalAttr<I64Attr>:$algorithm);
+    OptionalAttr<I64Attr>:$algorithm,
+    OptionalAttr<BoolAttr>:$grad_x,
+    OptionalAttr<BoolAttr>:$grad_y);
 }
 
 def LHLOGPU_CublasLtMatmulOp : LHLOGPU_Op<"cublas.lt.matmul", [AttrSizedOperandSegments]> {
@@ -187,7 +190,9 @@
     F64Attr:$alpha_imag,
     F64Attr:$beta,
     CublasLtMatmulEpilogueAttr:$epilogue,
-    I64Attr:$algorithm);
+    I64Attr:$algorithm,
+    OptionalAttr<BoolAttr>:$grad_x,
+    OptionalAttr<BoolAttr>:$grad_y);
 }
 
 def LHLOGPU_CublasLtMatmulF8Op : LHLOGPU_Op<"cublas.lt.matmul.f8", [AttrSizedOperandSegments]> {
@@ -208,7 +213,9 @@
     F64Attr:$alpha_imag,
     F64Attr:$beta,
     CublasLtMatmulEpilogueAttr:$epilogue,
-    I64Attr:$algorithm);
+    I64Attr:$algorithm,
+    OptionalAttr<BoolAttr>:$grad_x,
+    OptionalAttr<BoolAttr>:$grad_y);
 }
 
 def LHLOGPU_CholeskyOp : LHLOGPU_Op<"cholesky"> {
diff --git a/third_party/xla/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops_base.td b/third_party/xla/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops_base.td
index 8190e64..878dc2c 100644
--- a/third_party/xla/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops_base.td
+++ b/third_party/xla/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops_base.td
@@ -25,7 +25,6 @@
   let cppNamespace = "::mlir::lmhlo_gpu";
 
   let useDefaultAttributePrinterParser = 1;
-  let usePropertiesForAttributes = 0;
 }
 
 #endif // LHLO_GPU_OPS_BASE
diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc
index 166b132..4774aac 100644
--- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc
+++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc
@@ -107,13 +107,6 @@
     // Proposal: https://github.com/openxla/stablehlo/issues/742.
     if (hasPackedNibble(hloOp.getPrecisionConfig())) return true;
   }
-  if constexpr (std::is_same<HloOpTy, mhlo::CustomCallOp>::value) {
-    // StableHLO CustomCall doesn't support API_VERSION_TYPED_FFI yet.
-    // Proposal: https://github.com/openxla/stablehlo/issues/637.
-    if (hloOp.getApiVersion() ==
-        mhlo::CustomCallApiVersion::API_VERSION_TYPED_FFI)
-      return true;
-  }
   if constexpr (std::is_same<HloOpTy, mhlo::DotGeneralOp>::value) {
     // StableHLO DotGeneral doesn't support PACKED_NIBBLE yet.
     // Proposal: https://github.com/openxla/stablehlo/issues/742.
@@ -132,13 +125,21 @@
 // frontends but are not yet part of StableHLO. Such features might be a good
 // fit for StableHLO, and are usually accompanied by a StableHLO GitHub ticket.
 template <typename HloOpTy>
-std::optional<int64_t> getPublicFeaturesNotInStablehlo(HloOpTy) {
+std::optional<int64_t> getPublicFeaturesNotInStablehlo(HloOpTy hloOp) {
   // StableHLO doesn't support TanOp yet.
   // Proposal: https://github.com/openxla/stablehlo/issues/954
   if constexpr (std::is_same<HloOpTy, mhlo::TanOp>::value) {
     // Version 1: Initial version for TanOp.
     return 1;
   }
+  // StableHLO CustomCall doesn't support API_VERSION_TYPED_FFI yet.
+  // Proposal: https://github.com/openxla/stablehlo/issues/637.
+  if constexpr (std::is_same<HloOpTy, mhlo::CustomCallOp>::value) {
+    // Version 1: Initial version for TYPED_FFI
+    if (hloOp.getApiVersion() ==
+        mhlo::CustomCallApiVersion::API_VERSION_TYPED_FFI)
+      return 1;
+  }
   return std::nullopt;
 }
 
diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc
index 3675c70..a3c401a 100644
--- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc
+++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc
@@ -263,6 +263,34 @@
   return success();
 }
 
+// Preserve backward compatibility of typed_ffi custom calls by converting:
+// `stablehlo.custom_call @foo(%arg0) { mhlo.backend_config = {...} }`
+// ==>
+// `mhlo.custom_call @foo(%arg0) { backend_config = {...}, api_version = 4}`
+//
+// Fails if StableHLO op has non-empty backend_config, or uses API version
+// other than API_VERSION_ORIGINAL.
+LogicalResult fixupMhloBackendConfig(stablehlo::CustomCallOp stablehloOp,
+                                     mhlo::CustomCallOp hloOp) {
+  auto stablehloBackendConfig = stablehloOp->getAttr("mhlo.backend_config");
+  if (stablehloBackendConfig) {
+    if (auto oldHloBackendConfig =
+            hloOp.getBackendConfigAttr()
+                .template dyn_cast_or_null<StringAttr>()) {
+      if (!oldHloBackendConfig.empty()) return failure();
+    } else {
+      return failure();
+    }
+    if (stablehloOp.getApiVersion() !=
+        stablehlo::CustomCallApiVersion::API_VERSION_ORIGINAL)
+      return failure();
+
+    hloOp.setBackendConfigAttr(stablehloBackendConfig);
+    hloOp.setApiVersion(mhlo::CustomCallApiVersion::API_VERSION_TYPED_FFI);
+  }
+  return success();
+}
+
 template <typename StablehloOpTy>
 class StablehloToHloOpConverter : public OpConversionPattern<StablehloOpTy> {
  public:
@@ -323,23 +351,9 @@
           stablehloOp, hloTypes, hloOperands, hloAttrs);
     }
 
+    // For backward compatibility, fix custom call with mhlo.backend_config
     if constexpr (std::is_same<StablehloOpTy, stablehlo::CustomCallOp>::value) {
-      auto stablehloBackendConfig = stablehloOp->getAttr("mhlo.backend_config");
-      if (stablehloBackendConfig) {
-        if (auto oldHloBackendConfig =
-                hloOp.getBackendConfigAttr()
-                    .template dyn_cast_or_null<StringAttr>()) {
-          if (oldHloBackendConfig != "") return failure();
-        } else {
-          return failure();
-        }
-        if (stablehloOp.getApiVersion() !=
-            stablehlo::CustomCallApiVersion::API_VERSION_ORIGINAL)
-          return failure();
-
-        hloOp.setBackendConfigAttr(stablehloBackendConfig);
-        hloOp.setApiVersion(mhlo::CustomCallApiVersion::API_VERSION_TYPED_FFI);
-      }
+      if (failed(fixupMhloBackendConfig(stablehloOp, hloOp))) return failure();
     }
 
     // Finally, populate the regions while converting argument types
diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/canonicalize.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/canonicalize.mlir
index a7454be..da2d9bc 100644
--- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/canonicalize.mlir
+++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/canonicalize.mlir
@@ -2573,6 +2573,20 @@
   func.return %1 : tensor<i32>
 }
 
+func.func @fold_predtosi() -> tensor<i8> {
+  %0 = mhlo.constant dense<false> : tensor<i1>
+  // CHECK: mhlo.constant dense<0> : tensor<i8>
+  %1 = "mhlo.convert"(%0) : (tensor<i1>) -> tensor<i8>
+  func.return %1 : tensor<i8>
+}
+
+func.func @not_fold_itouq() -> tensor<!quant.uniform<i8:f32, 1.000000e+00:3>> {
+  // CHECK: mhlo.constant dense<1> : tensor<i8>
+  %0 = mhlo.constant dense<1> : tensor<i8>
+  %1 = "mhlo.convert"(%0) : (tensor<i8>) -> tensor<!quant.uniform<i8:f32, 1.000000e+00:3>>
+  func.return %1 : tensor<!quant.uniform<i8:f32, 1.000000e+00:3>>
+}
+
 // CHECK-LABEL: @eliminate_redundant_reshape
 func.func @eliminate_redundant_reshape(%arg : tensor<1x32xi16>) -> tensor<1x32xi16> {
   %0 = "mhlo.reshape"(%arg) : (tensor<1x32xi16>) -> tensor<2x16xi16>
diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-memref.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-memref.mlir
index bbbc30d..7756978 100644
--- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-memref.mlir
+++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-memref.mlir
@@ -121,8 +121,8 @@
 // CHECK-DAG: %[[O0:.*]] = memref.alloc() {{.*}} : memref<2xf32>
 // CHECK-DAG: %[[O1:.*]] = memref.alloc() {{.*}} : memref<2xf32>
 // CHECK-DAG: %[[O2:.*]] = memref.alloc() {{.*}} : memref<5xi32>
-// CHECK: "lmhlo.custom_call"(%[[I0]], %[[I1]], %[[O0]], %[[O1]], %[[O2]]) ({
-// CHECK-NEXT: }) {backend_config = "", call_target_name = "foo", has_side_effect = false, operandSegmentSizes = array<i32: 2, 3>} : (memref<2xf32>, memref<5xi32>, memref<2xf32>, memref<2xf32>, memref<5xi32>) -> ()
+// CHECK: "lmhlo.custom_call"(%[[I0]], %[[I1]], %[[O0]], %[[O1]], %[[O2]]) <{backend_config = "", call_target_name = "foo", has_side_effect = false, operandSegmentSizes = array<i32: 2, 3>}>
+// CHECK-NEXT: }) : (memref<2xf32>, memref<5xi32>, memref<2xf32>, memref<2xf32>, memref<5xi32>) -> ()
 // CHECK-DAG: %[[T0:.+]] = bufferization.to_tensor %[[O0]] : memref<2xf32>
 // CHECK-DAG: %[[T1:.+]] = bufferization.to_tensor %[[O1]] : memref<2xf32>
 // CHECK: return %[[T0]], %[[T1]] : tensor<2xf32>, tensor<2xf32>
@@ -148,6 +148,6 @@
 // CHECK-DAG: %[[I0:.+]] = bufferization.to_memref %[[ARG0]] : memref<2xf32>
 // CHECK-DAG: %[[I1:.+]] = bufferization.to_memref %[[ARG1]] : memref<5xi32>
 // CHECK-DAG: %[[ALLOC:.+]] = memref.alloc
-// CHECK: "lmhlo.custom_call"(%[[I0]], %[[I1]], %[[ALLOC]]) ({
-// CHECK-NEXT: }) {backend_config = "", call_target_name = "bar", has_side_effect = true, operandSegmentSizes = array<i32: 2, 1>, target_arg_mapping = #lmhlo.custom_call_target_arg_mapping<num_args = 3, num_results = 2, args_to_target_args = [0, 1], results_to_target_results = [1]>} : (memref<2xf32>, memref<5xi32>, memref<2xi32>)
+// CHECK: "lmhlo.custom_call"(%[[I0]], %[[I1]], %[[ALLOC]]) <{backend_config = "", call_target_name = "bar", has_side_effect = true, operandSegmentSizes = array<i32: 2, 1>, target_arg_mapping = #lmhlo.custom_call_target_arg_mapping<num_args = 3, num_results = 2, args_to_target_args = [0, 1], results_to_target_results = [1]>}>
+// CHECK-NEXT: }) : (memref<2xf32>, memref<5xi32>, memref<2xi32>) -> ()
 // CHECK: return %[[TOKEN]] : !mhlo.token
diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo-experimental.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo-experimental.mlir
index fb1a9e0..ed2d07b 100644
--- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo-experimental.mlir
+++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo-experimental.mlir
@@ -42,23 +42,6 @@
 
 // -----
 
-// CHECK-LABEL: "op_custom_call_api_version_typed_ffi"
-func.func @op_custom_call_api_version_typed_ffi(%arg0: tensor<f32>) -> tensor<f32> {
-  //      CHECK: "stablehlo.custom_call"(%arg0) {
-  // CHECK-SAME:   call_target_name = "mhlo.custom_call"
-  // CHECK-SAME:   mhlo.attributes = {api_version = 4 : i32, backend_config = {foo = "bar"}, call_target_name = "foo"}
-  // CHECK-SAME: } : (tensor<f32>) -> tensor<f32>
-  // expected-error@+1 {{failed to legalize operation 'mhlo.custom_call' that was explicitly marked illegal}}
-  %0 = "mhlo.custom_call"(%arg0) {
-    call_target_name = "foo",
-    backend_config = {foo = "bar"},
-    api_version = 4 : i32
-  } : (tensor<f32>) -> tensor<f32>
-  return %0 : tensor<f32>
-}
-
-// -----
-
 // CHECK-LABEL: "attr_precision_packed_nibble"
 func.func @attr_precision_packed_nibble(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> {
   //      CHECK: "stablehlo.custom_call"(%arg0, %arg1) {
diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir
index 0ad2e01..6de9c1a 100644
--- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir
+++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir
@@ -152,6 +152,23 @@
   func.return %0 : tensor<f32>
 }
 
+// -----
+
+// CHECK-LABEL: "attr_custom_call_api_version_typed_ffi"
+func.func @attr_custom_call_api_version_typed_ffi(%arg0: tensor<f32>) -> tensor<f32> {
+  //      CHECK: "stablehlo.custom_call"(%arg0) {
+  // CHECK-SAME:   call_target_name = "mhlo.custom_call"
+  // CHECK-SAME:   mhlo.attributes = {api_version = 4 : i32, backend_config = {foo = "bar"}, call_target_name = "foo"},
+  // CHECK-SAME:   mhlo.version = 1 : i64
+  // CHECK-SAME: } : (tensor<f32>) -> tensor<f32>
+  %0 = "mhlo.custom_call"(%arg0) {
+    call_target_name = "foo",
+    backend_config = {foo = "bar"},
+    api_version = 4 : i32
+  } : (tensor<f32>) -> tensor<f32>
+  return %0 : tensor<f32>
+}
+
 // CustomCallSchedule aka #mhlo<custom_call_schedule> is unsupported at the moment (see negative test below).
 // DequantizeMode aka #mhlo<dequantize_mode> is unused at the moment.
 // DomainKind aka #mhlo<kind> is unsupported at the moment (see negative test below).
diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir
index d1dc28f..af84198 100644
--- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir
+++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir
@@ -715,7 +715,8 @@
   // CHECK-SAME: } : (tensor<f32>) -> tensor<f32>
   %0 = "stablehlo.custom_call"(%arg0) {
     call_target_name = "mhlo.custom_call",
-    mhlo.attributes = {api_version = 4 : i32, backend_config = {foo = "bar"}, call_target_name = "foo"}
+    mhlo.attributes = {api_version = 4 : i32, backend_config = {foo = "bar"}, call_target_name = "foo"},
+    mhlo.version = 1 : i64
   } : (tensor<f32>) -> tensor<f32>
   return %0 : tensor<f32>
 }
@@ -729,7 +730,7 @@
   // CHECK-SAME: } : (tensor<16x256xbf16>) -> tensor<16x4xbf16>
   %4 = stablehlo.custom_call @foo(%arg0) {
     "mhlo.backend_config" = {aggregate_to_topk = true}
-    } : (tensor<16x256xbf16>) -> tensor<16x4xbf16>
+  } : (tensor<16x256xbf16>) -> tensor<16x4xbf16>
   return %4 : tensor<16x4xbf16>
 }
 
diff --git a/third_party/xla/xla/mlir_hlo/utils/convert_op_folder.cc b/third_party/xla/xla/mlir_hlo/utils/convert_op_folder.cc
index 55ba8de..9421c6e 100644
--- a/third_party/xla/xla/mlir_hlo/utils/convert_op_folder.cc
+++ b/third_party/xla/xla/mlir_hlo/utils/convert_op_folder.cc
@@ -31,7 +31,8 @@
                                        mlir::Type newType) {
   auto oldType = getElementTypeOrSelf(elements);
   // TODO(kramerb): Add support when MLIR can represent const complex tensors.
-  if (oldType.isa<mlir::ComplexType>() || newType.isa<mlir::ComplexType>()) {
+  if (!oldType.isa<mlir::IntegerType, mlir::FloatType>() ||
+      !newType.isa<mlir::IntegerType, mlir::FloatType>()) {
     return {};
   }
 
diff --git a/third_party/xla/xla/pjrt/abstract_tfrt_cpu_buffer.cc b/third_party/xla/xla/pjrt/abstract_tfrt_cpu_buffer.cc
index e8b5009..b1e7be8 100644
--- a/third_party/xla/xla/pjrt/abstract_tfrt_cpu_buffer.cc
+++ b/third_party/xla/xla/pjrt/abstract_tfrt_cpu_buffer.cc
@@ -73,23 +73,45 @@
 
 constexpr size_t kSmallDataTransferByteSize = 102400;  // 100 KiB
 
+// Unpacks and copies the int4 data at 'input' into the literal at the given
+// ShapeIndex.
+void UnpackInt4ToLiteral(const MaybeOwningCpuMemory& input,
+                         MutableLiteralBase* literal,
+                         const ShapeIndex& shape_index) {
+  absl::Span<const char> input_span{static_cast<const char*>(input.data()),
+                                    input.size()};
+  size_t output_size = static_cast<size_t>(ShapeUtil::ByteSizeOf(
+      ShapeUtil::GetSubshape(literal->shape(), shape_index)));
+  absl::Span<char> output_span{
+      static_cast<char*>(literal->untyped_data(shape_index)), output_size};
+  UnpackInt4(input_span, output_span);
+}
+
 void CopyCpuBufferToLiteral(const Shape& device_shape,
                             TrackedTfrtCpuDeviceBuffer* device_buffer,
                             MutableLiteralBase* literal) {
   if (!device_shape.IsTuple()) {
     const std::shared_ptr<MaybeOwningCpuMemory>& b =
         device_buffer->Buffers()[0];
-    std::memcpy(literal->untyped_data(), b->data(),
-                ShapeUtil::ByteSizeOf(device_shape));
+    if (primitive_util::Is4BitType(device_shape.element_type())) {
+      UnpackInt4ToLiteral(*b, literal, /*shape_index=*/{});
+    } else {
+      std::memcpy(literal->untyped_data(), b->data(),
+                  ShapeUtil::ByteSizeOf(device_shape));
+    }
   } else {
     // Tuple case.
     int num_leaves = literal->shape().tuple_shapes().size();
     for (int i = 0; i < num_leaves; ++i) {
       const std::shared_ptr<MaybeOwningCpuMemory>& b =
           device_buffer->Buffers()[i];
-      std::memcpy(
-          literal->untyped_data({i}), b->data(),
-          ShapeUtil::ByteSizeOf(ShapeUtil::GetSubshape(device_shape, {i})));
+      if (primitive_util::Is4BitType(device_shape.element_type())) {
+        UnpackInt4ToLiteral(*b, literal, {i});
+      } else {
+        std::memcpy(
+            literal->untyped_data({i}), b->data(),
+            ShapeUtil::ByteSizeOf(ShapeUtil::GetSubshape(device_shape, {i})));
+      }
     }
   }
 }
@@ -666,12 +688,14 @@
     TransposePlanCache* transpose_cache) {
   bool has_default_layout =
       !byte_strides || HasMajorToMinorLayout(type, dims, *byte_strides);
+  // Int4 arrays are unpacked on host and packed on device.
+  bool is_int4 = primitive_util::Is4BitType(type);
   // If the input buffer has a default layout and is sufficiently aligned, we
   // can simply point to the input array's data without any further copies. At
   // the time of writing we require a 16-byte alignment because XLA may generate
   // code which requires it.
   bool can_use_zero_copy =
-      has_default_layout &&
+      has_default_layout && !is_int4 &&
       host_buffer_semantics == PjRtClient::HostBufferSemantics::kZeroCopy &&
       ((absl::bit_cast<std::uintptr_t>(data) &
         (cpu_function_runtime::MinAlign() - 1)) == 0);
@@ -685,11 +709,13 @@
     buffers.push_back(std::move(device_buffer));
     on_delete_callback = std::move(on_done_with_host_buffer);
   } else {
+    size_t dst_byte_size =
+        is_int4 ? CeilOfRatio(byte_size, size_t{2}) : byte_size;
     TF_ASSIGN_OR_RETURN(std::shared_ptr<MaybeOwningCpuMemory> device_buffer,
-                        MaybeOwningCpuMemory::AllocateShared(byte_size));
+                        MaybeOwningCpuMemory::AllocateShared(dst_byte_size));
     auto dst_data_ptr = device_buffer->data();
     buffers.push_back(device_buffer);
-    if (!has_default_layout) {
+    if (!has_default_layout || is_int4) {
       // If the input array does not have a major-to-minor layout, transpose it
       // into major-to-minor layout. Currently we choose to always do this
       // synchronously.
@@ -705,7 +731,20 @@
                            primitive_util::ByteWidth(type), dims, permutation,
                            TransposePlan::Striding{*byte_strides}));
       }
-      transpose->Execute(data, dst_data_ptr);
+      if (!is_int4) {
+        transpose->Execute(data, dst_data_ptr);
+      } else {
+        // First transpose the unpacked data into a new temporary buffer, then
+        // pack the data.
+        // TODO(reedwm): Fuse the transpose and packing by having TransposePlan
+        // support packing.
+        auto data_transposed = std::make_unique<char[]>(byte_size);
+        transpose->Execute(data, data_transposed.get());
+        absl::Span<const char> src_data_span(data_transposed.get(), byte_size);
+        absl::Span<char> dst_data_span(static_cast<char*>(dst_data_ptr),
+                                       dst_byte_size);
+        PackInt4(src_data_span, dst_data_span);
+      }
       if (on_done_with_host_buffer) {
         on_done_with_host_buffer();
         on_done_with_host_buffer = nullptr;
diff --git a/third_party/xla/xla/pjrt/c/CHANGELOG.md b/third_party/xla/xla/pjrt/c/CHANGELOG.md
index db6d03c..95c49a9 100644
--- a/third_party/xla/xla/pjrt/c/CHANGELOG.md
+++ b/third_party/xla/xla/pjrt/c/CHANGELOG.md
@@ -1,28 +1,35 @@
 # PJRT C API changelog
 
-## 0.34 (Oct 9, 2023)
+## 0.38 (Oct 30, 2023)
+* Use `enum` to define STRUCT_SIZE constants in a header file.
 
+## 0.37 (Oct 27, 2023)
+* Added const to a bunch of lists and value types.
+
+## 0.36 (Oct 24, 2023)
+* Added PJRT_Client_TopologyDescription
+
+## 0.35 (Oct 20, 2023)
+* Added PJRT_Executable_Fingerprint method
+* Deprecated PJRT_LoadedExecutable_Fingerprint
+
+## 0.34 (Oct 9, 2023)
 * Added PJRT_Structure_Type::PJRT_Structure_Type_Profiler.
 
 ## 0.33 (Oct 3, 2023)
-
 * Added PJRT_Client_CreateViewOfDeviceBuffer.
 
 ## 0.32 (Sep 26, 2023)
-
 * Added PJRT_Buffer_CopyToMemory.
 
 ## 0.31 (Sep 22, 2023)
-
 * Added PJRT_Structure_Base.
 * Added PJRT_Structure_Type.
 * Renamed PJRT_Api.priv to PJRT_Api.extension_start.
 
 ## 0.30 (Sep 14, 2023)
-
 * Added PJRT_NamedValue_Type::PJRT_NamedValue_kBool.
 
 ## 0.29 (Sep 6, 2023)
-
 * Added PJRT_Executable_OutputElementTypes.
-* Added PJRT_Executable_OutputDimensions.
\ No newline at end of file
+* Added PJRT_Executable_OutputDimensions.
diff --git a/third_party/xla/xla/pjrt/c/README.md b/third_party/xla/xla/pjrt/c/README.md
index e93d810..cdead8e 100644
--- a/third_party/xla/xla/pjrt/c/README.md
+++ b/third_party/xla/xla/pjrt/c/README.md
@@ -9,10 +9,11 @@
 ## Communication channel
 
 *   Please file issues in the [OpenXla/xla repo](https://github.com/openxla/xla).
-*   Join discussion in the #pjrt-plugin channel of the [IREE discord server](https://github.com/openxla/iree/#communication-channels).
+*   Join the [pjrt-announcement maillist](https://groups.google.com/g/pjrt-announce/).
 
 ## Resources
 
-*   [OpenXLA/IREE PJRT plugin implementation](https://github.com/openxla/openxla-pjrt-plugin)
+*   [PJRT C API changelog](https://github.com/openxla/xla/blob/main/xla/pjrt/c/CHANGELOG.md)
 *   [PJRT integration guide](https://github.com/openxla/xla/blob/main/xla/pjrt/c/docs/pjrt_integration_guide.md)
 *   [PJRT Plugin Mechanism design doc](https://docs.google.com/document/d/1Qdptisz1tUPGn1qFAVgCV2omnfjN01zoQPwKLdlizas/edit)
+*   [OpenXLA/IREE PJRT plugin implementation](https://github.com/openxla/openxla-pjrt-plugin)
diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api.h b/third_party/xla/xla/pjrt/c/pjrt_c_api.h
index 20f3560..122f941 100644
--- a/third_party/xla/xla/pjrt/c/pjrt_c_api.h
+++ b/third_party/xla/xla/pjrt/c/pjrt_c_api.h
@@ -25,7 +25,7 @@
 
 #define PJRT_DEFINE_STRUCT_TRAITS(sname, last_field) \
   typedef struct sname sname;                        \
-  const size_t sname##_STRUCT_SIZE = PJRT_STRUCT_SIZE(sname, last_field);
+  enum { sname##_STRUCT_SIZE = PJRT_STRUCT_SIZE(sname, last_field) }
 
 #ifdef __cplusplus
 extern "C" {
@@ -53,7 +53,7 @@
 // Changes include:
 // * Adding a new field to the PJRT_Api or argument structs
 // * Renaming a method or argument (doesn't affect ABI)
-#define PJRT_API_MINOR 34
+#define PJRT_API_MINOR 38
 
 // The plugin should set the major_version and minor_version of
 // PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in
@@ -183,8 +183,8 @@
   size_t struct_size;
   void* priv;
   // Returned attributes have the lifetime of the process.
-  PJRT_NamedValue* attributes;  // out
-  size_t num_attributes;        // out
+  const PJRT_NamedValue* attributes;  // out
+  size_t num_attributes;              // out
 };
 PJRT_DEFINE_STRUCT_TRAITS(PJRT_Plugin_Attributes_Args, attributes);
 
@@ -282,6 +282,7 @@
 typedef struct PJRT_Device PJRT_Device;
 typedef struct PJRT_Memory PJRT_Memory;
 typedef struct PJRT_DeviceDescription PJRT_DeviceDescription;
+typedef struct PJRT_TopologyDescription PJRT_TopologyDescription;
 typedef struct PJRT_Executable PJRT_Executable;
 typedef struct PJRT_LoadedExecutable PJRT_LoadedExecutable;
 typedef struct PJRT_Buffer PJRT_Buffer;
@@ -345,7 +346,7 @@
   size_t struct_size;
   void* priv;
   // Extra platform-specific options to create a client.
-  PJRT_NamedValue* create_options;
+  const PJRT_NamedValue* create_options;
   size_t num_options;
   // Key-value get/put callback provided by the caller of PJRT_Client_Create.
   // PJRT client can use these callbacks to share information between
@@ -418,12 +419,26 @@
 typedef PJRT_Error* PJRT_Client_PlatformVersion(
     PJRT_Client_PlatformVersion_Args* args);
 
+struct PJRT_Client_TopologyDescription_Args {
+  size_t struct_size;
+  void* priv;
+  PJRT_Client* client;
+  // Is owned by and has the same lifetime as `client`.
+  PJRT_TopologyDescription* topology;  // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_TopologyDescription_Args, topology);
+
+// Returns the topology description of the runtime topology. The returned
+// topology is owned by the client and should not be deleted by the caller.
+typedef PJRT_Error* PJRT_Client_TopologyDescription(
+    PJRT_Client_TopologyDescription_Args* args);
+
 struct PJRT_Client_Devices_Args {
   size_t struct_size;
   void* priv;
   PJRT_Client* client;
-  PJRT_Device** devices;  // out
-  size_t num_devices;     // out
+  PJRT_Device* const* devices;  // out
+  size_t num_devices;           // out
 };
 PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_Devices_Args, num_devices);
 
@@ -435,8 +450,8 @@
   size_t struct_size;
   void* priv;
   PJRT_Client* client;
-  PJRT_Device** addressable_devices;  // out
-  size_t num_addressable_devices;     // out
+  PJRT_Device* const* addressable_devices;  // out
+  size_t num_addressable_devices;           // out
 };
 PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_AddressableDevices_Args,
                           num_addressable_devices);
@@ -483,8 +498,8 @@
   size_t struct_size;
   void* priv;
   PJRT_Client* client;
-  PJRT_Memory** addressable_memories;  // out
-  size_t num_addressable_memories;     // out
+  PJRT_Memory* const* addressable_memories;  // out
+  size_t num_addressable_memories;           // out
 };
 PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_AddressableMemories_Args,
                           num_addressable_memories);
@@ -518,7 +533,7 @@
   PJRT_Client* client;
   // Only needs to stay alive for the duration of the Compile call.
   // `program->format` and `program->format_size` are owned by the caller.
-  PJRT_Program* program;
+  const PJRT_Program* program;
   // TODO(b/240560013): consider putting some of option fields in priv.
   // Serialized CompileOptionsProto
   // (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/compile_options.proto)
@@ -799,8 +814,8 @@
   size_t struct_size;
   void* priv;
   PJRT_DeviceDescription* device_description;
-  size_t num_attributes;        // out
-  PJRT_NamedValue* attributes;  // out
+  size_t num_attributes;              // out
+  const PJRT_NamedValue* attributes;  // out
 };
 PJRT_DEFINE_STRUCT_TRAITS(PJRT_DeviceDescription_Attributes_Args, attributes);
 
@@ -898,8 +913,8 @@
   void* priv;
   PJRT_Device* device;
   // Has the lifetime of `device`.
-  PJRT_Memory** memories;  // out
-  size_t num_memories;     // out
+  PJRT_Memory* const* memories;  // out
+  size_t num_memories;           // out
 };
 PJRT_DEFINE_STRUCT_TRAITS(PJRT_Device_AddressableMemories_Args, memories);
 
@@ -1025,8 +1040,8 @@
   size_t struct_size;
   void* priv;
   PJRT_Memory* memory;
-  PJRT_Device** devices;  // out
-  size_t num_devices;     // out
+  PJRT_Device* const* devices;  // out
+  size_t num_devices;           // out
 };
 PJRT_DEFINE_STRUCT_TRAITS(PJRT_Memory_AddressableByDevices_Args, num_devices);
 
@@ -1114,8 +1129,8 @@
   size_t struct_size;
   void* priv;
   PJRT_LoadedExecutable* executable;
-  PJRT_Device** addressable_devices;  // out
-  size_t num_addressable_devices;     // out
+  PJRT_Device* const* addressable_devices;  // out
+  size_t num_addressable_devices;           // out
 };
 PJRT_DEFINE_STRUCT_TRAITS(PJRT_LoadedExecutable_AddressableDevices_Args,
                           num_addressable_devices);
@@ -1259,7 +1274,7 @@
   // Only needs to stay alive for the duration of the Execute call.
   PJRT_ExecuteOptions* options;
   // Execution input of size [`num_devices`, `num_args`].
-  PJRT_Buffer*** argument_lists;
+  PJRT_Buffer* const* const* argument_lists;
   size_t num_devices;
   size_t num_args;
   // Execution output of size [`num_devices`, num_outputs`], where `num_outputs`
@@ -1267,7 +1282,7 @@
   // outer (`PJRT_Buffer***`) and inner lists (`PJRT_Buffer**`) must be
   // allocated and deallocated by the caller. PJRT_Buffer_Destroy must be called
   // on the output PJRT_Buffer*.
-  PJRT_Buffer*** output_lists;  // in/out
+  PJRT_Buffer** const* output_lists;  // in/out
   // If `device_complete_events` isn't nullptr, `device_complete_events` needs
   // to be the same length as `output_lists` (i.e. of length `num_devices`), and
   // each `PJRT_Event` will become ready once the corresponding device execution
@@ -1315,6 +1330,24 @@
 typedef PJRT_Error* PJRT_Executable_SizeOfGeneratedCodeInBytes(
     PJRT_Executable_SizeOfGeneratedCodeInBytes_Args* args);
 
+struct PJRT_Executable_Fingerprint_Args {
+  size_t struct_size;
+  void* priv;
+  PJRT_Executable* executable;
+  // Has the lifetime of `executable`
+  const char* executable_fingerprint;  // out
+  size_t executable_fingerprint_size;  // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_Fingerprint_Args,
+                          executable_fingerprint_size);
+
+// A unique fingerprint for `executable`. Two executables that were produced by
+// compiling with identical inputs (same program, compile options, compiler
+// version, etc.) should have the same fingerprint. May not be implemented by
+// all platforms.
+typedef PJRT_Error* PJRT_Executable_Fingerprint(
+    PJRT_Executable_Fingerprint_Args* args);
+
 struct PJRT_Executable_GetCostAnalysis_Args {
   size_t struct_size;
   void* priv;
@@ -1322,7 +1355,7 @@
   size_t num_properties;  // out
   // `properties` and any embedded data are owned by and have the same lifetime
   // as `executable`.
-  PJRT_NamedValue* properties;  // out
+  const PJRT_NamedValue* properties;  // out
 };
 PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_GetCostAnalysis_Args, properties);
 
@@ -1371,7 +1404,7 @@
   PJRT_Executable* executable;
   size_t num_outputs;
   // Has length `num_outputs`.
-  const char** memory_kinds;  // out
+  const char* const* memory_kinds;  // out
   // Has length `num_outputs`.
   const size_t* memory_kind_sizes;  // out
 };
@@ -1434,10 +1467,11 @@
 };
 PJRT_DEFINE_STRUCT_TRAITS(PJRT_LoadedExecutable_Fingerprint_Args,
                           executable_fingerprint_size);
-// A unique fingerprint for `executable`. Two executables that were produced by
-// compiling with identical inputs (same program, compile options, compiler
-// version, etc.) should have the same fingerprint. May not be implemented by
-// all platforms.
+// DEPRECATED. Will be removed in PJRT version 2.0. Please use
+// PJRT_Executable_Fingerprint instead. A unique fingerprint for `executable`.
+// Two executables that were produced by compiling with identical inputs (same
+// program, compile options, compiler version, etc.) should have the same
+// fingerprint. May not be implemented by all platforms.
 typedef PJRT_Error* PJRT_LoadedExecutable_Fingerprint(
     PJRT_LoadedExecutable_Fingerprint_Args* args);
 
@@ -1810,15 +1844,13 @@
 
 // ------------------------------ Device Topology ------------------------------
 
-typedef struct PJRT_TopologyDescription PJRT_TopologyDescription;
-
 struct PJRT_TopologyDescription_Create_Args {
   size_t struct_size;
   void* priv;
   const char* topology_name;
   size_t topology_name_size;
   // Extra platform-specific options to create a client.
-  PJRT_NamedValue* create_options;
+  const PJRT_NamedValue* create_options;
   size_t num_options;
   PJRT_TopologyDescription* topology;  // out
 };
@@ -1878,8 +1910,8 @@
   void* priv;
   PJRT_TopologyDescription* topology;
   // Has the same lifetime as topology.
-  PJRT_DeviceDescription** descriptions;  // out
-  size_t num_descriptions;                // out
+  PJRT_DeviceDescription* const* descriptions;  // out
+  size_t num_descriptions;                      // out
 };
 PJRT_DEFINE_STRUCT_TRAITS(PJRT_TopologyDescription_GetDeviceDescriptions_Args,
                           num_descriptions);
@@ -1920,8 +1952,8 @@
   PJRT_TopologyDescription* topology;
 
   // Only lives as long as topology.
-  PJRT_NamedValue* attributes;  // out
-  size_t num_attributes;        // out
+  const PJRT_NamedValue* attributes;  // out
+  size_t num_attributes;              // out
 };
 PJRT_DEFINE_STRUCT_TRAITS(PJRT_TopologyDescription_Attributes_Args,
                           num_attributes);
@@ -1936,7 +1968,7 @@
   const PJRT_TopologyDescription* topology;
   // Only needs to stay alive for the duration of the Compile call.
   // `program->format` and `program->format_size` are owned by the caller.
-  PJRT_Program* program;
+  const PJRT_Program* program;
   // TODO(b/240560013): consider putting some of option fields in priv.
   // Serialized CompileOptionsProto
   // (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/compile_options.proto)
@@ -2090,10 +2122,16 @@
   _PJRT_API_STRUCT_FIELD(PJRT_Buffer_CopyToMemory);
 
   _PJRT_API_STRUCT_FIELD(PJRT_Client_CreateViewOfDeviceBuffer);
+
+  _PJRT_API_STRUCT_FIELD(PJRT_Executable_Fingerprint);
+
+  _PJRT_API_STRUCT_FIELD(PJRT_Client_TopologyDescription);
 } PJRT_Api;
 
-const size_t PJRT_Api_STRUCT_SIZE =
-    PJRT_STRUCT_SIZE(PJRT_Api, PJRT_Client_CreateViewOfDeviceBuffer);
+enum {
+  PJRT_Api_STRUCT_SIZE =
+      PJRT_STRUCT_SIZE(PJRT_Api, PJRT_Client_TopologyDescription)
+};
 
 #undef _PJRT_API_STRUCT_FIELD
 
diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_extension.h b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_extension.h
index f86803a..a94d080 100644
--- a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_extension.h
+++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_extension.h
@@ -16,7 +16,7 @@
 #ifndef XLA_PJRT_C_PJRT_C_API_GPU_EXTENSION_H_
 #define XLA_PJRT_C_PJRT_C_API_GPU_EXTENSION_H_
 
-#include <cstddef>
+#include <stddef.h>
 
 #include "xla/pjrt/c/pjrt_c_api.h"
 
diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.cc
index 3ebc3ef..8b62aca 100644
--- a/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.cc
+++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.cc
@@ -480,7 +480,8 @@
 }
 
 absl::flat_hash_map<std::string, xla::PjRtValueType>
-ConvertFromPjRtNamedValueList(PJRT_NamedValue* c_value_list, size_t list_size) {
+ConvertFromPjRtNamedValueList(const PJRT_NamedValue* c_value_list,
+                              size_t list_size) {
   absl::flat_hash_map<std::string, xla::PjRtValueType> cpp_value_map;
   for (int i = 0; i < list_size; ++i) {
     const PJRT_NamedValue& c_value = c_value_list[i];
@@ -612,6 +613,16 @@
   return platform_name;
 }
 
+xla::StatusOr<PJRT_TopologyDescription*> GetTopologyDescription(
+    PJRT_Client* client, const PJRT_Api* api) {
+  PJRT_Client_TopologyDescription_Args args;
+  args.struct_size = PJRT_Client_TopologyDescription_Args_STRUCT_SIZE;
+  args.priv = nullptr;
+  args.client = client;
+  RETURN_STATUS_IF_PJRT_ERROR(api->PJRT_Client_TopologyDescription(&args), api);
+  return args.topology;
+}
+
 PJRT_Chunk ConvertFromCppChunk(xla::PjRtChunk chunk) {
   // `deleter_arg` holds a copy of the original xla::PjRtChunk
   // deleter. The original xla::PjRtChunk `input` releases its ownership
@@ -651,8 +662,8 @@
   return args.device_description;
 }
 
-absl::Span<PJRT_Memory*> GetAddressableMemories(const PJRT_Api* api,
-                                                PJRT_Device* device) {
+absl::Span<PJRT_Memory* const> GetAddressableMemories(const PJRT_Api* api,
+                                                      PJRT_Device* device) {
   PJRT_Device_AddressableMemories_Args args;
   args.struct_size = PJRT_Device_AddressableMemories_Args_STRUCT_SIZE;
   args.priv = nullptr;
@@ -661,6 +672,13 @@
   return absl::MakeSpan(args.memories, args.num_memories);
 }
 
+int GetId(const PJRT_Api* api, PJRT_DeviceDescription* device_desc) {
+  PJRT_DeviceDescription_Id_Args args = PJRT_DeviceDescription_Id_Args{
+      PJRT_DeviceDescription_Id_Args_STRUCT_SIZE, nullptr, device_desc};
+  pjrt::LogFatalIfPjrtError(api->PJRT_DeviceDescription_Id(&args), api);
+  return args.id;
+}
+
 static void PjRtValueDeleterCallback(char* value) { delete[] value; }
 
 static PJRT_KeyValueGetCFunc ToKVGetCFunc(
@@ -858,4 +876,26 @@
   return shape;
 }
 
+absl::string_view PlatformName(const PJRT_Api* api,
+                               const PJRT_TopologyDescription* topo_desc) {
+  PJRT_TopologyDescription_PlatformName_Args args;
+  args.struct_size = PJRT_TopologyDescription_PlatformName_Args_STRUCT_SIZE;
+  args.priv = nullptr;
+  args.topology = const_cast<PJRT_TopologyDescription*>(topo_desc);
+  LogFatalIfPjrtError(api->PJRT_TopologyDescription_PlatformName(&args), api);
+  return {args.platform_name, args.platform_name_size};
+}
+
+absl::Span<PJRT_DeviceDescription* const> DeviceDescriptions(
+    const PJRT_Api* api, const PJRT_TopologyDescription* topo_desc) {
+  PJRT_TopologyDescription_GetDeviceDescriptions_Args args;
+  args.struct_size =
+      PJRT_TopologyDescription_GetDeviceDescriptions_Args_STRUCT_SIZE;
+  args.priv = nullptr;
+  args.topology = const_cast<PJRT_TopologyDescription*>(topo_desc);
+  LogFatalIfPjrtError(
+      api->PJRT_TopologyDescription_GetDeviceDescriptions(&args), api);
+  return {args.descriptions, args.num_descriptions};
+}
+
 }  // namespace pjrt
diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.h b/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.h
index 7a9e8ba..569f3af 100644
--- a/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.h
+++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.h
@@ -142,7 +142,8 @@
     int api_minor_version);
 
 absl::flat_hash_map<std::string, xla::PjRtValueType>
-ConvertFromPjRtNamedValueList(PJRT_NamedValue* c_value_list, size_t list_size);
+ConvertFromPjRtNamedValueList(const PJRT_NamedValue* c_value_list,
+                              size_t list_size);
 
 // Validates that all entries in value_map have a matching name and type in
 // expected_name_and_type. expected_name_and_type may contain extra entries
@@ -163,6 +164,9 @@
 absl::string_view GetPlatformVersion(PJRT_Client* client, const PJRT_Api* api);
 absl::string_view GetPlatformName(PJRT_Client* client, const PJRT_Api* api);
 
+xla::StatusOr<PJRT_TopologyDescription*> GetTopologyDescription(
+    PJRT_Client* client, const PJRT_Api* api);
+
 // Releases `chunk`.
 PJRT_Chunk ConvertFromCppChunk(xla::PjRtChunk chunk);
 
@@ -173,8 +177,10 @@
 PJRT_DeviceDescription* GetDeviceDescription(const PJRT_Api* api,
                                              PJRT_Device* device);
 
-absl::Span<PJRT_Memory*> GetAddressableMemories(const PJRT_Api* api,
-                                                PJRT_Device* device);
+absl::Span<PJRT_Memory* const> GetAddressableMemories(const PJRT_Api* api,
+                                                      PJRT_Device* device);
+
+int GetId(const PJRT_Api* api, PJRT_DeviceDescription* device_desc);
 
 using PJRT_KeyValueGetCFunc =
     std::function<PJRT_Error*(PJRT_KeyValueGetCallback_Args* args)>;
@@ -236,6 +242,11 @@
                                              size_t num_dims,
                                              PJRT_Buffer_MemoryLayout* layout);
 
+absl::string_view PlatformName(const PJRT_Api* api,
+                               const PJRT_TopologyDescription* topo_desc);
+absl::Span<PJRT_DeviceDescription* const> DeviceDescriptions(
+    const PJRT_Api* api, const PJRT_TopologyDescription* topo_desc);
+
 }  // namespace pjrt
 
 #endif  // XLA_PJRT_C_PJRT_C_API_HELPERS_H_
diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_test.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_test.cc
index 146d4df..36e9344 100644
--- a/third_party/xla/xla/pjrt/c/pjrt_c_api_test.cc
+++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_test.cc
@@ -170,7 +170,7 @@
     return args.local_hardware_id;
   }
 
-  absl::Span<PJRT_Device*> GetClientDevices() const {
+  absl::Span<PJRT_Device* const> GetClientDevices() const {
     PJRT_Client_Devices_Args dev_args;
     dev_args.struct_size = PJRT_Client_Devices_Args_STRUCT_SIZE;
     dev_args.priv = nullptr;
@@ -315,7 +315,7 @@
 }
 
 TEST_F(PjrtCApiTest, ClientDevices) {
-  absl::Span<PJRT_Device*> devices = GetClientDevices();
+  absl::Span<PJRT_Device* const> devices = GetClientDevices();
 
   ASSERT_FALSE(devices.empty());
   for (auto& device : devices) {
@@ -324,14 +324,15 @@
 }
 
 TEST_F(PjrtCApiTest, ClientAddressableDevices) {
-  absl::Span<PJRT_Device*> addressable_devices = GetClientAddressableDevices();
+  absl::Span<PJRT_Device* const> addressable_devices =
+      GetClientAddressableDevices();
 
   ASSERT_FALSE(addressable_devices.empty());
   for (auto& device : addressable_devices) {
     ASSERT_TRUE(this->IsValidDeviceId(device));
   }
 
-  absl::Span<PJRT_Device*> client_devices = GetClientDevices();
+  absl::Span<PJRT_Device* const> client_devices = GetClientDevices();
   for (auto& addressable_device : addressable_devices) {
     ASSERT_THAT(client_devices, ::testing::Contains(addressable_device));
   }
diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_test_base.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_test_base.cc
index f7245c9..c7d43ca 100644
--- a/third_party/xla/xla/pjrt/c/pjrt_c_api_test_base.cc
+++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_test_base.cc
@@ -68,7 +68,8 @@
   CHECK_EQ(error, nullptr);
 }
 
-absl::Span<PJRT_Device*> PjrtCApiTestBase::GetClientAddressableDevices() const {
+absl::Span<PJRT_Device* const> PjrtCApiTestBase::GetClientAddressableDevices()
+    const {
   PJRT_Client_AddressableDevices_Args addr_args;
   addr_args.struct_size = PJRT_Client_AddressableDevices_Args_STRUCT_SIZE;
   addr_args.priv = nullptr;
diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_test_base.h b/third_party/xla/xla/pjrt/c/pjrt_c_api_test_base.h
index cb96775..35201f3 100644
--- a/third_party/xla/xla/pjrt/c/pjrt_c_api_test_base.h
+++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_test_base.h
@@ -39,7 +39,7 @@
   PJRT_Client* client_;
   void destroy_client(PJRT_Client* client);
 
-  absl::Span<PJRT_Device*> GetClientAddressableDevices() const;
+  absl::Span<PJRT_Device* const> GetClientAddressableDevices() const;
 
   PJRT_Client_BufferFromHostBuffer_Args CreateBufferFromHostBufferArgs(
       const std::vector<float>& data, const xla::Shape& shape,
diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_tpu.h b/third_party/xla/xla/pjrt/c/pjrt_c_api_tpu.h
index 2f8082d..898dd37 100644
--- a/third_party/xla/xla/pjrt/c/pjrt_c_api_tpu.h
+++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_tpu.h
@@ -18,17 +18,15 @@
 
 #include "xla/pjrt/c/pjrt_c_api.h"
 
-namespace pjrt {
 enum PjRtCApiTpuInitType {
   // Build with static linking and deploy internally.
-  kInternalStaticLinking,
+  kPjRtCApiTpuInitTypeInternalStaticLinking,
   // Build with static linking and deploy on cloud.
-  kExternalStaticLinking,
+  kPjRtCApiTpuInitTypeExternalStaticLinking,
   // Build with dynamic linking and deploy on cloud.
-  kDynamicLinking
+  kPjRtCApiTpuInitTypeDynamicLinking
 };
 extern enum PjRtCApiTpuInitType kPjRtCApiTpuInitType;
-}  // namespace pjrt
 
 #ifdef __cplusplus
 extern "C" {
diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc
index 9e7570e..1605021 100644
--- a/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc
+++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc
@@ -375,6 +375,17 @@
   return nullptr;
 }
 
+PJRT_Error* PJRT_Client_TopologyDescription(
+    PJRT_Client_TopologyDescription_Args* args) {
+  PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual(
+      "PJRT_Client_TopologyDescription_Args",
+      PJRT_Client_TopologyDescription_Args_STRUCT_SIZE, args->struct_size));
+
+  PJRT_RETURN_IF_ERROR(args->client->topology.status());
+  args->topology = args->client->topology->get();
+  return nullptr;
+}
+
 PJRT_Error* PJRT_Client_Devices(PJRT_Client_Devices_Args* args) {
   PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual(
       "PJRT_Client_Devices_Args", PJRT_Client_Devices_Args_STRUCT_SIZE,
@@ -416,6 +427,8 @@
   return nullptr;
 }
 
+// TODO: b/306669267 - this method is deprecated. When can we return
+// unimplemented?
 PJRT_Error* PJRT_LoadedExecutable_Fingerprint(
     PJRT_LoadedExecutable_Fingerprint_Args* args) {
   PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual(
@@ -517,7 +530,7 @@
 xla::StatusOr<
     std::variant<mlir::OwningOpRef<mlir::ModuleOp>, xla::XlaComputation>>
 ParsePjrtProgram(std::optional<mlir::MLIRContext>& context,
-                 PJRT_Program* program) {
+                 const PJRT_Program* program) {
   auto format_str = absl::string_view(program->format, program->format_size);
   auto module_str = absl::string_view(program->code, program->code_size);
 
@@ -1115,6 +1128,18 @@
   }
 }
 
+PJRT_Error* PJRT_Executable_Fingerprint(
+    PJRT_Executable_Fingerprint_Args* args) {
+  PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual(
+      "PJRT_Executable_Fingerprint_Args",
+      PJRT_Executable_Fingerprint_Args_STRUCT_SIZE, args->struct_size));
+  PJRT_RETURN_IF_ERROR(args->executable->fingerprint.status());
+  args->executable_fingerprint = args->executable->fingerprint.value().c_str();
+  args->executable_fingerprint_size =
+      args->executable->fingerprint.value().size();
+  return nullptr;
+}
+
 PJRT_Error* PJRT_Executable_GetCostAnalysis(
     PJRT_Executable_GetCostAnalysis_Args* args) {
   PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual(
@@ -1296,7 +1321,7 @@
 }
 
 static std::vector<std::vector<xla::PjRtBuffer*>> Convert2DCBuffersToCppBuffers(
-    PJRT_Buffer*** c_lists, size_t outer_size, size_t inner_size) {
+    PJRT_Buffer* const* const* c_lists, size_t outer_size, size_t inner_size) {
   std::vector<std::vector<xla::PjRtBuffer*>> cpp_lists;
   cpp_lists.reserve(outer_size);
   for (int i = 0; i < outer_size; ++i) {
@@ -2122,8 +2147,19 @@
   }
 }
 
+static xla::StatusOr<std::unique_ptr<PJRT_TopologyDescription>>
+GetStatusOrTopologyDescription(const xla::PjRtClient& cpp_client) {
+  xla::StatusOr<const xla::PjRtTopologyDescription*> status_or_cpp_topo =
+      cpp_client.GetTopologyDescription();
+  if (!status_or_cpp_topo.ok()) {
+    return status_or_cpp_topo.status();
+  }
+  return std::unique_ptr<PJRT_TopologyDescription>(
+      CreateWrapperDeviceTopology(*status_or_cpp_topo));
+}
+
 PJRT_Client* CreateWrapperClient(std::unique_ptr<xla::PjRtClient> cpp_client) {
-  PJRT_Client* c_client = new PJRT_Client{std::move(cpp_client)};
+  PJRT_Client* c_client = new PJRT_Client(std::move(cpp_client));
   PopulatePjrtClientDevices(c_client);
   PopulatePjrtClientMemories(c_client);
   AttachDevicesAndMemories(c_client);
@@ -2131,9 +2167,9 @@
 }
 
 PJRT_TopologyDescription* CreateWrapperDeviceTopology(
-    std::unique_ptr<xla::PjRtTopologyDescription> cpp_topology) {
+    const xla::PjRtTopologyDescription* cpp_topology) {
   PJRT_TopologyDescription* c_topology =
-      new PJRT_TopologyDescription{std::move(cpp_topology)};
+      new PJRT_TopologyDescription{/*owned_topology=*/nullptr, cpp_topology};
   c_topology->cpp_descriptions = c_topology->topology->DeviceDescriptions();
   c_topology->descriptions.reserve(c_topology->cpp_descriptions.size());
   c_topology->description_pointers.reserve(c_topology->cpp_descriptions.size());
@@ -2150,11 +2186,24 @@
   return c_topology;
 }
 
+PJRT_TopologyDescription* CreateWrapperDeviceTopology(
+    std::unique_ptr<xla::PjRtTopologyDescription> cpp_topology) {
+  PJRT_TopologyDescription* topo_desc =
+      CreateWrapperDeviceTopology(cpp_topology.get());
+  topo_desc->owned_topology = std::move(cpp_topology);
+  return topo_desc;
+}
+
 }  // namespace pjrt
 
+PJRT_Client::PJRT_Client(std::unique_ptr<xla::PjRtClient> cpp_client)
+    : client(std::move(cpp_client)),
+      topology(pjrt::GetStatusOrTopologyDescription(*client)) {}
+
 PJRT_Executable::PJRT_Executable(
     std::shared_ptr<xla::PjRtExecutable> executable)
-    : executable(std::move(executable)) {}
+    : executable(std::move(executable)),
+      fingerprint(this->executable->FingerprintExecutable()) {}
 
 PJRT_LoadedExecutable::PJRT_LoadedExecutable(
     std::shared_ptr<xla::PjRtLoadedExecutable> executable, PJRT_Client* client)
diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.h b/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.h
index f26450c..0caa851 100644
--- a/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.h
+++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.h
@@ -41,6 +41,21 @@
   xla::Status status;
 };
 
+struct PJRT_TopologyDescription {
+  // nullptr iff the PjRtTopologyDescription isn't owned by the caller. The PJRT
+  // C API sometimes returns a topo desc that's owned by the caller and must be
+  // freed using PJRT_TopologyDescription_Destroy
+  // (e.g. PJRT_TopologyDescription_Create), and sometimes returns a topo desc
+  // that's owned by something else (e.g. PJRT_Client_TopologyDescription).
+  std::unique_ptr<xla::PjRtTopologyDescription> owned_topology;
+  const xla::PjRtTopologyDescription* topology;
+  std::vector<std::unique_ptr<const xla::PjRtDeviceDescription>>
+      cpp_descriptions;
+  std::vector<PJRT_DeviceDescription> descriptions;
+  std::vector<PJRT_DeviceDescription*> description_pointers;
+  std::vector<PJRT_NamedValue> attributes;
+};
+
 struct PJRT_Client {
   std::unique_ptr<xla::PjRtClient> client;
   std::vector<PJRT_Device> owned_devices;
@@ -62,6 +77,9 @@
   // `owned_memories`.
   absl::flat_hash_map<xla::PjRtMemorySpace*, PJRT_Memory*>
       c_memory_from_cpp_memory;
+  xla::StatusOr<std::unique_ptr<PJRT_TopologyDescription>> topology;
+
+  explicit PJRT_Client(std::unique_ptr<xla::PjRtClient> cpp_client);
 };
 
 // PJRT_DeviceDescriptions are owned by their corresponding PJRT_Device.
@@ -93,6 +111,8 @@
   // Must be shared_ptr so that we can share with PJRT_LoadedExecutable.
   std::shared_ptr<xla::PjRtExecutable> executable;
 
+  xla::StatusOr<std::string> fingerprint;
+
   // Used to synchronize concurrent setting of cached values.
   mutable absl::Mutex mutex;
 
@@ -171,15 +191,6 @@
   std::string serialized;
 };
 
-struct PJRT_TopologyDescription {
-  std::unique_ptr<xla::PjRtTopologyDescription> topology;
-  std::vector<std::unique_ptr<const xla::PjRtDeviceDescription>>
-      cpp_descriptions;
-  std::vector<PJRT_DeviceDescription> descriptions;
-  std::vector<PJRT_DeviceDescription*> description_pointers;
-  std::vector<PJRT_NamedValue> attributes;
-};
-
 struct PJRT_TransferMetadata {
   // Decompose xla::Shape into C API type fields, without any Tuple information.
   // TODO(b/238999986) support other `xla::Shape` fields when they are fully
@@ -210,6 +221,8 @@
 PJRT_Error* PJRT_Client_PlatformName(PJRT_Client_PlatformName_Args* args);
 PJRT_Error* PJRT_Client_ProcessIndex(PJRT_Client_ProcessIndex_Args* args);
 PJRT_Error* PJRT_Client_PlatformVersion(PJRT_Client_PlatformVersion_Args* args);
+PJRT_Error* PJRT_Client_TopologyDescription(
+    PJRT_Client_TopologyDescription_Args* args);
 PJRT_Error* PJRT_Client_Devices(PJRT_Client_Devices_Args* args);
 PJRT_Error* PJRT_Client_AddressableDevices(
     PJRT_Client_AddressableDevices_Args* args);
@@ -262,6 +275,7 @@
 PJRT_Error* PJRT_Executable_NumOutputs(PJRT_Executable_NumOutputs_Args* args);
 PJRT_Error* PJRT_Executable_SizeOfGeneratedCodeInBytes(
     PJRT_Executable_SizeOfGeneratedCodeInBytes_Args* args);
+PJRT_Error* PJRT_Executable_Fingerprint(PJRT_Executable_Fingerprint_Args* args);
 PJRT_Error* PJRT_Executable_GetCostAnalysis(
     PJRT_Executable_GetCostAnalysis_Args* args);
 PJRT_Error* PJRT_Executable_OutputElementTypes(
@@ -286,6 +300,8 @@
     PJRT_Executable_DeserializeAndLoad_Args* args);
 PJRT_Error* PJRT_LoadedExecutable_GetExecutable(
     PJRT_LoadedExecutable_GetExecutable_Args* args);
+// TODO: b/306669267 - this method is deprecated. When can we return
+// unimplemented?
 PJRT_Error* PJRT_LoadedExecutable_Fingerprint(
     PJRT_LoadedExecutable_Fingerprint_Args* args);
 
@@ -375,11 +391,22 @@
 std::string ProgramFormatErrorMsg(absl::string_view program_format);
 
 // Creates a C PJRT topology from a C++ PJRT topology.
-// The returned topology is owned by the caller and
-// should be destroyed with PJRT_TopologyDescription_Destroy.
+//
+// The returned topology is owned by the caller and should be destroyed with
+// PJRT_TopologyDescription_Destroy. This can be used to implement functions
+// like PJRT_TopologyDescription_Create that return an owned topo desc.
 PJRT_TopologyDescription* CreateWrapperDeviceTopology(
     std::unique_ptr<xla::PjRtTopologyDescription> cpp_topology);
 
+// Creates a C PJRT topology from a C++ PJRT topology.
+//
+// The returned topology is *not* owned by the caller and should *not* be
+// destroyed with PJRT_TopologyDescription_Destroy. This can be used to
+// implement functions like PJRT_Client_TopologyDescription that return a topo
+// desc owned by something else.
+PJRT_TopologyDescription* CreateWrapperDeviceTopology(
+    const xla::PjRtTopologyDescription* cpp_topology);
+
 // Creates a C PJRT client from a C++ PJRT client and creates C PJRT devices
 // from cpp_client's devices. The returned client is owned by the caller and
 // should be destroyed with PJRT_Client_Destroy.
@@ -563,6 +590,9 @@
       pjrt::PJRT_Buffer_CopyToMemory,
       /*PJRT_Client_CreateViewOfDeviceBuffer=*/
       pjrt::PJRT_Client_CreateViewOfDeviceBuffer,
+      /*PJRT_Executable_Fingerprint=*/pjrt::PJRT_Executable_Fingerprint,
+      /*PJRT_Client_TopologyDescription= */
+      pjrt::PJRT_Client_TopologyDescription,
   };
 }
 
diff --git a/third_party/xla/xla/pjrt/pjrt_c_api_client.cc b/third_party/xla/xla/pjrt/pjrt_c_api_client.cc
index a8c8226..8bc2141 100644
--- a/third_party/xla/xla/pjrt/pjrt_c_api_client.cc
+++ b/third_party/xla/xla/pjrt/pjrt_c_api_client.cc
@@ -100,6 +100,20 @@
 
 // ---------------------------------- Client -----------------------------------
 
+static StatusOr<const PjRtCApiTopologyDescription> InitClientTopoDesc(
+    const PJRT_Api* c_api, PJRT_Client* c_client) {
+  if (c_api->pjrt_api_version.major_version == 0 &&
+      c_api->pjrt_api_version.minor_version < 36) {
+    return Unimplemented(
+        "Getting TopologyDescription for PJRT client requires plugin with PJRT "
+        "C API version >= 0.36");
+  }
+  StatusOr<PJRT_TopologyDescription*> c_topo =
+      pjrt::GetTopologyDescription(c_client, c_api);
+  TF_RETURN_IF_ERROR(c_topo.status());
+  return PjRtCApiTopologyDescription(c_api, *c_topo, /*owned=*/false);
+}
+
 PjRtCApiClient::PjRtCApiClient(
     const PJRT_Api* c_api, PJRT_Client* c_client,
     std::unique_ptr<pjrt::PJRT_KeyValueCallbackData> kv_callback_data)
@@ -107,6 +121,7 @@
       c_client_(std::unique_ptr<PJRT_Client, ::pjrt::PJRT_ClientDeleter>(
           c_client, ::pjrt::MakeClientDeleter(c_api))),
       kv_callback_data_(std::move(kv_callback_data)),
+      topo_desc_(InitClientTopoDesc(c_api, c_client)),
       // Example platform version string:
       //   PJRT C API
       //   TFRT TPU v2
@@ -430,6 +445,14 @@
       std::make_unique<PjRtCApiLoadedExecutable>(this, c_exec));
 }
 
+StatusOr<const PjRtTopologyDescription*>
+PjRtCApiClient::GetTopologyDescription() const {
+  if (!topo_desc_.ok()) {
+    return topo_desc_.status();
+  }
+  return &(*topo_desc_);
+}
+
 StatusOr<std::uintptr_t> PjRtCApiClient::UnsafeBufferPointer(
     PjRtBuffer* buffer) {
   // Validate that the buffer's client matches the function call's client, since
@@ -1140,6 +1163,19 @@
   return std::string(ser_args.serialized_bytes, ser_args.serialized_bytes_size);
 }
 
+StatusOr<std::string> PjRtCApiExecutable::FingerprintExecutable() const {
+  PJRT_Executable_Fingerprint_Args args;
+  args.struct_size = PJRT_Executable_Fingerprint_Args_STRUCT_SIZE;
+  args.priv = nullptr;
+  args.executable = c_executable();
+
+  RETURN_STATUS_IF_PJRT_ERROR(c_api_->PJRT_Executable_Fingerprint(&args),
+                              c_api_);
+
+  return std::string(args.executable_fingerprint,
+                     args.executable_fingerprint_size);
+}
+
 // ------------------------ Loaded Executables ---------------------------------
 
 PjRtCApiLoadedExecutable::PjRtCApiLoadedExecutable(
@@ -1197,7 +1233,7 @@
 }
 
 static std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>
-Convert2DCBuffersToCppBuffers(PJRT_Buffer*** c_lists, size_t outer_size,
+Convert2DCBuffersToCppBuffers(PJRT_Buffer** const* c_lists, size_t outer_size,
                               int inner_size, xla::PjRtCApiClient* client) {
   std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> ret;
   for (size_t i = 0; i < outer_size; ++i) {
@@ -2016,16 +2052,21 @@
 // ------------------------------ Device Topology ------------------------------
 
 PjRtCApiTopologyDescription::PjRtCApiTopologyDescription(
-    const PJRT_Api* c_api, PJRT_TopologyDescription* c_topology)
+    const PJRT_Api* c_api, PJRT_TopologyDescription* c_topology, bool owned)
     : compiler_(std::make_unique<PjRtCApiCompiler>(c_api)),
       c_api_(c_api),
-      c_topology_(c_topology, ::pjrt::MakeTopologyDescriptionDeleter(c_api)) {
+      c_topology_(c_topology) {
+  if (owned) {
+    owned_c_topology_ = std::unique_ptr<PJRT_TopologyDescription,
+                                        pjrt::PJRT_TopologyDescriptionDeleter>(
+        c_topology, pjrt::MakeTopologyDescriptionDeleter(c_api));
+  }
   InitAttributes();
 }
 
 absl::string_view PjRtCApiTopologyDescription::platform_name() const {
   PJRT_TopologyDescription_PlatformName_Args args;
-  args.topology = c_topology_.get();
+  args.topology = c_topology_;
   args.struct_size = PJRT_TopologyDescription_PlatformName_Args_STRUCT_SIZE;
   args.priv = nullptr;
   pjrt::LogFatalIfPjrtError(
@@ -2037,7 +2078,7 @@
   PJRT_TopologyDescription_PlatformVersion_Args args;
   args.struct_size = PJRT_TopologyDescription_PlatformVersion_Args_STRUCT_SIZE;
   args.priv = nullptr;
-  args.topology = c_topology_.get();
+  args.topology = c_topology_;
   pjrt::LogFatalIfPjrtError(
       c_api_->PJRT_TopologyDescription_PlatformVersion(&args), c_api_);
   return absl::string_view(args.platform_version, args.platform_version_size);
@@ -2049,14 +2090,14 @@
   args.struct_size =
       PJRT_TopologyDescription_GetDeviceDescriptions_Args_STRUCT_SIZE;
   args.priv = nullptr;
-  args.topology = c_topology_.get();
+  args.topology = c_topology_;
   pjrt::LogFatalIfPjrtError(
       c_api_->PJRT_TopologyDescription_GetDeviceDescriptions(&args), c_api_);
   std::vector<std::unique_ptr<const PjRtDeviceDescription>> out;
   out.reserve(args.num_descriptions);
   for (PJRT_DeviceDescription* device_desc :
-       absl::Span<PJRT_DeviceDescription*>(args.descriptions,
-                                           args.num_descriptions)) {
+       absl::Span<PJRT_DeviceDescription* const>(args.descriptions,
+                                                 args.num_descriptions)) {
     out.push_back(
         std::make_unique<PjRtCApiDeviceDescription>(c_api_, device_desc));
   }
@@ -2067,7 +2108,7 @@
   PJRT_TopologyDescription_Serialize_Args args;
   args.struct_size = PJRT_TopologyDescription_Serialize_Args_STRUCT_SIZE;
   args.priv = nullptr;
-  args.topology = c_topology_.get();
+  args.topology = c_topology_;
   RETURN_STATUS_IF_PJRT_ERROR(c_api_->PJRT_TopologyDescription_Serialize(&args),
                               c_api_);
   auto out = std::string(args.serialized_bytes, args.serialized_bytes_size);
@@ -2079,7 +2120,7 @@
   PJRT_TopologyDescription_Attributes_Args args;
   args.struct_size = PJRT_TopologyDescription_Attributes_Args_STRUCT_SIZE;
   args.priv = nullptr;
-  args.topology = c_topology_.get();
+  args.topology = c_topology_;
   pjrt::LogFatalIfPjrtError(c_api_->PJRT_TopologyDescription_Attributes(&args),
                             c_api_);
   attributes_ =
@@ -2220,7 +2261,8 @@
       c_api->PJRT_TopologyDescription_Create(&init_args), c_api);
   PJRT_TopologyDescription* c_topology = init_args.topology;
   return std::unique_ptr<PjRtTopologyDescription>(
-      std::make_unique<PjRtCApiTopologyDescription>(c_api, c_topology));
+      std::make_unique<PjRtCApiTopologyDescription>(c_api, c_topology,
+                                                    /*owned=*/true));
 }
 
 }  // namespace xla
diff --git a/third_party/xla/xla/pjrt/pjrt_c_api_client.h b/third_party/xla/xla/pjrt/pjrt_c_api_client.h
index 1b1eecf..7e16ee3 100644
--- a/third_party/xla/xla/pjrt/pjrt_c_api_client.h
+++ b/third_party/xla/xla/pjrt/pjrt_c_api_client.h
@@ -165,6 +165,72 @@
   std::vector<PjRtMemorySpace*> memory_spaces_;
 };
 
+class PjRtCApiCompiler : public PjRtCompiler {
+ public:
+  explicit PjRtCApiCompiler(const PJRT_Api* c_api) : c_api_(c_api) {}
+
+  StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
+      CompileOptions options, const XlaComputation& computation,
+      const PjRtTopologyDescription& topology, PjRtClient* client) override;
+
+  StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
+      CompileOptions options, mlir::ModuleOp module,
+      const PjRtTopologyDescription& topology, PjRtClient* client) override;
+
+ private:
+  const PJRT_Api* c_api_;
+};
+
+class PjRtCApiTopologyDescription : public PjRtTopologyDescription {
+ public:
+  // `owned` indicates whether this PjRtCApiTopologyDescription should take
+  // ownership of `c_topology`, i.e., if owned is true,
+  // PJRT_TopologyDescription_Destroy will be called on `c_topology` when this
+  // PjRtCApiTopologyDescription is destroyed.
+  PjRtCApiTopologyDescription(const PJRT_Api* c_api,
+                              PJRT_TopologyDescription* c_topology, bool owned);
+
+  PjRtPlatformId platform_id() const override {
+    CHECK(false) << "PJRT C API does not support platform_id.";
+  }
+
+  absl::string_view platform_name() const override;
+
+  absl::string_view platform_version() const override;
+
+  std::optional<PjRtCompiler*> compiler() const override {
+    return compiler_.get();
+  }
+
+  PJRT_TopologyDescription* c_topology() const { return c_topology_; }
+
+  std::vector<std::unique_ptr<const PjRtDeviceDescription>> DeviceDescriptions()
+      const override;
+
+  absl::StatusOr<std::string> Serialize() const override;
+
+  // Returns vendor specific attributes about the topology.
+  const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& Attributes()
+      const override {
+    return attributes_;
+  }
+
+ private:
+  std::unique_ptr<PjRtCApiCompiler> compiler_;
+  const PJRT_Api* c_api_;
+  // nullptr iff the PJRT_TopologyDescription isn't owned by this wrapper
+  // (i.e. by the caller).
+  std::unique_ptr<PJRT_TopologyDescription,
+                  ::pjrt::PJRT_TopologyDescriptionDeleter>
+      owned_c_topology_;
+  PJRT_TopologyDescription* c_topology_;
+  // Device specific attributes with corresponding values.
+  absl::flat_hash_map<std::string, xla::PjRtDeviceAttribute> attributes_;
+
+  // Initializes device specific attributes.
+  void InitAttributes();
+};
+
 class PjRtCApiClient : public PjRtClient {
  public:
   PjRtCApiClient(
@@ -225,6 +291,9 @@
         "PJRT C API does not support CreateUninitializedBuffer");
   }
 
+  StatusOr<const PjRtTopologyDescription*> GetTopologyDescription()
+      const override;
+
   StatusOr<std::unique_ptr<AsyncHostToDeviceTransferManager>>
   CreateBuffersForAsyncHostToDevice(absl::Span<const Shape> shapes,
                                     PjRtDevice* device) override {
@@ -352,6 +421,10 @@
   // supported.
   std::vector<PjRtMemorySpace*> addressable_memory_spaces_;
   absl::flat_hash_map<PJRT_Memory*, PjRtCApiMemorySpace*> c_to_cpp_memory_map_;
+  // There may be an error fetching the topology desc via the C API
+  // (e.g. unimplemented). Save the error during client init so we can return it
+  // from GetTopologyDescription().
+  StatusOr<const PjRtCApiTopologyDescription> topo_desc_;
 
   const std::string platform_version_;
   const std::string platform_name_;
@@ -518,6 +591,8 @@
 
   StatusOr<std::string> SerializeExecutable() const override;
 
+  StatusOr<std::string> FingerprintExecutable() const override;
+
  private:
   const PJRT_Api* c_api_;
   std::unique_ptr<PJRT_Executable, ::pjrt::PJRT_ExecutableDeleter> executable_;
@@ -663,67 +738,6 @@
   void InitDevices();
 };
 
-class PjRtCApiCompiler : public PjRtCompiler {
- public:
-  explicit PjRtCApiCompiler(const PJRT_Api* c_api) : c_api_(c_api) {}
-
-  StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
-      CompileOptions options, const XlaComputation& computation,
-      const PjRtTopologyDescription& topology, PjRtClient* client) override;
-
-  StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
-      CompileOptions options, mlir::ModuleOp module,
-      const PjRtTopologyDescription& topology, PjRtClient* client) override;
-
- private:
-  const PJRT_Api* c_api_;
-};
-
-class PjRtCApiTopologyDescription : public PjRtTopologyDescription {
- public:
-  PjRtCApiTopologyDescription(const PJRT_Api* c_api,
-                              PJRT_TopologyDescription* c_topology);
-
-  PjRtPlatformId platform_id() const override {
-    CHECK(false) << "PJRT C API does not support platform_id.";
-  }
-
-  absl::string_view platform_name() const override;
-
-  absl::string_view platform_version() const override;
-
-  std::optional<PjRtCompiler*> compiler() const override {
-    return compiler_.get();
-  }
-
-  const PJRT_TopologyDescription* c_topology() const {
-    return c_topology_.get();
-  }
-
-  std::vector<std::unique_ptr<const PjRtDeviceDescription>> DeviceDescriptions()
-      const override;
-
-  absl::StatusOr<std::string> Serialize() const override;
-
-  // Returns vendor specific attributes about the topology.
-  const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& Attributes()
-      const override {
-    return attributes_;
-  }
-
- private:
-  std::unique_ptr<PjRtCApiCompiler> compiler_;
-  const PJRT_Api* c_api_;
-  std::unique_ptr<PJRT_TopologyDescription,
-                  ::pjrt::PJRT_TopologyDescriptionDeleter>
-      c_topology_;
-  // Device specific attributes with corresponding values.
-  absl::flat_hash_map<std::string, xla::PjRtDeviceAttribute> attributes_;
-
-  // Initializes device specific attributes.
-  void InitAttributes();
-};
-
 class CApiCopyToDeviceStream : public CopyToDeviceStream {
  public:
   CApiCopyToDeviceStream(PJRT_CopyToDeviceStream* c_stream,
diff --git a/third_party/xla/xla/pjrt/pjrt_client.h b/third_party/xla/xla/pjrt/pjrt_client.h
index c4b73ed..175291d 100644
--- a/third_party/xla/xla/pjrt/pjrt_client.h
+++ b/third_party/xla/xla/pjrt/pjrt_client.h
@@ -598,7 +598,8 @@
   // Gets the pointer to the topology description held by the client.
   virtual StatusOr<const PjRtTopologyDescription*> GetTopologyDescription()
       const {
-    return Unimplemented("GetTopologyDescription not supported!");
+    return Unimplemented("GetTopologyDescription not supported on platform %s",
+                         platform_name());
   }
 
   // Returns topology object for compilation based on this client's topology.
diff --git a/third_party/xla/xla/primitive_util.h b/third_party/xla/xla/primitive_util.h
index 7ed8f48..63fa4e1 100644
--- a/third_party/xla/xla/primitive_util.h
+++ b/third_party/xla/xla/primitive_util.h
@@ -71,7 +71,7 @@
 // Returns the XLA primitive type (eg, F32) corresponding to the given
 // template parameter native type (eg, float).
 template <typename NativeT>
-PrimitiveType NativeToPrimitiveType() {
+constexpr PrimitiveType NativeToPrimitiveType() {
   // Make the expression depend on the template parameter NativeT so
   // that this compile-time error only appears if this function is
   // instantiated with some concrete type that is not specialized
@@ -82,119 +82,118 @@
 }
 
 // Declarations of specializations for each native type which correspond to a
-// XLA primitive type.  As an optimization, these are declared inline in the
-// header.
+// XLA primitive type.
 template <>
-inline PrimitiveType NativeToPrimitiveType<bool>() {
+constexpr PrimitiveType NativeToPrimitiveType<bool>() {
   return PRED;
 }
 
 // Unsigned integer
 template <>
-inline PrimitiveType NativeToPrimitiveType<u4>() {
+constexpr PrimitiveType NativeToPrimitiveType<u4>() {
   return U4;
 }
 
 template <>
-inline PrimitiveType NativeToPrimitiveType<uint8_t>() {
+constexpr PrimitiveType NativeToPrimitiveType<uint8_t>() {
   return U8;
 }
 
 template <>
-inline PrimitiveType NativeToPrimitiveType<uint16_t>() {
+constexpr PrimitiveType NativeToPrimitiveType<uint16_t>() {
   return U16;
 }
 
 template <>
-inline PrimitiveType NativeToPrimitiveType<uint32_t>() {
+constexpr PrimitiveType NativeToPrimitiveType<uint32_t>() {
   return U32;
 }
 
 template <>
-inline PrimitiveType NativeToPrimitiveType<uint64_t>() {
+constexpr PrimitiveType NativeToPrimitiveType<uint64_t>() {
   return U64;
 }
 
 // Signed integer
 template <>
-inline PrimitiveType NativeToPrimitiveType<s4>() {
+constexpr PrimitiveType NativeToPrimitiveType<s4>() {
   return S4;
 }
 
 template <>
-inline PrimitiveType NativeToPrimitiveType<int8_t>() {
+constexpr PrimitiveType NativeToPrimitiveType<int8_t>() {
   return S8;
 }
 
 template <>
-inline PrimitiveType NativeToPrimitiveType<int16_t>() {
+constexpr PrimitiveType NativeToPrimitiveType<int16_t>() {
   return S16;
 }
 
 template <>
-inline PrimitiveType NativeToPrimitiveType<int32_t>() {
+constexpr PrimitiveType NativeToPrimitiveType<int32_t>() {
   return S32;
 }
 
 template <>
-inline PrimitiveType NativeToPrimitiveType<int64_t>() {
+constexpr PrimitiveType NativeToPrimitiveType<int64_t>() {
   return S64;
 }
 
 // Floating point
 template <>
-inline PrimitiveType NativeToPrimitiveType<float>() {
+constexpr PrimitiveType NativeToPrimitiveType<float>() {
   return F32;
 }
 
 template <>
-inline PrimitiveType NativeToPrimitiveType<double>() {
+constexpr PrimitiveType NativeToPrimitiveType<double>() {
   return F64;
 }
 
 template <>
-inline PrimitiveType NativeToPrimitiveType<half>() {
+constexpr PrimitiveType NativeToPrimitiveType<half>() {
   return F16;
 }
 
 template <>
-inline PrimitiveType NativeToPrimitiveType<bfloat16>() {
+constexpr PrimitiveType NativeToPrimitiveType<bfloat16>() {
   return BF16;
 }
 
 template <>
-inline PrimitiveType NativeToPrimitiveType<tsl::float8_e5m2>() {
+constexpr PrimitiveType NativeToPrimitiveType<tsl::float8_e5m2>() {
   return F8E5M2;
 }
 
 template <>
-inline PrimitiveType NativeToPrimitiveType<tsl::float8_e4m3fn>() {
+constexpr PrimitiveType NativeToPrimitiveType<tsl::float8_e4m3fn>() {
   return F8E4M3FN;
 }
 
 template <>
-inline PrimitiveType NativeToPrimitiveType<tsl::float8_e4m3b11>() {
+constexpr PrimitiveType NativeToPrimitiveType<tsl::float8_e4m3b11>() {
   return F8E4M3B11FNUZ;
 }
 
 template <>
-inline PrimitiveType NativeToPrimitiveType<tsl::float8_e5m2fnuz>() {
+constexpr PrimitiveType NativeToPrimitiveType<tsl::float8_e5m2fnuz>() {
   return F8E5M2FNUZ;
 }
 
 template <>
-inline PrimitiveType NativeToPrimitiveType<tsl::float8_e4m3fnuz>() {
+constexpr PrimitiveType NativeToPrimitiveType<tsl::float8_e4m3fnuz>() {
   return F8E4M3FNUZ;
 }
 
 // Complex
 template <>
-inline PrimitiveType NativeToPrimitiveType<complex64>() {
+constexpr PrimitiveType NativeToPrimitiveType<complex64>() {
   return C64;
 }
 
 template <>
-inline PrimitiveType NativeToPrimitiveType<complex128>() {
+constexpr PrimitiveType NativeToPrimitiveType<complex128>() {
   return C128;
 }
 
diff --git a/third_party/xla/xla/python/ifrt/dtype.h b/third_party/xla/xla/python/ifrt/dtype.h
index 8d8088f..184e359 100644
--- a/third_party/xla/xla/python/ifrt/dtype.h
+++ b/third_party/xla/xla/python/ifrt/dtype.h
@@ -72,11 +72,11 @@
     // dtype will have empty dimensions.
     kToken = 17,
 
-    kF8E4M3FN = 19,
+    kF8E4M3FN = 20,
     kF8E4M3B11FNUZ = 23,
-    kF8E4M3FNUZ = 24,
-    kF8E5M2 = 20,
-    kF8E5M2FNUZ = 25,
+    kF8E4M3FNUZ = 25,
+    kF8E5M2 = 19,
+    kF8E5M2FNUZ = 24,
 
     // Next = 26
 
diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc
index 452262b..92c079f 100644
--- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc
+++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc
@@ -22,6 +22,7 @@
 #include <vector>
 
 #include "absl/log/check.h"
+#include "absl/status/status.h"
 #include "absl/strings/str_join.h"
 #include "xla/literal.h"
 #include "xla/pjrt/pjrt_client.h"
@@ -43,31 +44,36 @@
 
 StatusOr<xla::PrimitiveType> ToPrimitiveType(DType dtype) {
   switch (dtype.kind()) {
-    case DType::kInvalid:
-    case DType::kPred:
-    case DType::kS4:
-    case DType::kS8:
-    case DType::kS16:
-    case DType::kS32:
-    case DType::kS64:
-    case DType::kU4:
-    case DType::kU8:
-    case DType::kU16:
-    case DType::kU32:
-    case DType::kU64:
-    case DType::kF8E4M3FN:
-    case DType::kF8E4M3B11FNUZ:
-    case DType::kF8E4M3FNUZ:
-    case DType::kF8E5M2:
-    case DType::kF8E5M2FNUZ:
-    case DType::kF16:
-    case DType::kF32:
-    case DType::kBF16:
-    case DType::kF64:
-    case DType::kC64:
-    case DType::kC128:
-    case DType::kToken:
-      return static_cast<xla::PrimitiveType>(static_cast<int>(dtype.kind()));
+#define CASE(DT, PT)                                                      \
+  case DT:                                                                \
+    static_assert(PT ==                                                   \
+                  static_cast<xla::PrimitiveType>(static_cast<int>(DT))); \
+    return PT
+    CASE(DType::kInvalid, xla::PrimitiveType::PRIMITIVE_TYPE_INVALID);
+    CASE(DType::kPred, xla::PrimitiveType::PRED);
+    CASE(DType::kS4, xla::PrimitiveType::S4);
+    CASE(DType::kS8, xla::PrimitiveType::S8);
+    CASE(DType::kS16, xla::PrimitiveType::S16);
+    CASE(DType::kS32, xla::PrimitiveType::S32);
+    CASE(DType::kS64, xla::PrimitiveType::S64);
+    CASE(DType::kU4, xla::PrimitiveType::U4);
+    CASE(DType::kU8, xla::PrimitiveType::U8);
+    CASE(DType::kU16, xla::PrimitiveType::U16);
+    CASE(DType::kU32, xla::PrimitiveType::U32);
+    CASE(DType::kU64, xla::PrimitiveType::U64);
+    CASE(DType::kF8E4M3FN, xla::PrimitiveType::F8E4M3FN);
+    CASE(DType::kF8E4M3B11FNUZ, xla::PrimitiveType::F8E4M3B11FNUZ);
+    CASE(DType::kF8E4M3FNUZ, xla::PrimitiveType::F8E4M3FNUZ);
+    CASE(DType::kF8E5M2, xla::PrimitiveType::F8E5M2);
+    CASE(DType::kF8E5M2FNUZ, xla::PrimitiveType::F8E5M2FNUZ);
+    CASE(DType::kF16, xla::PrimitiveType::F16);
+    CASE(DType::kF32, xla::PrimitiveType::F32);
+    CASE(DType::kBF16, xla::PrimitiveType::BF16);
+    CASE(DType::kF64, xla::PrimitiveType::F64);
+    CASE(DType::kC64, xla::PrimitiveType::C64);
+    CASE(DType::kC128, xla::PrimitiveType::C128);
+    CASE(DType::kToken, xla::PrimitiveType::TOKEN);
+#undef CASE
     case DType::kString:
       return InvalidArgument("Not supported as XLA PrimitiveType: %d",
                              static_cast<int>(dtype.kind()));
@@ -381,35 +387,55 @@
             "first fetched to the host and then sent to the destination "
             "device.");
       }
-      // Use `PjRtBuffer::CopyToMemorySpace` instead of
+      if (new_sharding_has_memory_kind && memories_supported &&
+          semantics == ArrayCopySemantics::kDonateInput && !memory_kind_equal) {
+        return Unimplemented(
+            "Donation across different memory kinds is not implemented.");
+      }
+      // Try using `PjRtBuffer::CopyToMemorySpace` instead of
       // `PjRtBuffer::CopyToDevice` when memories are supported. Because the
       // semantics of the latter one is to copy to the default memory space of
       // the device.
+      std::unique_ptr<PjRtBuffer> copied_buffer;
       if (new_sharding_has_memory_kind && memories_supported) {
         TF_ASSIGN_OR_RETURN(
             auto memory_space,
             GetMemorySpaceFromMemoryKind(new_sharding->devices()[i],
                                          canonicalized_sharding_memory_kind));
-        TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtBuffer> copied_buffer,
-                            pjrt_buffers_[i]->CopyToMemorySpace(memory_space));
-        if (semantics == ArrayCopySemantics::kDonateInput) {
-          if (!memory_kind_equal) {
-            return Unimplemented(
-                "Donation across different memory kinds is not implemented.");
+        StatusOr<std::unique_ptr<PjRtBuffer>> copied_buffer_using_memory_space =
+            pjrt_buffers_[i]->CopyToMemorySpace(memory_space);
+        if (copied_buffer_using_memory_space.ok()) {
+          copied_buffer = std::move(*copied_buffer_using_memory_space);
+        } else if (!absl::IsUnimplemented(
+                       copied_buffer_using_memory_space.status())) {
+          return copied_buffer_using_memory_space.status();
+        } else {
+          // Returns unimplemented if the sharding's memory space isn't the
+          // device's default memory space. Otherwise continue on to the
+          // CopyToDevice fallback.
+          // TODO(b/307743645): clean up this branch when memory space is better
+          // supported.
+          TF_ASSIGN_OR_RETURN(
+              PjRtMemorySpace * default_memory_space,
+              new_sharding->devices()[i]->default_memory_space());
+          if (canonicalized_sharding_memory_kind.memory_kind() !=
+              default_memory_space->memory_space_kind()) {
+            return copied_buffer_using_memory_space.status();
           }
-          pjrt_buffers_[i] = nullptr;
         }
-        buffers.push_back(std::shared_ptr<PjRtBuffer>(copied_buffer.release()));
-      } else {
-        // Use `PjRtBuffer::CopyToDevice` when memories are not supported.
-        TF_ASSIGN_OR_RETURN(
-            std::unique_ptr<xla::PjRtBuffer> copied_buffer,
-            pjrt_buffers_[i]->CopyToDevice(new_sharding->devices()[i]));
-        if (semantics == ArrayCopySemantics::kDonateInput) {
-          pjrt_buffers_[i] = nullptr;
-        }
-        buffers.push_back(std::shared_ptr<PjRtBuffer>(copied_buffer.release()));
       }
+      // Fallback to `PjRtBuffer::CopyToDevice` if (1) memories are not
+      // supported or (2) `PjRtBuffer::CopyToMemorySpace` returns unimplemented
+      // and canonicalized_sharding_memory_kind is the same as the
+      // default_memory_space of `new_sharding->devices()[i]`.
+      if (copied_buffer == nullptr) {
+        TF_ASSIGN_OR_RETURN(copied_buffer, pjrt_buffers_[i]->CopyToDevice(
+                                               new_sharding->devices()[i]));
+      }
+      if (semantics == ArrayCopySemantics::kDonateInput) {
+        pjrt_buffers_[i] = nullptr;
+      }
+      buffers.push_back(std::shared_ptr<PjRtBuffer>(copied_buffer.release()));
     }
   }
   return PjRtArray::Create(client_, dtype_, shape_, std::move(new_sharding),
diff --git a/third_party/xla/xla/python/py_buffer.cc b/third_party/xla/xla/python/py_buffer.cc
index 481337e..ed98c84 100644
--- a/third_party/xla/xla/python/py_buffer.cc
+++ b/third_party/xla/xla/python/py_buffer.cc
@@ -29,6 +29,7 @@
 #include "pybind11/pytypes.h"  // from @pybind11
 #include "xla/pjrt/pjrt_client.h"
 #include "xla/pjrt/pjrt_compiler.h"
+#include "xla/primitive_util.h"
 #include "xla/python/ifrt/array.h"
 #include "xla/python/ifrt/device.h"
 #include "xla/python/pjrt_ifrt/pjrt_array.h"
@@ -151,8 +152,10 @@
   if (arr != nullptr) {
     auto* pjrt_buffer = arr->pjrt_buffers().front().get();
     TF_RET_CHECK(!pjrt_buffer->IsTuple());
-    // On CPU, we can return the value in a zero-copy way.
-    if (pjrt_buffer->IsOnCpu()) {
+    // On CPU for non-int4 values, we can return the value in a zero-copy way.
+    // For int4 values, we must copy in order to unpack the array.
+    if (pjrt_buffer->IsOnCpu() &&
+        !primitive_util::Is4BitType(pjrt_buffer->element_type())) {
       TF_ASSIGN_OR_RETURN(
           const auto* shape,
           IfrtHelpers::xla_dynamic_shape(ifrt_array, dynamic_shape_holder));
@@ -203,7 +206,8 @@
   auto* arr = llvm::dyn_cast_or_null<ifrt::PjRtCompatibleArray>(ifrt_array);
   if (arr != nullptr) {
     auto* pjrt_buffer = arr->pjrt_buffers().front().get();
-    if (pjrt_buffer->IsOnCpu()) {
+    if (pjrt_buffer->IsOnCpu() &&
+        !primitive_util::Is4BitType(pjrt_buffer->element_type())) {
       return OkStatus();
     }
   }
diff --git a/third_party/xla/xla/python/py_executable.h b/third_party/xla/xla/python/py_executable.h
index c58c8e1..71c1ddf 100644
--- a/third_party/xla/xla/python/py_executable.h
+++ b/third_party/xla/xla/python/py_executable.h
@@ -24,6 +24,7 @@
 #include <vector>
 
 #include "absl/types/span.h"
+#include "pybind11/gil.h"  // from @pybind11
 #include "xla/pjrt/pjrt_client.h"
 #include "xla/python/pjrt_ifrt/pjrt_executable.h"
 #include "xla/python/py_array.h"
@@ -134,6 +135,7 @@
   }
 
   StatusOr<CompiledMemoryStats> GetCompiledMemoryStats() const {
+    pybind11::gil_scoped_release scope;
     return ifrt_loaded_executable_->GetCompiledMemoryStats();
   }
 
diff --git a/third_party/xla/xla/python/py_host_callback.cc b/third_party/xla/xla/python/py_host_callback.cc
index 625bcba..6905540 100644
--- a/third_party/xla/xla/python/py_host_callback.cc
+++ b/third_party/xla/xla/python/py_host_callback.cc
@@ -206,6 +206,8 @@
 PyHostSendAndRecvLoadedHostCallback::~PyHostSendAndRecvLoadedHostCallback() {
   GlobalPyRefManager()->AddGarbage(
       absl::MakeSpan(static_cast<pybind11::object*>(&callable_), 1));
+  GlobalPyRefManager()->AddGarbage(
+      absl::MakeSpan(static_cast<pybind11::object*>(&serializer_), 1));
 }
 
 StatusOr<std::string> PyHostSendAndRecvLoadedHostCallback::Serialize() const {
diff --git a/third_party/xla/xla/python/py_values.cc b/third_party/xla/xla/python/py_values.cc
index 5fccf23..85de0aa 100644
--- a/third_party/xla/xla/python/py_values.cc
+++ b/third_party/xla/xla/python/py_values.cc
@@ -272,7 +272,9 @@
   // We only allow single device case for PyArray in device put.
   if (py_array.num_shards() != 1) {
     return InvalidArgument(
-        "Only single-sharded Array is expected in device_put.");
+        "device_put expects an array with exactly one shard, got an array with "
+        "with %d shards.",
+        py_array.num_shards());
   }
 
   ifrt::Array* ifrt_array = py_array.ifrt_array();
diff --git a/third_party/xla/xla/python/pytree.cc b/third_party/xla/xla/python/pytree.cc
index 5b5e0cc..1a47247 100644
--- a/third_party/xla/xla/python/pytree.cc
+++ b/third_party/xla/xla/python/pytree.cc
@@ -637,6 +637,9 @@
         "PyTree registries of PyTreeDefs passed to Compose() must match.");
   }
   auto out = std::make_unique<PyTreeDef>(registry_->shared_from_this());
+  out->traversal_.reserve(static_cast<size_t>(num_leaves()) *
+                              inner.num_nodes() +
+                          num_nodes() - num_leaves());
   for (const Node& n : traversal_) {
     if (n.kind == PyTreeKind::kLeaf) {
       absl::c_copy(inner.traversal_, std::back_inserter(out->traversal_));
@@ -644,13 +647,7 @@
       out->traversal_.push_back(n);
     }
   }
-  const auto& root = traversal_.back();
-  const auto& inner_root = inner.traversal_.back();
-  // TODO(tomhennigan): This should update all nodes in the traversal.
-  auto& out_root = out->traversal_.back();
-  out_root.num_nodes = (root.num_nodes - root.num_leaves) +
-                       (inner_root.num_nodes * root.num_leaves);
-  out_root.num_leaves *= inner_root.num_leaves;
+  out->SetNumLeavesAndNumNodes();
   return out;
 }
 
diff --git a/third_party/xla/xla/python/pytree_test.py b/third_party/xla/xla/python/pytree_test.py
index 8c80ec5..2bba921 100644
--- a/third_party/xla/xla/python/pytree_test.py
+++ b/third_party/xla/xla/python/pytree_test.py
@@ -87,6 +87,11 @@
     self.roundtrip_node_data(ExampleType(field0=o, field1=o))
     self.roundtrip_node_data(ExampleType2(field0=o, field1=o))
 
+  def testCompose(self):
+    x = registry.flatten(0)[1]
+    y = registry.flatten((0, 0))[1]
+    self.assertEqual((x.compose(y)).num_leaves, 2)
+
 
 if __name__ == "__main__":
   absltest.main()
diff --git a/third_party/xla/xla/python/xla_client.py b/third_party/xla/xla/python/xla_client.py
index 3d61b9a..7d46baf 100644
--- a/third_party/xla/xla/python/xla_client.py
+++ b/third_party/xla/xla/python/xla_client.py
@@ -47,7 +47,7 @@
 
 # Just an internal arbitrary increasing number to help with backward-compatible
 # changes. In JAX, reference this via jax._src.lib.xla_extension_version.
-_version = 207
+_version = 210
 
 # Version number for MLIR:Python components.
 mlir_api_version = 54
diff --git a/third_party/xla/xla/python/xla_client.pyi b/third_party/xla/xla/python/xla_client.pyi
index 5b85d74..11cfb5f 100644
--- a/third_party/xla/xla/python/xla_client.pyi
+++ b/third_party/xla/xla/python/xla_client.pyi
@@ -109,7 +109,7 @@
   ...
 
 
-def make_tpu_client(library_path: Optional[str]) -> Client:
+def make_tpu_client(library_path: Optional[str] = None) -> Client:
   ...
 
 
diff --git a/third_party/xla/xla/python/xla_compiler.cc b/third_party/xla/xla/python/xla_compiler.cc
index 335795f..d27bc48 100644
--- a/third_party/xla/xla/python/xla_compiler.cc
+++ b/third_party/xla/xla/python/xla_compiler.cc
@@ -533,6 +533,40 @@
                     &HloPrintOptions::is_in_nested_computation,
                     &HloPrintOptions::set_is_in_nested_computation);
 
+  // HloModule.computations() returns raw pointers.
+  // pybind seems to prefer smart pointers.
+  // We give pybind a smart pointer to a wrapper around a raw pointer to satisfy
+  // pybind and avoid double frees.
+  class ComputationWrapper {
+   public:
+    ComputationWrapper(const HloComputation* comp,
+                       const std::shared_ptr<HloModule> module)
+        : comp{comp}, module{module} {}
+    absl::string_view name() const { return comp->name(); }
+    void render_html(const std::string& filename) {
+      std::string html = xla::ValueOrThrow(RenderGraph(
+          *comp, /*label=*/"", comp->parent()->config().debug_options(),
+          RenderedGraphFormat::kHtml, HloRenderOptions()));
+      xla::ThrowIfError(tsl::WriteStringToFile(
+          tsl::Env::Default(), absl::StrCat(filename, ".html"), html));
+    }
+
+   private:
+    const HloComputation* comp;
+    // The module owns the computations: if its destructor is called, the
+    // computations are freed. To prevent that from happening in cases where the
+    // module Python object goes out of scope and gets garbage collected before
+    // the computations, we keep a shared_ptr to the module that originated the
+    // computation.
+    const std::shared_ptr<HloModule> module;
+  };
+
+  py::class_<ComputationWrapper, std::shared_ptr<ComputationWrapper>>
+      hlo_computation_class(m, "HloComputation");
+
+  hlo_computation_class.def_property_readonly("name", &ComputationWrapper::name)
+      .def("render_html", &ComputationWrapper::render_html);
+
   py::class_<HloModule, std::shared_ptr<HloModule>> hlo_module_class(
       m, "HloModule");
   hlo_module_class.def_property_readonly("name", &HloModule::name)
@@ -545,6 +579,15 @@
            xla::ValueOrThrowWrapper(GetHloModuleSerializedProto))
       .def("from_serialized_hlo_module_proto",
            xla::ValueOrThrowWrapper(HloModuleFromSerializedProto))
+      .def("computations",
+           [](const std::shared_ptr<HloModule> m)
+               -> std::vector<std::shared_ptr<ComputationWrapper>> {
+             std::vector<std::shared_ptr<ComputationWrapper>> computations;
+             for (HloComputation* comp : m->computations())
+               computations.push_back(
+                   std::make_shared<ComputationWrapper>(comp, m));
+             return computations;
+           })
       .def_property_readonly(
           "spmd_output_sharding",
           [](const HloModule& m) -> std::optional<xla::OpSharding> {
diff --git a/third_party/xla/xla/python/xla_extension/__init__.pyi b/third_party/xla/xla/python/xla_extension/__init__.pyi
index 2bf60da..95f63f9 100644
--- a/third_party/xla/xla/python/xla_extension/__init__.pyi
+++ b/third_party/xla/xla/python/xla_extension/__init__.pyi
@@ -167,6 +167,9 @@
   indent_amount: int
   is_in_nested_computation: bool
 
+class HloComputation:
+  def render_html(self) -> None: ...
+
 class HloModule:
   spmd_output_sharding: Optional[OpSharding]
   spmd_parameters_shardings: Optional[List[OpSharding]]
@@ -177,6 +180,7 @@
   @staticmethod
   def from_serialized_hlo_module_proto(
     serialized_hlo_module_proto: bytes) -> HloModule: ...
+  def computations(self) -> List[HloComputation]: ...
 
 class HloModuleGroup:
   def __init__(self, name: str, modules: List[HloModule]) -> None: ...
diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD
index dc62b9e..65dfc5b 100644
--- a/third_party/xla/xla/service/BUILD
+++ b/third_party/xla/xla/service/BUILD
@@ -11,6 +11,7 @@
     "xla_cc_test",
     "xla_py_proto_library",
     "xla_py_test_deps",
+    "xla_symbol_repository_deps",
 )
 load("//xla/service:xla_compile.bzl", "xla_aot_compile_cpu", "xla_aot_compile_gpu", "xla_aot_compile_gpu_runtime_autotuning")
 load("//xla/stream_executor:build_defs.bzl", "if_gpu_is_configured")
@@ -522,6 +523,7 @@
     visibility = ["//visibility:public"],
     deps = [
         ":constant_value",
+        ":hlo_dce",
         ":hlo_pass",
         ":value_range",
         "//xla:comparison_util",
@@ -574,6 +576,7 @@
         "//xla:util",
         "//xla:xla_proto_cc",
         "//xla/hlo/ir:hlo",
+        "@com_google_absl//absl/functional:any_invocable",
         "@com_google_absl//absl/strings",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:IR",
@@ -597,6 +600,7 @@
     deps = [
         "//xla:permutation_util",
         "//xla:shape_util",
+        "//xla:status",
         "//xla:status_macros",
         "//xla:statusor",
         "//xla:types",
@@ -606,12 +610,14 @@
         "//xla/hlo/ir:hlo",
         "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/container:inlined_vector",
+        "@com_google_absl//absl/log:check",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/strings:str_format",
         "@com_google_absl//absl/types:span",
         "@local_tsl//tsl/platform:errors",
         "@local_tsl//tsl/platform:logging",
-        "@local_tsl//tsl/platform:protobuf",
+        "@local_tsl//tsl/platform:status",
         "@local_tsl//tsl/platform:statusor",
     ],
 )
@@ -661,6 +667,7 @@
         ":dot_as_convolution_util",
         ":hlo_graph_dumper",
         ":hlo_pass",
+        "//xla:array",
         "//xla:protobuf_util",
         "//xla:shape_tree",
         "//xla:shape_util",
@@ -760,16 +767,13 @@
     name = "dynamic_parameter_binding_test",
     srcs = ["dynamic_parameter_binding_test.cc"],
     deps = [
-        ":hlo_dce",
-        ":hlo_memory_scheduler",
-        ":hlo_ordering",
         "//xla:shape_util",
-        "//xla:types",
         "//xla/hlo/ir:hlo",
         "//xla/tests:hlo_test_base",
         "//xla/tests:xla_internal_test_main",
-        "@com_google_absl//absl/algorithm:container",
+        "@com_google_googletest//:gtest",
         "@local_tsl//tsl/lib/core:status_test_util",
+        "@local_tsl//tsl/platform:statusor",
     ],
 )
 
@@ -1675,6 +1679,7 @@
         "@com_google_absl//absl/container:btree",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/log:check",
         "@com_google_absl//absl/memory",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/strings:str_format",
@@ -1774,6 +1779,7 @@
         ":hlo_dataflow_analysis",
         ":hlo_ordering",
         ":hlo_proto_cc",
+        ":time_utils",
         ":tuple_points_to_analysis",
         "//xla:comparison_util",
         "//xla:status",
@@ -2060,8 +2066,9 @@
         ":hlo_module_config",
         ":shape_inference",
         "//xla:comparison_util",
-        "//xla:literal",
         "//xla:literal_util",
+        "//xla:shape_util",
+        "//xla:status_macros",
         "//xla:statusor",
         "//xla:util",
         "//xla/client:xla_builder",
@@ -2069,7 +2076,11 @@
         "//xla/client/lib:comparators",
         "//xla/hlo/ir:hlo",
         "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/log:check",
         "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:span",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:statusor",
     ],
 )
 
@@ -2131,15 +2142,18 @@
         "//xla:literal",
         "//xla:literal_util",
         "//xla:shape_util",
-        "//xla:status_macros",
-        "//xla:types",
+        "//xla:status",
+        "//xla:statusor",
         "//xla:util",
         "//xla:xla_data_proto_cc",
         "//xla/hlo/ir:hlo",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/functional:function_ref",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:span",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:logging",
         "@local_tsl//tsl/platform:status",
+        "@local_tsl//tsl/platform:statusor",
     ],
 )
 
@@ -2472,26 +2486,30 @@
     deps = [
         ":hlo_cost_analysis",
         ":hlo_creation_utils",
+        ":hlo_module_config",
         ":hlo_pass",
         ":pattern_matcher",
+        ":shape_inference",
         "//xla:comparison_util",
         "//xla:literal",
         "//xla:literal_comparison",
         "//xla:literal_util",
         "//xla:permutation_util",
         "//xla:shape_util",
+        "//xla:status",
         "//xla:status_macros",
-        "//xla:types",
+        "//xla:statusor",
         "//xla:util",
         "//xla:window_util",
         "//xla:xla_data_proto_cc",
         "//xla/hlo/evaluator:hlo_evaluator",
         "//xla/hlo/ir:hlo",
-        "//xla/hlo/utils:hlo_query",
         "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/container:flat_hash_set",
         "@com_google_absl//absl/container:inlined_vector",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/numeric:bits",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:span",
         "@local_tsl//tsl/platform:errors",
@@ -3554,6 +3572,7 @@
     deps = [
         ":hlo_creation_utils",
         ":hlo_pass",
+        "//xla:shape_util",
         "//xla/hlo/ir:hlo",
     ],
 )
@@ -3742,16 +3761,18 @@
         ":call_inliner",
         ":dynamic_window_utils",
         ":hlo_creation_utils",
+        ":hlo_dataflow_analysis",
+        ":hlo_value",
         ":tuple_util",
         ":while_util",
         "//xla:comparison_util",
+        "//xla:literal",
         "//xla:literal_util",
         "//xla:shape_tree",
         "//xla:shape_util",
         "//xla:status",
         "//xla:status_macros",
         "//xla:statusor",
-        "//xla:types",
         "//xla:util",
         "//xla:window_util",
         "//xla:xla_data_proto_cc",
@@ -3760,6 +3781,7 @@
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/container:flat_hash_set",
         "@com_google_absl//absl/container:inlined_vector",
+        "@com_google_absl//absl/functional:function_ref",
         "@com_google_absl//absl/log",
         "@com_google_absl//absl/log:check",
         "@com_google_absl//absl/strings",
@@ -3767,6 +3789,7 @@
         "@com_google_absl//absl/types:span",
         "@local_tsl//tsl/platform:errors",
         "@local_tsl//tsl/platform:status",
+        "@local_tsl//tsl/platform:statusor",
     ],
 )
 
@@ -3814,6 +3837,7 @@
     hdrs = ["dynamic_padder.h"],
     visibility = ["//visibility:public"],
     deps = [
+        ":call_graph",
         ":dynamic_dimension_inference",
         ":dynamic_window_utils",
         ":hlo_creation_utils",
@@ -3821,24 +3845,29 @@
         ":hlo_pass",
         ":pattern_matcher",
         ":shape_inference",
+        ":tuple_util",
         "//xla:comparison_util",
-        "//xla:literal",
         "//xla:literal_util",
         "//xla:shape_util",
         "//xla:status",
         "//xla:status_macros",
+        "//xla:statusor",
         "//xla:util",
         "//xla:window_util",
         "//xla:xla_data_proto_cc",
         "//xla/client:xla_builder",
         "//xla/hlo/ir:hlo",
         "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/container:flat_hash_set",
         "@com_google_absl//absl/functional:function_ref",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/strings",
         "@com_google_absl//absl/strings:str_format",
+        "@com_google_absl//absl/types:span",
         "@local_tsl//tsl/lib/monitoring:gauge",
         "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:statusor",
     ],
 )
 
@@ -3846,6 +3875,7 @@
     name = "dynamic_padder_test",
     srcs = ["dynamic_padder_test.cc"],
     deps = [
+        ":algebraic_simplifier",
         ":dynamic_dimension_inference",
         ":dynamic_dimension_simplifier",
         ":dynamic_padder",
@@ -3854,9 +3884,13 @@
         ":pattern_matcher",
         ":pattern_matcher_gmock",
         ":tuple_simplifier",
+        "//xla:error_spec",
         "//xla:literal",
+        "//xla:literal_util",
         "//xla:shape_util",
+        "//xla:status",
         "//xla:status_macros",
+        "//xla:statusor",
         "//xla:test",
         "//xla:test_helpers",
         "//xla:util",
@@ -3870,8 +3904,14 @@
         "//xla/tests:llvm_irgen_test_base",
         "//xla/tests:test_macros_header",
         "//xla/tests:xla_internal_test_main",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
         "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:span",
         "@local_tsl//tsl/lib/core:status_test_util",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:status",
+        "@local_tsl//tsl/platform:statusor",
         "@local_tsl//tsl/platform:test_benchmark",
         "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc",
     ],
@@ -3895,6 +3935,7 @@
         "//xla/tests:hlo_test_base",
         "//xla/tests:xla_internal_test_main",
         "@local_tsl//tsl/lib/core:status_test_util",
+        "@local_tsl//tsl/platform:statusor",
         "@local_tsl//tsl/platform:test_benchmark",
     ],
 )
@@ -4262,6 +4303,7 @@
     deps = [
         ":call_graph",
         ":hlo_phi_graph",
+        ":hlo_value",
         "//xla:shape_util",
         "//xla:status",
         "//xla:statusor",
@@ -4270,10 +4312,12 @@
         "//xla:xla_data_proto_cc",
         "//xla/hlo/ir:hlo",
         "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/base:core_headers",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/container:flat_hash_set",
         "@com_google_absl//absl/container:inlined_vector",
         "@com_google_absl//absl/functional:function_ref",
+        "@com_google_absl//absl/memory",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:span",
         "@local_tsl//tsl/platform:errors",
@@ -4292,19 +4336,21 @@
         ":hlo_dce",
         ":hlo_graph_dumper",
         ":hlo_ordering",
-        ":instruction_fusion",
-        "//xla:literal",
+        ":hlo_value",
+        "//xla:comparison_util",
+        "//xla:literal_util",
         "//xla:shape_util",
-        "//xla:status_macros",
+        "//xla:status",
         "//xla:test",
-        "//xla:test_helpers",
         "//xla:xla_data_proto_cc",
         "//xla/hlo/ir:hlo",
-        "//xla/hlo/utils:hlo_matchers",
         "//xla/tests:hlo_test_base",
         "//xla/tests:xla_internal_test_main",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/strings",
+        "@com_google_googletest//:gtest",
         "@local_tsl//tsl/lib/core:status_test_util",
-        "@local_tsl//tsl/platform:logging",
+        "@local_tsl//tsl/platform:statusor",
         "@local_tsl//tsl/platform:test",
     ],
 )
@@ -4765,14 +4811,16 @@
     deps = [
         ":hlo_pass",
         "//xla:status",
-        "//xla:status_macros",
         "//xla:statusor",
-        "//xla:types",
         "//xla:util",
         "//xla/hlo/ir:hlo",
         "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/strings",
         "@local_tsl//tsl/platform:errors",
         "@local_tsl//tsl/platform:logging",
+        "@local_tsl//tsl/platform:statusor",
     ],
 )
 
@@ -4854,9 +4902,10 @@
         ":hlo_verifier",
         "//xla:shape_util",
         "//xla:status",
-        "//xla:util",
         "//xla/hlo/ir:hlo",
         "@com_google_absl//absl/status",
+        "@com_google_absl//absl/strings:str_format",
+        "@local_tsl//tsl/platform:errors",
     ],
 )
 
@@ -5368,6 +5417,7 @@
     ],
     deps = [
         ":hlo_parser",
+        "//xla:error_spec",
         "//xla:execution_options_util",
         "//xla:status_macros",
         "//xla:test",
@@ -5375,6 +5425,7 @@
         "//xla/tests:hlo_test_base",
         "//xla/tests:test_macros_header",
         "//xla/tests:xla_internal_test_main",
+        "@com_google_absl//absl/strings",
     ],
 )
 
@@ -5765,8 +5816,17 @@
     visibility = ["//visibility:public"],
     deps = [
         ":hlo_value",
+        "//xla:shape_tree",
+        "//xla:shape_util",
+        "//xla:statusor",
         "//xla/hlo/ir:hlo",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:inlined_vector",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:span",
+        "@local_tsl//tsl/platform:statusor",
     ],
 )
 
@@ -5818,13 +5878,21 @@
         ":call_inliner",
         ":hlo_creation_utils",
         ":tuple_util",
+        "//xla:comparison_util",
         "//xla:literal_util",
+        "//xla:shape_util",
+        "//xla:statusor",
+        "//xla:xla_data_proto_cc",
         "//xla/hlo/ir:hlo",
         "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/container:inlined_vector",
         "@com_google_absl//absl/functional:function_ref",
+        "@com_google_absl//absl/log:check",
         "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:span",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:statusor",
     ],
 )
 
@@ -7076,11 +7144,14 @@
     visibility = ["//visibility:public"],
     deps = [
         ":hlo_pass",
+        "//xla:shape_layout",
+        "//xla:shape_util",
         "//xla:statusor",
         "//xla/hlo/ir:hlo",
         "@com_google_absl//absl/container:flat_hash_set",
         "@com_google_absl//absl/strings",
         "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:status",
     ],
 )
 
@@ -7139,20 +7210,32 @@
     visibility = ["//visibility:public"],
     deps = [
         ":compiler",
+        ":executable",
+        ":hlo_module_config",
+        ":symbol_repository",
         "//xla:autotune_results_proto_cc",
         "//xla:debug_options_flags",
         "//xla:statusor",
+        "//xla:util",
+        "//xla/hlo/ir:hlo_module_group",
         "//xla/mlir_hlo",
         "//xla/pjrt:mlir_to_hlo",
         "//xla/service:cpu_plugin",
         "//xla/service/cpu:cpu_compiler",
         "//xla/service/cpu:cpu_executable",
+        "//xla/service/gpu:autotuner_util",
+        "//xla/service/gpu:gpu_symbol_repository",
+        "//xla/stream_executor",
+        "//xla/stream_executor:device_memory_allocator",
         "//xla/tools:hlo_module_loader",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/strings",
         "@llvm-project//mlir:ArithDialect",
         "@llvm-project//mlir:FuncDialect",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Parser",
         "@local_tsl//tsl/platform:env",
+        "@local_tsl//tsl/platform:errors",
         "@local_tsl//tsl/platform:path",
         "@local_tsl//tsl/platform:platform_port",
         "@local_tsl//tsl/platform:protobuf",
@@ -7161,16 +7244,22 @@
     ] + if_cuda_is_configured([
         "//xla/service/gpu:executable_proto_cc",
         "//xla/service/gpu:gpu_compiler",
+        "//xla/service/gpu:nvptx_compiler",
         "//xla/service/gpu:nvptx_compiler_impl",
+        "//xla/stream_executor/gpu:gpu_init",
+        "//xla/stream_executor/cuda:cuda_platform",
     ]) + if_rocm_is_configured([
         "//xla/service/gpu:executable_proto_cc",
         "//xla/service/gpu:gpu_compiler",
+        "//xla/service/gpu:amdgpu_compiler",
         "//xla/service/gpu:amdgpu_compiler_impl",
+        "//xla/stream_executor/gpu:gpu_init",
+        "//xla/stream_executor/rocm:rocm_platform",
     ]) + if_cuda([
         "//xla/stream_executor/cuda:cublas_plugin",
     ]) + if_rocm([
         "//xla/stream_executor/rocm:rocblas_plugin",
-    ]),
+    ]) + xla_symbol_repository_deps(),
 )
 
 # A simple test of xla_aot_compile which generates an output file from an mhlo file.
@@ -7344,3 +7433,29 @@
         "@local_tsl//tsl/platform:protobuf",
     ],
 )
+
+cc_library(
+    name = "symbol_repository",
+    hdrs = ["symbol_repository.h"],
+    visibility = ["//visibility:public"],
+    deps = [
+        ":compiler",
+        "//xla:xla_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/stream_executor:device_description_proto_cc",
+        "@com_google_absl//absl/base:core_headers",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/synchronization",
+    ],
+)
+
+cc_library(
+    name = "time_utils",
+    srcs = ["time_utils.cc"],
+    hdrs = ["time_utils.h"],
+    visibility = ["//visibility:public"],
+    deps = [],
+)
diff --git a/third_party/xla/xla/service/algebraic_simplifier.cc b/third_party/xla/xla/service/algebraic_simplifier.cc
index 1ddd9d2..434e3bf 100644
--- a/third_party/xla/xla/service/algebraic_simplifier.cc
+++ b/third_party/xla/xla/service/algebraic_simplifier.cc
@@ -18,14 +18,13 @@
 #include <algorithm>
 #include <array>
 #include <cmath>
-#include <functional>
+#include <cstdint>
 #include <iterator>
 #include <memory>
 #include <numeric>
 #include <optional>
 #include <string>
 #include <tuple>
-#include <type_traits>
 #include <utility>
 #include <vector>
 
@@ -33,17 +32,19 @@
 #include "absl/container/flat_hash_map.h"
 #include "absl/container/flat_hash_set.h"
 #include "absl/container/inlined_vector.h"
+#include "absl/log/check.h"
+#include "absl/numeric/bits.h"
 #include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
 #include "absl/types/span.h"
 #include "xla/comparison_util.h"
 #include "xla/hlo/evaluator/hlo_evaluator.h"
-#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
 #include "xla/hlo/ir/hlo_casting_utils.h"
 #include "xla/hlo/ir/hlo_computation.h"
 #include "xla/hlo/ir/hlo_instruction.h"
 #include "xla/hlo/ir/hlo_instructions.h"
 #include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/hlo/utils/hlo_query.h"
+#include "xla/layout.h"
 #include "xla/layout_util.h"
 #include "xla/literal.h"
 #include "xla/literal_comparison.h"
@@ -53,11 +54,14 @@
 #include "xla/primitive_util.h"
 #include "xla/service/hlo_cost_analysis.h"
 #include "xla/service/hlo_creation_utils.h"
+#include "xla/service/hlo_module_config.h"
 #include "xla/service/pattern_matcher.h"
+#include "xla/service/shape_inference.h"
 #include "xla/shape.h"
 #include "xla/shape_util.h"
+#include "xla/status.h"
 #include "xla/status_macros.h"
-#include "xla/types.h"
+#include "xla/statusor.h"
 #include "xla/util.h"
 #include "xla/window_util.h"
 #include "xla/xla_data.pb.h"
@@ -4491,8 +4495,9 @@
           OutputIsSubsetOfOperandElements(user, broadcast)) {
         VLOG(10) << "transform permuting/subset of a scalar broadcast into "
                  << "a single broadcast";
-        HloInstruction* new_broadcast = user->AddInstruction(
-            HloInstruction::CreateBroadcast(user->shape(), operand, {}));
+        HloInstruction* new_broadcast =
+            user->AddInstruction(HloInstruction::CreateBroadcast(
+                ShapeUtil::MakeStaticShape(user->shape()), operand, {}));
         // Use HloInstruction::ReplaceAllUsesWith instead of
         // HloComputation::ReplaceWithNewInstruction because we are replacing an
         // instruction other than the visited instruction.
@@ -4600,6 +4605,28 @@
         return ReplaceInstruction(compare, MakeScalarLike(compare, true));
     }
   }
+  if (ShapeUtil::HasPrimitiveType(lhs->shape(), xla::PRED) &&
+      ShapeUtil::HasPrimitiveType(rhs->shape(), xla::PRED)) {
+    if (compare->comparison_direction() == ComparisonDirection::kNe) {
+      // A != false -> A
+      if (IsAll(rhs, false)) {
+        return ReplaceInstruction(compare, lhs);
+      }
+      // false != A -> A
+      if (IsAll(lhs, false)) {
+        return ReplaceInstruction(compare, rhs);
+      }
+    } else if (compare->comparison_direction() == ComparisonDirection::kEq) {
+      // A == true -> A
+      if (IsAll(rhs, true)) {
+        return ReplaceInstruction(compare, lhs);
+      }
+      // true == A -> A
+      if (IsAll(lhs, true)) {
+        return ReplaceInstruction(compare, rhs);
+      }
+    }
+  }
   return OkStatus();
 }
 
@@ -7039,6 +7066,28 @@
                                                         /*index=*/0));
     }
   }
+
+  // Replace Reduce(Broadcast(x), dims, Sum()) with Broadcast(x * prod(dims)).
+  if (HloInstruction * broadcast_arg;
+      Match(arg, m::Broadcast(m::ConstantScalar(&broadcast_arg))) &&
+      Match(function->root_instruction(),
+            m::AddAnyOrder(m::Parameter(0), m::Parameter(1)))) {
+    if (auto broadcast_value = GetConstantValue(broadcast_arg);
+        broadcast_value.has_value() &&
+        // Skip float64, where product is too accurate compared to repeated-sum.
+        broadcast_arg->shape().element_type() != PrimitiveType::F64) {
+      auto result_value = broadcast_value.value() *
+                          ShapeUtil::ElementsIn(arg->shape()) /
+                          ShapeUtil::ElementsIn(reduce_result_shape);
+      return ReplaceWithNewInstruction(
+          reduce, HloInstruction::CreateBroadcast(
+                      reduce_result_shape,
+                      reduce->AddInstruction(
+                          MakeScalarInstruction(reduce, result_value)),
+                      {}));
+    }
+  }
+
   return OkStatus();
 }
 
@@ -7378,16 +7427,19 @@
 
 Status AlgebraicSimplifierVisitor::HandleSelect(HloInstruction* select) {
   // select(x, y, y) -> y.
-  if (select->operand(1) == select->operand(2)) {
-    return ReplaceInstruction(select, select->mutable_operand(1));
+  if (select->operand(1) == select->operand(2) &&
+      ReplaceInstructionIfCompatible(select, select->mutable_operand(1))) {
+    return OkStatus();
   }
   // select(true, x, y) -> x.
-  if (IsAll(select->operand(0), true)) {
-    return ReplaceInstruction(select, select->mutable_operand(1));
+  if (IsAll(select->operand(0), true) &&
+      ReplaceInstructionIfCompatible(select, select->mutable_operand(1))) {
+    return OkStatus();
   }
   // select(false, x, y) -> y.
-  if (IsAll(select->operand(0), false)) {
-    return ReplaceInstruction(select, select->mutable_operand(2));
+  if (IsAll(select->operand(0), false) &&
+      ReplaceInstructionIfCompatible(select, select->mutable_operand(2))) {
+    return OkStatus();
   }
   // select(not(pred), a, b) -> select(pred, b, a)
   if (HloOpcode::kNot == select->operand(0)->opcode()) {
@@ -7399,6 +7451,20 @@
         HloInstruction::CreateTernary(select->shape(), HloOpcode::kSelect,
                                       pred_operand, on_false, on_true));
   }
+  // select(PRED, PRED, PRED)
+  if (ShapeUtil::HasPrimitiveType(select->shape(), xla::PRED)) {
+    // select(a, true, false) -> a
+    if (IsAll(select->operand(1), true) && IsAll(select->operand(2), false)) {
+      return ReplaceInstruction(select, select->mutable_operand(0));
+    }
+    // select(a, false, true) -> not(a)
+    if (IsAll(select->operand(1), false) && IsAll(select->operand(2), true)) {
+      return ReplaceWithNewInstruction(
+          select, HloInstruction::CreateUnary(
+                      select->mutable_operand(0)->shape(), HloOpcode::kNot,
+                      select->mutable_operand(0)));
+    }
+  }
 
   // select(pred, xs, dynamic_update_slice(xs, x, i))
   //     -> dynamic_update_slice(xs, select(pred, dynamic_slice(xs, i), x), i)
diff --git a/third_party/xla/xla/service/algebraic_simplifier_test.cc b/third_party/xla/xla/service/algebraic_simplifier_test.cc
index bfc53ad..e5dea99 100644
--- a/third_party/xla/xla/service/algebraic_simplifier_test.cc
+++ b/third_party/xla/xla/service/algebraic_simplifier_test.cc
@@ -554,6 +554,28 @@
   EXPECT_EQ(computation->root_instruction(), param0);
 }
 
+// Test that select(true, a, b) is not simplified to a when mixed precision
+TEST_F(AlgebraicSimplifierTest, SelectTrueMixedPrecision) {
+  Shape r0bf16 = ShapeUtil::MakeShape(BF16, {});
+  Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+  HloComputation::Builder builder(TestName());
+  HloInstruction* param0 = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, r0bf16, "param0"));
+  HloInstruction* param1 = builder.AddInstruction(
+      HloInstruction::CreateParameter(1, r0f32, "param1"));
+  HloInstruction* one = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
+  builder.AddInstruction(HloInstruction::CreateTernary(
+      r0f32, HloOpcode::kSelect, one, param0, param1));
+
+  auto module = CreateNewVerifiedModule();
+  auto computation = module->AddEntryComputationWithLayouts(builder.Build());
+  HloInstruction* root = computation->root_instruction();
+  EXPECT_EQ(root->opcode(), HloOpcode::kSelect);
+  AlgebraicSimplifier simplifier(default_options_);
+  ASSERT_FALSE(simplifier.Run(module.get()).value());
+}
+
 // Test that select(false, a, b) is simplified to b
 TEST_F(AlgebraicSimplifierTest, SelectFalse) {
   Shape r0s32 = ShapeUtil::MakeShape(S32, {});
@@ -576,6 +598,28 @@
   EXPECT_EQ(computation->root_instruction(), param1);
 }
 
+// Test that select(false a, b) is not simplified to a when mixed precision
+TEST_F(AlgebraicSimplifierTest, SelectFalseMixedPrecision) {
+  Shape r0bf16 = ShapeUtil::MakeShape(BF16, {});
+  Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+  HloComputation::Builder builder(TestName());
+  HloInstruction* param0 = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, r0f32, "param0"));
+  HloInstruction* param1 = builder.AddInstruction(
+      HloInstruction::CreateParameter(1, r0bf16, "param1"));
+  HloInstruction* one = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
+  builder.AddInstruction(HloInstruction::CreateTernary(
+      r0f32, HloOpcode::kSelect, one, param0, param1));
+
+  auto module = CreateNewVerifiedModule();
+  auto computation = module->AddEntryComputationWithLayouts(builder.Build());
+  HloInstruction* root = computation->root_instruction();
+  EXPECT_EQ(root->opcode(), HloOpcode::kSelect);
+  AlgebraicSimplifier simplifier(default_options_);
+  ASSERT_FALSE(simplifier.Run(module.get()).value());
+}
+
 // Test that select(a, b, b) is simplified to b
 TEST_F(AlgebraicSimplifierTest, SelectIdentical) {
   Shape r0s32 = ShapeUtil::MakeShape(S32, {});
@@ -596,6 +640,27 @@
   EXPECT_EQ(computation->root_instruction(), param1);
 }
 
+// Test that select(a, b, b) is not simplified to a when mixed precision
+TEST_F(AlgebraicSimplifierTest, SelectIdenticalMixedPrecision) {
+  Shape r0bf16 = ShapeUtil::MakeShape(BF16, {});
+  Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+  Shape r0pred = ShapeUtil::MakeShape(PRED, {});
+  HloComputation::Builder builder(TestName());
+  HloInstruction* param0 = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, r0pred, "param0"));
+  HloInstruction* param1 = builder.AddInstruction(
+      HloInstruction::CreateParameter(1, r0bf16, "param1"));
+  builder.AddInstruction(HloInstruction::CreateTernary(
+      r0f32, HloOpcode::kSelect, param0, param1, param1));
+
+  auto module = CreateNewVerifiedModule();
+  auto computation = module->AddEntryComputationWithLayouts(builder.Build());
+  HloInstruction* root = computation->root_instruction();
+  EXPECT_EQ(root->opcode(), HloOpcode::kSelect);
+  AlgebraicSimplifier simplifier(default_options_);
+  ASSERT_FALSE(simplifier.Run(module.get()).value());
+}
+
 // Test that select(not(pred), a, b) is simplified to select(pred, b, a)
 TEST_F(AlgebraicSimplifierTest, SelectWithNotPred) {
   Shape pred_ty = ShapeUtil::MakeShape(PRED, {});
@@ -624,6 +689,52 @@
   EXPECT_EQ(operands[2], param1);
 }
 
+// Test that select(a, true, false) is simplified to a
+TEST_F(AlgebraicSimplifierTest, SelectPredPred) {
+  Shape r0pred = ShapeUtil::MakeShape(PRED, {});
+  HloComputation::Builder builder(TestName());
+  HloInstruction* param0 = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, r0pred, "param0"));
+  HloInstruction* one = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
+  HloInstruction* zero = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
+  builder.AddInstruction(HloInstruction::CreateTernary(
+      r0pred, HloOpcode::kSelect, param0, one, zero));
+
+  auto module = CreateNewVerifiedModule();
+  auto computation = module->AddEntryComputationWithLayouts(builder.Build());
+  HloInstruction* root = computation->root_instruction();
+  EXPECT_EQ(root->opcode(), HloOpcode::kSelect);
+  AlgebraicSimplifier simplifier(default_options_);
+  ASSERT_TRUE(simplifier.Run(module.get()).value());
+  EXPECT_EQ(computation->root_instruction(), param0);
+}
+
+// Test that select(a, false, true) is simplified to not(a)
+TEST_F(AlgebraicSimplifierTest, SelectPredPred2) {
+  auto m = CreateNewVerifiedModule();
+  Shape r0pred = ShapeUtil::MakeShape(PRED, {});
+  HloComputation::Builder builder(TestName());
+  HloInstruction* param0 = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, r0pred, "param0"));
+  HloInstruction* zero = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
+  HloInstruction* one = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
+  builder.AddInstruction(HloInstruction::CreateTernary(
+      r0pred, HloOpcode::kSelect, param0, zero, one));
+
+  auto module = CreateNewVerifiedModule();
+  auto computation = module->AddEntryComputationWithLayouts(builder.Build());
+  HloInstruction* root = computation->root_instruction();
+  EXPECT_EQ(root->opcode(), HloOpcode::kSelect);
+  AlgebraicSimplifier simplifier(default_options_);
+  ASSERT_TRUE(simplifier.Run(module.get()).value());
+  EXPECT_THAT(computation->root_instruction(),
+              GmockMatch(m::Not(m::Parameter(0))));
+}
+
 // Test that select(pred, xs, dynamic_update_slice(xs, x, i)) is simplified
 // to dynamic_update_slice(xs, select(pred, dynamic_slice(xs, i), x), i)
 TEST_F(AlgebraicSimplifierTest, SelectDUSWithShapedPred) {
@@ -7944,6 +8055,90 @@
                      .WithComparisonDirection(ComparisonDirection::kLt)));
 }
 
+// Test that A != False is simplified to A
+TEST_F(AlgebraicSimplifierTest, NeFalse) {
+  auto m = CreateNewVerifiedModule();
+  Shape r0pred = ShapeUtil::MakeShape(PRED, {});
+  HloComputation::Builder builder(TestName());
+  HloInstruction* param0 = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, r0pred, "param0"));
+  HloInstruction* const_false = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
+  builder.AddInstruction(HloInstruction::CreateCompare(
+      r0pred, param0, const_false, ComparisonDirection::kNe));
+
+  auto computation = m->AddEntryComputationWithLayouts(builder.Build());
+  HloInstruction* root = computation->root_instruction();
+  EXPECT_EQ(root->opcode(), HloOpcode::kCompare);
+  AlgebraicSimplifier simplifier(default_options_);
+  ASSERT_TRUE(simplifier.Run(m.get()).value());
+  root = computation->root_instruction();
+  EXPECT_EQ(root, param0);
+}
+
+// Test that False != A is simplified to A
+TEST_F(AlgebraicSimplifierTest, NeFalse2) {
+  auto m = CreateNewVerifiedModule();
+  Shape r0pred = ShapeUtil::MakeShape(PRED, {});
+  HloComputation::Builder builder(TestName());
+  HloInstruction* param0 = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, r0pred, "param0"));
+  HloInstruction* const_false = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
+  builder.AddInstruction(HloInstruction::CreateCompare(
+      r0pred, const_false, param0, ComparisonDirection::kNe));
+
+  auto computation = m->AddEntryComputationWithLayouts(builder.Build());
+  HloInstruction* root = computation->root_instruction();
+  EXPECT_EQ(root->opcode(), HloOpcode::kCompare);
+  AlgebraicSimplifier simplifier(default_options_);
+  ASSERT_TRUE(simplifier.Run(m.get()).value());
+  root = computation->root_instruction();
+  EXPECT_EQ(root, param0);
+}
+
+// Test that A == True is simplified to A
+TEST_F(AlgebraicSimplifierTest, EqTrue) {
+  auto m = CreateNewVerifiedModule();
+  Shape r0pred = ShapeUtil::MakeShape(PRED, {});
+  HloComputation::Builder builder(TestName());
+  HloInstruction* param0 = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, r0pred, "param0"));
+  HloInstruction* const_true = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
+  builder.AddInstruction(HloInstruction::CreateCompare(
+      r0pred, param0, const_true, ComparisonDirection::kEq));
+
+  auto computation = m->AddEntryComputationWithLayouts(builder.Build());
+  HloInstruction* root = computation->root_instruction();
+  EXPECT_EQ(root->opcode(), HloOpcode::kCompare);
+  AlgebraicSimplifier simplifier(default_options_);
+  ASSERT_TRUE(simplifier.Run(m.get()).value());
+  root = computation->root_instruction();
+  EXPECT_EQ(root, param0);
+}
+
+// Test that True == A is simplified to A
+TEST_F(AlgebraicSimplifierTest, EqTrue2) {
+  auto m = CreateNewVerifiedModule();
+  Shape r0pred = ShapeUtil::MakeShape(PRED, {});
+  HloComputation::Builder builder(TestName());
+  HloInstruction* param0 = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, r0pred, "param0"));
+  HloInstruction* const_true = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
+  builder.AddInstruction(HloInstruction::CreateCompare(
+      r0pred, const_true, param0, ComparisonDirection::kEq));
+
+  auto computation = m->AddEntryComputationWithLayouts(builder.Build());
+  HloInstruction* root = computation->root_instruction();
+  EXPECT_EQ(root->opcode(), HloOpcode::kCompare);
+  AlgebraicSimplifier simplifier(default_options_);
+  ASSERT_TRUE(simplifier.Run(m.get()).value());
+  root = computation->root_instruction();
+  EXPECT_EQ(root, param0);
+}
+
 TEST_F(AlgebraicSimplifierTest, CanDisableDotToMultiplyRewrite) {
   // Some backends may have better performance by treating an outer product as a
   // Dot, rather than a broadcast Multiply
@@ -8783,6 +8978,33 @@
               GmockMatch(m::Add(m::Parameter(0), m::Parameter(1))));
 }
 
+TEST_F(AlgebraicSimplifierTest, ReplaceReduceSumOfConstantBroadcast) {
+  const char* kModuleStr = R"(
+HloModule ReplaceReduceSumOfConstantBroadcast
+
+add_f32 {
+  p0 = f32[] parameter(0)
+  p1 = f32[] parameter(1)
+  ROOT r = f32[] add(p0, p1)
+}
+
+ENTRY main {
+  init_value = f32[] constant(0)
+  const_value = f32[] constant(1)
+  const_bcast = f32[8, 128] broadcast(f32[] const_value), dimensions={}
+  ROOT reduce = f32[8] reduce(f32[8, 128] const_bcast, f32[] init_value), dimensions={1}, to_apply=add_f32
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
+  ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value());
+  int64_t reduce_count =
+      absl::c_count_if(m->entry_computation()->instructions(),
+                       HloPredicateIsOp<HloOpcode::kReduce>);
+  // Expect no Reduce operation after simplification.
+  EXPECT_EQ(0, reduce_count);
+}
+
 TEST_F(AlgebraicSimplifierTest, ReplaceReduceMaxWithReduceArgMax) {
   const char* kModuleStr = R"(
 HloModule ReplaceReduceMaxWithReduceArgMax
diff --git a/third_party/xla/xla/service/all_reduce_promotion.cc b/third_party/xla/xla/service/all_reduce_promotion.cc
index 00469a2..3096512 100644
--- a/third_party/xla/xla/service/all_reduce_promotion.cc
+++ b/third_party/xla/xla/service/all_reduce_promotion.cc
@@ -49,6 +49,7 @@
     return inst->GetModule()->AddEmbeddedComputation(promoted.Build());
   }();
   new_inst->set_to_apply(to_apply_promoted);
+  to_apply_promoted->SetCollectiveCallInstruction(new_inst.get());
   return new_inst;
 }
 
diff --git a/third_party/xla/xla/service/all_reduce_reassociate.cc b/third_party/xla/xla/service/all_reduce_reassociate.cc
index 7575a75..f7be2f8 100644
--- a/third_party/xla/xla/service/all_reduce_reassociate.cc
+++ b/third_party/xla/xla/service/all_reduce_reassociate.cc
@@ -359,7 +359,9 @@
       }
       if (reduce_scatter_pattern_match) {
         TF_RETURN_IF_ERROR(computation->RemoveInstruction(lhs));
-        TF_RETURN_IF_ERROR(computation->RemoveInstruction(rhs));
+        if (lhs != rhs) {
+          TF_RETURN_IF_ERROR(computation->RemoveInstruction(rhs));
+        }
       }
       TF_RETURN_IF_ERROR(computation->RemoveInstruction(ar0));
       if (ar0 != ar1) {
diff --git a/third_party/xla/xla/service/all_reduce_reassociate_test.cc b/third_party/xla/xla/service/all_reduce_reassociate_test.cc
index 13b17ef..b6b8618 100644
--- a/third_party/xla/xla/service/all_reduce_reassociate_test.cc
+++ b/third_party/xla/xla/service/all_reduce_reassociate_test.cc
@@ -667,5 +667,39 @@
   EXPECT_EQ(AllReduceCount(module), 1);
 }
 
+TEST_F(AllReduceSimplifierTest, AllReduceDynamicSlicePatternSameOperand) {
+  absl::string_view hlo_string = R"(
+HloModule m
+
+sum {
+  a = f32[] parameter(0)
+  b = f32[] parameter(1)
+  ROOT add.2 = f32[] add(a, b)
+}
+
+ENTRY main {
+  p0 = f32[1,8] parameter(0)
+  p1 = f32[1,8] parameter(1)
+  p2 = s32[] parameter(2)
+  cst = s32[] constant(0)
+  ar0 = f32[1,8] all-reduce(p0), replica_groups={}, to_apply=sum
+  ar2 = f32[1,8] all-reduce(p1), replica_groups={}, to_apply=sum
+  dyn0 = f32[1,4] dynamic-slice(ar0, cst, p2), dynamic_slice_sizes={1,4}
+  dyn2 = f32[1,4] dynamic-slice(ar2, cst, p2), dynamic_slice_sizes={1,4}
+  add = f32[1,4] add(dyn0, dyn0)
+  ROOT add1 = f32[1,4] add(add, dyn2)
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          RunPass(hlo_string, /*expect_change=*/true));
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              m::DynamicSlice(
+                  m::AllReduce(m::Add(m::Add(m::Parameter(0), m::Parameter(0)),
+                                      m::Parameter(1))),
+                  m::Constant(), m::Parameter(2)));
+  XLA_VLOG_LINES(1, module->ToString());
+  EXPECT_EQ(AllReduceCount(module), 1);
+}
+
 }  // namespace
 }  // namespace xla
diff --git a/third_party/xla/xla/service/batchnorm_expander.cc b/third_party/xla/xla/service/batchnorm_expander.cc
index fed640a..0e2edcc 100644
--- a/third_party/xla/xla/service/batchnorm_expander.cc
+++ b/third_party/xla/xla/service/batchnorm_expander.cc
@@ -15,27 +15,32 @@
 
 #include "xla/service/batchnorm_expander.h"
 
+#include <cstdint>
 #include <memory>
 #include <optional>
-#include <string>
 #include <utility>
 #include <vector>
 
+#include "absl/container/flat_hash_set.h"
+#include "absl/functional/function_ref.h"
+#include "absl/log/check.h"
+#include "absl/strings/string_view.h"
 #include "absl/types/span.h"
 #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
 #include "xla/hlo/ir/hlo_computation.h"
 #include "xla/hlo/ir/hlo_instruction.h"
 #include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/hlo/ir/hlo_sharding.h"
 #include "xla/literal.h"
 #include "xla/literal_util.h"
+#include "xla/shape.h"
 #include "xla/shape_util.h"
-#include "xla/status_macros.h"
-#include "xla/types.h"
+#include "xla/status.h"
+#include "xla/statusor.h"
 #include "xla/util.h"
 #include "xla/xla_data.pb.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
 #include "tsl/platform/status.h"
+#include "tsl/platform/statusor.h"
 
 namespace xla {
 
@@ -91,8 +96,8 @@
       HloInstruction* element_count, HloInstruction* operand,
       absl::FunctionRef<HloInstruction*(std::unique_ptr<HloInstruction>)>
           add_instruction) {
-    auto broadcast = add_instruction(
-        HloInstruction::CreateBroadcast(operand->shape(), element_count, {}));
+    auto broadcast = add_instruction(HloInstruction::CreateBroadcast(
+        ShapeUtil::MakeStaticShape(operand->shape()), element_count, {}));
     return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kDivide,
                                         operand, broadcast);
   }
@@ -180,8 +185,9 @@
 
   auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
   TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
+  Shape scalar_broadcast_shape = ShapeUtil::MakeStaticShape(operand_shape);
   auto epsilon = add(HloInstruction::CreateBroadcast(
-      operand_shape,
+      scalar_broadcast_shape,
       add(HloInstruction::CreateConstant(std::move(epsilon_literal))), {}));
   std::vector<int64_t> dimensions_without_feature;
   const int64_t rank = operand_shape.rank();
@@ -196,11 +202,17 @@
   auto elements_per_feature =
       add(DynamicElementCountPerFeature(operand, feature_index, add));
 
-  auto scale_broadcasted = add(
-      HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index}));
+  auto feature_broadcast = [&](HloInstruction* inst) -> HloInstruction* {
+    Shape feature_broadcast_shape = scalar_broadcast_shape;
+    feature_broadcast_shape.set_dynamic_dimension(
+        feature_index, inst->shape().is_dynamic_dimension(0));
+    return add(HloInstruction::CreateBroadcast(feature_broadcast_shape, inst,
+                                               {feature_index}));
+  };
 
-  auto offset_broadcasted = add(
-      HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index}));
+  auto scale_broadcasted = feature_broadcast(scale);
+
+  auto offset_broadcasted = feature_broadcast(offset);
 
   HloComputation* add_reduce_computation =
       GetOrCreateScalarAddComputation(ptype);
@@ -221,8 +233,7 @@
   // E[X].
   auto mean = add(Mean(elements_per_feature, sum, add));
 
-  auto mean_broadcasted = add(
-      HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index}));
+  auto mean_broadcasted = feature_broadcast(mean);
 
   // E[X^2].
   auto square_mean = add(Mean(elements_per_feature, squared_sum, add));
@@ -235,12 +246,11 @@
   auto var =
       add_binary(feature_shape, HloOpcode::kSubtract, square_mean, mean_square);
 
-  auto var_broadcasted =
-      add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index}));
+  auto var_broadcasted = feature_broadcast(var);
 
   // Var[X] + epsilon.
-  auto var_add_epsilon =
-      add_binary(operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon);
+  auto var_add_epsilon = add_binary(var_broadcasted->shape(), HloOpcode::kAdd,
+                                    var_broadcasted, epsilon);
 
   // 1 / Sqrt[Var[X] + epsilon].
   auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon));
@@ -304,11 +314,12 @@
   HloInstruction* mean = batch_norm->mutable_operand(3);
   HloInstruction* var = batch_norm->mutable_operand(4);
   const Shape feature_shape = scale->shape();
+  Shape scalar_broadcast_shape = ShapeUtil::MakeStaticShape(feature_shape);
 
   auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
   TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
   auto epsilon = computation_->AddInstruction(HloInstruction::CreateBroadcast(
-      feature_shape,
+      scalar_broadcast_shape,
       computation_->AddInstruction(
           HloInstruction::CreateConstant(std::move(epsilon_literal))),
       {}));
@@ -335,8 +346,11 @@
     return add(HloInstruction::CreateBinary(shape, opcode, a, b));
   };
   auto feature_broadcast = [&](HloInstruction* a) {
+    Shape broadcast_shape = ShapeUtil::MakeStaticShape(operand_shape);
+    broadcast_shape.set_dynamic_dimension(feature_index,
+                                          a->shape().is_dynamic_dimension(0));
     return add(
-        HloInstruction::CreateBroadcast(operand_shape, a, {feature_index}));
+        HloInstruction::CreateBroadcast(broadcast_shape, a, {feature_index}));
   };
 
   int64_t instruction_count_before = computation_->instruction_count();
@@ -428,10 +442,10 @@
   TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
   auto epsilon_scalar =
       add(HloInstruction::CreateConstant(std::move(epsilon_literal)));
-  auto epsilon_activation = add(
-      HloInstruction::CreateBroadcast(activation_shape, epsilon_scalar, {}));
-  auto epsilon_feature =
-      add(HloInstruction::CreateBroadcast(feature_shape, epsilon_scalar, {}));
+  auto epsilon_activation = add(HloInstruction::CreateBroadcast(
+      ShapeUtil::MakeStaticShape(activation_shape), epsilon_scalar, {}));
+  auto epsilon_feature = add(HloInstruction::CreateBroadcast(
+      ShapeUtil::MakeStaticShape(feature_shape), epsilon_scalar, {}));
 
   std::vector<int64_t> dimensions_without_feature;
   const int64_t rank = activation_shape.rank();
@@ -443,18 +457,23 @@
     }
   }
 
-  auto scale_broadcasted = add(HloInstruction::CreateBroadcast(
-      activation_shape, scale, {feature_index}));
-  auto variance_broadcasted = add(HloInstruction::CreateBroadcast(
-      activation_shape, variance, {feature_index}));
+  auto activation_broadcast = [&](HloInstruction* hlo) -> HloInstruction* {
+    Shape broadcast_shape = ShapeUtil::MakeStaticShape(activation_shape);
+    broadcast_shape.set_dynamic_dimension(feature_index,
+                                          hlo->shape().is_dynamic_dimension(0));
+    return add(
+        HloInstruction::CreateBroadcast(broadcast_shape, hlo, {feature_index}));
+  };
+
+  auto scale_broadcasted = activation_broadcast(scale);
+  auto variance_broadcasted = activation_broadcast(variance);
 
   // E[X].
-  auto mean_broadcasted = add(
-      HloInstruction::CreateBroadcast(activation_shape, mean, {feature_index}));
+  auto mean_broadcasted = activation_broadcast(mean);
 
   // rsqrt[Var[X] + epsilon].
   auto rsqrt_var_add_epsilon_broadcasted =
-      add(Rsqrt(add_binary(activation_shape, HloOpcode::kAdd,
+      add(Rsqrt(add_binary(variance_broadcasted->shape(), HloOpcode::kAdd,
                            variance_broadcasted, epsilon_activation)));
 
   auto rsqrt_var_add_epsilon = add(Rsqrt(
@@ -489,34 +508,40 @@
                                rsqrt_var_add_epsilon);
 
   // I2 = Sum(Grad[Y])
-  auto i2 = add(HloInstruction::CreateBroadcast(activation_shape, grad_beta,
-                                                {feature_index}));
+  auto i2 = activation_broadcast(grad_beta);
 
   // I3 = Sum(Grad[Y] * (X - E[X]))
-  auto i3 = add(HloInstruction::CreateBroadcast(
-      activation_shape, sum_grad_output_times_activation_minus_mean,
-      {feature_index}));
+  auto i3 = activation_broadcast(sum_grad_output_times_activation_minus_mean);
 
   // I4 = (X - E[X]) * I3
   auto i4 = add_binary(activation_shape, HloOpcode::kMultiply, i3,
                        activation_minus_mean);
 
   // I5 = I4 / (Var[X] + epsilon)
-  auto i5 = add_binary(activation_shape, HloOpcode::kDivide, i4,
-                       add_binary(activation_shape, HloOpcode::kAdd,
-                                  variance_broadcasted, epsilon_activation));
+  auto i5 =
+      add_binary(activation_shape, HloOpcode::kDivide, i4,
+                 add_binary(variance_broadcasted->shape(), HloOpcode::kAdd,
+                            variance_broadcasted, epsilon_activation));
 
   // scale * rsqrt[Var[X] + epsilon] * 1/N
+  Shape scale_times_rsqrt_var_add_epsilon_shape = scale_broadcasted->shape();
+  for (int64_t i = 0; i < rsqrt_var_add_epsilon_broadcasted->shape().rank();
+       ++i) {
+    if (rsqrt_var_add_epsilon_broadcasted->shape().is_dynamic_dimension(i)) {
+      scale_times_rsqrt_var_add_epsilon_shape.set_dynamic_dimension(i, true);
+    }
+  }
   auto scale_times_rsqrt_var_add_epsilon =
-      add_binary(activation_shape, HloOpcode::kMultiply, scale_broadcasted,
-                 rsqrt_var_add_epsilon_broadcasted);
+      add_binary(scale_times_rsqrt_var_add_epsilon_shape, HloOpcode::kMultiply,
+                 scale_broadcasted, rsqrt_var_add_epsilon_broadcasted);
 
   scale_times_rsqrt_var_add_epsilon =
       add(Mean(elements_per_feature, scale_times_rsqrt_var_add_epsilon, add));
 
-  auto i1 = add_binary(activation_shape, HloOpcode::kMultiply, grad_output,
+  auto i1 = add_binary(grad_output->shape(), HloOpcode::kMultiply, grad_output,
                        add(HloInstruction::CreateBroadcast(
-                           activation_shape, elements_per_feature, {})));
+                           ShapeUtil::MakeStaticShape(activation_shape),
+                           elements_per_feature, {})));
 
   // I6 = I1 - I2 - I5
   auto i6 = add_binary(
diff --git a/third_party/xla/xla/service/bitcast_dtypes_expander_test.cc b/third_party/xla/xla/service/bitcast_dtypes_expander_test.cc
index 00b5ec2..b400b28 100644
--- a/third_party/xla/xla/service/bitcast_dtypes_expander_test.cc
+++ b/third_party/xla/xla/service/bitcast_dtypes_expander_test.cc
@@ -45,26 +45,26 @@
   EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
 // CHECK: HloModule bitcast_to_smaller
 // CHECK: %xla.bitcast_convert_s32_10__2_s8_10_4_.17 (a.1: s32[10]) -> s8[10,4] {
-// CHECK:  %a.1 = s32[10]{0} parameter(0)
-// CHECK:  %reshape.2 = s32[10,1]{1,0} reshape(s32[10]{0} %a.1)
-// CHECK:  %broadcast.3 = s32[10,1]{1,0} broadcast(s32[10,1]{1,0} %reshape.2), dimensions={0,1}
-// CHECK:  %reshape.4 = s32[10]{0} reshape(s32[10,1]{1,0} %broadcast.3)
-// CHECK:  %broadcast.5 = s32[10,4]{1,0} broadcast(s32[10]{0} %reshape.4), dimensions={0}
-// CHECK:  %bitcast-convert.6 = u32[10,4]{1,0} bitcast-convert(s32[10,4]{1,0} %broadcast.5)
-// CHECK:  %constant.8 = u32[] constant(8)
-// CHECK:  %broadcast.9 = u32[10,4]{1,0} broadcast(u32[] %constant.8), dimensions={}
-// CHECK:  %iota.7 = u32[10,4]{1,0} iota(), iota_dimension=1
-// CHECK:  %multiply.10 = u32[10,4]{1,0} multiply(u32[10,4]{1,0} %broadcast.9, u32[10,4]{1,0} %iota.7)
-// CHECK:  %shift-right-logical{{\.?[0-9]*}} = u32[10,4]{1,0} shift-right-logical(u32[10,4]{1,0} %bitcast-convert.6, u32[10,4]{1,0} %multiply.10)
-// CHECK:  %constant{{\.?[0-9]*}} = u32[] constant(255)
-// CHECK:  %broadcast.13 = u32[10,4]{1,0} broadcast(u32[] %constant{{\.?[0-9]*}}), dimensions={}
-// CHECK:  %and.14 = u32[10,4]{1,0} and(u32[10,4]{1,0} %shift-right-logical{{\.?[0-9]*}}, u32[10,4]{1,0} %broadcast.13)
-// CHECK:  %convert.15 = u8[10,4]{1,0} convert(u32[10,4]{1,0} %and.14)
-// CHECK:  ROOT %bitcast-convert.16 = s8[10,4]{1,0} bitcast-convert(u8[10,4]{1,0} %convert.15)
+// CHECK:  %[[VAL_0:.*]] = s32[10]{0} parameter(0)
+// CHECK:  %[[VAL_1:.*]] = s32[10,1]{1,0} reshape(s32[10]{0} %[[VAL_0]])
+// CHECK:  %[[VAL_2:.*]] = s32[10,1]{1,0} broadcast(s32[10,1]{1,0} %[[VAL_1]]), dimensions={0,1}
+// CHECK:  %[[VAL_3:.*]] = s32[10]{0} reshape(s32[10,1]{1,0} %[[VAL_2]])
+// CHECK:  %[[VAL_4:.*]] = s32[10,4]{1,0} broadcast(s32[10]{0} %[[VAL_3]]), dimensions={0}
+// CHECK:  %[[VAL_5:.*]] = u32[10,4]{1,0} bitcast-convert(s32[10,4]{1,0} %[[VAL_4]])
+// CHECK:  %[[VAL_6:.*]] = u32[] constant(8)
+// CHECK:  %[[VAL_7:.*]] = u32[10,4]{1,0} broadcast(u32[] %[[VAL_6]]), dimensions={}
+// CHECK:  %[[VAL_8:.*]] = u32[10,4]{1,0} iota(), iota_dimension=1
+// CHECK:  %[[VAL_9:.*]] = u32[10,4]{1,0} multiply(u32[10,4]{1,0} %[[VAL_7]], u32[10,4]{1,0} %[[VAL_8]])
+// CHECK:  %[[VAL_10:.*]] = u32[10,4]{1,0} shift-right-logical(u32[10,4]{1,0} %[[VAL_5]], u32[10,4]{1,0} %[[VAL_9]])
+// CHECK:  %[[VAL_11:.*]] = u32[] constant(255)
+// CHECK:  %[[VAL_12:.*]] = u32[10,4]{1,0} broadcast(u32[] %[[VAL_11]]), dimensions={}
+// CHECK:  %[[VAL_13:.*]] = u32[10,4]{1,0} and(u32[10,4]{1,0} %[[VAL_10]], u32[10,4]{1,0} %[[VAL_12]])
+// CHECK:  %[[VAL_14:.*]] = u8[10,4]{1,0} convert(u32[10,4]{1,0} %[[VAL_13]])
+// CHECK:  ROOT %[[VAL_15:.*]] = s8[10,4]{1,0} bitcast-convert(u8[10,4]{1,0} %[[VAL_14]])
 // CHECK: }
 // CHECK: ENTRY %main (p: s32[10]) -> s8[10,4] {
-// CHECK:  %p = s32[10]{0} parameter(0)
-// CHECK:  ROOT %call = s8[10,4]{1,0} call(s32[10]{0} %p), to_apply=%xla.bitcast_convert_s32_10__2_s8_10_4_.17
+// CHECK:  %[[VAL_16:.*]] = s32[10]{0} parameter(0)
+// CHECK:  ROOT %[[VAL_17:.*]] = s8[10,4]{1,0} call(s32[10]{0} %[[VAL_16]]), to_apply=%[[VAL_18:.*]]
 // CHECK: }
 )"));
 }
@@ -88,26 +88,26 @@
   EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
 // CHECK: HloModule bitcast_to_smaller, entry_computation_layout={(s64[10]{0})->s32[10,2]{1,0}}
 // CHECK: %xla.bitcast_convert_s64_10__2_s32_10_2_.17 (a.1: s64[10]) -> s32[10,2] {
-// CHECK:   %a.1 = s64[10]{0} parameter(0)
-// CHECK:   %reshape.2 = s64[10,1]{1,0} reshape(s64[10]{0} %a.1)
-// CHECK:   %broadcast.3 = s64[10,1]{1,0} broadcast(s64[10,1]{1,0} %reshape.2), dimensions={0,1}
-// CHECK:   %reshape.4 = s64[10]{0} reshape(s64[10,1]{1,0} %broadcast.3)
-// CHECK:   %broadcast.5 = s64[10,2]{1,0} broadcast(s64[10]{0} %reshape.4), dimensions={0}
-// CHECK:   %bitcast-convert.6 = u64[10,2]{1,0} bitcast-convert(s64[10,2]{1,0} %broadcast.5)
-// CHECK:   %constant.8 = u64[] constant(32)
-// CHECK:   %broadcast.9 = u64[10,2]{1,0} broadcast(u64[] %constant.8), dimensions={}
-// CHECK:   %iota.7 = u64[10,2]{1,0} iota(), iota_dimension=1
-// CHECK:   %multiply.10 = u64[10,2]{1,0} multiply(u64[10,2]{1,0} %broadcast.9, u64[10,2]{1,0} %iota.7)
-// CHECK:   %shift-right-logical.11 = u64[10,2]{1,0} shift-right-logical(u64[10,2]{1,0} %bitcast-convert.6, u64[10,2]{1,0} %multiply.10)
-// CHECK:   %constant.12 = u64[] constant(4294967295)
-// CHECK:   %broadcast.13 = u64[10,2]{1,0} broadcast(u64[] %constant.12), dimensions={}
-// CHECK:   %and.14 = u64[10,2]{1,0} and(u64[10,2]{1,0} %shift-right-logical.11, u64[10,2]{1,0} %broadcast.13)
-// CHECK:   %convert.15 = u32[10,2]{1,0} convert(u64[10,2]{1,0} %and.14)
-// CHECK:   ROOT %bitcast-convert.16 = s32[10,2]{1,0} bitcast-convert(u32[10,2]{1,0} %convert.15)
+// CHECK:   %[[VAL_0:.*]] = s64[10]{0} parameter(0)
+// CHECK:   %[[VAL_1:.*]] = s64[10,1]{1,0} reshape(s64[10]{0} %[[VAL_0]])
+// CHECK:   %[[VAL_2:.*]] = s64[10,1]{1,0} broadcast(s64[10,1]{1,0} %[[VAL_1]]), dimensions={0,1}
+// CHECK:   %[[VAL_3:.*]] = s64[10]{0} reshape(s64[10,1]{1,0} %[[VAL_2]])
+// CHECK:   %[[VAL_4:.*]] = s64[10,2]{1,0} broadcast(s64[10]{0} %[[VAL_3]]), dimensions={0}
+// CHECK:   %[[VAL_5:.*]] = u64[10,2]{1,0} bitcast-convert(s64[10,2]{1,0} %[[VAL_4]])
+// CHECK:   %[[VAL_6:.*]] = u64[] constant(32)
+// CHECK:   %[[VAL_7:.*]] = u64[10,2]{1,0} broadcast(u64[] %[[VAL_6]]), dimensions={}
+// CHECK:   %[[VAL_8:.*]] = u64[10,2]{1,0} iota(), iota_dimension=1
+// CHECK:   %[[VAL_9:.*]] = u64[10,2]{1,0} multiply(u64[10,2]{1,0} %[[VAL_7]], u64[10,2]{1,0} %[[VAL_8]])
+// CHECK:   %[[VAL_10:.*]] = u64[10,2]{1,0} shift-right-logical(u64[10,2]{1,0} %[[VAL_5]], u64[10,2]{1,0} %[[VAL_9]])
+// CHECK:   %[[VAL_11:.*]] = u64[] constant(4294967295)
+// CHECK:   %[[VAL_12:.*]] = u64[10,2]{1,0} broadcast(u64[] %[[VAL_11]]), dimensions={}
+// CHECK:   %[[VAL_13:.*]] = u64[10,2]{1,0} and(u64[10,2]{1,0} %[[VAL_10]], u64[10,2]{1,0} %[[VAL_12]])
+// CHECK:   %[[VAL_14:.*]] = u32[10,2]{1,0} convert(u64[10,2]{1,0} %[[VAL_13]])
+// CHECK:   ROOT %[[VAL_15:.*]] = s32[10,2]{1,0} bitcast-convert(u32[10,2]{1,0} %[[VAL_14]])
 // CHECK: }
 // CHECK: ENTRY %main (p: s64[10]) -> s32[10,2] {
-// CHECK:   %p = s64[10]{0} parameter(0)
-// CHECK:   ROOT %call = s32[10,2]{1,0} call(s64[10]{0} %p), to_apply=%xla.bitcast_convert_s64_10__2_s32_10_2_.17
+// CHECK:   %[[VAL_16:.*]] = s64[10]{0} parameter(0)
+// CHECK:   ROOT %[[VAL_17:.*]] = s32[10,2]{1,0} call(s64[10]{0} %[[VAL_16]]), to_apply=%[[VAL_18:.*]]
 // CHECK: }
 )"));
 }
@@ -132,22 +132,27 @@
   EXPECT_TRUE(changed);
   EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
 // CHECK: HloModule bitcast_to_larger
+// CHECK: %or_U32.10 (lhs.11: u32[], rhs.12: u32[]) -> u32[] {
+// CHECK:  %[[VAL_0:.*]] = u32[] parameter(0)
+// CHECK:  %[[VAL_1:.*]] = u32[] parameter(1)
+// CHECK:  ROOT %[[VAL_2:.*]] = u32[] or(u32[] %[[VAL_0]], u32[] %[[VAL_1]])
+// CHECK: }
 // CHECK: %xla.bitcast_convert_s8_10_4__2_s32_10_.16 (a.1: s8[10,4]) -> s32[10] {
-// CHECK:  %a.1 = s8[10,4]{1,0} parameter(0)
-// CHECK:  %bitcast-convert.2 = u8[10,4]{1,0} bitcast-convert(s8[10,4]{1,0} %a.1)
-// CHECK:  %convert.3 = u32[10,4]{1,0} convert(u8[10,4]{1,0} %bitcast-convert.2)
-// CHECK:  %constant{{\.?[0-9]*}} = u32[] constant(8)
-// CHECK:  %broadcast.6 = u32[10,4]{1,0} broadcast(u32[] %constant{{\.?[0-9]*}}), dimensions={}
-// CHECK:  %iota{{\.?[0-9]*}} = u32[10,4]{1,0} iota(), iota_dimension=1
-// CHECK:  %multiply.7 = u32[10,4]{1,0} multiply(u32[10,4]{1,0} %broadcast.6, u32[10,4]{1,0} %iota{{\.?[0-9]*}})
-// CHECK:  %shift-left.8 = u32[10,4]{1,0} shift-left(u32[10,4]{1,0} %convert.3, u32[10,4]{1,0} %multiply.7)
-// CHECK:  %constant.9 = u32[] constant(0)
-// CHECK:  %reduce.14 = u32[10]{0} reduce(u32[10,4]{1,0} %shift-left.8, u32[] %constant.9), dimensions={1}, to_apply=%or_U32.10
-// CHECK:  ROOT %bitcast-convert.15 = s32[10]{0} bitcast-convert(u32[10]{0} %reduce.14)
+// CHECK:  %[[VAL_3:.*]] = s8[10,4]{1,0} parameter(0)
+// CHECK:  %[[VAL_4:.*]] = u8[10,4]{1,0} bitcast-convert(s8[10,4]{1,0} %[[VAL_3]])
+// CHECK:  %[[VAL_5:.*]] = u32[10,4]{1,0} convert(u8[10,4]{1,0} %[[VAL_4]])
+// CHECK:  %[[VAL_6:.*]] = u32[] constant(8)
+// CHECK:  %[[VAL_7:.*]] = u32[10,4]{1,0} broadcast(u32[] %[[VAL_6]]), dimensions={}
+// CHECK:  %[[VAL_8:.*]] = u32[10,4]{1,0} iota(), iota_dimension=1
+// CHECK:  %[[VAL_9:.*]] = u32[10,4]{1,0} multiply(u32[10,4]{1,0} %[[VAL_7]], u32[10,4]{1,0} %[[VAL_8]])
+// CHECK:  %[[VAL_10:.*]] = u32[10,4]{1,0} shift-left(u32[10,4]{1,0} %[[VAL_5]], u32[10,4]{1,0} %[[VAL_9]])
+// CHECK:  %[[VAL_11:.*]] = u32[] constant(0)
+// CHECK:  %[[VAL_12:.*]] = u32[10]{0} reduce(u32[10,4]{1,0} %[[VAL_10]], u32[] %[[VAL_11]]), dimensions={1}, to_apply=%[[VAL_13:.*]]
+// CHECK:  ROOT %[[VAL_14:.*]] = s32[10]{0} bitcast-convert(u32[10]{0} %[[VAL_12]])
 // CHECK: }
 // CHECK: ENTRY %main (p: s8[10,4]) -> s32[10] {
-// CHECK:  %p = s8[10,4]{1,0} parameter(0)
-// CHECK:  ROOT %call = s32[10]{0} call(s8[10,4]{1,0} %p), to_apply=%xla.bitcast_convert_s8_10_4__2_s32_10_.16
+// CHECK:  %[[VAL_15:.*]] = s8[10,4]{1,0} parameter(0)
+// CHECK:  ROOT %[[VAL_16:.*]] = s32[10]{0} call(s8[10,4]{1,0} %[[VAL_15]]), to_apply=%[[VAL_17:.*]]
 // CHECK: }
 )"));
 }
diff --git a/third_party/xla/xla/service/buffer_assignment.cc b/third_party/xla/xla/service/buffer_assignment.cc
index 3235e5d..84663e9 100644
--- a/third_party/xla/xla/service/buffer_assignment.cc
+++ b/third_party/xla/xla/service/buffer_assignment.cc
@@ -30,6 +30,7 @@
 #include "absl/container/btree_map.h"
 #include "absl/container/flat_hash_map.h"
 #include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
 #include "absl/memory/memory.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/str_format.h"
@@ -707,8 +708,14 @@
   for (size_t index = 0; index < allocations_.size(); ++index) {
     BufferAllocation* allocation = &allocations_[index];
     allocation->set_index(index);
+    std::vector<const HloValue*> sorted_values;
+    sorted_values.reserve(allocation->assigned_buffers_.size());
     for (const auto& buffer_offset_size : allocation->assigned_buffers_) {
       const HloValue* value = buffer_offset_size.first;
+      sorted_values.emplace(sorted_values.end(), value);
+    }
+    absl::c_sort(sorted_values, &CompareHloValuesById);
+    for (const HloValue* value : sorted_values) {
       allocation_index_for_value_[value] = index;
     }
   }
@@ -1189,7 +1196,7 @@
       for (const auto& buffer_offset_size : allocation->assigned_buffers()) {
         const HloValue* value = buffer_offset_size.first;
         if ((*must_not_live_out_)(value->instruction(), value->index())) {
-          VLOG(4) << "Can't assign: " << buffer_offset_size.first->instruction()
+          VLOG(4) << "Can't assign: " << value->instruction()
                   << " cannot live out of the module";
           return false;
         }
@@ -1650,8 +1657,14 @@
                                    buffers_to_assign.end());
     }
     auto color_map = SplitBuffersByColor(all_buffers_to_assign);
+    std::vector<LogicalBuffer::Color> sorted_colors;
+    sorted_colors.reserve(color_map.size());
     for (auto& single_colored_set : color_map) {
       auto color = single_colored_set.first;
+      sorted_colors.emplace(sorted_colors.end(), color);
+    }
+    absl::c_sort(sorted_colors);
+    for (auto color : sorted_colors) {
       VLOG(2) << "Simulating heap for color " << color;
       int64_t alignment = assignment->color_alignment_(color);
       HeapSimulator::Options options;
@@ -1666,7 +1679,7 @@
         // whole-module heap simulation. Performing heap simulation from the
         // private stack computation allows better temporal reuse of buffers.
         auto computation_map = SplitBuffersByPrivateStackComputation(
-            single_colored_set.second, private_stacks_it->second,
+            color_map[color], private_stacks_it->second,
             assignment->alias_analysis().dataflow_analysis().call_graph());
         for (const HloComputation* private_stack_computation :
              private_stacks_it->second) {
@@ -1684,19 +1697,19 @@
                   get_heap_algorithm(alignment), *private_stack_computation,
                   *instruction_sequence, assignment->alias_analysis(),
                   assignment->buffer_size_, &schedule, options));
-          AssignBuffersFromHeapSimulator(
-              result, assignment, single_colored_set.first, isolation_options);
+          AssignBuffersFromHeapSimulator(result, assignment, color,
+                                         isolation_options);
         }
       } else {
-        options.buffers_to_assign = &single_colored_set.second;
+        options.buffers_to_assign = &color_map[color];
         TF_ASSIGN_OR_RETURN(
             HeapSimulator::Result<HloValue> result,
             HeapSimulator::Run(get_heap_algorithm(alignment),
                                assignment->module(), schedule,
                                assignment->alias_analysis(),
                                assignment->buffer_size_, options));
-        AssignBuffersFromHeapSimulator(
-            result, assignment, single_colored_set.first, isolation_options);
+        AssignBuffersFromHeapSimulator(result, assignment, color,
+                                       isolation_options);
       }
     }
   } else {
@@ -1711,20 +1724,26 @@
           hlo_ordering.SequentialOrder(*computation);
       CHECK(instruction_sequence != nullptr) << computation->name();
       auto color_map = SplitBuffersByColor(buffers_to_assign);
+      std::vector<LogicalBuffer::Color> sorted_colors;
+      sorted_colors.reserve(color_map.size());
       for (auto& single_colored_set : color_map) {
         auto color = single_colored_set.first;
+        sorted_colors.emplace(sorted_colors.end(), color);
+      }
+      absl::c_sort(sorted_colors);
+      for (auto color : sorted_colors) {
         VLOG(2) << "Simulating heap for color " << color;
         int64_t alignment = assignment->color_alignment_(color);
         HeapSimulator::Options options;
-        options.buffers_to_assign = &single_colored_set.second;
+        options.buffers_to_assign = &color_map[color];
         TF_ASSIGN_OR_RETURN(
             HeapSimulator::Result<HloValue> result,
             HeapSimulator::Run(get_heap_algorithm(alignment), *computation,
                                *instruction_sequence,
                                assignment->alias_analysis(),
                                assignment->buffer_size_, options));
-        AssignBuffersFromHeapSimulator(
-            result, assignment, single_colored_set.first, isolation_options);
+        AssignBuffersFromHeapSimulator(result, assignment, color,
+                                       isolation_options);
       }
     }
   }
diff --git a/third_party/xla/xla/service/collective_pipeliner.cc b/third_party/xla/xla/service/collective_pipeliner.cc
index 1b64e3a..3d5ddf3 100644
--- a/third_party/xla/xla/service/collective_pipeliner.cc
+++ b/third_party/xla/xla/service/collective_pipeliner.cc
@@ -45,6 +45,7 @@
 #include "xla/map_util.h"
 #include "xla/primitive_util.h"
 #include "xla/service/constant_value.h"
+#include "xla/service/hlo_dce.h"
 #include "xla/service/value_range.h"
 #include "xla/shape.h"
 #include "xla/shape_util.h"
@@ -283,7 +284,8 @@
 CheckStoreIntoSliceIsCompatible(HloInstruction* instr,
                                 const HloComputation* while_body,
                                 int64_t level_to_operate_on,
-                                bool multi_uses_pipelining) {
+                                bool multi_uses_pipelining,
+                                HloPredicate acceptable_formatting) {
   if ((!multi_uses_pipelining && instr->user_count() != 1) ||
       instr->operand_count() != 1 || instr->HasControlDependencies()) {
     return std::make_pair(nullptr, std::vector<HloInstruction*>{});
@@ -297,13 +299,20 @@
   // of being saved across the loop. So protect them through
   // "multi_uses_pipelining" flag.
   auto is_acceptable_user = [&](HloInstruction* i) {
-    return (HloPredicateIsOp<HloOpcode::kSlice, HloOpcode::kDynamicSlice,
-                             HloOpcode::kPad, HloOpcode::kCollectivePermute,
-                             HloOpcode::kConvert, HloOpcode::kReshape,
-                             HloOpcode::kTranspose>(i) ||
-            (multi_uses_pipelining && i->IsElementwise()) ||
-            i->IsCustomCall(CollectivePipeliner::kInsertedByPreviousStep)) &&
-           !i->HasControlDependencies();
+    if (i->HasControlDependencies() || !acceptable_formatting(i)) {
+      return false;
+    }
+    if (i->opcode() == HloOpcode::kReduce &&
+        ShapeUtil::ElementsIn(i->shape()) ==
+            ShapeUtil::ElementsIn(instr->operand(0)->shape())) {
+      return true;
+    }
+    return HloPredicateIsOp<HloOpcode::kSlice, HloOpcode::kDynamicSlice,
+                            HloOpcode::kPad, HloOpcode::kCollectivePermute,
+                            HloOpcode::kConvert, HloOpcode::kReshape,
+                            HloOpcode::kAllReduce, HloOpcode::kTranspose>(i) ||
+           (multi_uses_pipelining && i->IsElementwise()) ||
+           i->IsCustomCall(CollectivePipeliner::kInsertedByPreviousStep);
   };
   // Returns if this instruction is a dynamic-update-slice inserting the value
   // into a bigger buffer that we are going to pipeline to the next iteration.
@@ -353,7 +362,6 @@
         return std::make_pair(nullptr, std::vector<HloInstruction*>{});
       }
       final_slice_insertion = Cast<HloDynamicUpdateSliceInstruction>(next_user);
-      stack.pop_back();
       continue;
     }
     if (!is_acceptable_user(next_user)) {
@@ -398,7 +406,6 @@
         return std::make_pair(nullptr, std::vector<HloInstruction*>{});
       }
       final_slice_insertion = Cast<HloDynamicUpdateSliceInstruction>(next_user);
-      stack.pop_back();
       continue;
     }
     if (--formatting_map[next_user] > 0) {
@@ -647,7 +654,7 @@
   void CollectCollectivesToMove(
       int64_t level_to_operate_on,
       CollectivePipeliner::PipeliningDirection direction,
-      HloPredicate should_process);
+      HloPredicate should_process, HloPredicate acceptable_formatting);
   HloInstruction* while_loop_instruction() const { return while_; }
 
  private:
@@ -755,7 +762,7 @@
 void WhileLoopAnalysis::CollectCollectivesToMove(
     int64_t level_to_operate_on,
     CollectivePipeliner::PipeliningDirection direction,
-    HloPredicate should_process) {
+    HloPredicate should_process, HloPredicate acceptable_formatting) {
   move_infos_.clear();
   HloComputation* while_body = while_->while_body();
   const HloInstruction* loop_parameter =
@@ -783,24 +790,28 @@
   absl::flat_hash_map<const HloInstruction*, Range> index_ranges;
   absl::flat_hash_map<const HloInstruction*, int64_t>
       index_per_dyn_update_slice;
+  std::optional<Range> index_range;
   if (loop_bound_) {
     // Compute the range of the index as "start + iteration_count * increment"
-    Range index_range =
-        Range{*loop_start_,
-              loop_start_->add(
-                  loop_iteration_count_
-                      ->sub(ConstantValue::GetOne(loop_start_->GetBitwidth(),
-                                                  loop_start_->IsSigned()))
-                      .mul(*loop_increment_)),
-              /*is_linear=*/true};
-    for (auto* instr : while_body->instructions()) {
-      if (instr->opcode() == HloOpcode::kGetTupleElement) {
-        if (instr->tuple_index() == 0) {
-          index_ranges.insert({instr, index_range});
-        }
+    index_range = Range{*loop_start_,
+                        loop_start_->add(loop_iteration_count_
+                                             ->sub(ConstantValue::GetOne(
+                                                 loop_start_->GetBitwidth(),
+                                                 loop_start_->IsSigned()))
+                                             .mul(*loop_increment_)),
+                        /*is_linear=*/true};
+  }
+  int64_t count = 0;
+  absl::flat_hash_map<const HloInstruction*, int64_t> instruction_order;
+  for (auto* instr : while_body->MakeInstructionPostOrder()) {
+    if (instr->opcode() == HloOpcode::kGetTupleElement) {
+      if (index_range && instr->tuple_index() == 0) {
+        index_ranges.insert({instr, *index_range});
       }
     }
+    instruction_order[instr] = count++;
   }
+
   for (auto* instr : while_body->instructions()) {
     if (direction == CollectivePipeliner::PipeliningDirection::kForward &&
         (instr->operand_count() != 1 ||
@@ -814,7 +825,8 @@
     if (direction == CollectivePipeliner::PipeliningDirection::kForward ||
         direction == CollectivePipeliner::PipeliningDirection::kForwardSink) {
       auto [dyn_update, formatting_ops] = CheckStoreIntoSliceIsCompatible(
-          instr, while_body, level_to_operate_on, pipeline_use_tree_);
+          instr, while_body, level_to_operate_on, pipeline_use_tree_,
+          acceptable_formatting);
       if (dyn_update == nullptr) {
         VLOG(5)
             << "Skipping " << instr->ToString()
@@ -916,9 +928,8 @@
                 << " because couldn't find unique output index for insertion";
         continue;
       }
-      //
       auto merge_as_formatting =
-          [this](
+          [this, &instruction_order](
               absl::flat_hash_map<const HloInstruction*, int64_t>::iterator it,
               HloInstruction* instr, HloInstruction* dyn_upd,
               absl::Span<HloInstruction* const> formatting_ops) {
@@ -929,21 +940,21 @@
                 move_infos_[it->second].formatting_ops.end());
             existing_entry_instrs.insert(
                 move_infos_[it->second].collective_to_move);
-            std::vector<HloInstruction*> to_merge;
             // If instr is already in the set then this instruction is already
             // in formatting-ops of the other one, so its already pipelined.
             if (existing_entry_instrs.count(instr)) {
               return;
             }
-            to_merge.push_back(instr);
+            move_infos_[it->second].formatting_ops.push_back(instr);
             for (auto* op : formatting_ops) {
               if (!existing_entry_instrs.count(op)) {
-                to_merge.push_back(op);
+                move_infos_[it->second].formatting_ops.push_back(op);
               }
             }
-            move_infos_[it->second].formatting_ops.insert(
-                move_infos_[it->second].formatting_ops.begin(),
-                to_merge.begin(), to_merge.end());
+            absl::c_sort(move_infos_[it->second].formatting_ops,
+                         [&](const HloInstruction* a, const HloInstruction* b) {
+                           return instruction_order[a] < instruction_order[b];
+                         });
           };
       auto it = index_per_dyn_update_slice.find(dyn_update);
       if (it != index_per_dyn_update_slice.end()) {
@@ -1121,6 +1132,8 @@
                             int64_t level_to_operate_on, bool pipeline_use_tree,
                             bool process_different_sized_ops,
                             HloPredicate should_process,
+                            HloPredicate acceptable_formatting,
+                            HloPredicate reuse_output_buffer,
                             int64_t& next_channel_id) {
   // Defining some maps/sets to keep track of instructions duplicated.
   InstructionMap while_body_to_peeled;
@@ -1139,7 +1152,8 @@
     const Shape& output_shape = to_move.formatting_ops.empty()
                                     ? to_move.collective_to_move->shape()
                                     : to_move.formatting_ops.back()->shape();
-    if (output_shape != to_move.collective_to_move->operand(0)->shape()) {
+    if (!reuse_output_buffer(to_move.collective_to_move) ||
+        output_shape != to_move.collective_to_move->operand(0)->shape()) {
       moves_requiring_special_output.push_back(count);
       to_skip_set.insert(to_move.dynamic_update_slice);
     }
@@ -1300,7 +1314,7 @@
   new_loop_analysis.ComputeLoopStatistics();
   new_loop_analysis.CollectCollectivesToMove(
       level_to_operate_on, CollectivePipeliner::PipeliningDirection::kForward,
-      should_process);
+      should_process, acceptable_formatting);
   CHECK_EQ(new_loop_analysis.GetMoveInfos().size(),
            loop_analysis.GetMoveInfos().size());
   for (int64_t i = new_loop_tuple_operand_count;
@@ -1776,12 +1790,37 @@
       }
       return operands;
     };
-    // We are adding a batch dimension to the formatting ops, so we need to
-    // specially rewrite each instruction potentially if adding dimensions has
-    // an effect on the instruction itself (like say broadcast, slices ... etc).
+    absl::flat_hash_set<HloInstruction*> to_add_batch_set;
+    absl::flat_hash_set<HloInstruction*> formatting_ops_set(
+        to_move.formatting_ops.begin(), to_move.formatting_ops.end());
+    std::vector<HloInstruction*> stack(1, to_move.collective_to_move);
+    while (!stack.empty()) {
+      auto* current = stack.back();
+      stack.pop_back();
+      to_add_batch_set.insert(current);
+      for (auto* u : current->users()) {
+        if (formatting_ops_set.contains(u)) {
+          stack.push_back(u);
+        }
+      }
+    }
+    //  We are adding a batch dimension to the formatting ops, so we need to
+    //  specially rewrite each instruction potentially if adding dimensions has
+    //  an effect on the instruction itself (like say broadcast, slices ...
+    //  etc).
     for (HloInstruction* formatting_op : to_move.formatting_ops) {
+      if (!to_add_batch_set.contains(formatting_op) &&
+          formatting_op->opcode() != HloOpcode::kBroadcast) {
+        HloInstruction* cloned_not_to_batch = loop_computation->AddInstruction(
+            formatting_op->CloneWithNewOperands(
+                formatting_op->shape(), collect_operands(formatting_op)));
+        pipelined_map[formatting_op] = cloned_not_to_batch;
+        continue;
+      }
       if (formatting_op->IsElementwise() ||
           formatting_op->opcode() == HloOpcode::kReshape ||
+          formatting_op->opcode() == HloOpcode::kReduce ||
+          formatting_op->opcode() == HloOpcode::kAllReduce ||
           formatting_op->opcode() == HloOpcode::kConvert ||
           formatting_op->opcode() == HloOpcode::kCollectivePermute) {
         HloInstruction* cloned_elementwise = loop_computation->AddInstruction(
@@ -1834,13 +1873,13 @@
                     ->shape()
                     .element_type())));
         std::vector<HloInstruction*> indices(1, zero);
-        indices.insert(indices.end(), dynslice->index_operands().begin(),
-                       dynslice->index_operands().end());
+        auto collected_operands = collect_operands(formatting_op);
+        indices.insert(indices.end(), std::next(collected_operands.begin()),
+                       collected_operands.end());
         HloInstruction* expanded_dynslice =
             loop_computation->AddInstruction(HloInstruction::CreateDynamicSlice(
                 ComputeFullOutputShape(to_move, formatting_op->shape()),
-                collect_operands(formatting_op)[0], indices,
-                dynamic_slice_sizes));
+                collected_operands[0], indices, dynamic_slice_sizes));
         pipelined_map[formatting_op] = expanded_dynslice;
         continue;
       }
@@ -1879,7 +1918,7 @@
         pipelined_map[formatting_op] = expanded_transpose;
         continue;
       }
-      CHECK(false) << "Unsupported instruction";
+      CHECK(false) << "Unsupported instruction " << formatting_op->ToString();
     }
     HloInstruction* inserted_operand =
         to_move.dynamic_update_slice->mutable_operand(1);
@@ -1939,6 +1978,7 @@
                                     int64_t level_to_operate_on,
                                     bool process_different_sized_ops,
                                     HloPredicate should_process,
+                                    HloPredicate acceptable_formatting,
                                     int64_t& next_channel_id) {
   // Defining some maps/sets to keep track of instructions duplicated.
   absl::flat_hash_map<HloInstruction*, HloInstruction*> while_body_to_peeled;
@@ -2244,9 +2284,9 @@
     }
     VLOG(1) << "While iterations: "
             << loop_analysis.GetLoopIterationCount()->ToString();
-    loop_analysis.CollectCollectivesToMove(config_.level_to_operate_on,
-                                           config_.pipelining_direction,
-                                           config_.should_process);
+    loop_analysis.CollectCollectivesToMove(
+        config_.level_to_operate_on, config_.pipelining_direction,
+        config_.should_process, config_.acceptable_formatting);
     if (loop_analysis.GetMoveInfos().empty()) {
       continue;
     }
@@ -2265,7 +2305,8 @@
       TF_RETURN_IF_ERROR(TransformLoopForward(
           loop_analysis, !config_.last_run, config_.level_to_operate_on,
           config_.pipeline_use_tree, config_.process_different_sized_ops,
-          config_.should_process, next_channel_id));
+          config_.should_process, config_.acceptable_formatting,
+          config_.reuse_pipelined_op_buffer, next_channel_id));
     } else if (config_.pipelining_direction ==
                PipeliningDirection::kForwardSink) {
       TF_RETURN_IF_ERROR(TransformLoopForwardSink(
@@ -2277,7 +2318,7 @@
       TF_RETURN_IF_ERROR(TransformLoopBackward(
           loop_analysis, !config_.last_run, config_.level_to_operate_on,
           config_.process_different_sized_ops, config_.should_process,
-          next_channel_id));
+          config_.acceptable_formatting, next_channel_id));
     }
     ++transformed_loops;
     changed = true;
@@ -2307,6 +2348,11 @@
           << " and transformed instructions: " << transformed_instructions
           << " for pipelining direction: "
           << GetPipelineDirectionString(config_.pipelining_direction);
+  // Run necessary cleanup to make sure unused code doesn't trigger HloVerifier.
+  if (changed) {
+    TF_RETURN_IF_ERROR(HloDCE().Run(module, execution_threads).status());
+  }
+
   return changed;
 }
 
diff --git a/third_party/xla/xla/service/collective_pipeliner.h b/third_party/xla/xla/service/collective_pipeliner.h
index d1de713..01b167d 100644
--- a/third_party/xla/xla/service/collective_pipeliner.h
+++ b/third_party/xla/xla/service/collective_pipeliner.h
@@ -75,6 +75,13 @@
     bool process_different_sized_ops = false;
     PipeliningDirection pipelining_direction = PipeliningDirection::kForward;
     HloPredicate should_process;
+    // Filter acceptable formatting ops for for forward piplining to discard
+    // cases that pipeline formatting operations that we don't want to support.
+    HloPredicate acceptable_formatting;
+    // If the pipelined op has same input/output size the we reuse  the same
+    // buffer we are storing the value in in the output loop for forward
+    // pipelining. This function allows to not do it for certain ops.
+    HloPredicate reuse_pipelined_op_buffer;
   };
   static const char* const kInsertedByPreviousStep;
   static const char* const kSunkByPreviousStep;
diff --git a/third_party/xla/xla/service/collective_pipeliner_test.cc b/third_party/xla/xla/service/collective_pipeliner_test.cc
index 845576a..efdd4f9 100644
--- a/third_party/xla/xla/service/collective_pipeliner_test.cc
+++ b/third_party/xla/xla/service/collective_pipeliner_test.cc
@@ -58,7 +58,11 @@
     bool pipeline_use_tree = false, bool process_different_sized_ops = true,
     CollectivePipeliner::PipeliningDirection direction =
         CollectivePipeliner::PipeliningDirection::kForward,
-    HloPredicate should_process = HloPredicateIsOp<HloOpcode::kAllReduce>) {
+    HloPredicate should_process = HloPredicateIsOp<HloOpcode::kAllReduce>,
+    HloPredicate acceptable_formatting =
+        [](const HloInstruction*) { return true; },
+    HloPredicate reuse_pipelined_op_buffer =
+        [](const HloInstruction* i) { return true; }) {
   CollectivePipeliner::Config config = {
       /*level_to_operate_on=*/level_to_operate_on,
       /*max_pipelining_per_loop=*/INT64_MAX,
@@ -68,6 +72,8 @@
       /*direction=*/
       direction,
       /*should_process=*/should_process,
+      /*acceptable_formatting=*/acceptable_formatting,
+      /*reuse_pipelined_op_buffer=*/reuse_pipelined_op_buffer,
   };
   HloPassPipeline pass("optimizer");
   pass.AddPass<HloVerifier>(/*layout_sensitive=*/false,
@@ -157,6 +163,70 @@
   EXPECT_EQ(get_tuple_index->tuple_index(), 3);
 }
 
+TEST_F(CollectivePipelinerTest, TransformIncrementIndexByOneNoReuse) {
+  constexpr absl::string_view hlo_string = R"(
+HloModule module
+
+add {
+  lhs = bf16[] parameter(0)
+  rhs = bf16[] parameter(1)
+  ROOT add = bf16[] add(lhs, rhs)
+}
+
+while_cond {
+  param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
+  gte = s32[] get-tuple-element(param), index=0
+  constant.1 = s32[] constant(3)
+  ROOT cmp = pred[] compare(gte, constant.1), direction=LT
+}
+
+while_body {
+  param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
+  get-tuple-element.394 = s32[] get-tuple-element(param), index=0
+  get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1
+  get-tuple-element.5 = bf16[3,8,128] get-tuple-element(param), index=2
+  constant.2557 = s32[] constant(1)
+  add.230 = s32[] add(get-tuple-element.394, constant.2557)
+  constant.2559 = s32[] constant(3)
+  subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394)
+  constant.2560 = s32[] constant(-1)
+  add.231 = s32[] add(subtract.139, constant.2560)
+  constant.2561 = s32[] constant(0)
+  compare.747 = pred[] compare(add.231, constant.2561), direction=LT
+  constant.2562 = s32[] constant(2)
+  add.232 = s32[] add(subtract.139, constant.2562)
+  select.1348 = s32[] select(compare.747, add.232, add.231)
+  dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.5, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
+  mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99)
+  ar.1 = bf16[1,8,128] all-reduce(mul), replica_groups={}, to_apply=add, channel_id=1
+  dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, ar.1, select.1348, constant.2561, constant.2561)
+  ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.5)
+}
+
+ENTRY entry {
+  c0 = s32[] constant(0)
+  p0 = bf16[3,8,128] parameter(0)
+  tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0)
+  while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body
+  ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1
+}
+)";
+  auto module = ParseAndReturnUnverifiedModule(hlo_string, config_).value();
+  EXPECT_TRUE(RunOptimizer(
+                  module.get(), /*last_run=*/true, 0, false, true,
+                  CollectivePipeliner::PipeliningDirection::kForward,
+                  HloPredicateIsOp<HloOpcode::kAllReduce>,
+                  /*acceptable_formatting=*/
+                  [](const HloInstruction* i) { return true; },
+                  /*reuse_pipelined_op_buffer=*/
+                  [](const HloInstruction* i) { return false; })
+                  .value());
+  XLA_VLOG_LINES(1, module->ToString());
+  HloInstruction* while_instr =
+      FindInstruction(module.get(), HloOpcode::kWhile);
+  EXPECT_EQ(while_instr->shape().tuple_shapes_size(), 5);
+}
+
 TEST_F(CollectivePipelinerTest, TransformIncrementIndexByOneNotFirstIdx) {
   constexpr absl::string_view hlo_string = R"(
 HloModule module
@@ -1803,5 +1873,146 @@
   XLA_VLOG_LINES(1, module->ToString());
 }
 
+TEST_F(CollectivePipelinerTest, MultiUsesElementwiseFeedTwoWithReduce) {
+  constexpr absl::string_view hlo_string = R"(
+HloModule module
+
+add {
+  lhs = bf16[] parameter(0)
+  rhs = bf16[] parameter(1)
+  ROOT add = bf16[] add(lhs, rhs)
+}
+
+add.1 {
+  lhs = bf16[] parameter(0)
+  rhs = bf16[] parameter(1)
+  ROOT add = bf16[] add(lhs, rhs)
+}
+
+add.2 {
+  lhs = bf16[] parameter(0)
+  rhs = bf16[] parameter(1)
+  ROOT add = bf16[] add(lhs, rhs)
+}
+
+while_cond {
+  param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
+  gte = s32[] get-tuple-element(param), index=0
+  constant.1 = s32[] constant(3)
+  ROOT cmp = pred[] compare(gte, constant.1), direction=LT
+}
+
+while_body {
+  param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
+  get-tuple-element.394 = s32[] get-tuple-element(param), index=0
+  get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1
+  get-tuple-element.5 = bf16[3,8,128] get-tuple-element(param), index=2
+  constant.2557 = s32[] constant(1)
+  add.230 = s32[] add(get-tuple-element.394, constant.2557)
+  constant.2559 = s32[] constant(3)
+  subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394)
+  constant.2560 = s32[] constant(-1)
+  add.231 = s32[] add(subtract.139, constant.2560)
+  constant.2561 = s32[] constant(0)
+  compare.747 = pred[] compare(add.231, constant.2561), direction=LT
+  constant.2562 = s32[] constant(2)
+  add.232 = s32[] add(subtract.139, constant.2562)
+  select.1348 = s32[] select(compare.747, add.232, add.231)
+  dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.5, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
+  mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99)
+  bm = bf16[1,1,8,128] broadcast(mul), dimensions={1,2,3}
+  c2 = bf16[] constant(2.0)
+  bc = bf16[1,8,128] broadcast(c2)
+  ar.1 = bf16[1,1,8,128] all-reduce(bm), replica_groups={}, to_apply=add, channel_id=1
+  ar.2 = bf16[1,1,8,128] all-reduce(ar.1), replica_groups={}, to_apply=add, channel_id=2
+  red.1 = bf16[1,8,128] reduce(ar.1, c2), to_apply=add.1, dimensions={0}
+  red.2 = bf16[1,8,128] reduce(ar.2, c2), to_apply=add.2, dimensions={0}
+  mul2 = bf16[1,8,128] multiply(red.1, bc), control-predecessors={ar.1}
+  mul3 = bf16[1,8,128] multiply(mul2, red.2)
+  mul4 = bf16[1,8,128] multiply(mul3, mul)
+  dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, mul4, select.1348, constant.2561, constant.2561)
+  ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.5), control-predecessors={ar.1}
+}
+
+ENTRY entry {
+  c0 = s32[] constant(0)
+  p0 = bf16[3,8,128] parameter(0)
+  tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0)
+  while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body
+  ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1
+}
+)";
+  auto module = ParseAndReturnUnverifiedModule(hlo_string, config_).value();
+  EXPECT_TRUE(RunOptimizer(module.get(), /*last_run=*/true, 0,
+                           /*pipeline_use_tree=*/true,
+                           /*process_different_sized_ops=*/true,
+                           CollectivePipeliner::PipeliningDirection::kForward)
+                  .value());
+  XLA_VLOG_LINES(1, module->ToString());
+}
+
+TEST_F(CollectivePipelinerTest, PipelinedReduceScatterCanPassVerifier) {
+  constexpr absl::string_view hlo_string = R"(
+HloModule module
+
+to_apply0 {
+  Arg_0.732 = bf16[] parameter(0)
+  Arg_1.733 = bf16[] parameter(1)
+  ROOT add.734 = bf16[] add(Arg_0.732, Arg_1.733)
+}
+
+body {
+  p2 = (s32[], bf16[3,4096,4096]{2,1,0}, bf16[10,512,3,4096]{3,2,1,0}) parameter(0)
+  gte2 = bf16[3,4096,4096]{2,1,0} get-tuple-element(p2), index=1
+  gte3 = bf16[10,512,3,4096]{3,2,1,0} get-tuple-element(p2), index=2
+  c2 = s32[] constant(9)
+  gte4 = s32[] get-tuple-element(p2), index=0
+  sub0 = s32[] subtract(c2, gte4)
+  c3 = s32[] constant(0)
+  comp1 = pred[] compare(sub0, c3), direction=LT
+  c4 = s32[] constant(19)
+  sub2 = s32[] subtract(c4, gte4)
+  sel0 = s32[] select(comp1, sub2, sub0)
+
+  rsp0 = bf16[3,4096,4096]{2,1,0} reshape(gte2)
+  rs0 = bf16[3,4096,512]{2,1,0} reduce-scatter(rsp0), channel_id=75, replica_groups={{0,1,2,3}}, dimensions={2}, to_apply=to_apply0
+  tran0 = bf16[512,3,4096]{0,2,1} transpose(rs0), dimensions={2,0,1}
+  rsp1 = bf16[1,512,3,4096]{3,2,1,0} reshape(tran0)
+  dus0 = bf16[10,512,3,4096]{3,2,1,0} dynamic-update-slice(gte3, rsp1, sel0, c3, c3, /*index=5*/c3)
+  c5 = s32[] constant(1)
+  add0 = s32[] add(gte4, c5)
+  ROOT t1 = (s32[], bf16[3,4096,4096]{2,1,0}, bf16[10,512,3,4096]{3,2,1,0}) tuple(add0, rsp0, dus0)
+} // body
+
+condition {
+  cond_p1 = (s32[], bf16[3,4096,4096]{2,1,0}, bf16[10,512,3,4096]{3,2,1,0}) parameter(0)
+  gte1 = s32[] get-tuple-element(cond_p1), index=0
+  c1 = s32[] constant(9)
+  ROOT comp0 = pred[] compare(gte1, c1), direction=LT
+}
+
+ENTRY main.3813_spmd {
+  p0 = bf16[3,4096,4096]{2,1,0} parameter(0)
+  p1 = bf16[10,512,3,4096]{3,2,1,0} parameter(1)
+  c0 = s32[] constant(0)
+
+  t0 = (s32[], bf16[3,4096,4096]{2,1,0}, bf16[10,512,3,4096]{3,2,1,0}) tuple(c0, p0, p1)
+  w0 = (s32[], bf16[3,4096,4096]{2,1,0}, bf16[10,512,3,4096]{3,2,1,0}) while(t0), condition=condition, body=body
+  ROOT gte0 = bf16[3,4096,4096]{2,1,0} get-tuple-element(w0), index=1
+}
+)";
+  auto module = ParseAndReturnUnverifiedModule(hlo_string, config_).value();
+  EXPECT_TRUE(RunOptimizer(module.get(), /*last_run=*/true, 0,
+                           /*pipeline_use_tree=*/true,
+                           /*process_different_sized_ops=*/true,
+                           CollectivePipeliner::PipeliningDirection::kForward,
+                           HloPredicateIsOp<HloOpcode::kReduceScatter>)
+                  .value());
+  XLA_VLOG_LINES(1, module->ToString());
+  HloVerifier verifier(/*layout_sensitive=*/false,
+                       /*allow_mixed_precision*/ true);
+  ASSERT_IS_OK(verifier.Run(module.get()).status());
+}
+
 }  // namespace
 }  // namespace xla
diff --git a/third_party/xla/xla/service/convert_mover.cc b/third_party/xla/xla/service/convert_mover.cc
index d351e64..cb76816 100644
--- a/third_party/xla/xla/service/convert_mover.cc
+++ b/third_party/xla/xla/service/convert_mover.cc
@@ -16,6 +16,7 @@
 #include "xla/service/convert_mover.h"
 
 #include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/primitive_util.h"
 #include "xla/service/hlo_creation_utils.h"
 
 namespace xla {
@@ -113,6 +114,12 @@
       continue;
     }
 
+    // Currently int4 is not supported in most ops so moving the convert is not
+    // safe.
+    if (primitive_util::Is4BitType(src_ty)) {
+      continue;
+    }
+
     VLOG(2) << "Moving increase-precision convert op " << convert_op->ToString()
             << " down the graph: " << instr->ToString();
 
@@ -162,6 +169,9 @@
     if (primitive_util::BitWidth(src_ty) <= primitive_util::BitWidth(dst_ty)) {
       continue;
     }
+    if (primitive_util::Is4BitType(dst_ty)) {
+      continue;
+    }
 
     VLOG(2) << "Moving decrease-precision convert up the graph: "
             << instr->ToString();
diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD
index b230166..5652d5a 100644
--- a/third_party/xla/xla/service/cpu/BUILD
+++ b/third_party/xla/xla/service/cpu/BUILD
@@ -1,7 +1,7 @@
 # Description:
 #    LLVM-based CPU backend for XLA.
 
-load("@local_tsl//tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable")
+load("@bazel_skylib//rules:build_test.bzl", "build_test")
 load(
     "//xla:xla.bzl",
     "ORC_JIT_MEMORY_MAPPER_TARGETS",
@@ -9,19 +9,19 @@
     "xla_cc_test",
 )
 load(
-    "@local_tsl//tsl/mkl:build_defs.bzl",
-    "mkl_deps",
-)
-load("@local_tsl//tsl:tsl.bzl", "tf_openmp_copts", "tsl_copts")
-load(
     "//third_party/compute_library:build_defs.bzl",
     "acl_deps",
     "if_enable_acl",
 )
-load(":build_defs.bzl", "runtime_copts")
+load("@local_tsl//tsl:tsl.bzl", "tf_openmp_copts", "tsl_copts")
+load("@local_tsl//tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable")
+load(
+    "@local_tsl//tsl/mkl:build_defs.bzl",
+    "mkl_deps",
+)
 load("@local_tsl//tsl/platform:build_config.bzl", "if_llvm_system_z_available", "tf_proto_library")
 load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library")
-load("@bazel_skylib//rules:build_test.bzl", "build_test")
+load(":build_defs.bzl", "runtime_copts")
 
 package(
     default_visibility = ["//visibility:public"],
@@ -322,6 +322,7 @@
         "//xla/service:slow_operation_alarm",
         "//xla/service:sort_simplifier",
         "//xla/service:stochastic_convert_decomposer",
+        "//xla/service:sub_byte_normalization",
         "//xla/service:topk_rewriter",
         "//xla/service:transpose_folding",
         "//xla/service:tree_reduction_rewriter",
@@ -1700,3 +1701,10 @@
         "//xla/service:pattern_matcher",
     ] + mkl_deps(),
 )
+
+cc_library(
+    name = "cpu_symbol_repository",
+    hdrs = ["cpu_symbol_repository.h"],
+    visibility = ["//visibility:public"],
+    deps = ["//xla/service:symbol_repository"],
+)
diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.cc b/third_party/xla/xla/service/cpu/cpu_compiler.cc
index 8f56714..376da95 100644
--- a/third_party/xla/xla/service/cpu/cpu_compiler.cc
+++ b/third_party/xla/xla/service/cpu/cpu_compiler.cc
@@ -205,6 +205,7 @@
 #include "xla/service/sort_simplifier.h"
 #include "xla/service/spmd/stateful_rng_spmd_partitioner.h"
 #include "xla/service/stochastic_convert_decomposer.h"
+#include "xla/service/sub_byte_normalization.h"
 #include "xla/service/topk_rewriter.h"
 #include "xla/service/transpose_folding.h"
 #include "xla/service/tree_reduction_rewriter.h"
@@ -648,6 +649,16 @@
     TF_RETURN_IF_ERROR(sharding_removal_pipeline.Run(module).status());
   }
 
+  {
+    // Int4Packer must be run before the rest of the pipeline since it modifies
+    // the layout of the entry computation inputs/outputs, which is passed to
+    // LayoutAssignment.
+    HloPassPipeline int4_packer_pipeline("Int4Packer pipeline");
+    int4_packer_pipeline.AddPass<SubByteNormalization>(
+        SubByteNormalization::SET_ELEMENT_SIZE);
+    TF_RETURN_IF_ERROR(int4_packer_pipeline.Run(module).status());
+  }
+
   HloPassPipeline pipeline("HLO passes through layout assignment");
   AddHloVerifier(&pipeline, allow_sparse_shapes_);
 
@@ -742,8 +753,15 @@
   pipeline.AddPass<ConditionalCanonicalizer>();
   pipeline.AddPass<DynamicDimensionSimplifier>();
   auto dynamic_padder_options = DynamicPadderOptions();
+  // TODO(pgavin): ShapeChecks were never implemented correctly by the dynamic
+  // padder.  The mode defaults to kIgnore, and it was not overridden for nested
+  // computations (such as while bodies or conditional branches), and so cases
+  // that could not be proven would still be accepted even with compile-time
+  // checks enabled.  Recent changes to the DynamicPadder correctly
+  // override the mode.  However, some models have started to rely on the check
+  // being ignored, and they would be broken if it is enforced.
   dynamic_padder_options.shape_check_mode =
-      DynamicDimensionInference::ShapeCheckMode::kCompileTime;
+      DynamicDimensionInference::ShapeCheckMode::kIgnore;
   pipeline.AddPass<DynamicPadder>(dynamic_padder_options);
   if (!is_mlir_compile) {
     pipeline.AddPass<SelectAndScatterExpander>();
@@ -840,6 +858,10 @@
     pipeline.AddPass<CpuLayoutAssignment>(
         module->mutable_entry_computation_layout(), target_machine_features,
         &layout_constraints);
+    // Run SubByteNormalization because CpuLayoutAssignment may modify a
+    // Layout's element_size_in_bits field.
+    pipeline.AddPass<SubByteNormalization>(
+        SubByteNormalization::SET_ELEMENT_SIZE);
   }
 
   return pipeline.Run(module).status();
diff --git a/third_party/xla/xla/service/cpu/cpu_symbol_repository.h b/third_party/xla/xla/service/cpu/cpu_symbol_repository.h
new file mode 100644
index 0000000..da5bd03
--- /dev/null
+++ b/third_party/xla/xla/service/cpu/cpu_symbol_repository.h
@@ -0,0 +1,29 @@
+#ifndef XLA_SERVICE_CPU_CPU_SYMBOL_REPOSITORY_H_
+#define XLA_SERVICE_CPU_CPU_SYMBOL_REPOSITORY_H_
+
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/symbol_repository.h"
+#include "xla/xla.pb.h"
+
+namespace xla::cpu {
+
+// CPU-specific fields for SymbolRepositories.
+struct CpuBackendSpecificData : public BackendSpecificData {};
+
+}  // namespace xla::cpu
+
+#endif  // XLA_SERVICE_CPU_CPU_SYMBOL_REPOSITORY_H_
diff --git a/third_party/xla/xla/service/cpu/cpu_transfer_manager.h b/third_party/xla/xla/service/cpu/cpu_transfer_manager.h
index 4ce26e3..bf9b10f 100644
--- a/third_party/xla/xla/service/cpu/cpu_transfer_manager.h
+++ b/third_party/xla/xla/service/cpu/cpu_transfer_manager.h
@@ -59,6 +59,8 @@
                            Shape* device_shape) override;
 
  private:
+  bool PackSubbyteTypes() const override { return true; }
+
   CpuTransferManager(const CpuTransferManager&) = delete;
   CpuTransferManager& operator=(const CpuTransferManager&) = delete;
 };
diff --git a/third_party/xla/xla/service/cpu_gpu_shape_verifier.cc b/third_party/xla/xla/service/cpu_gpu_shape_verifier.cc
index 82bf70d..ba678fe 100644
--- a/third_party/xla/xla/service/cpu_gpu_shape_verifier.cc
+++ b/third_party/xla/xla/service/cpu_gpu_shape_verifier.cc
@@ -16,27 +16,55 @@
 #include "xla/service/cpu_gpu_shape_verifier.h"
 
 #include "absl/status/status.h"
+#include "absl/strings/str_format.h"
 #include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/primitive_util.h"
 #include "xla/shape.h"
 #include "xla/shape_util.h"
 #include "xla/status.h"
+#include "tsl/platform/errors.h"
 
 namespace xla {
 
 namespace {
-
-bool HasInt4(const Shape& shape) {
-  return ShapeUtil::HasPrimitiveType(shape, S4) ||
-         ShapeUtil::HasPrimitiveType(shape, U4);
-}
-
 Status VerifyS4U4Usage(HloInstruction* instruction) {
-  if (HasInt4(instruction->shape())) {
-    return absl::InvalidArgumentError(absl::StrFormat(
-        "S4/U4 is currently not support on XLA CPU/GPU, but got instruction "
-        " with S4/U4 output: %s",
-        instruction->ToString()));
+  switch (instruction->opcode()) {
+    case HloOpcode::kBitcast:
+    case HloOpcode::kConstant:
+    case HloOpcode::kConvert:
+    case HloOpcode::kCopy:
+    case HloOpcode::kFusion:
+    case HloOpcode::kGetTupleElement:
+    case HloOpcode::kParameter:
+    case HloOpcode::kTuple:
+      TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
+          instruction->shape(), [&](const Shape& shape, const ShapeIndex&) {
+            if (primitive_util::Is4BitType(shape.element_type()) &&
+                ShapeUtil::ElementsIn(shape) % 2 == 1) {
+              return absl::InvalidArgumentError(absl::StrFormat(
+                  "S4/U4 arrays must have an even number of elements, but got "
+                  "instruction with S4/U4 input with odd number of elements: "
+                  "%s",
+                  instruction->ToString()));
+            }
+            return OkStatus();
+          }));
+      break;
+    default:
+      TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
+          instruction->shape(), [&](const Shape& shape, const ShapeIndex&) {
+            if (primitive_util::Is4BitType(shape.element_type())) {
+              return absl::InvalidArgumentError(absl::StrFormat(
+                  "S4/U4 is currently only supported in convert instructions, "
+                  "but got instruction with S4/U4 input: %s",
+                  instruction->ToString()));
+            }
+            return OkStatus();
+          }));
+      break;
   }
+
   return OkStatus();
 }
 }  // namespace
@@ -50,10 +78,11 @@
                 "The XLA CPU/GPU backend does not support sparse shapes: %s",
                 hlo->ToString()));
           }
-          if (shape.layout().element_size_in_bits() != 0) {
+          if (!primitive_util::Is4BitType(shape.element_type()) &&
+              shape.layout().element_size_in_bits() != 0) {
             return absl::InvalidArgumentError(absl::StrFormat(
-                "The XLA CPU/GPU backend does not support custom element "
-                "sizes: %s",
+                "The XLA CPU/GPU backend does not support custom element sizes "
+                "on non-4-bit types: %s",
                 hlo->ToString()));
           }
         }
diff --git a/third_party/xla/xla/service/cpu_gpu_shape_verifier_test.cc b/third_party/xla/xla/service/cpu_gpu_shape_verifier_test.cc
index 8b5cd45..b31add3 100644
--- a/third_party/xla/xla/service/cpu_gpu_shape_verifier_test.cc
+++ b/third_party/xla/xla/service/cpu_gpu_shape_verifier_test.cc
@@ -28,29 +28,54 @@
 namespace xla {
 namespace {
 
-using CpuGpuShapeVerifierTest = HloTestBase;
 using ::testing::HasSubstr;
 
-TEST_F(CpuGpuShapeVerifierTest, Int4NotSupported) {
+class CpuGpuShapeVerifierTest : public HloTestBase {
+ public:
+  CpuGpuShapeVerifierTest() {
+    // Create HloVerifier which uses CpuGpuShapeVerifier
+    HloVerifierOpts opts;
+    std::unique_ptr<TargetVerifierMetadata> metadata =
+        std::make_unique<CpuGpuVerifierMetadata>(std::move(opts));
+    hlo_verifier_ = std::make_unique<HloVerifier>(std::move(metadata));
+  }
+};
+
+TEST_F(CpuGpuShapeVerifierTest, Int4UnsupportedInstruction) {
   const char* const hlo_string = R"(
   HloModule Module
 
   ENTRY main {
-    p0 = u4[10] parameter(0)
-    ROOT out = u8[10] convert(p0)
+    p0 = u4[2,5] parameter(0)
+    ROOT out = u4[2,5] add(p0, p0)
   }
   )";
   TF_ASSERT_OK_AND_ASSIGN(auto module,
                           ParseAndReturnUnverifiedModule(hlo_string));
 
-  HloVerifierOpts opts;
-  std::unique_ptr<TargetVerifierMetadata> metadata =
-      std::make_unique<CpuGpuVerifierMetadata>(std::move(opts));
-  HloVerifier hlo_verifier(std::move(metadata));
-  auto status = hlo_verifier.Run(module.get()).status();
+  auto status = verifier().Run(module.get()).status();
+  ASSERT_FALSE(status.ok());
+  EXPECT_THAT(
+      status.message(),
+      HasSubstr("S4/U4 is currently only supported in convert instructions"));
+}
+
+TEST_F(CpuGpuShapeVerifierTest, Int4OddNumberOfElements) {
+  const char* const hlo_string = R"(
+  HloModule Module
+
+  ENTRY main {
+    p0 = u4[11] parameter(0)
+    ROOT out = u8[11] convert(p0)
+  }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnUnverifiedModule(hlo_string));
+
+  auto status = verifier().Run(module.get()).status();
   ASSERT_FALSE(status.ok());
   EXPECT_THAT(status.message(),
-              HasSubstr("S4/U4 is currently not support on XLA CPU/GPU"));
+              HasSubstr("S4/U4 arrays must have an even number of elements,"));
 }
 
 }  // namespace
diff --git a/third_party/xla/xla/service/despecializer.cc b/third_party/xla/xla/service/despecializer.cc
index c11a04d..c63aeb3 100644
--- a/third_party/xla/xla/service/despecializer.cc
+++ b/third_party/xla/xla/service/despecializer.cc
@@ -32,7 +32,8 @@
   pipeline_.AddPass<ControlDepRemover>();
   pipeline_.AddPass<Defuser>();
   pipeline_.AddPass<BFloat16MixedPrecisionRemoval>();
-  pipeline_.AddPass<SubByteNormalization>();
+  pipeline_.AddPass<SubByteNormalization>(
+      SubByteNormalization::REMOVE_ELEMENT_SIZE);
 }
 
 void Despecializer::AddReduceWindowToReduceBroadcastDeconstruct() {
diff --git a/third_party/xla/xla/service/dot_decomposer.cc b/third_party/xla/xla/service/dot_decomposer.cc
index dd71c63..45d8554 100644
--- a/third_party/xla/xla/service/dot_decomposer.cc
+++ b/third_party/xla/xla/service/dot_decomposer.cc
@@ -55,18 +55,25 @@
   std::vector<int64_t> lhs_non_contracting_dims;
   lhs_non_contracting_dims.reserve(num_lhs_non_contracting_dims);
   int64_t lhs_contracting_size = 1;
+  bool lhs_contracting_dynamic = false;
   int64_t lhs_non_contracting_size = 1;
+  bool lhs_non_contracting_dynamic = false;
   std::vector<int64_t> batch_dim_sizes;
   batch_dim_sizes.reserve(num_batch_dims);
+  std::vector<bool> batch_dynamic_dims;
+  batch_dynamic_dims.reserve(num_batch_dims);
   for (int64_t i = 0; i < lhs_rank; ++i) {
     if (absl::c_linear_search(original_dnums.lhs_contracting_dimensions(), i)) {
       lhs_contracting_size *= lhs_shape.dimensions(i);
+      lhs_contracting_dynamic |= lhs_shape.is_dynamic_dimension(i);
     } else if (absl::c_linear_search(original_dnums.lhs_batch_dimensions(),
                                      i)) {
       batch_dim_sizes.push_back(lhs_shape.dimensions(i));
+      batch_dynamic_dims.push_back(lhs_shape.is_dynamic_dimension(i));
     } else {
       lhs_non_contracting_dims.push_back(i);
       lhs_non_contracting_size *= lhs_shape.dimensions(i);
+      lhs_non_contracting_dynamic |= lhs_shape.is_dynamic_dimension(i);
     }
   }
   // The canonical form of the lhs is
@@ -90,14 +97,18 @@
       &lhs_operand->metadata());
 
   std::vector<int64_t> lhs_reshape_dims = batch_dim_sizes;
+  std::vector<bool> lhs_reshape_dynamic_dims = batch_dynamic_dims;
   if (lhs_non_contracting_size > 1) {
     lhs_reshape_dims.push_back(lhs_non_contracting_size);
+    lhs_reshape_dynamic_dims.push_back(lhs_non_contracting_dynamic);
   }
   lhs_reshape_dims.push_back(lhs_contracting_size);
+  lhs_reshape_dynamic_dims.push_back(lhs_contracting_dynamic);
   // Reshape the contracting and non-contracting dimensions together.
   HloInstruction* reshaped_lhs = computation->AddInstruction(
       HloInstruction::CreateReshape(
-          ShapeUtil::MakeShape(lhs_shape.element_type(), lhs_reshape_dims),
+          ShapeUtil::MakeShape(lhs_shape.element_type(), lhs_reshape_dims,
+                               lhs_reshape_dynamic_dims),
           transposed_lhs),
       &transposed_lhs->metadata());
 
@@ -108,14 +119,18 @@
   std::vector<int64_t> rhs_non_contracting_dims;
   rhs_non_contracting_dims.reserve(num_rhs_non_contracting_dims);
   int64_t rhs_non_contracting_size = 1;
+  bool rhs_non_contracting_dynamic = false;
   int64_t rhs_contracting_size = 1;
+  bool rhs_contracting_dynamic = false;
   for (int64_t i = 0; i < rhs_rank; ++i) {
     if (absl::c_linear_search(original_dnums.rhs_contracting_dimensions(), i)) {
       rhs_contracting_size *= rhs_shape.dimensions(i);
+      rhs_contracting_dynamic |= rhs_shape.is_dynamic_dimension(i);
     } else if (!absl::c_linear_search(original_dnums.rhs_batch_dimensions(),
                                       i)) {
       rhs_non_contracting_dims.push_back(i);
       rhs_non_contracting_size *= rhs_shape.dimensions(i);
+      rhs_non_contracting_dynamic |= rhs_shape.is_dynamic_dimension(i);
     }
   }
 
@@ -141,22 +156,29 @@
 
   std::vector<int64_t> rhs_reshape_dims = batch_dim_sizes;
   rhs_reshape_dims.push_back(rhs_contracting_size);
+  std::vector<bool> rhs_reshape_dynamic_dims = batch_dynamic_dims;
+  rhs_reshape_dynamic_dims.push_back(rhs_contracting_dynamic);
   if (rhs_non_contracting_size > 1) {
     rhs_reshape_dims.push_back(rhs_non_contracting_size);
+    rhs_reshape_dynamic_dims.push_back(rhs_non_contracting_dynamic);
   }
   // Reshape the contracting and non-contracting dimensions together.
   HloInstruction* reshaped_rhs = computation->AddInstruction(
       HloInstruction::CreateReshape(
-          ShapeUtil::MakeShape(rhs_shape.element_type(), rhs_reshape_dims),
+          ShapeUtil::MakeShape(rhs_shape.element_type(), rhs_reshape_dims,
+                               rhs_reshape_dynamic_dims),
           transposed_rhs),
       &transposed_rhs->metadata());
 
   std::vector<int64_t> dot_dims = batch_dim_sizes;
+  std::vector<bool> dot_dynamic_dims = batch_dynamic_dims;
   if (lhs_non_contracting_size > 1) {
     dot_dims.push_back(lhs_non_contracting_size);
+    dot_dynamic_dims.push_back(lhs_non_contracting_dynamic);
   }
   if (rhs_non_contracting_size > 1) {
     dot_dims.push_back(rhs_non_contracting_size);
+    dot_dynamic_dims.push_back(rhs_non_contracting_dynamic);
   }
 
   DotDimensionNumbers dot_dnums;
@@ -169,7 +191,8 @@
   dot_dnums.add_rhs_contracting_dimensions(num_batch_dims);
 
   HloInstruction* dot = computation->AddInstruction(HloInstruction::CreateDot(
-      ShapeUtil::MakeShape(original_dot->shape().element_type(), dot_dims),
+      ShapeUtil::MakeShape(original_dot->shape().element_type(), dot_dims,
+                           dot_dynamic_dims),
       reshaped_lhs, reshaped_rhs, dot_dnums, original_dot->precision_config()));
   original_dot->SetupDerivedInstruction(dot);
 
diff --git a/third_party/xla/xla/service/dot_merger.cc b/third_party/xla/xla/service/dot_merger.cc
index de29064..020ade5 100644
--- a/third_party/xla/xla/service/dot_merger.cc
+++ b/third_party/xla/xla/service/dot_merger.cc
@@ -19,6 +19,7 @@
 #include <set>
 #include <string>
 #include <utility>
+#include <vector>
 
 #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
 #include "xla/service/graphcycles/graphcycles.h"
@@ -304,10 +305,18 @@
   // them earlier because removing an instruction deletes it; we'd then have
   // dangling pointers in our hashtable!)
   absl::flat_hash_set<HloInstruction*> dead_instrs;
+  std::vector<HloInstruction*> keys;
+  keys.reserve(equivalence_classes.size());
   for (auto& kv : equivalence_classes) {
+    keys.push_back(kv.first);
+  }
+  absl::c_sort(keys, [](const HloInstruction* a, const HloInstruction* b) {
+    return a->unique_id() < b->unique_id();
+  });
+  for (auto key : keys) {
+    const auto& values = equivalence_classes[key];
     // For determinism, iterate in order of the instructions' IDs.
-    absl::InlinedVector<HloInstruction*, 16> dots(kv.second.begin(),
-                                                  kv.second.end());
+    absl::InlinedVector<HloInstruction*, 16> dots(values.begin(), values.end());
     absl::c_sort(dots, [](const HloInstruction* a, const HloInstruction* b) {
       return a->unique_id() < b->unique_id();
     });
diff --git a/third_party/xla/xla/service/dump.cc b/third_party/xla/xla/service/dump.cc
index 65a09e9..a490f9a 100644
--- a/third_party/xla/xla/service/dump.cc
+++ b/third_party/xla/xla/service/dump.cc
@@ -20,6 +20,7 @@
 #include <queue>
 #include <utility>
 
+#include "absl/functional/any_invocable.h"
 #include "absl/strings/ascii.h"
 #include "absl/strings/str_cat.h"
 #include "llvm/ADT/SmallString.h"
@@ -31,6 +32,7 @@
 #include "xla/hlo/ir/hlo_module.h"
 #include "xla/service/hlo_graph_dumper.h"
 #include "xla/service/hlo_proto_util.h"
+#include "xla/status.h"
 #include "xla/util.h"
 #include "tsl/lib/io/zlib_compression_options.h"
 #include "tsl/lib/io/zlib_outputbuffer.h"
@@ -472,7 +474,7 @@
       if (!rendered_graph.ok()) {
         VLOG(1) << "Skipping fusion visualization"
                 << " for computation " << computation->name()
-                << " due to: " << rendered_graph.status().ToString();
+                << " due to: " << rendered_graph.status();
         continue;
       }
       file_paths.push_back(DumpToFileInDirImpl(
@@ -634,7 +636,10 @@
 
 void DumpProtobufToFile(const tsl::protobuf::Message& proto,
                         const DebugOptions& debug_options,
-                        absl::string_view filename) {
+                        absl::string_view filename,
+                        absl::AnyInvocable<StatusOr<std::string>(
+                            tsl::Env*, const tsl::protobuf::Message&)>
+                            text_formatter) {
   CanonicalDebugOptions opts(debug_options);
   tsl::Env* env = tsl::Env::Default();
   const std::string& dir = opts.dump_to;
@@ -642,31 +647,45 @@
     auto status = env->RecursivelyCreateDir(dir);
     if (!status.ok()) {
       LOG(ERROR) << "Could not create directory " << dir
-                 << " for dumping XLA execution options: " << status;
+                 << " for dumping: " << status;
       return;
     }
   }
-  if (env->IsDirectory(dir).ok()) {
-    const std::string path = tsl::io::JoinPath(dir, filename);
-    Status status;
-    if (opts.dump_as_text) {
-      status = tsl::WriteTextProto(env, absl::StrCat(path, ".txt"), proto);
+  if (!env->IsDirectory(dir).ok()) {
+    return;
+  }
+  const std::string path = tsl::io::JoinPath(dir, filename);
+  Status status;
+  if (opts.dump_as_text) {
+    if (text_formatter) {
+      auto written_proto = text_formatter(env, proto);
+      if (!written_proto.status().ok()) {
+        LOG(ERROR) << "Failure with custom proto text formatting function. "
+                   << "Could not write XLA data to " << filename << ": "
+                   << written_proto.status();
+        return;
+      }
+      status = tsl::WriteStringToFile(env, absl::StrCat(path, ".txt"),
+                                      written_proto.value());
     } else {
-      status = tsl::WriteBinaryProto(env, absl::StrCat(path, ".pb"), proto);
+      status = tsl::WriteTextProto(env, absl::StrCat(path, ".txt"), proto);
     }
-    if (!status.ok()) {
-      LOG(ERROR) << "Could not write XLA debug data to " << filename << ": "
-                 << status;
-    }
+  } else {
+    status = tsl::WriteBinaryProto(env, absl::StrCat(path, ".pb"), proto);
+  }
+  if (!status.ok()) {
+    LOG(ERROR) << "Could not write XLA data to " << filename << ": " << status;
   }
 }
 
-void DumpPerModuleProtobufToFile(const HloModule& module,
-                                 const tsl::protobuf::Message& proto,
-                                 const DebugOptions& debug_options,
-                                 absl::string_view name) {
+void DumpPerModuleProtobufToFile(
+    const HloModule& module, const tsl::protobuf::Message& proto,
+    const DebugOptions& debug_options, absl::string_view name,
+    absl::AnyInvocable<StatusOr<std::string>(tsl::Env*,
+                                             const tsl::protobuf::Message&)>
+        text_formatter) {
   const std::string filename = FilenameFor(module, TimestampFor(module), name);
-  DumpProtobufToFile(proto, debug_options, filename);
+  DumpProtobufToFile(proto, debug_options, filename, std::move(text_formatter));
 }
 
 void DumpHloModuleIfEnabled(const HloModule& module, string_view name) {
diff --git a/third_party/xla/xla/service/dump.h b/third_party/xla/xla/service/dump.h
index 6e1cab9..86244ba 100644
--- a/third_party/xla/xla/service/dump.h
+++ b/third_party/xla/xla/service/dump.h
@@ -83,17 +83,23 @@
 
 // Dumps the given protobuf to the given filename if dumping is enabled.
 // Exactly where and in what formats it's dumped is determined by the debug
-// options.
+// options. Allows for an optional custom serialization function to be used for
+// added customization.
 void DumpProtobufToFile(const tsl::protobuf::Message& proto,
                         const DebugOptions& debug_options,
-                        absl::string_view filename);
+                        absl::string_view filename,
+                        absl::AnyInvocable<StatusOr<std::string>(
+                            tsl::Env*, const tsl::protobuf::Message&)>
+                            text_formatter = nullptr);
 
 // Similar to above, but the filename depends on module's information and the
-// given name.
-void DumpPerModuleProtobufToFile(const HloModule& module,
-                                 const tsl::protobuf::Message& proto,
-                                 const DebugOptions& debug_options,
-                                 absl::string_view name);
+// given name. Also allows for the optional serialization function.
+void DumpPerModuleProtobufToFile(
+    const HloModule& module, const tsl::protobuf::Message& proto,
+    const DebugOptions& debug_options, absl::string_view name,
+    absl::AnyInvocable<StatusOr<std::string>(tsl::Env*,
+                                             const tsl::protobuf::Message&)>
+        text_formatter = nullptr);
 
 // Dumps the given HLO module if dumping is enabled for the module. Exactly
 // where and in what formats it's dumped is determined by the module's config.
diff --git a/third_party/xla/xla/service/dynamic_dimension_inference.cc b/third_party/xla/xla/service/dynamic_dimension_inference.cc
index bfce53c..5e76198 100644
--- a/third_party/xla/xla/service/dynamic_dimension_inference.cc
+++ b/third_party/xla/xla/service/dynamic_dimension_inference.cc
@@ -17,7 +17,9 @@
 
 #include <cstdint>
 #include <functional>
+#include <memory>
 #include <string>
+#include <tuple>
 #include <utility>
 #include <vector>
 
@@ -25,6 +27,7 @@
 #include "absl/container/flat_hash_map.h"
 #include "absl/container/flat_hash_set.h"
 #include "absl/container/inlined_vector.h"
+#include "absl/functional/function_ref.h"
 #include "absl/log/check.h"
 #include "absl/log/log.h"
 #include "absl/strings/match.h"
@@ -32,6 +35,7 @@
 #include "absl/strings/str_format.h"
 #include "absl/strings/str_join.h"
 #include "absl/strings/string_view.h"
+#include "absl/types/span.h"
 #include "xla/comparison_util.h"
 #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
 #include "xla/hlo/ir/dynamic_parameter_binding.h"
@@ -41,10 +45,13 @@
 #include "xla/hlo/ir/hlo_instructions.h"
 #include "xla/hlo/ir/hlo_module.h"
 #include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/literal.h"
 #include "xla/literal_util.h"
 #include "xla/service/call_inliner.h"
 #include "xla/service/dynamic_window_utils.h"
 #include "xla/service/hlo_creation_utils.h"
+#include "xla/service/hlo_dataflow_analysis.h"
+#include "xla/service/hlo_value.h"
 #include "xla/service/tuple_util.h"
 #include "xla/service/while_util.h"
 #include "xla/shape.h"
@@ -58,79 +65,92 @@
 #include "xla/xla_data.pb.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/status.h"
+#include "tsl/platform/statusor.h"
+
 namespace xla {
 
 namespace {
 // Replace `narrow_comp` with a new computation with `wide_shape` as input.
-StatusOr<HloComputation*> WidenComputation(HloComputation* narrow_comp,
-                                           const Shape& wide_shape) {
+StatusOr<std::pair<HloComputation*, CallInliner::InlinedInstructionMap>>
+WidenComputation(HloComputation* narrow_comp, const Shape& wide_shape) {
   TF_RET_CHECK(wide_shape.IsTuple());
   const Shape& narrow_shape = narrow_comp->parameter_instruction(0)->shape();
   if (Shape::Equal()(wide_shape, narrow_shape)) {
     // No need to widen the computation.
-    return narrow_comp;
+    return std::make_pair(narrow_comp, CallInliner::InlinedInstructionMap());
   }
   HloComputation* wide_comp = [&]() {
     HloComputation::Builder builder(absl::StrCat("wide.", narrow_comp->name()));
-    builder.AddInstruction(
-        HloInstruction::CreateParameter(0, wide_shape, "wide_param"));
+    builder.AddInstruction(HloInstruction::CreateParameter(
+        0, wide_shape,
+        absl::StrCat("wide.", narrow_comp->parameter_instruction(0)->name())));
     return narrow_comp->parent()->AddEmbeddedComputation(builder.Build());
   }();
 
   HloInstruction* wide_parameter = wide_comp->parameter_instruction(0);
   HloInstruction* truncated_parameter = TupleUtil::ExtractPrefix(
-      wide_parameter, narrow_shape.tuple_shapes_size());
+      wide_parameter, narrow_shape.tuple_shapes_size(),
+      absl::StrCat("renarrowed.",
+                   narrow_comp->parameter_instruction(0)->name()));
   HloInstruction* call_narrow_comp = wide_comp->AddInstruction(
       HloInstruction::CreateCall(narrow_comp->root_instruction()->shape(),
                                  {truncated_parameter}, narrow_comp));
   wide_comp->set_root_instruction(call_narrow_comp,
                                   /*accept_different_shape=*/true);
-  TF_RETURN_IF_ERROR(CallInliner::Inline(call_narrow_comp).status());
-  return wide_comp;
+  TF_ASSIGN_OR_RETURN(auto inline_map, CallInliner::Inline(call_narrow_comp));
+  return std::make_pair(wide_comp, std::move(inline_map));
 }
 }  // namespace
 
-class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault {
+class DynamicDimensionInferenceVisitor : public DfsHloRewriteVisitor {
  public:
   explicit DynamicDimensionInferenceVisitor(
       const DynamicParameterBinding& param_bindings,
-      DynamicDimensionInference* parent,
+      HloDataflowAnalysis& dataflow_analysis, DynamicDimensionInference* parent,
       DynamicDimensionInference::CustomCallInferenceHandler custom_call_handler,
-      DynamicDimensionInference::ShapeCheckMode shape_check_mode)
+      DynamicDimensionInference::ShapeCheckMode shape_check_mode,
+      DynamicDimensionInference::AssertionGenerator assertion_generator)
       : param_bindings_(param_bindings),
+        dataflow_analysis_(dataflow_analysis),
         parent_(parent),
         custom_call_handler_(std::move(custom_call_handler)),
-        shape_check_mode_(shape_check_mode) {}
+        shape_check_mode_(shape_check_mode),
+        assertion_generator_(assertion_generator) {}
 
   Status DefaultAction(HloInstruction* hlo) override;
 
-  static Status Run(HloComputation* computation,
-                    const DynamicParameterBinding& param_bindings,
-                    DynamicDimensionInference* parent,
-                    DynamicDimensionInference::CustomCallInferenceHandler
-                        custom_call_handler = nullptr,
-                    DynamicDimensionInference::ShapeCheckMode shape_check_mode =
-                        DynamicDimensionInference::ShapeCheckMode::kIgnore,
-                    const DynamicDimensionInference::AssertionGenerator&
-                        assertion_generator = nullptr) {
+  static StatusOr<bool> Run(
+      HloComputation* computation, HloDataflowAnalysis& dataflow_analysis,
+      const DynamicParameterBinding& param_bindings,
+      DynamicDimensionInference* parent,
+      DynamicDimensionInference::CustomCallInferenceHandler
+          custom_call_handler = nullptr,
+      DynamicDimensionInference::ShapeCheckMode shape_check_mode =
+          DynamicDimensionInference::ShapeCheckMode::kIgnore,
+      const DynamicDimensionInference::AssertionGenerator& assertion_generator =
+          nullptr) {
     if (!HloInstruction::IsThreadIncluded(computation->execution_thread(),
                                           parent->execution_threads_)) {
-      return OkStatus();
+      return false;
     }
-    DynamicDimensionInferenceVisitor visitor(param_bindings, parent,
-                                             std::move(custom_call_handler),
-                                             shape_check_mode);
+    DynamicDimensionInferenceVisitor visitor(
+        param_bindings, dataflow_analysis, parent,
+        std::move(custom_call_handler), shape_check_mode, assertion_generator);
 
     TF_RETURN_IF_ERROR(computation->Accept(&visitor));
     if (visitor.shape_assertion_ != nullptr) {
       CHECK(assertion_generator);
       assertion_generator(visitor.shape_assertion_);
     }
-    return OkStatus();
+    return visitor.changed();
   }
 
   Status HandleParameter(HloInstruction* hlo) override;
 
+  Status HandleInfeed(HloInstruction* hlo) override;
+
+  Status HandleConstant(HloInstruction* hlo) override;
+
   Status HandleReduce(HloInstruction* hlo) override;
 
   Status HandleDot(HloInstruction* hlo) override;
@@ -197,6 +217,8 @@
 
   Status HandleAsyncStart(HloInstruction* hlo) override;
 
+  Status HandleAsyncDone(HloInstruction* hlo) override;
+
  private:
   using OperandDynamicDimensionFn = absl::FunctionRef<Status(
       HloInstruction* operand, ShapeIndex index, int64_t dimension,
@@ -205,6 +227,13 @@
   using DynamicDimensionFn = std::function<Status(
       ShapeIndex index, int64_t dimension, HloInstruction* dynamic_size)>;
 
+  void SetDynamicSize(HloInstruction* inst, const ShapeIndex& index,
+                      int64_t dim, HloInstruction* size,
+                      bool clear_dynamic_dimension = true);
+
+  void SetDynamicSizes(HloInstruction* inst, const ShapeIndex& index,
+                       absl::Span<HloInstruction* const> sizes);
+
   Status HandleDynamicConvolutionForward(HloInstruction* hlo,
                                          int64_t operand_index,
                                          int64_t dimension,
@@ -231,6 +260,22 @@
   Status ForEachDynamicDimension(HloInstruction* inst,
                                  const DynamicDimensionFn& fn);
 
+  bool CanInfer(HloInstruction* hlo) { return parent_->CanInfer(hlo); }
+
+  // Return true unless all users of the instruction can consume a dynamic shape
+  // (including uses across control flow, but only within the same thread). The
+  // given `ShapeIndex` is the leaf array returned by the given instruction that
+  // will be considered.
+  StatusOr<bool> RequiresPadToStatic(HloInstruction* instr,
+                                     ShapeIndex shape_index);
+
+  // Insert pad-to-static after `inst` if `inst` has dynamic dimensions in it
+  // and `RequiresPadToStatic` is true for all leaves. If the instruction
+  // produces a tuple, each tuple component will be considered independently.
+  // Returns the original instruction, with all arrays converted to static
+  // shapes.
+  Status InsertPadToStaticOnInstruction(HloInstruction* inst);
+
   // Insert shape check to make sure `dim1` is equal to `dim2`. If
   // support_implicit_broadcast is true, the check will pass if either of them
   // is 1, even if they are different.
@@ -245,6 +290,8 @@
   // The dynamic parameter bindings of this computation.
   const DynamicParameterBinding& param_bindings_;
 
+  HloDataflowAnalysis& dataflow_analysis_;
+
   // A pointer to DynamicDimensionInference, used to update the dynamic mapping.
   DynamicDimensionInference* parent_;
 
@@ -256,8 +303,36 @@
 
   // Value which has to be `true` for the shapes to match.
   HloInstruction* shape_assertion_ = nullptr;
+
+  DynamicDimensionInference::AssertionGenerator assertion_generator_;
 };
 
+void DynamicDimensionInferenceVisitor::SetDynamicSize(
+    HloInstruction* inst, const ShapeIndex& index, int64_t dim,
+    HloInstruction* size, bool clear_dynamic_dimension) {
+  parent_->SetDynamicSize(inst, index, dim, size);
+  // Clear the dynamic dimension since we have recorded a dynamic size.
+  // If there are any dynamic dimensions left after DynamicPadder has completely
+  // run, we will raise an error.
+  if (clear_dynamic_dimension) {
+    ShapeUtil::GetMutableSubshape(inst->mutable_shape(), index)
+        ->set_dynamic_dimension(dim, false);
+  }
+  MarkAsChanged();
+}
+
+void DynamicDimensionInferenceVisitor::SetDynamicSizes(
+    HloInstruction* inst, const ShapeIndex& index,
+    absl::Span<HloInstruction* const> sizes) {
+  const Shape& subshape = ShapeUtil::GetSubshape(inst->shape(), index);
+  CHECK(subshape.IsArray() && subshape.rank() == sizes.size());
+  for (int64_t dimension = 0; dimension < subshape.rank(); ++dimension) {
+    if (sizes[dimension] != nullptr) {
+      SetDynamicSize(inst, index, dimension, sizes[dimension]);
+    }
+  }
+}
+
 Status DynamicDimensionInferenceVisitor::DefaultAction(HloInstruction* hlo) {
   return ForEachOperandDynamicDimension(
       hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
@@ -271,37 +346,120 @@
 
 Status DynamicDimensionInferenceVisitor::HandleGetTupleElement(
     HloInstruction* hlo) {
+  if (!CanInfer(hlo)) {
+    return OkStatus();
+  }
   return ForEachOperandDynamicDimension(
-      hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
-               int64_t operand_index, HloInstruction* dynamic_size) {
-        if (hlo->tuple_index() == index[0]) {
-          ShapeIndex new_index(ShapeIndexView(index).subspan(1));
-          parent_->SetDynamicSize(hlo, new_index, dimension, dynamic_size);
+      hlo,
+      [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
+          int64_t operand_index, HloInstruction* dynamic_size) -> Status {
+        if (hlo->tuple_index() != index[0]) {
+          return OkStatus();
         }
+        ShapeIndex new_index(ShapeIndexView(index).subspan(1));
+        SetDynamicSize(hlo, new_index, dimension, dynamic_size);
         return OkStatus();
       });
 }
 
 Status DynamicDimensionInferenceVisitor::HandleTuple(HloInstruction* hlo) {
-  return ForEachOperandDynamicDimension(
+  if (!CanInfer(hlo)) {
+    return OkStatus();
+  }
+  TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
       hlo, [&](HloInstruction*, ShapeIndex index, int64_t dimension,
                int64_t operand_index, HloInstruction* dynamic_size) {
         index.push_front(operand_index);
-        parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
+        SetDynamicSize(hlo, index, dimension, dynamic_size);
         return OkStatus();
-      });
+      }));
+  return OkStatus();
 }
 
 Status DynamicDimensionInferenceVisitor::HandleBroadcast(HloInstruction* hlo) {
+  if (!CanInfer(hlo)) {
+    return OkStatus();
+  }
   return ForEachOperandDynamicDimension(
       hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
                int64_t operand_index, HloInstruction* dynamic_size) {
         int64_t broadcast_dim = hlo->dimensions(dimension);
-        parent_->SetDynamicSize(hlo, {}, broadcast_dim, dynamic_size);
+        SetDynamicSize(hlo, {}, broadcast_dim, dynamic_size);
         return OkStatus();
       });
 }
 
+Status DynamicDimensionInferenceVisitor::HandleConstant(HloInstruction* hlo) {
+  if (!hlo->shape().is_dynamic()) {
+    return OkStatus();
+  }
+  auto* constant = Cast<HloConstantInstruction>(hlo);
+  ShapeTree<bool> do_pad(constant->shape(), false);
+  Shape padded_shape = constant->shape();
+  bool pad_any = false;
+  TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshapeWithStatus(
+      &padded_shape, [&](Shape* subshape, const ShapeIndex& index) -> Status {
+        if (!subshape->IsArray()) {
+          return OkStatus();
+        }
+        TF_ASSIGN_OR_RETURN(bool requires_pad, RequiresPadToStatic(hlo, index));
+        if (requires_pad) {
+          pad_any = *do_pad.mutable_element(index) = true;
+          *subshape = ShapeUtil::MakeStaticShape(*subshape);
+        }
+        return OkStatus();
+      }));
+  if (!pad_any) {
+    return OkStatus();
+  }
+  Literal padded_literal(padded_shape);
+  do_pad.ForEachElement([&](const ShapeIndex& index, bool requires_pad) {
+    const Shape& subshape = ShapeUtil::GetSubshape(padded_shape, index);
+    if (!subshape.IsArray()) {
+      return OkStatus();
+    }
+    TF_RETURN_IF_ERROR(padded_literal.CopyFrom(constant->literal(), index,
+                                               index,
+                                               /*only_dynamic_bound=*/true));
+    if (!requires_pad) {
+      for (int64_t dimension = 0; dimension < subshape.rank(); ++dimension) {
+        if (subshape.is_dynamic_dimension(dimension)) {
+          padded_literal.SetDynamicSize(
+              dimension, index,
+              constant->literal().GetDynamicSize(dimension, index));
+        }
+      }
+    }
+    return OkStatus();
+  });
+  auto* padded_constant = hlo->AddInstruction(
+      HloInstruction::CreateConstant(std::move(padded_literal)));
+  TF_RETURN_IF_ERROR(constant->ReplaceAllUsesWith(padded_constant));
+  SetVisited(*padded_constant);
+  TF_RETURN_IF_ERROR(do_pad.ForEachElementWithStatus(
+      [&](const ShapeIndex& index, bool requires_pad) -> Status {
+        if (!requires_pad) {
+          return OkStatus();
+        }
+        const Shape& subshape =
+            ShapeUtil::GetSubshape(constant->shape(), index);
+        TF_RET_CHECK(subshape.IsArray());
+        for (int64_t dimension = 0; dimension < subshape.rank(); ++dimension) {
+          if (!subshape.is_dynamic_dimension(dimension)) {
+            continue;
+          }
+          HloInstruction* dynamic_size = hlo->AddInstruction(
+              HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(
+                  constant->literal().GetDynamicSize(dimension, index))));
+          SetVisited(*dynamic_size);
+          SetDynamicSize(padded_constant, index, dimension, dynamic_size);
+        }
+        return OkStatus();
+      }));
+  MarkAsChanged();
+  return OkStatus();
+}
+
 Status DynamicDimensionInferenceVisitor::HandleCustomCall(HloInstruction* hlo) {
   if (hlo->custom_call_target() == "PadToStatic") {
     for (int64_t i = 0; i < hlo->operand(0)->shape().rank(); ++i) {
@@ -313,103 +471,89 @@
         // returns the padded data output and the dynamic sizes of input
         // dimensions.
         ShapeIndex data_output = {0};
-        parent_->SetDynamicSize(hlo, data_output, i, dynamic_size);
+        SetDynamicSize(hlo, data_output, i, dynamic_size);
       }
     }
     return OkStatus();
   }
+
+  if (!CanInfer(hlo)) {
+    return OkStatus();
+  }
+
   if (custom_call_handler_) {
-    return custom_call_handler_(hlo, parent_);
-  }
-
-  if (hlo->custom_call_target() == "DynamicConvolutionForward") {
-    // If input feature is dynamic and kernel feature is static, we can infer
-    // that input feature is also static.
-    // E.g.,:
-    // lhs = [B, X, Y, ?]
-    // rhs = [X, Y, I, O]
-    // dim_labels = b01f_01io
-    // We can infer that the dynamic dimension in rhs is static I.
-    const ConvolutionDimensionNumbers& dnums =
-        hlo->convolution_dimension_numbers();
-    HloInstruction* input_feature = parent_->GetDynamicSize(
-        hlo->mutable_operand(0), {}, dnums.input_feature_dimension());
-    HloInstruction* kernel_feature = parent_->GetDynamicSize(
-        hlo->mutable_operand(1), {}, dnums.kernel_input_feature_dimension());
-
-    if (input_feature != nullptr && kernel_feature == nullptr) {
-      if (hlo->mutable_operand(0)->shape().dimensions(
-              dnums.input_feature_dimension()) ==
-          hlo->mutable_operand(1)->shape().dimensions(
-              dnums.kernel_input_feature_dimension()))
-        parent_->SetDynamicSize(hlo->mutable_operand(0), {},
-                                dnums.input_feature_dimension(), nullptr);
-    }
-  }
-  return ForEachOperandDynamicDimension(
-      hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
-               int64_t operand_index, HloInstruction* dynamic_size) {
-        // Resize custom call should propagate dynamic batch (0) and channel (3)
-        // dimensions.
-        if (hlo->custom_call_target() == "SliceToDynamic" ||
-            hlo->custom_call_target() == "Sharding" ||
-            (absl::StartsWith(hlo->custom_call_target(), "Resize") &&
-             (dimension == 0 || dimension == 3))) {
-          parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
-          return OkStatus();
-        }
-        if (hlo->custom_call_target() == "DynamicReduceWindowSamePadding") {
-          if (hlo->operand_count() > 2) {
-            return Unimplemented(
-                "DynamicReduceWindowSamePadding doesn't support variadic "
-                "reduce window %s",
-                hlo->ToString());
-          }
-          return HandleDynamicWindowSamePadding(hlo, dynamic_size,
-                                                operand_index, dimension);
-        }
-
-        if (hlo->custom_call_target() == "DynamicSelectAndScatterSamePadding") {
-          if (operand_index == 1) {
-            // Operand 0 (input) determines dynamic output size. We ignore the
-            // dynamic size in the operand 1 (output gradient).
+    TF_RETURN_IF_ERROR(custom_call_handler_(hlo, parent_));
+  } else {
+    TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
+        hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
+                 int64_t operand_index, HloInstruction* dynamic_size) {
+          // Resize custom call should propagate dynamic batch (0) and channel
+          // (3) dimensions.
+          if (hlo->custom_call_target() == "SliceToDynamic" ||
+              hlo->custom_call_target() == "Sharding" ||
+              (absl::StartsWith(hlo->custom_call_target(), "Resize") &&
+               (dimension == 0 || dimension == 3))) {
+            SetDynamicSize(hlo, {}, dimension, dynamic_size);
             return OkStatus();
           }
-          parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
-          return OkStatus();
-        }
+          if (hlo->custom_call_target() == "DynamicReduceWindowSamePadding") {
+            if (hlo->operand_count() > 2) {
+              return Unimplemented(
+                  "DynamicReduceWindowSamePadding doesn't support variadic "
+                  "reduce window %s",
+                  hlo->ToString());
+            }
+            return HandleDynamicWindowSamePadding(hlo, dynamic_size,
+                                                  operand_index, dimension);
+          }
 
-        if (hlo->custom_call_target() == "DynamicConvolutionInputGrad") {
-          return HandleDynamicConvolutionInputGrad(hlo, operand_index,
-                                                   dimension);
-        }
+          if (hlo->custom_call_target() ==
+              "DynamicSelectAndScatterSamePadding") {
+            if (operand_index == 1) {
+              // Operand 0 (input) determines dynamic output size. We ignore the
+              // dynamic size in the operand 1 (output gradient).
+              return OkStatus();
+            }
+            SetDynamicSize(hlo, {}, dimension, dynamic_size);
+            return OkStatus();
+          }
 
-        if (hlo->custom_call_target() == "DynamicConvolutionKernelGrad") {
-          return HandleDynamicConvolutionKernelGrad(hlo, operand_index,
-                                                    dimension);
-        }
+          if (hlo->custom_call_target() == "DynamicConvolutionInputGrad") {
+            return HandleDynamicConvolutionInputGrad(hlo, operand_index,
+                                                     dimension);
+          }
 
-        if (hlo->custom_call_target() == "DynamicConvolutionForward") {
-          return HandleDynamicConvolutionForward(hlo, operand_index, dimension,
-                                                 dynamic_size);
-        }
-        return Unimplemented(
-            "CustomCall \"%s\" is not supported to have a dynamic dimension",
-            hlo->custom_call_target());
-      });
+          if (hlo->custom_call_target() == "DynamicConvolutionKernelGrad") {
+            return HandleDynamicConvolutionKernelGrad(hlo, operand_index,
+                                                      dimension);
+          }
+
+          if (hlo->custom_call_target() == "DynamicConvolutionForward") {
+            return HandleDynamicConvolutionForward(hlo, operand_index,
+                                                   dimension, dynamic_size);
+          }
+          return Unimplemented(
+              "CustomCall \"%s\" is not supported to have a dynamic dimension",
+              hlo->custom_call_target());
+        }));
+  }
+
+  return InsertPadToStaticOnInstruction(hlo);
 }
 
 Status DynamicDimensionInferenceVisitor::HandleSort(HloInstruction* hlo) {
+  if (!CanInfer(hlo)) {
+    return OkStatus();
+  }
   return ForEachOperandDynamicDimension(
       hlo,
       [&](HloInstruction* operand, ShapeIndex index, int64_t dynamic_dimension,
           int64_t operand_index, HloInstruction* dynamic_size) {
         HloSortInstruction* sort = Cast<HloSortInstruction>(hlo);
         if (sort->values_count() == 0) {
-          parent_->SetDynamicSize(hlo, {}, dynamic_dimension, dynamic_size);
+          SetDynamicSize(hlo, {}, dynamic_dimension, dynamic_size);
         } else {
-          parent_->SetDynamicSize(hlo, {operand_index}, dynamic_dimension,
-                                  dynamic_size);
+          SetDynamicSize(hlo, {operand_index}, dynamic_dimension, dynamic_size);
         }
 
         return OkStatus();
@@ -417,6 +561,9 @@
 }
 
 Status DynamicDimensionInferenceVisitor::HandlePad(HloInstruction* hlo) {
+  if (!CanInfer(hlo)) {
+    return OkStatus();
+  }
   return ForEachOperandDynamicDimension(
       hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
                int64_t operand_index, HloInstruction* dynamic_size) {
@@ -465,141 +612,175 @@
             hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
                 dynamic_size_adjusted->shape(), HloOpcode::kAdd,
                 dynamic_size_adjusted, adjustment));
-        parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size_adjusted);
+        SetDynamicSize(hlo, {}, dimension, dynamic_size_adjusted);
         return OkStatus();
       });
 }
 
 Status DynamicDimensionInferenceVisitor::HandleReduce(HloInstruction* hlo) {
-  return ForEachOperandDynamicDimension(
+  if (!CanInfer(hlo)) {
+    return OkStatus();
+  }
+  auto* reduce = Cast<HloReduceInstruction>(hlo);
+  int64_t rank = -1;
+  TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
+      reduce->shape(),
+      [&](const Shape& subshape, const ShapeIndex& index) -> Status {
+        if (!subshape.IsArray()) {
+          return OkStatus();
+        }
+        if (rank < 0) {
+          rank = subshape.rank();
+        } else {
+          TF_RET_CHECK(rank == subshape.rank());
+        }
+        return OkStatus();
+      }));
+  TF_RET_CHECK(rank >= 0);
+  absl::InlinedVector<HloInstruction*, 4> dynamic_sizes(rank, nullptr);
+
+  TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
       hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
                int64_t operand_index, HloInstruction* dynamic_size) {
-        auto* reduce = Cast<HloReduceInstruction>(hlo);
         int64_t operand_count = reduce->operand_count();
         CHECK_EQ(operand_count % 2, 0);
         if (operand_index >= reduce->input_count()) {
           // Init values doesn't have dynamic size.
           return OkStatus();
         }
-        if ((absl::c_count(reduce->dimensions(), dimension) != 0)) {
+        if (absl::c_count(reduce->dimensions(), dimension) != 0) {
           // Dimension is to be reduced, stop tracing.
           return OkStatus();
         }
 
         // Find out the new dynamic dimension after reduce.
         int64_t dimensions_not_reduced_count = 0;
-        for (int i = 0; i < operand->shape().rank(); ++i) {
+        for (int64_t i = 0; i < operand->shape().rank(); ++i) {
           if (dimension == i) {
             // The dimensions of all data operands of a variadic reduce have
             // to be the same.  This means that if one operand of variadic
             // reduce has a dynamic dimension, we set all outputs to use the
             // same dynamic size in corresponding dimensions.
-            ShapeUtil::ForEachSubshape(
-                reduce->shape(),
-                [&](const Shape& subshape, ShapeIndex reduce_result_index) {
-                  if (!ShapeUtil::IsLeafIndex(reduce->shape(),
-                                              reduce_result_index)) {
-                    return;
-                  }
-                  parent_->SetDynamicSize(reduce, reduce_result_index,
-                                          dimensions_not_reduced_count,
-                                          dynamic_size);
-                });
-
+            dynamic_sizes[dimensions_not_reduced_count] = dynamic_size;
             return OkStatus();
           }
-          if (absl::c_count(reduce->dimensions(), i) == 0) {
+          if (!absl::c_linear_search(reduce->dimensions(), i)) {
             dimensions_not_reduced_count++;
           }
         }
 
         return OkStatus();
+      }));
+
+  ShapeUtil::ForEachSubshape(
+      reduce->shape(), [&](const Shape& subshape, ShapeIndex shape_index) {
+        if (!subshape.IsArray()) {
+          return;
+        }
+        SetDynamicSizes(reduce, shape_index, dynamic_sizes);
       });
+
+  return OkStatus();
 }
 
 Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) {
-  return ForEachOperandDynamicDimension(hlo, [&](HloInstruction* operand,
-                                                 ShapeIndex operand_shape_index,
-                                                 int64_t operand_dimension,
-                                                 int64_t operand_index,
-                                                 HloInstruction* dynamic_size) {
-    // There are three types of dimensions in a dot:
-    // A. batch dims
-    // B. contracting dims
-    // C. non-batch non-contracting dims.
-    // The output dimensions of a dot has three parts with the following
-    // order:
-    // [(type A), (lhs type C), (rhs type C)]
-    //
-    // Note that both lhs and rhs have the same dimension sizes for batch,
-    // but the dimension index could be different.
-    //
-    // Given one dynamic input dimension, either lhs or rhs, we use a
-    // mapping to find the corresponding output dimension.
-    HloInstruction* dot = hlo;
-    const DotDimensionNumbers& dimension_numbers = dot->dot_dimension_numbers();
-    // A map from the operand dimensions to result dimension.
-    absl::flat_hash_map<int64_t, int64_t> result_dim_mapping;
-    int64_t current_result_dims = 0;
-
-    bool lhs = operand_index == 0;
-
-    // The first loop keep tracks of batch dimension. RHS and LHS could have
-    // different batch dimension numbers.
-    if (lhs) {
-      for (int64_t i : dimension_numbers.lhs_batch_dimensions()) {
-        result_dim_mapping[i] = current_result_dims++;
-      }
-    } else {
-      for (int64_t i : dimension_numbers.rhs_batch_dimensions()) {
-        result_dim_mapping[i] = current_result_dims++;
-      }
-    }
-
-    // Handle dimensions in the lhs.
-    for (int64_t i = 0; i < dot->operand(0)->shape().rank(); i++) {
-      // Look for non-contracting and non-batching dimension.
-      if (absl::c_linear_search(dimension_numbers.lhs_contracting_dimensions(),
-                                i)) {
-        continue;
-      }
-      if (absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(), i)) {
-        continue;
-      }
-      if (lhs) {
-        result_dim_mapping[i] = current_result_dims;
-      }
-      current_result_dims++;
-    }
-
-    // Handle dimensions in the rhs.
-    for (int64_t i = 0; i < dot->operand(1)->shape().rank(); i++) {
-      // Look for non-contracting and non-batching dimension.
-      if (absl::c_linear_search(dimension_numbers.rhs_contracting_dimensions(),
-                                i)) {
-        continue;
-      }
-      if (absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), i)) {
-        continue;
-      }
-      if (!lhs) {
-        result_dim_mapping[i] = current_result_dims;
-      }
-      current_result_dims++;
-    }
-
-    // Check if the operand dim is in the result shape. If so, add another
-    // work item to trace that dimension.
-    auto iter = result_dim_mapping.find(operand_dimension);
-    if (iter != result_dim_mapping.end()) {
-      parent_->SetDynamicSize(dot, {}, iter->second, dynamic_size);
-    }
-
+  if (!CanInfer(hlo)) {
     return OkStatus();
-  });
+  }
+  absl::InlinedVector<HloInstruction*, 4> dynamic_sizes(hlo->shape().rank(),
+                                                        nullptr);
+  TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
+      hlo,
+      [&](HloInstruction* operand, ShapeIndex operand_shape_index,
+          int64_t operand_dimension, int64_t operand_index,
+          HloInstruction* dynamic_size) -> Status {
+        // There are three types of dimensions in a dot:
+        // A. batch dims
+        // B. contracting dims
+        // C. non-batch non-contracting dims.
+        // The output dimensions of a dot has three parts with the following
+        // order:
+        // [(type A), (lhs type C), (rhs type C)]
+        //
+        // Note that both lhs and rhs have the same dimension sizes for batch,
+        // but the dimension index could be different.
+        //
+        // Given one dynamic input dimension, either lhs or rhs, we use a
+        // mapping to find the corresponding output dimension.
+        HloInstruction* dot = hlo;
+        const DotDimensionNumbers& dimension_numbers =
+            dot->dot_dimension_numbers();
+        // A map from the operand dimensions to result dimension.
+        absl::flat_hash_map<int64_t, int64_t> result_dim_mapping;
+        int64_t current_result_dims = 0;
+
+        bool lhs = operand_index == 0;
+
+        // The first loop keep tracks of batch dimension. RHS and LHS could have
+        // different batch dimension numbers.
+        if (lhs) {
+          for (int64_t i : dimension_numbers.lhs_batch_dimensions()) {
+            result_dim_mapping[i] = current_result_dims++;
+          }
+        } else {
+          for (int64_t i : dimension_numbers.rhs_batch_dimensions()) {
+            result_dim_mapping[i] = current_result_dims++;
+          }
+        }
+
+        // Handle dimensions in the lhs.
+        for (int64_t i = 0; i < dot->operand(0)->shape().rank(); i++) {
+          // Look for non-contracting and non-batching dimension.
+          if (absl::c_linear_search(
+                  dimension_numbers.lhs_contracting_dimensions(), i)) {
+            continue;
+          }
+          if (absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(),
+                                    i)) {
+            continue;
+          }
+          if (lhs) {
+            result_dim_mapping[i] = current_result_dims;
+          }
+          current_result_dims++;
+        }
+
+        // Handle dimensions in the rhs.
+        for (int64_t i = 0; i < dot->operand(1)->shape().rank(); i++) {
+          // Look for non-contracting and non-batching dimension.
+          if (absl::c_linear_search(
+                  dimension_numbers.rhs_contracting_dimensions(), i)) {
+            continue;
+          }
+          if (absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(),
+                                    i)) {
+            continue;
+          }
+          if (!lhs) {
+            result_dim_mapping[i] = current_result_dims;
+          }
+          current_result_dims++;
+        }
+
+        // Check if the operand dim is in the result shape. If so, add another
+        // work item to trace that dimension.
+        auto iter = result_dim_mapping.find(operand_dimension);
+        if (iter != result_dim_mapping.end()) {
+          dynamic_sizes[iter->second] = dynamic_size;
+        }
+
+        return OkStatus();
+      }));
+
+  SetDynamicSizes(hlo, {}, dynamic_sizes);
+
+  return OkStatus();
 }
 
 Status DynamicDimensionInferenceVisitor::HandleTranspose(HloInstruction* hlo) {
+  if (!CanInfer(hlo)) {
+    return OkStatus();
+  }
   return ForEachOperandDynamicDimension(
       hlo,
       [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
@@ -611,13 +792,16 @@
             permuted_dim = i;
           }
         }
-        parent_->SetDynamicSize(hlo, {}, permuted_dim, dynamic_size);
+        SetDynamicSize(hlo, {}, permuted_dim, dynamic_size);
         return OkStatus();
       });
 }
 
 Status DynamicDimensionInferenceVisitor::HandleConvolution(
     HloInstruction* hlo) {
+  if (!CanInfer(hlo)) {
+    return OkStatus();
+  }
   return ForEachOperandDynamicDimension(
       hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
                int64_t operand_index, HloInstruction* dynamic_size) {
@@ -626,9 +810,8 @@
             conv->convolution_dimension_numbers();
         if (operand_index == 0) {
           if (dimension == dimension_numbers.input_batch_dimension()) {
-            parent_->SetDynamicSize(conv, {},
-                                    dimension_numbers.output_batch_dimension(),
-                                    dynamic_size);
+            SetDynamicSize(conv, {}, dimension_numbers.output_batch_dimension(),
+                           dynamic_size);
             return OkStatus();
           }
 
@@ -648,26 +831,37 @@
 
 Status DynamicDimensionInferenceVisitor::HandleConcatenate(
     HloInstruction* hlo) {
+  if (!CanInfer(hlo)) {
+    return OkStatus();
+  }
   // First handle concatenate dimensions. We do this by iterating through all
   // operands while tracking both dynamic and static dimensions.
 
-  // static_size is used to keep track of the concated size of static
+  // static_size is used to keep track of the concatenated size of static
   // dimensions.
   int64_t static_size = 0;
   std::vector<HloInstruction*> dynamic_concat_dims;
   for (int64_t i = 0; i < hlo->operand_count(); ++i) {
-    HloInstruction* dynamic_size = parent_->GetDynamicSize(
-        hlo->mutable_operand(i), {}, hlo->concatenate_dimension());
-    if (dynamic_size == nullptr) {
+    HloInstruction* concat_dim_size = nullptr;
+    for (int64_t dimension = 0; dimension < hlo->operand(i)->shape().rank();
+         ++dimension) {
+      if (dimension == hlo->concatenate_dimension()) {
+        HloInstruction* dynamic_size =
+            parent_->GetDynamicSize(hlo->mutable_operand(i), {}, dimension);
+        concat_dim_size = dynamic_size;
+      }
+    }
+    if (concat_dim_size == nullptr) {
       // This is a static dimension.
       static_size +=
           hlo->operand(i)->shape().dimensions(hlo->concatenate_dimension());
     } else {
-      dynamic_concat_dims.push_back(dynamic_size);
+      dynamic_concat_dims.push_back(concat_dim_size);
     }
   }
   // If concat dimension is dynamic, calculate its size by summing up static
   // dims and dynamic dims together.
+  std::vector<HloInstruction*> dynamic_sizes(hlo->shape().rank(), nullptr);
   if (!dynamic_concat_dims.empty()) {
     HloInstruction* dim_size_total =
         hlo->parent()->AddInstruction(HloInstruction::CreateConstant(
@@ -677,21 +871,26 @@
           HloInstruction::CreateBinary(dim_size_total->shape(), HloOpcode::kAdd,
                                        dim_size_total, dynamic_dim));
     }
-    parent_->SetDynamicSize(hlo, {}, hlo->concatenate_dimension(),
-                            dim_size_total);
+    dynamic_sizes[hlo->concatenate_dimension()] = dim_size_total;
   }
 
   // Simply pass through non-concat dynamic dimensions.
-  return ForEachOperandDynamicDimension(
-      hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
-               int64_t operand_index, HloInstruction* dynamic_size) {
+  TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
+      hlo,
+      [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
+          int64_t operand_index, HloInstruction* dynamic_size) -> Status {
+        TF_RET_CHECK(index.empty());
         int64_t concatenate_dimension = hlo->concatenate_dimension();
         if (concatenate_dimension == dimension) {
           return OkStatus();
         }
-        parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
+        dynamic_sizes[dimension] = dynamic_size;
         return OkStatus();
-      });
+      }));
+
+  SetDynamicSizes(hlo, {}, dynamic_sizes);
+
+  return OkStatus();
 }
 
 Status DynamicDimensionInferenceVisitor::HandleGetDimensionSize(
@@ -706,22 +905,27 @@
   // the shape (although the value contains the real size of the dynamic
   // dimension of the input).
   int64_t dim = gds->dimension();
+  TF_RET_CHECK(dim < gds->operand(0)->shape().rank()) << gds->ToString();
   HloInstruction* operand = gds->mutable_operand(0);
-  HloInstruction* dynamic_size = parent_->GetDynamicSize(operand, {}, dim);
+  TF_RET_CHECK(dim < operand->shape().rank());
+  HloInstruction* replacement = parent_->GetDynamicSize(operand, {}, dim);
   HloComputation* computation = gds->parent();
-  if (dynamic_size != nullptr) {
-    TF_RETURN_IF_ERROR(gds->ReplaceAllUsesWith(dynamic_size));
+  if (replacement == nullptr &&
+      !gds->operand(0)->shape().is_dynamic_dimension(dim)) {
+    TF_RET_CHECK(dim < gds->operand(0)->shape().rank());
+    int32_t size = gds->operand(0)->shape().dimensions(dim);
+    replacement = computation->AddInstruction(
+        HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(size)),
+        gds->name());
+  }
+
+  if (replacement != nullptr) {
+    TF_RETURN_IF_ERROR(gds->ReplaceAllUsesWith(replacement));
     // The dependency between an instruction and its dynamic dimensions is not
     // modeled in the IR. As instr is being replaced by dynamic_size, also tell
     // dynamic dimension inference that the instruction is being replaced.
-    parent_->ReplaceAllDynamicDimensionUsesWith(gds, dynamic_size);
-  } else {
-    TF_RET_CHECK(dim < gds->operand(0)->shape().rank());
-    int32_t size = gds->operand(0)->shape().dimensions(dim);
-    HloInstruction* new_instr = computation->AddInstruction(
-        HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(size)));
-    TF_RETURN_IF_ERROR(gds->ReplaceAllUsesWith(new_instr));
-    parent_->ReplaceAllDynamicDimensionUsesWith(gds, new_instr);
+    parent_->ReplaceAllDynamicDimensionUsesWith(gds, replacement);
+    MarkAsChanged();
   }
   return OkStatus();
 }
@@ -749,15 +953,19 @@
   if (!dimension_is_static) {
     // Propagate dynamic dimension indicated by this set dimension size
     // instruction.
-    parent_->SetDynamicSize(hlo, {}, hlo->dimension(), hlo->mutable_operand(1));
+    SetDynamicSize(hlo, {}, hlo->dimension(), hlo->mutable_operand(1),
+                   /*clear_dynamic_dimension=*/false);
   }
 
   // Also Propagate dynamic dimension already set by operands.
   TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
-      hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
-               int64_t operand_index, HloInstruction* dynamic_size) {
+      hlo,
+      [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
+          int64_t operand_index, HloInstruction* dynamic_size) -> Status {
+        TF_RET_CHECK(operand_index == 0);
         if (dimension != hlo->dimension()) {
-          parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
+          SetDynamicSize(hlo, index, dimension, dynamic_size,
+                         /*clear_dynamic_dimension=*/false);
         }
         return OkStatus();
       }));
@@ -768,14 +976,17 @@
 Status DynamicDimensionInferenceVisitor::HandleDynamicConvolutionForward(
     HloInstruction* hlo, int64_t operand_index, int64_t dimension,
     HloInstruction* dynamic_size) {
+  if (!CanInfer(hlo)) {
+    return OkStatus();
+  }
   TF_RET_CHECK(operand_index == 0);
   const ConvolutionDimensionNumbers& dimension_numbers =
       hlo->convolution_dimension_numbers();
 
   if (dimension == dimension_numbers.input_batch_dimension()) {
     // Batch dimension is propagated without any changes.
-    parent_->SetDynamicSize(hlo, {}, dimension_numbers.output_batch_dimension(),
-                            dynamic_size);
+    SetDynamicSize(hlo, {}, dimension_numbers.output_batch_dimension(),
+                   dynamic_size);
     return OkStatus();
   }
 
@@ -793,8 +1004,8 @@
           dynamic_size, window_dim.size(), window_dim.window_dilation(),
           window_dim.stride(), hlo->padding_type());
       TF_RET_CHECK(window_dim.base_dilation() == 1);
-      parent_->SetDynamicSize(hlo, {}, output_spatial_dim,
-                              dynamic_window_dims.output_size);
+      SetDynamicSize(hlo, {}, output_spatial_dim,
+                     dynamic_window_dims.output_size);
       return OkStatus();
     }
   }
@@ -805,24 +1016,28 @@
 Status DynamicDimensionInferenceVisitor::HandleDynamicWindowSamePadding(
     HloInstruction* hlo, HloInstruction* dynamic_size, int64_t operand_index,
     int64_t dimension) {
+  if (!CanInfer(hlo)) {
+    return OkStatus();
+  }
   const Window& window = hlo->window();
   const WindowDimension& window_dim = window.dimensions(dimension);
   if (!window_util::IsTrivialWindowDimension(window_dim)) {
     DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
         dynamic_size, window_dim.size(), window_dim.window_dilation(),
         window_dim.stride(), PaddingType::PADDING_SAME);
-    parent_->SetDynamicSize(hlo, {}, dimension,
-                            dynamic_window_dims.output_size);
-    return OkStatus();
+    SetDynamicSize(hlo, {}, dimension, dynamic_window_dims.output_size);
+  } else {
+    SetDynamicSize(hlo, {}, dimension, dynamic_size);
   }
 
-  parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
-
   return OkStatus();
 }
 
 Status DynamicDimensionInferenceVisitor::HandleDynamicConvolutionInputGrad(
     HloInstruction* hlo, int64_t operand_index, int64_t dimension) {
+  if (!CanInfer(hlo)) {
+    return OkStatus();
+  }
   // The output size of convolution input grad is corresponding input size.
   HloInstruction* input_sizes = hlo->mutable_operand(0);
   HloComputation* comp = hlo->parent();
@@ -837,7 +1052,7 @@
                                   {dimension}, {dimension + 1}, {1}));
   HloInstruction* reshape = comp->AddInstruction(
       HloInstruction::CreateReshape(ShapeUtil::MakeScalarShape(S32), slice));
-  parent_->SetDynamicSize(hlo, {}, dimension, reshape);
+  SetDynamicSize(hlo, {}, dimension, reshape);
   return OkStatus();
 }
 
@@ -849,12 +1064,29 @@
 
 Status DynamicDimensionInferenceVisitor::PassThroughDynamicDimension(
     HloInstruction* hlo) {
-  return ForEachOperandDynamicDimension(
+  if (!CanInfer(hlo)) {
+    return OkStatus();
+  }
+  // TODO(b/298671312): This is ambiguous with respect to which operand provides
+  // the dynamic size.
+  ShapeTree<absl::InlinedVector<HloInstruction*, 2>> dynamic_sizes(
+      hlo->shape());
+  TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
       hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
                int64_t operand_index, HloInstruction* dynamic_size) {
-        parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
+        const Shape& subshape = ShapeUtil::GetSubshape(hlo->shape(), index);
+        auto* element = dynamic_sizes.mutable_element(index);
+        element->resize(subshape.rank(), nullptr);
+        element->at(dimension) = dynamic_size;
         return OkStatus();
-      });
+      }));
+  dynamic_sizes.ForEachElement([&](const ShapeIndex& index, const auto& sizes) {
+    if (sizes.empty()) {
+      return;
+    }
+    SetDynamicSizes(hlo, index, sizes);
+  });
+  return OkStatus();
 }
 
 Status DynamicDimensionInferenceVisitor::HandleDomain(HloInstruction* hlo) {
@@ -864,7 +1096,7 @@
 Status DynamicDimensionInferenceVisitor::HandleAsyncStart(HloInstruction* hlo) {
   if (!HloInstruction::IsThreadIncluded(hlo->async_execution_thread(),
                                         parent_->execution_threads_)) {
-    // Async-start not included in specificed execution thread set will use
+    // Async-start not included in specified execution thread set will use
     // metadata-prefix version of dynamic shapes (result of slice-to-dynamic) so
     // there is no need to propagate dynamic dimension info.
     return OkStatus();
@@ -872,74 +1104,110 @@
   return DefaultAction(hlo);
 }
 
+Status DynamicDimensionInferenceVisitor::HandleAsyncDone(HloInstruction* hlo) {
+  if (!HloInstruction::IsThreadIncluded(hlo->async_execution_thread(),
+                                        parent_->execution_threads_)) {
+    // Other threads can return a dynamic shape directly, so we may need to
+    // insert PadToStatic.
+    return InsertPadToStaticOnInstruction(hlo);
+  }
+  return DefaultAction(hlo);
+}
+
 Status DynamicDimensionInferenceVisitor::HandleElementwiseUnary(
     HloInstruction* hlo) {
   return PassThroughDynamicDimension(hlo);
 }
 
 Status DynamicDimensionInferenceVisitor::HandleSelect(HloInstruction* hlo) {
-  return PassThroughDynamicDimension(hlo);
+  return HandleElementwiseNary(hlo);
 }
 
 Status DynamicDimensionInferenceVisitor::HandleElementwiseNary(
     HloInstruction* hlo) {
+  if (!CanInfer(hlo)) {
+    return OkStatus();
+  }
   HloComputation* comp = hlo->parent();
-  return ForEachOperandDynamicDimension(
-      hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
-               int64_t operand_index, HloInstruction* dynamic_size) {
-        HloInstruction* existing_size =
-            parent_->GetDynamicSize(hlo, index, dimension);
-        if (existing_size == nullptr || existing_size == dynamic_size) {
-          parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
-        } else {
-          TF_RETURN_IF_ERROR(
-              InsertShapeCheck(existing_size, dynamic_size,
-                               /*support_implicit_broadcast=*/true));
-
-          auto one = comp->AddInstruction(
-              HloInstruction::CreateConstant(LiteralUtil::One(S32)));
-
-          auto operand_needs_broadcast =
-              comp->AddInstruction(HloInstruction::CreateCompare(
-                  ShapeUtil::MakeShape(PRED, {}), dynamic_size, existing_size,
-                  ComparisonDirection::kLt));
-          auto is_one = comp->AddInstruction(HloInstruction::CreateCompare(
-              ShapeUtil::MakeShape(PRED, {}), dynamic_size, one,
-              ComparisonDirection::kEq));
-          operand_needs_broadcast =
-              comp->AddInstruction(HloInstruction::CreateBinary(
-                  ShapeUtil::MakeShape(PRED, {}), HloOpcode::kAnd, is_one,
-                  operand_needs_broadcast));
-
-          auto existing_needs_broadcast =
-              comp->AddInstruction(HloInstruction::CreateCompare(
-                  ShapeUtil::MakeShape(PRED, {}), existing_size, dynamic_size,
-                  ComparisonDirection::kLt));
-          is_one = comp->AddInstruction(HloInstruction::CreateCompare(
-              ShapeUtil::MakeShape(PRED, {}), existing_size, one,
-              ComparisonDirection::kEq));
-          existing_needs_broadcast =
-              comp->AddInstruction(HloInstruction::CreateBinary(
-                  ShapeUtil::MakeShape(PRED, {}), HloOpcode::kAnd, is_one,
-                  existing_needs_broadcast));
-
-          auto needs_broadcast =
-              comp->AddInstruction(HloInstruction::CreateBinary(
-                  ShapeUtil::MakeShape(PRED, {}), HloOpcode::kOr,
-                  operand_needs_broadcast, existing_needs_broadcast));
-          auto max_size = comp->AddInstruction(HloInstruction::CreateBinary(
-              ShapeUtil::MakeScalarShape(S32), HloOpcode::kMaximum,
-              dynamic_size, existing_size));
-          auto min_size = comp->AddInstruction(HloInstruction::CreateBinary(
-              ShapeUtil::MakeScalarShape(S32), HloOpcode::kMinimum,
-              dynamic_size, existing_size));
-          auto select_size = comp->AddInstruction(HloInstruction::CreateTernary(
-              ShapeUtil::MakeScalarShape(S32), HloOpcode::kSelect,
-              needs_broadcast, max_size, min_size));
-          parent_->SetDynamicSize(hlo, index, dimension, select_size);
-        }
+  // First find all the dynamic sizes of the operands, and arrange them by
+  // dimension.
+  absl::InlinedVector<absl::InlinedVector<HloInstruction*, 2>, 2> operand_sizes(
+      hlo->shape().rank(),
+      absl::InlinedVector<HloInstruction*, 2>(hlo->operand_count(), nullptr));
+  TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
+      hlo,
+      [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
+          int64_t operand_index, HloInstruction* dynamic_size) -> Status {
+        TF_RET_CHECK(index.empty());
+        operand_sizes[dimension][operand_index] = dynamic_size;
         return OkStatus();
-      });
+      }));
+
+  absl::InlinedVector<HloInstruction*, 2> existing_sizes(hlo->shape().rank(),
+                                                         nullptr);
+  for (int operand_index = 0; operand_index < hlo->operand_count();
+       ++operand_index) {
+    for (int64_t dimension = 0; dimension < hlo->shape().rank(); ++dimension) {
+      HloInstruction* dynamic_size = operand_sizes[dimension][operand_index];
+      if (dynamic_size == nullptr) {
+        continue;
+      }
+      HloInstruction* existing_size = existing_sizes[dimension];
+      if (existing_size == nullptr) {
+        existing_sizes[dimension] = dynamic_size;
+      } else if (existing_sizes[dimension] != dynamic_size) {
+        TF_RETURN_IF_ERROR(
+            InsertShapeCheck(existing_size, dynamic_size,
+                             /*support_implicit_broadcast=*/true));
+
+        auto one = comp->AddInstruction(
+            HloInstruction::CreateConstant(LiteralUtil::One(S32)));
+
+        auto operand_needs_broadcast =
+            comp->AddInstruction(HloInstruction::CreateCompare(
+                ShapeUtil::MakeShape(PRED, {}), dynamic_size, existing_size,
+                ComparisonDirection::kLt));
+        auto is_one = comp->AddInstruction(HloInstruction::CreateCompare(
+            ShapeUtil::MakeShape(PRED, {}), dynamic_size, one,
+            ComparisonDirection::kEq));
+        operand_needs_broadcast =
+            comp->AddInstruction(HloInstruction::CreateBinary(
+                ShapeUtil::MakeShape(PRED, {}), HloOpcode::kAnd, is_one,
+                operand_needs_broadcast));
+
+        auto existing_needs_broadcast =
+            comp->AddInstruction(HloInstruction::CreateCompare(
+                ShapeUtil::MakeShape(PRED, {}), existing_size, dynamic_size,
+                ComparisonDirection::kLt));
+        is_one = comp->AddInstruction(HloInstruction::CreateCompare(
+            ShapeUtil::MakeShape(PRED, {}), existing_size, one,
+            ComparisonDirection::kEq));
+        existing_needs_broadcast =
+            comp->AddInstruction(HloInstruction::CreateBinary(
+                ShapeUtil::MakeShape(PRED, {}), HloOpcode::kAnd, is_one,
+                existing_needs_broadcast));
+
+        auto needs_broadcast =
+            comp->AddInstruction(HloInstruction::CreateBinary(
+                ShapeUtil::MakeShape(PRED, {}), HloOpcode::kOr,
+                operand_needs_broadcast, existing_needs_broadcast));
+        auto max_size = comp->AddInstruction(HloInstruction::CreateBinary(
+            ShapeUtil::MakeScalarShape(S32), HloOpcode::kMaximum, dynamic_size,
+            existing_size));
+        auto min_size = comp->AddInstruction(HloInstruction::CreateBinary(
+            ShapeUtil::MakeScalarShape(S32), HloOpcode::kMinimum, dynamic_size,
+            existing_size));
+        auto select_size = comp->AddInstruction(HloInstruction::CreateTernary(
+            ShapeUtil::MakeScalarShape(S32), HloOpcode::kSelect,
+            needs_broadcast, max_size, min_size));
+        existing_sizes[dimension] = select_size;
+      }
+    }
+  }
+
+  SetDynamicSizes(hlo, {}, existing_sizes);
+
+  return OkStatus();
 }
 
 Status DynamicDimensionInferenceVisitor::HandleElementwiseBinary(
@@ -953,18 +1221,28 @@
 
 Status DynamicDimensionInferenceVisitor::HandleDynamicReshape(
     HloInstruction* hlo) {
+  if (!CanInfer(hlo)) {
+    return OkStatus();
+  }
   HloDynamicReshapeInstruction* dynamic_reshape =
       Cast<HloDynamicReshapeInstruction>(hlo);
   for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
     if (hlo->shape().is_dynamic_dimension(i)) {
-      parent_->SetDynamicSize(hlo, {}, i, dynamic_reshape->dim_sizes(i));
+      SetDynamicSize(hlo, {}, i, dynamic_reshape->dim_sizes(i));
     }
   }
   return OkStatus();
 }
 
-Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) {
+Status DynamicDimensionInferenceVisitor::HandleReshape(
+    HloInstruction* const hlo) {
+  if (!CanInfer(hlo)) {
+    return OkStatus();
+  }
   VLOG(2) << "Handle reshape: " << hlo->ToString() << "\n";
+
+  absl::InlinedVector<HloInstruction*, 2> dynamic_sizes(hlo->shape().rank(),
+                                                        nullptr);
   using ReshapeGroup = std::pair<int64_t, int64_t>;
   using ReshapeGroupPair = std::pair<ReshapeGroup, ReshapeGroup>;
   auto is_reverse_reshape_group_pair =
@@ -1053,8 +1331,7 @@
                       auto hlo_dimension_index = op_dynamic_dimension -
                                                  orig_reshape_pair.first.first +
                                                  reshape_pair.second.first;
-                      parent_->SetDynamicSize(hlo, {}, hlo_dimension_index,
-                                              dynamic_size);
+                      dynamic_sizes[hlo_dimension_index] = dynamic_size;
                     }
                     return OkStatus();
                   }));
@@ -1131,10 +1408,11 @@
       dynamic_size = comp->AddInstruction(HloInstruction::CreateBinary(
           dynamic_size->shape(), HloOpcode::kDivide, dynamic_size,
           size_without_inferred_dim_hlo));
-      parent_->SetDynamicSize(hlo, {}, hlo->inferred_dimension(), dynamic_size);
+      dynamic_sizes[hlo->inferred_dimension()] = dynamic_size;
       VLOG(3)
-          << "Need to decopose a dynamic reshape to flatten-unflatten pair. "
+          << "Need to decompose a dynamic reshape to flatten-unflatten pair. "
           << comp->parent()->ToString();
+      SetDynamicSizes(hlo, {}, dynamic_sizes);
       return OkStatus();
     }
     return InternalError(
@@ -1143,12 +1421,12 @@
         hlo->ToString());
   }
 
-  return ForEachOperandDynamicDimension(
+  TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
       hlo,
       [&](HloInstruction* operand, ShapeIndex index,
           int64_t input_dynamic_dimension, int64_t operand_index,
           HloInstruction* operand_dynamic_size) -> Status {
-        HloInstruction* reshape = hlo;
+        HloInstruction* const reshape = hlo;
         if (reshape->shape().rank() == 0) {
           VLOG(0) << "Reshaping a dynamic dimension into a scalar, which has "
                      "undefined behavior when input size is 0. The offending "
@@ -1273,8 +1551,7 @@
 
         if (input_dim_size == output_dim_size) {
           // Simply forward dynamic dimension.
-          parent_->SetDynamicSize(reshape, {}, output_dynamic_dimension,
-                                  operand_dynamic_size);
+          dynamic_sizes[output_dynamic_dimension] = operand_dynamic_size;
         }
 
         if (input_dim_size > output_dim_size) {
@@ -1290,8 +1567,7 @@
                   operand_dynamic_size->shape(), HloOpcode::kDivide,
                   operand_dynamic_size, divisor_hlo));
 
-          parent_->SetDynamicSize(reshape, {}, output_dynamic_dimension,
-                                  new_dynamic_size);
+          dynamic_sizes[output_dynamic_dimension] = new_dynamic_size;
         }
 
         if (input_dim_size < output_dim_size) {
@@ -1309,7 +1585,7 @@
           //
           //
           HloInstruction* output_dynamic_size =
-              parent_->GetDynamicSize(reshape, {}, output_dynamic_dimension);
+              dynamic_sizes[output_dynamic_dimension];
           if (output_dynamic_size == nullptr) {
             output_dynamic_size =
                 hlo->parent()->AddInstruction(HloInstruction::CreateConstant(
@@ -1328,17 +1604,25 @@
               hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
                   output_dynamic_size->shape(), HloOpcode::kMultiply,
                   new_dynamic_size, operand_dynamic_size));
-          parent_->SetDynamicSize(reshape, {}, output_dynamic_dimension,
-                                  new_dynamic_size);
+          dynamic_sizes[output_dynamic_dimension] = new_dynamic_size;
         }
 
         return OkStatus();
-      });
+      }));
+
+  SetDynamicSizes(hlo, {}, dynamic_sizes);
+
+  return OkStatus();
 }
 
 Status DynamicDimensionInferenceVisitor::HandleReduceWindow(
     HloInstruction* hlo) {
-  return ForEachOperandDynamicDimension(
+  if (!CanInfer(hlo)) {
+    return OkStatus();
+  }
+  ShapeTree<absl::InlinedVector<HloInstruction*, 2>> dynamic_sizes(
+      hlo->shape());
+  TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
       hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
                int64_t operand_index, HloInstruction* dynamic_size) {
         auto* reduce_window = Cast<HloReduceWindowInstruction>(hlo);
@@ -1368,16 +1652,30 @@
                                           reduce_window_result_index)) {
                 return;
               }
-              parent_->SetDynamicSize(reduce_window, reduce_window_result_index,
-                                      dimension, dynamic_size);
+              auto* leaf_dynamic_sizes =
+                  dynamic_sizes.mutable_element(reduce_window_result_index);
+              leaf_dynamic_sizes->resize(subshape.rank(), nullptr);
+              leaf_dynamic_sizes->at(dimension) = dynamic_size;
             });
 
         return OkStatus();
+      }));
+  dynamic_sizes.ForEachElement(
+      [&](const ShapeIndex& shape_index,
+          const absl::InlinedVector<HloInstruction*, 2> sizes) {
+        if (sizes.empty()) {
+          return;
+        }
+        SetDynamicSizes(hlo, shape_index, sizes);
       });
+  return OkStatus();
 }
 
 Status DynamicDimensionInferenceVisitor::HandleSelectAndScatter(
     HloInstruction* hlo) {
+  if (!CanInfer(hlo)) {
+    return OkStatus();
+  }
   return ForEachOperandDynamicDimension(
       hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
                int64_t operand_index, HloInstruction* dynamic_size) {
@@ -1386,25 +1684,48 @@
           // dynamic size in the operand 1 (output gradient).
           return OkStatus();
         }
-        parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
+        SetDynamicSize(hlo, {}, dimension, dynamic_size);
 
         return OkStatus();
       });
 }
 
 Status DynamicDimensionInferenceVisitor::HandleSlice(HloInstruction* hlo) {
+  if (!CanInfer(hlo)) {
+    return OkStatus();
+  }
   return ForEachOperandDynamicDimension(
-      hlo, [&](HloInstruction* operand, ShapeIndex /*index*/, int64_t dimension,
-               int64_t /*operand_index*/, HloInstruction* dynamic_size) {
-        if (hlo->slice_starts(dimension) != 0 ||
-            hlo->slice_strides(dimension) != 1 ||
-            hlo->slice_limits(dimension) !=
-                operand->shape().dimensions(dimension)) {
-          // Slicing a partial element out eliminates the dynamic dimension.
+      hlo,
+      [&](HloInstruction* operand, ShapeIndex /*index*/, int64_t dimension,
+          int64_t /*operand_index*/, HloInstruction* dynamic_size) -> Status {
+        int64_t start = hlo->slice_starts(dimension);
+        int64_t limit = hlo->slice_limits(dimension);
+        int64_t stride = hlo->slice_strides(dimension);
+        int64_t size = CeilOfRatio<int64_t>(limit - start, stride);
+        if (size == 1) {
+          TF_RET_CHECK(!hlo->shape().is_dynamic_dimension(dimension));
+          // Slicing a single element out eliminates the dynamic dimension.
           return OkStatus();
         }
 
-        parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
+        TF_RET_CHECK(hlo->shape().is_dynamic_dimension(dimension));
+        if (start != 0) {
+          dynamic_size = hlo->AddInstruction(HloInstruction::CreateBinary(
+              dynamic_size->shape(), HloOpcode::kSubtract, dynamic_size,
+              hlo->AddInstruction(HloInstruction::CreateConstant(
+                  LiteralUtil::CreateR0<int32_t>(start)))));
+        }
+        if (stride != 1) {
+          dynamic_size = hlo->AddInstruction(HloInstruction::CreateBinary(
+              dynamic_size->shape(), HloOpcode::kAdd, dynamic_size,
+              hlo->AddInstruction(HloInstruction::CreateConstant(
+                  LiteralUtil::CreateR0<int32_t>(stride - 1)))));
+          dynamic_size = hlo->AddInstruction(HloInstruction::CreateBinary(
+              dynamic_size->shape(), HloOpcode::kDivide, dynamic_size,
+              hlo->AddInstruction(HloInstruction::CreateConstant(
+                  LiteralUtil::CreateR0<int32_t>(stride)))));
+        }
+        SetDynamicSize(hlo, {}, dimension, dynamic_size);
 
         return OkStatus();
       });
@@ -1412,22 +1733,30 @@
 
 Status DynamicDimensionInferenceVisitor::HandleDynamicSlice(
     HloInstruction* hlo) {
+  if (!CanInfer(hlo)) {
+    return OkStatus();
+  }
   return ForEachOperandDynamicDimension(
-      hlo, [&](HloInstruction*, ShapeIndex /*index*/, int64_t dimension,
-               int64_t /*operand_index*/, HloInstruction* dynamic_size) {
+      hlo,
+      [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
+          int64_t operand_index, HloInstruction* dynamic_size) -> Status {
+        // Slicing a single element out kills the dynamic dimension.
+        if (hlo->shape().dimensions(dimension) == 1) {
+          return OkStatus();
+        }
         if (hlo->shape().dimensions(dimension) !=
             hlo->operand(0)->shape().dimensions(dimension)) {
-          // Slicing a single element out kills the dynamic dimension.
-          if (hlo->shape().dimensions(dimension) == 1) {
-            return OkStatus();
-          }
           return Unimplemented(
               "Dynamic dimension propagation on DynamicSlice where a partial "
               "dimension is selected %s",
               hlo->ToString());
         }
 
-        parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
+        // Only the base operand should be dynamic (since the rest are scalars).
+        TF_RET_CHECK(operand_index == 0);
+
+        TF_RET_CHECK(index.empty());
+        SetDynamicSize(hlo, {}, dimension, dynamic_size);
 
         return OkStatus();
       });
@@ -1435,10 +1764,17 @@
 
 Status DynamicDimensionInferenceVisitor::HandleDynamicUpdateSlice(
     HloInstruction* hlo) {
-  return ForEachOperandDynamicDimension(
+  if (!CanInfer(hlo)) {
+    return OkStatus();
+  }
+  absl::InlinedVector<HloInstruction*, 2> output_dynamic_sizes(
+      hlo->shape().rank(), nullptr);
+  TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
       hlo,
-      [&](HloInstruction* /*operand*/, ShapeIndex /*index*/, int64_t dimension,
-          int64_t operand_index, HloInstruction* dynamic_size) {
+      [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
+          int64_t operand_index, HloInstruction* dynamic_size) -> Status {
+        TF_RET_CHECK(index.empty());
+
         if (hlo->shape().dimensions(dimension) !=
             hlo->operand(0)->shape().dimensions(dimension)) {
           return Unimplemented(
@@ -1456,13 +1792,16 @@
           // a partial update, no need to set the output dynamic dimension.
           //
           // The dynamic shape in `update` doesn't change output dynamic shape.
+          hlo->mutable_shape()->set_dynamic_dimension(dimension, false);
           return OkStatus();
         }
 
-        parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
+        output_dynamic_sizes[dimension] = dynamic_size;
 
         return OkStatus();
-      });
+      }));
+  SetDynamicSizes(hlo, {}, output_dynamic_sizes);
+  return OkStatus();
 }
 
 Status DynamicDimensionInferenceVisitor::HandleReverse(HloInstruction* hlo) {
@@ -1470,13 +1809,19 @@
 }
 
 Status DynamicDimensionInferenceVisitor::HandleGather(HloInstruction* hlo) {
-  return ForEachOperandDynamicDimension(
-      hlo, [&](HloInstruction* operand, ShapeIndex /*index*/,
-               int64_t input_dynamic_dimension, int64_t operand_index,
-               HloInstruction* dynamic_size) {
+  if (!CanInfer(hlo)) {
+    return OkStatus();
+  }
+  absl::InlinedVector<HloInstruction*, 2> output_dynamic_sizes(
+      hlo->shape().rank(), nullptr);
+  TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
+      hlo,
+      [&](HloInstruction* operand, ShapeIndex /*index*/,
+          int64_t input_dynamic_dimension, int64_t operand_index,
+          HloInstruction* dynamic_size) -> Status {
         const GatherDimensionNumbers& gather_dims =
             hlo->gather_dimension_numbers();
-        if (operand_index != 1) {
+        if (operand_index == 0) {
           if (hlo->gather_slice_sizes()[input_dynamic_dimension] == 1) {
             // Gathering a size 1 dimension out of a dynamic dimension removes
             // the dynamicity.
@@ -1484,26 +1829,32 @@
           }
           if (hlo->gather_slice_sizes()[input_dynamic_dimension] ==
               operand->shape().dimensions(input_dynamic_dimension)) {
-            // Gathering a full-sized dimension out of a dynamic dimension
-            // propagates the dynamicity to output.
-            int64_t output_dimension = input_dynamic_dimension;
-            for (int64_t collapsed_dim : gather_dims.collapsed_slice_dims()) {
-              if (collapsed_dim < input_dynamic_dimension) {
-                // This output dimension is collapsed.
-                output_dimension--;
+            int64_t operand_dimension = 0;
+            for (int64_t output_dimension : gather_dims.offset_dims()) {
+              TF_RET_CHECK(output_dimension < hlo->shape().rank());
+              while (operand_dimension < operand->shape().rank() &&
+                     absl::c_linear_search(gather_dims.collapsed_slice_dims(),
+                                           operand_dimension)) {
+                ++operand_dimension;
               }
+              TF_RET_CHECK(operand_dimension < operand->shape().rank());
+              if (operand_dimension == input_dynamic_dimension) {
+                output_dynamic_sizes[output_dimension] = dynamic_size;
+                return OkStatus();
+              }
+              ++operand_dimension;
             }
-            parent_->SetDynamicSize(hlo, {}, output_dimension, dynamic_size);
-            return OkStatus();
+            return InternalError("Invalid instruction: %s", hlo->ToString());
           }
           return Unimplemented(
               "Detects a dynamic dimension on the data input of gather, which "
               "is not supported: %s, %lld",
               hlo->ToString(), input_dynamic_dimension);
         }
-        // A mapping from output to input batch dim number. -1 means not a batch
-        // dimension.
         int64_t indices_rank = hlo->operand(1)->shape().rank();
+        if (gather_dims.index_vector_dim() == indices_rank) {
+          ++indices_rank;
+        }
         int64_t output_rank = hlo->shape().rank();
 
         // indices_dim is an iterator over indices dimensions.
@@ -1516,7 +1867,7 @@
               indices_dim++;
             }
             if (indices_dim++ == input_dynamic_dimension) {
-              parent_->SetDynamicSize(hlo, {}, output_dim, dynamic_size);
+              output_dynamic_sizes[output_dim] = dynamic_size;
               return OkStatus();
             }
           }
@@ -1527,11 +1878,16 @@
             "Detects a non-batch dynamic dimension of gather, "
             "which is not supported: %s",
             hlo->ToString());
-      });
+      }));
+  SetDynamicSizes(hlo, {}, output_dynamic_sizes);
+  return OkStatus();
 }
 
 Status DynamicDimensionInferenceVisitor::HandleConditional(
     HloInstruction* hlo) {
+  if (!CanInfer(hlo)) {
+    return OkStatus();
+  }
   // Conditionals are handled by producing additional inputs and outputs of
   // the conditional instruction.
   std::vector<HloComputation*> new_branch_computations;
@@ -1540,7 +1896,7 @@
   // dynamic dimension size out by adding additional root element. A mapping
   // from the root instruction's dynamic dimension index (represented by a shape
   // index as output index and a int64_t dimension number) to output index
-  // (represented by an int64_t) is tracked for the conditional intsruction (all
+  // (represented by an int64_t) is tracked for the conditional instruction (all
   // branches should have the same mapping).
   ShapeTree<absl::flat_hash_map<int64_t, int64_t>> dynamic_output_mapping(
       hlo->shape());
@@ -1589,14 +1945,20 @@
     HloComputation* branch_computation = hlo->branch_computation(branch_index);
 
     HloComputation* new_computation = branch_computation;
+    CallInliner::InlinedInstructionMap inline_map;
     HloInstruction* new_operand = hlo->mutable_operand(operand_index);
+    Shape new_param_shape =
+        branch_computation->parameter_instruction(0)->shape();
     if (!operands_to_add.empty()) {
       TF_RET_CHECK(original_input->shape().IsTuple());
       need_rewrite = true;
       new_operand = TupleUtil::AppendSuffix(original_input, operands_to_add);
+      for (HloInstruction* operand : operands_to_add) {
+        ShapeUtil::AppendShapeToTuple(operand->shape(), &new_param_shape);
+      }
       TF_ASSIGN_OR_RETURN(
-          new_computation,
-          WidenComputation(branch_computation, new_operand->shape()));
+          std::tie(new_computation, inline_map),
+          WidenComputation(branch_computation, new_param_shape));
     }
     // Set the dynamic dimensions for the newly created branch computation's
     // parameters so that the hlos inside the computation can see dynamic
@@ -1606,7 +1968,7 @@
         hlo, operand_index,
         [&](HloInstruction*, ShapeIndex index, int64_t dimension,
             int64_t operand_index, HloInstruction* dynamic_size) {
-          DynamicParameterBinding::DynamicParameter dynamic_parameter{
+          DynamicParameterBinding::DynamicSizeParameter dynamic_parameter{
               0, {dynamic_size_to_operand_id_index_map[dynamic_size]}};
           DynamicParameterBinding::DynamicDimension dynamic_dimension{
               0, {index}, dimension};
@@ -1617,8 +1979,23 @@
         }));
     VLOG(2) << "dynamic_parameter_binding for conditional branch"
             << dynamic_parameter_binding;
-    TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run(
-        new_computation, dynamic_parameter_binding, parent_));
+
+    for (auto [old_inst, new_inst] : inline_map) {
+      parent_->CopyMapping(
+          /*from=*/old_inst,
+          /*to=*/new_inst,
+          /*dynamic_size_map=*/&inline_map);
+    }
+
+    TF_ASSIGN_OR_RETURN(
+        bool changed,
+        DynamicDimensionInferenceVisitor::Run(
+            new_computation, dataflow_analysis_, dynamic_parameter_binding,
+            parent_, custom_call_handler_, shape_check_mode_,
+            assertion_generator_));
+    if (changed) {
+      MarkAsChanged();
+    }
 
     new_branch_computations.push_back(new_computation);
     new_operands.push_back(new_operand);
@@ -1713,9 +2090,10 @@
               HloInstruction::CreateGetTupleElement(
                   ShapeUtil::MakeScalarShape(S32), new_conditional,
                   output_index));
-          parent_->SetDynamicSize(new_conditional, index, dim, dynamic_size);
-          parent_->SetDynamicSize(new_conditional_extracted, index, dim,
-                                  dynamic_size);
+          SetDynamicSize(new_conditional, index, dim, dynamic_size,
+                         /*clear_dynamic_dimension=*/false);
+          SetDynamicSize(new_conditional_extracted, index, dim, dynamic_size,
+                         /*clear_dynamic_dimension=*/false);
         }
       });
 
@@ -1724,20 +2102,27 @@
   TF_RETURN_IF_ERROR(hlo->parent()->RemoveInstruction(hlo));
   SetVisited(*new_conditional);
   SetVisited(*new_conditional_extracted);
+  MarkAsChanged();
   return OkStatus();
 }
 
 Status DynamicDimensionInferenceVisitor::HandleMap(HloInstruction* hlo) {
+  if (!CanInfer(hlo)) {
+    return OkStatus();
+  }
   return HandleElementwiseNary(hlo);
 }
 
 Status DynamicDimensionInferenceVisitor::HandleScatter(HloInstruction* hlo) {
+  if (!CanInfer(hlo)) {
+    return OkStatus();
+  }
   return ForEachOperandDynamicDimension(
       hlo,
       [&](HloInstruction* operand, ShapeIndex dynamic_index, int64_t dimension,
           int64_t operand_index, HloInstruction* operand_dynamic_size) {
         if (operand_index == 0) {
-          parent_->SetDynamicSize(hlo, {}, dimension, operand_dynamic_size);
+          SetDynamicSize(hlo, {}, dimension, operand_dynamic_size);
           return OkStatus();
         }
 
@@ -1797,101 +2182,101 @@
 }
 
 Status DynamicDimensionInferenceVisitor::HandleWhile(HloInstruction* hlo) {
+  if (!CanInfer(hlo)) {
+    return OkStatus();
+  }
   // If the output of the kWhile contains dynamic dimension, we send
   // dynamic dimension size into the while body by adding additional root/body
   // element. A mapping from the root instruction's dynamic dimension index
   // (represented by a shape index as output index and an int64_t dimension
   // number) to output index (represented by an int64_t) is tracked for the
   // while instruction.
+  Shape original_shape = hlo->shape();
   ShapeTree<absl::flat_hash_map<int64_t, int64_t>> dynamic_output_mapping(
-      hlo->shape());
+      original_shape);
   std::vector<HloInstruction*> operands_to_add;
-  const int original_tuple_count = hlo->shape().tuple_shapes_size();
+  const int original_tuple_count = original_shape.tuple_shapes_size();
   int operand_count = original_tuple_count;
+  // Clean up the result shape
+  DynamicParameterBinding binding_for_while;
   TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
-      hlo, [&](HloInstruction*, ShapeIndex index, int64_t dim, int64_t,
-               HloInstruction* dynamic_size) {
+      hlo,
+      [&](HloInstruction* operand, ShapeIndex index, int64_t dim,
+          int64_t operand_num, HloInstruction* dynamic_size) -> Status {
+        TF_RET_CHECK(operand_num == 0);
         operands_to_add.push_back(dynamic_size);
         dynamic_output_mapping.mutable_element(index)->emplace(dim,
-                                                               operand_count++);
+                                                               operand_count);
+        DynamicParameterBinding::DynamicDimension dynamic_dimension{
+            /*parameter_num=*/0,
+            /*parameter_index=*/index,
+            /*dimension=*/dim,
+        };
+        DynamicParameterBinding::DynamicSizeParameter dynamic_size_param{
+            /*parameter_num=*/0,
+            /*parameter_index=*/{operand_count},
+        };
+        TF_RETURN_IF_ERROR(
+            binding_for_while.Bind(dynamic_size_param, dynamic_dimension));
+        ++operand_count;
         return OkStatus();
       }));
-  ShapeUtil::ForEachSubshape(
-      hlo->while_body()->root_instruction()->shape(),
-      [&](const Shape& subshape, const ShapeIndex& index) {
-        if (!subshape.IsArray()) {
-          return;
-        }
-        for (int64_t dim = 0; dim < subshape.rank(); ++dim) {
-          if (subshape.is_dynamic_dimension(dim)) {
-            if (!dynamic_output_mapping.mutable_element(index)->contains(dim)) {
-              // This dynamic dimension doesn't come from operand, but is
-              // generated in the middle of the while body. Its initial size
-              // should be static.
-              operands_to_add.push_back(hlo->parent()->AddInstruction(
-                  HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(
-                      subshape.dimensions(dim)))));
-              dynamic_output_mapping.mutable_element(index)->emplace(
-                  dim, operand_count++);
-            }
-          }
-        }
-      });
-  DynamicParameterBinding binding_for_while;
-  if (!operands_to_add.empty()) {
-    // Only replace the while loop if there are new parameters to add.
-    HloInstruction* old_tuple_operand = hlo->mutable_operand(0);
-    TF_ASSIGN_OR_RETURN(
-        WhileUtil::MakeInstructionsLiveInResult result,
-        WhileUtil::MakeInstructionsLiveIn(hlo, operands_to_add));
-    // WhileUtil creates a new while hlo and tuple. Update the dynamic size
-    // mapping for the newly created tuple.
-    HloInstruction* new_tuple_operand =
-        result.new_while_instr->mutable_operand(0);
-    parent_->CopyMapping(/*from=*/old_tuple_operand,
-                         /*to=*/new_tuple_operand);
-    hlo = result.new_while_instr;
-    // We have replaced the while loop, now set the dynamic dimensions for the
-    // newly created while loop so that the hlos that consumes the while loop
-    // can see the dynamic dimensions. Also sets the dynamic parameter binding
-    // for running inference in the while loop.
-    TF_RETURN_IF_ERROR(dynamic_output_mapping.ForEachElementWithStatus(
-        [&](const ShapeIndex& index,
-            const absl::flat_hash_map<int64_t, int64_t>& dim_to_size) {
-          for (auto key : dim_to_size) {
-            int64_t dimension = key.first;
-            const int64_t output_dynamic_size_index = key.second;
-            DynamicParameterBinding::DynamicParameter dynamic_parameter{
-                0, {output_dynamic_size_index}};
-            DynamicParameterBinding::DynamicDimension dynamic_dimension{
-                0, index, dimension};
-            TF_RETURN_IF_ERROR(
-                binding_for_while.Bind(dynamic_parameter, dynamic_dimension));
-            // This is the updated output dynamic size coming out of hlo while
-            // loop.
-            HloInstruction* output_dynamic_size = hlo->parent()->AddInstruction(
-                HloInstruction::CreateGetTupleElement(
-                    ShapeUtil::MakeScalarShape(S32), hlo,
-                    output_dynamic_size_index));
-            parent_->SetDynamicSize(result.replacement_instr, index, dimension,
-                                    output_dynamic_size);
-          }
-          return OkStatus();
-        }));
-    // Set the replacement instruction as visited to avoid visiting it again.
-    SetVisited(*result.replacement_instr);
-  }
-  // Run inference in while body and condition.
-  TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run(
-      hlo->while_body(), binding_for_while, parent_));
-  TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run(
-      hlo->while_condition(), binding_for_while, parent_));
-
   if (operands_to_add.empty()) {
-    // No dynamic dimension in the inputs and outputs.
     return OkStatus();
   }
 
+  HloInstruction* old_tuple_operand = hlo->mutable_operand(0);
+  HloInstruction* old_body_root = hlo->while_body()->root_instruction();
+  // HloInstruction* old_body_parameter =
+  //     hlo->while_body()->parameter_instruction(0);
+  // HloInstruction* old_condition_parameter =
+  //     hlo->while_condition()->parameter_instruction(0);
+  TF_ASSIGN_OR_RETURN(WhileUtil::MakeInstructionsLiveInResult result,
+                      WhileUtil::MakeInstructionsLiveIn(hlo, operands_to_add));
+  TF_RET_CHECK(result.replacement_instr->opcode() == HloOpcode::kTuple);
+  // WhileUtil creates a new while hlo and tuple. Update the dynamic size
+  // mapping for the newly created tuple.
+  HloInstruction* new_tuple_operand =
+      result.new_while_instr->mutable_operand(0);
+  parent_->CopyMapping(/*from=*/old_tuple_operand,
+                       /*to=*/new_tuple_operand);
+
+  hlo = result.new_while_instr;
+
+  // Set the replacement instruction as visited to avoid visiting it again.
+  SetVisited(*hlo);
+
+  for (auto [old_inst, new_inst] : result.while_body_instruction_map) {
+    parent_->CopyMapping(
+        /*from=*/old_inst,
+        /*to=*/new_inst,
+        /*dynamic_size_map=*/&result.while_body_instruction_map);
+  }
+  // MakeInstructionsLiveIn does not include the new root tuple in the
+  // instruction map, so we have to copy the mapping here.
+  parent_->CopyMapping(/*from=*/old_body_root,
+                       /*to=*/hlo->while_body()->root_instruction(),
+                       &result.while_body_instruction_map);
+  for (auto [old_inst, new_inst] : result.while_condition_instruction_map) {
+    parent_->CopyMapping(
+        /*from=*/old_inst,
+        /*to=*/new_inst,
+        /*dynamic_size_map=*/&result.while_condition_instruction_map);
+  }
+
+  // Rerun inference on the body and condition now that we have added dynamic
+  // size parameters.
+  TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run(
+                         hlo->while_body(), dataflow_analysis_,
+                         binding_for_while, parent_, custom_call_handler_,
+                         shape_check_mode_, assertion_generator_)
+                         .status());
+  TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run(
+                         hlo->while_condition(), dataflow_analysis_,
+                         binding_for_while, parent_, custom_call_handler_,
+                         shape_check_mode_, assertion_generator_)
+                         .status());
+
   // The dynamic dimension size could have been changed in the loop body (e.g, A
   // loop that inserts items in a stack, the stack size increases with each
   // iteration). Rewrite the dynamic dimension size at the root.
@@ -1902,10 +2287,10 @@
   // Original non-dynamic-dim operands of root are pass-through.
   for (int i = 0; i < original_tuple_count; ++i) {
     new_root_operands[i] =
-        hlo->while_body()->AddInstruction(HloInstruction::CreateGetTupleElement(
+        body_root->AddInstruction(HloInstruction::CreateGetTupleElement(
             body_root->shape().tuple_shapes(i), body_root, i));
   }
-  // Add dynamic dimension size as new parameters.
+  // Add dynamic dimension size as new outputs of the while loop body.
   TF_RETURN_IF_ERROR(dynamic_output_mapping.ForEachElementWithStatus(
       [&](const ShapeIndex& index,
           const absl::flat_hash_map<int64_t, int64_t>& dim_to_size) -> Status {
@@ -1923,37 +2308,75 @@
   }
   HloInstruction* new_body_root = hlo->while_body()->AddInstruction(
       HloInstruction::CreateTuple(new_root_operands));
+  for (int i = 0; i < original_tuple_count; ++i) {
+    TF_RETURN_IF_ERROR(ForEachDynamicDimension(
+        body_root,
+        [&](ShapeIndex index, int64_t dimension,
+            HloInstruction* dynamic_size) -> Status {
+          SetDynamicSize(new_body_root, index, dimension, dynamic_size);
+          if (index.empty() || index.front() != i) {
+            return OkStatus();
+          }
+          index.pop_front();
+          SetDynamicSize(new_root_operands[i], index, dimension, dynamic_size);
+          return OkStatus();
+        }));
+  }
   hlo->while_body()->set_root_instruction(new_body_root);
-  return OkStatus();
+  MarkAsChanged();
+
+  // Record the dynamic sizes of while loop output.
+  return dynamic_output_mapping.ForEachElementWithStatus(
+      [&](const ShapeIndex& index,
+          const absl::flat_hash_map<int64_t, int64_t>& dim_to_size) -> Status {
+        for (auto [dimension, output_index] : dim_to_size) {
+          HloInstruction* dynamic_size = hlo->AddInstruction(
+              HloInstruction::CreateGetTupleElement(hlo, output_index));
+          SetDynamicSize(result.replacement_instr, index, dimension,
+                         dynamic_size);
+          ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index)
+              ->set_dynamic_dimension(dimension, false);
+          TF_RET_CHECK(!index.empty());
+          HloInstruction* gte =
+              result.replacement_instr->mutable_operand(index.front());
+          TF_RET_CHECK(gte->opcode() == HloOpcode::kGetTupleElement);
+          TF_RET_CHECK(gte->operand(0) == hlo);
+          ShapeUtil::GetMutableSubshape(gte->mutable_shape(),
+                                        ShapeIndexView(index).subspan(1))
+              ->set_dynamic_dimension(dimension, false);
+        }
+        return OkStatus();
+      });
 }
 
 Status DynamicDimensionInferenceVisitor::HandleParameter(HloInstruction* hlo) {
+  if (hlo->parent()->IsEntryComputation()) {
+    TF_RET_CHECK(param_bindings_.empty());
+    return InsertPadToStaticOnInstruction(hlo);
+  }
+
   return param_bindings_.ForEachBinding(
-      [&](const DynamicParameterBinding::DynamicParameter& dynamic_parameter,
-          const DynamicParameterBinding::DynamicDimension& dynamic_dimension) {
-        if (dynamic_dimension.parameter_num != hlo->parameter_number()) {
-          return OkStatus();
+      [&](const DynamicParameterBinding::DynamicSizeParameter& dynamic_size,
+          const DynamicParameterBinding::DynamicDimension& dynamic_dimension)
+          -> Status {
+        if (dynamic_dimension.parameter_num == hlo->parameter_number()) {
+          SetDynamicSize(
+              hlo, dynamic_dimension.parameter_index,
+              dynamic_dimension.dimension,
+              TupleUtil::AddGetTupleElements(HloPosition{
+                  /*instruction=*/hlo->parent()->parameter_instruction(
+                      dynamic_size.parameter_num),
+                  /*index=*/dynamic_size.parameter_index,
+              }));
         }
-        HloComputation* computation = hlo->parent();
-        HloInstruction* target_parameter =
-            computation->parameter_instruction(dynamic_dimension.parameter_num);
-
-        HloInstruction* dynamic_size =
-            computation->parameter_instruction(dynamic_parameter.parameter_num);
-        for (int64_t i : dynamic_parameter.parameter_index) {
-          dynamic_size =
-              computation->AddInstruction(HloInstruction::CreateGetTupleElement(
-                  ShapeUtil::GetSubshape(dynamic_size->shape(), {i}),
-                  dynamic_size, i));
-        }
-
-        parent_->SetDynamicSize(target_parameter,
-                                dynamic_dimension.parameter_index,
-                                dynamic_dimension.dimension, dynamic_size);
         return OkStatus();
       });
 }
 
+Status DynamicDimensionInferenceVisitor::HandleInfeed(HloInstruction* hlo) {
+  return InsertPadToStaticOnInstruction(hlo);
+}
+
 Status DynamicDimensionInferenceVisitor::ForEachDynamicDimension(
     HloInstruction* inst, const DynamicDimensionFn& fn) {
   auto iter = parent_->per_hlo_dynamic_dimensions_.find(inst);
@@ -1969,6 +2392,183 @@
   return OkStatus();
 }
 
+StatusOr<bool> DynamicDimensionInferenceVisitor::RequiresPadToStatic(
+    HloInstruction* instr, ShapeIndex shape_index) {
+  TF_RET_CHECK(ShapeUtil::IsLeafIndex(instr->shape(), shape_index))
+      << instr->shape() << " @ " << shape_index;
+  if (ShapeUtil::GetSubshape(instr->shape(), shape_index).is_static()) {
+    return false;
+  }
+  auto uses =
+      dataflow_analysis_.GetValueDefinedAt(instr, shape_index).GetUses();
+  for (const auto& use : uses) {
+    if (use.instruction->opcode() == HloOpcode::kAsyncStart ||
+        use.instruction->opcode() == HloOpcode::kAsyncUpdate ||
+        use.instruction->opcode() == HloOpcode::kAsyncDone ||
+        use.instruction->opcode() == HloOpcode::kCall ||
+        use.instruction->opcode() == HloOpcode::kTuple ||
+        use.instruction->opcode() == HloOpcode::kGetTupleElement ||
+        use.instruction->opcode() == HloOpcode::kConditional) {
+      // These uses do not require padding as they do not operate the data.
+      continue;
+    }
+    if (use.instruction->opcode() == HloOpcode::kWhile) {
+      TF_RET_CHECK(use.operand_number == 0);
+      HloInstruction* root = use.instruction->while_body()->root_instruction();
+      if (parent_->HasDynamicDimension(root, use.operand_index)) {
+        return true;
+      }
+      continue;
+    }
+    if (use.instruction->opcode() == HloOpcode::kSetDimensionSize) {
+      // The dynamic size cannot itself be dynamic.
+      TF_RET_CHECK(use.operand_number == 0);
+      // SetDimensionSize will be removed, so the array must be padded if it
+      // is a user of the array.
+      return true;
+    }
+    if (use.instruction->opcode() == HloOpcode::kGetDimensionSize) {
+      return true;
+    }
+    if (use.instruction->opcode() != HloOpcode::kCustomCall ||
+        use.instruction->custom_call_target() != "PadToStatic") {
+      if (parent_->op_supports_dynamism_handler_ == nullptr) {
+        return true;
+      }
+      if (parent_->op_supports_dynamism_handler_(use.instruction) ==
+          OpDynamismSupport::kNoSupport) {
+        return true;
+      }
+    }
+  }
+
+  // Don't do pad-to-static.
+  return false;
+}
+
+// Insert pad-to-static after `inst` if `inst` has dynamic dimensions in it.
+// If the instruction produces a tuple, each tuple component will be considered
+// independently.
+Status DynamicDimensionInferenceVisitor::InsertPadToStaticOnInstruction(
+    HloInstruction* inst) {
+  if (inst->shape().is_static()) {
+    return OkStatus();
+  }
+
+  // Decide while leaf arrays need to be padded.
+  ShapeTree<bool> needs_pad(inst->shape(), false);
+  bool any_needs_pad = false;
+  TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
+      inst->shape(), [&](const Shape& subshape, const ShapeIndex& shape_index) {
+        if (subshape.IsTuple()) {
+          return OkStatus();
+        }
+        TF_ASSIGN_OR_RETURN(bool do_pad,
+                            RequiresPadToStatic(inst, shape_index));
+        if (do_pad) {
+          *needs_pad.mutable_element(shape_index) = true;
+          any_needs_pad = true;
+        }
+        return OkStatus();
+      }));
+
+  if (!any_needs_pad) {
+    return OkStatus();
+  }
+
+  auto users = inst->users();
+
+  ShapeTree<HloInstruction*> gtes =
+      TupleUtil::DisassembleTupleInstruction(inst);
+
+  // Add PadToStatic to the leaf arrays and record the dynamic dimensions.
+  ShapeTree<HloInstruction*> padded(inst->shape(), nullptr);
+  TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapePostOrderWithStatus(
+      inst->shape(),
+      [&](const Shape& subshape, const ShapeIndex& shape_index) -> Status {
+        HloInstruction* element = gtes.element(shape_index);
+        SetVisited(*gtes.element(shape_index));
+        if (subshape.IsTuple()) {
+          absl::InlinedVector<HloInstruction*, 2> children;
+          ShapeIndex child_index = shape_index;
+          for (int i = 0; i < subshape.tuple_shapes_size(); ++i) {
+            child_index.push_back(i);
+            children.push_back(padded.element(child_index));
+            child_index.pop_back();
+          }
+          HloInstruction* tuple =
+              element->AddInstruction(HloInstruction::CreateVariadic(
+                  subshape, HloOpcode::kTuple, children));
+          TF_CHECK_OK(ForEachOperandDynamicDimension(
+              tuple,
+              [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
+                  int64_t operand_index, HloInstruction* dynamic_size) {
+                index.push_front(operand_index);
+                SetDynamicSize(tuple, index, dimension, dynamic_size);
+                return OkStatus();
+              }));
+          *padded.mutable_element(shape_index) = tuple;
+          return OkStatus();
+        }
+        if (needs_pad.element(shape_index)) {
+          // The output shape of pad static is a tuple. The 0th element is the
+          // data output, which is the same as input shape, but without
+          // dynamic dimensions; i-th element is the dynamic dimension size
+          // for i-1th input dimension.
+          Shape data_output_shape =
+              ShapeUtil::MakeStaticShape(element->shape());  // 0th element.
+          Shape output_shape = ShapeUtil::MakeTupleShape({data_output_shape});
+          for (int64_t i = 0; i < element->shape().rank(); ++i) {
+            ShapeUtil::AppendShapeToTuple(ShapeUtil::MakeScalarShape(S32),
+                                          &output_shape);
+          }
+          HloInstruction* pad_to_static = inst->parent()->AddInstruction(
+              HloInstruction::CreateCustomCall(output_shape, {element},
+                                               "PadToStatic"),
+              absl::StrCat(element->name(), ".padded"));
+          SetVisited(*pad_to_static);
+          HloInstruction* data_output = inst->parent()->AddInstruction(
+              HloInstruction::CreateGetTupleElement(data_output_shape,
+                                                    pad_to_static, 0),
+              absl::StrCat(element->name(), ".data"));
+          SetVisited(*data_output);
+          for (int64_t i = 0; i < element->shape().rank(); ++i) {
+            if (!element->shape().is_dynamic_dimension(i)) {
+              continue;
+            }
+            HloInstruction* dynamic_size_output =
+                inst->parent()->AddInstruction(
+                    HloInstruction::CreateGetTupleElement(
+                        output_shape.tuple_shapes(i + 1), pad_to_static, i + 1),
+                    absl::StrCat(element->name(), ".size"));
+            SetVisited(*dynamic_size_output);
+            SetDynamicSize(data_output, {}, i, dynamic_size_output,
+                           /*clear_dynamic_dimension=*/false);
+          }
+          *padded.mutable_element(shape_index) = data_output;
+        } else {
+          *padded.mutable_element(shape_index) = element;
+        }
+        return OkStatus();
+      }));
+
+  HloInstruction* result = padded.element({});
+
+  // Replace all uses of the original instruction with the padded outputs.
+  for (auto user : users) {
+    for (int64_t i : user->OperandIndices(inst)) {
+      TF_RETURN_IF_ERROR(user->ReplaceOperandWith(i, result));
+    }
+  }
+  if (inst->IsRoot()) {
+    inst->parent()->set_root_instruction(result);
+  }
+
+  MarkAsChanged();
+
+  return OkStatus();
+}
+
 Status DynamicDimensionInferenceVisitor::InsertShapeCheck(
     HloInstruction* dim1, HloInstruction* dim2,
     bool support_implicit_broadcast) {
@@ -2029,28 +2629,39 @@
                                                const ShapeIndex& index,
                                                int64_t dim,
                                                HloInstruction* size) {
+  CHECK_NE(inst, nullptr);
+  CHECK_NE(size, nullptr);
   VLOG(1) << "Set dimension inst " << inst->ToString() << " index "
           << index.ToString() << "@" << dim << " to " << size->ToShortString();
-  Shape subshape = ShapeUtil::GetSubshape(inst->shape(), index);
+  const Shape& subshape = ShapeUtil::GetSubshape(inst->shape(), index);
   CHECK(!subshape.IsTuple()) << "Can't set a tuple shape to dynamic dimension";
   CHECK(dim < subshape.rank() && dim >= 0)
       << "Asked to set invalid dynamic dimension. Shape: "
       << subshape.ToString() << ", Dimension: " << dim;
   DynamicDimension dynamic_dimension{inst, index, dim};
-  // Updating a dynamic dimension twice overwrites the previous one.
-  dynamic_mapping_[dynamic_dimension] = size;
+  // If we have already set the dynamic size, it should be the same.
+  auto [it, inserted] = dynamic_mapping_.try_emplace(dynamic_dimension, size);
+  if (!inserted) {
+    CHECK_EQ(size, it->second) << "old: " << it->second->ToShortString()
+                               << ", new: " << size->ToShortString();
+  }
   auto iter = per_hlo_dynamic_dimensions_.try_emplace(inst);
   iter.first->second.emplace(dynamic_dimension);
 }
 
-void DynamicDimensionInference::CopyMapping(HloInstruction* from,
-                                            HloInstruction* to) {
+void DynamicDimensionInference::CopyMapping(
+    HloInstruction* from, HloInstruction* to,
+    const absl::flat_hash_map<HloInstruction*, HloInstruction*>*
+        dynamic_size_map) {
   auto iter = per_hlo_dynamic_dimensions_.find(from);
   if (iter != per_hlo_dynamic_dimensions_.end()) {
     for (auto& dynamic_dimension : iter->second) {
       HloInstruction* dynamic_size =
           GetDynamicSize(dynamic_dimension.inst, dynamic_dimension.index,
                          dynamic_dimension.dim);
+      if (dynamic_size_map != nullptr) {
+        dynamic_size = dynamic_size_map->at(dynamic_size);
+      }
       SetDynamicSize(to, dynamic_dimension.index, dynamic_dimension.dim,
                      dynamic_size);
     }
@@ -2059,15 +2670,17 @@
 
 /* static */
 StatusOr<DynamicDimensionInference> DynamicDimensionInference::Run(
-    HloModule* module, CustomCallInferenceHandler custom_call_handler,
+    HloModule* module, OpSupportsDynamismHandler op_supports_dynamism_handler,
+    CustomCallInferenceHandler custom_call_handler,
     ShapeCheckMode shape_check_mode,
     const AssertionGenerator& assertion_generator,
     const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  DynamicDimensionInference inference(module, std::move(custom_call_handler),
-                                      shape_check_mode, assertion_generator,
-                                      execution_threads);
+  DynamicDimensionInference inference(
+      module, std::move(op_supports_dynamism_handler),
+      std::move(custom_call_handler), shape_check_mode, assertion_generator,
+      execution_threads);
   TF_RETURN_IF_ERROR(inference.AnalyzeDynamicDimensions());
-  return inference;
+  return std::move(inference);
 }
 
 std::string DynamicDimensionInference::ToString() const {
@@ -2085,19 +2698,37 @@
 }
 
 DynamicDimensionInference::DynamicDimensionInference(
-    HloModule* module, CustomCallInferenceHandler custom_call_handler,
+    HloModule* module, OpSupportsDynamismHandler op_supports_dynamism_handler,
+    CustomCallInferenceHandler custom_call_handler,
     ShapeCheckMode shape_check_mode, AssertionGenerator assertion_generator,
     const absl::flat_hash_set<absl::string_view>& execution_threads)
     : module_(module),
+      op_supports_dynamism_handler_(std::move(op_supports_dynamism_handler)),
       custom_call_handler_(std::move(custom_call_handler)),
       shape_check_mode_(shape_check_mode),
       assertion_generator_(assertion_generator),
       execution_threads_(execution_threads) {}
 
 Status DynamicDimensionInference::AnalyzeDynamicDimensions() {
-  return DynamicDimensionInferenceVisitor::Run(
-      module_->entry_computation(), {}, this, custom_call_handler_,
-      shape_check_mode_, assertion_generator_);
+  TF_ASSIGN_OR_RETURN(
+      std::unique_ptr<HloDataflowAnalysis> dataflow_analysis,
+      HloDataflowAnalysis::Run(*module_, /*ssa_form=*/false,
+                               /*bitcast_defines_value=*/true,
+                               /*can_share_buffer=*/nullptr,
+                               /*forwards_value=*/nullptr, execution_threads_));
+  for (HloComputation* computation : module_->MakeComputationPostOrder()) {
+    if (!HloInstruction::IsThreadIncluded(computation->execution_thread(),
+                                          execution_threads_)) {
+      continue;
+    }
+    TF_ASSIGN_OR_RETURN(
+        bool changed,
+        DynamicDimensionInferenceVisitor::Run(
+            computation, *dataflow_analysis, {}, this, custom_call_handler_,
+            shape_check_mode_, assertion_generator_));
+    changed_ |= changed;
+  }
+  return OkStatus();
 }
 
 void DynamicDimensionInference::ReplaceAllDynamicDimensionUsesWith(
@@ -2116,7 +2747,7 @@
 Status DynamicDimensionInference::ForwardDynamicSize(HloInstruction* inst,
                                                      HloInstruction* new_inst,
                                                      const ShapeIndex& index) {
-  CHECK(Shape::Equal()(inst->shape(), new_inst->shape()));
+  TF_RET_CHECK(ShapeUtil::Compatible(inst->shape(), new_inst->shape()));
 
   for (int64_t dim = 0; dim < inst->shape().rank(); ++dim) {
     DynamicDimension dynamic_dimension_new{new_inst, index, dim};
@@ -2140,7 +2771,7 @@
     if (subshape.IsTuple()) {
       return;
     }
-    if (ShapeIndexView(subindex).first(index.size()) != index) {
+    if (ShapeIndexView(subindex).subspan(0, index.size()) != index) {
       return;
     }
     for (int64_t i = 0; i < subshape.dimensions_size(); ++i) {
@@ -2153,11 +2784,21 @@
   return has_dynamic_dim;
 }
 
-Status DynamicDimensionInference::Update(HloInstruction* inst) {
-  DynamicParameterBinding parameter_binding;
-  DynamicDimensionInferenceVisitor visitor(
-      parameter_binding, this, custom_call_handler_, shape_check_mode_);
-  return inst->Visit(&visitor);
+Shape DynamicDimensionInference::GetDynamicShape(HloInstruction* inst) {
+  Shape shape = inst->shape();
+  ShapeUtil::ForEachMutableSubshape(
+      &shape, [&](Shape* subshape, const ShapeIndex& index) {
+        if (!subshape->IsArray()) {
+          return;
+        }
+        for (int64_t dimension = 0; dimension < subshape->rank(); ++dimension) {
+          if (GetDynamicSize(inst, index, dimension) != nullptr) {
+            subshape->set_dynamic_dimension(dimension, true);
+          }
+        }
+      });
+
+  return shape;
 }
 
 HloInstruction* DynamicDimensionInference::GetDynamicSize(
@@ -2180,9 +2821,56 @@
   const int64_t rank = ShapeUtil::GetSubshape(inst->shape(), index).rank();
   std::vector<HloInstruction*> result(rank, nullptr);
   for (int64_t i = 0; i < rank; ++i) {
-    result[i] = GetDynamicSize(inst, {}, i);
+    result[i] = GetDynamicSize(inst, index, i);
   }
   return result;
 }
 
+bool DynamicDimensionInference::CanInfer(HloInstruction* hlo) {
+  // If the result shape is static, there are no dynamic dimensions to infer.
+  // However, if there are called computations, we may need to run inference on
+  // them.  Similarly, custom calls can do anything based on the user callbacks.
+  if (hlo->shape().is_static() && hlo->called_computations().empty() &&
+      hlo->opcode() != HloOpcode::kCustomCall) {
+    return false;
+  }
+  // The dimensions of all operands must either be 1) not dynamic, or 2) have a
+  // recorded dynamic size.  The only case where a dimension can be dynamic, but
+  // where we have recorded a dynamic size is for SetDynamicSize instructions.
+  bool ok = true;
+  for (int64_t operand_index = 0; operand_index < hlo->operand_count();
+       ++operand_index) {
+    ShapeUtil::ForEachSubshape(
+        hlo->operand(operand_index)->shape(),
+        [&](const Shape& subshape, const ShapeIndex& shape_index) {
+          if (!subshape.IsArray()) {
+            return;
+          }
+          for (int64_t dimension = 0; dimension < subshape.rank();
+               ++dimension) {
+            bool shape_is_dynamic = subshape.is_dynamic_dimension(dimension);
+            bool dynamic_size_recorded =
+                GetDynamicSize(hlo->operand(operand_index), shape_index,
+                               dimension) != nullptr;
+            if (shape_is_dynamic && !dynamic_size_recorded) {
+              VLOG(2) << "cannot infer " << hlo->ToShortString()
+                      << " because operand " << operand_index << " ("
+                      << hlo->operand(operand_index)->ToShortString() << ")"
+                      << " subshape " << shape_index.ToString()
+                      << " is missing dynamic size for dimension " << dimension;
+              ok = false;
+            }
+            // Sanity check that we have cleared the dynamic dimension on the
+            // shape if we have recorded the dynamic size.
+            CHECK(hlo->operand(operand_index)->opcode() ==
+                      HloOpcode::kSetDimensionSize ||
+                  hlo->operand(operand_index)->opcode() ==
+                      HloOpcode::kCustomCall ||
+                  !shape_is_dynamic || !dynamic_size_recorded);
+          }
+        });
+  }
+  return ok;
+}
+
 }  // namespace xla
diff --git a/third_party/xla/xla/service/dynamic_dimension_inference.h b/third_party/xla/xla/service/dynamic_dimension_inference.h
index 0631936..681ba70 100644
--- a/third_party/xla/xla/service/dynamic_dimension_inference.h
+++ b/third_party/xla/xla/service/dynamic_dimension_inference.h
@@ -16,22 +16,45 @@
 #ifndef XLA_SERVICE_DYNAMIC_DIMENSION_INFERENCE_H_
 #define XLA_SERVICE_DYNAMIC_DIMENSION_INFERENCE_H_
 
+#include <cstdint>
 #include <functional>
-#include <memory>
+#include <map>
+#include <set>
 #include <string>
+#include <tuple>
 #include <vector>
 
 #include "absl/container/flat_hash_map.h"
-#include "absl/types/span.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/string_view.h"
 #include "xla/hlo/ir/hlo_instruction.h"
 #include "xla/hlo/ir/hlo_module.h"
+#include "xla/shape.h"
 #include "xla/shape_util.h"
 #include "xla/status.h"
 #include "xla/statusor.h"
-#include "xla/types.h"
 
 namespace xla {
 
+// Each instruction can have one of the three modes in supporting dynamic
+// lowering.
+enum OpDynamismSupport : uint8_t {
+  // There is no support for dynamic lowering -- dynamic padder will make sure
+  // the input to that op has static bound by rewriting the op (e.g, extra space
+  // in reduce_sum will be padded with 0).
+  kNoSupport = 0,
+  // The op can take either dynamic input or static input.
+  kOptional,
+  // The op only has a dynamic lowering, dynamic padder will make sure the input
+  // to this op is in dynamic form.
+  kRequired,
+};
+
+// Returns true if given instruction supports native dynamic lowering. If
+// so, dynamic padder will not attempt to pad it.
+using OpSupportsDynamismHandler =
+    std::function<OpDynamismSupport(HloInstruction*)>;
+
 // DynamicDimensionInference analyzes each HLO instruction in a graph and
 // inferences which dimensions are dynamic and which scalar instructions
 // represent the runtime real size of those dynamic dimensions.
@@ -56,6 +79,7 @@
 
   static StatusOr<DynamicDimensionInference> Run(
       HloModule* module,
+      OpSupportsDynamismHandler op_supports_dynamism_handler = nullptr,
       CustomCallInferenceHandler custom_call_handler = nullptr,
       ShapeCheckMode shape_check_mode = ShapeCheckMode::kIgnore,
       const AssertionGenerator& assertion_generator = nullptr,
@@ -98,15 +122,23 @@
   void ReplaceAllDynamicDimensionUsesWith(HloInstruction* replace,
                                           HloInstruction* with);
 
-  // Update dynamic dimension inference to analyze `inst`. Useful to
-  // incrementally track new instructions added after initial run.
-  Status Update(HloInstruction* inst);
+  // Get the original dynamic shape of the given instruction.
+  Shape GetDynamicShape(HloInstruction* inst);
+
+  // Returns true iff all dynamic dimensions on the operands of the given
+  // instruction have inferred dynamic sizes.
+  bool CanInfer(HloInstruction* hlo);
+
+  // Returns true iff DynamicDimensionInferenceVisitor made changes to the
+  // module.
+  bool changed() const { return changed_; }
 
   friend class DynamicDimensionInferenceVisitor;
 
  private:
   explicit DynamicDimensionInference(
-      HloModule* module, CustomCallInferenceHandler custom_call_handler,
+      HloModule* module, OpSupportsDynamismHandler op_supports_dynamism_handler,
+      CustomCallInferenceHandler custom_call_handler,
       ShapeCheckMode shape_check_mode, AssertionGenerator assertion_generator,
       const absl::flat_hash_set<absl::string_view>& execution_threads_);
 
@@ -151,7 +183,13 @@
   // Copies the internal mapping from instruction `from` to instruction `to`.
   // This is useful when an instruction is replaced by the other during the
   // inferencing process.
-  void CopyMapping(HloInstruction* from, HloInstruction* to);
+  // For cases where the `from` and `to` instructions are in different
+  // computations, a `dynamic_size_map` can be provided which maps the dynamic
+  // size instructions in the `from` computation into the corresponding
+  // instruction in the `to` computation.
+  void CopyMapping(HloInstruction* from, HloInstruction* to,
+                   const absl::flat_hash_map<HloInstruction*, HloInstruction*>*
+                       dynamic_size_map = nullptr);
 
   // AnalyzeDynamicDimensions starts the analysis of the dynamic dimensions in
   // module_.
@@ -172,6 +210,8 @@
       ConstHloInstructionMap<std::set<DynamicDimension>>;
   PerHloDynamicDimensions per_hlo_dynamic_dimensions_;
 
+  OpSupportsDynamismHandler op_supports_dynamism_handler_;
+
   // A handler for custom calls.
   CustomCallInferenceHandler custom_call_handler_;
 
@@ -180,6 +220,8 @@
 
   AssertionGenerator assertion_generator_;
 
+  bool changed_ = false;
+
   const absl::flat_hash_set<absl::string_view>& execution_threads_;
 };
 
diff --git a/third_party/xla/xla/service/dynamic_dimension_inference_test.cc b/third_party/xla/xla/service/dynamic_dimension_inference_test.cc
index 615c6ad..0ca66d5 100644
--- a/third_party/xla/xla/service/dynamic_dimension_inference_test.cc
+++ b/third_party/xla/xla/service/dynamic_dimension_inference_test.cc
@@ -32,6 +32,7 @@
 #include "xla/tests/hlo_test_base.h"
 #include "xla/xla_data.pb.h"
 #include "tsl/lib/core/status_test_util.h"
+#include "tsl/platform/statusor.h"
 #include "tsl/platform/test_benchmark.h"
 
 namespace op = xla::testing::opcode_matchers;
@@ -46,15 +47,16 @@
   }
 
   Status RunInference(
+      OpSupportsDynamismHandler op_supports_dynamism_handler = nullptr,
       DynamicDimensionInference::CustomCallInferenceHandler handler = nullptr,
       DynamicDimensionInference::ShapeCheckMode shape_check_mode =
           DynamicDimensionInference::ShapeCheckMode::kIgnore,
       const DynamicDimensionInference::AssertionGenerator& assertion_generator =
           nullptr) {
-    TF_ASSIGN_OR_RETURN(
-        DynamicDimensionInference inference,
-        DynamicDimensionInference::Run(module_.get(), handler, shape_check_mode,
-                                       assertion_generator));
+    TF_ASSIGN_OR_RETURN(DynamicDimensionInference inference,
+                        DynamicDimensionInference::Run(
+                            module_.get(), op_supports_dynamism_handler,
+                            handler, shape_check_mode, assertion_generator));
 
     inference_ = std::make_unique<DynamicDimensionInference>(inference);
     return OkStatus();
@@ -158,7 +160,7 @@
 TEST_F(DynamicDimensionInferenceTest, ReduceTestI) {
   auto builder = HloComputation::Builder(TestName());
   auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
-  auto reduce_shape = ShapeUtil::MakeShape(F32, {2});
+  auto reduce_shape = ShapeUtil::MakeShape(F32, {2}, {true});
   auto dynamic_shape =
       ShapeUtil::MakeShape(F32, {1, 2, 2}, {false, true, false});
 
@@ -190,7 +192,7 @@
   // Same as ReduceTestI, but only reduce one dimension.
   auto builder = HloComputation::Builder(TestName());
   auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
-  auto reduce_shape = ShapeUtil::MakeShape(F32, {1, 2});
+  auto reduce_shape = ShapeUtil::MakeShape(F32, {1, 2}, {false, true});
   auto dynamic_shape =
       ShapeUtil::MakeShape(F32, {1, 2, 2}, {false, false, true});
 
@@ -223,32 +225,35 @@
   // Handle variadic reduce where output is a tuple.
   auto builder = HloComputation::Builder(TestName());
   auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
-  auto reduce_shape = ShapeUtil::MakeShape(F32, {1, 2});
+  auto reduce_shape = ShapeUtil::MakeShape(F32, {1, 2}, {false, true});
   auto dynamic_shape =
       ShapeUtil::MakeShape(F32, {1, 2, 2}, {false, false, true});
 
-  auto data_param_dynamic = builder.AddInstruction(
+  auto data_param_1 = builder.AddInstruction(
       HloInstruction::CreateParameter(0, input_shape, "data_param"));
-  auto data_param_static = builder.AddInstruction(
+  auto data_param_2 = builder.AddInstruction(
       HloInstruction::CreateParameter(1, input_shape, "data_param.2"));
   auto size_param = builder.AddInstruction(
       HloInstruction::CreateParameter(2, scalar_shape_, "size_param"));
-  data_param_dynamic =
+  auto data_param_dynamic_1 =
       builder.AddInstruction(HloInstruction::CreateSetDimensionSize(
-          dynamic_shape, data_param_dynamic, size_param, 2));
+          dynamic_shape, data_param_1, size_param, 2));
+  auto data_param_dynamic_2 =
+      builder.AddInstruction(HloInstruction::CreateSetDimensionSize(
+          dynamic_shape, data_param_2, size_param, 2));
 
-  auto dynamic_negate = builder.AddInstruction(HloInstruction::CreateUnary(
-      input_shape, HloOpcode::kNegate, data_param_dynamic));
+  auto dynamic_negate_1 = builder.AddInstruction(HloInstruction::CreateUnary(
+      dynamic_shape, HloOpcode::kNegate, data_param_dynamic_1));
 
-  auto static_negate = builder.AddInstruction(HloInstruction::CreateUnary(
-      input_shape, HloOpcode::kNegate, data_param_static));
+  auto dynamic_negate_2 = builder.AddInstruction(HloInstruction::CreateUnary(
+      dynamic_shape, HloOpcode::kNegate, data_param_dynamic_2));
 
   auto init = builder.AddInstruction(
       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
 
   auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
       ShapeUtil::MakeTupleShape({reduce_shape, reduce_shape}),
-      {dynamic_negate, static_negate}, {init, init}, {1}, GetAddTuple()));
+      {dynamic_negate_1, dynamic_negate_2}, {init, init}, {1}, GetAddTuple()));
 
   module_->AddEntryComputation(builder.Build());
 
@@ -267,10 +272,11 @@
   constexpr int zdim = 1;
   auto xy_shape = ShapeUtil::MakeShape(F32, {xdim, ydim});
   auto yz_shape = ShapeUtil::MakeShape(F32, {ydim, zdim});
-  auto xz_shape = ShapeUtil::MakeShape(F32, {xdim, zdim});
   auto xy_dynamic_shape = ShapeUtil::MakeShape(F32, {xdim, ydim}, {true, true});
   auto yz_dynamic_shape =
       ShapeUtil::MakeShape(F32, {ydim, zdim}, {true, false});
+  auto xz_dynamic_shape =
+      ShapeUtil::MakeShape(F32, {xdim, zdim}, {true, false});
 
   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
       /*parameter_number=*/0, xy_shape, "A"));
@@ -280,7 +286,8 @@
       /*parameter_number=*/2, scalar_shape_, "size_param"));
 
   a_param = builder.AddInstruction(HloInstruction::CreateSetDimensionSize(
-      xy_dynamic_shape, a_param, size_param, 0));
+      ShapeUtil::MakeShape(F32, xy_shape.dimensions(), {true, false}), a_param,
+      size_param, 0));
   a_param = builder.AddInstruction(HloInstruction::CreateSetDimensionSize(
       xy_dynamic_shape, a_param, size_param, 1));
   b_param = builder.AddInstruction(HloInstruction::CreateSetDimensionSize(
@@ -290,7 +297,7 @@
   dot_dnums.add_lhs_contracting_dimensions(1);
   dot_dnums.add_rhs_contracting_dimensions(0);
   auto dot = builder.AddInstruction(
-      HloInstruction::CreateDot(xz_shape, a_param, b_param, dot_dnums,
+      HloInstruction::CreateDot(xz_dynamic_shape, a_param, b_param, dot_dnums,
                                 HloTestBase::DefaultPrecisionConfig(2)));
 
   module_->AddEntryComputation(builder.Build());
@@ -305,7 +312,8 @@
   auto builder = HloComputation::Builder(TestName());
   auto lhs_shape = ShapeUtil::MakeShape(F32, {4, 128, 2, 8});
   auto rhs_shape = ShapeUtil::MakeShape(F32, {4, 128, 2, 8});
-  auto output_shape = ShapeUtil::MakeShape(F32, {4, 2, 128, 128});
+  auto output_shape =
+      ShapeUtil::MakeShape(F32, {4, 2, 128, 128}, {true, false, false, false});
   auto lhs_shape_dynamic =
       ShapeUtil::MakeShape(F32, {4, 128, 2, 8}, {true, false, false, false});
 
@@ -358,11 +366,14 @@
       /*parameter_number=*/2, scalar_shape_, "size_param"));
 
   a_param = builder.AddInstruction(HloInstruction::CreateSetDimensionSize(
-      lhs_shape_dynamic, a_param, size_param, 0));
+      ShapeUtil::MakeShape(F32, lhs_shape.dimensions(),
+                           {true, false, false, false}),
+      a_param, size_param, 0));
   a_param = builder.AddInstruction(HloInstruction::CreateSetDimensionSize(
       lhs_shape_dynamic, a_param, size_param, 1));
   b_param = builder.AddInstruction(HloInstruction::CreateSetDimensionSize(
-      rhs_shape_dynamic, b_param, size_param, 0));
+      ShapeUtil::MakeShape(F32, rhs_shape.dimensions(), {true, false, false}),
+      b_param, size_param, 0));
   b_param = builder.AddInstruction(HloInstruction::CreateSetDimensionSize(
       rhs_shape_dynamic, b_param, size_param, 1));
 
@@ -392,8 +403,9 @@
   constexpr int zdim = 1;
   auto xy_shape = ShapeUtil::MakeShape(F32, {xdim, ydim});
   auto yz_shape = ShapeUtil::MakeShape(F32, {ydim, zdim});
-  auto zx_shape = ShapeUtil::MakeShape(F32, {zdim, xdim});
   auto xy_shape_dynamic = ShapeUtil::MakeShape(F32, {xdim, ydim}, {true, true});
+  auto zx_shape_dynamic =
+      ShapeUtil::MakeShape(F32, {zdim, xdim}, {false, true});
 
   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
       /*parameter_number=*/0, xy_shape, "A"));
@@ -403,7 +415,8 @@
       /*parameter_number=*/2, scalar_shape_, "size_param"));
 
   a_param = builder.AddInstruction(HloInstruction::CreateSetDimensionSize(
-      xy_shape_dynamic, a_param, size_param, 0));
+      ShapeUtil::MakeShape(F32, xy_shape.dimensions(), {true, false}), a_param,
+      size_param, 0));
   a_param = builder.AddInstruction(HloInstruction::CreateSetDimensionSize(
       xy_shape_dynamic, a_param, size_param, 1));
 
@@ -418,7 +431,7 @@
   Window window;
 
   auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
-      zx_shape, a_param, b_param, /*feature_group_count=*/1,
+      zx_shape_dynamic, a_param, b_param, /*feature_group_count=*/1,
       /*batch_group_count=*/1, window, dnums,
       HloTestBase::DefaultPrecisionConfig(2)));
 
@@ -434,7 +447,7 @@
   // Test the ability to trace unmodified dimensions
   auto builder = HloComputation::Builder(TestName());
   auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 3});
-  auto output_shape = ShapeUtil::MakeShape(F32, {3, 2, 1});
+  auto output_shape = ShapeUtil::MakeShape(F32, {3, 2, 1}, {true, true, true});
   auto dynamic_shape = ShapeUtil::MakeShape(F32, {1, 2, 3}, {true, true, true});
 
   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
@@ -447,9 +460,11 @@
       /*parameter_number=*/3, scalar_shape_, "size_param"));
 
   a_param = builder.AddInstruction(HloInstruction::CreateSetDimensionSize(
-      dynamic_shape, a_param, size_param_1, 0));
+      ShapeUtil::MakeShape(F32, {1, 2, 3}, {true, false, false}), a_param,
+      size_param_1, 0));
   a_param = builder.AddInstruction(HloInstruction::CreateSetDimensionSize(
-      dynamic_shape, a_param, size_param_2, 1));
+      ShapeUtil::MakeShape(F32, {1, 2, 3}, {true, true, false}), a_param,
+      size_param_2, 1));
   a_param = builder.AddInstruction(HloInstruction::CreateSetDimensionSize(
       dynamic_shape, a_param, size_param_3, 2));
 
@@ -469,7 +484,7 @@
   // Test the ability to trace unmodified dimensions
   auto builder = HloComputation::Builder(TestName());
   auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 3});
-  auto output_shape = ShapeUtil::MakeShape(F32, {3, 1, 2});
+  auto output_shape = ShapeUtil::MakeShape(F32, {3, 1, 2}, {true, true, true});
   auto dynamic_shape = ShapeUtil::MakeShape(F32, {1, 2, 3}, {true, true, true});
 
   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
@@ -482,9 +497,11 @@
       /*parameter_number=*/3, scalar_shape_, "size_param"));
 
   a_param = builder.AddInstruction(HloInstruction::CreateSetDimensionSize(
-      dynamic_shape, a_param, size_param_1, 0));
+      ShapeUtil::MakeShape(F32, {1, 2, 3}, {true, false, false}), a_param,
+      size_param_1, 0));
   a_param = builder.AddInstruction(HloInstruction::CreateSetDimensionSize(
-      dynamic_shape, a_param, size_param_2, 1));
+      ShapeUtil::MakeShape(F32, {1, 2, 3}, {true, true, false}), a_param,
+      size_param_2, 1));
   a_param = builder.AddInstruction(HloInstruction::CreateSetDimensionSize(
       dynamic_shape, a_param, size_param_3, 2));
 
@@ -504,7 +521,8 @@
   // Test the ability to trace unmodified reshape dimensions.
   auto builder = HloComputation::Builder(TestName());
   auto input_shape = ShapeUtil::MakeShape(F32, {2, 3, 4, 5, 6});
-  auto output_shape = ShapeUtil::MakeShape(F32, {6, 4, 1, 5, 2, 3});
+  auto output_shape = ShapeUtil::MakeShape(
+      F32, {6, 4, 1, 5, 2, 3}, {false, true, false, true, false, false});
   auto dynamic_shape = ShapeUtil::MakeShape(F32, {2, 3, 4, 5, 6},
                                             {false, false, true, true, false});
 
@@ -514,7 +532,9 @@
       /*parameter_number=*/1, scalar_shape_, "size_param"));
 
   a_param = builder.AddInstruction(HloInstruction::CreateSetDimensionSize(
-      dynamic_shape, a_param, size_param, 2));
+      ShapeUtil::MakeShape(F32, {2, 3, 4, 5, 6},
+                           {false, false, true, false, false}),
+      a_param, size_param, 2));
   a_param = builder.AddInstruction(HloInstruction::CreateSetDimensionSize(
       dynamic_shape, a_param, size_param, 3));
 
@@ -538,7 +558,8 @@
   // input.
   auto builder = HloComputation::Builder(TestName());
   auto input_shape = ShapeUtil::MakeShape(F32, {4, 5});
-  auto output_shape = ShapeUtil::MakeShape(F32, {1, 4, 5});
+  auto output_shape =
+      ShapeUtil::MakeShape(F32, {1, 4, 5}, {true, false, false});
   auto dynamic_shape = ShapeUtil::MakeShape(F32, {4, 5}, {true, false});
 
   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
@@ -563,7 +584,7 @@
   // Test the ability to trace dimension combining.
   auto builder = HloComputation::Builder(TestName());
   auto input_shape = ShapeUtil::MakeShape(F32, {32, 10, 4});
-  auto output_shape = ShapeUtil::MakeShape(F32, {320, 4});
+  auto output_shape = ShapeUtil::MakeShape(F32, {320, 4}, {true, false});
   auto dynamic_shape =
       ShapeUtil::MakeShape(F32, {32, 10, 4}, {true, false, false});
 
@@ -640,7 +661,8 @@
   // Test the ability to trace broadcast dimension.
   auto builder = HloComputation::Builder(TestName());
   auto input_shape = ShapeUtil::MakeShape(F32, {2});
-  auto output_shape = ShapeUtil::MakeShape(F32, {3, 2, 4});
+  auto output_shape =
+      ShapeUtil::MakeShape(F32, {3, 2, 4}, {false, true, false});
   auto dynamic_shape = ShapeUtil::MakeShape(F32, {2}, {true});
 
   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
@@ -669,7 +691,6 @@
   auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4});
   auto dynamic_shape =
       ShapeUtil::MakeShape(F32, {2, 4, 4}, {true, false, false});
-  auto output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
   auto tuple_shape = ShapeUtil::MakeTupleShape({input_shape, input_shape});
   auto dynamic_tuple_shape =
       ShapeUtil::MakeTupleShape({dynamic_shape, dynamic_shape});
@@ -898,7 +919,8 @@
   // Test the ability to trace reduce window batch dimensions.
   auto builder = HloComputation::Builder(TestName());
   auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4});
-  auto output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
+  auto output_shape =
+      ShapeUtil::MakeShape(F32, {2, 2, 2}, {true, false, false});
   auto dynamic_shape =
       ShapeUtil::MakeShape(F32, {2, 4, 4}, {true, false, false});
 
@@ -992,7 +1014,7 @@
       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
 
   auto* sns = builder.AddInstruction(HloInstruction::CreateSelectAndScatter(
-      input_shape, a_param, GetGe(), window, source, init, GetAdd()));
+      input_shape_dynamic, a_param, GetGe(), window, source, init, GetAdd()));
 
   module_->AddEntryComputation(builder.Build());
 
@@ -1072,7 +1094,7 @@
       0));
 
   auto* slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
-      ShapeUtil::MakeShape(F32, {5, 1}), data_param, params,
+      ShapeUtil::MakeShape(F32, {5, 1}, {true, false}), data_param, params,
       /*slice_sizes=*/{5, 1}));
 
   module_->AddEntryComputation(builder.Build());
@@ -1209,7 +1231,7 @@
     handler_called = true;
     return OkStatus();
   };
-  TF_ASSERT_OK(RunInference(handler));
+  TF_ASSERT_OK(RunInference(/*op_supports_dynamism_handler=*/nullptr, handler));
 
   EXPECT_TRUE(handler_called);
 }
@@ -1320,6 +1342,7 @@
   TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo));
 
   TF_ASSERT_OK(RunInference(
+      /*op_supports_dynamism_handler=*/nullptr,
       /*handler=*/nullptr, DynamicDimensionInference::ShapeCheckMode::kRuntime,
       /*assertion_generator=*/[&](HloInstruction* constraint) {
         constraint->parent()->AddInstruction(HloInstruction::CreateCustomCall(
@@ -1339,5 +1362,400 @@
   EXPECT_TRUE(*filecheck_result);
 }
 
+TEST_F(DynamicDimensionInferenceTest, NestedControlFlow) {
+  // A module with heavily nested control flow that manipulates dynamic shapes.
+  const char* hlo = R"(
+HloModule tfcompile.377, entry_computation_layout={(s32[], f32[250]{0}, pred[], pred[], s32[], /*index=5*/pred[], s32[], pred[])->(f32[3]{0})}
+
+cond_2_Sum-reduction.17 {
+  x.18 = f32[] parameter(0)
+  y.19 = f32[] parameter(1)
+  ROOT add.20 = f32[] add(x.18, y.19)
+}
+
+cond_2_cond_true_214__.21 {
+  arg_tuple.22 = () parameter(0)
+  constant.23 = s32[] constant(1)
+  reshape.24 = s32[] reshape(constant.23)
+  ROOT tuple.25 = (s32[]) tuple(constant.23)
+}
+
+cond_2_cond_false_215__.26 {
+  arg_tuple.27 = () parameter(0)
+  constant.28 = s32[] constant(0)
+  reshape.29 = s32[] reshape(constant.28)
+  ROOT tuple.30 = (s32[]) tuple(constant.28)
+}
+
+cond_2_true_195__.31 {
+  arg_tuple.32 = (s32[], f32[250]{0}) parameter(0)
+  get-tuple-element.33 = s32[] get-tuple-element(arg_tuple.32), index=0
+  constant.35 = s32[] constant(20)
+  minimum.36 = s32[] minimum(get-tuple-element.33, constant.35)
+  reshape.37 = s32[1]{0} reshape(minimum.36)
+  concatenate.38 = s32[1]{0} concatenate(reshape.37), dimensions={0}
+  slice.48 = s32[1]{0} slice(concatenate.38), slice={[0:1]}
+  reshape.49 = s32[] reshape(reshape.37)
+  constant.43 = s32[] constant(0)
+  compare.50 = pred[] compare(minimum.36, constant.43), direction=LT
+  constant.44 = s32[] constant(250)
+  add.51 = s32[] add(constant.44, minimum.36)
+  select.52 = s32[] select(compare.50, add.51, minimum.36)
+  constant.45 = s32[1]{0} constant({0})
+  slice.46 = s32[1]{0} slice(constant.45), slice={[0:1]}
+  reshape.47 = s32[] reshape(slice.46)
+  subtract.53 = s32[] subtract(select.52, reshape.47)
+  maximum.54 = s32[] maximum(subtract.53, constant.43)
+  convert.55 = s32[] convert(maximum.54)
+  get-tuple-element.34 = f32[250]{0} get-tuple-element(arg_tuple.32), index=1
+  constant.39 = f32[] constant(0)
+  pad.40 = f32[500]{0} pad(get-tuple-element.34, constant.39), padding=0_250
+  constant.41 = s32[] constant(500)
+  set-dimension-size.42 = f32[500]{0} set-dimension-size(pad.40, constant.41), dimensions={0}
+  dynamic-slice.56 = f32[250]{0} dynamic-slice(set-dimension-size.42, reshape.47), dynamic_slice_sizes={250}
+  reshape.57 = f32[250]{0} reshape(dynamic-slice.56)
+  set-dimension-size.58 = f32[<=250]{0} set-dimension-size(dynamic-slice.56, maximum.54), dimensions={0}
+  constant.59 = f32[] constant(1)
+  broadcast.60 = f32[250]{0} broadcast(constant.59), dimensions={}
+  compare.61 = pred[<=250]{0} compare(set-dimension-size.58, broadcast.60), direction=GE
+  convert.62 = f32[<=250]{0} convert(compare.61)
+  convert.63 = f32[<=250]{0} convert(convert.62)
+  constant.64 = f32[] constant(0)
+  convert.65 = f32[] convert(constant.64)
+  reduce.66 = f32[] reduce(convert.62, constant.64), dimensions={0}, to_apply=cond_2_Sum-reduction.17
+  convert.67 = f32[] convert(reduce.66)
+  reshape.73 = f32[] reshape(reduce.66)
+  constant.68 = f32[] constant(6)
+  compare.69 = pred[] compare(reduce.66, constant.68), direction=GE
+  tuple.70 = () tuple()
+  conditional.71 = (s32[]) conditional(compare.69, tuple.70, tuple.70), true_computation=cond_2_cond_true_214__.21, false_computation=cond_2_cond_false_215__.26
+  get-tuple-element.72 = s32[] get-tuple-element(conditional.71), index=0
+  reshape.74 = s32[] reshape(get-tuple-element.72)
+  ROOT tuple.75 = (f32[], s32[]) tuple(reduce.66, get-tuple-element.72)
+} // cond_2_true_195__.31
+
+cond_2_false_196__.76 {
+  arg_tuple.77 = (s32[], f32[250]{0}) parameter(0)
+  constant.80 = f32[] constant(0)
+  reshape.82 = f32[] reshape(constant.80)
+  constant.81 = s32[] constant(0)
+  reshape.83 = s32[] reshape(constant.81)
+  ROOT tuple.84 = (f32[], s32[]) tuple(constant.80, constant.81)
+} // cond_2_false_196__.76
+
+cond_true_10__.85 {
+  arg_tuple.86 = (pred[], pred[], pred[]) parameter(0)
+  get-tuple-element.87 = pred[] get-tuple-element(arg_tuple.86), index=0
+  reshape.90 = pred[] reshape(get-tuple-element.87)
+  ROOT tuple.91 = (pred[]) tuple(get-tuple-element.87)
+}
+
+cond_cond_true_16__.92 {
+  arg_tuple.93 = (pred[], pred[]) parameter(0)
+  get-tuple-element.94 = pred[] get-tuple-element(arg_tuple.93), index=0
+  reshape.96 = pred[] reshape(get-tuple-element.94)
+  ROOT tuple.97 = (pred[]) tuple(get-tuple-element.94)
+}
+
+cond_cond_false_17__.98 {
+  arg_tuple.99 = (pred[], pred[]) parameter(0)
+  get-tuple-element.101 = pred[] get-tuple-element(arg_tuple.99), index=1
+  reshape.102 = pred[] reshape(get-tuple-element.101)
+  ROOT tuple.103 = (pred[]) tuple(get-tuple-element.101)
+}
+
+cond_false_11__.104 {
+  arg_tuple.105 = (pred[], pred[], pred[]) parameter(0)
+  get-tuple-element.107 = pred[] get-tuple-element(arg_tuple.105), index=1
+  get-tuple-element.108 = pred[] get-tuple-element(arg_tuple.105), index=2
+  tuple.109 = (pred[], pred[]) tuple(get-tuple-element.107, get-tuple-element.108)
+  conditional.110 = (pred[]) conditional(get-tuple-element.107, tuple.109, tuple.109), true_computation=cond_cond_true_16__.92, false_computation=cond_cond_false_17__.98
+  get-tuple-element.111 = pred[] get-tuple-element(conditional.110), index=0
+  reshape.112 = pred[] reshape(get-tuple-element.111)
+  ROOT tuple.113 = (pred[]) tuple(get-tuple-element.111)
+} // cond_false_11__.104
+
+cond_1_map_while_cond_true_82__.114 {
+  arg_tuple.115 = (f32[]) parameter(0)
+  constant.117 = f32[] constant(0)
+  reshape.118 = f32[] reshape(constant.117)
+  ROOT tuple.119 = (f32[]) tuple(constant.117)
+}
+
+cond_1_map_while_cond_cond_true_91__.120 {
+  constant.123 = f32[] constant(0.1)
+  arg_tuple.121 = (f32[]) parameter(0)
+  get-tuple-element.122 = f32[] get-tuple-element(arg_tuple.121), index=0
+  multiply.124 = f32[] multiply(constant.123, get-tuple-element.122)
+  constant.125 = f32[] constant(0)
+  add.126 = f32[] add(multiply.124, constant.125)
+  constant.127 = f32[] constant(0.9)
+  divide.128 = f32[] divide(add.126, constant.127)
+  reshape.129 = f32[] reshape(divide.128)
+  ROOT tuple.130 = (f32[]) tuple(divide.128)
+} // cond_1_map_while_cond_cond_true_91__.120
+
+cond_1_map_while_cond_cond_cond_true_106__.131 {
+  constant.134 = f32[] constant(0.8)
+  arg_tuple.132 = (f32[]) parameter(0)
+  get-tuple-element.133 = f32[] get-tuple-element(arg_tuple.132), index=0
+  multiply.135 = f32[] multiply(constant.134, get-tuple-element.133)
+  constant.136 = f32[] constant(-0.711)
+  add.137 = f32[] add(multiply.135, constant.136)
+  constant.138 = f32[] constant(0.09)
+  divide.139 = f32[] divide(add.137, constant.138)
+  reshape.140 = f32[] reshape(divide.139)
+  ROOT tuple.141 = (f32[]) tuple(divide.139)
+} // cond_1_map_while_cond_cond_cond_true_106__.131
+
+cond_1_map_while_cond_cond_cond_cond_true_121__.142 {
+  constant.145 = f32[] constant(0.2)
+  arg_tuple.143 = (f32[]) parameter(0)
+  get-tuple-element.144 = f32[] get-tuple-element(arg_tuple.143), index=0
+  multiply.146 = f32[] multiply(constant.145, get-tuple-element.144)
+  constant.147 = f32[] constant(-0.18)
+  add.148 = f32[] add(multiply.146, constant.147)
+  constant.149 = f32[] constant(0.02)
+  divide.150 = f32[] divide(add.148, constant.149)
+  reshape.151 = f32[] reshape(divide.150)
+  ROOT tuple.152 = (f32[]) tuple(divide.150)
+} // cond_1_map_while_cond_cond_cond_cond_true_121__.142
+
+cond_1_map_while_cond_cond_cond_cond_cond_true_136__.153 {
+  constant.156 = f32[] constant(0.1)
+  arg_tuple.154 = (f32[]) parameter(0)
+  get-tuple-element.155 = f32[] get-tuple-element(arg_tuple.154), index=0
+  multiply.157 = f32[] multiply(constant.156, get-tuple-element.155)
+  constant.158 = f32[] constant(108.788)
+  add.159 = f32[] add(multiply.157, constant.158)
+  constant.160 = f32[] constant(98.99)
+  divide.161 = f32[] divide(add.159, constant.160)
+  reshape.162 = f32[] reshape(divide.161)
+  ROOT tuple.163 = (f32[]) tuple(divide.161)
+} // cond_1_map_while_cond_cond_cond_cond_cond_true_136__.153
+
+cond_1_map_while_cond_cond_cond_cond_cond_false_137__.164 {
+  arg_tuple.165 = (f32[]) parameter(0)
+  constant.167 = f32[] constant(1.2)
+  reshape.168 = f32[] reshape(constant.167)
+  ROOT tuple.169 = (f32[]) tuple(constant.167)
+}
+
+cond_1_map_while_cond_cond_cond_cond_false_122__.170 {
+  arg_tuple.171 = (f32[]) parameter(0)
+  get-tuple-element.172 = f32[] get-tuple-element(arg_tuple.171), index=0
+  constant.173 = f32[] constant(100)
+  compare.174 = pred[] compare(get-tuple-element.172, constant.173), direction=LE
+  tuple.175 = (f32[]) tuple(get-tuple-element.172)
+  conditional.176 = (f32[]) conditional(compare.174, tuple.175, tuple.175), true_computation=cond_1_map_while_cond_cond_cond_cond_cond_true_136__.153, false_computation=cond_1_map_while_cond_cond_cond_cond_cond_false_137__.164
+  get-tuple-element.177 = f32[] get-tuple-element(conditional.176), index=0
+  reshape.178 = f32[] reshape(get-tuple-element.177)
+  ROOT tuple.179 = (f32[]) tuple(get-tuple-element.177)
+} // cond_1_map_while_cond_cond_cond_cond_false_122__.170
+
+cond_1_map_while_cond_cond_cond_false_107__.180 {
+  arg_tuple.181 = (f32[]) parameter(0)
+  get-tuple-element.182 = f32[] get-tuple-element(arg_tuple.181), index=0
+  constant.183 = f32[] constant(1.01)
+  compare.184 = pred[] compare(get-tuple-element.182, constant.183), direction=LE
+  tuple.185 = (f32[]) tuple(get-tuple-element.182)
+  conditional.186 = (f32[]) conditional(compare.184, tuple.185, tuple.185), true_computation=cond_1_map_while_cond_cond_cond_cond_true_121__.142, false_computation=cond_1_map_while_cond_cond_cond_cond_false_122__.170
+  get-tuple-element.187 = f32[] get-tuple-element(conditional.186), index=0
+  reshape.188 = f32[] reshape(get-tuple-element.187)
+  ROOT tuple.189 = (f32[]) tuple(get-tuple-element.187)
+} // cond_1_map_while_cond_cond_cond_false_107__.180
+
+cond_1_map_while_cond_cond_false_92__.190 {
+  arg_tuple.191 = (f32[]) parameter(0)
+  get-tuple-element.192 = f32[] get-tuple-element(arg_tuple.191), index=0
+  constant.193 = f32[] constant(0.99)
+  compare.194 = pred[] compare(get-tuple-element.192, constant.193), direction=LE
+  tuple.195 = (f32[]) tuple(get-tuple-element.192)
+  conditional.196 = (f32[]) conditional(compare.194, tuple.195, tuple.195), true_computation=cond_1_map_while_cond_cond_cond_true_106__.131, false_computation=cond_1_map_while_cond_cond_cond_false_107__.180
+  get-tuple-element.197 = f32[] get-tuple-element(conditional.196), index=0
+  reshape.198 = f32[] reshape(get-tuple-element.197)
+  ROOT tuple.199 = (f32[]) tuple(get-tuple-element.197)
+} // cond_1_map_while_cond_cond_false_92__.190
+
+cond_1_map_while_cond_false_83__.200 {
+  arg_tuple.201 = (f32[]) parameter(0)
+  get-tuple-element.202 = f32[] get-tuple-element(arg_tuple.201), index=0
+  constant.203 = f32[] constant(0.9)
+  compare.204 = pred[] compare(get-tuple-element.202, constant.203), direction=LE
+  tuple.205 = (f32[]) tuple(get-tuple-element.202)
+  conditional.206 = (f32[]) conditional(compare.204, tuple.205, tuple.205), true_computation=cond_1_map_while_cond_cond_true_91__.120, false_computation=cond_1_map_while_cond_cond_false_92__.190
+  get-tuple-element.207 = f32[] get-tuple-element(conditional.206), index=0
+  reshape.208 = f32[] reshape(get-tuple-element.207)
+  ROOT tuple.209 = (f32[]) tuple(get-tuple-element.207)
+} // cond_1_map_while_cond_false_83__.200
+
+cond_1_map_while_body_59__.210 {
+  arg_tuple.211 = (s32[], s32[], s32[], (f32[<=250]{0}, s32[]), s32[], /*index=5*/(f32[<=250]{0}, s32[])) parameter(0)
+  get-tuple-element.212 = s32[] get-tuple-element(arg_tuple.211), index=0
+  constant.218 = s32[] constant(1)
+  add.219 = s32[] add(get-tuple-element.212, constant.218)
+  reshape.239 = s32[] reshape(add.219)
+  get-tuple-element.213 = s32[] get-tuple-element(arg_tuple.211), index=1
+  reshape.240 = s32[] reshape(get-tuple-element.213)
+  get-tuple-element.214 = s32[] get-tuple-element(arg_tuple.211), index=2
+  constant.220 = s32[] constant(1)
+  add.221 = s32[] add(get-tuple-element.214, constant.220)
+  reshape.241 = s32[] reshape(add.221)
+  get-tuple-element.216 = s32[] get-tuple-element(arg_tuple.211), index=4
+  reshape.242 = s32[] reshape(get-tuple-element.216)
+  get-tuple-element.215 = (f32[<=250]{0}, s32[]) get-tuple-element(arg_tuple.211), index=3
+  get-tuple-element.235 = f32[<=250]{0} get-tuple-element(get-tuple-element.215), index=0
+  get-tuple-element.217 = (f32[<=250]{0}, s32[]) get-tuple-element(arg_tuple.211), index=5
+  get-tuple-element.223 = f32[<=250]{0} get-tuple-element(get-tuple-element.217), index=0
+  dynamic-slice.224 = f32[1]{0} dynamic-slice(get-tuple-element.223, get-tuple-element.214), dynamic_slice_sizes={1}
+  reshape.225 = f32[] reshape(dynamic-slice.224)
+  constant.226 = f32[] constant(0)
+  compare.227 = pred[] compare(reshape.225, constant.226), direction=LE
+  tuple.228 = (f32[]) tuple(reshape.225)
+  conditional.229 = (f32[]) conditional(compare.227, tuple.228, tuple.228), true_computation=cond_1_map_while_cond_true_82__.114, false_computation=cond_1_map_while_cond_false_83__.200
+  get-tuple-element.230 = f32[] get-tuple-element(conditional.229), index=0
+  reshape.233 = f32[1]{0} reshape(get-tuple-element.230)
+  dynamic-update-slice.236 = f32[<=250]{0} dynamic-update-slice(get-tuple-element.235, reshape.233, get-tuple-element.214)
+  get-tuple-element.237 = s32[] get-tuple-element(get-tuple-element.215), index=1
+  tuple.238 = (f32[<=250]{0}, s32[]) tuple(dynamic-update-slice.236, get-tuple-element.237)
+  ROOT tuple.243 = (s32[], s32[], s32[], (f32[<=250]{0}, s32[]), s32[], /*index=5*/(f32[<=250]{0}, s32[])) tuple(add.219, get-tuple-element.213, add.221, tuple.238, get-tuple-element.216, /*index=5*/get-tuple-element.217)
+} // cond_1_map_while_body_59__.210
+
+cond_wrapper.257 {
+  inputs.258 = (s32[], s32[], s32[], (f32[<=250]{0}, s32[]), s32[], /*index=5*/(f32[<=250]{0}, s32[])) parameter(0)
+  get-tuple-element.0 = s32[] get-tuple-element(inputs.258), index=0
+  get-tuple-element.1 = s32[] get-tuple-element(inputs.258), index=1
+  compare.0 = pred[] compare(get-tuple-element.0, get-tuple-element.1), direction=LT
+  get-tuple-element.2 = s32[] get-tuple-element(inputs.258), index=2
+  get-tuple-element.3 = s32[] get-tuple-element(inputs.258), index=4
+  compare.1 = pred[] compare(get-tuple-element.2, get-tuple-element.3), direction=LT
+  and.0 = pred[] and(compare.0, compare.1)
+  tuple.0 = (pred[]) tuple(and.0)
+  ROOT get-tuple-element.260 = pred[] get-tuple-element(tuple.0), index=0
+  reshape.0 = pred[] reshape(and.0)
+} // cond_wrapper.257
+
+cond_1_Sum-reduction.261 {
+  x.262 = f32[] parameter(0)
+  y.263 = f32[] parameter(1)
+  ROOT add.264 = f32[] add(x.262, y.263)
+}
+
+cond_1_true_36__.265 {
+  arg_tuple.266 = (s32[], f32[250]{0}) parameter(0)
+  get-tuple-element.267 = s32[] get-tuple-element(arg_tuple.266), index=0
+  reshape.269 = s32[1]{0} reshape(get-tuple-element.267)
+  concatenate.270 = s32[1]{0} concatenate(reshape.269), dimensions={0}
+  slice.280 = s32[1]{0} slice(concatenate.270), slice={[0:1]}
+  reshape.281 = s32[] reshape(reshape.269)
+  constant.275 = s32[] constant(0)
+  compare.282 = pred[] compare(get-tuple-element.267, constant.275), direction=LT
+  constant.276 = s32[] constant(250)
+  add.283 = s32[] add(constant.276, get-tuple-element.267)
+  select.284 = s32[] select(compare.282, add.283, get-tuple-element.267)
+  constant.277 = s32[1]{0} constant({0})
+  slice.278 = s32[1]{0} slice(constant.277), slice={[0:1]}
+  reshape.279 = s32[] reshape(slice.278)
+  subtract.285 = s32[] subtract(select.284, reshape.279)
+  maximum.286 = s32[] maximum(subtract.285, constant.275)
+  convert.287 = s32[] convert(maximum.286)
+  get-tuple-element.268 = f32[250]{0} get-tuple-element(arg_tuple.266), index=1
+  constant.271 = f32[] constant(0)
+  pad.272 = f32[500]{0} pad(get-tuple-element.268, constant.271), padding=0_250
+  constant.273 = s32[] constant(500)
+  set-dimension-size.274 = f32[500]{0} set-dimension-size(pad.272, constant.273), dimensions={0}
+  dynamic-slice.288 = f32[250]{0} dynamic-slice(set-dimension-size.274, reshape.279), dynamic_slice_sizes={250}
+  reshape.289 = f32[250]{0} reshape(dynamic-slice.288)
+  set-dimension-size.290 = f32[<=250]{0} set-dimension-size(dynamic-slice.288, maximum.286), dimensions={0}
+  get-dimension-size.291 = s32[] get-dimension-size(set-dimension-size.290), dimensions={0}
+  convert.292 = s32[] convert(get-dimension-size.291)
+  broadcast.293 = s32[1]{0} broadcast(get-dimension-size.291), dimensions={}
+  concatenate.294 = s32[1]{0} concatenate(broadcast.293), dimensions={0}
+  slice.295 = s32[1]{0} slice(concatenate.294), slice={[0:1]}
+  reshape.296 = s32[] reshape(broadcast.293)
+  constant.309 = s32[] constant(0)
+  constant.310 = s32[] constant(0)
+  constant.312 = f32[] constant(0)
+  broadcast.313 = f32[250]{0} broadcast(constant.312), dimensions={}
+  constant.302 = s32[] constant(0)
+  broadcast.303 = s32[250]{0} broadcast(constant.302), dimensions={}
+  set-dimension-size.304 = s32[<=250]{0} set-dimension-size(broadcast.303, get-dimension-size.291), dimensions={0}
+  get-dimension-size.311 = s32[] get-dimension-size(set-dimension-size.304), dimensions={0}
+  set-dimension-size.314 = f32[<=250]{0} set-dimension-size(broadcast.313, get-dimension-size.311), dimensions={0}
+  constant.315 = s32[] constant(0)
+  tuple.316 = (f32[<=250]{0}, s32[]) tuple(set-dimension-size.314, constant.315)
+  constant.305 = s32[] constant(250)
+  tuple.306 = (f32[<=250]{0}, s32[]) tuple(set-dimension-size.290, constant.305)
+  tuple.317 = (s32[], s32[], s32[], (f32[<=250]{0}, s32[]), s32[], /*index=5*/(f32[<=250]{0}, s32[])) tuple(constant.309, get-dimension-size.291, constant.310, tuple.316, get-dimension-size.291, /*index=5*/tuple.306)
+  while.318 = (s32[], s32[], s32[], (f32[<=250]{0}, s32[]), s32[], /*index=5*/(f32[<=250]{0}, s32[])) while(tuple.317), condition=cond_wrapper.257, body=cond_1_map_while_body_59__.210
+  get-tuple-element.319 = s32[] get-tuple-element(while.318), index=0
+  get-tuple-element.320 = s32[] get-tuple-element(while.318), index=1
+  get-tuple-element.321 = s32[] get-tuple-element(while.318), index=2
+  get-tuple-element.322 = (f32[<=250]{0}, s32[]) get-tuple-element(while.318), index=3
+  get-tuple-element.323 = s32[] get-tuple-element(while.318), index=4
+  get-tuple-element.324 = (f32[<=250]{0}, s32[]) get-tuple-element(while.318), index=5
+  tuple.325 = (s32[], s32[], s32[], (f32[<=250]{0}, s32[]), s32[], /*index=5*/(f32[<=250]{0}, s32[])) tuple(get-tuple-element.319, get-tuple-element.320, get-tuple-element.321, get-tuple-element.322, get-tuple-element.323, /*index=5*/get-tuple-element.324)
+  get-tuple-element.329 = (f32[<=250]{0}, s32[]) get-tuple-element(tuple.325), index=3
+  get-tuple-element.332 = f32[<=250]{0} get-tuple-element(get-tuple-element.329), index=0
+  convert.333 = f32[<=250]{0} convert(get-tuple-element.332)
+  constant.334 = f32[] constant(0)
+  convert.335 = f32[] convert(constant.334)
+  reduce.336 = f32[] reduce(get-tuple-element.332, constant.334), dimensions={0}, to_apply=cond_1_Sum-reduction.261
+  convert.337 = f32[] convert(reduce.336)
+  reshape.338 = f32[] reshape(reduce.336)
+  ROOT tuple.339 = (f32[]) tuple(reduce.336)
+} // cond_1_true_36__.265
+
+cond_1_false_37__.340 {
+  arg_tuple.341 = (s32[], f32[250]{0}) parameter(0)
+  constant.344 = f32[] constant(0)
+  reshape.345 = f32[] reshape(constant.344)
+  ROOT tuple.346 = (f32[]) tuple(constant.344)
+}
+
+ENTRY tfcompile.377 {
+  arg6.7 = s32[] parameter(6), parameter_replication={false}
+  arg0.1 = s32[] parameter(0), parameter_replication={false}
+  reshape.9 = s32[] reshape(arg0.1)
+  arg1.2 = f32[250]{0} parameter(1), parameter_replication={false}
+  reshape.10 = f32[250]{0} reshape(arg1.2)
+  arg2.3 = pred[] parameter(2), parameter_replication={false}
+  reshape.11 = pred[] reshape(arg2.3)
+  arg3.4 = pred[] parameter(3), parameter_replication={false}
+  reshape.12 = pred[] reshape(arg3.4)
+  arg4.5 = s32[] parameter(4), parameter_replication={false}
+  reshape.13 = s32[] reshape(arg4.5)
+  arg5.6 = pred[] parameter(5), parameter_replication={false}
+  reshape.14 = pred[] reshape(arg5.6)
+  arg7.8 = pred[] parameter(7), parameter_replication={false}
+  reshape.16 = pred[] reshape(arg7.8)
+  tuple.1 = (s32[], f32[250]{0}) tuple(arg0.1, arg1.2)
+  conditional.0 = (f32[], s32[]) conditional(arg2.3, tuple.1, tuple.1), true_computation=cond_2_true_195__.31, false_computation=cond_2_false_196__.76
+  get-tuple-element.4 = f32[] get-tuple-element(conditional.0), index=0
+  reshape.1 = f32[1]{0} reshape(get-tuple-element.4)
+  get-tuple-element.5 = s32[] get-tuple-element(conditional.0), index=1
+  convert.0 = f32[] convert(get-tuple-element.5)
+  reshape.2 = f32[1]{0} reshape(convert.0)
+  tuple.2 = (pred[], pred[], pred[]) tuple(arg3.4, arg5.6, arg7.8)
+  conditional.1 = (pred[]) conditional(arg3.4, tuple.2, tuple.2), true_computation=cond_true_10__.85, false_computation=cond_false_11__.104
+  get-tuple-element.6 = pred[] get-tuple-element(conditional.1), index=0
+  tuple.3 = (s32[], f32[250]{0}) tuple(arg4.5, arg1.2)
+  conditional.2 = (f32[]) conditional(get-tuple-element.6, tuple.3, tuple.3), true_computation=cond_1_true_36__.265, false_computation=cond_1_false_37__.340
+  get-tuple-element.7 = f32[] get-tuple-element(conditional.2), index=0
+  reshape.3 = f32[1]{0} reshape(get-tuple-element.7)
+  concatenate.0 = f32[3]{0} concatenate(reshape.1, reshape.2, reshape.3), dimensions={0}
+  tuple.4 = (f32[3]{0}) tuple(concatenate.0)
+  get-tuple-element.374 = f32[3]{0} get-tuple-element(tuple.4), index=0
+  reshape.375 = f32[3]{0} reshape(get-tuple-element.374)
+  ROOT tuple.376 = (f32[3]{0}) tuple(get-tuple-element.374)
+  reshape.4 = f32[3]{0} reshape(concatenate.0)
+} // tfcompile.377
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo));
+
+  TF_ASSERT_OK(RunInference());
+}
+
 }  // namespace
 }  // namespace xla
diff --git a/third_party/xla/xla/service/dynamic_padder.cc b/third_party/xla/xla/service/dynamic_padder.cc
index 8077ee0..fd14724 100644
--- a/third_party/xla/xla/service/dynamic_padder.cc
+++ b/third_party/xla/xla/service/dynamic_padder.cc
@@ -14,41 +14,52 @@
 ==============================================================================*/
 #include "xla/service/dynamic_padder.h"
 
-#include <algorithm>
+#include <cstdint>
 #include <functional>
-#include <optional>
+#include <iterator>
+#include <set>
+#include <utility>
 #include <vector>
 
 #include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_map.h"
 #include "absl/container/flat_hash_set.h"
 #include "absl/functional/function_ref.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/strings/str_cat.h"
 #include "absl/strings/str_format.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
 #include "xla/client/xla_builder.h"
 #include "xla/comparison_util.h"
 #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
+#include "xla/hlo/ir/dynamic_parameter_binding.h"
 #include "xla/hlo/ir/hlo_casting_utils.h"
 #include "xla/hlo/ir/hlo_computation.h"
 #include "xla/hlo/ir/hlo_instruction.h"
 #include "xla/hlo/ir/hlo_instructions.h"
 #include "xla/hlo/ir/hlo_module.h"
 #include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/literal.h"
 #include "xla/literal_util.h"
+#include "xla/service/call_graph.h"
 #include "xla/service/dynamic_dimension_inference.h"
 #include "xla/service/dynamic_window_utils.h"
 #include "xla/service/hlo_creation_utils.h"
 #include "xla/service/hlo_dce.h"
 #include "xla/service/pattern_matcher.h"
 #include "xla/service/shape_inference.h"
+#include "xla/service/tuple_util.h"
+#include "xla/shape.h"
 #include "xla/shape_util.h"
 #include "xla/status.h"
 #include "xla/status_macros.h"
+#include "xla/statusor.h"
 #include "xla/util.h"
 #include "xla/window_util.h"
 #include "xla/xla_data.pb.h"
 #include "tsl/lib/monitoring/gauge.h"
 #include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
 
 namespace xla {
 
@@ -302,9 +313,9 @@
   CHECK(inst != nullptr && dynamic_size != nullptr &&
         padding_scalar != nullptr);
   const Shape mask_shape =
-      ShapeUtil::ChangeElementType(inst->shape(), xla::S32);
+      ShapeUtil::MakeShape(xla::S32, inst->shape().dimensions());
   const Shape pred_shape =
-      ShapeUtil::ChangeElementType(inst->shape(), xla::PRED);
+      ShapeUtil::MakeShape(xla::PRED, inst->shape().dimensions());
   HloInstruction* iota =
       inst->AddInstruction(HloInstruction::CreateIota(mask_shape, dim));
 
@@ -313,11 +324,12 @@
   HloInstruction* pred = inst->AddInstruction(HloInstruction::CreateCompare(
       pred_shape, iota, broadcasted_effective_size, ComparisonDirection::kLt));
 
-  HloInstruction* broadcasted_identity_value = inst->AddInstruction(
-      HloInstruction::CreateBroadcast(inst->shape(), padding_scalar, {}));
-  HloInstruction* padded = inst->AddInstruction(
-      HloInstruction::CreateTernary(inst->shape(), HloOpcode::kSelect, pred,
-                                    inst, broadcasted_identity_value));
+  HloInstruction* broadcasted_identity_value =
+      inst->AddInstruction(HloInstruction::CreateBroadcast(
+          ShapeUtil::MakeStaticShape(inst->shape()), padding_scalar, {}));
+  HloInstruction* padded = inst->AddInstruction(HloInstruction::CreateTernary(
+      ShapeUtil::MakeStaticShape(inst->shape()), HloOpcode::kSelect, pred, inst,
+      broadcasted_identity_value));
   return padded;
 }
 
@@ -742,9 +754,11 @@
 
   // Temporarily removes dynamic dimension of the reshape before we send it to
   // the sort -- we want padded area to also participate in the gather.
+  Shape reshape_static_shape = reshape->shape();
+  reshape_static_shape.set_dynamic_dimension(output_dim, false);
   HloInstruction* reshape_static =
       reshape->AddInstruction(HloInstruction::CreateSetDimensionSize(
-          reshape->shape(), reshape, static_dim_size, output_dim));
+          reshape_static_shape, reshape, static_dim_size, output_dim));
   std::vector<int64_t> gather_slice_sizes(output_shape.dimensions().begin(),
                                           output_shape.dimensions().end());
   gather_slice_sizes[output_dim] = 1;
@@ -898,10 +912,9 @@
   HloInstruction* dynamic_reverse =
       reverse->AddInstruction(HloInstruction::CreateDynamicSlice(
           reverse_shape, pad, start_indices, reverse_shape.dimensions()));
-  TF_RETURN_IF_ERROR(
-      reverse->parent()->ReplaceInstruction(reverse, dynamic_reverse));
   TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
       reverse, dynamic_reverse, {}));
+  TF_RETURN_IF_ERROR(reverse->ReplaceAllUsesWith(dynamic_reverse));
   return true;
 }
 
@@ -1041,8 +1054,6 @@
     DynamicDimensionInference* dynamic_dimension_inference) {
   HloInstruction* input = custom_call_conv->mutable_operand(0);
   HloInstruction* kernel = custom_call_conv->mutable_operand(1);
-  TF_RET_CHECK(kernel->shape().is_static());
-  TF_RET_CHECK(input->shape().is_dynamic());
   Window window = custom_call_conv->window();
   auto dnums = custom_call_conv->convolution_dimension_numbers();
   HloInstruction* zero =
@@ -1102,8 +1113,8 @@
     DynamicDimensionInference* dynamic_dimension_inference) {
   HloInstruction* activations = custom_call_conv->mutable_operand(0);
   HloInstruction* gradients = custom_call_conv->mutable_operand(1);
-  TF_RET_CHECK(activations->shape().is_dynamic());
-  TF_RET_CHECK(gradients->shape().is_dynamic());
+  TF_RET_CHECK(dynamic_dimension_inference->HasDynamicDimension(activations));
+  TF_RET_CHECK(dynamic_dimension_inference->HasDynamicDimension(gradients));
   Window window = custom_call_conv->window();
   auto dnums = custom_call_conv->convolution_dimension_numbers();
   HloInstruction* zero =
@@ -1369,12 +1380,13 @@
 
   Shape operand_shape =
       ShapeUtil::ChangeElementType(sort->operand(0)->shape(), S32);
-  HloInstruction* iota =
-      hlo->AddInstruction(HloInstruction::CreateIota(operand_shape, sort_dim));
+  Shape broadcast_shape = ShapeUtil::MakeStaticShape(operand_shape);
+  HloInstruction* iota = hlo->AddInstruction(
+      HloInstruction::CreateIota(broadcast_shape, sort_dim));
   HloInstruction* dynamic_size_broadcasted = hlo->AddInstruction(
-      HloInstruction::CreateBroadcast(operand_shape, dynamic_size, {}));
+      HloInstruction::CreateBroadcast(broadcast_shape, dynamic_size, {}));
   HloInstruction* lt = hlo->AddInstruction(HloInstruction::CreateCompare(
-      ShapeUtil::ChangeElementType(operand_shape, PRED), iota,
+      ShapeUtil::ChangeElementType(broadcast_shape, PRED), iota,
       dynamic_size_broadcasted, ComparisonDirection::kLt));
   sort->AppendOperand(lt);
 
@@ -1412,8 +1424,6 @@
       ShapeUtil::MakeScalarShape(PRED), HloOpcode::kAnd, inbound_lhs,
       sort_comp_or_out_of_bound_rhs));
   sort_comp->set_root_instruction(new_root);
-  Shape compare_shape =
-      ShapeUtil::ChangeElementType(sort->operand(0)->shape(), PRED);
   if (sort->shape().IsTuple()) {
     // For sort that is already tuple, simply add another result to the tuple.
     *sort->mutable_shape()->add_tuple_shapes() =
@@ -1473,8 +1483,7 @@
       // Broadcast [2, 5, 3]
       auto rewrite_operand = [&](HloInstruction* pred,
                                  HloInstruction* operand) -> HloInstruction* {
-        Shape static_shape = operand->shape();
-        static_shape.clear_dynamic_dimensions();
+        Shape static_shape = ShapeUtil::MakeStaticShape(operand->shape());
         pred = binary->AddInstruction(HloInstruction::CreateBroadcast(
             ShapeUtil::ChangeElementType(static_shape, PRED), pred, {}));
         Shape slice_shape = static_shape;
@@ -1725,8 +1734,9 @@
     int64_t num_elements = ShapeUtil::ElementsIn(operand->shape());
     Shape flattened_shape =
         ShapeUtil::MakeShape(operand->shape().element_type(), {num_elements});
-    HloInstruction* flatten = operand->AddInstruction(
-        HloInstruction::CreateReshape(flattened_shape, operand));
+    HloInstruction* flatten = operand->parent()->AddInstruction(
+        HloInstruction::CreateReshape(flattened_shape, operand),
+        absl::StrCat(reshape->name(), ".flatten"));
 
     HloInstruction* dynamic_size =
         operand->AddInstruction(HloInstruction::CreateConstant(
@@ -1748,8 +1758,10 @@
     }
     dynamic_dimension_inference->SetDynamicSize(flatten, {}, 0, dynamic_size);
 
-    HloInstruction* unflatten = reshape->AddInstruction(
-        HloInstruction::CreateReshape(reshape->shape(), flatten));
+    Shape unflattened_shape = ShapeUtil::MakeStaticShape(reshape->shape());
+    HloInstruction* unflatten = reshape->parent()->AddInstruction(
+        HloInstruction::CreateReshape(unflattened_shape, flatten),
+        absl::StrCat(reshape->name(), ".unflatten"));
     TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
         reshape, unflatten, {}));
 
@@ -1759,6 +1771,9 @@
     TF_ASSIGN_OR_RETURN(
         changed_unused,
         RewriteDynamicReshape(unflatten, dynamic_dimension_inference));
+
+    TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
+        reshape, unflatten, {}));
     TF_RETURN_IF_ERROR(reshape->ReplaceAllUsesWith(unflatten));
 
     return true;
@@ -1811,109 +1826,6 @@
   return changed;
 }
 
-// Insert pad-to-static after `inst` if `inst` has dynamic dimensions in it.
-// Recurse into tuple instructions.
-StatusOr<HloInstruction*> InsertPadToStaticOnInstruction(HloInstruction* inst) {
-  if (inst->shape().is_static()) {
-    return inst;
-  }
-  if (!inst->shape().IsTuple()) {
-    // The output shape of pad static is a tuple. The 0th element is the data
-    // output, which is the same as input shape, but without dynamic dimensions;
-    // i-th element is the dynamic dimension size for i-1th input dimension.
-    Shape data_output_shape = inst->shape();  // 0th element.
-    data_output_shape.clear_dynamic_dimensions();
-    Shape output_shape = ShapeUtil::MakeTupleShape({data_output_shape});
-    for (int64_t i = 0; i < inst->shape().rank(); ++i) {
-      ShapeUtil::AppendShapeToTuple(ShapeUtil::MakeScalarShape(S32),
-                                    &output_shape);
-    }
-    HloInstruction* pad_to_static =
-        inst->AddInstruction(HloInstruction::CreateCustomCall(
-            output_shape, {inst}, "PadToStatic", ""));
-    HloInstruction* data_output =
-        inst->AddInstruction(HloInstruction::CreateGetTupleElement(
-            data_output_shape, pad_to_static, 0));
-    return data_output;
-  }
-
-  TF_RET_CHECK(inst->shape().IsTuple());
-  std::vector<HloInstruction*> static_tuple_elements;
-  for (int64_t i = 0; i < inst->shape().tuple_shapes_size(); ++i) {
-    // For each tuple element, if it is static, pass it through. If it is
-    // dynamic, recursively call this function again.
-    HloInstruction* gte =
-        inst->AddInstruction(HloInstruction::CreateGetTupleElement(
-            inst->shape().tuple_shapes(i), inst, i));
-
-    if (gte->shape().is_static()) {
-      static_tuple_elements.push_back(gte);
-    } else {
-      TF_ASSIGN_OR_RETURN(HloInstruction * static_gte,
-                          InsertPadToStaticOnInstruction(gte));
-      static_tuple_elements.push_back(static_gte);
-    }
-  }
-
-  return inst->AddInstruction(
-      HloInstruction::CreateTuple(static_tuple_elements));
-}
-
-// Inserts PadToStatic for parameters and custom-calls which "materialize"
-// dynamic outputs given only static inputs.
-Status InsertPadToStaticAfterModuleInputs(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  std::vector<HloInstruction*> params;
-  HloComputation* entry = module->entry_computation();
-  for (HloComputation* comp : module->MakeNonfusionComputationsSorted()) {
-    if (!HloInstruction::IsThreadIncluded(comp->execution_thread(),
-                                          execution_threads)) {
-      continue;
-    }
-    for (HloInstruction* instr : comp->instructions()) {
-      auto should_do_pad_to_static =
-          [&execution_threads](HloInstruction* instr) {
-            for (auto user : instr->users()) {
-              if (user->opcode() == HloOpcode::kAsyncStart) {
-                if (HloInstruction::IsThreadIncluded(
-                        user->async_execution_thread(), execution_threads)) {
-                  return true;
-                }
-              } else {
-                return true;
-              }
-            }
-            // If there are users, they must be the bypassing cases. Don't to
-            // pad-to-static.
-            return instr->users().empty();
-          };
-
-      if (!instr->shape().is_static() &&
-          ((instr->opcode() == HloOpcode::kParameter && comp == entry) ||
-           instr->opcode() == HloOpcode::kCustomCall ||
-           instr->opcode() == HloOpcode::kAsyncDone) &&
-          should_do_pad_to_static(instr) &&
-          absl::c_all_of(instr->operands(), [&](HloInstruction* operand) {
-            return operand->shape().is_static();
-          })) {
-        LOG(ERROR) << "Inserting PadToStatic for instruction: "
-                   << instr->ToString();
-        auto users = instr->users();
-        TF_ASSIGN_OR_RETURN(HloInstruction * instr_static,
-                            InsertPadToStaticOnInstruction(instr));
-        for (auto* user : users) {
-          TF_RETURN_IF_ERROR(instr->ReplaceUseWith(user, instr_static));
-        }
-        if (instr == entry->root_instruction()) {
-          module->entry_computation()->set_root_instruction(instr_static);
-        }
-      }
-    }
-  }
-  return OkStatus();
-}
-
 // Remove all dynamic shapes between pad-to-static and slice-to-dynamic.
 //
 // After this visitor the entry computation then looks like:
@@ -1928,14 +1840,15 @@
 //  SliceToDynamic(dynamic)
 //    |
 // ROOT tuple (dynamic)
-class DynamicShapeRemovingVisitor : public DfsHloVisitorWithDefault {
+class DynamicShapeRemovingVisitor : public DfsHloRewriteVisitor {
  public:
   explicit DynamicShapeRemovingVisitor(
-      const DynamicPadderOptions::OpSupportsDynamismHandler&
-          op_supports_dynamism_handler,
-      DynamicDimensionInference* dynamic_dimension_inference)
+      const OpSupportsDynamismHandler& op_supports_dynamism_handler,
+      DynamicDimensionInference* dynamic_dimension_inference,
+      const absl::flat_hash_set<absl::string_view>& execution_threads)
       : op_supports_dynamism_handler_(op_supports_dynamism_handler),
-        dynamic_dimension_inference_(dynamic_dimension_inference) {}
+        dynamic_dimension_inference_(dynamic_dimension_inference),
+        execution_threads_(execution_threads) {}
 
   Status DefaultAction(HloInstruction* hlo) override;
 
@@ -1948,15 +1861,24 @@
   Status HandleInfeed(HloInstruction* hlo) override;
 
   Status HandleAsyncStart(HloInstruction* hlo) override;
+  Status HandleAsyncUpdate(HloInstruction* hlo) override;
   Status HandleAsyncDone(HloInstruction* hlo) override;
 
-  static Status Run(HloComputation* computation,
-                    const DynamicPadderOptions::OpSupportsDynamismHandler&
-                        op_supports_dynamism_handler,
-                    DynamicDimensionInference* dynamic_shape_inference,
-                    bool require_dynamic_output) {
+  Status HandleWhile(HloInstruction* hlo) override;
+  Status HandleConditional(HloInstruction* hlo) override;
+
+  Status HandleGetDimensionSize(HloInstruction* hlo) override;
+  Status HandleSetDimensionSize(HloInstruction* hlo) override;
+
+  static StatusOr<bool> Run(
+      HloComputation* computation,
+      const OpSupportsDynamismHandler& op_supports_dynamism_handler,
+      DynamicDimensionInference* dynamic_shape_inference,
+      const absl::flat_hash_set<absl::string_view>& execution_threads,
+      bool require_dynamic_output) {
     DynamicShapeRemovingVisitor visitor(op_supports_dynamism_handler,
-                                        dynamic_shape_inference);
+                                        dynamic_shape_inference,
+                                        execution_threads);
     TF_RETURN_IF_ERROR(computation->Accept(&visitor));
     // If the outputs is required to be dynamic form, insert static to dynamic
     // conversion as root.
@@ -1968,173 +1890,114 @@
         computation->set_root_instruction(new_root);
       }
     }
-    return OkStatus();
+    return visitor.changed();
   }
 
  private:
-  // If a tensor produced by `inst` is in dynamic form, convert it to static and
-  // returns the new instruction.
-  StatusOr<HloInstruction*> ConvertToStatic(HloInstruction* inst);
-
   // If a tensor produced by `inst` is in static form, convert it to dynamic and
   // returns the new instruction.
   StatusOr<HloInstruction*> ConvertToDynamic(HloInstruction* inst);
 
-  const DynamicPadderOptions::OpSupportsDynamismHandler&
-      op_supports_dynamism_handler_;
+  // Same as above, but for all of the instructions operands.  The operands will
+  // be replaced by dynamic operands as needed.
+  Status ConvertOperandsToDynamic(HloInstruction* inst);
+
+  const OpSupportsDynamismHandler& op_supports_dynamism_handler_;
 
   DynamicDimensionInference* dynamic_dimension_inference_;
+
+  absl::flat_hash_set<absl::string_view> execution_threads_;
 };
 
 StatusOr<HloInstruction*> DynamicShapeRemovingVisitor::ConvertToDynamic(
     HloInstruction* inst) {
-  const Shape& shape = inst->shape();
-  if (shape.IsTuple()) {
-    std::vector<HloInstruction*> dynamic_operands;
-    for (int64_t i = 0; i < shape.tuple_shapes_size(); ++i) {
-      auto gte = inst->AddInstruction(HloInstruction::CreateGetTupleElement(
-          shape.tuple_shapes(i), inst, i));
-      if (dynamic_dimension_inference_->HasDynamicDimension(inst, {i})) {
-        TF_RETURN_IF_ERROR(dynamic_dimension_inference_->Update(gte));
-        TF_ASSIGN_OR_RETURN(auto dynamic, ConvertToDynamic(gte));
-        dynamic_operands.push_back(dynamic);
-      } else {
-        dynamic_operands.push_back(gte);
-      }
+  if (!dynamic_dimension_inference_->HasDynamicDimension(inst)) {
+    return OkStatus();
+  }
+  MarkAsChanged();
+  Shape shape = dynamic_dimension_inference_->GetDynamicShape(inst);
+  auto gtes = TupleUtil::DisassembleTupleInstruction(inst);
+
+  gtes.ForEachMutableElement([&](const ShapeIndex& index,
+                                 HloInstruction** element) {
+    const Shape& subshape = ShapeUtil::GetSubshape(shape, index);
+    if (!subshape.IsArray()) {
+      return;
     }
-    return inst->AddInstruction(HloInstruction::CreateTuple(dynamic_operands));
-  } else {
+    if (!dynamic_dimension_inference_->HasDynamicDimension(inst, index)) {
+      return;
+    }
     // Collect the data input, as well as dimension sizes, and feed them to
     // slice to dynamic to create a dynamic tensor.
-    Shape output_shape = shape;  // 0th element.
-    CHECK(output_shape.is_static());
     std::vector<HloInstruction*> slice_operand;
-    slice_operand.push_back(inst);
-    for (int64_t i = 0; i < output_shape.dimensions_size(); ++i) {
+    slice_operand.push_back(*element);
+    for (int64_t i = 0; i < subshape.dimensions_size(); ++i) {
       auto dimension_size =
-          dynamic_dimension_inference_->GetDynamicSize(inst, {}, i);
+          dynamic_dimension_inference_->GetDynamicSize(inst, index, i);
       if (dimension_size == nullptr) {
         dimension_size = inst->AddInstruction(HloInstruction::CreateConstant(
-            LiteralUtil::CreateR0<int32_t>(output_shape.dimensions(i))));
-      } else {
-        output_shape.set_dynamic_dimension(i, true);
+            LiteralUtil::CreateR0<int32_t>(subshape.dimensions(i))));
       }
       slice_operand.push_back(dimension_size);
     }
-    return inst->AddInstruction(HloInstruction::CreateCustomCall(
-        output_shape, slice_operand, "SliceToDynamic"));
-  }
+    *element = inst->AddInstruction(HloInstruction::CreateCustomCall(
+        subshape, slice_operand, "SliceToDynamic"));
+  });
+
+  return TupleUtil::AssembleTupleInstruction(inst->parent(), std::move(gtes));
 }
 
-StatusOr<HloInstruction*> DynamicShapeRemovingVisitor::ConvertToStatic(
+Status DynamicShapeRemovingVisitor::ConvertOperandsToDynamic(
     HloInstruction* inst) {
-  const Shape& shape = inst->shape();
-  CHECK(shape.is_dynamic());
-  if (shape.IsTuple()) {
-    std::vector<HloInstruction*> static_operands;
-    for (int64_t i = 0; i < shape.tuple_shapes_size(); ++i) {
-      auto gte = inst->AddInstruction(HloInstruction::CreateGetTupleElement(
-          shape.tuple_shapes(i), inst, i));
-      TF_RETURN_IF_ERROR(dynamic_dimension_inference_->Update(gte));
-      auto operand = inst->mutable_operand(i);
-      if (shape.tuple_shapes(i).is_dynamic()) {
-        TF_ASSIGN_OR_RETURN(auto static_inst, ConvertToStatic(gte));
-        static_operands.push_back(static_inst);
-      } else {
-        static_operands.push_back(operand);
-      }
+  for (int64_t i = 0; i < inst->operand_count(); ++i) {
+    auto operand = inst->mutable_operand(i);
+    if (dynamic_dimension_inference_->HasDynamicDimension(operand)) {
+      TF_ASSIGN_OR_RETURN(auto dynamic_operand,
+                          ConvertToDynamic(inst->mutable_operand(i)));
+      TF_RETURN_IF_ERROR(inst->ReplaceOperandWith(i, dynamic_operand));
+      MarkAsChanged();
     }
-    return inst->AddInstruction(HloInstruction::CreateTuple(static_operands));
-  } else {
-    // The output shape of pad static is a tuple. The 0th element is the data
-    // output, which is the same as input shape, but without dynamic dimensions.
-    // i-th element is the dynamic dimension size for i-1th input dimension.
-    Shape data_output_shape = shape;  // 0th element.
-    data_output_shape.clear_dynamic_dimensions();
-    Shape output_shape = ShapeUtil::MakeTupleShape({data_output_shape});
-    for (int64_t i = 0; i < shape.rank(); ++i) {
-      ShapeUtil::AppendShapeToTuple(ShapeUtil::MakeScalarShape(S32),
-                                    &output_shape);
-    }
-    HloInstruction* pad_to_static =
-        inst->AddInstruction(HloInstruction::CreateCustomCall(
-            output_shape, {inst}, "PadToStatic", ""));
-    HloInstruction* data_output =
-        inst->AddInstruction(HloInstruction::CreateGetTupleElement(
-            data_output_shape, pad_to_static, 0));
-    return data_output;
   }
+  return OkStatus();
 }
 
 Status DynamicShapeRemovingVisitor::DefaultAction(HloInstruction* hlo) {
-  const bool input_is_dynamic = absl::c_any_of(
-      hlo->operands(),
-      [](const HloInstruction* hlo) { return hlo->shape().is_dynamic(); });
-
   // By default, ops don't support dynamic lowering.
   OpDynamismSupport op_support = OpDynamismSupport::kNoSupport;
   if (op_supports_dynamism_handler_) {
     op_support = op_supports_dynamism_handler_(hlo);
   }
-  if (op_support == OpDynamismSupport::kNoSupport) {
-    for (auto* sub_computation : hlo->called_computations()) {
-      for (auto* param : sub_computation->parameter_instructions()) {
-        param->mutable_shape()->clear_dynamic_dimensions();
-      }
-    }
-  }
-  // If the input to an op is static and the op doesn't support
-  // dynamic output, remove dynamism in output -- dynamic_padder should have
-  // rewritten it to support static shapes.
-  if (!input_is_dynamic && op_support == OpDynamismSupport::kNoSupport) {
-    hlo->mutable_shape()->clear_dynamic_dimensions();
-    return OkStatus();
-  }
-
-  // Op doesn't support dynamic tensor: For each operand rewrite dynamic input
-  // into static input using pad_to_static.
-  if (input_is_dynamic && op_support == OpDynamismSupport::kNoSupport) {
-    VLOG(1) << "op doesn't support dynamic tensor: " << hlo->ToString();
-    for (int64_t i = 0; i < hlo->operand_count(); ++i) {
-      if (hlo->operand(i)->shape().is_dynamic()) {
-        TF_ASSIGN_OR_RETURN(auto static_operand,
-                            ConvertToStatic(hlo->mutable_operand(i)));
-        TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(i, static_operand));
-      }
-    }
-    // This op doesn't support dynamic lowering so the op has to be static.
-    hlo->mutable_shape()->clear_dynamic_dimensions();
-    return OkStatus();
-  }
 
   // If the op requires dynamic tensor and input is static -- construct a
   // dynamic tensor from the static tensor to feed it.
-  if (!input_is_dynamic && op_support == OpDynamismSupport::kRequired) {
+  if (op_support == OpDynamismSupport::kRequired) {
     VLOG(1) << "op doesn't support static tensor: " << hlo->ToString();
-    for (int64_t i = 0; i < hlo->operand_count(); ++i) {
-      auto operand = hlo->mutable_operand(i);
-      if (dynamic_dimension_inference_->HasDynamicDimension(operand)) {
-        TF_ASSIGN_OR_RETURN(auto dynamic_operand,
-                            ConvertToDynamic(hlo->mutable_operand(i)));
-        TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(i, dynamic_operand));
-      }
-    }
+    return ConvertOperandsToDynamic(hlo);
+  }
+
+  const bool input_is_dynamic = absl::c_any_of(
+      hlo->operands(),
+      [](const HloInstruction* hlo) { return hlo->shape().is_dynamic(); });
+
+  // If the input to an op is static, we are done.
+  if (!input_is_dynamic) {
     return OkStatus();
   }
 
+  // Op doesn't support dynamic tensor, but by now we should have already
+  // removed the dynamic dimensions for such ops.
+  TF_RET_CHECK(op_support != OpDynamismSupport::kNoSupport)
+      << "Dynamic input unexpectedly found for unsupported instruction: "
+      << hlo->ToString();
+
   return OkStatus();
 }
 
 Status DynamicShapeRemovingVisitor::HandleGetTupleElement(HloInstruction* hlo) {
-  *hlo->mutable_shape() =
-      hlo->operand(0)->shape().tuple_shapes(hlo->tuple_index());
   return OkStatus();
 }
 
 Status DynamicShapeRemovingVisitor::HandleTuple(HloInstruction* hlo) {
-  for (int64_t i = 0; i < hlo->operand_count(); ++i) {
-    *hlo->mutable_shape()->mutable_tuple_shapes(i) = hlo->operand(i)->shape();
-  }
   return OkStatus();
 }
 
@@ -2158,12 +2021,38 @@
 }
 
 Status DynamicShapeRemovingVisitor::HandleAsyncStart(HloInstruction* hlo) {
-  // async-start is handled specially in InsertPadToStaticAfterModuleInputs().
+  if (HloInstruction::IsThreadIncluded(hlo->async_execution_thread(),
+                                       execution_threads_)) {
+    return OkStatus();
+  }
+  return ConvertOperandsToDynamic(hlo);
+}
+
+Status DynamicShapeRemovingVisitor::HandleAsyncUpdate(HloInstruction* hlo) {
   return OkStatus();
 }
 
 Status DynamicShapeRemovingVisitor::HandleAsyncDone(HloInstruction* hlo) {
-  // async-done is handled specially in InsertPadToStaticAfterModuleInputs().
+  return OkStatus();
+}
+
+Status DynamicShapeRemovingVisitor::HandleWhile(HloInstruction* hlo) {
+  return OkStatus();
+}
+
+Status DynamicShapeRemovingVisitor::HandleConditional(HloInstruction* hlo) {
+  return OkStatus();
+}
+
+Status DynamicShapeRemovingVisitor::HandleGetDimensionSize(
+    HloInstruction* hlo) {
+  return OkStatus();
+}
+
+Status DynamicShapeRemovingVisitor::HandleSetDimensionSize(
+    HloInstruction* hlo) {
+  *hlo->mutable_shape() = hlo->operand(0)->shape();
+  hlo->mutable_shape()->set_dynamic_dimension(hlo->dimension(), false);
   return OkStatus();
 }
 
@@ -2172,17 +2061,22 @@
 StatusOr<bool> DynamicPadder::Run(
     HloModule* module,
     const absl::flat_hash_set<absl::string_view>& execution_threads) {
-  bool changed = false;
   VLOG(2) << "Pre DynamicPadder HLO:";
   XLA_VLOG_LINES(2, module->ToString());
-  TF_RETURN_IF_ERROR(
-      InsertPadToStaticAfterModuleInputs(module, execution_threads));
+
+  // Run DCE before inference, in case earlier passes left dead instructions
+  // that could cause us to insert PadToStatic when it isn't desired.
+  HloDCE dce;
+  TF_ASSIGN_OR_RETURN(bool changed, dce.Run(module, execution_threads));
+
   TF_ASSIGN_OR_RETURN(
       DynamicDimensionInference dynamic_dimension_inference,
       DynamicDimensionInference::Run(
-          module, options_.custom_call_handler, options_.shape_check_mode,
+          module, options_.op_supports_dynamism_handler,
+          options_.custom_call_handler, options_.shape_check_mode,
           options_.assertion_generator, execution_threads));
 
+  changed |= dynamic_dimension_inference.changed();
   std::vector<HloComputation*> computations =
       module->MakeComputationPostOrder(execution_threads);
 
@@ -2316,20 +2210,30 @@
   // There are ops that only support dynamic lowering and ops that only support
   // static lowering, add dynamic<->static tensor conversion around the boundary
   // between those ops, as well as the root instruction.
+  // DynamicDimensionInference can leave behind dead, partially inferred
+  // computations, but we want to ensure that ops that do not support dynamic
+  // shapes do not remain once the DynamicPadder is done.  So we filter out
+  // those computations using a CallGraph.
+  auto call_graph = CallGraph::Build(module, execution_threads);
   computations = module->MakeComputationPostOrder(execution_threads);
-  // Reverse postorder so that if caller doesn't support dynamic tensor (while,
-  // etc), change their called computation to only take static tensors.
+  // Reverse postorder so that if caller doesn't support dynamic tensor, change
+  // their called computation to only take static tensors.
   for (auto it = computations.rbegin(); it != computations.rend(); ++it) {
     HloComputation* computation = *it;
+    if (!call_graph->Dominates(module->entry_computation(), computation)) {
+      continue;
+    }
     // if slice_dynamic_output_ is set and this is entry computation, we need
     // the output tensor to be in dynamic form.
     bool require_dynamic_output = options_.slice_dynamic_output &&
                                   computation == module->entry_computation();
     changed |= require_dynamic_output;
-    TF_RETURN_IF_ERROR(DynamicShapeRemovingVisitor::Run(
-        computation, options_.op_supports_dynamism_handler,
-        &dynamic_dimension_inference,
-        /*require_dynamic_output=*/require_dynamic_output));
+    TF_ASSIGN_OR_RETURN(bool c,
+                        DynamicShapeRemovingVisitor::Run(
+                            computation, options_.op_supports_dynamism_handler,
+                            &dynamic_dimension_inference, execution_threads,
+                            /*require_dynamic_output=*/require_dynamic_output));
+    changed |= c;
   }
 
   if (changed) {
@@ -2338,6 +2242,9 @@
   }
 
   for (auto* computation : module->computations(execution_threads)) {
+    if (!call_graph->Dominates(module->entry_computation(), computation)) {
+      continue;
+    }
     for (auto instruction : computation->MakeInstructionPostOrder()) {
       TF_ASSIGN_OR_RETURN(
           bool c, ReplaceGetSize(instruction, &dynamic_dimension_inference));
@@ -2346,6 +2253,9 @@
   }
 
   for (auto* computation : module->computations(execution_threads)) {
+    if (!call_graph->Dominates(module->entry_computation(), computation)) {
+      continue;
+    }
     for (auto instruction : computation->MakeInstructionPostOrder()) {
       TF_ASSIGN_OR_RETURN(bool c, ReplaceSetSize(instruction));
       changed |= c;
@@ -2355,9 +2265,11 @@
     }
   }
 
-  HloDCE dce;
-  TF_ASSIGN_OR_RETURN(bool c, dce.Run(module, execution_threads));
-  changed |= c;
+  if (changed) {
+    HloDCE dce;
+    TF_ASSIGN_OR_RETURN(bool c, dce.Run(module, execution_threads));
+    changed |= c;
+  }
 
   VLOG(2) << "Post DynamicPadder HLO:";
   XLA_VLOG_LINES(2, module->ToString());
diff --git a/third_party/xla/xla/service/dynamic_padder.h b/third_party/xla/xla/service/dynamic_padder.h
index b35d667..ac6c388 100644
--- a/third_party/xla/xla/service/dynamic_padder.h
+++ b/third_party/xla/xla/service/dynamic_padder.h
@@ -38,27 +38,15 @@
 // Dynamic_padder removes dynamic shapes from the entry computation, and inserts
 // custom calls (with dynamic shapes), which are lowered by specialized
 // emitters: PadToStatic and SliceToDynamic.
-
-// Each instruction can have one of the three modes in supporting dynamic
-// lowering.
-enum OpDynamismSupport {
-  // There is no support for dynamic lowering -- dynamic padder will make sure
-  // the input to that op has static bound by rewriting the op (e.g, extra space
-  // in reduce_sum will be padded with 0).
-  kNoSupport = 0,
-  // The op can take either dynamic input or static input.
-  kOptional,
-  // The op only has a dynamic lowering, dynamic padder will make sure the input
-  // to this op is in dynamic form.
-  kRequired,
-};
+//
+// Note that it is not currently possible to send the output of PadToStatic
+// across thread boundaries, and such shapes will be sent across the boundary in
+// dynamic form. The DynamicPadder should be run separately for each thread that
+// requires static shapes, and the dynamic shapes will be padded within the
+// thread's computation.
 
 struct DynamicPadderOptions {
-  // Returns true if given instruction supports native dynamic lowering. If
-  // so, dynamic padder will not attempt to pad it.
-  using OpSupportsDynamismHandler =
-      std::function<OpDynamismSupport(HloInstruction*)>;
-
+  // Determines the form of dynamism supported by an HLO op.
   OpSupportsDynamismHandler op_supports_dynamism_handler = nullptr;
 
   // Instruct how to inference output dynamic dimensions of custom calls.
diff --git a/third_party/xla/xla/service/dynamic_padder_test.cc b/third_party/xla/xla/service/dynamic_padder_test.cc
index edb0b46..b8113de 100644
--- a/third_party/xla/xla/service/dynamic_padder_test.cc
+++ b/third_party/xla/xla/service/dynamic_padder_test.cc
@@ -15,14 +15,25 @@
 
 #include "xla/service/dynamic_padder.h"
 
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "absl/log/check.h"
+#include "absl/log/log.h"
 #include "absl/strings/str_replace.h"
+#include "absl/types/span.h"
 #include "xla/client/xla_builder.h"
+#include "xla/error_spec.h"
 #include "xla/hlo/ir/hlo_computation.h"
 #include "xla/hlo/ir/hlo_instruction.h"
 #include "xla/hlo/ir/hlo_module.h"
 #include "xla/hlo/ir/hlo_opcode.h"
 #include "xla/hlo/utils/hlo_matchers.h"
 #include "xla/literal.h"
+#include "xla/literal_util.h"
+#include "xla/service/algebraic_simplifier.h"
 #include "xla/service/dynamic_dimension_inference.h"
 #include "xla/service/dynamic_dimension_simplifier.h"
 #include "xla/service/hlo_dce.h"
@@ -30,8 +41,11 @@
 #include "xla/service/pattern_matcher.h"
 #include "xla/service/pattern_matcher_gmock.h"
 #include "xla/service/tuple_simplifier.h"
+#include "xla/shape.h"
 #include "xla/shape_util.h"
+#include "xla/status.h"
 #include "xla/status_macros.h"
+#include "xla/statusor.h"
 #include "xla/test.h"
 #include "xla/test_helpers.h"
 #include "xla/tests/client_library_test_base.h"
@@ -42,6 +56,9 @@
 #include "xla/util.h"
 #include "xla/xla_data.pb.h"
 #include "tsl/lib/core/status_test_util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/status.h"
+#include "tsl/platform/statusor.h"
 #include "tsl/platform/test_benchmark.h"
 #include "tsl/protobuf/error_codes.pb.h"
 
@@ -90,13 +107,27 @@
     return module;
   }
 
-  StatusOr<bool> RunPadder(bool slice_dynamic_output = false) {
+  StatusOr<bool> RunPadder(
+      bool slice_dynamic_output = false,
+      OpSupportsDynamismHandler op_supports_dynamism_handler =
+          OpHasDynamismSupport,
+      DynamicDimensionInference::CustomCallInferenceHandler
+          custom_call_handler = CustomCallDynamicDimensionInference) {
     DynamicPadderOptions options;
     options.slice_dynamic_output = slice_dynamic_output;
-    options.custom_call_handler = CustomCallDynamicDimensionInference;
-    options.op_supports_dynamism_handler = OpHasDynamismSupport;
+    options.op_supports_dynamism_handler =
+        std::move(op_supports_dynamism_handler);
+    options.custom_call_handler = std::move(custom_call_handler);
     DynamicPadder padder(std::move(options));
-    return RunHloPass(&padder, module_.get());
+    TF_ASSIGN_OR_RETURN(bool changed, RunHloPass(&padder, module_.get()));
+    if (!changed) return false;
+    // Dynamic padder can add redundant tuple/get-tuple-element and copy
+    // instructions.
+    TupleSimplifier tuple_simplifier;
+    TF_RETURN_IF_ERROR(RunHloPass(&tuple_simplifier, module_.get()).status());
+    AlgebraicSimplifier alg_simplifier(AlgebraicSimplifierOptions{});
+    TF_RETURN_IF_ERROR(RunHloPass(&alg_simplifier, module_.get()).status());
+    return true;
   }
 
   void ExpectPadded(const HloInstruction* inst) {
@@ -167,8 +198,8 @@
   data_param = builder.AddInstruction(HloInstruction::CreateSetDimensionSize(
       dynamic_shape, data_param, size_param, 2));
 
-  auto negate = builder.AddInstruction(
-      HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param));
+  auto negate = builder.AddInstruction(HloInstruction::CreateUnary(
+      dynamic_shape, HloOpcode::kNegate, data_param));
 
   auto init = builder.AddInstruction(
       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
@@ -282,8 +313,10 @@
   //   SliceToDynamic // Root require dynamic form tensor.
 
   auto* root = module_->entry_computation()->root_instruction();
-  EXPECT_THAT(root,
-              op::CustomCall({"SliceToDynamic"}, op::Negate(), op::Constant()));
+  // The final result should use the dynamic size provided by PadToStatic.
+  EXPECT_THAT(root, op::CustomCall(
+                        {"SliceToDynamic"}, op::Negate(),
+                        op::GetTupleElement(op::CustomCall({"PadToStatic"}))));
   HloInstruction* negate = root->mutable_operand(0);
   EXPECT_THAT(
       negate,
@@ -294,7 +327,7 @@
       module_->entry_computation()->GetInstructionWithName("custom-call.1");
   EXPECT_THAT(custom_call_1,
               op::CustomCall({"OpWithDynamicLowering"},
-                             op::Tuple(op::GetTupleElement(),
+                             op::Tuple(op::Constant(),
                                        op::CustomCall({"SliceToDynamic"}))));
 }
 
@@ -377,7 +410,7 @@
   constexpr int zdim = 1;
   auto xy_shape = ShapeUtil::MakeShape(F32, {xdim, ydim});
   auto yz_shape = ShapeUtil::MakeShape(F32, {ydim, zdim});
-  auto zx_shape = ShapeUtil::MakeShape(F32, {zdim, xdim});
+  auto zx_shape = ShapeUtil::MakeShape(F32, {zdim, xdim}, {false, true});
 
   auto dynamic_shape = ShapeUtil::MakeShape(F32, {xdim, ydim}, {true, false});
 
@@ -416,7 +449,7 @@
 TEST_F(DynamicPadderTest, ReduceWindowNoPadForTrivialWindow) {
   auto builder = HloComputation::Builder(TestName());
   auto input_shape = ShapeUtil::MakeShape(F32, {4, 5});
-  auto reduce_shape = ShapeUtil::MakeShape(F32, {3, 5});
+  auto reduce_shape = ShapeUtil::MakeShape(F32, {3, 5}, {false, true});
   auto dynamic_shape = ShapeUtil::MakeShape(F32, {4, 5}, {false, true});
 
   auto input = builder.AddInstruction(
@@ -463,7 +496,7 @@
   input_dynamic.1 = s32[4,<=5] set-dimension-size(input.1, size_param.0), dimensions={1}
   init.0 = f32[] constant(0.0)
   init.1 = s32[] constant(0)
-  ROOT output = (f32[3, 5], s32[3, 5]) reduce-window(input_dynamic.0, input_dynamic.1, init.0, init.1), window={size=2x1 pad=0_0x0_0}, to_apply=add_f32
+  ROOT output = (f32[3, <=5], s32[3, <=5]) reduce-window(input_dynamic.0, input_dynamic.1, init.0, init.1), window={size=2x1 pad=0_0x0_0}, to_apply=add_f32
 }
 )";
 
@@ -512,11 +545,7 @@
   TF_ASSERT_OK(RunPadder(/*slice_dynamic_output=*/true).status());
 
   EXPECT_THAT(module_->entry_computation()->root_instruction(),
-              GmockMatch(m::CustomCall(
-                  {"SliceToDynamic"},
-                  m::GetTupleElement(m::CustomCall(
-                      {"PadToStatic"}, m::CustomCall({"UnknownOp"}))),
-                  m::Op())));
+              GmockMatch(m::CustomCall({"UnknownOp"})));
 }
 
 TEST_F(DynamicPadderTest, WhileLoopDynamicShapeChangeToStatic) {
@@ -580,6 +609,115 @@
                                        ShapeUtil::MakeScalarShape(S32)}));
 }
 
+TEST_F(DynamicPadderTest, WhileLoopCarriesRequiredDynamicShape) {
+  // Test a while loop that carries dynamic shapes.
+  // This module is similar to an on-device training loop with gradients delayed
+  // by a step. Dynamic shapes are only touched by ops with dynamic lowerings,
+  // so they should not be padded.
+  const std::string hlo_text = R"(
+HloModule WhileLoopCarriesRequiredDynamicShape
+
+%cond {
+  param = (f32[1024], f32[<=64], f32[32], f32[<=64], f32[32], s32[], s32[], token[]) parameter(0)
+  current = s32[] get-tuple-element(param), index=5
+  last = s32[] get-tuple-element(param), index=6
+  ROOT result = pred[] compare(current, last), direction=LT
+}
+
+%body {
+  param = (f32[1024], f32[<=64], f32[32], f32[<=64], f32[32], s32[], s32[], token[]) parameter(0)
+  var = f32[1024] get-tuple-element(param), index=0
+  input0 = f32[<=64] get-tuple-element(param), index=1
+  grad0 = f32[32] get-tuple-element(param), index=2
+  input1 = f32[<=64] get-tuple-element(param), index=3
+  act1 = f32[32] get-tuple-element(param), index=4
+
+  grad1 = f32[32] custom-call(act1), custom_call_target="ComputeGradients"
+
+  var1 = f32[1024] custom-call(var, input0, grad0), custom_call_target="ApplyGradients", output_to_operand_aliasing={{}: (0, {})}
+
+  token2 = token[] get-tuple-element(param), index=7
+  infeed2 = (f32[<=64], token[]) infeed(token2)
+  input2 = f32[<=64] get-tuple-element(infeed2), index=0
+  act2 = f32[32] custom-call(var1, input2), custom_call_target="ComputeActivations"
+
+  current = s32[] get-tuple-element(param), index=5
+  constant1 = s32[] constant(1)
+  add = s32[] add(current, constant1)
+
+  last = s32[] get-tuple-element(param), index=6
+  token3 = token[] get-tuple-element(infeed2), index=1
+  ROOT result = (f32[1024], f32[<=64], f32[32], f32[<=64], f32[32], s32[], s32[], token[]) tuple(var1, input1, grad1, input2, act2, add, last, token3)
+}
+
+ENTRY main {
+  last = s32[] parameter(0)
+  var = f32[1024] parameter(1)
+
+  token0 = token[] after-all()
+  infeed0 = (f32[<=64], token[]) infeed(token0)
+  input0 = f32[<=64] get-tuple-element(infeed0), index=0
+  act0 = f32[32] custom-call(var, input0), custom_call_target="ComputeActivations"
+
+  grad0 = f32[32] custom-call(act0), custom_call_target="ComputeGradients"
+  token1 = token[] get-tuple-element(infeed0), index=1
+  infeed1 = (f32[<=64], token[]) infeed(token1)
+  input1 = f32[<=64] get-tuple-element(infeed1), index=0
+  act1 = f32[32] custom-call(var, input1), custom_call_target="ComputeActivations"
+
+  token2 = token[] get-tuple-element(infeed1), index=1
+
+  zero = s32[] constant(0)
+  tuple = (f32[1024], f32[<=64], f32[32]{0}, f32[<=64], f32[32]{0}, s32[], s32[], token[]) tuple(var, input0, grad0, input1, act1, zero, last, token2)
+  while = (f32[1024], f32[<=64], f32[32]{0}, f32[<=64], f32[32]{0}, s32[], s32[], token[]) while(tuple), condition=%cond, body=%body
+
+  ROOT result = f32[1024] get-tuple-element(while), index=0
+}
+)";
+
+  module_ = GetHloModule(hlo_text);
+
+  auto op_supports_dynamism = [](HloInstruction* hlo) {
+    if (hlo->opcode() != HloOpcode::kCustomCall) {
+      return OpDynamismSupport::kNoSupport;
+    }
+    if (hlo->custom_call_target() == "ComputeActivations" ||
+        hlo->custom_call_target() == "ApplyGradients") {
+      return OpDynamismSupport::kRequired;
+    }
+    return OpDynamismSupport::kNoSupport;
+  };
+  auto custom_call_handler = [](HloInstruction* hlo,
+                                DynamicDimensionInference* inference) {
+    return OkStatus();
+  };
+  TF_ASSERT_OK(
+      RunPadder(
+          /*slice_dynamic_output=*/true,
+          /*op_supports_dynamism_handler=*/std::move(op_supports_dynamism),
+          /*custom_call_handler=*/std::move(custom_call_handler))
+          .status());
+  XLA_LOG_LINES(1, module_->ToString());
+
+  for (HloComputation* computation : module_->computations()) {
+    for (HloInstruction* instruction : computation->instructions()) {
+      if (instruction->opcode() == HloOpcode::kCustomCall) {
+        EXPECT_NE(instruction->custom_call_target(), "PadToStatic");
+        EXPECT_NE(instruction->custom_call_target(), "SliceToDynamic");
+        if (instruction->custom_call_target() == "ComputeActivations") {
+          EXPECT_TRUE(instruction->operand(1)->shape().is_dynamic());
+        } else if (instruction->custom_call_target() == "ApplyGradients") {
+          EXPECT_TRUE(instruction->operand(1)->shape().is_dynamic());
+        }
+      } else if (instruction->opcode() == HloOpcode::kWhile) {
+        const Shape& shape = instruction->shape();
+        EXPECT_TRUE(shape.tuple_shapes(1).is_dynamic());
+        EXPECT_TRUE(shape.tuple_shapes(3).is_dynamic());
+      }
+    }
+  }
+}
+
 TEST_F(DynamicPadderTest, HandleReshapeCheckPastReshape) {
   // Two different sizes.
   auto hlo_text = R"(
@@ -589,10 +727,10 @@
   p1 = s32[] parameter(1)
   p2 = f32[432,337]{1,0:T(8,128)} parameter(2)
   p0_dynamic = f32[<=4,511,432] set-dimension-size(p0, p1), dimensions={0}
-  reshape.4179 = f32[2044,432]{1,0} reshape(p0_dynamic)
-   dot.4180 = f32[2044,337]{1,0} dot(reshape.4179, p2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-  transpose.4181 = f32[2044,337]{1,0} transpose(dot.4180), dimensions={0,1}
-  ROOT reshape.4183 = f32[4,511,337]{2,1,0} reshape(transpose.4181)
+  reshape.4179 = f32[<=2044,432]{1,0} reshape(p0_dynamic)
+  dot.4180 = f32[<=2044,337]{1,0} dot(reshape.4179, p2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  transpose.4181 = f32[<=2044,337]{1,0} transpose(dot.4180), dimensions={0,1}
+  ROOT reshape.4183 = f32[<=4,511,337]{2,1,0} reshape(transpose.4181)
 })";
   module_ = GetHloModule(hlo_text);
   // Set up dynamic parameter binding.
@@ -809,9 +947,9 @@
 ENTRY main {
   param = s32[3, 2, 1] parameter(0)
   size = s32[] constant(1)
-  param_padded = s32[3, 2, 1] set-dimension-size(param, size), dimensions={1}
+  param_padded = s32[3, <=2, 1] set-dimension-size(param, size), dimensions={1}
   index = s32[] constant(1)
-  gather = s32[2,1]{1,0} gather(param_padded, index),
+  gather = s32[<=2,1]{1,0} gather(param_padded, index),
               offset_dims={0,1},
               collapsed_slice_dims={0},
               start_index_map={0},
@@ -898,8 +1036,8 @@
 ENTRY main {
   param = s32[5] parameter(0)
   const = s32[] constant(3)
-  param_padded = s32[5] set-dimension-size(param, const), dimensions={0}
-  clamp = s32[5] clamp(param_padded, param_padded, param_padded)
+  param_padded = s32[<=5] set-dimension-size(param, const), dimensions={0}
+  clamp = s32[<=5] clamp(param_padded, param_padded, param_padded)
   init = s32[] constant(0)
   ROOT reduce = s32[] reduce(clamp, init),
       dimensions={0},
@@ -931,7 +1069,7 @@
   size = s32[] constant(2)
   param_padded_0 = s32[<=3] set-dimension-size(param_0, size), dimensions={0}
   param_padded_2 = s32[<=3] set-dimension-size(param_2, size), dimensions={0}
-  ROOT %concatenate = s32[9]
+  ROOT %concatenate = s32[<=9]
     concatenate(s32[<=3] param_padded_0, s32[<=3] param_1, s32[<=3] param_padded_2),
     dimensions={0}
 }
@@ -1055,8 +1193,8 @@
 ENTRY main {
   param = s32[1, 2, 5, 1] parameter(0)
   const = s32[] constant(3)
-  param_padded = s32[1, 2, 5, 1] set-dimension-size(param, const), dimensions={2}
-  reshaped = s32[10] reshape(param_padded)
+  param_padded = s32[1, 2, <=5, 1] set-dimension-size(param, const), dimensions={2}
+  reshaped = s32[<=10] reshape(param_padded)
   init = s32[] constant(0)
   ROOT reduce = s32[] reduce(reshaped, init),
       dimensions={0},
@@ -1085,7 +1223,7 @@
 ENTRY main {
   param = s32[5] parameter(0)
   const = s32[] constant(3)
-  param_padded = s32[5] set-dimension-size(param, const), dimensions={0}
+  param_padded = s32[<=5] set-dimension-size(param, const), dimensions={0}
   ROOT slice = s32[1]{0} slice(param_padded), slice={[0:1]}
 }
 )";
@@ -1114,9 +1252,9 @@
 ENTRY main {
   param = s32[12] parameter(0)
   const = s32[] constant(8)
-  param_padded = s32[12] set-dimension-size(param, const), dimensions={0}
+  param_padded = s32[<=12] set-dimension-size(param, const), dimensions={0}
   // Second dimension is dynamic.
-  reshaped = s32[2, 3, 2] reshape(param_padded), inferred_dimension=1
+  reshaped = s32[2, <=3, 2] reshape(param_padded), inferred_dimension=1
   init = s32[] constant(0)
   ROOT reduce = s32[2, 2] reduce(reshaped, init),
       dimensions={1},
@@ -1161,9 +1299,9 @@
 ENTRY main {
   param = s32[2, 6] parameter(0)
   const = s32[] constant(4)
-  param_padded = s32[2, 6] set-dimension-size(param, const), dimensions={1}
+  param_padded = s32[2, <=6] set-dimension-size(param, const), dimensions={1}
   // Third dimension is dynamic.
-  reshaped = s32[2, 2, 3] reshape(param_padded), inferred_dimension=2
+  reshaped = s32[2, 2, <=3] reshape(param_padded), inferred_dimension=2
   init = s32[] constant(0)
   ROOT reduce = s32[2, 2] reduce(reshaped, init),
       dimensions={2},
@@ -1206,9 +1344,9 @@
 ENTRY main {
   param = s32[6, 2] parameter(0)
   const = s32[] constant(4)
-  param_padded = s32[6, 2] set-dimension-size(param, const), dimensions={0}
+  param_padded = s32[<=6, 2] set-dimension-size(param, const), dimensions={0}
   // Second dimension is dynamic.
-  reshaped = s32[2, 3, 2] reshape(param_padded), inferred_dimension=1
+  reshaped = s32[2, <=3, 2] reshape(param_padded), inferred_dimension=1
   init = s32[] constant(0)
   ROOT reduce = s32[2, 2] reduce(reshaped, init),
       dimensions={1},
@@ -1250,7 +1388,7 @@
   one = f32[] constant(1)
   kernel = f32[1,5,1]{2,1,0} broadcast(f32[] one), dimensions={}
   param_dynamic = f32[1,1,<=5] set-dimension-size(param, const), dimensions={2}
-  ROOT conv = f32[1, 1, 1]{2,1,0} custom-call(f32[1, 1, <=5] param_dynamic, f32[1,5,1]{2,1,0} kernel),
+  ROOT conv = f32[1, 1, 1]{2,1,0} custom-call(f32[1, 1, <=5] param_dynamic, f32[1,<=5,1]{2,1,0} kernel),
                              window={size=1 pad=0_0},
                              dim_labels=b0f_0io->b0f,
                              padding_type=PADDING_VALID,
@@ -1307,8 +1445,8 @@
 ENTRY main {
   param = s32[1, 2, 5, 1] parameter(0)
   const = s32[] constant(3)
-  param_padded = s32[1, 2, 5, 1] set-dimension-size(param, const), dimensions={2}
-  reshaped = s32[2, 5] reshape(param_padded)
+  param_padded = s32[1, 2, <=5, 1] set-dimension-size(param, const), dimensions={2}
+  reshaped = s32[2, <=5] reshape(param_padded)
   init = s32[] constant(0)
   ROOT reduce = s32[2] reduce(reshaped, init),
       dimensions={1},
@@ -1342,9 +1480,9 @@
   param = s32[1, 2, 5, 1] parameter(0)
   size = s32[] constant(0)
 // First dimension is dynamic.
-  param_padded = s32[1, 2, 5, 1] set-dimension-size(param, size),
+  param_padded = s32[<=1, 2, 5, 1] set-dimension-size(param, size),
     dimensions={0}
-  reshaped = s32[10] reshape(param_padded)
+  reshaped = s32[<=10] reshape(param_padded)
   init = s32[] constant(0)
   ROOT reduce = s32[] reduce(reshaped, init),
       dimensions={0},
@@ -1517,10 +1655,10 @@
   one = s32[] constant(1)
   // content of the stack is the stack index broadcasted.
   new_data = s32[1, 2] broadcast(s32[] stack_size), dimensions={}
-  new_stack_buffer = s32[<=4, 2] dynamic-update-slice(stack_buffer, new_data, stack_size, zero)
   new_stack_size = s32[] add(stack_size, one)
-  new_stack_buffer_dynamic = s32[<=4, 2]set-dimension-size(new_stack_buffer, new_stack_size), dimensions={0}
-  ROOT new_stack = (s32[<=4,2]) tuple(new_stack_buffer_dynamic)
+  new_stack_buffer = s32[<=4, 2] set-dimension-size(stack_buffer, new_stack_size), dimensions={0}
+  new_stack = s32[<=4, 2] dynamic-update-slice(new_stack_buffer, new_data, stack_size, zero)
+  ROOT new_stack_tuple = (s32[<=4,2]) tuple(new_stack)
 }
 
 condition {
@@ -1688,9 +1826,11 @@
 ENTRY entry {
   one = s32[] constant(1)
   zero = s32[] constant(0)
+  four = s32[] constant(4)
   stack_buffer_input = s32[4, 2] broadcast(s32[] one), dimensions={}
-  input_tuple = (s32[4, 2]) tuple(stack_buffer_input)
-  while = (s32[4, 2]) while(input_tuple), body=body, condition=condition
+  stack_buffer_dynamic = s32[<=4, 2] set-dimension-size(stack_buffer_input, four), dimensions={0}
+  input_tuple = (s32[<=4, 2]) tuple(stack_buffer_dynamic)
+  while = (s32[<=4, 2]) while(input_tuple), body=body, condition=condition
   stack_buffer = s32[<=4, 2] get-tuple-element(while), index=0
   ROOT reduce = s32[2] reduce(stack_buffer, zero),
     dimensions={0},
@@ -1726,11 +1866,11 @@
 ENTRY main {
   param = s32[2, 3, 3] parameter(0)
   size = s32[] constant(2)
-  param_padded_partial = s32[2, 3, 3] set-dimension-size(param, size),
+  param_padded_partial = s32[2, <=3, 3] set-dimension-size(param, size),
     dimensions={1}
-  param_padded = s32[2, 3, 3] set-dimension-size(param_padded_partial, size),
+  param_padded = s32[2, 3, <=3] set-dimension-size(param_padded_partial, size),
     dimensions={2}
-  reshaped = s32[18] reshape(param_padded)
+  reshaped = s32[<=18] reshape(param_padded)
   init = s32[] constant(0)
   ROOT reduce = s32[] reduce(reshaped, init),
       dimensions={0},
@@ -1922,9 +2062,9 @@
 ENTRY main {
   param = s32[4] parameter(0)
   size = s32[] constant(3)
-  param_dynamic_size = s32[4] set-dimension-size(param, size),
+  param_dynamic_size = s32[<=4] set-dimension-size(param, size),
     dimensions={0}
-  ROOT sort = s32[4]{0} sort(s32[4]{0} %param_dynamic_size),
+  ROOT sort = s32[<=4]{0} sort(s32[4]{0} %param_dynamic_size),
     dimensions={0}, is_stable=false, to_apply=%compare-greater-than
 }
 )";
@@ -2088,12 +2228,12 @@
 ENTRY main {
   param = s32[3] parameter(0)
   size = s32[] constant(2)
-  param_dynamic_size = s32[3] set-dimension-size(param, size),
+  param_dynamic_size = s32[<=3] set-dimension-size(param, size),
     dimensions={0}
-  sort = (s32[3]{0}, s32[3]{0}) sort(s32[3]{0} %param_dynamic_size,
-                                     s32[3]{0} %param_dynamic_size),
+  sort = (s32[<=3]{0}, s32[<=3]{0}) sort(s32[<=3]{0} %param_dynamic_size,
+                                         s32[<=3]{0} %param_dynamic_size),
     dimensions={0}, is_stable=true, to_apply=%compare-greater-than
-  ROOT get-tuple-element = s32[3]{0} get-tuple-element((s32[3]{0}, s32[3]{0}) %sort),
+  ROOT get-tuple-element = s32[<=3]{0} get-tuple-element((s32[<=3]{0}, s32[<=3]{0}) %sort),
     index=0
 }
 )";
diff --git a/third_party/xla/xla/service/dynamic_parameter_binding_test.cc b/third_party/xla/xla/service/dynamic_parameter_binding_test.cc
index 43992f5..55e8b37 100644
--- a/third_party/xla/xla/service/dynamic_parameter_binding_test.cc
+++ b/third_party/xla/xla/service/dynamic_parameter_binding_test.cc
@@ -16,19 +16,16 @@
 #include "xla/hlo/ir/dynamic_parameter_binding.h"
 
 #include <memory>
+#include <optional>
 #include <string>
 
-#include "absl/algorithm/container.h"
+#include <gtest/gtest.h>
 #include "xla/hlo/ir/hlo_computation.h"
 #include "xla/hlo/ir/hlo_instruction.h"
-#include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/service/hlo_dce.h"
-#include "xla/service/hlo_memory_scheduler.h"
-#include "xla/service/hlo_ordering.h"
 #include "xla/shape_util.h"
 #include "xla/tests/hlo_test_base.h"
-#include "xla/types.h"
 #include "tsl/lib/core/status_test_util.h"
+#include "tsl/platform/statusor.h"
 
 namespace xla {
 namespace {
@@ -52,11 +49,11 @@
   DynamicParameterBinding binding;
 
   TF_EXPECT_OK(
-      binding.Bind(DynamicParameterBinding::DynamicParameter{0, {}},
+      binding.Bind(DynamicParameterBinding::DynamicSizeParameter{0, {}},
                    DynamicParameterBinding::DynamicDimension{1, {}, 0}));
 
   auto test = [&](const DynamicParameterBinding& binding) {
-    std::optional<DynamicParameterBinding::DynamicParameter> param =
+    std::optional<DynamicParameterBinding::DynamicSizeParameter> param =
         binding.GetBinding(
             DynamicParameterBinding::DynamicDimension{/*parameter_num=*/1,
                                                       /*parameter_index=*/{},
@@ -64,7 +61,7 @@
     EXPECT_TRUE(param);
     EXPECT_EQ(param->parameter_num, 0);
     EXPECT_EQ(param->parameter_index, ShapeIndex({}));
-    TF_EXPECT_OK(binding.Verify(*module));
+    TF_EXPECT_OK(binding.Verify(*module->entry_computation()));
   };
   test(binding);
 }
@@ -88,11 +85,11 @@
   DynamicParameterBinding binding;
 
   TF_EXPECT_OK(
-      binding.Bind(DynamicParameterBinding::DynamicParameter{0, {0}},
+      binding.Bind(DynamicParameterBinding::DynamicSizeParameter{0, {0}},
                    DynamicParameterBinding::DynamicDimension{0, {1}, 0}));
 
   auto test = [&](const DynamicParameterBinding& binding) {
-    std::optional<DynamicParameterBinding::DynamicParameter> param =
+    std::optional<DynamicParameterBinding::DynamicSizeParameter> param =
         binding.GetBinding(
             DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0,
                                                       /*parameter_index=*/{1},
@@ -101,7 +98,7 @@
     EXPECT_TRUE(param);
     EXPECT_EQ(param->parameter_num, 0);
     EXPECT_EQ(param->parameter_index, ShapeIndex({0}));
-    TF_EXPECT_OK(binding.Verify(*module));
+    TF_EXPECT_OK(binding.Verify(*module->entry_computation()));
   };
   test(binding);
 }
@@ -125,15 +122,15 @@
   DynamicParameterBinding binding;
 
   TF_EXPECT_OK(
-      binding.Bind(DynamicParameterBinding::DynamicParameter{0, {0}},
+      binding.Bind(DynamicParameterBinding::DynamicSizeParameter{0, {0}},
                    DynamicParameterBinding::DynamicDimension{0, {1}, 0}));
 
   TF_EXPECT_OK(
-      binding.Bind(DynamicParameterBinding::DynamicParameter{0, {0}},
+      binding.Bind(DynamicParameterBinding::DynamicSizeParameter{0, {0}},
                    DynamicParameterBinding::DynamicDimension{0, {1}, 1}));
 
   auto test = [&](const DynamicParameterBinding& binding) {
-    std::optional<DynamicParameterBinding::DynamicParameter> param =
+    std::optional<DynamicParameterBinding::DynamicSizeParameter> param =
         binding.GetBinding(
             DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0,
                                                       /*parameter_index=*/{1},
@@ -143,7 +140,7 @@
     EXPECT_EQ(param->parameter_num, 0);
     EXPECT_EQ(param->parameter_index, ShapeIndex({0}));
 
-    std::optional<DynamicParameterBinding::DynamicParameter> param2 =
+    std::optional<DynamicParameterBinding::DynamicSizeParameter> param2 =
 
         binding.GetBinding(
             DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0,
@@ -152,7 +149,7 @@
     EXPECT_TRUE(param2);
     EXPECT_EQ(param2->parameter_num, 0);
     EXPECT_EQ(param2->parameter_index, ShapeIndex({0}));
-    TF_EXPECT_OK(binding.Verify(*module));
+    TF_EXPECT_OK(binding.Verify(*module->entry_computation()));
   };
 
   test(binding);
diff --git a/third_party/xla/xla/service/elemental_ir_emitter_test.cc b/third_party/xla/xla/service/elemental_ir_emitter_test.cc
index 9ca6f0a..9823d91 100644
--- a/third_party/xla/xla/service/elemental_ir_emitter_test.cc
+++ b/third_party/xla/xla/service/elemental_ir_emitter_test.cc
@@ -13,6 +13,8 @@
 limitations under the License.
 ==============================================================================*/
 
+#include "absl/strings/string_view.h"
+#include "xla/error_spec.h"
 #include "xla/execution_options_util.h"
 #include "xla/service/hlo_parser.h"
 #include "xla/status_macros.h"
@@ -48,6 +50,18 @@
   }
 };
 
+class ElementalIrEmitterExecutionTestWithoutFastMinMax
+    : public ElementalIrEmitterExecutionTest {
+ protected:
+  DebugOptions GetDebugOptionsForTest() override {
+    DebugOptions debug_options =
+        ElementalIrEmitterExecutionTest::GetDebugOptionsForTest();
+    debug_options.set_xla_cpu_enable_fast_min_max(false);
+    debug_options.set_xla_gpu_enable_fast_min_max(false);
+    return debug_options;
+  }
+};
+
 XLA_TEST_F(ElementalIrEmitterExecutionTest, DotFusion) {
   const std::string hlo_text = R"(
 HloModule FusedDot
@@ -669,5 +683,133 @@
   RunTest(hlo_text, {});
 }
 
+XLA_TEST_F(ElementalIrEmitterExecutionTestWithoutFastMinMax,
+           MinimumHandlesNaNsOnTheLeft) {
+  constexpr absl::string_view kHloText = R"(
+HloModule t
+
+ENTRY e {
+  neg1 = f32[] constant(-1)
+  neg1s = f32[5,5] broadcast(neg1), dimensions={}
+  nans = f32[5,5] sqrt(neg1s)
+  ROOT min = f32[5,5] minimum(nans, neg1s)
+})";
+
+  EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}));
+}
+
+XLA_TEST_F(ElementalIrEmitterExecutionTestWithoutFastMinMax,
+           MinimumHandlesNaNsOnTheRight) {
+  constexpr absl::string_view kHloText = R"(
+HloModule t
+
+ENTRY e {
+  neg1 = f32[] constant(-1)
+  neg1s = f32[5,5] broadcast(neg1), dimensions={}
+  nans = f32[5,5] sqrt(neg1s)
+  ROOT min = f32[5,5] minimum(neg1s, nans)
+})";
+
+  EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}));
+}
+
+XLA_TEST_F(ElementalIrEmitterExecutionTestWithoutFastMinMax,
+           MaximumHandlesNaNsOnTheLeft) {
+  constexpr absl::string_view kHloText = R"(
+HloModule t
+
+ENTRY e {
+  neg1 = f32[] constant(-1)
+  neg1s = f32[5,5] broadcast(neg1), dimensions={}
+  nans = f32[5,5] sqrt(neg1s)
+  ROOT max = f32[5,5] maximum(nans, neg1s)
+})";
+
+  EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}));
+}
+
+XLA_TEST_F(ElementalIrEmitterExecutionTestWithoutFastMinMax,
+           MaximumHandlesNaNsOnTheRight) {
+  constexpr absl::string_view kHloText = R"(
+HloModule t
+
+ENTRY e {
+  neg1 = f32[] constant(-1)
+  neg1s = f32[5,5] broadcast(neg1), dimensions={}
+  nans = f32[5,5] sqrt(neg1s)
+  ROOT max = f32[5,5] maximum(neg1s, nans)
+})";
+
+  EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}));
+}
+
+XLA_TEST_F(ElementalIrEmitterExecutionTestWithoutFastMinMax,
+           MinimumReturnsLHS) {
+  constexpr absl::string_view kHloText = R"(
+HloModule t
+
+ENTRY e {
+  zero = f32[] constant(0)
+  zeros = f32[5,5] broadcast(zero), dimensions={}
+  one = f32[] constant(1)
+  ones = f32[5,5] broadcast(one), dimensions={}
+  ROOT min = f32[5,5] minimum(zeros, ones)
+})";
+
+  EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3,
+                                                /*arel=*/1e-3}));
+}
+
+XLA_TEST_F(ElementalIrEmitterExecutionTestWithoutFastMinMax,
+           MinimumReturnsRHS) {
+  constexpr absl::string_view kHloText = R"(
+HloModule t
+
+ENTRY e {
+  zero = f32[] constant(0)
+  zeros = f32[5,5] broadcast(zero), dimensions={}
+  one = f32[] constant(1)
+  ones = f32[5,5] broadcast(one), dimensions={}
+  ROOT min = f32[5,5] minimum(ones, zeros)
+})";
+
+  EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3,
+                                                /*arel=*/1e-3}));
+}
+
+XLA_TEST_F(ElementalIrEmitterExecutionTestWithoutFastMinMax,
+           MaximumReturnsLHS) {
+  constexpr absl::string_view kHloText = R"(
+HloModule t
+
+ENTRY e {
+  zero = f32[] constant(0)
+  zeros = f32[5,5] broadcast(zero), dimensions={}
+  one = f32[] constant(1)
+  ones = f32[5,5] broadcast(one), dimensions={}
+  ROOT max = f32[5,5] maximum(ones, zeros)
+})";
+
+  EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3,
+                                                /*arel=*/1e-3}));
+}
+
+XLA_TEST_F(ElementalIrEmitterExecutionTestWithoutFastMinMax,
+           MaximumReturnsRHS) {
+  constexpr absl::string_view kHloText = R"(
+HloModule t
+
+ENTRY e {
+  zero = f32[] constant(0)
+  zeros = f32[5,5] broadcast(zero), dimensions={}
+  one = f32[] constant(1)
+  ones = f32[5,5] broadcast(one), dimensions={}
+  ROOT max = f32[5,5] maximum(zeros, ones)
+})";
+
+  EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3,
+                                                /*arel=*/1e-3}));
+}
+
 }  // namespace
 }  // namespace xla
diff --git a/third_party/xla/xla/service/float_normalization.cc b/third_party/xla/xla/service/float_normalization.cc
index 267a923..84774a3 100644
--- a/third_party/xla/xla/service/float_normalization.cc
+++ b/third_party/xla/xla/service/float_normalization.cc
@@ -335,6 +335,9 @@
 
   std::vector<HloComputation*> low_precision_called_comps;
   for (auto* comp : hlo->called_computations()) {
+    if (comp->IsCollectiveCalledComputation()) {
+      continue;
+    }
     bool comp_has_low_precision = false;
     if (comp->root_instruction()->shape().element_type() ==
         HighPrecisionType()) {
@@ -411,6 +414,9 @@
 
   std::vector<HloComputation*> low_precision_called_comps;
   for (auto* comp : hlo->called_computations()) {
+    if (comp->IsCollectiveCalledComputation()) {
+      continue;
+    }
     bool comp_has_low_precision = false;
     high_prec_count += CountSubshapesWithMatchingType(
         comp->root_instruction()->shape(), HighPrecisionType());
@@ -549,6 +555,11 @@
                         ", before:\n" + module->ToString());
   FloatNormalizationVisitor visitor(float_support_, this);
   for (auto* comp : module->MakeComputationPostOrder(execution_threads)) {
+    if (comp->IsCollectiveCalledComputation()) {
+      XLA_VLOG_LINES(2, "Skip processing collective called computation: " +
+                            comp->ToString());
+      continue;
+    }
     TF_RETURN_IF_ERROR(comp->Accept(&visitor));
   }
   XLA_VLOG_LINES(2, "FloatNormalization::Run() for " +
diff --git a/third_party/xla/xla/service/float_normalization_test.cc b/third_party/xla/xla/service/float_normalization_test.cc
index 3a41960..2d6a976 100644
--- a/third_party/xla/xla/service/float_normalization_test.cc
+++ b/third_party/xla/xla/service/float_normalization_test.cc
@@ -76,6 +76,38 @@
   }
 };
 
+// The test float class that doesn't support any compute ops for low-precision
+// but supports some collectives.
+class TestFloatNoComputeSupport : public FloatSupport {
+ public:
+  explicit TestFloatNoComputeSupport(PrimitiveType low_precision_type)
+      : FloatSupport(low_precision_type) {}
+  ~TestFloatNoComputeSupport() override = default;
+
+  bool SupportsLowPrecisionOperand(const HloInstruction& hlo,
+                                   int64_t operand_index) const override {
+    if (hlo.opcode() == HloOpcode::kTuple ||
+        hlo.opcode() == HloOpcode::kGetTupleElement ||
+        hlo.opcode() == HloOpcode::kAllToAll ||
+        hlo.opcode() == HloOpcode::kAllReduce ||
+        hlo.opcode() == HloOpcode::kReduceScatter) {
+      return true;
+    }
+    return false;
+  }
+
+  bool SupportsLowPrecisionOutput(const HloInstruction& hlo) const override {
+    if (hlo.opcode() == HloOpcode::kTuple ||
+        hlo.opcode() == HloOpcode::kGetTupleElement ||
+        hlo.opcode() == HloOpcode::kAllToAll ||
+        hlo.opcode() == HloOpcode::kAllReduce ||
+        hlo.opcode() == HloOpcode::kReduceScatter) {
+      return true;
+    }
+    return false;
+  }
+};
+
 class FloatNormalizationTest : public HloTestBase {
  protected:
   FloatNormalizationTest()
@@ -485,4 +517,135 @@
   EXPECT_EQ(mul1->operand(0)->opcode(), HloOpcode::kConvert);
 }
 
+class FloatNormalizationNoComputeSupportTest : public FloatNormalizationTest {
+ protected:
+  bool Normalize(HloModule* module, PrimitiveType low_precision_type = BF16) {
+    TestFloatNoComputeSupport float_support(low_precision_type);
+    FloatNormalization normalization(&float_support);
+
+    StatusOr<bool> result = normalization.Run(module);
+    EXPECT_IS_OK(result.status());
+
+    HloVerifier verifier(/*layout_sensitive=*/false,
+                         /*allow_mixed_precision=*/true);
+    EXPECT_IS_OK(verifier.Run(module).status());
+
+    return result.value();
+  }
+};
+
+TEST_F(FloatNormalizationNoComputeSupportTest,
+       NoNormalizationForToApplyMultiOuputAllReduce) {
+  auto module = CreateNewVerifiedModule();
+  HloComputation::Builder sum_builder("sum");
+  auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter(
+      /*parameter_number=*/0, ShapeUtil::MakeShape(BF16, {}), "x"));
+  auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter(
+      /*parameter_number=*/1, ShapeUtil::MakeShape(BF16, {}), "y"));
+  sum_builder.AddInstruction(HloInstruction::CreateBinary(
+      ShapeUtil::MakeShape(BF16, {}), HloOpcode::kAdd, x, y));
+  HloComputation* reduction =
+      module->AddEmbeddedComputation(sum_builder.Build());
+
+  auto builder = HloComputation::Builder(TestName());
+  Shape bf16_shape_a = ShapeUtil::MakeShape(BF16, {2, 4});
+  Shape bf16_shape_b = ShapeUtil::MakeShape(BF16, {16, 16});
+
+  HloInstruction* a = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, bf16_shape_a, "a"));
+  HloInstruction* b = builder.AddInstruction(
+      HloInstruction::CreateParameter(1, bf16_shape_b, "b"));
+
+  HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateAllReduce(
+      ShapeUtil::MakeTupleShape({bf16_shape_a, bf16_shape_b}), {a, b},
+      reduction,
+      /*replica_groups=*/{},
+      /*constrain_layout=*/false,
+      /*channel_id=*/std::nullopt,
+      /*use_global_device_ids=*/false));
+  builder.AddInstruction(
+      HloInstruction::CreateGetTupleElement(bf16_shape_b, crs, 1));
+
+  auto computation = module->AddEntryComputation(builder.Build());
+  // Since we skip processing to_apply region, nothing should change in the
+  // original HLO.
+  EXPECT_FALSE(Normalize(module.get()));
+  EXPECT_EQ(computation->root_instruction()->shape().element_type(), BF16);
+  EXPECT_EQ(crs->operand(1)->shape().element_type(), BF16);
+  EXPECT_EQ(crs->to_apply()->root_instruction()->opcode(), HloOpcode::kAdd);
+  EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {1}).element_type(), BF16);
+}
+
+TEST_F(FloatNormalizationNoComputeSupportTest,
+       NoNormalizationForToApplyAllReduce) {
+  auto module = CreateNewVerifiedModule();
+  HloComputation::Builder sum_builder("sum");
+  auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter(
+      /*parameter_number=*/0, ShapeUtil::MakeShape(BF16, {}), "x"));
+  auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter(
+      /*parameter_number=*/1, ShapeUtil::MakeShape(BF16, {}), "y"));
+  sum_builder.AddInstruction(HloInstruction::CreateBinary(
+      ShapeUtil::MakeShape(BF16, {}), HloOpcode::kAdd, x, y));
+  HloComputation* reduction =
+      module->AddEmbeddedComputation(sum_builder.Build());
+
+  auto builder = HloComputation::Builder(TestName());
+  Shape bf16_shape_a = ShapeUtil::MakeShape(BF16, {2, 4});
+
+  HloInstruction* a = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, bf16_shape_a, "a"));
+
+  HloInstruction* crs = builder.AddInstruction(
+      HloInstruction::CreateAllReduce(bf16_shape_a, {a}, reduction,
+                                      /*replica_groups=*/{},
+                                      /*constrain_layout=*/false,
+                                      /*channel_id=*/std::nullopt,
+                                      /*use_global_device_ids=*/false));
+
+  auto computation = module->AddEntryComputation(builder.Build());
+  // Since we skip processing to_apply region, nothing should change in the
+  // original HLO.
+  EXPECT_FALSE(Normalize(module.get()));
+  EXPECT_EQ(computation->root_instruction()->shape().element_type(), BF16);
+  EXPECT_EQ(crs->operand(0)->shape().element_type(), BF16);
+  EXPECT_EQ(crs->to_apply()->root_instruction()->opcode(), HloOpcode::kAdd);
+}
+
+TEST_F(FloatNormalizationNoComputeSupportTest,
+       NoNormalizationForToApplyReduceScatter) {
+  auto module = CreateNewVerifiedModule();
+  HloComputation::Builder sum_builder("sum");
+  auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter(
+      /*parameter_number=*/0, ShapeUtil::MakeShape(BF16, {}), "x"));
+  auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter(
+      /*parameter_number=*/1, ShapeUtil::MakeShape(BF16, {}), "y"));
+  sum_builder.AddInstruction(HloInstruction::CreateBinary(
+      ShapeUtil::MakeShape(BF16, {}), HloOpcode::kAdd, x, y));
+  HloComputation* reduction =
+      module->AddEmbeddedComputation(sum_builder.Build());
+
+  auto builder = HloComputation::Builder(TestName());
+  Shape bf16_shape_a = ShapeUtil::MakeShape(BF16, {2, 4});
+  Shape bf16_shape_scattered = ShapeUtil::MakeShape(BF16, {1, 4});
+
+  HloInstruction* a = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, bf16_shape_a, "a"));
+
+  HloInstruction* crs =
+      builder.AddInstruction(HloInstruction::CreateReduceScatter(
+          bf16_shape_scattered, {a}, reduction,
+          /*replica_groups=*/{},
+          /*constrain_layout=*/false,
+          /*channel_id=*/std::nullopt,
+          /*use_global_device_ids=*/false, /*scatter_dimension*/ 0));
+
+  auto computation = module->AddEntryComputation(builder.Build());
+  // Since we skip processing to_apply region, nothing should change in the
+  // original HLO.
+  EXPECT_FALSE(Normalize(module.get()));
+  EXPECT_EQ(computation->root_instruction()->shape().element_type(), BF16);
+  EXPECT_EQ(crs->operand(0)->shape().element_type(), BF16);
+  EXPECT_EQ(crs->to_apply()->root_instruction()->opcode(), HloOpcode::kAdd);
+}
+
 }  // namespace xla
diff --git a/third_party/xla/xla/service/generic_transfer_manager_test.cc b/third_party/xla/xla/service/generic_transfer_manager_test.cc
index 4dce741..107ad62 100644
--- a/third_party/xla/xla/service/generic_transfer_manager_test.cc
+++ b/third_party/xla/xla/service/generic_transfer_manager_test.cc
@@ -15,6 +15,7 @@
 
 #include "xla/service/generic_transfer_manager.h"
 
+#include <cstddef>
 #include <cstdint>
 #include <optional>
 #include <utility>
@@ -88,6 +89,23 @@
   EXPECT_EQ(absl::Span<uint16_t>(device_ptr, expected.size()), expected);
 }
 
+MATCHER_P2(MaskedValuesEqual, mask, expected, "") {
+  if (arg.size() != expected.size()) {
+    *result_listener << "argument sizes do not match";
+    return false;
+  }
+  for (size_t i = 0; i < expected.size(); ++i) {
+    const auto v1 = arg[i] & mask;
+    const auto v2 = expected[i] & mask;
+    if (v1 != v2) {
+      *result_listener << "mismatch at position " << i << ", " << v1 << " vs "
+                       << v2;
+      return false;
+    }
+  }
+  return true;
+}
+
 TEST_F(GenericTransferManagerTest, TransferLiteralToDeviceInt4) {
   Literal literal =
       LiteralUtil::CreateR2<s4>({{s4{1}, s4{-2}}, {s4{-3}, s4{4}}});
@@ -104,8 +122,10 @@
     std::vector<int8_t> expected =
         pack ? std::vector<int8_t>{static_cast<int8_t>(0x1e),
                                    static_cast<int8_t>(0xd4)}
-             : std::vector<int8_t>{1, (-2) & 0xf, (-3) & 0xf, 4};
-    EXPECT_EQ(absl::Span<int8_t>(device_ptr, expected.size()), expected);
+             : std::vector<int8_t>{1, -2, -3, 4};
+    // Ignore high bits in equality comparisons.
+    EXPECT_THAT(absl::Span<int8_t>(device_ptr, expected.size()),
+                MaskedValuesEqual(pack ? 0xFF : 0x0F, expected));
   }
 }
 
diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD
index 3973819..00f1d5e 100644
--- a/third_party/xla/xla/service/gpu/BUILD
+++ b/third_party/xla/xla/service/gpu/BUILD
@@ -3,7 +3,8 @@
 
 load("//xla/tests:build_defs.bzl", "xla_test")
 load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
-load("//xla:xla.bzl", "xla_cc_test", "xla_export_hlo_deps")
+load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library")
+load("//xla:xla.bzl", "xla_cc_test", "xla_cub_deps", "xla_export_hlo_deps")
 load(
     "//xla/service/gpu:build_defs.bzl",
     "build_cub_sort_kernels",
@@ -129,10 +130,12 @@
     tags = tf_cuda_tests_tags(),
     deps = [
         "//xla:debug_options_flags",
+        "//xla:status",
         "//xla:status_macros",
         "//xla:test_helpers",
         "//xla/client:xla_builder",
         "//xla/client/lib:constants",
+        "//xla/ffi",
         "//xla/runtime:custom_call",
         "//xla/runtime:custom_call_registry",
         "//xla/runtime:executable",
@@ -149,6 +152,8 @@
         "//xla/stream_executor/gpu:gpu_types_header",
         "//xla/tests:client_library_test_base",
         "//xla/tests:xla_internal_test_main",  # fixdeps: keep
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/strings",
         "@local_tsl//tsl/lib/core:status_test_util",
         "@local_tsl//tsl/platform:test",
     ] + if_cuda_is_configured([
@@ -260,13 +265,14 @@
     name = "ir_emitter_unnested",
     srcs = ["ir_emitter_unnested.cc"],
     hdrs = ["ir_emitter_unnested.h"],
-    local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]) + if_rocm_hipblaslt([
+    local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([
+        "TENSORFLOW_USE_ROCM=1",
+    ]) + if_rocm_hipblaslt([
         "TF_HIPBLASLT=1",
     ]),
     visibility = ["//visibility:public"],
     deps = [
         ":backend_configs_cc",
-        ":fft_thunk",
         ":gemm_thunk",
         ":gpu_asm_opts_util",
         ":gpu_constants",
@@ -296,6 +302,8 @@
         "//xla:statusor",
         "//xla:util",
         "//xla:xla_data_proto_cc",
+        "//xla/ffi",
+        "//xla/ffi/api:c_api",
         "//xla/hlo/ir:hlo",
         "//xla/hlo/utils:hlo_query",
         "//xla/mlir_hlo",
@@ -303,12 +311,18 @@
         "//xla/mlir_hlo:lhlo_gpu",
         "//xla/mlir_hlo:transforms_gpu_passes",
         "//xla/service:buffer_assignment",
+        "//xla/service:custom_call_status",
         "//xla/service:custom_call_target_registry",
         "//xla/service:name_uniquer",
         "//xla/service/gpu/fusions",
+        "//xla/service/gpu/fusions:fusion_emitter",
+        "//xla/service/gpu/fusions:input_slices",
+        "//xla/service/gpu/fusions:loop",
         "//xla/service/gpu/fusions:thunk_util",
         "//xla/service/gpu/fusions:tiling_util",
+        "//xla/service/gpu/fusions:transpose",
         "//xla/service/gpu/runtime3:custom_call_thunk",
+        "//xla/service/gpu/runtime3:fft_thunk",
         "//xla/service/llvm_ir:buffer_assignment_util",
         "//xla/service/llvm_ir:dynamic_update_slice_util",
         "//xla/service/llvm_ir:fused_ir_emitter",
@@ -327,6 +341,7 @@
         "@com_google_absl//absl/container:flat_hash_set",
         "@com_google_absl//absl/container:inlined_vector",
         "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/status",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/strings:str_format",
         "@com_google_absl//absl/types:span",
@@ -352,7 +367,7 @@
         "@local_tsl//tsl/protobuf:dnn_proto_cc",
     ] + if_gpu_is_configured([
         ":cub_sort_thunk",
-        ":cublas_lt_matmul_thunk",
+        ":gpublas_lt_matmul_thunk",
         ":ir_emitter_triton",
         "//xla/service/gpu/runtime3:cholesky_thunk",
         "//xla/service/gpu/runtime3:triangular_solve_thunk",
@@ -503,6 +518,7 @@
         ":gpu_device_info_for_tests",
         ":ir_emission_utils",
         ":ir_emitter_triton",
+        ":matmul_utils",
         "//xla:autotuning_proto_cc",
         "//xla:error_spec",
         "//xla:status_macros",
@@ -604,6 +620,7 @@
         ":gpu_fusible",
         ":instruction_fusion",
         ":ir_emission_utils",
+        ":matmul_utils",
         ":split_k_gemm_rewriter",
         ":stream_executor_util",
         "@com_google_absl//absl/algorithm:container",
@@ -635,6 +652,7 @@
         "//xla/stream_executor:device_memory",
         "//xla/stream_executor",
         "//xla/stream_executor/gpu:redzone_allocator",
+        "@local_tsl//tsl/lib/core:bits",
         "@local_tsl//tsl/platform:blocking_counter",
         "@local_tsl//tsl/platform:env",
         "@local_tsl//tsl/platform:errors",
@@ -660,6 +678,7 @@
         ":autotuner_util",
         ":backend_configs_cc",
         ":gemm_rewriter_triton",
+        ":matmul_utils",
         ":triton_autotuner",
         "//xla:autotuning_proto_cc",
         "//xla:error_spec",
@@ -931,7 +950,6 @@
         ":backend_configs_cc",
         ":buffer_allocations",
         ":cusolver_context",
-        ":fft_thunk",
         ":gemm_thunk",
         ":gpu_asm_opts_util",
         ":gpu_constants",
@@ -976,6 +994,7 @@
         "//xla/service/gpu/runtime:executable",
         "//xla/service/gpu/runtime:support",
         "//xla/service/gpu/runtime3:custom_call_thunk",
+        "//xla/service/gpu/runtime3:fft_thunk",
         "//xla/stream_executor",
         "//xla/stream_executor:blas",
         "//xla/stream_executor:device_description",
@@ -1110,21 +1129,50 @@
     ],
 )
 
+cuda_library(
+    name = "gpu_prim_cuda",
+    hdrs = ["gpu_prim_cuda.h"],
+    local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
+    visibility = ["//visibility:public"],
+    deps = [
+        "@eigen_archive//:eigen3",
+        "@local_tsl//tsl/platform:bfloat16",
+    ] + if_cuda_is_configured(xla_cub_deps()),
+)
+
+cc_library(
+    name = "gpu_prim_rocm",
+    hdrs = ["gpu_prim_rocm.h"],
+    local_defines = if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]),
+    visibility = ["//visibility:public"],
+    deps = [
+        "@eigen_archive//:eigen3",
+        "@local_tsl//tsl/platform:bfloat16",
+    ] + if_rocm_is_configured([
+        "@local_config_rocm//rocm:rocprim",
+    ]),
+)
+
 cc_library(
     name = "cub_sort_thunk",
     srcs = if_gpu_is_configured(["cub_sort_thunk.cc"]),
     hdrs = if_gpu_is_configured(["cub_sort_thunk.h"]),
+    local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([
+        "TENSORFLOW_USE_ROCM=1",
+    ]),
     visibility = ["//visibility:public"],
     deps = if_gpu_is_configured([
         ":buffer_allocations",
         ":thunk",
         "@com_google_absl//absl/log:check",
-        "@local_config_cuda//cuda:cuda_headers",
         "//xla/service:buffer_assignment",
         "//xla:shape_util",
         "//xla/stream_executor:device_memory",
         "//xla:status",
+        "//xla:statusor",
+        "//xla:util",
         "//xla:xla_data_proto_cc",
+        "@local_tsl//tsl/platform:errors",
     ] + [":cub_sort_kernel_" + suffix for suffix in get_cub_sort_kernel_types()]),
 )
 
@@ -1132,34 +1180,15 @@
     name = "cub_sort_kernel",
     srcs = if_gpu_is_configured(["cub_sort_kernel.cu.cc"]),
     hdrs = if_gpu_is_configured(["cub_sort_kernel.h"]),
-    types = get_cub_sort_kernel_types(),
-    deps = if_gpu_is_configured([
-        "@com_google_absl//absl/status",
-        "@com_google_absl//absl/strings",
-        "@local_config_cuda//cuda:cuda_headers",  #cub
+    local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([
+        "TENSORFLOW_USE_ROCM=1",
     ]),
-)
-
-cc_library(
-    name = "fft_thunk",
-    srcs = ["fft_thunk.cc"],
-    hdrs = ["fft_thunk.h"],
-    visibility = ["//visibility:public"],
-    deps = [
-        ":buffer_allocations",
-        ":thunk",
-        "//xla:types",
-        "//xla:util",
-        "//xla:xla_data_proto_cc",
-        "//xla/hlo/ir:hlo",
-        "//xla/service:buffer_assignment",
-        "//xla/stream_executor",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/strings:str_format",
-        "@local_tsl//tsl/platform:logging",
-        "@local_tsl//tsl/platform:status",
-    ],
+    types = get_cub_sort_kernel_types(),
+    deps = if_cuda_is_configured([
+        ":gpu_prim_cuda",
+    ]) + if_rocm_is_configured([
+        ":gpu_prim_rocm",
+    ]),
 )
 
 cc_library(
@@ -1174,6 +1203,7 @@
         ":ir_emission_utils",
         ":matmul_utils",
         "//xla:shape_util",
+        "//xla:status",
         "//xla:status_macros",
         "//xla:statusor",
         "//xla:xla_data_proto_cc",
@@ -1187,6 +1217,7 @@
         "//xla/stream_executor/gpu:gpu_blas_lt",
         "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/log",
+        "@com_google_absl//absl/status",
         "@com_google_absl//absl/strings",
         "@local_tsl//tsl/platform:errors",
         "@local_tsl//tsl/platform:statusor",
@@ -1289,12 +1320,15 @@
     srcs = ["split_k_gemm_rewriter_test.cc"],
     deps = [
         ":gemm_rewriter_triton",
+        ":matmul_utils",
         ":split_k_gemm_rewriter",
         "//xla:autotuning_proto_cc",
         "//xla:shape_util",
         "//xla:xla_data_proto_cc",
         "//xla:xla_proto_cc",
         "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_verifier",
+        "//xla/service:layout_assignment",
         "//xla/service:pattern_matcher",
         "//xla/service:pattern_matcher_gmock",
         "//xla/tests:hlo_test_base",
@@ -1358,13 +1392,9 @@
 )
 
 cc_library(
-    name = "cublas_lt_matmul_thunk",
-    srcs = if_cuda_is_configured(["cublas_lt_matmul_thunk.cc"]) + if_rocm_is_configured([
-        "cublas_lt_matmul_thunk.cc",
-    ]),
-    hdrs = if_cuda_is_configured(["cublas_lt_matmul_thunk.h"]) + if_rocm_is_configured([
-        "cublas_lt_matmul_thunk.h",
-    ]),
+    name = "gpublas_lt_matmul_thunk",
+    srcs = if_gpu_is_configured(["gpublas_lt_matmul_thunk.cc"]),
+    hdrs = if_gpu_is_configured(["gpublas_lt_matmul_thunk.h"]),
     local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([
         "TENSORFLOW_USE_ROCM=1",
     ]),
@@ -1378,11 +1408,6 @@
         "//xla/stream_executor",
         "@local_tsl//tsl/platform:logging",
         "@local_tsl//tsl/platform:statusor",
-    ]) + if_cuda_is_configured([
-        "//xla/stream_executor/cuda:cublas_lt_header",
-        "//xla/stream_executor/cuda:cublas_plugin",
-    ]) + if_rocm_is_configured([
-        "//xla/stream_executor/rocm:hipblas_lt_header",
     ]),
 )
 
@@ -1529,7 +1554,9 @@
     deps = [
         ":backend_configs_cc",
         ":ir_emission_utils",
+        "//xla:autotuning_proto_cc",
         "//xla:shape_util",
+        "//xla:status",
         "//xla:status_macros",
         "//xla:statusor",
         "//xla:types",
@@ -1542,6 +1569,7 @@
         "//xla/stream_executor/gpu:gpu_blas_lt",
         "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:span",
         "@local_tsl//tsl/platform:status",
         "@local_tsl//tsl/platform:statusor",
@@ -1790,6 +1818,30 @@
 )
 
 cc_library(
+    name = "gpu_sort_rewriter",
+    srcs = if_cuda_is_configured(["gpu_sort_rewriter.cc"]),
+    hdrs = if_cuda_is_configured(["gpu_sort_rewriter.h"]),
+    visibility = ["//visibility:public"],
+    deps = [
+        ":cub_sort_thunk",
+        ":cublas_cudnn",
+        "//xla:comparison_util",
+        "//xla:shape_util",
+        "//xla:statusor",
+        "//xla:util",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:hlo_pass",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:span",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:logging",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+cc_library(
     name = "move_copy_to_users",
     srcs = ["move_copy_to_users.cc"],
     hdrs = ["move_copy_to_users.h"],
@@ -1839,6 +1891,27 @@
     ],
 )
 
+xla_test(
+    name = "gpu_sort_rewriter_test",
+    srcs = if_cuda_is_configured(["gpu_sort_rewriter_test.cc"]),
+    backends = ["gpu"],
+    tags = ["no_oss"],
+    deps = [
+        ":cublas_cudnn",
+        ":gpu_sort_rewriter",
+        "//xla:statusor",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:pattern_matcher",
+        "//xla/service:pattern_matcher_gmock",
+        "//xla/tests:hlo_test_base",
+        "//xla/tests:xla_internal_test_main",  # fixdeps: keep
+        "@com_google_googletest//:gtest",
+        "@local_tsl//tsl/platform:statusor",
+        "@local_tsl//tsl/platform:test",
+    ],
+)
+
 cc_library(
     name = "cusolver_context",
     srcs = if_gpu_is_configured(["cusolver_context.cc"]),
@@ -2428,6 +2501,7 @@
     hdrs = ["gpu_all_gather_optimizer.h"],
     visibility = ["//visibility:public"],
     deps = [
+        "//xla:shape_util",
         "//xla:statusor",
         "//xla/hlo/ir:hlo",
         "//xla/service:collective_ops_utils",
@@ -2465,7 +2539,6 @@
         ":buffer_sharing",
         ":executable_proto_cc",
         ":gpu_constants",
-        ":gpu_convert_async_collectives_to_sync",
         ":gpu_executable",
         ":ir_emitter_context",
         ":ir_emitter_unnested",
@@ -2494,6 +2567,7 @@
         "//xla/translate/mhlo_to_hlo:location_exporter",
         "//xla/translate/mhlo_to_lhlo_with_xla",
         "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/container:flat_hash_set",
         "@com_google_absl//absl/strings",
         "@llvm-project//llvm:AsmParser",
         "@llvm-project//llvm:Support",
@@ -2772,6 +2846,7 @@
         "//xla/service:zero_sized_hlo_elimination",
         "//xla/service/gpu/model:gpu_cost_model_stats_collection",
         "//xla/service/gpu/model:gpu_hlo_cost_analysis",
+        "//xla/service:sub_byte_normalization",
         "//xla/service/llvm_ir:llvm_util",
         "//xla/service/spmd:collective_permute_motion",
         "//xla/service/spmd:stateful_rng_spmd_partitioner",
@@ -2795,6 +2870,7 @@
     ]) + xla_export_hlo_deps() + [
         ":fusion_pipeline",
         ":prepare_hlo_for_ir_emitting_pipeline",
+        "//xla:status",
         "@local_tsl//tsl/lib/monitoring:counter",
     ],
 )
@@ -2879,6 +2955,7 @@
         ":gpu_conv_rewriter",
         ":gpu_executable",
         ":gpu_layout_assignment",
+        ":gpu_sort_rewriter",
         ":ir_emission_utils",
         ":metrics",
         ":target_constants",
@@ -3173,8 +3250,13 @@
     deps = [
         ":backend_configs_cc",
         ":cublas_cudnn",
+        "//xla:shape_util",
+        "//xla:status",
+        "//xla:statusor",
+        "//xla:util",
         "//xla/hlo/ir:hlo",
         "//xla/hlo/utils:hlo_query",
+        "//xla/service:buffer_value",
         "//xla/service:hlo_memory_scheduler",
         "//xla/service:hlo_pass_pipeline",
         "//xla/service:latency_hiding_scheduler",
@@ -3182,6 +3264,11 @@
         "//xla/service:profile_guided_latency_estimator",
         "//xla/service/gpu/model:analytical_latency_estimator",
         "//xla/stream_executor:device_description",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
         "@com_google_absl//absl/status",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/strings:str_format",
@@ -3199,13 +3286,23 @@
     tags = tf_cuda_tests_tags(),
     deps = [
         ":gpu_hlo_schedule",
+        "//xla:shape_util",
         "//xla/hlo/ir:hlo",
         "//xla/hlo/utils:hlo_query",
+        "//xla/service:backend",
         "//xla/service:gpu_plugin",
+        "//xla/service:hlo_module_config",
+        "//xla/service:hlo_ordering",
         "//xla/stream_executor:device_description",
+        "//xla/tests:filecheck",
         "//xla/tests:hlo_test_base",
         "//xla/tests:test_utils",
         "//xla/tests:xla_internal_test_main",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/log",
+        "@com_google_googletest//:gtest",
+        "@local_tsl//tsl/platform:status",
+        "@local_tsl//tsl/platform:statusor",
         "@local_tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc",
     ],
 )
@@ -3298,6 +3395,7 @@
         "//xla/stream_executor:device_description",
         "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/container:inlined_vector",
+        "@com_google_absl//absl/log:check",
         "@com_google_absl//absl/types:span",
         "@llvm-project//llvm:ir_headers",
         "@local_tsl//tsl/platform:macros",
@@ -3362,9 +3460,12 @@
     deps = [
         ":backend_configs_cc",
         ":cublas_cudnn",
+        ":hlo_fusion_analysis",
         ":ir_emission_utils",
         "//xla:shape_util",
         "//xla/hlo/ir:hlo",
+        "//xla/stream_executor:device_description",
+        "//xla/stream_executor:device_description_proto_cc",
         "@com_google_absl//absl/container:flat_hash_set",
     ],
 )
@@ -3919,7 +4020,8 @@
         ":custom_call_test",
         # copybara:uncomment ":gpu_aot_compilation_test",
         # copybara:uncomment "//platforms/xla/tests/internal:xfeed_test_gpu",
-        # copybara:uncomment "//third_party/py/jax/experimental/jax2tf/tests:primitives_test_gpu",
+        # TODO(anlunx): Re-enable when the FFI mechanism is avalable in Thunk-based runtime.
+        # copybara:uncomment # "//third_party/py/jax/experimental/jax2tf/tests:primitives_test_gpu",
         # copybara:uncomment "//third_party/py/jax/tests:pmap_test_gpu",
         # copybara:uncomment "//tensorflow/compiler/tests:fft_test_gpu",
         "//xla/python:xla_client_test_gpu",
@@ -4354,13 +4456,15 @@
         ":gpu_constants",
         ":ir_emission_utils",
         "//xla:shape_util",
+        "//xla:status",
         "//xla:statusor",
+        "//xla/hlo/ir:hlo",
         "//xla/mlir_hlo:lhlo",
-        "//xla/mlir_hlo:transforms_gpu_passes",
         "//xla/service:buffer_assignment",
+        "@com_google_absl//absl/types:span",
         "@llvm-project//llvm:Support",
-        "@llvm-project//mlir:GPUDialect",
         "@llvm-project//mlir:IR",
+        "@local_tsl//tsl/platform:errors",
     ],
 )
 
@@ -4514,3 +4618,13 @@
         "@local_tsl//tsl/platform:statusor",
     ],
 )
+
+cc_library(
+    name = "gpu_symbol_repository",
+    hdrs = ["gpu_symbol_repository.h"],
+    visibility = ["//visibility:public"],
+    deps = [
+        "//xla:xla_proto_cc",
+        "//xla/service:symbol_repository",
+    ],
+)
diff --git a/third_party/xla/xla/service/gpu/amdgpu_compiler.cc b/third_party/xla/xla/service/gpu/amdgpu_compiler.cc
index 52306d0..a5c3fd7 100644
--- a/third_party/xla/xla/service/gpu/amdgpu_compiler.cc
+++ b/third_party/xla/xla/service/gpu/amdgpu_compiler.cc
@@ -113,12 +113,10 @@
 
 Status AMDGPUCompiler::OptimizeHloPostLayoutAssignment(
     HloModule* hlo_module, se::StreamExecutor* stream_exec,
-    const CompileOptions& options, const GpuTargetConfig& gpu_target_config,
-    const AutotuneResults* autotune_results,
+    const CompileOptions& options, const TargetConfig& gpu_target_config,
     tsl::thread::ThreadPool* thread_pool) {
   TF_RETURN_IF_ERROR(GpuCompiler::OptimizeHloPostLayoutAssignment(
-      hlo_module, stream_exec, options, gpu_target_config, autotune_results,
-      thread_pool));
+      hlo_module, stream_exec, options, gpu_target_config, thread_pool));
 
   HloPassPipeline post_pipeline("AMDGPU post-layout_assignment");
 
diff --git a/third_party/xla/xla/service/gpu/amdgpu_compiler.h b/third_party/xla/xla/service/gpu/amdgpu_compiler.h
index cbc2765..fb2a2e1 100644
--- a/third_party/xla/xla/service/gpu/amdgpu_compiler.h
+++ b/third_party/xla/xla/service/gpu/amdgpu_compiler.h
@@ -40,8 +40,7 @@
 
   Status OptimizeHloPostLayoutAssignment(
       HloModule* hlo_module, se::StreamExecutor* stream_exec,
-      const CompileOptions& options, const GpuTargetConfig& gpu_target_config,
-      const AutotuneResults* autotune_results,
+      const CompileOptions& options, const TargetConfig& gpu_target_config,
       tsl::thread::ThreadPool* thread_pool) override;
 
   bool RequiresCollectiveScheduleLinearizer(
diff --git a/third_party/xla/xla/service/gpu/autotuner_compile_util.cc b/third_party/xla/xla/service/gpu/autotuner_compile_util.cc
index ed91d49..6965937 100644
--- a/third_party/xla/xla/service/gpu/autotuner_compile_util.cc
+++ b/third_party/xla/xla/service/gpu/autotuner_compile_util.cc
@@ -88,7 +88,7 @@
   // Avoid using another thread pool.
   opts_.set_xla_gpu_force_compilation_parallelism(1);
   // Avoid using GPU graphs as we don't want to measure graph construction time.
-  opts_.set_xla_gpu_graph_level(0);
+  opts_.clear_xla_gpu_enable_command_buffer();
   // Disable experimental XLA:GPU runtime.
   opts_.set_xla_gpu_enable_gpu2_runtime(false);
   opts_.set_xla_embed_ir_in_executable(false);
diff --git a/third_party/xla/xla/service/gpu/backend_configs.proto b/third_party/xla/xla/service/gpu/backend_configs.proto
index a113dc7..1c90e1a 100644
--- a/third_party/xla/xla/service/gpu/backend_configs.proto
+++ b/third_party/xla/xla/service/gpu/backend_configs.proto
@@ -94,6 +94,12 @@
   }
 
   Epilogue epilogue = 13;
+
+  optional int64 lhs_stride = 14;
+  optional int64 rhs_stride = 15;
+
+  optional bool grad_x = 16;
+  optional bool grad_y = 17;
 }
 
 // Backend config for bitcast operation generated from MLIR MHLO dialect.
diff --git a/third_party/xla/xla/service/gpu/buffer_sharing.cc b/third_party/xla/xla/service/gpu/buffer_sharing.cc
index d02cbc7..6442159 100644
--- a/third_party/xla/xla/service/gpu/buffer_sharing.cc
+++ b/third_party/xla/xla/service/gpu/buffer_sharing.cc
@@ -21,13 +21,18 @@
 #include <utility>
 
 #include "absl/container/flat_hash_set.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
 #include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
 #include "xla/hlo/ir/hlo_opcode.h"
 #include "xla/service/gpu/backend_configs.pb.h"
 #include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/gpu/hlo_fusion_analysis.h"
 #include "xla/service/gpu/ir_emission_utils.h"
 #include "xla/shape.h"
 #include "xla/shape_util.h"
+#include "xla/stream_executor/device_description.h"
+#include "xla/stream_executor/device_description.pb.h"
 
 namespace xla {
 namespace gpu {
@@ -35,7 +40,8 @@
 std::optional<bool> FusionCanShareBufferHint(const HloInstruction* user,
                                              const HloInstruction* operand,
                                              const ShapeIndex& user_index) {
-  if (user->opcode() != HloOpcode::kFusion) {
+  const HloFusionInstruction* fusion = DynCast<HloFusionInstruction>(user);
+  if (fusion == nullptr) {
     return std::nullopt;
   }
 
@@ -65,10 +71,21 @@
     }
   }
 
+  // Allow multiple output users, if they end in reductions.
+  // This only works for the reduction emitter, as it calculates the reduction
+  // first, i.e. before processing other outputs (that may overwrite the input).
+  stream_executor::GpuDeviceInfoProto device_info;
+  stream_executor::DeviceDescription device_description(device_info);
+  auto analysis = HloFusionAnalysis::Create(fusion, &device_description);
+  bool is_reduction_emitter = analysis->GetEmitterFusionKind() ==
+                              HloFusionAnalysis::EmitterFusionKind::kReduction;
+  const HloInstruction* reduction_hero =
+      is_reduction_emitter ? reduction_hero = analysis->FindHeroReduction()
+                           : nullptr;
+
   // We need to make sure that the fusion parameter is accessed in the same
-  // iteration order as the fusion output. Also, there should not be two fusion
-  // outputs that consume the fusion parameter, because we do not want to share
-  // the same fusion operand with two different fusion outputs. To make sure
+  // iteration order as the fusion output. Also, there should not be any other
+  // fusion output that accesses it in a different iteration order. To make sure
   // that the iteration order is the same, we only allow ops on the path from
   // fusion parameter to fusion output which are elementwise (no copy) or
   // bitcast or an elementwise dynamic update slice (i.e. with the first operand
@@ -88,16 +105,21 @@
   q.push(fusion_param);
   visited.insert(fusion_param);
   bool found_path_to_output = false;
+  int reached_root = 0;
   while (!q.empty()) {
     HloInstruction* hlo_operand = q.front();
     q.pop();
     if (hlo_operand == output) {
       found_path_to_output = true;
-      // The output should have at most 1 user: the tuple op (in case of a
-      // multi-output fusion)
-      if (hlo_operand->user_count() > 1) {
+      // We still need to process the users of 'hlo_operand'. There can be other
+      // reduction users in addition to the tuple user.
+      if (hlo_operand->user_count() > 1 && !is_reduction_emitter) {
         return false;
       }
+    }
+    // Reduction emitter processes the reduction first, so the values below it
+    // will not interfere with buffer sharing.
+    if (hlo_operand == reduction_hero) {
       continue;
     }
     for (HloInstruction* hlo : hlo_operand->users()) {
@@ -134,7 +156,8 @@
       } else if ((!hlo->IsElementwiseOnOperand(
                       hlo->operand_index(hlo_operand)) ||
                   hlo->opcode() == HloOpcode::kCopy) &&
-                 hlo->opcode() != HloOpcode::kBitcast) {
+                 hlo->opcode() != HloOpcode::kBitcast &&
+                 hlo->opcode() != HloOpcode::kTuple && hlo != reduction_hero) {
         // This check also catches the case that we reach a different fusion
         // output, as that fusion output would have a tuple op as user, which we
         // do not allow here.
@@ -151,9 +174,12 @@
           return false;
         }
       }
+      if (hlo->IsRoot()) {
+        ++reached_root;
+      }
     }
   }
-  return found_path_to_output;
+  return found_path_to_output && (user_index.empty() || reached_root == 1);
 }
 
 std::optional<bool> CanShareBufferHint(const HloInstruction* user,
diff --git a/third_party/xla/xla/service/gpu/build_defs.bzl b/third_party/xla/xla/service/gpu/build_defs.bzl
index 4d361f2..bca5e3b 100644
--- a/third_party/xla/xla/service/gpu/build_defs.bzl
+++ b/third_party/xla/xla/service/gpu/build_defs.bzl
@@ -2,6 +2,7 @@
 """
 
 load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library")
+load("@local_config_rocm//rocm:build_defs.bzl", "rocm_copts")  # copybara:comment
 
 def get_cub_sort_kernel_types(name = ""):
     """ List of supported types for CUB sort kernels.
@@ -29,12 +30,13 @@
         "u64_b64",
     ]
 
-def build_cub_sort_kernels(name, types, **kwargs):
+def build_cub_sort_kernels(name, types, local_defines = [], **kwargs):
     """ Create build rules for all CUB sort kernels.
     """
     for suffix in types:
         cuda_library(
             name = name + "_" + suffix,
-            local_defines = ["CUB_TYPE_" + suffix.upper()],
+            local_defines = local_defines + ["CUB_TYPE_" + suffix.upper()],
+            copts = rocm_copts(),  # copybara:comment
             **kwargs
         )
diff --git a/third_party/xla/xla/service/gpu/command_buffer_scheduling.cc b/third_party/xla/xla/service/gpu/command_buffer_scheduling.cc
index 5a32f7c..7cf4d40 100644
--- a/third_party/xla/xla/service/gpu/command_buffer_scheduling.cc
+++ b/third_party/xla/xla/service/gpu/command_buffer_scheduling.cc
@@ -227,7 +227,7 @@
   builder.AddInstruction(HloInstruction::CreateTuple(new_instructions));
 
   BuildCommandBufferResult result = {builder.Build(), parameters_map,
-                                     inst_to_tuple_index_map};
+                                     inst_to_tuple_index_map, instructions_map};
   return result;
 }
 
@@ -270,7 +270,45 @@
           HloInstruction::CreateGetTupleElement(call_command_buffer, i));
     }
 
+    // Remove instructions in the command buffer sequence.
+    bool first_inst = true;
     for (HloInstruction* inst : seq.instructions()) {
+      // Replace the first instruction in the sequence by command buffer call.
+      // Removal of the rest of the instructions in the sequence is handled by
+      // HloSchedule::Update().
+      if (first_inst) {
+        first_inst = false;
+        HloInstructionSequence& sequence =
+            module->schedule().GetOrCreateSequence(entry);
+        sequence.replace_instruction(inst, call_command_buffer);
+      }
+
+      // Forward control dependencies to the new instruction inside command
+      // buffer. If the dependent instruction is not captured by the command
+      // buffer, forward the dependency to the command buffer call instead.
+      HloInstruction* new_inst = result.instructions_map[inst];
+      for (HloInstruction* predecessor : inst->control_predecessors()) {
+        if (auto it = result.instructions_map.find(predecessor);
+            it != result.instructions_map.end()) {
+          HloInstruction* new_predecessor = it->second;
+          TF_RETURN_IF_ERROR(new_predecessor->AddControlDependencyTo(new_inst));
+        } else {
+          TF_RETURN_IF_ERROR(
+              predecessor->AddControlDependencyTo(call_command_buffer));
+        }
+      }
+      for (HloInstruction* successor : inst->control_successors()) {
+        if (auto it = result.instructions_map.find(successor);
+            it != result.instructions_map.end()) {
+          HloInstruction* new_successor = it->second;
+          TF_RETURN_IF_ERROR(new_inst->AddControlDependencyTo(new_successor));
+        } else {
+          TF_RETURN_IF_ERROR(
+              call_command_buffer->AddControlDependencyTo(successor));
+        }
+      }
+      TF_RETURN_IF_ERROR(inst->DropAllControlDeps());
+
       int64_t tuple_index = result.inst_to_tuple_index_map[inst];
       TF_RETURN_IF_ERROR(inst->ReplaceAllUsesWith(results[tuple_index]));
       TF_RETURN_IF_ERROR(entry->RemoveInstruction(inst));
diff --git a/third_party/xla/xla/service/gpu/command_buffer_scheduling.h b/third_party/xla/xla/service/gpu/command_buffer_scheduling.h
index c59c070..601e72c 100644
--- a/third_party/xla/xla/service/gpu/command_buffer_scheduling.h
+++ b/third_party/xla/xla/service/gpu/command_buffer_scheduling.h
@@ -96,6 +96,10 @@
     // the original instruction to the tuple index of the result that replaces
     // the original instruction.
     absl::flat_hash_map<HloInstruction*, int64_t> inst_to_tuple_index_map;
+
+    // Map original instructions to their clones in the command buffer
+    // computation.
+    absl::flat_hash_map<HloInstruction*, HloInstruction*> instructions_map;
   };
 
   // Builds a computation from the instruction sequence. Used values constructed
diff --git a/third_party/xla/xla/service/gpu/command_buffer_scheduling_test.cc b/third_party/xla/xla/service/gpu/command_buffer_scheduling_test.cc
index d30b66d..2a556b6 100644
--- a/third_party/xla/xla/service/gpu/command_buffer_scheduling_test.cc
+++ b/third_party/xla/xla/service/gpu/command_buffer_scheduling_test.cc
@@ -357,6 +357,67 @@
   EXPECT_EQ(inst_to_tuple_index_map[instructions[4]], 2);
 }
 
+TEST_F(CommandBufferSchedulingTest, RelayControlDependencies) {
+  const char* hlo = R"(
+      HloModule TestModule, is_scheduled=true
+
+      %fused_computation (param_0: s32[], param_1: s32[]) -> s32[] {
+        %p0 = s32[] parameter(0)
+        %p1 = s32[] parameter(1)
+        ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+      }
+
+      %fused_computation.1 (param_0: s32[], param_1: s32[]) -> s32[] {
+        %p0 = s32[] parameter(0)
+        %p1 = s32[] parameter(1)
+        ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+      }
+
+      %fused_computation.2 (param_0: s32[], param_1: s32[]) -> s32[] {
+        %p0 = s32[] parameter(0)
+        %p1 = s32[] parameter(1)
+        ROOT %add = s32[] add(s32[] %p0, s32[] %p1)
+      }
+
+      ENTRY %main (a: s32[], b: s32[]) -> s32[] {
+        %a = s32[] parameter(0)
+        %b = s32[] parameter(1)
+        %custom-call = s32[] custom-call(), custom_call_target="some target"
+        %fusion = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation, control-predecessors={%custom-call}
+        %fusion.1 = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation.1, control-predecessors={%fusion}
+        %custom-call.1 = s32[] custom-call(), custom_call_target="some target"
+        %fusion.2 = s32[] fusion(s32[] %a, s32[] %b), kind=kLoop, calls=%fused_computation.2, control-predecessors={%fusion.1}
+        ROOT %custom-call.2 = s32[] custom-call(), custom_call_target="some target"
+      })";
+
+  const char* expected = R"(
+// CHECK: %command_buffer (param: s32[], param.1: s32[]) -> (s32[], s32[]) {
+// CHECK:   %param = s32[] parameter(0)
+// CHECK:   %param.1 = s32[] parameter(1)
+// CHECK:   %fusion.3 = s32[] fusion(%param, %param.1), kind=kLoop, calls=%fused_computation
+// CHECK:   %fusion.4 = s32[] fusion(%param, %param.1), kind=kLoop, calls=%fused_computation.1, control-predecessors={%fusion.3}
+// CHECK:   ROOT %tuple = (s32[], s32[]) tuple(%fusion.3, %fusion.4)
+// CHECK: }
+//
+// CHECK: ENTRY %main (a: s32[], b: s32[]) -> s32[] {
+// CHECK:   %a = s32[] parameter(0)
+// CHECK:   %b = s32[] parameter(1)
+// CHECK:   %custom-call = s32[] custom-call(), custom_call_target="some target"
+// CHECK:   %call = (s32[], s32[]) call(%a, %b), to_apply=%command_buffer, control-predecessors={%custom-call}
+// CHECK:   %get-tuple-element = s32[] get-tuple-element(%call), index=0
+// CHECK:   %get-tuple-element.1 = s32[] get-tuple-element(%call), index=1
+// CHECK:   %custom-call.1 = s32[] custom-call(), custom_call_target="some target"
+// CHECK:   %fusion.2 = s32[] fusion(%a, %b), kind=kLoop, calls=%fused_computation.2, control-predecessors={%call}
+// CHECK:   ROOT %custom-call.2 = s32[] custom-call(), custom_call_target="some target"
+// CHECK: })";
+
+  RunAndFilecheckHloRewrite(hlo, CommandBufferScheduling(), expected,
+                            [](HloModule* module) {
+                              EXPECT_TRUE(module->has_schedule());
+                              TF_CHECK_OK(module->schedule().Verify());
+                            });
+}
+
 }  // namespace
 
 }  // namespace xla::gpu
diff --git a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc
index d8a807f..1be565e 100644
--- a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc
+++ b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc
@@ -28,8 +28,10 @@
 #include <vector>
 
 #include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
 #include "absl/strings/str_cat.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/AsmParser/Parser.h"
 #include "llvm/IR/DiagnosticInfo.h"
 #include "llvm/IR/DiagnosticPrinter.h"
@@ -122,8 +124,18 @@
   mlir::PassManager pm(module->getName(), mlir::PassManager::Nesting::Implicit);
   pm.enableVerifier(should_verify);
 
+  absl::flat_hash_set<DebugOptions::CommandBufferCmdType> command_types;
+  for (int command_type_num : debug_options.xla_gpu_enable_command_buffer()) {
+    if (!DebugOptions::CommandBufferCmdType_IsValid(command_type_num)) {
+      return InternalError("Invalid command buffer command type");
+    }
+    DebugOptions::CommandBufferCmdType command_type =
+        static_cast<DebugOptions::CommandBufferCmdType>(command_type_num);
+    command_types.insert(command_type);
+  }
+
   GpuPipelineOpts opts;
-  opts.gpu_graph_level = debug_options.xla_gpu_graph_level();
+  opts.command_types = command_types;
   opts.min_graph_size = debug_options.xla_gpu_graph_min_graph_size();
   opts.enable_concurrent_region =
       debug_options.xla_gpu_graph_enable_concurrent_region();
@@ -198,9 +210,12 @@
                             module_str);
   }
 
+  // Collect allocation indices for handling graph capture functions.
+  auto allocation_indices = GetAllocationIndices(mlir_module);
+
   return std::make_unique<GpuRuntimeProgram>(
       entry_function_name.str(), std::move(module_str), buffer_sizes.vec(),
-      module_config.debug_options());
+      std::move(allocation_indices), module_config.debug_options());
 }
 
 StatusOr<std::unique_ptr<llvm::Module>> CompileModuleToLlvmIr(
@@ -352,8 +367,12 @@
 
   absl::flat_hash_map<const mlir::Operation*, const xla::HloInstruction*>
       operation_map;
+
+  // Store the allocations in the order of the LMHLO buffer arguments.
+  std::vector<const BufferAllocation*> ordered_allocations;
   TF_RETURN_IF_ERROR(HloToLhloModule(*results->buffer_assignment, *hlo_module,
-                                     *mlir_module, &operation_map));
+                                     *mlir_module, &ordered_allocations,
+                                     &operation_map));
 
   results->module_name =
       mlir::mhlo::GetDebugNameFromLocation(mlir_module->getLoc());
@@ -365,22 +384,39 @@
   auto entry_function = mlir::cast<mlir::func::FuncOp>(
       mlir_module->lookupSymbol(hlo_module->entry_computation()->name()));
 
-  // TODO(b/304613751): Add this flag to xla flags.
-  constexpr bool emit_ir_from_hlo = false;
+  bool emit_from_hlo = !IsXlaRuntimeExecutableEnabled(hlo_module->config());
+
+  std::vector<BufferAllocation> mlir_allocations;
+  absl::flat_hash_map<ShapeIndex, GpuExecutable::OutputInfo> mlir_output_info;
+  Shape mlir_output_shape;
+  TF_RETURN_IF_ERROR(GetMlirAllocationInfo(
+      entry_function, &mlir_allocations, &mlir_output_info, &mlir_output_shape,
+      &results->entry_func_attrs));
 
   IrEmitterContext ir_emitter_context(
-      hlo_module, emit_ir_from_hlo ? results->buffer_assignment.get() : nullptr,
-      platform_name, gpu_device_info, mlir_context.get(),
-      results->llvm_module.get(), emit_ir_from_hlo);
+      hlo_module, results->buffer_assignment.get(), platform_name,
+      gpu_device_info, mlir_context.get(), results->llvm_module.get(),
+      emit_from_hlo);
 
-  if (emit_ir_from_hlo) {
-    TF_RET_CHECK(!IsXlaRuntimeExecutableEnabled(hlo_module->config()));
-    results->allocations = results->buffer_assignment->Allocations();
+  std::vector<BufferAllocation*> allocations;
+  if (emit_from_hlo) {
+    results->output_shape = hlo_module->result_shape();
+    TF_ASSIGN_OR_RETURN(
+        results->output_info,
+        GetOutputInfo(*hlo_module, *results->buffer_assignment));
+    TF_RET_CHECK(mlir_allocations.size() == ordered_allocations.size());
+    ir_emitter_context.set_allocations(ordered_allocations);
+    results->use_original_allocations = true;
   } else {
-    TF_RETURN_IF_ERROR(GetMlirAllocationInfo(
-        entry_function, &results->allocations, &results->output_info,
-        &results->output_shape, &results->entry_func_attrs));
-    ir_emitter_context.set_allocations(results->allocations);
+    results->allocations = std::move(mlir_allocations);
+    results->output_shape = mlir_output_shape;
+    results->output_info = mlir_output_info;
+    allocations.reserve(results->allocations.size());
+    for (auto& allocation : results->allocations) {
+      allocations.push_back(&allocation);
+    }
+    ir_emitter_context.set_allocations(allocations);
+    results->use_original_allocations = false;
   }
 
   auto ir_emitter = IrEmitterUnnested::Create(&ir_emitter_context);
@@ -411,16 +447,16 @@
     RecordHloToLlvmDuration(end_usecs - start_usecs);
   }
 
-  // Sizes of all buffers required for running XLA module.
-  std::vector<int64_t> buffer_sizes;
-  llvm::transform(
-      results->allocations, std::back_inserter(buffer_sizes),
-      [](const BufferAllocation& allocation) { return allocation.size(); });
-
   // TODO(ezhulenev): Remove the FP8 check once https://reviews.llvm.org/D140088
   // is submitted. Currently we can't emit LLVM IR with fp8 types.
   if (IsXlaRuntimeExecutableEnabled(hlo_module->config()) &&
       !HasFp8(*hlo_module)) {
+    // Sizes of all buffers required for running XLA module.
+    std::vector<int64_t> buffer_sizes;
+    llvm::transform(
+        results->allocations, std::back_inserter(buffer_sizes),
+        [](const BufferAllocation& allocation) { return allocation.size(); });
+
     TF_ASSIGN_OR_RETURN(
         results->executable,
         LowerToJitRt(*mlir_module, entry_function.getName(), buffer_sizes,
diff --git a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.h b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.h
index a670553..39a2a04 100644
--- a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.h
+++ b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.h
@@ -48,6 +48,11 @@
   absl::flat_hash_map<ShapeIndex, GpuExecutable::OutputInfo> output_info;
   Shape output_shape;
   std::string module_name;
+
+  // If true, the compiled module uses buffer allocations owned by
+  // buffer_assignment. Otherwise the compiled module uses buffer allocations
+  // stored in allocations.
+  bool use_original_allocations;
 };
 
 // Removes all globals from the given module that are both uninitialized and
diff --git a/third_party/xla/xla/service/gpu/conv_algorithm_picker.cc b/third_party/xla/xla/service/gpu/conv_algorithm_picker.cc
index 627bee5..ee767e2 100644
--- a/third_party/xla/xla/service/gpu/conv_algorithm_picker.cc
+++ b/third_party/xla/xla/service/gpu/conv_algorithm_picker.cc
@@ -191,6 +191,7 @@
           config.output_descriptor,
           /* output_data = */ DeviceMemoryBase(nullptr), config.conv_desc,
           use_fallback, nullptr, numeric_options, &runners));
+
       for (auto& runner : runners) {
         TF_ASSIGN_OR_RETURN(
             auto runner_cache,
@@ -230,8 +231,9 @@
       params.config->input_descriptor, params.input_buf,
       params.config->filter_descriptor, params.filter_buf,
       params.config->output_descriptor, params.output_buf,
-      params.config->conv_desc, /* use_fallback = */ false, scratch_allocator,
-      numeric_options, &runners));
+      params.config->conv_desc,
+      /* use_fallback = */ false, scratch_allocator, numeric_options,
+      &runners));
 
   return runners;
 }
diff --git a/third_party/xla/xla/service/gpu/cub_sort_kernel.cu.cc b/third_party/xla/xla/service/gpu/cub_sort_kernel.cu.cc
index 6fcdd77..99ab262 100644
--- a/third_party/xla/xla/service/gpu/cub_sort_kernel.cu.cc
+++ b/third_party/xla/xla/service/gpu/cub_sort_kernel.cu.cc
@@ -18,69 +18,77 @@
 #include <cstddef>
 #include <cstdint>
 
-#include "absl/status/status.h"
-#include "absl/strings/str_cat.h"
-#include "cub/device/device_radix_sort.cuh"
+#if GOOGLE_CUDA
+#include "xla/service/gpu/gpu_prim_cuda.h"
+#elif TENSORFLOW_USE_ROCM
+#include "xla/service/gpu/gpu_prim_rocm.h"
+#endif  // TENSORFLOW_USE_ROCM
 
 namespace xla {
 namespace gpu {
 namespace {
 
+#if GOOGLE_CUDA
+#define CHK_GPU_ERR(err)            \
+  if (err != cudaSuccess) {         \
+    return cudaGetErrorString(err); \
+  }
+#elif TENSORFLOW_USE_ROCM
+#define CHK_GPU_ERR(err)           \
+  if (err != hipSuccess) {         \
+    return hipGetErrorString(err); \
+  }
+#endif
+
 template <typename KeyT>
-absl::Status CubSortKeys(void* d_temp_storage, size_t& temp_bytes,
-                         const void* d_keys_in, void* d_keys_out,
-                         size_t num_items, bool descending) {
-  cudaError_t err =
+const char* CubSortKeys(void* d_temp_storage, size_t& temp_bytes,
+                        const void* d_keys_in, void* d_keys_out,
+                        size_t num_items, bool descending) {
+  auto err =
       descending
-          ? cub::DeviceRadixSort::SortKeysDescending<KeyT>(
+          ? gpuprim::DeviceRadixSort::SortKeysDescending<KeyT>(
                 d_temp_storage, temp_bytes, static_cast<const KeyT*>(d_keys_in),
                 static_cast<KeyT*>(d_keys_out), num_items)
-          : cub::DeviceRadixSort::SortKeys<KeyT>(
+          : gpuprim::DeviceRadixSort::SortKeys<KeyT>(
                 d_temp_storage, temp_bytes, static_cast<const KeyT*>(d_keys_in),
                 static_cast<KeyT*>(d_keys_out), num_items);
-  if (err != 0) {
-    return absl::InvalidArgumentError(
-        absl::StrCat("CUB error: ", cudaGetErrorString(err)));
-  }
-  return absl::OkStatus();
+  CHK_GPU_ERR(err)
+  return nullptr;
 }
 
 template <typename KeyT, typename ValT>
-absl::Status CubSortPairs(void* d_temp_storage, size_t& temp_bytes,
-                          const void* d_keys_in, void* d_keys_out,
-                          const void* d_values_in, void* d_values_out,
-                          size_t num_items, bool descending) {
-  cudaError_t err =
+const char* CubSortPairs(void* d_temp_storage, size_t& temp_bytes,
+                         const void* d_keys_in, void* d_keys_out,
+                         const void* d_values_in, void* d_values_out,
+                         size_t num_items, bool descending) {
+  auto err =
       descending
-          ? cub::DeviceRadixSort::SortPairsDescending<KeyT, ValT>(
+          ? gpuprim::DeviceRadixSort::SortPairsDescending<KeyT, ValT>(
                 d_temp_storage, temp_bytes, static_cast<const KeyT*>(d_keys_in),
                 static_cast<KeyT*>(d_keys_out),
                 static_cast<const ValT*>(d_values_in),
                 static_cast<ValT*>(d_values_out), num_items)
-          : cub::DeviceRadixSort::SortPairs<KeyT, ValT>(
+          : gpuprim::DeviceRadixSort::SortPairs<KeyT, ValT>(
                 d_temp_storage, temp_bytes, static_cast<const KeyT*>(d_keys_in),
                 static_cast<KeyT*>(d_keys_out),
                 static_cast<const ValT*>(d_values_in),
                 static_cast<ValT*>(d_values_out), num_items);
-  if (err != 0) {
-    return absl::InvalidArgumentError(
-        absl::StrCat("CUB error: ", cudaGetErrorString(err)));
-  }
-  return absl::OkStatus();
+  CHK_GPU_ERR(err)
+  return nullptr;
 }
 
 }  // namespace
 
-#define XLA_CUB_DEFINE_SORT_KEYS(suffix, type)                                \
-  absl::Status CubSortKeys_##suffix(void* d_temp_storage, size_t& temp_bytes, \
-                                    const void* d_keys_in, void* d_keys_out,  \
-                                    size_t num_items, bool descending) {      \
-    return CubSortKeys<type>(d_temp_storage, temp_bytes, d_keys_in,           \
-                             d_keys_out, num_items, descending);              \
+#define XLA_CUB_DEFINE_SORT_KEYS(suffix, type)                               \
+  const char* CubSortKeys_##suffix(void* d_temp_storage, size_t& temp_bytes, \
+                                   const void* d_keys_in, void* d_keys_out,  \
+                                   size_t num_items, bool descending) {      \
+    return CubSortKeys<type>(d_temp_storage, temp_bytes, d_keys_in,          \
+                             d_keys_out, num_items, descending);             \
   }
 
 #define XLA_CUB_DEFINE_SORT_PAIRS(suffix, type1, type2)                      \
-  absl::Status CubSortPairs_##suffix(                                        \
+  const char* CubSortPairs_##suffix(                                         \
       void* d_temp_storage, size_t& temp_bytes, const void* d_keys_in,       \
       void* d_keys_out, const void* d_values_in, void* d_values_out,         \
       size_t num_items, bool descending) {                                   \
@@ -91,7 +99,11 @@
 
 // Floating point types.
 #ifdef CUB_TYPE_BF16
+#if GOOGLE_CUDA
 XLA_CUB_DEFINE_SORT_KEYS(bf16, __nv_bfloat16)
+#elif TENSORFLOW_USE_ROCM
+XLA_CUB_DEFINE_SORT_KEYS(bf16, hip_bfloat16)
+#endif
 #endif
 #ifdef CUB_TYPE_F16
 XLA_CUB_DEFINE_SORT_KEYS(f16, __half)
diff --git a/third_party/xla/xla/service/gpu/cub_sort_kernel.h b/third_party/xla/xla/service/gpu/cub_sort_kernel.h
index 1d8fb0e..621489b 100644
--- a/third_party/xla/xla/service/gpu/cub_sort_kernel.h
+++ b/third_party/xla/xla/service/gpu/cub_sort_kernel.h
@@ -19,18 +19,20 @@
 #include <cstddef>
 #include <cstdint>
 
-#include "absl/status/status.h"
-
 namespace xla {
 namespace gpu {
 
-#define XLA_CUB_DECLARE_SORT_KEYS(suffix)                                     \
-  absl::Status CubSortKeys_##suffix(void* d_temp_storage, size_t& temp_bytes, \
-                                    const void* d_keys_in, void* d_keys_out,  \
-                                    size_t num_items, bool descending);
+// Returns nullptr if no error, otherwise the error message as a null-terminated
+// string (cudaGetErrorString or similar).
+#define XLA_CUB_DECLARE_SORT_KEYS(suffix)                                    \
+  const char* CubSortKeys_##suffix(void* d_temp_storage, size_t& temp_bytes, \
+                                   const void* d_keys_in, void* d_keys_out,  \
+                                   size_t num_items, bool descending);
 
+// Returns nullptr if no error, otherwise the error message as a null-terminated
+// string (cudaGetErrorString or similar).
 #define XLA_CUB_DECLARE_SORT_PAIRS(suffix)                             \
-  absl::Status CubSortPairs_##suffix(                                  \
+  const char* CubSortPairs_##suffix(                                   \
       void* d_temp_storage, size_t& temp_bytes, const void* d_keys_in, \
       void* d_keys_out, const void* d_values_in, void* d_values_out,   \
       size_t num_items, bool descending);
diff --git a/third_party/xla/xla/service/gpu/cub_sort_thunk.cc b/third_party/xla/xla/service/gpu/cub_sort_thunk.cc
index 733aeff..762f822 100644
--- a/third_party/xla/xla/service/gpu/cub_sort_thunk.cc
+++ b/third_party/xla/xla/service/gpu/cub_sort_thunk.cc
@@ -16,6 +16,7 @@
 #include "xla/service/gpu/cub_sort_thunk.h"
 
 #include <cstddef>
+#include <cstdint>
 #include <functional>
 #include <memory>
 #include <optional>
@@ -29,8 +30,11 @@
 #include "xla/service/gpu/cub_sort_kernel.h"
 #include "xla/service/gpu/thunk.h"
 #include "xla/status.h"
+#include "xla/statusor.h"
 #include "xla/stream_executor/device_memory.h"
+#include "xla/util.h"
 #include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
 
 namespace xla {
 namespace gpu {
@@ -39,8 +43,8 @@
 // Template class for sorting a single tensor.
 class CubSortKeysImpl : public CubSortRunnerInterface {
  public:
-  using SortKeysFn =
-      std::function<Status(void*, size_t&, const void*, void*, size_t, bool)>;
+  using SortKeysFn = std::function<const char*(void*, size_t&, const void*,
+                                               void*, size_t, bool)>;
 
   explicit CubSortKeysImpl(SortKeysFn sort_keys_fn, PrimitiveType type)
       : sort_keys_fn_(sort_keys_fn), type_(type) {}
@@ -51,6 +55,7 @@
              bool descending) override;
   Status Run(const Thunk::ExecuteParams& params,
              const CubSortThunk* thunk) override;
+  StatusOr<int64_t> GetScratchSize(int64_t num_items) override;
 
  private:
   SortKeysFn sort_keys_fn_;
@@ -66,8 +71,14 @@
   size_t num_items = input_keys.size() * 8 / primitive_util::BitWidth(type_);
   CHECK(input_values.is_null());
   CHECK(output_values.is_null());
-  return sort_keys_fn_(scratch.opaque(), temp_bytes, input_keys.opaque(),
-                       output_keys.opaque(), num_items, descending);
+  const char* error =
+      sort_keys_fn_(scratch.opaque(), temp_bytes, input_keys.opaque(),
+                    output_keys.opaque(), num_items, descending);
+  if (error != nullptr) {
+    return absl::InvalidArgumentError(
+        absl::StrCat("CubSortKeys error: ", error));
+  }
+  return absl::OkStatus();
 }
 
 Status CubSortKeysImpl::Run(const Thunk::ExecuteParams& params,
@@ -78,11 +89,22 @@
              allocs.GetDeviceAddress(thunk->scratch()), thunk->descending());
 }
 
+StatusOr<int64_t> CubSortKeysImpl::GetScratchSize(int64_t num_items) {
+  size_t temp_bytes = 0;
+  const char* error =
+      sort_keys_fn_(nullptr, temp_bytes, nullptr, nullptr, num_items, false);
+  if (error != nullptr) {
+    return absl::InvalidArgumentError(
+        absl::StrCat("CubSortKeys error: ", error));
+  }
+  return temp_bytes;
+}
+
 // Template class for sorting a pair of tensors.
 class CubSortPairsImpl : public CubSortRunnerInterface {
  public:
-  using SortPairsFn = std::function<Status(void*, size_t&, const void*, void*,
-                                           const void*, void*, size_t, bool)>;
+  using SortPairsFn = std::function<const char*(
+      void*, size_t&, const void*, void*, const void*, void*, size_t, bool)>;
 
   explicit CubSortPairsImpl(SortPairsFn sort_pairs_fn, PrimitiveType type)
       : sort_pairs_fn_(sort_pairs_fn), type_(type) {}
@@ -93,6 +115,7 @@
              bool descending) override;
   Status Run(const Thunk::ExecuteParams& params,
              const CubSortThunk* thunk) override;
+  StatusOr<int64_t> GetScratchSize(int64_t num_items) override;
 
  private:
   SortPairsFn sort_pairs_fn_;
@@ -106,9 +129,14 @@
                              se::DeviceMemoryBase scratch, bool descending) {
   size_t temp_bytes = scratch.size();
   size_t num_items = input_keys.size() * 8 / primitive_util::BitWidth(type_);
-  return sort_pairs_fn_(scratch.opaque(), temp_bytes, input_keys.opaque(),
-                        output_keys.opaque(), input_values.opaque(),
-                        output_values.opaque(), num_items, descending);
+  const char* error = sort_pairs_fn_(
+      scratch.opaque(), temp_bytes, input_keys.opaque(), output_keys.opaque(),
+      input_values.opaque(), output_values.opaque(), num_items, descending);
+  if (error != nullptr) {
+    return absl::InvalidArgumentError(
+        absl::StrCat("CubSortPairs error: ", error));
+  }
+  return absl::OkStatus();
 }
 
 Status CubSortPairsImpl::Run(const Thunk::ExecuteParams& params,
@@ -121,7 +149,18 @@
              allocs.GetDeviceAddress(thunk->scratch()), thunk->descending());
 }
 
-std::unique_ptr<CubSortRunnerInterface> CreateCubSortRunner(
+StatusOr<int64_t> CubSortPairsImpl::GetScratchSize(int64_t num_items) {
+  size_t temp_bytes = 0;
+  const char* error = sort_pairs_fn_(nullptr, temp_bytes, nullptr, nullptr,
+                                     nullptr, nullptr, num_items, false);
+  if (error != nullptr) {
+    return absl::InvalidArgumentError(
+        absl::StrCat("CubSortPairs error: ", error));
+  }
+  return temp_bytes;
+}
+
+StatusOr<std::unique_ptr<CubSortRunnerInterface>> CreateCubSortRunner(
     PrimitiveType type) {
   switch (type) {
     case F16:
@@ -147,18 +186,20 @@
     case U64:
       return std::make_unique<CubSortKeysImpl>(CubSortKeys_u64, U64);
     default:
-      CHECK(false) << "Unsupported type of the sort kernel: "
-                   << primitive_util::LowercasePrimitiveTypeName(type);
+      return InvalidArgument("Unsupported type of the sort kernel: %s",
+                             primitive_util::LowercasePrimitiveTypeName(type));
   }
 }
 
-std::unique_ptr<CubSortRunnerInterface> CreateCubSortRunner(
+StatusOr<std::unique_ptr<CubSortRunnerInterface>> CreateCubSortRunner(
     PrimitiveType key_type, PrimitiveType value_type) {
   // Values can be of any type of 16/32/64 bit width.
   int valueWidth = primitive_util::BitWidth(value_type);
-  CHECK(valueWidth == 16 || valueWidth == 32 || valueWidth == 64)
-      << "Unsupported value type of the sort kernel: "
-      << primitive_util::LowercasePrimitiveTypeName(value_type);
+  if (valueWidth != 16 && valueWidth != 32 && valueWidth != 64) {
+    return InvalidArgument(
+        "Unsupported value type of the sort kernel: %s",
+        primitive_util::LowercasePrimitiveTypeName(value_type));
+  }
 
   // Only unsigned integer types could be used for keys.
   switch (key_type) {
@@ -187,26 +228,28 @@
       }
       return std::make_unique<CubSortPairsImpl>(CubSortPairs_u64_b64, U64);
     default:
-      CHECK(false) << "Unsupported key type of the sort kernel: "
-                   << primitive_util::LowercasePrimitiveTypeName(key_type);
+      return InvalidArgument(
+          "Unsupported key type of the sort kernel: %s",
+          primitive_util::LowercasePrimitiveTypeName(key_type));
   }
 }
 
-std::unique_ptr<CubSortRunnerInterface> CreateCubSortRunner(
-    PrimitiveType type, std::optional<PrimitiveType> value_type) {
+}  // namespace
+
+StatusOr<std::unique_ptr<CubSortRunnerInterface>>
+CubSortRunnerInterface::Create(PrimitiveType type,
+                               std::optional<PrimitiveType> value_type) {
   return value_type.has_value() ? CreateCubSortRunner(type, *value_type)
                                 : CreateCubSortRunner(type);
 }
 
-}  // namespace
-
 CubSortThunk::CubSortThunk(ThunkInfo thunk_info, PrimitiveType type,
                            std::optional<PrimitiveType> value_type,
                            std::vector<BufferAllocation::Slice> operands,
                            std::vector<BufferAllocation::Slice> results,
                            BufferAllocation::Slice scratch, bool descending)
     : Thunk(Thunk::kCubSort, thunk_info),
-      runner_(CreateCubSortRunner(type, value_type)),
+      runner_(CubSortRunnerInterface::Create(type, value_type).value()),
       operands_(std::move(operands)),
       results_(std::move(results)),
       scratch_(scratch),
@@ -218,7 +261,7 @@
                   se::DeviceMemoryBase output_keys,
                   se::DeviceMemoryBase output_values,
                   se::DeviceMemoryBase scratch, bool descending) {
-  auto runner = CreateCubSortRunner(type, value_type);
+  auto runner = CubSortRunnerInterface::Create(type, value_type).value();
   return runner->Run(input_keys, input_values, output_keys, output_values,
                      scratch, descending);
 }
diff --git a/third_party/xla/xla/service/gpu/cub_sort_thunk.h b/third_party/xla/xla/service/gpu/cub_sort_thunk.h
index 80ae1d4..d79fb0e 100644
--- a/third_party/xla/xla/service/gpu/cub_sort_thunk.h
+++ b/third_party/xla/xla/service/gpu/cub_sort_thunk.h
@@ -16,6 +16,7 @@
 #ifndef XLA_SERVICE_GPU_CUB_SORT_THUNK_H_
 #define XLA_SERVICE_GPU_CUB_SORT_THUNK_H_
 
+#include <cstdint>
 #include <memory>
 #include <optional>
 #include <vector>
@@ -23,6 +24,7 @@
 #include "xla/service/buffer_assignment.h"
 #include "xla/service/gpu/thunk.h"
 #include "xla/status.h"
+#include "xla/statusor.h"
 #include "xla/stream_executor/device_memory.h"
 #include "xla/xla_data.pb.h"
 
@@ -39,6 +41,10 @@
                      se::DeviceMemoryBase scratch, bool descending) = 0;
   virtual Status Run(const Thunk::ExecuteParams& params,
                      const class CubSortThunk* thunk) = 0;
+  virtual StatusOr<int64_t> GetScratchSize(int64_t num_items) = 0;
+
+  static StatusOr<std::unique_ptr<CubSortRunnerInterface>> Create(
+      PrimitiveType type, std::optional<PrimitiveType> value_type);
 };
 
 class CubSortThunk : public Thunk {
diff --git a/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.cc b/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.cc
index ba7c951..099e677 100644
--- a/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.cc
+++ b/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.cc
@@ -1315,19 +1315,14 @@
       dbias_user = user;
     }
   }
-  HloInstruction* reduce;
   auto ConsumeExtraConvert = [](HloInstruction** instr) {
     Match((*instr)->users()[0], m::Convert(instr, m::Op()).WithOneUse());
     return true;
   };
   // user_count == 1 && (reduce-> {convert} ->bitcast)
   return user_count == 1 &&
-         Match(dbias_user, m::Reduce(&reduce, m::Op(), m::Op()).WithOneUse()) &&
-         ConsumeExtraConvert(&reduce) &&
-         Match(reduce->users()[0],
-               m::AnyOf<HloInstruction>(m::Reshape(dbias, m::Op()),
-                                        m::Bitcast(dbias, m::Op()))
-                   .WithOneUse());
+         Match(dbias_user, m::Reduce(dbias, m::Op(), m::Op()).WithOneUse()) &&
+         (*dbias)->shape().rank() == 3 && ConsumeExtraConvert(dbias);
 }
 
 StatusOr<bool> FuseBwdMultiHeadedAttentionBlock(
@@ -1459,12 +1454,24 @@
   output_shapes.push_back(ShapeUtil::MakeShape(U8, {0}));
 
   HloInstruction* dbias = nullptr;
-  if (d_intermediate &&
-      IsDbiasOnlyUserBesidesGradGemm(d_intermediate, bmm_1_grad_1, bmm_1_grad_2,
-                                     &dbias)) {
-    output_shapes.push_back(dbias->shape());
+  if (d_intermediate) {
+    if (IsDbiasOnlyUserBesidesGradGemm(d_intermediate, bmm_1_grad_1,
+                                       bmm_1_grad_2, &dbias)) {
+      // Cudnn kernel only outputs dbias in this shape [1, num_heads, seq, seq],
+      // so we add a dimension of 1 to existing dbias' shape.
+      std::vector<int64_t> dbias_shape_vector =
+          SpanToVector(dbias->shape().dimensions());
+      dbias_shape_vector.insert(dbias_shape_vector.begin(), 1);
+      Shape cudnn_dbias_shape = ShapeUtil::MakeShape(
+          dbias->shape().element_type(), dbias_shape_vector);
+      output_shapes.push_back(cudnn_dbias_shape);
+    } else {
+      VLOG(2) << "Intermediate gradient has other users outside of gradient "
+                 "gemms and dbias"
+              << " which is not supported by CUDNN for now. Skipping.";
+      return false;
+    }
   }
-
   Shape call_shape = ShapeUtil::MakeTupleShape(output_shapes);
   HloInstruction* fmha_bwd_call =
       comp->AddInstruction(HloInstruction::CreateCustomCall(
@@ -1485,12 +1492,23 @@
   TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction(
       bmm_2_grad_1, HloInstruction::CreateGetTupleElement(bmm_2_grad_1->shape(),
                                                           fmha_bwd_call, 2)));
-  // d_intermediate tensor
+
   if (dbias) {
-    // does not really need d_intermediate
-    TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction(
-        dbias, HloInstruction::CreateGetTupleElement(dbias->shape(),
-                                                     fmha_bwd_call, 5)));
+    // Reshape fmha dbias output to original user's input shape.
+    // If the reshape doesn't involve physical data movement,
+    // algebraic simplifer can change it to a no-op bitcast.
+    Shape original_shape = dbias->shape();
+    HloInstruction* dbias_user = dbias->users()[0];
+    HloInstruction* cudnn_dbias_output =
+        comp->AddInstruction(HloInstruction::CreateGetTupleElement(
+            output_shapes.back(), fmha_bwd_call, 5));
+    HloInstruction* reshape_dbias = comp->AddInstruction(
+        HloInstruction::CreateReshape(original_shape, cudnn_dbias_output));
+    TF_RETURN_IF_ERROR(dbias_user->ReplaceOperandWith(
+        dbias_user->operand_index(dbias), reshape_dbias));
+
+    TF_RETURN_IF_ERROR(
+        comp->ReplaceInstructionWithDifferentShape(dbias, cudnn_dbias_output));
   }
   return true;
 }
diff --git a/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc b/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc
index 06ba3ef..1545b4a 100644
--- a/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc
+++ b/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc
@@ -1721,8 +1721,10 @@
                   m::Transpose(m::Transpose(m::GetTupleElement(
                                    m::CustomCall({backward_target}), 2)))
                       .WithShape(BF16, {16, 256, 16, 64}),
-                  m::GetTupleElement(  // dbias
-                      m::CustomCall({backward_target}), dbias_index)),
+                  m::Reshape(
+                      m::Reshape(m::GetTupleElement(  // dbias
+                          m::CustomCall({backward_target}), dbias_index)))
+                      .WithShape(BF16, {1, 16, 256, 256})),
               0)),
           m::Op(), m::Op(), m::Op(), m::Op())));
   TF_ASSERT_OK_AND_ASSIGN(auto config,
@@ -1909,8 +1911,10 @@
                   m::Transpose(m::Transpose(m::GetTupleElement(
                                    m::CustomCall({backward_target}), 2)))
                       .WithShape(F16, {16, 256, 16, 64}),
-                  m::GetTupleElement(  // dbias
-                      m::CustomCall({backward_target}), dbias_index)),
+                  m::Reshape(
+                      m::Reshape(m::GetTupleElement(  // dbias
+                          m::CustomCall({backward_target}), dbias_index)))
+                      .WithShape(F16, {1, 16, 256, 256})),
               0)),
           m::Op(), m::Op(), m::Op(), m::Op())));
   TF_ASSERT_OK_AND_ASSIGN(auto config,
@@ -2695,6 +2699,200 @@
   EXPECT_NEAR(config.dropout_rate(), 0, 1e-2);
 }
 
+TEST_F(CudnnFusedMhaRewriterTestHloTest,
+       BF16TrainingBmm1ScaleBiasSoftmaxDropoutBmm2DbiasShouldHaveUserShape) {
+  const char* module_str = R"(
+HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0},bf16[1,16,256,256]{3,2,1,0},pred[16,1,256,256]{3,2,1,0},bf16[16,256,16,64]{3,2,1,0})->(bf16[16,256,16,64]{3,2,1,0}, bf16[16,256,16,64]{3,2,1,0}, bf16[16,256,16,64]{3,2,1,0}, bf16[16,256,16,64]{3,2,1,0}, bf16[1,16,256,256]{3,2,1,0})}
+
+region_0.54 {
+  Arg_0.55 = bf16[] parameter(0)
+  Arg_1.56 = bf16[] parameter(1)
+  ROOT maximum.57 = bf16[] maximum(Arg_0.55, Arg_1.56)
+}
+
+region_1.66 {
+  Arg_0.67 = f32[] parameter(0)
+  Arg_1.68 = f32[] parameter(1)
+  ROOT add.69 = f32[] add(Arg_0.67, Arg_1.68)
+}
+
+region_2.114 {
+  Arg_0.115 = bf16[] parameter(0)
+  Arg_1.116 = bf16[] parameter(1)
+  ROOT add.117 = bf16[] add(Arg_0.115, Arg_1.116)
+}
+
+ENTRY main.146 {
+  Arg_2.3 = bf16[16,256,16,64]{3,2,1,0} parameter(2), sharding={replicated}
+  copy = bf16[16,256,16,64]{1,3,2,0} copy(Arg_2.3), sharding={replicated}
+  transpose.5 = bf16[16,16,64,256]{3,2,1,0} transpose(copy), dimensions={0,2,3,1}
+  Arg_0.1 = bf16[16,256,16,64]{3,2,1,0} parameter(0), sharding={replicated}
+  copy.1 = bf16[16,256,16,64]{3,1,2,0} copy(Arg_0.1), sharding={replicated}
+  transpose = bf16[16,16,256,64]{3,2,1,0} transpose(copy.1), dimensions={0,2,1,3}
+  Arg_1.2 = bf16[16,256,16,64]{3,2,1,0} parameter(1), sharding={replicated}
+  copy.2 = bf16[16,256,16,64]{1,3,2,0} copy(Arg_1.2), sharding={replicated}
+  transpose.1 = bf16[16,16,64,256]{3,2,1,0} transpose(copy.2), dimensions={0,2,3,1}
+  dot = bf16[16,16,256,256]{3,2,1,0} dot(transpose, transpose.1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  Arg_4.5 = pred[16,1,256,256]{3,2,1,0} parameter(4), sharding={replicated}
+  convert.35 = s32[16,1,256,256]{3,2,1,0} convert(Arg_4.5)
+  constant.28 = s32[] constant(0)
+  broadcast.29 = s32[16,1,256,256]{3,2,1,0} broadcast(constant.28), dimensions={}
+  compare.36 = pred[16,1,256,256]{3,2,1,0} compare(convert.35, broadcast.29), direction=GT
+  constant.30 = bf16[] constant(0)
+  broadcast.1 = bf16[16,1,256,256]{3,2,1,0} broadcast(constant.30), dimensions={}
+  constant.10 = bf16[] constant(-9.999e+09)
+  broadcast.3 = bf16[16,1,256,256]{3,2,1,0} broadcast(constant.10), dimensions={}
+  select.39 = bf16[16,1,256,256]{3,2,1,0} select(compare.36, broadcast.1, broadcast.3)
+  reshape.41 = bf16[16,256,256]{2,1,0} reshape(select.39)
+  broadcast.42 = bf16[16,16,256,256]{3,2,1,0} broadcast(reshape.41), dimensions={0,2,3}
+  Arg_3.4 = bf16[1,16,256,256]{3,2,1,0} parameter(3), sharding={replicated}
+  reshape.44 = bf16[16,256,256]{2,1,0} reshape(Arg_3.4)
+  broadcast.45 = bf16[16,16,256,256]{3,2,1,0} broadcast(reshape.44), dimensions={1,2,3}
+  add.46 = bf16[16,16,256,256]{3,2,1,0} add(broadcast.42, broadcast.45)
+  add.53 = bf16[16,16,256,256]{3,2,1,0} add(dot, add.46)
+  constant.31 = bf16[] constant(-inf)
+  reduce.58 = bf16[16,16,256]{2,1,0} reduce(add.53, constant.31), dimensions={3}, to_apply=region_0.54
+  broadcast.62 = bf16[16,16,256,256]{3,2,1,0} broadcast(reduce.58), dimensions={0,1,2}
+  subtract.63 = bf16[16,16,256,256]{3,2,1,0} subtract(add.53, broadcast.62)
+  exponential.64 = bf16[16,16,256,256]{3,2,1,0} exponential(subtract.63)
+  convert.65 = f32[16,16,256,256]{3,2,1,0} convert(exponential.64)
+  constant.11 = f32[] constant(0)
+  reduce.70 = f32[16,16,256]{2,1,0} reduce(convert.65, constant.11), dimensions={3}, to_apply=region_1.66
+  convert.4 = bf16[16,16,256]{2,1,0} convert(reduce.70)
+  broadcast.75 = bf16[16,16,256,256]{3,2,1,0} broadcast(convert.4), dimensions={0,1,2}
+  divide.76 = bf16[16,16,256,256]{3,2,1,0} divide(exponential.64, broadcast.75)
+  constant.22 = u32[1]{0} constant({255383827})
+  constant.21 = u32[1]{0} constant({267815257})
+  constant.2 = u32[1]{0} constant({0})
+  constant.23 = u32[1]{0} constant({3213575472})
+  custom-call.49 = (u32[1]{0}, u32[1]{0}) custom-call(constant.22, constant.21, constant.2, constant.23), custom_call_target="cu_threefry2x32", operand_layout_constraints={u32[1]{0}, u32[1]{0}, u32[1]{0}, u32[1]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config="\001\000\000\000\000\000\000\000"
+  get-tuple-element.50 = u32[1]{0} get-tuple-element(custom-call.49), index=0
+  reshape.80 = u32[] reshape(get-tuple-element.50)
+  broadcast.84 = u32[32768]{0} broadcast(reshape.80), dimensions={}
+  get-tuple-element.51 = u32[1]{0} get-tuple-element(custom-call.49), index=1
+  reshape.81 = u32[] reshape(get-tuple-element.51)
+  broadcast.85 = u32[32768]{0} broadcast(reshape.81), dimensions={}
+  iota.79 = u32[65536]{0} iota(), iota_dimension=0
+  slice.82 = u32[32768]{0} slice(iota.79), slice={[0:32768]}
+  slice.83 = u32[32768]{0} slice(iota.79), slice={[32768:65536]}
+  custom-call.86 = (u32[32768]{0}, u32[32768]{0}) custom-call(broadcast.84, broadcast.85, slice.82, slice.83), custom_call_target="cu_threefry2x32", operand_layout_constraints={u32[32768]{0}, u32[32768]{0}, u32[32768]{0}, u32[32768]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config="\000\200\000\000\000\000\000\000"
+  get-tuple-element.87 = u32[32768]{0} get-tuple-element(custom-call.86), index=0
+  get-tuple-element.88 = u32[32768]{0} get-tuple-element(custom-call.86), index=1
+  concatenate.89 = u32[65536]{0} concatenate(get-tuple-element.87, get-tuple-element.88), dimensions={0}
+  constant.17 = u32[] constant(9)
+  broadcast.13 = u32[65536]{0} broadcast(constant.17), dimensions={}
+  shift-right-logical.0 = u32[65536]{0} shift-right-logical(concatenate.89, broadcast.13)
+  constant.15 = u32[] constant(1065353216)
+  broadcast.21 = u32[65536]{0} broadcast(constant.15), dimensions={}
+  or.0 = u32[65536]{0} or(shift-right-logical.0, broadcast.21)
+  bitcast-convert.0 = f32[65536]{0} bitcast-convert(or.0)
+  constant.3 = f32[] constant(-1)
+  broadcast.30 = f32[65536]{0} broadcast(constant.3), dimensions={}
+  add.1 = f32[65536]{0} add(bitcast-convert.0, broadcast.30)
+  broadcast.31 = f32[65536]{0} broadcast(constant.11), dimensions={}
+  maximum.0 = f32[65536]{0} maximum(add.1, broadcast.31)
+  constant.9 = f32[] constant(0.9)
+  broadcast.32 = f32[65536]{0} broadcast(constant.9), dimensions={}
+  compare.0 = pred[65536]{0} compare(maximum.0, broadcast.32), direction=LT
+  constant = bf16[] constant(1.109)
+  broadcast.33 = bf16[65536]{0} broadcast(constant), dimensions={}
+  broadcast.34 = bf16[65536]{0} broadcast(constant.30), dimensions={}
+  select.2 = bf16[65536]{0} select(compare.0, broadcast.33, broadcast.34)
+  reshape.39 = bf16[16,16,256]{2,1,0} reshape(select.2)
+  broadcast.9 = bf16[16,16,256,256]{3,2,1,0} broadcast(reshape.39), dimensions={0,1,3}
+  multiply.101 = bf16[16,16,256,256]{3,2,1,0} multiply(divide.76, broadcast.9)
+  dot.1 = bf16[16,16,64,256]{3,2,1,0} dot(transpose.5, multiply.101), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+  transpose.103 = bf16[16,256,16,64]{1,3,2,0} transpose(dot.1), dimensions={0,3,1,2}
+  Arg_5.6 = bf16[16,256,16,64]{3,2,1,0} parameter(5), sharding={replicated}
+  copy.3 = bf16[16,256,16,64]{3,1,2,0} copy(Arg_5.6), sharding={replicated}
+  transpose.4 = bf16[16,16,256,64]{3,2,1,0} transpose(copy.3), dimensions={0,2,1,3}
+  dot.2 = bf16[16,16,256,256]{3,2,1,0} dot(transpose.4, transpose.5), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  multiply.108 = bf16[16,16,256,256]{3,2,1,0} multiply(dot.2, broadcast.9)
+  divide.124 = bf16[16,16,256,256]{3,2,1,0} divide(multiply.108, broadcast.75)
+  constant.19 = bf16[] constant(1)
+  broadcast.24 = bf16[16,16,256]{2,1,0} broadcast(constant.19), dimensions={}
+  multiply.2 = bf16[16,16,256]{2,1,0} multiply(convert.4, convert.4)
+  divide.0 = bf16[16,16,256]{2,1,0} divide(broadcast.24, multiply.2)
+  broadcast.111 = bf16[16,16,256,256]{3,2,1,0} broadcast(divide.0), dimensions={0,1,2}
+  multiply.112 = bf16[16,16,256,256]{3,2,1,0} multiply(multiply.108, broadcast.111)
+  multiply.113 = bf16[16,16,256,256]{3,2,1,0} multiply(multiply.112, exponential.64)
+  reduce.118 = bf16[16,16,256]{2,1,0} reduce(multiply.113, constant.30), dimensions={3}, to_apply=region_2.114
+  negate.1 = bf16[16,16,256]{2,1,0} negate(reduce.118)
+  broadcast.11 = bf16[16,16,256,256]{3,2,1,0} broadcast(negate.1), dimensions={0,1,2}
+  add.133 = bf16[16,16,256,256]{3,2,1,0} add(divide.124, broadcast.11)
+  multiply.134 = bf16[16,16,256,256]{3,2,1,0} multiply(add.133, exponential.64)
+  copy.4 = bf16[16,256,16,64]{3,1,2,0} copy(Arg_1.2), sharding={replicated}
+  transpose.9 = bf16[16,16,256,64]{3,2,1,0} transpose(copy.4), dimensions={0,2,1,3}
+  dot.4 = bf16[16,16,256,64]{3,2,1,0} dot(multiply.134, transpose.9), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  transpose.144 = bf16[16,256,16,64]{3,1,2,0} transpose(dot.4), dimensions={0,2,1,3}
+  dot.3 = bf16[16,16,256,64]{3,2,1,0} dot(multiply.134, transpose), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  transpose.142 = bf16[16,256,16,64]{3,1,2,0} transpose(dot.3), dimensions={0,2,1,3}
+  copy.5 = bf16[16,256,16,64]{1,3,2,0} copy(Arg_5.6), sharding={replicated}
+  transpose.104 = bf16[16,16,64,256]{3,2,1,0} transpose(copy.5), dimensions={0,2,3,1}
+  dot.106 = bf16[16,16,64,256]{3,2,1,0} dot(transpose.104, multiply.101), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
+  transpose.107 = bf16[16,256,16,64]{1,3,2,0} transpose(dot.106), dimensions={0,3,1,2}
+  reduce.139 = bf16[16,256,256]{2,1,0} reduce(multiply.134, constant.30), dimensions={0}, to_apply=region_2.114
+  bitcast.111 = bf16[1,16,256,256]{3,2,1,0} bitcast(reduce.139)
+  all-reduce = bf16[1,16,256,256]{3,2,1,0} all-reduce(bitcast.111), channel_id=85, replica_groups={{0}}, to_apply=region_2.114
+  tuple.145 = (bf16[16,256,16,64]{1,3,2,0}, bf16[16,256,16,64]{3,1,2,0}, bf16[16,256,16,64]{3,1,2,0}, bf16[16,256,16,64]{1,3,2,0}, bf16[1,16,256,256]{3,2,1,0}) tuple(transpose.103, transpose.144, transpose.142, transpose.107, all-reduce)
+  get-tuple-element = bf16[16,256,16,64]{1,3,2,0} get-tuple-element(tuple.145), index=0
+  copy.6 = bf16[16,256,16,64]{3,2,1,0} copy(get-tuple-element)
+  get-tuple-element.1 = bf16[16,256,16,64]{3,1,2,0} get-tuple-element(tuple.145), index=1
+  copy.7 = bf16[16,256,16,64]{3,2,1,0} copy(get-tuple-element.1)
+  get-tuple-element.2 = bf16[16,256,16,64]{3,1,2,0} get-tuple-element(tuple.145), index=2
+  copy.8 = bf16[16,256,16,64]{3,2,1,0} copy(get-tuple-element.2)
+  get-tuple-element.3 = bf16[16,256,16,64]{1,3,2,0} get-tuple-element(tuple.145), index=3
+  copy.9 = bf16[16,256,16,64]{3,2,1,0} copy(get-tuple-element.3)
+  get-tuple-element.4 = bf16[1,16,256,256]{3,2,1,0} get-tuple-element(tuple.145), index=4
+  ROOT tuple = (bf16[16,256,16,64]{3,2,1,0}, bf16[16,256,16,64]{3,2,1,0}, bf16[16,256,16,64]{3,2,1,0}, bf16[16,256,16,64]{3,2,1,0}, bf16[1,16,256,256]{3,2,1,0}) tuple(copy.6, copy.7, copy.8, copy.9, get-tuple-element.4)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
+  CudnnFusedMHARewriter fusedMhaRewriter{
+      GetCudaComputeCapability(),
+      GetCudnnVersionWithDbiasAndMaskBwdInputSupport()};
+  TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status());
+
+  HloDCE dce;
+  TF_ASSERT_OK(RunHloPass(&dce, m.get()).status());
+
+  ComputationLayout computation_layout(
+      m->entry_computation()->ComputeProgramShape());
+
+  const HloInstruction* fmha;
+  const absl::string_view backward_target =
+      kCudnnfMHAScaleBiasSoftmaxDropoutBackwardCallTarget;
+  auto dbias_index = 5;
+  SCOPED_TRACE(m->ToString());
+  EXPECT_THAT(
+      m->entry_computation()->root_instruction(),
+      GmockMatch(m::Tuple(
+          m::Copy(m::GetTupleElement(
+              m::Tuple(
+                  m::Transpose().WithShape(BF16, {16, 256, 16, 64}),
+                  m::Transpose(m::GetTupleElement(
+                                   m::CustomCall(&fmha, {backward_target}), 0))
+                      .WithShape(BF16, {16, 256, 16, 64}),
+                  m::Transpose(
+                      m::GetTupleElement(m::CustomCall({backward_target}), 1))
+                      .WithShape(BF16, {16, 256, 16, 64}),
+                  m::Transpose(m::Transpose(m::GetTupleElement(
+                                   m::CustomCall({backward_target}), 2)))
+                      .WithShape(BF16, {16, 256, 16, 64}),
+                  m::AllReduce(m::Bitcast(
+                      m::Reshape(
+                          m::GetTupleElement(  // dbias
+                              m::CustomCall({backward_target}), dbias_index))
+                          .WithShape(BF16, {16, 256, 256})))),
+              0)),
+          m::Op(), m::Op(), m::Op(), m::Op())));
+  TF_ASSERT_OK_AND_ASSIGN(auto config,
+                          fmha->backend_config<CudnnfMHABackendConfig>());
+  EXPECT_EQ(fmha->operands().size(), 5);
+  EXPECT_NEAR(config.dropout_rate(), 0.1, 1e-2);
+}
+
 }  // anonymous namespace
 }  // namespace gpu
 }  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/custom_call_test.cc b/third_party/xla/xla/service/gpu/custom_call_test.cc
index d151d4e..d1ed305 100644
--- a/third_party/xla/xla/service/gpu/custom_call_test.cc
+++ b/third_party/xla/xla/service/gpu/custom_call_test.cc
@@ -13,9 +13,12 @@
 limitations under the License.
 ==============================================================================*/
 
+#include <cstdint>
 #include <sstream>
 #include <string>
 
+#include "absl/strings/str_cat.h"
+
 #if GOOGLE_CUDA
 #include "third_party/gpus/cuda/include/cuda.h"
 #include "third_party/gpus/cuda/include/cuda_runtime_api.h"
@@ -26,8 +29,10 @@
 #define PLATFORM "ROCM"
 #endif
 
+#include "absl/status/status.h"
 #include "xla/client/lib/constants.h"
 #include "xla/client/xla_builder.h"
+#include "xla/ffi/ffi.h"
 #include "xla/runtime/custom_call.h"
 #include "xla/runtime/custom_call_registry.h"
 #include "xla/runtime/executable.h"
@@ -39,6 +44,7 @@
 #include "xla/service/gpu/runtime/custom_call_registry.h"
 #include "xla/service/gpu/runtime/support.h"
 #include "xla/service/service_executable_run_options.h"
+#include "xla/status.h"
 #include "xla/stream_executor/gpu/gpu_types.h"
 #include "xla/test_helpers.h"
 #include "xla/tests/client_library_test_base.h"
@@ -334,82 +340,6 @@
 }
 
 //===----------------------------------------------------------------------===//
-// Custom calls based on XLA runtime modules.
-//===----------------------------------------------------------------------===//
-
-struct TestModule : runtime::StatelessModule {
-  TestModule() : StatelessModule("TestModule") {}
-
-  // Check that we can use absl::Status to return errors back to the caller.
-  static absl::Status AlwaysFail(runtime::StridedMemrefView arg) {
-    return absl::InternalError("Uh oh, too bad");
-  }
-
-  // Check that we can get access to the stream and launch on device.
-  static absl::Status Memcpy(const ServiceExecutableRunOptions* run_options,
-                             runtime::FlatMemrefView src,
-                             runtime::FlatMemrefView dst) {
-    se::DeviceMemoryBase src_mem(src.data);
-    se::DeviceMemoryBase dst_mem(dst.data);
-
-    if (src.size_in_bytes != dst.size_in_bytes) {
-      return absl::InternalError("Size in bytes must match");
-    }
-
-    run_options->stream()->ThenMemcpyD2D(&dst_mem, src_mem, src.size_in_bytes);
-    return absl::OkStatus();
-  }
-
-  // Write bindings for custom calls and register with runtime.
-  void Export(runtime::DynamicCustomCallRegistry& registry) const final {
-    registry.Register(runtime::CustomCall::Bind("test.always_fail")
-                          .Arg<runtime::StridedMemrefView>()
-                          .To(AlwaysFail));
-
-    registry.Register(runtime::CustomCall::Bind("test.memcpy")
-                          .UserData<const ServiceExecutableRunOptions*>()
-                          .Arg<runtime::FlatMemrefView>()
-                          .Arg<runtime::FlatMemrefView>()
-                          .To(Memcpy));
-  }
-};
-
-XLA_REGISTER_RUNTIME_MODULE(std::make_unique<TestModule>());
-
-TEST_F(CustomCallTest, ExportedAlwaysFail) {
-  // TODO(ezhulenev): Remove once XLA runtime is enabled by default.
-  mutable_debug_options()->set_xla_gpu_enable_xla_runtime_executable(true);
-
-  XlaBuilder b(TestName());
-  CustomCall(&b, "test.always_fail", /*operands=*/{},
-             ShapeUtil::MakeShape(F32, {}), /*opaque=*/"",
-             /*has_side_effect=*/false,
-             /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
-             /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
-             /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
-  auto status = Execute(&b, {}).status();
-  EXPECT_EQ(status.code(), absl::StatusCode::kInternal);
-  VLOG(0) << status.message();
-  EXPECT_THAT(status.message(), ::testing::HasSubstr("Uh oh, too bad"));
-}
-
-TEST_F(CustomCallTest, ExportedMemcpy) {
-  // TODO(ezhulenev): Remove once XLA runtime is enabled by default.
-  mutable_debug_options()->set_xla_gpu_enable_xla_runtime_executable(true);
-
-  XlaBuilder b(TestName());
-  CustomCall(&b, "test.memcpy",
-             /*operands=*/{Broadcast(ConstantR0WithType(&b, F32, 42.0), {128})},
-             ShapeUtil::MakeShape(F32, {128}), /*opaque=*/"",
-             /*has_side_effect=*/false,
-             /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
-             /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
-             /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
-  TF_ASSERT_OK_AND_ASSIGN(auto result, ExecuteAndTransfer(&b, {}));
-  EXPECT_THAT(result.data<float>(), ::testing::Each(42));
-}
-
-//===----------------------------------------------------------------------===//
 // XLA runtime custom calls provides type-safe custom call API
 //===----------------------------------------------------------------------===//
 
@@ -422,13 +352,13 @@
 
 // (1) Declare custom call implementations as static functions.
 
-static absl::Status AlwaysFailImpl(runtime::StridedMemrefView arg) {
-  return absl::InternalError("Uh oh, too bad");
+static absl::Status AlwaysFailImpl(runtime::MemrefView arg, int32_t value) {
+  return absl::InternalError(absl::StrCat("Uh oh, wrong value: ", value));
 }
 
 static absl::Status MemcpyImpl(const ServiceExecutableRunOptions* run_options,
-                               runtime::StridedMemrefView src,
-                               runtime::StridedMemrefView dst) {
+                               runtime::MemrefView src,
+                               runtime::MemrefView dst) {
   auto src_mem = gpu::GetDeviceAddress(src);
   auto dst_mem = gpu::GetDeviceAddress(dst);
   run_options->stream()->ThenMemcpyD2D(&dst_mem, src_mem, src_mem.size());
@@ -439,20 +369,57 @@
 // declared signature matches function handlers, and at run time we check that
 // passed arguments match the signature (number of arguments and their types).
 
+// TODO(ezhulenev): Remove these custom calls once we switch to thunks runtime.
+
 XLA_RUNTIME_DEFINE_CUSTOM_CALL(
     AlwaysFail, AlwaysFailImpl, runtime::CustomCall::RuntimeChecks::kDefault,
     runtime::CustomCall::Bind("__gpu$xla.gpu.ext.always_fail")
-        .Arg<runtime::StridedMemrefView>()  // arg
+        .Arg<runtime::MemrefView>()  // arg
+        .Attr<int32_t>("value")      // value
 );
 
 XLA_RUNTIME_DEFINE_CUSTOM_CALL(
     Memcpy, MemcpyImpl, runtime::CustomCall::RuntimeChecks::kDefault,
     runtime::CustomCall::Bind("__gpu$xla.gpu.ext.memcpy")
         .UserData<const ServiceExecutableRunOptions*>()
-        .Arg<runtime::StridedMemrefView>()  // src
-        .Arg<runtime::StridedMemrefView>()  // dst
+        .Arg<runtime::MemrefView>()  // src
+        .Arg<runtime::MemrefView>()  // dst
 );
 
+// (3) Declare FFI handlers as adaptors for legacy XLA runtime custom calls.
+//
+// TODO(ezhulenev): This is a long term replacement for "legacy" custom calls
+// (custom calls with void** arguments) and a type safe xla runtime custom
+// calls (see above). XLA FFI unifies internal custom calls (static linking)
+// with external custom calls (dynamically loaded libraries). Make this the only
+// example, once it's fully supported.
+
+namespace impl {
+static Status AlwaysFail(ffi::Buffer arg, int32_t value) {
+  return AlwaysFailImpl(arg, value);
+}
+
+static Status Memcpy(const ServiceExecutableRunOptions* run_options,
+                     ffi::Buffer src, ffi::Buffer dst) {
+  return MemcpyImpl(run_options, src, dst);
+}
+}  // namespace impl
+
+XLA_FFI_DEFINE_HANDLER(kAlwaysFail, impl::AlwaysFail,
+                       ffi::Ffi::Bind()
+                           .Arg<ffi::Buffer>()      // arg
+                           .Attr<int32_t>("value")  // value
+);
+
+XLA_FFI_DEFINE_HANDLER(kMemcpy, impl::Memcpy,
+                       ffi::Ffi::Bind()
+                           .Ctx<ServiceExecutableRunOptions>()
+                           .Arg<ffi::Buffer>()  // src
+                           .Arg<ffi::Buffer>()  // dst
+);
+
+// (4) Register custom calls handlers with XLA runtime.
+
 static void RegisterCustomCalls(runtime::DirectCustomCallRegistry& registry) {
   registry.Register("__gpu$xla.gpu.ext.always_fail", AlwaysFail);
   registry.Register("__gpu$xla.gpu.ext.memcpy", Memcpy);
@@ -460,27 +427,27 @@
 
 XLA_GPU_REGISTER_RUNTIME_CUSTOM_CALL(RegisterCustomCalls);
 
-TEST_F(CustomCallTest, RuntimeCustomCallAlwaysFail) {
-  // TODO(ezhulenev): Remove once XLA runtime is enabled by default.
-  mutable_debug_options()->set_xla_gpu_enable_xla_runtime_executable(true);
+// (5) Register XLA FFI handlers with XLA runtime.
 
+XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__gpu$xla.gpu.ext.always_fail",
+                         kAlwaysFail);
+XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__gpu$xla.gpu.ext.memcpy",
+                         kMemcpy);
+
+TEST_F(CustomCallTest, RuntimeCustomCallAlwaysFail) {
   XlaBuilder b(TestName());
   CustomCall(&b, "__gpu$xla.gpu.ext.always_fail", /*operands=*/{},
-             ShapeUtil::MakeShape(F32, {}), /*opaque=*/"",
+             ShapeUtil::MakeShape(F32, {}), /*opaque=*/"{value = 42 : i32}",
              /*has_side_effect=*/false,
              /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
              /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
              /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
   auto status = Execute(&b, {}).status();
   EXPECT_EQ(status.code(), absl::StatusCode::kInternal);
-  VLOG(0) << status.message();
-  EXPECT_THAT(status.message(), ::testing::HasSubstr("Uh oh, too bad"));
+  EXPECT_THAT(status.message(), ::testing::HasSubstr("Uh oh, wrong value: 42"));
 }
 
 TEST_F(CustomCallTest, ExportedFfiMemcpy) {
-  // TODO(ezhulenev): Remove once XLA runtime is enabled by default.
-  mutable_debug_options()->set_xla_gpu_enable_xla_runtime_executable(true);
-
   XlaBuilder b(TestName());
   CustomCall(&b, "__gpu$xla.gpu.ext.memcpy",
              /*operands=*/{Broadcast(ConstantR0WithType(&b, F32, 42.0), {128})},
diff --git a/third_party/xla/xla/service/gpu/fusion_merger.cc b/third_party/xla/xla/service/gpu/fusion_merger.cc
index 56e4ee5f..3e44bc4 100644
--- a/third_party/xla/xla/service/gpu/fusion_merger.cc
+++ b/third_party/xla/xla/service/gpu/fusion_merger.cc
@@ -275,7 +275,9 @@
   }
 
   GpuPerformanceModel::RunTimes t = GpuPerformanceModel::EstimateRunTimes(
-      producer, &*cost_analysis_, producer->users());
+      producer, &*cost_analysis_,
+      GpuPerformanceModelOptions::ForModule(producer->GetModule()),
+      producer->users());
   if (t.time_fused > t.time_unfused) {
     ++num_fail_slower_if_fused_;
     return "will execute slower if fused";
diff --git a/third_party/xla/xla/service/gpu/fusion_process_dump.proto b/third_party/xla/xla/service/gpu/fusion_process_dump.proto
index 80b6cc7..0fc3794 100644
--- a/third_party/xla/xla/service/gpu/fusion_process_dump.proto
+++ b/third_party/xla/xla/service/gpu/fusion_process_dump.proto
@@ -3,14 +3,45 @@
 package xla.gpu;
 
 message FusionStep {
-  // Name of the resulting fusion. Can be the same as producer or consumer.
-  string fusion_name = 1;
+  message Fusion {
+    // Name of the resulting fusion. Can be the same as producer or consumer.
+    string fusion_name = 1;
 
-  // Name of the producer instruction before fusion.
-  string producer_name = 2;
+    // Name of the producer instruction before fusion.
+    string producer_name = 2;
 
-  // Name of the consumer instruction before fusion.
-  string consumer_name = 3;
+    // Name of the consumer instruction before fusion.
+    string consumer_name = 3;
+  }
+
+  message UpdatePriority {
+    // The name of the producer whose priority was updated.
+    string producer_name = 1;
+    // The names of all of the producers' consumers.
+    repeated string consumer_names = 2;
+
+    // The time to execute the epilogue of each consumer (consisting of the
+    // producer's HLO) and read the producer's inputs from each consumer.
+    float us_fused = 3;
+    // The time to execute the producer and read the producer's outputs from
+    // the consumers when unfused.
+    float us_unfused = 4;
+  }
+
+  message ProducerIneligible {
+    // The name of the producer.
+    string producer_name = 1;
+    // The reason why this producer cannot be fused.
+    string reason = 2;
+  }
+
+  oneof step {
+    Fusion fusion = 4;
+    ProducerIneligible producer_ineligible = 5;
+    UpdatePriority update_priority = 6;
+  }
+
+  reserved 1 to 3;
 }
 
 message FusionProcessDumpProto {
diff --git a/third_party/xla/xla/service/gpu/fusions/BUILD b/third_party/xla/xla/service/gpu/fusions/BUILD
index a810592..7fdd991 100644
--- a/third_party/xla/xla/service/gpu/fusions/BUILD
+++ b/third_party/xla/xla/service/gpu/fusions/BUILD
@@ -53,6 +53,7 @@
         "@llvm-project//llvm:ir_headers",
         "@llvm-project//mlir:IR",
         "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:statusor",
     ],
 )
 
@@ -124,8 +125,11 @@
     hdrs = ["thunk_util.h"],
     visibility = ["//visibility:public"],
     deps = [
+        "//xla:literal",
         "//xla:shape_util",
         "//xla:statusor",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:buffer_assignment",
         "//xla/service/gpu:gpu_executable",
         "//xla/service/gpu:ir_emission_utils",
         "//xla/service/gpu:ir_emitter_context",
@@ -162,6 +166,7 @@
         "//xla/service/gpu:ir_emitter_context",
         "//xla/service/gpu:kernel_reuse_cache",
         "//xla/service/gpu:parallel_loop_emitter",
+        "//xla/service/gpu:reduction_utils",
         "//xla/service/gpu:target_util",
         "//xla/service/gpu:thunk",
         "//xla/service/llvm_ir:fused_ir_emitter",
@@ -175,6 +180,7 @@
         "@com_google_absl//absl/types:span",
         "@llvm-project//llvm:Support",
         "@llvm-project//llvm:ir_headers",
+        "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Support",
         "@local_tsl//tsl/platform:logging",
         "@local_tsl//tsl/platform:status",
diff --git a/third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc b/third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc
index 5b5c1b3..fd3d730 100644
--- a/third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc
+++ b/third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc
@@ -34,6 +34,7 @@
 #include "xla/service/llvm_ir/ir_array.h"
 #include "xla/service/llvm_ir/llvm_util.h"
 #include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
 
 namespace xla {
 namespace gpu {
@@ -180,11 +181,15 @@
     IrEmitterContext& ir_emitter_context, ElementalIrEmitter& elemental_emitter,
     mlir::lmhlo::FusionOp fusion_op, const HloFusionInstruction& fusion,
     KernelReuseCache& kernel_cache, llvm::IRBuilder<>* builder) const {
-  std::string suggested_kernel_name = GetIrNameFromLoc(fusion_op->getLoc());
+  std::string suggested_kernel_name = std::string(fusion.name());
 
-  TF_ASSIGN_OR_RETURN(
-      auto kernel_arguments,
-      KernelArguments::Create(ir_emitter_context.allocations(), fusion_op));
+  TF_ASSIGN_OR_RETURN(KernelArguments kernel_arguments,
+                      ir_emitter_context.emit_ir_from_hlo()
+                          ? KernelArguments::Create(
+                                ir_emitter_context.buffer_assignment(), &fusion)
+                          : KernelArguments::Create(
+                                ir_emitter_context.allocations(), fusion_op));
+
   auto* fused_computation = fusion.fused_instructions_computation();
 
   FusionEmissionResult result;
@@ -198,8 +203,8 @@
           llvm::Function* kernel;
           std::tie(kernel, inputs, outputs) = BuildKernelPrototype(
               ir_emitter_context, suggested_kernel_name,
-              kernel_arguments.args(), fusion_op.getInputBuffers().size(),
-              launch_dims, builder);
+              kernel_arguments.args(), fusion.operand_count(), launch_dims,
+              builder);
           TF_RETURN_IF_ERROR(EmitKernel(ir_emitter_context, elemental_emitter,
                                         fusion, launch_dims, std::move(inputs),
                                         std::move(outputs), builder, i));
@@ -215,9 +220,15 @@
               << entry->kernel_name;
     }
 
-    result.thunks.emplace_back(std::make_unique<KernelThunk>(
-        fusion_op, entry->kernel_name, kernel_arguments.args(), launch_dims,
-        entry->shmem_bytes));
+    if (ir_emitter_context.emit_ir_from_hlo()) {
+      result.thunks.emplace_back(std::make_unique<KernelThunk>(
+          &fusion, entry->kernel_name, kernel_arguments.args(), launch_dims,
+          entry->shmem_bytes));
+    } else {
+      result.thunks.emplace_back(std::make_unique<KernelThunk>(
+          fusion_op, entry->kernel_name, kernel_arguments.args(), launch_dims,
+          entry->shmem_bytes));
+    }
   }
 
   return result;
diff --git a/third_party/xla/xla/service/gpu/fusions/fusions.cc b/third_party/xla/xla/service/gpu/fusions/fusions.cc
index 7f9de9b..f4b9d40 100644
--- a/third_party/xla/xla/service/gpu/fusions/fusions.cc
+++ b/third_party/xla/xla/service/gpu/fusions/fusions.cc
@@ -16,6 +16,7 @@
 
 #include <memory>
 #include <optional>
+#include <vector>
 
 #include "absl/types/span.h"
 #include "mlir/IR/Value.h"  // from @llvm-project
@@ -34,7 +35,6 @@
 
 namespace xla {
 namespace gpu {
-namespace {
 
 bool IsSingleInstructionFusion(mlir::lmhlo::FusionOp fusion) {
   bool seen_instruction = false;
@@ -50,10 +50,9 @@
   return seen_instruction;
 }
 
-}  // namespace
-
 std::optional<std::unique_ptr<FusionInterface>> GetFusionEmitter(
-    HloFusionAnalysis& analysis, absl::Span<const BufferAllocation> allocations,
+    HloFusionAnalysis& analysis,
+    absl::Span<const BufferAllocation* const> allocations,
     mlir::lmhlo::FusionOp fusion_op) {
   switch (analysis.GetEmitterFusionKind()) {
     case HloFusionAnalysis::EmitterFusionKind::kInputSlices:
diff --git a/third_party/xla/xla/service/gpu/fusions/fusions.h b/third_party/xla/xla/service/gpu/fusions/fusions.h
index 1899fea..82fc634 100644
--- a/third_party/xla/xla/service/gpu/fusions/fusions.h
+++ b/third_party/xla/xla/service/gpu/fusions/fusions.h
@@ -32,7 +32,8 @@
 // `allocations` may be empty and `fusion_op` may be nullptr if buffer
 // assignment didn't run yet.
 std::optional<std::unique_ptr<FusionInterface>> GetFusionEmitter(
-    HloFusionAnalysis& analysis, absl::Span<const BufferAllocation> allocations,
+    HloFusionAnalysis& analysis,
+    absl::Span<const BufferAllocation* const> allocations,
     mlir::lmhlo::FusionOp fusion_op);
 
 }  // namespace gpu
diff --git a/third_party/xla/xla/service/gpu/fusions/reduction.cc b/third_party/xla/xla/service/gpu/fusions/reduction.cc
index 318e34f..4695076 100644
--- a/third_party/xla/xla/service/gpu/fusions/reduction.cc
+++ b/third_party/xla/xla/service/gpu/fusions/reduction.cc
@@ -17,6 +17,7 @@
 #include <cstdint>
 #include <functional>
 #include <memory>
+#include <optional>
 #include <string>
 #include <utility>
 #include <vector>
@@ -35,6 +36,7 @@
 #include "llvm/IR/Value.h"
 #include "llvm/Support/AtomicOrdering.h"
 #include "llvm/Support/Casting.h"
+#include "mlir/IR/Value.h"  // from @llvm-project
 #include "mlir/Support/LLVM.h"  // from @llvm-project
 #include "xla/hlo/ir/hlo_casting_utils.h"
 #include "xla/hlo/ir/hlo_instruction.h"
@@ -54,6 +56,7 @@
 #include "xla/service/gpu/kernel_reuse_cache.h"
 #include "xla/service/gpu/kernel_thunk.h"
 #include "xla/service/gpu/parallel_loop_emitter.h"
+#include "xla/service/gpu/reduction_utils.h"
 #include "xla/service/gpu/target_util.h"
 #include "xla/service/gpu/thunk.h"
 #include "xla/service/llvm_ir/fused_ir_emitter.h"
@@ -324,21 +327,21 @@
 }
 
 StatusOr<std::unique_ptr<Thunk>> BuildFusedInitializerThunk(
-    IrEmitterContext& ir_emitter_context, mlir::lmhlo::FusionOp fusion,
-    const HloComputation* fused_computation,
-    ElementalIrEmitter& elemental_emitter, KernelReuseCache& kernel_cache,
-    int output_index, llvm::IRBuilder<>* builder) {
-  auto reduce = mlir::dyn_cast_or_null<mlir::mhlo::ReduceOp>(
-      fusion.getFusionRoots()[output_index]);
-
+    IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion,
+    mlir::lmhlo::FusionOp fusion_op, const HloComputation* fused_computation,
+    const HloInstruction* fusion_root, ElementalIrEmitter& elemental_emitter,
+    KernelReuseCache& kernel_cache, int output_index,
+    llvm::IRBuilder<>* builder) {
+  const HloReduceInstruction* reduce =
+      DynCast<HloReduceInstruction>(fusion_root);
   TF_RET_CHECK(reduce);
-  TF_RET_CHECK(reduce.getNumResults() == 1);
 
-  mlir::Value init_value = reduce.getInitValues()[0];
-  mlir::Value dest = fusion.getOutputBuffers()[output_index];
-  TF_ASSIGN_OR_RETURN(std::optional<std::unique_ptr<Thunk>> constant_init_thunk,
-                      BuildConstantInitializerThunk(ir_emitter_context, fusion,
-                                                    init_value, dest));
+  const HloInstruction* init_value = reduce->init_values()[0];
+  mlir::Value dest = fusion_op.getOutputBuffers()[output_index];
+  TF_ASSIGN_OR_RETURN(
+      std::optional<std::unique_ptr<Thunk>> constant_init_thunk,
+      BuildConstantInitializerThunk(ir_emitter_context, fusion_op, &fusion,
+                                    init_value, dest));
   if (constant_init_thunk) {
     return *std::move(constant_init_thunk);
   }
@@ -370,11 +373,11 @@
                         fused_emitter.GetGenerator(*instr->operand(1)));
     TF_RETURN_IF_ERROR(ParallelLoopEmitter(generator, {outputs[output_index]},
                                            launch_dimensions, builder)
-                           .EmitLoop(GetIrNameFromLoc(fusion.getLoc())));
+                           .EmitLoop(fusion.name()));
     return OkStatus();
   };
 
-  return BuildKernelThunkForFusion(ir_emitter_context, kernel_cache, fusion,
+  return BuildKernelThunkForFusion(ir_emitter_context, kernel_cache, fusion_op,
                                    fused_computation, launch_dimensions,
                                    /*discriminator=*/
                                    absl::StrCat("init_", output_index),
@@ -975,12 +978,13 @@
     absl::Span<const HloInstruction* const> fusion_roots =
         analysis_.fusion_roots();
     for (int i = 0; i < fusion_roots.size(); ++i) {
-      if (IsReductionFromOrToContiguousDimensions(*fusion_roots[i])) {
+      const HloInstruction* fusion_root = fusion_roots[i];
+      if (IsReductionFromOrToContiguousDimensions(*fusion_root)) {
         TF_ASSIGN_OR_RETURN(
             result.thunks.emplace_back(),
-            BuildFusedInitializerThunk(ir_emitter_context, fusion_op,
-                                       fused_computation, elemental_emitter,
-                                       kernel_cache, i, builder));
+            BuildFusedInitializerThunk(
+                ir_emitter_context, fusion, fusion_op, fused_computation,
+                fusion_root, elemental_emitter, kernel_cache, i, builder));
       }
     }
   }
diff --git a/third_party/xla/xla/service/gpu/fusions/thunk_util.cc b/third_party/xla/xla/service/gpu/fusions/thunk_util.cc
index 98a66ca..bb16744 100644
--- a/third_party/xla/xla/service/gpu/fusions/thunk_util.cc
+++ b/third_party/xla/xla/service/gpu/fusions/thunk_util.cc
@@ -14,8 +14,10 @@
 ==============================================================================*/
 #include "xla/service/gpu/fusions/thunk_util.h"
 
+#include <cstdint>
 #include <memory>
 #include <optional>
+#include <vector>
 
 #include "absl/types/span.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"  // from @llvm-project
@@ -23,6 +25,11 @@
 #include "mlir/IR/Operation.h"  // from @llvm-project
 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/literal.h"
+#include "xla/service/buffer_assignment.h"
 #include "xla/service/gpu/ir_emission_utils.h"
 #include "xla/service/gpu/ir_emitter_context.h"
 #include "xla/service/gpu/memset_thunk.h"
@@ -80,36 +87,24 @@
 
 StatusOr<std::optional<std::unique_ptr<Thunk>>> BuildConstantInitializerThunk(
     IrEmitterContext& ir_emitter_context, mlir::Operation* op,
-    mlir::Value init_value, mlir::Value dest) {
-  mlir::DenseElementsAttr const_init;
-  if (auto get_global_memref =
-          mlir::dyn_cast_or_null<mlir::memref::GetGlobalOp>(
-              init_value.getDefiningOp())) {
-    auto global_memref =
-        mlir::SymbolTable::lookupNearestSymbolFrom<mlir::memref::GlobalOp>(
-            get_global_memref, get_global_memref.getNameAttr());
-    if (global_memref.getConstant() && global_memref.getInitialValue()) {
-      // If the initial value happens to be a constant, generate a specialized
-      // thunk.
-      const_init = global_memref.getInitialValue()
-                       .value()
-                       .cast<mlir::DenseElementsAttr>();
-    }
-  } else if (auto constant = mlir::dyn_cast_or_null<mlir::mhlo::ConstantOp>(
-                 init_value.getDefiningOp())) {
-    const_init = constant.getValue().dyn_cast<mlir::DenseElementsAttr>();
-  }
-
-  if (const_init) {
+    const HloInstruction* instr, const HloInstruction* init_value,
+    mlir::Value dest) {
+  if (const HloConstantInstruction* constant =
+          DynCast<HloConstantInstruction>(init_value)) {
+    const Literal& literal = constant->literal();
+    const uint8_t* data = static_cast<const uint8_t*>(literal.untyped_data());
     std::vector<uint8_t> literal_bytes;
-    TF_RETURN_IF_ERROR(
-        CopyDenseElementsDataToXlaFormat(const_init, &literal_bytes));
+    for (int i = 0; i < literal.size_bytes(); i++) {
+      literal_bytes.push_back(static_cast<uint8_t>(data[i]));
+    }
 
     TF_ASSIGN_OR_RETURN(
-        auto dest_slice,
-        GetAllocationSlice(dest, ir_emitter_context.allocations()));
+        BufferAllocation::Slice dest_slice,
+        ir_emitter_context.emit_ir_from_hlo()
+            ? ir_emitter_context.buffer_assignment().GetUniqueSlice(instr, {})
+            : GetAllocationSlice(dest, ir_emitter_context.allocations()));
 
-    const Shape dest_shape = GetShape(dest);
+    const Shape dest_shape = instr->shape();
     return BuildConstantInitializerThunk(op, literal_bytes, dest, dest_slice,
                                          dest_shape);
   }
diff --git a/third_party/xla/xla/service/gpu/fusions/thunk_util.h b/third_party/xla/xla/service/gpu/fusions/thunk_util.h
index 0ac5dad..ca88d2f 100644
--- a/third_party/xla/xla/service/gpu/fusions/thunk_util.h
+++ b/third_party/xla/xla/service/gpu/fusions/thunk_util.h
@@ -18,6 +18,8 @@
 #include <memory>
 #include <optional>
 
+#include "mlir/IR/Value.h"  // from @llvm-project
+#include "xla/hlo/ir/hlo_instruction.h"
 #include "xla/service/gpu/ir_emitter_context.h"
 #include "xla/service/gpu/thunk.h"
 #include "xla/statusor.h"
@@ -29,7 +31,8 @@
 // empty optional if the value is not a constant.
 StatusOr<std::optional<std::unique_ptr<Thunk>>> BuildConstantInitializerThunk(
     IrEmitterContext& ir_emitter_context, mlir::Operation* op,
-    mlir::Value init_value, mlir::Value dest);
+    const HloInstruction* instr, const HloInstruction* init_value,
+    mlir::Value dest);
 
 }  // namespace gpu
 }  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc b/third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc
index 1deae4d..20f9e70 100644
--- a/third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc
+++ b/third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc
@@ -16,6 +16,7 @@
 #include "xla/service/gpu/gemm_algorithm_picker.h"
 
 #include <algorithm>
+#include <cstdint>
 #include <functional>
 #include <limits>
 #include <optional>
@@ -263,6 +264,16 @@
       AutotunerUtil::CreateBuffer(buffer_allocator, output_shape,
                                   autotune_config, rng_state));
 
+  int64_t workspace_size =
+      autotune_config.GetCudaComputeCapability().IsAtLeastHopper()
+          ? GemmConfig::kHopperWorkspace
+          : GemmConfig::kDefaultWorkspace;
+  TF_ASSIGN_OR_RETURN(
+      se::DeviceMemoryBase workspace_buffer,
+      AutotunerUtil::CreateBuffer(buffer_allocator,
+                                  ShapeUtil::MakeShape(S8, {workspace_size}),
+                                  autotune_config, rng_state));
+
   HloModuleConfig& hlo_module_config = gemm->GetModule()->mutable_config();
   AutotuneResult best_algorithm;
   if (IsCublasLtMatmul(*gemm)) {
@@ -342,8 +353,9 @@
               // success-ness is returned in
               // ProfileResult::is_valid.
               TF_RETURN_IF_ERROR(RunGemm(config, lhs_buffer, rhs_buffer,
-                                         output_buffer, deterministic_ops,
-                                         stream, algorithm, &profile_result));
+                                         output_buffer, workspace_buffer,
+                                         deterministic_ops, stream, algorithm,
+                                         &profile_result));
               return std::move(profile_result);
             }));
     if (best_algorithm.has_gemm()) {
@@ -362,6 +374,15 @@
                                 const AutotuneConfig& config) {
   VLOG(3) << "Loading the autotune result of GemmThunk " << gemm->ToString();
 
+  GemmBackendConfig gemm_config =
+      gemm->backend_config<GemmBackendConfig>().value();
+  // Degenerate gemms replaced with memzero operation, no need to auto tune it.
+  if (gemm_config.alpha_real() == 0.0 && gemm_config.alpha_imag() == 0.0 &&
+      gemm_config.beta() == 0.0) {
+    VLOG(3) << "Skip degenerate gemm instruction auto tuning";
+    return false;
+  }
+
   AutotuneCacheKey key(config.GetModelStr(), *gemm);
 
   TF_ASSIGN_OR_RETURN(AutotuneResult algorithm,
@@ -370,8 +391,6 @@
                       }));
 
   se::CudaComputeCapability capability = config.GetCudaComputeCapability();
-  GemmBackendConfig gemm_config =
-      gemm->backend_config<GemmBackendConfig>().value();
   GemmBackendConfig updated_config = gemm_config;
 
   // We only set the 'algorithm' field on non-Ampere architectures, as for
diff --git a/third_party/xla/xla/service/gpu/gemm_algorithm_picker_test.cc b/third_party/xla/xla/service/gpu/gemm_algorithm_picker_test.cc
index 6fc7472a4..07a901f 100644
--- a/third_party/xla/xla/service/gpu/gemm_algorithm_picker_test.cc
+++ b/third_party/xla/xla/service/gpu/gemm_algorithm_picker_test.cc
@@ -93,7 +93,7 @@
   SCOPED_TRACE(m->ToString());
   HloInstruction* dot;
   ASSERT_THAT(m->entry_computation()->root_instruction(),
-              GmockMatch(m::CustomCall(&dot)));
+              GmockMatch(m::GetTupleElement(m::CustomCall(&dot), 0)));
 
   TF_ASSERT_OK_AND_ASSIGN(GemmBackendConfig config,
                           dot->backend_config<GemmBackendConfig>());
@@ -163,7 +163,7 @@
   SCOPED_TRACE(m->ToString());
   HloInstruction* dot;
   ASSERT_THAT(m->entry_computation()->root_instruction(),
-              GmockMatch(m::CustomCall(&dot)));
+              GmockMatch(m::GetTupleElement(m::CustomCall(&dot), 0)));
 
   TF_ASSERT_OK_AND_ASSIGN(GemmBackendConfig config,
                           dot->backend_config<GemmBackendConfig>());
diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter.cc b/third_party/xla/xla/service/gpu/gemm_rewriter.cc
index 6913e04..0e41cfd 100644
--- a/third_party/xla/xla/service/gpu/gemm_rewriter.cc
+++ b/third_party/xla/xla/service/gpu/gemm_rewriter.cc
@@ -11,13 +11,15 @@
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
-==============================================================================*/
+=
+=============================================================================*/
 
 #include "xla/service/gpu/gemm_rewriter.h"
 
 #include <algorithm>
 #include <array>
 #include <cmath>
+#include <cstdint>
 #include <limits>
 #include <memory>
 #include <optional>
@@ -42,6 +44,9 @@
 #include "xla/service/gpu/matmul_utils.h"
 #include "xla/service/hlo_creation_utils.h"
 #include "xla/service/pattern_matcher.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/status.h"
 #include "xla/status_macros.h"
 #include "xla/statusor.h"
 #include "xla/stream_executor/blas.h"
@@ -487,6 +492,22 @@
         instr->dot_dimension_numbers();
     *gemm_backend_config.mutable_precision_config() = instr->precision_config();
 
+    HloInstruction *lhs = instr->mutable_operand(0);
+    HloInstruction *rhs = instr->mutable_operand(1);
+    auto attributes = instr->frontend_attributes().map();
+    gemm_backend_config.set_grad_x(attributes["grad_x"] == "true");
+    gemm_backend_config.set_grad_y(attributes["grad_y"] == "true");
+
+    int64_t lhs_batch_dims_size =
+        instr->dot_dimension_numbers().lhs_batch_dimensions_size();
+    int64_t lhs_stride = lhs->shape().dimensions(lhs_batch_dims_size) *
+                         lhs->shape().dimensions(lhs_batch_dims_size + 1);
+    int64_t rhs_stride = rhs->shape().dimensions(lhs_batch_dims_size) *
+                         rhs->shape().dimensions(lhs_batch_dims_size + 1);
+
+    gemm_backend_config.set_lhs_stride(lhs_stride);
+    gemm_backend_config.set_rhs_stride(rhs_stride);
+
     // First try to match the fp8 gemm pattern.
     TF_ASSIGN_OR_RETURN(bool supported_by_cublaslt,
                         GemmIsSupportedByCublasLt(*instr, gemm_backend_config));
@@ -1799,7 +1820,8 @@
             dot_dims.rhs_contracting_dimensions(),
             /*output_shape=*/instr.shape(), gemm_backend_config.alpha_real(),
             gemm_backend_config.alpha_imag(), gemm_backend_config.beta(),
-            /*algorithm*/ std::nullopt, se::blas::kDefaultComputePrecision));
+            /*algorithm*/ std::nullopt, se::blas::kDefaultComputePrecision,
+            gemm_backend_config.grad_x(), gemm_backend_config.grad_y()));
 
     if (matrix_name == "lhs" || matrix_name == "a") {
       return gemm_config.lhs_layout.order == MatrixLayout::Order::kColumnMajor;
@@ -1938,10 +1960,85 @@
   }
 };
 
+// Rewriter that adds a workspace to legacy cuBLAS custom calls. We run it
+// separately after gemm rewriter, so that we can do pattern matching without
+// having to match output tuples.
+class GemmWorkspaceRewriteVisitor : public DfsHloRewriteVisitor {
+ public:
+  explicit GemmWorkspaceRewriteVisitor(se::GpuComputeCapability gpu_version)
+      : gpu_version_(gpu_version) {}
+
+  Status HandleCustomCall(HloInstruction *instr) override {
+    if (instr->custom_call_target() != kGemmCallTarget ||
+        !instr->shape().IsArray()) {
+      return OkStatus();
+    }
+
+    auto *cuda_cc = std::get_if<se::CudaComputeCapability>(&gpu_version_);
+
+    // Pass a user-managed workspace to legacy cuBLAS operations, as
+    // otherwise cuBLAS will use its own internal pool which will be competing
+    // with XLA allocator for device memory.
+    int64_t workspace = cuda_cc == nullptr ? 0
+                        : cuda_cc->IsAtLeastHopper()
+                            ? GemmConfig::kHopperWorkspace
+                            : GemmConfig::kDefaultWorkspace;
+
+    // We do not know the workspace size required by cuBLAS, but we can guess
+    // that in a worst case cuBLAS will transpose all operands into tiled
+    // layout optimal for the tensor cores. It doesn't make sense to allocate a
+    // larger workspace.
+    //
+    // TODO(ezhulenev): This is not based on any measurement, just a common
+    // sense, we should tweak it to find the minimal workspace size.
+    int64_t operands_byte_size = 0;
+    for (auto &operand : instr->operands()) {
+      operands_byte_size += ShapeUtil::ByteSizeOf(operand->shape());
+    }
+    workspace = std::min(workspace, operands_byte_size);
+
+    // If CUDA graphs are disabled (command buffer implementation detail),
+    // then we reset the workspace size to 0 and rely on cuBlas to allocate
+    // workspace from its own pool.
+    //
+    // TODO(ezhulenev): Remove this work around, allocating workspace
+    // explicitly should always be better than relying on cuBlas.
+    bool cuda_graphs_disabled = instr->GetModule()
+                                    ->config()
+                                    .debug_options()
+                                    .xla_gpu_enable_command_buffer_size() == 0;
+    if (cuda_graphs_disabled) workspace = 0;
+
+    // Append workspace buffer to instruction outputs.
+    std::vector<Shape> output_shapes = {instr->shape()};
+    output_shapes.emplace_back(ShapeUtil::MakeShape(S8, {workspace}));
+    Shape output_shape = ShapeUtil::MakeTupleShape(output_shapes);
+
+    // Clone custom call with a new shape.
+    HloInstruction *new_call = instr->AddInstruction(
+        instr->CloneWithNewOperands(output_shape, instr->operands()));
+
+    // Update operand aliasing if it was a fused gemm with aliased output.
+    auto *custom_call = xla::Cast<HloCustomCallInstruction>(new_call);
+    if (!custom_call->output_to_operand_aliasing().empty()) {
+      custom_call->set_output_to_operand_aliasing({{{0}, {2, {}}}});
+    }
+
+    HloInstruction *get_output = instr->AddInstruction(
+        HloInstruction::CreateGetTupleElement(new_call, 0));
+    return ReplaceInstruction(instr, get_output);
+  }
+
+ private:
+  se::GpuComputeCapability gpu_version_;
+};
+
 StatusOr<bool> RunOnComputation(HloComputation *computation,
                                 se::GpuComputeCapability gpu_version) {
   GemmRewriterVisitor visitor(gpu_version);
   TF_RETURN_IF_ERROR(computation->Accept(&visitor));
+  GemmWorkspaceRewriteVisitor workspace_visitor(gpu_version);
+  TF_RETURN_IF_ERROR(computation->Accept(&workspace_visitor));
   return visitor.changed();
 }
 
diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc b/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc
index 1210f78..c4b1447 100644
--- a/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc
+++ b/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc
@@ -760,8 +760,6 @@
 
   const HloInstruction* src =
       (direction == TransformDirection::kOutputToInput) ? hlo : hlo->operand(0);
-  const HloInstruction* dst =
-      (direction == TransformDirection::kOutputToInput) ? hlo->operand(0) : hlo;
   // Note: copying instead of using a const reference because
   // some operations (slice) will modify fragment properties in-place.
   Fragments src_fragments_order = dim_orders.at(src).TensorFragmentsOrder();
@@ -769,14 +767,6 @@
       ShapeUtil::IsEffectiveScalar(hlo->shape())) {
     return FusionDecision("Slice to scalar is not implemented yet.");
   }
-  DimOrderUpdates result;
-  if (hlo->opcode() == HloOpcode::kReduce || hlo->opcode() == HloOpcode::kPad) {
-    // Operand 1 (the neutral value or padding value) has to be a scalar.
-    result.map.insert({hlo->operand(1), DimensionOrder()});
-  }
-  DimensionOrder& dst_dim_order =
-      result.map.insert({dst, DimensionOrder()}).first->second;
-  Fragments& dst_fragments_order = dst_dim_order.TensorFragmentsOrder();
   // Every HLO dimension can correspond to a group of subdimensions in
   // dim_order_. For the easier handling of permutations: group dim_order_ by
   // dimension, apply permutations, then finally remove the grouping.
@@ -798,136 +788,151 @@
     CHECK_EQ(subdim_size_accumulator, dim_size);
     src_physical.push_back(subdim_group);
   }
+
   // Source physical -> source logical.
   std::vector<std::vector<Fragment*>> src_logical;
   src_logical.resize(src_physical.size());
   for (int i = 0; i < src_physical.size(); ++i) {
     src_logical[src->shape().layout().minor_to_major(i)] = src_physical[i];
   }
-  // Source logical -> destination logical.
-  std::vector<std::vector<Fragment*>> dst_logical;
-  if (hlo->opcode() == HloOpcode::kTranspose) {
-    const auto* transpose = Cast<HloTransposeInstruction>(hlo);
-    std::vector<int64_t> permutation(transpose->dimensions().cbegin(),
-                                     transpose->dimensions().cend());
-    if (direction == TransformDirection::kInputToOutput) {
-      permutation = InversePermutation(permutation);
-    }
-    dst_logical.resize(permutation.size());
-    for (int i = 0; i < permutation.size(); ++i) {
-      dst_logical[permutation[i]] = src_logical[i];
-    }
-  } else if (hlo->opcode() == HloOpcode::kBroadcast) {
-    const auto* broadcast = Cast<HloBroadcastInstruction>(hlo);
-    dst_logical.resize(broadcast->dimensions().size());
-    for (int i = 0; i < broadcast->dimensions().size(); ++i) {
-      dst_logical[i] = src_logical[broadcast->dimensions()[i]];
-    }
-  } else if (hlo->opcode() == HloOpcode::kReduce) {
-    const auto* reduce = Cast<HloReduceInstruction>(hlo);
-    dst_logical.resize(src_logical.size() + reduce->dimensions().size());
-    if (reduce->dimensions().size() != 1) {
-      return FusionDecision("Unsupported reduction.");
-    }
-    for (int i = 0; i < dst_logical.size(); ++i) {
-      if (i == reduce->dimensions().front()) {
-        // This way to assign the reduction dimension will only work for
-        // softmax fusions with known patterns for now. Generally a reduction
-        // should create a new tiled dimension.
-        dst_logical[i] = {&new_fragments.emplace_back(
-            std::get<SoftmaxProperties>(properties_)
-                .softmax_reduction_dimension,
-            reduce->operand(0)->shape().dimensions(i))};
-      } else {
-        dst_logical[i] = src_logical[i];
-      }
-    }
-  } else if (hlo->opcode() == HloOpcode::kCopy) {
-    // Copy preserves the logical shape, just permutes the layout.
-    CHECK(ShapeUtil::SameDimensions(src->shape(), dst->shape()));
-    dst_logical = src_logical;
-  } else if (hlo->opcode() == HloOpcode::kPad) {
-    const auto* pad = Cast<HloPadInstruction>(hlo);
-    dst_logical.resize(src_logical.size());
-    for (int i = 0; i < src_logical.size(); ++i) {
-      // This only handles the padding added by PadDotOperandsIfNeededForSplitK,
-      // which sets only edge_padding_high.
-      const int padding =
-          pad->padding_config().dimensions(i).edge_padding_high();
-      CHECK_EQ(pad->padding_config().dimensions(i).edge_padding_low(), 0);
-      CHECK_EQ(pad->padding_config().dimensions(i).interior_padding(), 0);
-      if (padding == 0) {
-        dst_logical[i] = src_logical[i];
-      } else {
-        // This case is executed for the contracting dimension when we run the
-        // TritonFusionAnalysis after the padding and the split-k transform are
-        // applied.
-        const std::vector<Fragment*>& fragments = src_logical[i];
-        // We must have 2 fragments at this point.
-        CHECK_EQ(fragments.size(), 2);
-        // The dst_dim_numbers must be the same for the 2 fragments of the
-        // contracting dimension after applying split-k.
-        CHECK_EQ(fragments[0]->dst_dim_number(),
-                 fragments[1]->dst_dim_number());
 
-        new_fragments.emplace_back(
-            fragments[0]->dst_dim_number(),
-            fragments[0]->full_size() * fragments[1]->full_size() - padding);
-        dst_logical[i] = {&new_fragments.back()};
+  HloInstruction::InstructionVector output;
+  output.push_back(const_cast<HloInstruction*>(hlo));
+  DimOrderUpdates result;
+  for (const HloInstruction* dst :
+       (direction == TransformDirection::kInputToOutput) ? output
+                                                         : hlo->operands()) {
+    DimensionOrder& dst_dim_order =
+        result.map.insert({dst, DimensionOrder()}).first->second;
+    // Source logical -> destination logical.
+    std::vector<std::vector<Fragment*>> dst_logical;
+    if (hlo->opcode() == HloOpcode::kTranspose) {
+      const auto* transpose = Cast<HloTransposeInstruction>(hlo);
+      std::vector<int64_t> permutation(transpose->dimensions().cbegin(),
+                                       transpose->dimensions().cend());
+      if (direction == TransformDirection::kInputToOutput) {
+        permutation = InversePermutation(permutation);
       }
-    }
-  } else if (hlo->opcode() == HloOpcode::kSlice) {
-    const auto slice = Cast<HloSliceInstruction>(hlo);
-    dst_logical.resize(src_logical.size());
-    for (int dim = 0; dim < src_logical.size(); ++dim) {
-      dst_logical[dim] = src_logical[dim];
-      if (slice->slice_limits(dim) - slice->slice_starts(dim) !=
-          dst->shape().dimensions(dim)) {
-        if (dst_logical[dim].size() > 1) {
-          return FusionDecision("Slicing of fragmented dimension.");
-        }
-        dst_logical[dim].front()->set_size(dst->shape().dimensions(dim));
-        dst_logical[dim].front()->set_slice(slice->slice_starts(dim),
-                                            slice->slice_limits(dim));
+      dst_logical.resize(permutation.size());
+      for (int i = 0; i < permutation.size(); ++i) {
+        dst_logical[permutation[i]] = src_logical[i];
       }
-    }
-  } else {
-    return FusionDecision("Function called on a wrong instruction.");
-  }
-  // Destination logical -> destination physical and ungroup subdimensions.
-  // Map original fragments to the resulting ones to derive their new
-  // logical ordering within each dimension.
-  absl::flat_hash_map<const Fragment*, int> src_to_dst;
-  FragmentOrders& dst_dim_fragments_order = dst_dim_order.DimFragmentsOrders();
-  // Remember which dimensions are present before a broadcast;
-  // skip cases when already present dimension is being expanded.
-  absl::flat_hash_set<int> dim_numbers_present_in_dst;
-  for (const int64_t dim_idx : dst->shape().layout().minor_to_major()) {
-    for (const Fragment* subdim : dst_logical[dim_idx]) {
-      dst_fragments_order.push_back(*subdim);
-      src_to_dst[subdim] = dst_fragments_order.size() - 1;
-      dim_numbers_present_in_dst.insert(subdim->dst_dim_number());
-      if (std::holds_alternative<SoftmaxProperties>(properties_) &&
-          subdim->dst_dim_number() == std::get<SoftmaxProperties>(properties_)
-                                          .softmax_reduction_dimension) {
-        dst_dim_fragments_order[subdim->dst_dim_number()].push_back(
-            dst_fragments_order.size() - 1);
+    } else if (hlo->opcode() == HloOpcode::kBroadcast) {
+      const auto* broadcast = Cast<HloBroadcastInstruction>(hlo);
+      dst_logical.resize(broadcast->dimensions().size());
+      for (int i = 0; i < broadcast->dimensions().size(); ++i) {
+        dst_logical[i] = src_logical[broadcast->dimensions()[i]];
       }
-    }
-  }
-  for (const auto& [dim_index, dim_sequence] :
-       dim_orders.at(src).DimFragmentsOrders()) {
-    for (const int fragment_number : dim_sequence) {
-      const auto it = src_to_dst.find(&src_fragments_order[fragment_number]);
-      if (it == src_to_dst.cend()) {
-        if (hlo->opcode() == HloOpcode::kBroadcast &&
-            src_fragments_order[fragment_number].full_size() > 1 &&
-            dim_numbers_present_in_dst.contains(dim_index)) {
-          return FusionDecision("Unsupported broadcast");
-        }
+    } else if (hlo->opcode() == HloOpcode::kReduce) {
+      // Operand 1 (the neutral value) has to be a scalar.
+      if (dst != hlo && hlo->operand_index(dst) == 1) {
         continue;
       }
-      dst_dim_fragments_order[dim_index].push_back(it->second);
+      const auto* reduce = Cast<HloReduceInstruction>(hlo);
+      dst_logical.resize(src_logical.size() + reduce->dimensions().size());
+      if (reduce->dimensions().size() != 1) {
+        return FusionDecision("Unsupported reduction.");
+      }
+      for (int i = 0; i < dst_logical.size(); ++i) {
+        if (i == reduce->dimensions().front()) {
+          // This way to assign the reduction dimension will only work for
+          // softmax fusions with known patterns for now. Generally a reduction
+          // should create a new tiled dimension.
+          dst_logical[i] = {&new_fragments.emplace_back(
+              std::get<SoftmaxProperties>(properties_)
+                  .softmax_reduction_dimension,
+              reduce->operand(0)->shape().dimensions(i))};
+        } else {
+          dst_logical[i] = src_logical[i];
+        }
+      }
+    } else if (hlo->opcode() == HloOpcode::kCopy) {
+      // Copy preserves the logical shape, just permutes the layout.
+      CHECK(ShapeUtil::SameDimensions(src->shape(), dst->shape()));
+      dst_logical = src_logical;
+    } else if (hlo->opcode() == HloOpcode::kPad) {
+      // Operand 1 (the padding value) has to be a scalar.
+      if (dst != hlo && hlo->operand_index(dst) == 1) {
+        continue;
+      }
+      const auto* pad = Cast<HloPadInstruction>(hlo);
+      dst_logical.resize(src_logical.size());
+      for (int i = 0; i < src_logical.size(); ++i) {
+        // This only handles the padding added by
+        // PadDotOperandsIfNeededForSplitK, which sets only edge_padding_high.
+        const int padding =
+            pad->padding_config().dimensions(i).edge_padding_high();
+        CHECK_EQ(pad->padding_config().dimensions(i).edge_padding_low(), 0);
+        CHECK_EQ(pad->padding_config().dimensions(i).interior_padding(), 0);
+        if (padding == 0) {
+          dst_logical[i] = src_logical[i];
+        } else {
+          // This case is executed for the contracting dimension when we run the
+          // TritonFusionAnalysis after the padding and the split-k transform
+          // are applied.
+          const std::vector<Fragment*>& fragments = src_logical[i];
+          // We must have 2 fragments at this point.
+          CHECK_EQ(fragments.size(), 2);
+          // The dst_dim_numbers must be the same for the 2 fragments of the
+          // contracting dimension after applying split-k.
+          CHECK_EQ(fragments[0]->dst_dim_number(),
+                   fragments[1]->dst_dim_number());
+
+          new_fragments.emplace_back(
+              fragments[0]->dst_dim_number(),
+              fragments[0]->full_size() * fragments[1]->full_size() - padding);
+          dst_logical[i] = {&new_fragments.back()};
+        }
+      }
+    } else if (hlo->opcode() == HloOpcode::kSlice) {
+      const auto slice = Cast<HloSliceInstruction>(hlo);
+      dst_logical.resize(src_logical.size());
+      for (int dim = 0; dim < src_logical.size(); ++dim) {
+        dst_logical[dim] = src_logical[dim];
+        if (slice->slice_limits(dim) - slice->slice_starts(dim) !=
+            dst->shape().dimensions(dim)) {
+          if (dst_logical[dim].size() > 1) {
+            return FusionDecision("Slicing of fragmented dimension.");
+          }
+          dst_logical[dim].front()->set_size(dst->shape().dimensions(dim));
+          dst_logical[dim].front()->set_slice(slice->slice_starts(dim),
+                                              slice->slice_limits(dim));
+        }
+      }
+    } else {
+      return FusionDecision("Function called on a wrong instruction.");
+    }
+    // Destination logical -> destination physical and ungroup subdimensions.
+    // Map original fragments to the resulting ones to derive their new
+    // logical ordering within each dimension.
+    absl::flat_hash_map<const Fragment*, int> src_to_dst;
+    Fragments& dst_fragments_order = dst_dim_order.TensorFragmentsOrder();
+    FragmentOrders& dst_dim_fragments_order =
+        dst_dim_order.DimFragmentsOrders();
+    // Remember which dimensions are present before a broadcast;
+    // skip cases when already present dimension is being expanded.
+    absl::flat_hash_set<int> dim_numbers_present_in_dst;
+    for (const int64_t dim_idx : dst->shape().layout().minor_to_major()) {
+      for (const Fragment* subdim : dst_logical[dim_idx]) {
+        dst_fragments_order.push_back(*subdim);
+        src_to_dst[subdim] = dst_fragments_order.size() - 1;
+        dim_numbers_present_in_dst.insert(subdim->dst_dim_number());
+      }
+    }
+    for (const auto& [dim_index, dim_sequence] :
+         dim_orders.at(src).DimFragmentsOrders()) {
+      for (const int fragment_number : dim_sequence) {
+        const auto it = src_to_dst.find(&src_fragments_order[fragment_number]);
+        if (it == src_to_dst.cend()) {
+          if (hlo->opcode() == HloOpcode::kBroadcast &&
+              src_fragments_order[fragment_number].full_size() > 1 &&
+              dim_numbers_present_in_dst.contains(dim_index)) {
+            return FusionDecision("Unsupported broadcast");
+          }
+          continue;
+        }
+        dst_dim_fragments_order[dim_index].push_back(it->second);
+      }
     }
   }
   return result;
diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc b/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc
index a245615..f29b12b 100644
--- a/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc
+++ b/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc
@@ -752,6 +752,42 @@
             nullptr);
 }
 
+TEST_F(TritonSoftmaxAnalysisTest, BroadcastIntoBatchDimensionIsSupported) {
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(R"(
+c {
+  p1 = f32[127]{0} parameter(0)
+  ROOT b = f32[125,127]{1,0} broadcast(p1), dimensions={1}
+}
+
+ENTRY e {
+  p0 = f32[127]{0} parameter(0)
+  ROOT t = f32[125,127]{1,0} fusion(p0), kind=kCustom, calls=c
+})"));
+  const HloComputation* computation =
+      module->entry_computation()->root_instruction()->called_computations()[0];
+  TF_ASSERT_OK_AND_ASSIGN(const auto analysis,
+                          TritonFusionAnalysis::Execute(*computation));
+  EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT,
+                                 computation->root_instruction(), 0),
+              ElementsAre(FieldsAre(/*stride=*/1, /*count=*/127,
+                                    /*slice_start=*/0, /*slice_limit=*/127,
+                                    /*subfragments=*/ElementsAre(127))));
+  EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT,
+                                 computation->root_instruction(), 1),
+              ElementsAre(FieldsAre(/*stride=*/127, /*count=*/125,
+                                    /*slice_start=*/0, /*slice_limit=*/125,
+                                    /*subfragments=*/ElementsAre(125))));
+  EXPECT_THAT(*analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT,
+                                 computation->parameter_instruction(0), 0),
+              ElementsAre(FieldsAre(/*stride=*/1, /*count=*/127,
+                                    /*slice_start=*/0, /*slice_limit=*/127,
+                                    /*subfragments=*/ElementsAre(127))));
+  EXPECT_EQ(analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT,
+                              computation->parameter_instruction(0), 1),
+            nullptr);
+}
+
 TEST_F(GemmRewriterTritonTest, HandleDotIfCublasRequiresPadding) {
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
                           ParseAndReturnVerifiedModule(R"(
diff --git a/third_party/xla/xla/service/gpu/gemm_thunk.cc b/third_party/xla/xla/service/gpu/gemm_thunk.cc
index 0abdc69..b774280 100644
--- a/third_party/xla/xla/service/gpu/gemm_thunk.cc
+++ b/third_party/xla/xla/service/gpu/gemm_thunk.cc
@@ -41,10 +41,13 @@
 Status GemmThunk::ExecuteOnStream(const ExecuteParams& params) {
   VLOG(3) << "Running GEMM thunk";
   const BufferAllocations& allocs = *params.buffer_allocations;
+  // TODO(ezhulenev): Pass a correct workspace. For now we ignore it as Thunks
+  // are disabled by default, and they do not interact with CUDA graphs.
+  se::DeviceMemoryBase workspace(nullptr, 0);
   return RunGemm(config_, allocs.GetDeviceAddress(lhs_buffer_),
                  allocs.GetDeviceAddress(rhs_buffer_),
-                 allocs.GetDeviceAddress(output_buffer_), deterministic_,
-                 params.stream);
+                 allocs.GetDeviceAddress(output_buffer_), workspace,
+                 deterministic_, params.stream);
 }
 
 Status GemmThunk::Initialize(se::StreamExecutor* executor,
diff --git a/third_party/xla/xla/service/gpu/gpu_all_gather_optimizer.cc b/third_party/xla/xla/service/gpu/gpu_all_gather_optimizer.cc
index 29f52ff..ce25a01 100644
--- a/third_party/xla/xla/service/gpu/gpu_all_gather_optimizer.cc
+++ b/third_party/xla/xla/service/gpu/gpu_all_gather_optimizer.cc
@@ -26,6 +26,7 @@
 #include "xla/hlo/ir/hlo_module.h"
 #include "xla/hlo/ir/hlo_opcode.h"
 #include "xla/service/collective_ops_utils.h"
+#include "xla/shape_util.h"
 #include "xla/statusor.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/logging.h"
@@ -68,6 +69,12 @@
         continue;
       }
 
+      if (!ShapeUtil::Equal(left_all_gather->operand(0)->shape(),
+                            right_all_gather->operand(0)->shape())) {
+        VLOG(2) << "all-gather operands have different shapes";
+        continue;
+      }
+
       if (right_all_gather->user_count() != 1 ||
           left_all_gather->user_count() != 1) {
         VLOG(2) << "all-gather user_count > 1 ";
diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc
index 75e3460..f4756e9 100644
--- a/third_party/xla/xla/service/gpu/gpu_compiler.cc
+++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc
@@ -178,6 +178,7 @@
 #include "xla/service/spmd/stateful_rng_spmd_partitioner.h"
 #include "xla/service/stable_sort_expander.h"
 #include "xla/service/stochastic_convert_decomposer.h"
+#include "xla/service/sub_byte_normalization.h"
 #include "xla/service/topk_rewriter.h"
 #include "xla/service/transpose_folding.h"
 #include "xla/service/tuple_simplifier.h"
@@ -471,6 +472,9 @@
     pipeline.AddPass<OperandUpcaster>(upcaster_filter);
     pipeline.AddPass<ResultCaster>(upcaster_filter);
 
+    pipeline.AddPass<SubByteNormalization>(
+        SubByteNormalization::SET_ELEMENT_SIZE);
+
     // Expand random number generation.
     pipeline.AddPass<RngExpander>();
     pipeline.AddPass<RngBitGeneratorExpander>(RandomAlgorithm::RNG_PHILOX);
@@ -727,6 +731,10 @@
     pipeline.AddPass<GpuLayoutAssignment>(
         hlo_module->mutable_entry_computation_layout(), stream_exec,
         &layout_constraints);
+    // Run SubByteNormalization because GpuLayoutAssignment may modify a
+    // Layout's element_size_in_bits field.
+    pipeline.AddPass<SubByteNormalization>(
+        SubByteNormalization::SET_ELEMENT_SIZE);
     TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
   }
 
@@ -781,6 +789,9 @@
       pipeline.AddPass<AllReduceContiguous>();
     }
 
+    TF_RETURN_IF_ERROR(
+        AddCustomKernelReplacementPasses(&pipeline, debug_options));
+
     int32_t blueconnect_num_devices_per_host =
         debug_options.xla_gpu_all_reduce_blueconnect_num_devices_per_host();
     if (blueconnect_num_devices_per_host > 0) {
@@ -1532,7 +1543,7 @@
                             thunk_sequence.ToString());
   }
 
-  std::shared_ptr<const BufferAssignment> buffer_assignment;
+  std::shared_ptr<BufferAssignment> buffer_assignment;
   std::unique_ptr<BufferAssignmentProto> buffer_assignment_proto;
   std::function<std::string()> buffer_assignment_dumper = [] {
     return std::string();
@@ -1549,6 +1560,21 @@
     };
   }
 
+  std::vector<BufferAllocation> allocations;
+  if (compile_module_results.use_original_allocations) {
+    if (!options.is_autotuning_compilation) {
+      std::vector<BufferAllocation> original_allocations =
+          buffer_assignment->ReleaseAllocations();
+      allocations = std::move(original_allocations);
+    } else {
+      std::vector<BufferAllocation> original_allocations =
+          compile_module_results.buffer_assignment->ReleaseAllocations();
+      allocations = std::move(original_allocations);
+    }
+  } else {
+    allocations = std::move(compile_module_results.allocations);
+  }
+
   TF_ASSIGN_OR_RETURN(
       auto gpu_executable,
       GpuExecutable::Create(GpuExecutable::Params{
@@ -1564,7 +1590,7 @@
           /*output_info=*/std::move(compile_module_results.output_info),
           /*module_name=*/std::move(compile_module_results.module_name),
           /*output_shape=*/std::move(compile_module_results.output_shape),
-          /*allocations=*/std::move(compile_module_results.allocations),
+          /*allocations=*/std::move(allocations),
           /*enable_persistent_temp_buffers=*/
           module->config()
               .debug_options()
@@ -1761,7 +1787,6 @@
         HloPredicateIsOp<HloOpcode::kParameter, HloOpcode::kConstant,
                          HloOpcode::kBitcast, HloOpcode::kGetTupleElement>;
     pipeline.AddPass<GpuConvertAsyncCollectivesToSync>(is_nop);
-    pipeline.AddPass<OptimizationBarrierExpander>();
 
     TF_RETURN_IF_ERROR(pipeline.Run(module).status());
   }
@@ -1782,6 +1807,7 @@
         /*host_memory_offload_config=*/std::nullopt);
     HloRematerialization::RematerializationSizes sizes;
     pipeline.AddPass<HloRematerialization>(options, sizes);
+    pipeline.AddPass<OptimizationBarrierExpander>();
 
     TF_ASSIGN_OR_RETURN(bool changed, pipeline.Run(module));
     if (changed) {
diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.h b/third_party/xla/xla/service/gpu/gpu_compiler.h
index 89079f8..1c32a9b 100644
--- a/third_party/xla/xla/service/gpu/gpu_compiler.h
+++ b/third_party/xla/xla/service/gpu/gpu_compiler.h
@@ -34,11 +34,13 @@
 #include "xla/service/hlo_dataflow_analysis.h"
 #include "xla/service/hlo_pass_pipeline.h"
 #include "xla/service/llvm_compiler.h"
+#include "xla/status.h"
 #include "xla/statusor.h"
 #include "xla/stream_executor/device_description.h"
 #include "xla/stream_executor/device_description.pb.h"
 #include "xla/stream_executor/stream_executor.h"
 #include "xla/util.h"
+#include "xla/xla.pb.h"
 
 namespace xla {
 namespace gpu {
@@ -106,7 +108,11 @@
 
   // An attached device is passed in via stream_exec. We get GPU configuration
   // from the attached device. GemmAlgorithmPicker and GpuConvAlgorithmPicker
-  // can run on the attached device.
+  // can run on the attached device. If you call this directly, follow it with
+  // RunBackend rather than Compile. To compile without an attached device,
+  // pass a nullptr stream_exec and set a TargetConfig in the CompileOptions,
+  // and then call CompileAheadOfTime. See service/xla_compile_main.cc for an
+  // example.
   StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
       std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
       const CompileOptions& options) override;
@@ -186,6 +192,12 @@
     return OkStatus();
   }
 
+  // Add passes that convert HLO operations to custom kernels.
+  virtual Status AddCustomKernelReplacementPasses(
+      HloPassPipeline* pipeline, const DebugOptions& debug_options) {
+    return OkStatus();
+  }
+
  private:
   Status LoadAutotuneResultsFromFile(const DebugOptions& debug_options);
   Status SerializeAutotuneResultsToFile(const DebugOptions& debug_options);
diff --git a/third_party/xla/xla/service/gpu/gpu_conv_runner.h b/third_party/xla/xla/service/gpu/gpu_conv_runner.h
index ff85fce..748db93 100644
--- a/third_party/xla/xla/service/gpu/gpu_conv_runner.h
+++ b/third_party/xla/xla/service/gpu/gpu_conv_runner.h
@@ -200,8 +200,6 @@
   GenericConvRunner* runner_cache;
 };
 
-// This file contains low-level routines for running cudnn convolutions.
-
 // Calls into cudnn to run the specified convolution.
 //
 // We provide one overload which takes a scratch buffer, and another which takes
diff --git a/third_party/xla/xla/service/gpu/gpu_copy_insertion_test.cc b/third_party/xla/xla/service/gpu/gpu_copy_insertion_test.cc
index dc059e4..e3e6425 100644
--- a/third_party/xla/xla/service/gpu/gpu_copy_insertion_test.cc
+++ b/third_party/xla/xla/service/gpu/gpu_copy_insertion_test.cc
@@ -201,13 +201,14 @@
   param_1.1 = f32[2,3]{1,0} parameter(1)
   neg = f32[2,3]{1,0} negate(param_1.1)
   mul = f32[2,3]{1,0} multiply(param_0.1, neg)
-  ROOT tuple = (f32[2,3]{1,0}, f32[2,3]{1,0}) tuple(mul, neg)
+  transpose = f32[3,2]{1,0} transpose(neg), dimensions={1,0}
+  ROOT tuple = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[3,2]{1,0}) tuple(mul, neg, transpose)
 }
 
 ENTRY main {
   param_0 = f32[2,3]{1,0} parameter(0)
   param_1 = f32[2,3]{1,0} parameter(1)
-  ROOT fusion = (f32[2,3]{1,0}, f32[2,3]{1,0}) fusion(param_0, param_1), kind=kLoop, calls=fused_computation
+  ROOT fusion = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[3,2]{1,0}) fusion(param_0, param_1), kind=kLoop, calls=fused_computation
 }
 )";
 
@@ -216,7 +217,7 @@
   HloInstruction* fusion = module->entry_computation()->root_instruction();
   ExpectOptionalTrue(FusionCanShareBufferHint(fusion, fusion->operand(0), {0}));
   // The second operand cannot share the buffer with the second fusion output,
-  // because the 'neg' op is also used on the path to the first fusion output.
+  // because the 'neg' op is also used by a non-elementwise op.
   ExpectOptionalFalse(
       FusionCanShareBufferHint(fusion, fusion->operand(1), {1}));
   // The first operand cannot share the buffer with the second fusion output,
@@ -225,6 +226,39 @@
       FusionCanShareBufferHint(fusion, fusion->operand(0), {1}));
 }
 
+TEST_F(FusionCanShareBufferHintTest, BufferCanBeSharedReductionEmitter) {
+  constexpr char kModuleString[] = R"(
+HloModule TestModule
+
+%maximum {
+  %lhs = f32[] parameter(0)
+  %rhs = f32[] parameter(1)
+  ROOT %res = f32[] maximum(%lhs, %rhs)
+}
+
+%fused_computation {
+  %lhs = f32[3,40] parameter(0)
+  %rhs = f32[3,40] parameter(1)
+  %add = f32[3,40] add(%lhs, %rhs)
+  %bc = f32[120] bitcast(%add)
+  %init = f32[] constant(-inf)
+  %max = f32[] reduce(%bc, %init), dimensions={0}, to_apply=%maximum
+  ROOT %result = (f32[], f32[3,40]) tuple(%max, %add)
+}
+
+ENTRY %main {
+  %lhs = f32[3,40] parameter(0)
+  %rhs = f32[3,40] parameter(1)
+  ROOT %fusion = (f32[], f32[3,40]) fusion(%lhs, %rhs),
+      kind=kLoop, calls=%fused_computation
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
+                          ParseAndReturnVerifiedModule(kModuleString));
+  HloInstruction* fusion = module->entry_computation()->root_instruction();
+  ExpectOptionalTrue(FusionCanShareBufferHint(fusion, fusion->operand(0), {1}));
+}
+
 TEST_F(FusionCanShareBufferHintTest, BufferCanBeSharedScatterFusion) {
   const char* const kModuleString = R"(
     HloModule fusion
diff --git a/third_party/xla/xla/service/gpu/gpu_executable.cc b/third_party/xla/xla/service/gpu/gpu_executable.cc
index a6aebab..a576d89 100644
--- a/third_party/xla/xla/service/gpu/gpu_executable.cc
+++ b/third_party/xla/xla/service/gpu/gpu_executable.cc
@@ -996,7 +996,7 @@
   // to make sure they are loaded in correct order.
   auto export_ops = llvm::to_vector(module.getOps<runtime::ExportOp>());
   llvm::sort(export_ops, [](runtime::ExportOp a, runtime::ExportOp b) {
-    return b.getOrdinal()->getSExtValue() < b.getOrdinal()->getSExtValue();
+    return a.getOrdinal()->getSExtValue() < b.getOrdinal()->getSExtValue();
   });
   for (runtime::ExportOp exported : export_ops) {
     TF_CHECK_OK(convert(
@@ -1023,6 +1023,32 @@
   return buffer_sizes;
 }
 
+// TODO(ezhulenev): This is a copy of `GetAllocationIndices` from
+// `mlir/backends/gpu/transforms/passes.h`. We can't depend on that file because
+// of a dependency cycle, and this is a short term work around the cuda graph
+// capture bug. This code should not survive beyond Q1 2024.
+static std::vector<std::vector<int64_t>> GetAllocationIndices(
+    mlir::ModuleOp module) {
+  std::vector<std::vector<int64_t>> res;
+
+  mlir::SymbolTable sym_table(module);
+  for (auto op : module.getOps<runtime::ExportOp>()) {
+    unsigned ordinal = *op.ordinal();
+    if (ordinal >= res.size()) res.resize(ordinal + 1);
+
+    auto func = sym_table.lookup<mlir::func::FuncOp>(op.getFunctionRef());
+    res[ordinal].resize(func.getNumArguments(), -1);
+
+    for (unsigned i = 0; i < func.getNumArguments(); ++i) {
+      auto idx =
+          func.getArgAttrOfType<mlir::IntegerAttr>(i, "rt.allocation_index");
+      if (idx) res[ordinal][i] = idx.getInt();
+    }
+  }
+
+  return res;
+}
+
 StatusOr<std::unique_ptr<Executable>> GpuExecutable::LoadFromObjFile(
     std::shared_ptr<HloModule> hlo_module, absl::string_view obj_file,
     absl::string_view mlir_module,
@@ -1053,6 +1079,9 @@
   TF_ASSIGN_OR_RETURN(std::vector<int64_t> buffer_sizes,
                       GetBufferSizes(functions[0].signature));
 
+  // Get allocation indices from graph capture functions.
+  auto allocation_indices = GetAllocationIndices(*module);
+
   // Get the XLA module entrypoint function.
   auto func = mlir::cast<mlir::func::FuncOp>(module->lookupSymbol(entry));
 
@@ -1082,8 +1111,9 @@
   // Move runtime::Executable ownership to the GpuRuntimeExecutable.
   TF_ASSIGN_OR_RETURN(auto gpu_runtime_executable,
                       GpuRuntimeExecutable::Create(
-                          hlo_module->name(), buffer_sizes,
-                          std::move(*executable), std::move(debug_options)));
+                          hlo_module->name(), std::move(buffer_sizes),
+                          std::move(allocation_indices), std::move(*executable),
+                          std::move(debug_options)));
 
   // Construct GpuExecutable for the loaded XLA Runtime executable.
   std::string name = hlo_module->name();
diff --git a/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.cc b/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.cc
index 43748b1..74c3d83 100644
--- a/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.cc
+++ b/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.cc
@@ -82,7 +82,9 @@
                                      params.config->mask,
                                      params.config->activation,
                                      dropout_rate,
-                                     seed};
+                                     seed,
+                                     false,
+                                     false};
   TF_ASSIGN_OR_RETURN(auto *runner,
                       lazy_runner->GetOrCreateRunner(config, stream));
   return (*runner)(stream, options.profile_result, scratch_memory,
@@ -237,29 +239,37 @@
   if (params.config->seed) {
     seed = *params.config->seed;
   }
-
-  se::dnn::FusedMHABackwardOp::Config config{kind,
-                                             scale,
-                                             params.config->bmm1_grad_gemm1_rhs,
-                                             params.config->bmm1_grad_gemm2_rhs,
-                                             params.config->bmm2_grad_gemm1_lhs,
-                                             params.config->bmm2_grad_gemm2_rhs,
-                                             params.config->d_output,
-                                             params.config->d_bmm1_lhs,
-                                             params.config->d_bmm1_rhs,
-                                             params.config->d_bmm2_rhs,
-                                             params.config->d_s,
-                                             params.config->mask,
-                                             params.config->d_bias,
-                                             dropout_rate,
-                                             seed};
+  // TODO: set is_flash_attention to real value, set it to false for now
+  se::dnn::FusedMHABackwardOp::Config config{
+      kind,
+      scale,
+      params.config->bmm1_grad_gemm1_rhs,
+      params.config->bmm1_grad_gemm2_rhs,
+      params.config->bmm2_grad_gemm1_lhs,
+      params.config->bmm2_grad_gemm2_rhs,
+      params.config->d_output,
+      params.config->d_bmm1_lhs,
+      params.config->d_bmm1_rhs,
+      params.config->d_bmm2_rhs,
+      std::optional<TensorDescriptor>(params.config->d_s),
+      params.config->mask,
+      params.config->d_bias,
+      std::nullopt,
+      std::nullopt,
+      dropout_rate,
+      seed,
+      false,
+      false};
   TF_ASSIGN_OR_RETURN(auto *runner,
                       lazy_runner->GetOrCreateRunner(config, stream));
+  // TODO: pass in real softmax_sum, dQ_accum, fwd_output
   return (*runner)(stream, options.profile_result, scratch_memory,
                    bmm1_grad_gemm1_rhs_buffer, bmm1_grad_gemm2_rhs_buffer,
                    bmm2_grad_gemm1_lhs_buffer, bmm2_grad_gemm2_rhs_buffer,
                    d_output_buffer, d_bmm1_lhs_buffer, d_bmm1_rhs_buffer,
-                   d_bmm2_rhs_buffer, d_s_buffer, mask_buffer, d_bias_buffer);
+                   d_bmm2_rhs_buffer, d_s_buffer, se::DeviceMemoryBase(),
+                   se::DeviceMemoryBase(), mask_buffer, d_bias_buffer,
+                   se::DeviceMemoryBase(), se::DeviceMemoryBase());
 }
 
 template <typename ElementType, typename BiasType, typename OutputType>
diff --git a/third_party/xla/xla/service/gpu/gpu_fusible.cc b/third_party/xla/xla/service/gpu/gpu_fusible.cc
index b658c94..744dd1c 100644
--- a/third_party/xla/xla/service/gpu/gpu_fusible.cc
+++ b/third_party/xla/xla/service/gpu/gpu_fusible.cc
@@ -106,6 +106,28 @@
                                          instr.shape(), instr.dimensions()));
 }
 
+bool TransposesMinorDimension(const HloInstruction* instr) {
+  switch (instr->opcode()) {
+    case HloOpcode::kFusion:
+      return absl::c_any_of(instr->fused_instructions(),
+                            TransposesMinorDimension);
+    case HloOpcode::kCopy:
+      return instr->shape().layout().minor_to_major(0) !=
+             instr->operand(0)->shape().layout().minor_to_major(0);
+    case HloOpcode::kTranspose: {
+      // We have an input ([a,b,c]{x,y,z}) that's being transposed. We need to
+      // check if the minor-most dimension (x) is still the minor-most dimension
+      // after the transpose.
+      int64_t minor_input =
+          instr->operand(0)->shape().layout().minor_to_major(0);
+      int64_t minor_output = instr->shape().layout().minor_to_major(0);
+      return minor_input != instr->dimensions().at(minor_output);
+    }
+    default:
+      return false;
+  }
+}
+
 bool IsReduceInputFusion(const HloInstruction& instr) {
   return instr.opcode() == HloOpcode::kFusion &&
          absl::c_any_of(GetFusionRoots(*instr.called_computations()[0]),
diff --git a/third_party/xla/xla/service/gpu/gpu_fusible.h b/third_party/xla/xla/service/gpu/gpu_fusible.h
index 29b9e53..897b44b 100644
--- a/third_party/xla/xla/service/gpu/gpu_fusible.h
+++ b/third_party/xla/xla/service/gpu/gpu_fusible.h
@@ -70,10 +70,25 @@
 
 inline constexpr int64_t MaxOperandsAndOutputsPerFusion() { return 64; }
 
-// Whether the op tranposes the physical data layout. Fusing such ops may lead
+// Whether the op transposes the physical data layout. Fusing such ops may lead
 // to uncoalesced data access and may thus not be beneficial.
 bool IsPhysicallyTransposing(const HloInstruction& instr);
 
+// Whether the op transposes the minor-most dimension. In the case of fusions,
+// whether the fusion contains some op that does this.
+// If the minor-most dimension is transposed, this results in uncoalesced memory
+// accesses in untiled code generators. If some other dimension is transposed,
+// this just results in additional index computations.
+// Note that this function makes several simplifying assumptions:
+// - For non-fusion instructions, we assume the output is materialized as is.
+//   For internal instructions, this may not be the case.
+// - For fusions, it simply checks the output of this function for each
+//   instruction in the fusion's computation.
+// - There's no way to tell which parameters of the fusion are transposed.
+// TODO(jreiffers): Take into account the size of the transposed dimension as
+// well.
+bool TransposesMinorDimension(const HloInstruction* instr);
+
 // Note that reduction ops are lowered in different ways. Reduce input fusions
 // are lowered by IrEmitterUnnested::EmitReductionToVector and must be rooted at
 // reduction-to-vector ops. Other reduction ops are lowered by
diff --git a/third_party/xla/xla/service/gpu/gpu_fusible_test.cc b/third_party/xla/xla/service/gpu/gpu_fusible_test.cc
index bbdd486..8766416 100644
--- a/third_party/xla/xla/service/gpu/gpu_fusible_test.cc
+++ b/third_party/xla/xla/service/gpu/gpu_fusible_test.cc
@@ -210,6 +210,56 @@
   EXPECT_FALSE(IsPhysicallyTransposing(*loop_fusion));
 }
 
+TEST_F(GpuFusibleTest, TransposesMinorDimension) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    ENTRY entry {
+      default_layout = f32[10,20,30,40]{3,2,1,0} parameter(0)
+      non_default_layout = f32[10,20,30,40]{1,2,3,0} parameter(1)
+
+      transpose_minor_default = f32[10,20,40,30]{3,2,1,0} transpose(default_layout), dimensions={0,1,3,2}
+      no_transpose_minor_default = f32[10,20,40,30]{2,3,1,0} transpose(default_layout), dimensions={0,1,3,2}
+      transpose_major_default = f32[10,30,20,40]{3,2,1,0} transpose(default_layout), dimensions={0,2,1,3}
+
+      transpose_minor_non_default = f32[10,30,20,40]{1,2,3,0} transpose(non_default_layout), dimensions={0,2,1,3}
+      no_transpose_minor_non_default = f32[10,20,40,30]{1,2,0,3} transpose(non_default_layout), dimensions={0,1,3,2}
+      transpose_major_non_default = f32[10,20,40,30]{1,2,3,0} transpose(non_default_layout), dimensions={0,1,3,2}
+
+      ROOT r = tuple(transpose_minor_default, no_transpose_minor_default, transpose_major_default,
+                     transpose_minor_non_default, no_transpose_minor_non_default, transpose_major_non_default)
+    })"));
+
+  auto* tuple = (*module)->entry_computation()->root_instruction();
+  EXPECT_TRUE(TransposesMinorDimension(tuple->operand(0)));
+  EXPECT_FALSE(TransposesMinorDimension(tuple->operand(1)));
+  EXPECT_FALSE(TransposesMinorDimension(tuple->operand(2)));
+  EXPECT_TRUE(TransposesMinorDimension(tuple->operand(3)));
+  EXPECT_FALSE(TransposesMinorDimension(tuple->operand(4)));
+  EXPECT_FALSE(TransposesMinorDimension(tuple->operand(5)));
+}
+
+TEST_F(GpuFusibleTest, CopyTransposesMinorDimension) {
+  auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
+    ENTRY entry {
+      default_layout = f32[10,20,30,40]{3,2,1,0} parameter(0)
+      non_default_layout = f32[10,20,30,40]{1,2,3,0} parameter(1)
+
+      copy_transpose_minor_default = f32[10,20,30,40]{2,3,1,0} copy(default_layout)
+      copy_no_transpose_minor_default = f32[10,20,30,40]{3,2,1,0} copy(default_layout)
+
+      copy_transpose_minor_non_default = f32[10,20,30,40]{2,1,3,0} copy(non_default_layout)
+      copy_no_transpose_minor_non_default = f32[10,20,30,40]{1,2,3,0} copy(non_default_layout)
+
+      ROOT r = tuple(copy_transpose_minor_default, copy_no_transpose_minor_default,
+                     copy_transpose_minor_non_default, copy_no_transpose_minor_non_default)
+    })"));
+
+  auto* tuple = (*module)->entry_computation()->root_instruction();
+  EXPECT_TRUE(TransposesMinorDimension(tuple->operand(0)));
+  EXPECT_FALSE(TransposesMinorDimension(tuple->operand(1)));
+  EXPECT_TRUE(TransposesMinorDimension(tuple->operand(2)));
+  EXPECT_FALSE(TransposesMinorDimension(tuple->operand(3)));
+}
+
 TEST_F(GpuFusibleTest, IsReduceInputFusion_ReductionToVector) {
   auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
     ENTRY entry {
diff --git a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc
index dd01ed3..baeccac 100644
--- a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc
+++ b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc
@@ -15,6 +15,8 @@
 
 #include "xla/service/gpu/gpu_hlo_schedule.h"
 
+#include <cstddef>
+#include <cstdint>
 #include <deque>
 #include <memory>
 #include <optional>
@@ -22,14 +24,25 @@
 #include <utility>
 #include <vector>
 
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
 #include "absl/status/status.h"
 #include "absl/strings/match.h"
 #include "absl/strings/numbers.h"
 #include "absl/strings/str_format.h"
+#include "absl/strings/str_split.h"
 #include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_input_output_alias_config.h"
+#include "xla/hlo/ir/hlo_instruction.h"
 #include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
 #include "xla/hlo/ir/hlo_schedule.h"
 #include "xla/hlo/utils/hlo_query.h"
+#include "xla/service/buffer_value.h"
 #include "xla/service/gpu/backend_configs.pb.h"
 #include "xla/service/gpu/cublas_cudnn.h"
 #include "xla/service/gpu/model/analytical_latency_estimator.h"
@@ -38,7 +51,12 @@
 #include "xla/service/latency_hiding_scheduler.h"
 #include "xla/service/p2p_schedule_preparation.h"
 #include "xla/service/profile_guided_latency_estimator.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/status.h"
+#include "xla/statusor.h"
 #include "xla/stream_executor/device_description.h"
+#include "xla/util.h"
 #include "tsl/platform/env.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/protobuf.h"
@@ -600,6 +618,9 @@
 Status ScheduleGpuModule(HloModule* module, int64_t pointer_size,
                          int64_t memory_limit,
                          const se::DeviceDescription& gpu_device_info) {
+  if (module->has_schedule()) {
+    return OkStatus();
+  }
   HloPassPipeline prepare_pipeline("p2p-schedule-preparation");
   prepare_pipeline.AddPass<P2PSchedulePreparation>();
   TF_RETURN_IF_ERROR(prepare_pipeline.Run(module).status());
@@ -734,7 +755,10 @@
         total_io_size -= GetSizeOfShape(subshape, pointer_size);
       });
 
-  return (base_limit - total_io_size) * 95 / 100;
+  int64_t limit =
+      (base_limit - total_io_size) *
+      module->config().debug_options().xla_gpu_memory_limit_slop_factor() / 100;
+  return limit;
 }
 
 }  // namespace gpu
diff --git a/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc b/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc
index f3d4a21..e952c0c 100644
--- a/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc
+++ b/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc
@@ -16,20 +16,34 @@
 #include "xla/service/gpu/gpu_hlo_schedule.h"
 
 #include <algorithm>
+#include <cstdint>
+#include <cstdlib>
 #include <memory>
 #include <optional>
 #include <string>
 #include <string_view>
 #include <vector>
 
+#include <gtest/gtest.h>
+#include "absl/algorithm/container.h"
+#include "absl/log/log.h"
 #include "xla/hlo/ir/hlo_computation.h"
 #include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
 #include "xla/hlo/ir/hlo_opcode.h"
 #include "xla/hlo/ir/hlo_schedule.h"
 #include "xla/hlo/utils/hlo_query.h"
+#include "xla/service/backend.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/service/hlo_ordering.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
 #include "xla/stream_executor/device_description.h"
+#include "xla/tests/filecheck.h"
 #include "xla/tests/hlo_test_base.h"
 #include "xla/tests/test_utils.h"
+#include "tsl/platform/status.h"
+#include "tsl/platform/statusor.h"
 #include "tsl/profiler/protobuf/profiled_instructions.pb.h"
 
 namespace xla {
@@ -512,11 +526,11 @@
       _xla_send_recv_source_target_pairs="{{0, 1}}"
     }
     send = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all),
-      channel_id=1, control-predecessors={recv}, frontend_attributes={
+      channel_id=1, frontend_attributes={
       _xla_send_recv_source_target_pairs="{{0, 1}}"
     }
     recv-done = (f32[1, 1024, 1024], token[]) recv-done(recv), channel_id=1
-    send-done = token[] send-done(send), control-predecessors={recv-done}, channel_id=1
+    send-done = token[] send-done(send), channel_id=1
     recv-data = f32[1, 1024, 1024] get-tuple-element(recv-done), index=0
 
     c1 = u32[] constant(1)
@@ -593,11 +607,11 @@
       _xla_send_recv_source_target_pairs="{{0, 1}}"
     }
     send-0 = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all-0),
-      channel_id=1, control-predecessors={recv-0}, frontend_attributes={
+      channel_id=1, frontend_attributes={
       _xla_send_recv_source_target_pairs="{{0, 1}}"
     }
     recv-done-0 = (f32[1, 1024, 1024], token[]) recv-done(recv-0), channel_id=1
-    send-done-0 = token[] send-done(send-0), control-predecessors={recv-done-0}, channel_id=1
+    send-done-0 = token[] send-done(send-0), channel_id=1
     recv-data-0 = f32[1, 1024, 1024] get-tuple-element(recv-done-0), index=0
 
     c1 = u32[] constant(1)
@@ -615,11 +629,11 @@
       _xla_send_recv_source_target_pairs="{{1, 0}}"
     }
     send-1 = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all-1),
-      channel_id=2, control-predecessors={recv-1}, frontend_attributes={
+      channel_id=2, frontend_attributes={
       _xla_send_recv_source_target_pairs="{{1, 0}}"
     }
     recv-done-1 = (f32[1, 1024, 1024], token[]) recv-done(recv-1), channel_id=2
-    send-done-1 = token[] send-done(send-1), control-predecessors={recv-done-1}, channel_id=2
+    send-done-1 = token[] send-done(send-1), channel_id=2
     recv-data-1 = f32[1, 1024, 1024] get-tuple-element(recv-done-1), index=0
 
     s2 = f32[1, 1024, 1024] add(recv-data-0, s1)
@@ -695,11 +709,11 @@
       _xla_send_recv_source_target_pairs="{{0, 1}}"
     }
     send = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all),
-      channel_id=1, control-predecessors={recv}, frontend_attributes={
+      channel_id=1, frontend_attributes={
       _xla_send_recv_source_target_pairs="{{0, 1}}"
     }
-    recv-done = (f32[1, 1024, 1024], token[]) recv-done(recv), channel_id=1, control-predecessors={send}
-    send-done = token[] send-done(send), control-predecessors={recv-done}, channel_id=1
+    recv-done = (f32[1, 1024, 1024], token[]) recv-done(recv), channel_id=1
+    send-done = token[] send-done(send), channel_id=1
     recv-data = f32[1, 1024, 1024] get-tuple-element(recv-done), index=0
 
     c1 = u32[] constant(1)
@@ -758,6 +772,171 @@
   EXPECT_TRUE(HasValidFingerprint(module.get()));
 }
 
+// Checks that with the dependence added by the gpu-hlo-scheduler, the
+// pipelined Send and Recv instructions are scheduled correctly.
+TEST_F(GpuHloScheduleTest, LHSSendRecvPipelined) {
+  const char* hlo_text = R"(
+  HloModule test
+
+  while_cond {
+    param = (u32[], f32[1, 1024, 1024], f32[1, 1024, 1024]) parameter(0)
+    count = get-tuple-element(param), index=0
+    ub = u32[] constant(25)
+    ROOT cond-result = pred[] compare(count, ub), direction=LT
+  }
+
+while_body {
+    param = (u32[], f32[1, 1024, 1024], f32[1, 1024, 1024]) parameter(0)
+    count = get-tuple-element(param), index=0
+    send-data = get-tuple-element(param), index=1
+    recv-data = get-tuple-element(param), index=2
+
+    after-all.1 = token[] after-all()
+    send.1 = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all.1),
+      channel_id=1, frontend_attributes={
+      _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}"
+    }
+    send-done.1 = token[] send-done(send.1), channel_id=1
+    recv.1 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.1), channel_id=1,
+      frontend_attributes={
+       _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}"
+    }
+
+    c1 = u32[] constant(1)
+    new-count = u32[] add(count, c1)
+    replica = u32[] replica-id()
+    c10 = u32[] constant(10)
+    sum = u32[] add(replica, c10)
+    sum2 = u32[] add(sum, count)
+    conv = f32[] convert(sum2)
+    p = f32[1, 1024, 1024] broadcast(conv), dimensions={}
+    b = f32[1, 1024, 1024] add(p, recv-data)
+    c = f32[1, 1024, 1024] multiply(b, b)
+    d = f32[1, 1024, 1024] tan(c)
+    s = f32[1, 1024, 1024] dot(c, d), lhs_batch_dims={0},
+      lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1}
+    new-data-0 = f32[1, 1024, 1024] add(c, s)
+
+    recv-done.1 = (f32[1, 1024, 1024], token[]) recv-done(recv.1), channel_id=1
+    new-recv-data = f32[1, 1024, 1024] get-tuple-element(recv-done.1), index=0
+
+    after-all.4 = token[] after-all()
+    send.4 = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all.4),
+      channel_id=4, frontend_attributes={
+      _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}"
+    }
+    send-done.4 = token[] send-done(send.4), channel_id=4
+    recv.4 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.4), channel_id=4,
+      frontend_attributes={
+       _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}"
+    }
+    recv-done.4 = (f32[1, 1024, 1024], token[]) recv-done(recv.4), channel_id=4
+    recv-data-4 = f32[1, 1024, 1024] get-tuple-element(recv-done.4), index=0
+    new-data = f32[1, 1024, 1024] add(new-data-0, recv-data-4)
+
+    ROOT body-result = (u32[], f32[1, 1024, 1024], f32[1, 1024, 1024]) tuple(new-count, new-data, new-recv-data)
+  }
+
+  ENTRY main {
+    c0 = u32[] constant(0)
+    f0 = f32[] constant(0.0)
+    init = f32[1, 1024, 1024] broadcast(f0), dimensions={}
+
+    after-all.2 = token[] after-all()
+    recv.2 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.2), channel_id=1,
+      frontend_attributes={
+       _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}"
+    }
+    recv-done.2 = (f32[1, 1024, 1024], token[]) recv-done(recv.2), channel_id=1
+    recv-data = f32[1, 1024, 1024] get-tuple-element(recv-done.2), index=0
+
+    while-init =  (u32[], f32[1, 1024, 1024], f32[1, 1024, 1024]) tuple(c0, init, recv-data)
+    while-result = (u32[], f32[1, 1024, 1024], f32[1, 1024, 1024]) while(while-init),
+      body=while_body, condition=while_cond,
+      backend_config={"known_trip_count":{"n":"25"}}
+
+    send-data = f32[1, 1024, 1024] get-tuple-element(while-result), index=2
+    send.2 = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all.2),
+      channel_id=1, frontend_attributes={
+      _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}"
+    }
+    send-done.2 = token[] send-done(send.2), channel_id=1
+
+    ROOT entry-result = f32[1, 1024, 1024] get-tuple-element(while-result), index=1
+  }
+  )";
+
+  TF_ASSERT_OK_AND_ASSIGN(
+      auto module,
+      ParseAndReturnVerifiedModule(
+          hlo_text, GetModuleConfig(/*enable_latency_hiding_scheduler=*/true,
+                                    /*enable_gpu_async_tracker=*/true)));
+  SequentialHloOrdering order = BuildHloOrdering(module.get());
+  const std::vector<HloInstruction*>& while_body =
+      order.SequentialOrder(*module->GetComputationWithName("while_body"))
+          ->instructions();
+  const std::vector<HloInstruction*>& main =
+      order.SequentialOrder(*module->GetComputationWithName("main"))
+          ->instructions();
+  auto get_index =
+      [](absl::string_view hlo_name,
+         const std::vector<HloInstruction*>& instruction_sequence) {
+        return absl::c_find_if(instruction_sequence,
+                               [hlo_name](HloInstruction* instruction) {
+                                 return instruction->name() == hlo_name;
+                               }) -
+               instruction_sequence.begin();
+      };
+
+  EXPECT_TRUE(HasValidFingerprint(module.get()));
+
+  // The pipelined Send-Recv in the main.
+  EXPECT_LT(get_index("recv-done.2", main), get_index("while-result", main));
+  EXPECT_LT(get_index("while-result", main), get_index("send.2", main));
+
+  // The pipelined Send-Recv in the while-body.
+  EXPECT_LT(get_index("send.1", while_body), get_index("recv.1", while_body));
+
+  // The unpipelined Send-Recv in the while-body is scheduled after the
+  // pipelined Send-Done and before the pipelined Recv.
+  EXPECT_LT(get_index("send-done.1", while_body),
+            get_index("recv.4", while_body));
+  EXPECT_LT(get_index("recv-done.4", while_body),
+            get_index("recv.1", while_body));
+}
+
+TEST_F(GpuHloScheduleTest, SkipAlreadyScheduled) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+HloModule m, is_scheduled=true
+
+fused_computation {
+  param_0 = f32[1024,1024]{1,0} parameter(0)
+  ROOT exponential.1 = f32[1024,1024]{1,0} exponential(param_0)
+}
+
+fused_computation.1 {
+  param_0.1 = f32[1024,1024]{1,0} parameter(0)
+  ROOT negate.1 = f32[1024,1024]{1,0} negate(param_0.1)
+}
+
+ENTRY e {
+  p = f32[1024,1024]{1,0} parameter(0)
+  wrapped_negate = f32[1024,1024]{1,0} fusion(p), kind=kLoop, calls=fused_computation.1
+  wrapped_exponential = f32[1024,1024]{1,0} fusion(p), kind=kLoop, calls=fused_computation
+  ROOT t = (f32[1024,1024]{1,0}, f32[1024,1024]{1,0}) tuple(wrapped_exponential, wrapped_negate)
+})")
+                    .value();
+  TF_CHECK_OK(ScheduleGpuModule(
+      module.get(), /*pointer_size=*/8,
+      /*memory_limit=*/1024 * 1024 * 1024,
+      backend().default_stream_executor()->GetDeviceDescription()));
+  EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
+// CHECK: ENTRY
+// CHECK: wrapped_negate = f32[1024,1024]{1,0}
+// CHECK: wrapped_exponential = f32[1024,1024]{1,0}
+)"));
+}
+
 class GpuHloScheduleParameterizedTest
     : public GpuHloScheduleTest,
       public ::testing::WithParamInterface<bool> {};
diff --git a/third_party/xla/xla/service/gpu/gpu_prim_cuda.h b/third_party/xla/xla/service/gpu/gpu_prim_cuda.h
new file mode 100644
index 0000000..e4ee313
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/gpu_prim_cuda.h
@@ -0,0 +1,82 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+To in writing unless required by applicable law or agreed,
+distributed on an, software distributed under the license is "AS IS"
+BASIS, WITHOUT OF ANY KIND WARRANTIES OR CONDITIONS, either express
+or implied. For the specific language governing permissions and
+limitations under the license, the license you must see.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_GPU_PRIM_CUDA_H_
+#define XLA_SERVICE_GPU_GPU_PRIM_CUDA_H_
+
+#include "tsl/platform/bfloat16.h"
+
+#if GOOGLE_CUDA
+#include "cub/block/block_load.cuh"
+#include "cub/block/block_scan.cuh"
+#include "cub/block/block_store.cuh"
+#include "cub/device/device_histogram.cuh"
+#include "cub/device/device_radix_sort.cuh"
+#include "cub/device/device_reduce.cuh"
+#include "cub/device/device_scan.cuh"
+#include "cub/device/device_segmented_radix_sort.cuh"
+#include "cub/device/device_segmented_reduce.cuh"
+#include "cub/device/device_select.cuh"
+#include "cub/iterator/counting_input_iterator.cuh"
+#include "cub/iterator/transform_input_iterator.cuh"
+#include "cub/thread/thread_operators.cuh"
+#include "cub/warp/warp_reduce.cuh"
+#include "third_party/gpus/cuda/include/cusparse.h"
+
+namespace gpuprim = ::cub;
+
+// Required for sorting Eigen::half and bfloat16.
+namespace cub {
+template <>
+__device__ __forceinline__ void ThreadStoreVolatilePtr<Eigen::half>(
+    Eigen::half *ptr, Eigen::half val, Int2Type<true> /*is_primitive*/) {
+  *reinterpret_cast<volatile uint16_t *>(ptr) =
+      Eigen::numext::bit_cast<uint16_t>(val);
+}
+
+template <>
+__device__ __forceinline__ Eigen::half ThreadLoadVolatilePointer<Eigen::half>(
+    Eigen::half *ptr, Int2Type<true> /*is_primitive*/) {
+  uint16_t result = *reinterpret_cast<volatile uint16_t *>(ptr);
+  return Eigen::numext::bit_cast<Eigen::half>(result);
+}
+
+template <>
+__device__ __forceinline__ void ThreadStoreVolatilePtr<tsl::bfloat16>(
+    tsl::bfloat16 *ptr, tsl::bfloat16 val, Int2Type<true> /*is_primitive*/) {
+  *reinterpret_cast<volatile uint16_t *>(ptr) =
+      Eigen::numext::bit_cast<uint16_t>(val);
+}
+
+template <>
+__device__ __forceinline__ tsl::bfloat16
+ThreadLoadVolatilePointer<tsl::bfloat16>(tsl::bfloat16 *ptr,
+                                         Int2Type<true> /*is_primitive*/) {
+  uint16_t result = *reinterpret_cast<volatile uint16_t *>(ptr);
+  return Eigen::numext::bit_cast<tsl::bfloat16>(result);
+}
+
+template <>
+struct NumericTraits<Eigen::half>
+    : BaseTraits</*_CATEGORY=*/FLOATING_POINT, /*_PRIMITIVE=*/true,
+                 /*_NULL_TYPE=*/false, /*_UnsignedBits=*/uint16_t,
+                 /*T=*/Eigen::half> {};
+template <>
+struct NumericTraits<tsl::bfloat16>
+    : BaseTraits</*_CATEGORY=*/FLOATING_POINT, /*_PRIMITIVE=*/true,
+                 /*_NULL_TYPE=*/false, /*_UnsignedBits=*/uint16_t,
+                 /*T=*/tsl::bfloat16> {};
+}  // namespace cub
+#endif  // GOOGLE_CUDA
+
+#endif  // XLA_SERVICE_GPU_GPU_PRIM_CUDA_H_
diff --git a/third_party/xla/xla/service/gpu/gpu_prim_rocm.h b/third_party/xla/xla/service/gpu/gpu_prim_rocm.h
new file mode 100644
index 0000000..773e534
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/gpu_prim_rocm.h
@@ -0,0 +1,56 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+To in writing unless required by applicable law or agreed,
+distributed on an, software distributed under the license is "AS IS"
+BASIS, WITHOUT OF ANY KIND WARRANTIES OR CONDITIONS, either express
+or implied. For the specific language governing permissions and
+limitations under the license, the license you must see.
+==============================================================================*/
+#ifndef XLA_SERVICE_GPU_GPU_PRIM_ROCM_H_
+#define XLA_SERVICE_GPU_GPU_PRIM_ROCM_H_
+
+#include "tsl/platform/bfloat16.h"
+
+#if TENSORFLOW_USE_ROCM
+
+#include "rocm/include/hipcub/hipcub.hpp"
+#include "rocm/rocm_config.h"
+namespace gpuprim = ::hipcub;
+
+// Required for sorting Eigen::half and bfloat16.
+namespace rocprim {
+namespace detail {
+
+#if (TF_ROCM_VERSION >= 50200)
+template <>
+struct float_bit_mask<Eigen::half> {
+  static constexpr uint16_t sign_bit = 0x8000;
+  static constexpr uint16_t exponent = 0x7C00;
+  static constexpr uint16_t mantissa = 0x03FF;
+  using bit_type = uint16_t;
+};
+
+template <>
+struct float_bit_mask<tsl::bfloat16> {
+  static constexpr uint16_t sign_bit = 0x8000;
+  static constexpr uint16_t exponent = 0x7F80;
+  static constexpr uint16_t mantissa = 0x007F;
+  using bit_type = uint16_t;
+};
+#endif  // TF_ROCM_VERSION >= 50200
+template <>
+struct radix_key_codec_base<Eigen::half>
+    : radix_key_codec_floating<Eigen::half, uint16_t> {};
+template <>
+struct radix_key_codec_base<tsl::bfloat16>
+    : radix_key_codec_floating<tsl::bfloat16, uint16_t> {};
+};  // namespace detail
+};  // namespace rocprim
+
+#endif  // TENSORFLOW_USE_ROCM
+#endif  // XLA_SERVICE_GPU_GPU_PRIM_ROCM_H_
diff --git a/third_party/xla/xla/service/gpu/gpu_sort_rewriter.cc b/third_party/xla/xla/service/gpu/gpu_sort_rewriter.cc
new file mode 100644
index 0000000..77d182c
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/gpu_sort_rewriter.cc
@@ -0,0 +1,238 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/gpu_sort_rewriter.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <memory>
+#include <optional>
+#include <utility>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/comparison_util.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/cub_sort_thunk.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/statusor.h"
+#include "xla/util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/logging.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+// Analyze sort comparer function.
+struct SortComputationAnalysis {
+  int key_operand;  // 0 or 1
+  bool descending;
+};
+
+std::optional<SortComputationAnalysis> AnalyzeSortComputation(
+    const HloComputation* computation) {
+  // Root instruction must be a comparison with a valid direction.
+  const HloCompareInstruction* compare =
+      DynCast<HloCompareInstruction>(computation->root_instruction());
+  if (compare == nullptr || compare->direction() == ComparisonDirection::kEq ||
+      compare->direction() == ComparisonDirection::kNe) {
+    return std::nullopt;
+  }
+
+  // Compare should operate on the function parameters for a single tensor.
+  const HloParameterInstruction* param0 =
+      DynCast<HloParameterInstruction>(compare->operand(0));
+  const HloParameterInstruction* param1 =
+      DynCast<HloParameterInstruction>(compare->operand(1));
+  if (param0 == nullptr || param1 == nullptr) {
+    return std::nullopt;
+  }
+
+  // When sorting a pair of tensors, the parameters should be adjacent.
+  int index0 = param0->parameter_number();
+  int index1 = param1->parameter_number();
+  int first_index = std::min(index0, index1);
+  if (first_index % 2 != 0 || std::max(index0, index1) != first_index + 1) {
+    return std::nullopt;
+  }
+
+  // Return the tensor index and the sort direction.
+  bool descending = compare->direction() == ComparisonDirection::kGt ||
+                    compare->direction() == ComparisonDirection::kGe;
+  bool reverse = first_index != index0;
+  return SortComputationAnalysis{first_index / 2, descending != reverse};
+}
+
+// Create runner for CUB sort operation.
+StatusOr<std::unique_ptr<CubSortRunnerInterface>> CreateRunner(
+    HloSortInstruction* sort_op, const SortComputationAnalysis& sort_config) {
+  int value_index = 1 - sort_config.key_operand;
+  return CubSortRunnerInterface::Create(
+      sort_op->operand(sort_config.key_operand)->shape().element_type(),
+      sort_op->operand_count() == 2
+          ? std::optional(sort_op->operand(value_index)->shape().element_type())
+          : std::nullopt);
+}
+
+// Verify that the sort tensor shape is supported by CUB.
+bool IsCubCompatibleSort(HloSortInstruction* sort_op) {
+  VLOG(1) << "Sort instruction: " << sort_op->name();
+  if (sort_op->operand_count() != 1 && sort_op->operand_count() != 2) {
+    VLOG(2) << "Unsupported operand count: " << sort_op->operand_count();
+    return false;
+  }
+  if (sort_op->operand(0)->shape().rank() != 1) {
+    VLOG(2) << "Only 1D shapes are supported";
+    return false;
+  }
+  if (sort_op->operand(0)->shape().dimensions(0) <
+      GpuSortRewriter::kSortSizeThreshold) {
+    VLOG(2) << "Tensor shape size is too small to see an improvement";
+    return false;
+  }
+
+  auto sort_config =
+      AnalyzeSortComputation(sort_op->called_computations().front());
+  if (!sort_config.has_value()) {
+    VLOG(2) << "Only simple compare computations are supported";
+    return false;
+  }
+  if (!CreateRunner(sort_op, *sort_config).ok()) {
+    VLOG(2) << "Unsupported operand types (no compiled CUB kernels)";
+    return false;
+  }
+  VLOG(2) << "Sort operation is compatible";
+  return true;
+}
+
+// Restore the result shape after sorting a pair of tensors.
+// The trailing argument is the scratch buffer which should be discarded.
+HloInstruction* UnpackResultPair(HloSortInstruction* sort_op,
+                                 HloInstruction* custom_call, bool swap) {
+  HloComputation* parent = sort_op->parent();
+  HloInstruction* gte0 =
+      parent->AddInstruction(HloInstruction::CreateGetTupleElement(
+          sort_op->operand(0)->shape(), custom_call, swap ? 1 : 0));
+  HloInstruction* gte1 =
+      parent->AddInstruction(HloInstruction::CreateGetTupleElement(
+          sort_op->operand(1)->shape(), custom_call, swap ? 0 : 1));
+  return parent->AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
+}
+
+}  // namespace
+
+// Rewrites a single sort instruction with a custom call.
+StatusOr<bool> GpuSortRewriter::RunOnInstruction(HloSortInstruction* sort_op) {
+  // Get the sort tensor index and direction.
+  SortComputationAnalysis sort_config =
+      AnalyzeSortComputation(sort_op->called_computations().front()).value();
+
+  // Get scratch size requirements from CUB.
+  TF_ASSIGN_OR_RETURN(auto runner, CreateRunner(sort_op, sort_config));
+  TF_ASSIGN_OR_RETURN(
+      int64_t scratch_size,
+      runner->GetScratchSize(sort_op->operand(0)->shape().dimensions(0)));
+
+  // Values are only present if sorting a pair of tensors.
+  HloInstruction* keys = sort_op->mutable_operand(0);
+  HloInstruction* values = nullptr;
+  if (sort_op->operand_count() == 2) {
+    values = sort_op->mutable_operand(1);
+    if (sort_config.key_operand == 1) {
+      std::swap(keys, values);
+    }
+  }
+
+  // Build the resulting shape for the custom call.
+  std::vector<Shape> shapes{keys->shape()};
+  std::vector<HloInstruction*> operands{keys};
+  if (values != nullptr) {
+    shapes.push_back(values->shape());
+    operands.push_back(values);
+  }
+  shapes.push_back(ShapeUtil::MakeShape(U8, {scratch_size}));
+  Shape call_shape = ShapeUtil::MakeTupleShape(absl::MakeSpan(shapes));
+
+  // Build the custom call instruction.
+  HloInstruction* custom_call =
+      sort_op->parent()->AddInstruction(HloInstruction::CreateCustomCall(
+          call_shape, absl::MakeSpan(operands), kCubDeviceRadixSortTarget));
+
+  xla::SortOptions backend_config;
+  backend_config.set_descending(sort_config.descending);
+  TF_RETURN_IF_ERROR(custom_call->set_backend_config(backend_config));
+
+  // Build the replacement instruction.
+  HloInstruction* replacement;
+  if (sort_op->operand_count() == 1) {
+    replacement =
+        sort_op->parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
+            sort_op->shape(), custom_call, 0));
+  } else {
+    replacement = UnpackResultPair(sort_op, custom_call,
+                                   /*swap=*/sort_config.key_operand == 1);
+  }
+
+  // Replace sort operation with custom call followed by GTE.
+  TF_RETURN_IF_ERROR(
+      sort_op->parent()->ReplaceInstruction(sort_op, replacement));
+  return true;
+}
+
+// Rewrites the sorts in the given computation into calls to CUB.
+StatusOr<bool> GpuSortRewriter::RunOnComputation(HloComputation* computation) {
+  std::vector<HloSortInstruction*> sort_ops;
+  for (auto* inst : computation->instructions()) {
+    HloSortInstruction* sort = DynCast<HloSortInstruction>(inst);
+    if (sort != nullptr && IsCubCompatibleSort(sort)) {
+      sort_ops.push_back(sort);
+    }
+  }
+  bool changed = false;
+  for (auto* sort : sort_ops) {
+    TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(sort));
+    changed |= result;
+  }
+  return changed;
+}
+
+// Replace compatible sort operations with custom calls.
+StatusOr<bool> GpuSortRewriter::Run(
+    HloModule* module,
+    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+  XLA_VLOG_LINES(2, "GpuSortRewriter::Run(), before:\n" + module->ToString());
+  bool changed = false;
+  for (HloComputation* computation :
+       module->MakeNonfusionComputations(execution_threads)) {
+    TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
+    changed |= result;
+  }
+  XLA_VLOG_LINES(2, "GpuSortRewriter::Run(), after:\n" + module->ToString());
+  return changed;
+}
+
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_sort_rewriter.h b/third_party/xla/xla/service/gpu/gpu_sort_rewriter.h
new file mode 100644
index 0000000..094cbc4
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/gpu_sort_rewriter.h
@@ -0,0 +1,55 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_GPU_SORT_REWRITER_H_
+#define XLA_SERVICE_GPU_GPU_SORT_REWRITER_H_
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+// Rewrites sort operations into CustomCall HLOs that call into CUB.
+// Only a subset of shapes is supported - either a single tensor with a simple
+// compare function or a pair of tensors where keys are unsigned integers.
+
+class GpuSortRewriter : public HloModulePass {
+ public:
+  absl::string_view name() const override { return "gpu-sort-rewriter"; }
+
+  // CUB radix sort is slower than XLA sort on small shapes, so do not rewrite
+  // tensors with sizes below this limit.
+  static constexpr int kSortSizeThreshold = 100000;
+
+  using HloPassInterface::Run;
+  StatusOr<bool> Run(
+      HloModule* module,
+      const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+  StatusOr<bool> RunOnInstruction(HloSortInstruction* sort_op);
+  StatusOr<bool> RunOnComputation(HloComputation* computation);
+};
+
+}  // namespace gpu
+}  // namespace xla
+
+#endif  // XLA_SERVICE_GPU_GPU_SORT_REWRITER_H_
diff --git a/third_party/xla/xla/service/gpu/gpu_sort_rewriter_test.cc b/third_party/xla/xla/service/gpu/gpu_sort_rewriter_test.cc
new file mode 100644
index 0000000..8054077
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/gpu_sort_rewriter_test.cc
@@ -0,0 +1,314 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/gpu/gpu_sort_rewriter.h"
+
+#include <gtest/gtest.h>
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/gpu/cublas_cudnn.h"
+#include "xla/service/pattern_matcher.h"
+#include "xla/service/pattern_matcher_gmock.h"
+#include "xla/statusor.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/statusor.h"
+#include "tsl/platform/test.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+namespace m = ::xla::match;
+
+class GpuSortRewriterTest : public HloTestBase {
+ public:
+  bool RunPass(HloModule* module) {
+    return GpuSortRewriter().Run(module).value();
+  }
+
+  void ExpectDirection(const HloInstruction* instruction, bool descending) {
+    auto config = instruction->backend_config<xla::SortOptions>();
+    EXPECT_EQ(config->descending(), descending);
+  }
+};
+
+// Basic sort: ascending.
+TEST_F(GpuSortRewriterTest, SortKeysLessThan) {
+  constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+  %lhs = f32[] parameter(0)
+  %rhs = f32[] parameter(1)
+  ROOT %lt = pred[] compare(%lhs, %rhs), direction=LT
+}
+
+ENTRY %main {
+  %input = f32[100000] parameter(0)
+  ROOT %sort = f32[100000] sort(%input), dimensions={0}, to_apply=%compare
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+  EXPECT_TRUE(RunPass(module.get()));
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(m::GetTupleElement(
+          m::CustomCall({kCubDeviceRadixSortTarget}, m::Parameter()), 0)));
+  ExpectDirection(module->entry_computation()->root_instruction()->operand(0),
+                  /*descending=*/false);
+}
+
+// Basic sort: descending.
+TEST_F(GpuSortRewriterTest, SortKeysGreaterThan) {
+  constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+  %lhs = f32[] parameter(0)
+  %rhs = f32[] parameter(1)
+  ROOT %gt = pred[] compare(%lhs, %rhs), direction=GT
+}
+
+ENTRY %main {
+  %input = f32[100000] parameter(0)
+  ROOT %sort = f32[100000] sort(%input), dimensions={0}, to_apply=%compare
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+  EXPECT_TRUE(RunPass(module.get()));
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(m::GetTupleElement(
+          m::CustomCall({kCubDeviceRadixSortTarget}, m::Parameter()), 0)));
+  ExpectDirection(module->entry_computation()->root_instruction()->operand(0),
+                  /*descending=*/true);
+}
+
+// Comparer swaps the parameter order -> direction is reversed.
+TEST_F(GpuSortRewriterTest, SortKeysGreaterThanSwapped) {
+  constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+  %lhs = f32[] parameter(1)
+  %rhs = f32[] parameter(0)
+  ROOT %gt = pred[] compare(%lhs, %rhs), direction=GT
+}
+
+ENTRY %main {
+  %input = f32[100000] parameter(0)
+  ROOT %sort = f32[100000] sort(%input), dimensions={0}, to_apply=%compare
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+  EXPECT_TRUE(RunPass(module.get()));
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(m::GetTupleElement(
+          m::CustomCall({kCubDeviceRadixSortTarget}, m::Parameter()), 0)));
+  ExpectDirection(module->entry_computation()->root_instruction()->operand(0),
+                  /*descending=*/false);
+}
+
+// Sort a pair of tensors, keys go first.
+TEST_F(GpuSortRewriterTest, SortPairs) {
+  constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+  %lhs_key = u32[] parameter(0)
+  %rhs_key = u32[] parameter(1)
+  %lhs_value = f32[] parameter(2)
+  %rhs_value = f32[] parameter(3)
+  ROOT %lt = pred[] compare(%lhs_key, %rhs_key), direction=LT
+}
+
+ENTRY %main {
+  %input_keys = u32[100000] parameter(0)
+  %input_values = f32[100000] parameter(1)
+  ROOT %sort = (u32[100000], f32[100000]) sort(%input_keys, %input_values),
+      dimensions={0}, to_apply=%compare
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+  EXPECT_TRUE(RunPass(module.get()));
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::Tuple(m::GetTupleElement(m::CustomCall(), 0),
+                                  m::GetTupleElement(m::CustomCall(), 1))));
+}
+
+// Sort a pair of tensors, keys go last.
+TEST_F(GpuSortRewriterTest, SortPairsSwapped) {
+  constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+  %lhs_value = f32[] parameter(0)
+  %rhs_value = f32[] parameter(1)
+  %lhs_key = u32[] parameter(2)
+  %rhs_key = u32[] parameter(3)
+  ROOT %lt = pred[] compare(%lhs_key, %rhs_key), direction=LT
+}
+
+ENTRY %main {
+  %input_values = f32[100000] parameter(0)
+  %input_keys = u32[100000] parameter(1)
+  ROOT %sort = (f32[100000], u32[100000]) sort(%input_values, %input_keys),
+      dimensions={0}, to_apply=%compare
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+  EXPECT_TRUE(RunPass(module.get()));
+  EXPECT_THAT(module->entry_computation()->root_instruction(),
+              GmockMatch(m::Tuple(m::GetTupleElement(m::CustomCall(), 1),
+                                  m::GetTupleElement(m::CustomCall(), 0))));
+}
+
+// CUB sort doesn't support more than two tensors.
+TEST_F(GpuSortRewriterTest, NoRewriteManyTensors) {
+  constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+  %lhs = f32[] parameter(0)
+  %rhs = f32[] parameter(1)
+  %unused1 = f64[] parameter(2)
+  %unused2 = f64[] parameter(3)
+  %unused3 = u64[] parameter(4)
+  %unused4 = u64[] parameter(5)
+  ROOT %lt = pred[] compare(%lhs, %rhs), direction=LT
+}
+
+ENTRY %main {
+  %input1 = f32[100000] parameter(0)
+  %input2 = f64[100000] parameter(1)
+  %input3 = u64[100000] parameter(2)
+  ROOT %sort = (f32[100000], f64[100000], u64[100000]) sort(%input1, %input2, %input3),
+      dimensions={0}, to_apply=%compare
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+  EXPECT_FALSE(RunPass(module.get()));
+}
+
+// Only 1D shapes are supported.
+TEST_F(GpuSortRewriterTest, NoRewriteManyDimensions) {
+  constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+  %lhs = f32[] parameter(0)
+  %rhs = f32[] parameter(1)
+  ROOT %lt = pred[] compare(%lhs, %rhs), direction=LT
+}
+
+ENTRY %main {
+  %input = f32[100000,4] parameter(0)
+  ROOT %sort = f32[100000,4] sort(%input), dimensions={0}, to_apply=%compare
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+  EXPECT_FALSE(RunPass(module.get()));
+}
+
+// Kernels are compiled for a subset of types.
+TEST_F(GpuSortRewriterTest, NoRewriteUnsupportedType) {
+  constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+  %lhs = pred[] parameter(0)
+  %rhs = pred[] parameter(1)
+  ROOT %lt = pred[] compare(%lhs, %rhs), direction=LT
+}
+
+ENTRY %main {
+  %input = pred[100000] parameter(0)
+  ROOT %sort = pred[100000] sort(%input), dimensions={0}, to_apply=%compare
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+  EXPECT_FALSE(RunPass(module.get()));
+}
+
+// Comparer must be a simple function.
+TEST_F(GpuSortRewriterTest, NoRewriteComplexComparer) {
+  constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+  %lhs = f32[] parameter(0)
+  %lhs_scaled = f32[] multiply(%lhs, f32[] constant(2))
+  %rhs = f32[] parameter(1)
+  ROOT %lt = pred[] compare(%lhs_scaled, %rhs), direction=LT
+}
+
+ENTRY %main {
+  %input = f32[100000] parameter(0)
+  ROOT %sort = f32[100000] sort(%input), dimensions={0}, to_apply=%compare
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+  EXPECT_FALSE(RunPass(module.get()));
+}
+
+// Comparer must use adjacent input values.
+TEST_F(GpuSortRewriterTest, NoRewriteMixedKeysValues) {
+  constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+  %lhs_key = u32[] parameter(0)
+  %rhs_key = u32[] parameter(1)
+  %lhs_value = u32[] parameter(2)
+  %rhs_value = u32[] parameter(3)
+  ROOT %mixed = pred[] compare(%rhs_key, %lhs_value), direction=LT
+}
+
+ENTRY %main {
+  %input_keys = u32[100000] parameter(0)
+  %input_values = u32[100000] parameter(1)
+  ROOT %sort = (u32[100000], u32[100000]) sort(%input_keys, %input_values),
+      dimensions={0}, to_apply=%compare
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+  EXPECT_FALSE(RunPass(module.get()));
+}
+
+// Small shapes do not see improvement from CUB sort.
+TEST_F(GpuSortRewriterTest, NoRewriteSmallSize) {
+  constexpr char kHlo[] = R"(
+HloModule TestModule
+
+%compare {
+  %lhs = f32[] parameter(0)
+  %rhs = f32[] parameter(1)
+  ROOT %lt = pred[] compare(%lhs, %rhs), direction=LT
+}
+
+ENTRY %main {
+  %input = f32[1000] parameter(0)
+  ROOT %sort = f32[1000] sort(%input), dimensions={0}, to_apply=%compare
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+  EXPECT_FALSE(RunPass(module.get()));
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/gpu_symbol_repository.h b/third_party/xla/xla/service/gpu/gpu_symbol_repository.h
new file mode 100644
index 0000000..231abdf
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/gpu_symbol_repository.h
@@ -0,0 +1,33 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_GPU_GPU_SYMBOL_REPOSITORY_H_
+#define XLA_SERVICE_GPU_GPU_SYMBOL_REPOSITORY_H_
+
+#include <optional>
+
+#include "xla/service/symbol_repository.h"
+#include "xla/xla.pb.h"
+
+namespace xla::gpu {
+
+// GPU-specific fields for SymbolRepositories.
+struct GpuBackendSpecificData : public BackendSpecificData {
+  std::optional<GpuCompilationEnvironment> gpu_compilation_environment;
+};
+
+}  // namespace xla::gpu
+
+#endif  // XLA_SERVICE_GPU_GPU_SYMBOL_REPOSITORY_H_
diff --git a/third_party/xla/xla/service/gpu/gpu_transfer_manager.h b/third_party/xla/xla/service/gpu/gpu_transfer_manager.h
index 5dcc291..5006955 100644
--- a/third_party/xla/xla/service/gpu/gpu_transfer_manager.h
+++ b/third_party/xla/xla/service/gpu/gpu_transfer_manager.h
@@ -48,6 +48,8 @@
   GpuTransferManager(const GpuTransferManager&) = delete;
   GpuTransferManager& operator=(const GpuTransferManager&) = delete;
 
+  bool PackSubbyteTypes() const override { return true; }
+
   // This class keeps a pool of pinned memory
   // (StreamExecutor::HostMemoryAllocate()) that serves ReadDynamicShapes().
   // This is a bit of a hack: Callers like TensorFlow already have a full pinned
diff --git a/third_party/xla/xla/service/gpu/cublas_lt_matmul_thunk.cc b/third_party/xla/xla/service/gpu/gpublas_lt_matmul_thunk.cc
similarity index 87%
rename from third_party/xla/xla/service/gpu/cublas_lt_matmul_thunk.cc
rename to third_party/xla/xla/service/gpu/gpublas_lt_matmul_thunk.cc
index f856b01..4b3d4f1 100644
--- a/third_party/xla/xla/service/gpu/cublas_lt_matmul_thunk.cc
+++ b/third_party/xla/xla/service/gpu/gpublas_lt_matmul_thunk.cc
@@ -13,7 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "xla/service/gpu/cublas_lt_matmul_thunk.h"
+#include "xla/service/gpu/gpublas_lt_matmul_thunk.h"
 
 #include <memory>
 #include <utility>
@@ -55,11 +55,7 @@
 
 Status CublasLtMatmulThunk::ExecuteOnStream(const ExecuteParams& params) {
   TF_ASSIGN_OR_RETURN(auto plan, GetMatmulPlan(params.stream));
-  if (!algorithm_) {
-    TF_ASSIGN_OR_RETURN(auto algorithms, plan->GetAlgorithms());
-    TF_RET_CHECK(algorithm_idx_ >= 0 && algorithm_idx_ < algorithms.size());
-    algorithm_ = algorithms[algorithm_idx_];
-  }
+  TF_ASSIGN_OR_RETURN(auto algorithm, GetMatmulAlgorithm(plan));
 
   VLOG(3) << "Running cublas_lt matmul thunk";
   const BufferAllocations& allocs = *params.buffer_allocations;
@@ -95,7 +91,7 @@
       params.stream, allocs.GetDeviceAddress(a_buffer_),
       allocs.GetDeviceAddress(b_buffer_), allocs.GetDeviceAddress(c_buffer_),
       allocs.GetDeviceAddress(d_buffer_), bias, aux, a_scale, b_scale, c_scale,
-      d_scale, d_amax, *algorithm_, scratch_allocator);
+      d_scale, d_amax, *algorithm, scratch_allocator);
 }
 
 StatusOr<se::gpu::BlasLt::MatmulPlan*> CublasLtMatmulThunk::GetMatmulPlan(
@@ -110,5 +106,19 @@
   return it->second.get();
 }
 
+StatusOr<std::optional<se::gpu::BlasLt::MatmulAlgorithm> >
+CublasLtMatmulThunk::GetMatmulAlgorithm(
+    const se::gpu::BlasLt::MatmulPlan* plan) {
+  absl::MutexLock lock(&matmul_algorithm_cache_mutex_);
+  auto it = matmul_algorithm_cache_.find(plan);
+  if (it == matmul_algorithm_cache_.end()) {
+    TF_ASSIGN_OR_RETURN(auto algorithms, plan->GetAlgorithms());
+    TF_RET_CHECK(algorithm_idx_ >= 0 && algorithm_idx_ < algorithms.size());
+    auto algorithm = algorithms[algorithm_idx_];
+    it = matmul_algorithm_cache_.emplace(plan, algorithm).first;
+  }
+  return it->second;
+}
+
 }  // namespace gpu
 }  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/cublas_lt_matmul_thunk.h b/third_party/xla/xla/service/gpu/gpublas_lt_matmul_thunk.h
similarity index 84%
rename from third_party/xla/xla/service/gpu/cublas_lt_matmul_thunk.h
rename to third_party/xla/xla/service/gpu/gpublas_lt_matmul_thunk.h
index b2a1cb2..5a394db 100644
--- a/third_party/xla/xla/service/gpu/cublas_lt_matmul_thunk.h
+++ b/third_party/xla/xla/service/gpu/gpublas_lt_matmul_thunk.h
@@ -13,8 +13,8 @@
 limitations under the License.
 ==============================================================================*/
 
-#ifndef XLA_SERVICE_GPU_CUBLAS_LT_MATMUL_THUNK_H_
-#define XLA_SERVICE_GPU_CUBLAS_LT_MATMUL_THUNK_H_
+#ifndef XLA_SERVICE_GPU_GPUBLAS_LT_MATMUL_THUNK_H_
+#define XLA_SERVICE_GPU_GPUBLAS_LT_MATMUL_THUNK_H_
 
 #include <memory>
 #include <optional>
@@ -50,12 +50,19 @@
  private:
   StatusOr<se::gpu::BlasLt::MatmulPlan*> GetMatmulPlan(
       const stream_executor::Stream* stream);
+  StatusOr<std::optional<se::gpu::BlasLt::MatmulAlgorithm> > GetMatmulAlgorithm(
+      const se::gpu::BlasLt::MatmulPlan* plan);
 
   absl::Mutex matmul_plans_cache_mutex_;
   absl::flat_hash_map<const stream_executor::Stream*,
                       se::gpu::BlasLt::MatmulPlanPtr>
       matmul_plans_cache_ ABSL_GUARDED_BY(matmul_plans_cache_mutex_);
 
+  absl::Mutex matmul_algorithm_cache_mutex_;
+  absl::flat_hash_map<const se::gpu::BlasLt::MatmulPlan*,
+                      se::gpu::BlasLt::MatmulAlgorithm>
+      matmul_algorithm_cache_ ABSL_GUARDED_BY(matmul_algorithm_cache_mutex_);
+
   GemmConfig gemm_config_;
   se::gpu::BlasLt::Epilogue epilogue_;
   int64_t algorithm_idx_;
@@ -70,10 +77,9 @@
   BufferAllocation::Slice c_scale_buffer_;
   BufferAllocation::Slice d_scale_buffer_;
   BufferAllocation::Slice d_amax_buffer_;
-  std::optional<se::gpu::BlasLt::MatmulAlgorithm> algorithm_;
 };
 
 }  // namespace gpu
 }  // namespace xla
 
-#endif  // XLA_SERVICE_GPU_CUBLAS_LT_MATMUL_THUNK_H_
+#endif  // XLA_SERVICE_GPU_GPUBLAS_LT_MATMUL_THUNK_H_
diff --git a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc
index ed81951..0605af7 100644
--- a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc
+++ b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc
@@ -26,6 +26,7 @@
 #include <vector>
 
 #include "absl/algorithm/container.h"
+#include "absl/log/check.h"
 #include "absl/types/span.h"
 #include "xla/hlo/ir/hlo_casting_utils.h"
 #include "xla/hlo/ir/hlo_computation.h"
@@ -271,13 +272,19 @@
                         fusion_arguments.push_back(&argument);
                       });
 
+  auto is_4bit = [](const HloInstruction* arg) {
+    return primitive_util::Is4BitType(arg->shape().element_type());
+  };
+  bool has_4_bit_input = absl::c_any_of(fusion_arguments, is_4bit);
+  bool has_4_bit_output = absl::c_any_of(hlo_roots, is_4bit);
+
   std::optional<TransposeDescription> tiled_transpose_hero =
       FindConsistentTransposeHero(hlo_roots, heroes);
 
   return HloFusionAnalysis(std::move(backend_config), std::move(hlo_roots),
                            std::move(boundary_fn), std::move(fusion_arguments),
-                           std::move(heroes), device_info,
-                           tiled_transpose_hero);
+                           std::move(heroes), device_info, tiled_transpose_hero,
+                           has_4_bit_input, has_4_bit_output);
 }
 
 // static
@@ -306,6 +313,14 @@
     return EmitterFusionKind::kTriton;
   }
 #endif
+
+  if (has_4_bit_input_ || has_4_bit_output_) {
+    // Only loop fusions currently can handle int4 inputs/outputs, due to the
+    // special handling with IrArray needed to deal with two values occupying a
+    // single byte.
+    return EmitterFusionKind::kLoop;
+  }
+
   for (auto [root, hero] : llvm::zip(fusion_roots_, fusion_heroes_)) {
     if (IsRealReductionHero(*root, *hero)) {
       return EmitterFusionKind::kReduction;
@@ -458,6 +473,16 @@
       !MayPreventVectorization(fusion_roots_, fusion_boundary_fn_)) {
     unroll_factor = ComputeMaxUnrollFactor(num_elements);
   }
+  if (has_4_bit_output_ && unroll_factor == 1) {
+    // Ensure a single thread writes to a byte containing two int4 values. The
+    // HLO Verifier ensures each int4 array has an even number of elements so
+    // it's safe to set the unroll_factor to 2. Setting unroll_factor is safe
+    // even if MayPreventVectorization returns false, as the
+    // MayPreventVectorization check is an optimization, not a correctness
+    // requirement.
+    CHECK_EQ(num_elements % 2, 0);
+    unroll_factor = 2;
+  }
   VLOG(2) << "Unroll factor: " << unroll_factor;
 
   if (GetEmitterFusionKind() == EmitterFusionKind::kScatter) {
diff --git a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h
index 131d9c0..309aa9a 100644
--- a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h
+++ b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h
@@ -91,14 +91,17 @@
                     std::vector<const HloInstruction*> fusion_arguments,
                     std::vector<const HloInstruction*> fusion_heroes,
                     const se::DeviceDescription* device_info,
-                    std::optional<TransposeDescription> tiled_transpose)
+                    std::optional<TransposeDescription> tiled_transpose,
+                    bool has_4_bit_input, bool has_4_bit_output)
       : fusion_backend_config_(std::move(fusion_backend_config)),
         fusion_roots_(std::move(fusion_roots)),
         fusion_boundary_fn_(std::move(fusion_boundary_fn)),
         fusion_arguments_(std::move(fusion_arguments)),
         fusion_heroes_(std::move(fusion_heroes)),
         device_info_(device_info),
-        tiled_transpose_(tiled_transpose) {}
+        tiled_transpose_(tiled_transpose),
+        has_4_bit_input_(has_4_bit_input),
+        has_4_bit_output_(has_4_bit_output) {}
 
   const Shape& GetElementShape() const;
   int SmallestInputDtypeBits() const;
@@ -127,6 +130,8 @@
   std::vector<const HloInstruction*> fusion_heroes_;
   const se::DeviceDescription* device_info_;
   std::optional<TransposeDescription> tiled_transpose_;
+  const bool has_4_bit_input_ = false;
+  const bool has_4_bit_output_ = false;
 
   std::optional<ReductionCodegenInfo> reduction_codegen_info_;
   std::optional<TilingScheme> transpose_tiling_scheme_;
diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils.cc b/third_party/xla/xla/service/gpu/ir_emission_utils.cc
index 165398a..a28628c 100644
--- a/third_party/xla/xla/service/gpu/ir_emission_utils.cc
+++ b/third_party/xla/xla/service/gpu/ir_emission_utils.cc
@@ -33,6 +33,7 @@
 #include "xla/hlo/ir/hlo_instruction.h"
 #include "xla/hlo/ir/hlo_opcode.h"
 #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
+#include "xla/primitive_util.h"
 #include "xla/service/gpu/hlo_traversal.h"
 #include "xla/service/gpu/target_util.h"
 #include "xla/service/hlo_parser.h"
@@ -387,7 +388,7 @@
 }
 
 StatusOr<BufferAllocation::Slice> GetAllocationSlice(
-    mlir::Value v, absl::Span<const BufferAllocation> allocations,
+    mlir::Value v, absl::Span<const BufferAllocation* const> allocations,
     std::string* constant_name) {
   if (constant_name) {
     constant_name->clear();
@@ -413,9 +414,10 @@
           mlir::dyn_cast_or_null<mlir::memref::ViewOp>(v.getDefiningOp())) {
     TF_RET_CHECK(view.getSource().isa<mlir::BlockArgument>());
 
+    const BufferAllocation* allocation = allocations[GetAllocationIndex(
+        view.getSource().cast<mlir::BlockArgument>(), constant_name)];
     return BufferAllocation::Slice(
-        &allocations[GetAllocationIndex(
-            view.getSource().cast<mlir::BlockArgument>(), constant_name)],
+        allocation,
         mlir::cast<mlir::arith::ConstantOp>(view.getByteShift().getDefiningOp())
             .getValue()
             .cast<mlir::IntegerAttr>()
@@ -433,12 +435,13 @@
         module.lookupSymbol(get_global.getName()));
     int64_t index =
         global->getAttrOfType<mlir::IntegerAttr>("lmhlo.alloc").getInt();
-    return BufferAllocation::Slice(&allocations[index], 0,
-                                   allocations[index].size());
+
+    return BufferAllocation::Slice(allocations[index], 0,
+                                   allocations[index]->size());
   }
   if (auto arg = v.dyn_cast<mlir::BlockArgument>()) {
     return BufferAllocation::Slice(
-        &allocations[GetAllocationIndex(arg, constant_name)], 0, size);
+        allocations[GetAllocationIndex(arg, constant_name)], 0, size);
   }
 
   return Unimplemented(
@@ -492,7 +495,7 @@
 
 bool CanEmitFusedDynamicUpdateSliceInPlaceForGpu(
     mlir::lmhlo::FusionOp fusion,
-    absl::Span<const BufferAllocation> allocations) {
+    absl::Span<const BufferAllocation* const> allocations) {
   std::vector<mlir::mhlo::DynamicUpdateSliceOp> dus_ops =
       GetOutputDefiningDynamicUpdateSliceOps(fusion);
 
@@ -620,15 +623,21 @@
 }
 
 Shape GetShape(mlir::Value value) {
+  Shape shape;
   if (value.getType().isa<mlir::MemRefType>()) {
-    return TypeToShape(value.getType());
+    shape = TypeToShape(value.getType());
   } else if (value.getType().isa<mlir::TensorType>()) {
-    return GetShapeFromTensorType(value);
+    shape = GetShapeFromTensorType(value);
   } else if (value.getType().isa<mlir::TupleType>()) {
-    return TypeToShape(value.getType());
+    shape = TypeToShape(value.getType());
+  } else {
+    LOG(FATAL) << "Unexpected value type to get shape for";
   }
-  LOG(FATAL) << "Unexpected value type to get shape for";
-  return {};
+  if (primitive_util::Is4BitType(shape.element_type())) {
+    // 4-bit types are always packed on the GPU
+    shape.mutable_layout()->set_element_size_in_bits(4);
+  }
+  return shape;
 }
 
 std::optional<TransposeDescription> FindTiledTranspose(
diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils.h b/third_party/xla/xla/service/gpu/ir_emission_utils.h
index 65440ca..8903531 100644
--- a/third_party/xla/xla/service/gpu/ir_emission_utils.h
+++ b/third_party/xla/xla/service/gpu/ir_emission_utils.h
@@ -111,12 +111,14 @@
 }
 
 StatusOr<BufferAllocation::Slice> GetAllocationSlice(
-    mlir::Value v, absl::Span<const BufferAllocation> allocations,
+    mlir::Value v, absl::Span<const BufferAllocation* const> allocations,
     std::string* constant_name = nullptr);
 
+bool IsSingleInstructionFusion(mlir::lmhlo::FusionOp fusion);
+
 bool CanEmitFusedDynamicUpdateSliceInPlaceForGpu(
     mlir::lmhlo::FusionOp fusion,
-    absl::Span<const BufferAllocation> allocations);
+    absl::Span<const BufferAllocation* const> allocations);
 
 // Returns the dynamic-update-slice instructions defining the results of a
 // fusion node. A dynamic slice update is said to be "defining" of a result if
diff --git a/third_party/xla/xla/service/gpu/ir_emitter_context.h b/third_party/xla/xla/service/gpu/ir_emitter_context.h
index 904879e..3e9787a 100644
--- a/third_party/xla/xla/service/gpu/ir_emitter_context.h
+++ b/third_party/xla/xla/service/gpu/ir_emitter_context.h
@@ -79,15 +79,11 @@
 
   std::vector<GpuExecutable::ConstantInfo>& constants() { return constants_; }
 
-  absl::Span<const BufferAllocation> allocations() const {
-    if (buffer_assignment_) {
-      return buffer_assignment_->Allocations();
-    }
+  absl::Span<const BufferAllocation* const> allocations() const {
     return allocations_;
   }
 
-  void set_allocations(absl::Span<const BufferAllocation> allocations) {
-    CHECK_EQ(nullptr, buffer_assignment_);
+  void set_allocations(absl::Span<const BufferAllocation* const> allocations) {
     allocations_ = allocations;
   }
 
@@ -106,7 +102,12 @@
  private:
   const HloModule* hlo_module_;
   const BufferAssignment* buffer_assignment_;
-  absl::Span<const BufferAllocation> allocations_;
+
+  // Stores pointer to buffer allocations in the order of the LMHLO entry args.
+  // LMHLO-based emitters need the ordering to locate the buffer allocation.
+  // This should be removed once LMHLO-based emitters are removed.
+  absl::Span<const BufferAllocation* const> allocations_;
+
   std::string platform_name_;
   const se::DeviceDescription& gpu_device_info_;
   mlir::MLIRContext* mlir_context_;
diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc
index b8e87ff..2c24214 100644
--- a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc
+++ b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc
@@ -316,13 +316,42 @@
 }
 
 Value Maximum(ImplicitLocOpBuilder& b, ValueRange values) {
-  auto cmp = Compare(b, values, mlir::mhlo::ComparisonDirection::GT);
-  return b.create<ma::SelectOp>(cmp, values[0], values[1]);
+  // ma::MaximumFOp seems to think that max(NaN, x) = x, so we don't use that.
+  //
+  // logic: isNaN(lhs) || (!isNan(rhs) && lhs >= rhs) ? lhs : rhs
+  // See also: IEEE Std 754-2008 5.11.
+  //
+  // This also works, but we wanted to make it similar to minimum.
+  // logic: isNaN(lhs) || lhs >= rhs ? lhs : rhs
+  Value lhs_is_nan =
+      Compare(b, {values[0], values[0]}, mlir::mhlo::ComparisonDirection::NE);
+  Value rhs_is_not_nan =
+      Compare(b, {values[1], values[1]}, mlir::mhlo::ComparisonDirection::EQ);
+  Value lhs_is_ge = Compare(b, values, mlir::mhlo::ComparisonDirection::GE);
+  return b.create<ma::SelectOp>(
+      b.create<ma::OrIOp>(lhs_is_nan,
+                          b.create<ma::AndIOp>(rhs_is_not_nan, lhs_is_ge)),
+      values[0], values[1]);
 }
 
 Value Minimum(ImplicitLocOpBuilder& b, ValueRange values) {
-  auto cmp = Compare(b, values, mlir::mhlo::ComparisonDirection::LT);
-  return b.create<ma::SelectOp>(cmp, values[0], values[1]);
+  // ma::MinimumFOp seems to think that min(NaN, x) = x, so we don't use that.
+  //
+  // logic: isNaN(lhs) || (!isNan(rhs) && lhs <= rhs) ? lhs : rhs
+  // See also: IEEE Std 754-2008 5.11.
+  //
+  // This should also work, but the tests show that it doesn't work for
+  // minimum(x, NaN):
+  // logic: isNaN(lhs) || lhs <= rhs ? lhs : rhs
+  Value lhs_is_nan =
+      Compare(b, {values[0], values[0]}, mlir::mhlo::ComparisonDirection::NE);
+  Value rhs_is_not_nan =
+      Compare(b, {values[1], values[1]}, mlir::mhlo::ComparisonDirection::EQ);
+  Value lhs_is_le = Compare(b, values, mlir::mhlo::ComparisonDirection::LE);
+  return b.create<ma::SelectOp>(
+      b.create<ma::OrIOp>(lhs_is_nan,
+                          b.create<ma::AndIOp>(rhs_is_not_nan, lhs_is_le)),
+      values[0], values[1]);
 }
 
 // TODO(b/269489810): Contribute nicer builders to Triton, so we don't need to
@@ -456,16 +485,21 @@
   return CreateConst(b, ty, ScalarConstantValue<double>(constant, F64));
 }
 
+// Grouped properties of tiled dimensions used to generate block pointers.
 struct DimProperties {
-  DimProperties(int64_t index, Value offset, int block_size, int split_value)
+  DimProperties(int64_t index, Value pid, int block_size, int split_value)
       : index(index),
-        offset(offset),
+        pid(pid),
         block_size(block_size),
         split_value(split_value) {}
 
+  // Logical index of the dimension at the tiling-defining operation.
   int64_t index;
-  Value offset;
+  // Block program ID corresponding to this dimension.
+  Value pid;
+  // Elements of the dimension to process per block program.
   int block_size;
+  // Size of the major part of the dimension if it's split into two parts.
   int split_value;
 };
 
@@ -822,8 +856,7 @@
 //   split-K, batch, non-contracting LHS, non-contracting RHS,
 // where split-K and batch are optional.
 struct MatMulDims {
-  MatMulDims(const AutotuneResult::TritonGemmKey& config,
-             const HloDotInstruction& dot,
+  MatMulDims(const TritonGemmConfig& config, const HloDotInstruction& dot,
              const TritonFusionAnalysis& analysis);
 
   std::optional<int> out_split_k_dim_idx = std::nullopt;
@@ -851,7 +884,7 @@
 
 // Structure for parameters relating to the MatMul launch grid.
 struct MatMulLaunchConfig {
-  explicit MatMulLaunchConfig(const AutotuneResult::TritonGemmKey& config,
+  explicit MatMulLaunchConfig(const TritonGemmConfig& config,
                               const HloDotInstruction& dot,
                               const MatMulDims& dims);
 
@@ -862,15 +895,15 @@
   mt::ProgramIDDim noncontracting_program_id_dim;
 };
 
-MatMulDims::MatMulDims(const AutotuneResult::TritonGemmKey& config,
+MatMulDims::MatMulDims(const TritonGemmConfig& config,
                        const HloDotInstruction& dot,
                        const TritonFusionAnalysis& analysis) {
-  if (config.split_k() > 1) {
+  if (config.split_k > 1) {
     // split-k is always the first logical dimension.
     out_split_k_dim_idx = 0;
   }
 
-  int64_t num_split_k_dims = config.split_k() > 1 ? 1 : 0;
+  int64_t num_split_k_dims = config.split_k > 1 ? 1 : 0;
   const auto& dims = dot.dot_dimension_numbers();
   lhs_contracting_dim_idx = dims.lhs_contracting_dimensions(0);
   lhs_noncontracting_dim_idx =
@@ -906,7 +939,7 @@
           ->at(0)
           .count;
   // Contracting dimension length.
-  if (config.split_k() > 1 &&
+  if (config.split_k > 1 &&
       dot.operand(0)->operand(0)->opcode() == HloOpcode::kPad) {
     // Unpadded LHS shape:  [..., k, ...]
     // Padded LHS shape:    [..., padded_k, ...]
@@ -917,7 +950,7 @@
     k = unpadded_lhs_shape.dimensions(dims.lhs_contracting_dimensions(0) - 1);
   } else {
     k = dot.operand(0)->shape().dimensions(dims.lhs_contracting_dimensions(0)) *
-        config.split_k();
+        config.split_k;
   }
 
   auto* lhs_noncontracting_split_spec =
@@ -943,11 +976,11 @@
   CHECK_GE(n, 1);
 }
 
-MatMulLaunchConfig::MatMulLaunchConfig(
-    const AutotuneResult::TritonGemmKey& config, const HloDotInstruction& dot,
-    const MatMulDims& dims)
-    : grid_m((dims.m + config.block_m() - 1) / config.block_m()),
-      grid_n((dims.n + config.block_n() - 1) / config.block_n()) {
+MatMulLaunchConfig::MatMulLaunchConfig(const TritonGemmConfig& config,
+                                       const HloDotInstruction& dot,
+                                       const MatMulDims& dims)
+    : grid_m((dims.m + config.block_m - 1) / config.block_m),
+      grid_n((dims.n + config.block_n - 1) / config.block_n) {
   int64_t batch_size = dims.lhs_noncontracting_split.value_or(
       dims.out_batch_dim_idx.has_value()
           ? dot.shape().dimensions(*dims.out_batch_dim_idx)
@@ -965,29 +998,29 @@
   if (large_batch) {
     batch_program_id_dim = mt::ProgramIDDim::X;
     noncontracting_program_id_dim = mt::ProgramIDDim::Y;
-    launch_dims = {{batch_size, grid_m * grid_n, config.split_k()},
-                   {config.num_warps() * WarpSize(), 1, 1}};
+    launch_dims = {{batch_size, grid_m * grid_n, config.split_k},
+                   {config.num_warps * WarpSize(), 1, 1}};
   } else {
     batch_program_id_dim = mt::ProgramIDDim::Y;
     noncontracting_program_id_dim = mt::ProgramIDDim::X;
     launch_dims =
-        LaunchDimensions{{grid_m * grid_n, batch_size, config.split_k()},
-                         {config.num_warps() * WarpSize(), 1, 1}};
+        LaunchDimensions{{grid_m * grid_n, batch_size, config.split_k},
+                         {config.num_warps * WarpSize(), 1, 1}};
   }
 }
 
-void ValidateMatMulConfig(const AutotuneResult::TritonGemmKey& config,
+void ValidateMatMulConfig(const TritonGemmConfig& config,
                           const HloDotInstruction& dot) {
-  CHECK_GE(config.split_k(), 1);
-  CHECK_GE(config.block_m(), 16);
-  CHECK_GE(config.block_k(), 16);
-  CHECK_GE(config.block_n(), 16);
+  CHECK_GE(config.split_k, 1);
+  CHECK_GE(config.block_m, 16);
+  CHECK_GE(config.block_k, 16);
+  CHECK_GE(config.block_n, 16);
 
   const auto& dims = dot.dot_dimension_numbers();
   int num_batch_dims =
-      dims.lhs_batch_dimensions_size() - (config.split_k() > 1 ? 1 : 0);
+      dims.lhs_batch_dimensions_size() - (config.split_k > 1 ? 1 : 0);
   CHECK_LE(num_batch_dims, 1);
-  if (config.split_k() > 1) {
+  if (config.split_k > 1) {
     // Split-K dimension has to be the first batch one and have an index
     // just before the contracting one.
     const int lhs_split_k_dim_idx = dims.lhs_contracting_dimensions(0) - 1;
@@ -995,9 +1028,9 @@
     // Size of this dimension has to match the split_k value.
     CHECK_EQ(dims.lhs_batch_dimensions(0), lhs_split_k_dim_idx);
     CHECK_EQ(dims.rhs_batch_dimensions(0), rhs_split_k_dim_idx);
-    CHECK_EQ(config.split_k(),
+    CHECK_EQ(config.split_k,
              dot.operand(0)->shape().dimensions(lhs_split_k_dim_idx));
-    CHECK_EQ(config.split_k(),
+    CHECK_EQ(config.split_k,
              dot.operand(1)->shape().dimensions(rhs_split_k_dim_idx));
   }
 
@@ -1007,7 +1040,7 @@
   CHECK_EQ(dims.rhs_contracting_dimensions_size(), 1);
 
   CHECK_EQ(dot.operand(0)->shape().rank(),
-           2 + (config.split_k() > 1 ? 1 : 0) + num_batch_dims);
+           2 + (config.split_k > 1 ? 1 : 0) + num_batch_dims);
 }
 
 struct Side {
@@ -1113,7 +1146,6 @@
       if (spec == nullptr) {
         return;
       }
-      const int64_t stride = spec->at(0).stride;
       int64_t count = spec->at(0).count;
       if (side.scope == TritonFusionAnalysis::Scope::OUTPUT &&
           properties.index == dims_.out_lhs_noncontracting_dim_idx &&
@@ -1126,8 +1158,12 @@
         boundary_checks.push_back(bounds.size());
       }
       bounds.push_back(Cst64(count));
-      strides.push_back(Cst64(stride));
-      block_offsets.push_back(properties.offset);
+      strides.push_back(Cst64(spec->at(0).stride));
+      block_offsets.push_back(
+          (properties.pid == nullptr)
+              ? Cst32(0)
+              : b_.create<ma::MulIOp>(properties.pid,
+                                      Cst32(properties.block_size)));
       tensor_offsets.push_back(Cst32(spec->at(0).slice_start));
       block_dims.push_back(properties.block_size);
       dim_order.emplace(dim_order.begin(), dim_order.size());
@@ -1176,10 +1212,10 @@
       const TensorIterationSpec::DimIterationSpec* spec = analysis_.IterSpec(
           TritonFusionAnalysis::Scope::OUTPUT, hlo, *dims_.out_split_k_dim_idx);
       if (spec != nullptr) {
-        int64_t stride_split_k = spec->at(0).stride;
-        Value offset_split_k =
-            b_.create<ma::MulIOp>(ConvertScalar(pid_k), Cst(stride_split_k));
-        base = AddPtr(b_, base, offset_split_k);
+        CHECK(pid_k != nullptr);
+        base = AddPtr(b_, base,
+                      b_.create<ma::MulIOp>(ConvertScalar(pid_k),
+                                            Cst(spec->at(0).stride)));
       }
     }
 
@@ -1225,8 +1261,7 @@
 LaunchDimensions GetMatMulLaunchDimensions(
     const TritonFusionAnalysis& analysis,
     absl::Span<const HloInstruction* const> roots,
-    const FusionBoundaryFn& fusion_boundary,
-    const AutotuneResult::TritonGemmKey& config) {
+    const FusionBoundaryFn& fusion_boundary, const TritonGemmConfig& config) {
   const auto* dot = static_cast<const HloDotInstruction*>(
       HloFindIf(roots, fusion_boundary, [](const HloInstruction& node) {
         return node.opcode() == HloOpcode::kDot;
@@ -1241,8 +1276,7 @@
 Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path,
                   const TritonFusionAnalysis& analysis,
                   const HloComputation* computation, mlir::triton::FuncOp fn,
-                  const AutotuneResult::TritonGemmKey& config,
-                  int shmem_budget) {
+                  const TritonGemmConfig& config, int shmem_budget) {
   const HloDotInstruction* dot_instr = DynCast<HloDotInstruction>(
       hlo_query::GetFirstInstructionWithOpcode(*computation, HloOpcode::kDot));
   // Use 32-bit indexing if addressing any of the inputs or the output (which
@@ -1251,7 +1285,7 @@
   bool use_64bit_indexing =
       ShapeUtil::ElementsIn(dot_instr->operand(0)->shape()) > INT_MAX ||
       ShapeUtil::ElementsIn(dot_instr->operand(1)->shape()) > INT_MAX ||
-      ShapeUtil::ElementsIn(dot_instr->shape()) * config.split_k() > INT_MAX;
+      ShapeUtil::ElementsIn(dot_instr->shape()) * config.split_k > INT_MAX;
   Type index_ty = builder.getIntegerType(use_64bit_indexing ? 64 : 32);
 
   const HloInstruction* root = dot_instr->parent()->root_instruction();
@@ -1265,10 +1299,10 @@
   Type i32_ty = b.getI32Type();
 
   ValidateMatMulConfig(config, *dot_instr);
-  const int split_k = config.split_k();
-  const int block_m = config.block_m();
-  const int block_k = config.block_k();
-  const int block_n = config.block_n();
+  const int split_k = config.split_k;
+  const int block_m = config.block_m;
+  const int block_k = config.block_k;
+  const int block_n = config.block_n;
 
   const MatMulDims dims(config, *dot_instr, analysis);
   const MatMulLaunchConfig launch_config(config, *dot_instr, dims);
@@ -1284,7 +1318,9 @@
 
   auto pid_nc =
       b.create<mt::GetProgramIdOp>(launch_config.noncontracting_program_id_dim);
-  auto pid_k = b.create<mt::GetProgramIdOp>(mt::ProgramIDDim::Z);
+  Value pid_k = (split_k > 1)
+                    ? b.create<mt::GetProgramIdOp>(mt::ProgramIDDim::Z)
+                    : Value{};
 
   auto group_id = b.create<ma::DivSIOp>(pid_nc, c32(width));
   ma::ConstantOp group_m_op = c32(group_m);
@@ -1296,13 +1332,8 @@
 
   auto pid_m = b.create<ma::AddIOp>(first_pid_m,
                                     b.create<ma::RemSIOp>(pid_nc, group_size));
-  auto pid_m_offset = b.create<ma::MulIOp>(pid_m, c32(block_m));
-
   auto pid_n = b.create<ma::DivSIOp>(b.create<ma::RemSIOp>(pid_nc, c32(width)),
                                      group_size);
-  auto pid_n_offset = b.create<ma::MulIOp>(pid_n, c32(block_n));
-
-  auto pid_k_offset = b.create<ma::MulIOp>(pid_k, c32(block_k));
 
   mlir::FloatType acc_ty = emitter.GetDotAccumulatorType();
 
@@ -1314,26 +1345,26 @@
   absl::flat_hash_map<int, const HloInstruction*> iter_args_to_parameters;
   absl::flat_hash_map<int, std::vector<int32_t>> iter_args_to_boundary_checks;
 
-  Side lhs{TritonFusionAnalysis::Scope::LHS,
-           /*tiled_dims=*/
-           {DimProperties(dims.lhs_noncontracting_dim_idx, pid_m_offset,
-                          block_m, /*split_value=*/1),
-            DimProperties(dims.lhs_contracting_dim_idx, pid_k_offset, block_k,
-                          split_k)},
-           dims.lhs_batch_dim_idx};
-  Side rhs{TritonFusionAnalysis::Scope::RHS,
-           /*tiled_dims=*/
-           {DimProperties(dims.rhs_contracting_dim_idx, pid_k_offset, block_k,
-                          split_k),
-            DimProperties(dims.rhs_noncontracting_dim_idx, pid_n_offset,
-                          block_n, /*split_value=*/1)},
-           dims.rhs_batch_dim_idx};
+  Side lhs{
+      TritonFusionAnalysis::Scope::LHS,
+      /*tiled_dims=*/
+      {DimProperties(dims.lhs_noncontracting_dim_idx, pid_m, block_m,
+                     /*split_value=*/1),
+       DimProperties(dims.lhs_contracting_dim_idx, pid_k, block_k, split_k)},
+      dims.lhs_batch_dim_idx};
+  Side rhs{
+      TritonFusionAnalysis::Scope::RHS,
+      /*tiled_dims=*/
+      {DimProperties(dims.rhs_contracting_dim_idx, pid_k, block_k, split_k),
+       DimProperties(dims.rhs_noncontracting_dim_idx, pid_n, block_n,
+                     /*split_value=*/1)},
+      dims.rhs_batch_dim_idx};
   Side out{TritonFusionAnalysis::Scope::OUTPUT,
            /*tiled_dims=*/
-           {DimProperties(dims.out_lhs_noncontracting_dim_idx, pid_m_offset,
-                          block_m, /*split_value=*/1),
-            DimProperties(dims.out_rhs_noncontracting_dim_idx, pid_n_offset,
-                          block_n, /*split_value=*/1)},
+           {DimProperties(dims.out_lhs_noncontracting_dim_idx, pid_m, block_m,
+                          /*split_value=*/1),
+            DimProperties(dims.out_rhs_noncontracting_dim_idx, pid_n, block_n,
+                          /*split_value=*/1)},
            dims.out_batch_dim_idx};
 
   auto body_builder = [&](mlir::OpBuilder&, mlir::Location, Value ki,
@@ -1389,10 +1420,14 @@
     if (need_masking) {
       auto elements_in_tile =
           b.create<ma::SubIOp>(CreateConst(b, i32_ty, dims.k), ki);
-      auto range_k = b.create<ma::AddIOp>(
-          Splat(b, b.create<ma::MulIOp>(pid_k, CreateConst(b, i32_ty, block_k)),
-                block_k),
-          Range(b, block_k));
+      auto range_k = Range(b, block_k);
+      if (pid_k != nullptr) {
+        range_k = b.create<ma::AddIOp>(
+            range_k,
+            Splat(b,
+                  b.create<ma::MulIOp>(pid_k, CreateConst(b, i32_ty, block_k)),
+                  block_k));
+      }
       auto apply_mask = [&](int64_t dim, Value input) {
         auto ty = input.getType().cast<mlir::RankedTensorType>();
         TensorValue range_expanded = b.create<mt::ExpandDimsOp>(range_k, dim)
@@ -1416,8 +1451,13 @@
                         });
 
     // Execute matrix multiplication of input tiles and pass the accumulator.
+    // TODO(manany): Should be looked into once we enable Hopper workloads.
+    // maxNumImpreciseAcc flag was introduced for Hopper to accumulate in a
+    // lower precision than the output type. The change was introduced here:
+    // https://github.com/openai/triton/commit/31b0c521427109a8eda609b58d756c380b21599a
     Value accumulator_next = b.create<mt::DotOp>(dot_input_lhs, dot_input_rhs,
-                                                 iter_args.back(), allow_tf32);
+                                                 iter_args.back(), allow_tf32,
+                                                 /*maxNumImpreciseAcc=*/0);
     iter_args_next.push_back(accumulator_next);
 
     b.create<mlir::scf::YieldOp>(iter_args_next);
@@ -1491,8 +1531,7 @@
 
 LaunchDimensions GetSoftMaxLaunchDimensions(
     absl::Span<const HloInstruction* const> roots,
-    const FusionBoundaryFn& fusion_boundary,
-    const AutotuneResult::TritonGemmKey& config) {
+    const FusionBoundaryFn& fusion_boundary, const TritonGemmConfig& config) {
   const HloInstruction* reduce =
       HloFindIf(roots, fusion_boundary, [](const HloInstruction& node) {
         return node.opcode() == HloOpcode::kReduce;
@@ -1505,13 +1544,13 @@
     num_rows *= reduce_input_shape.dimensions_minor(minor_axis);
   }
 
-  return {{num_rows, 1, 1}, {config.num_warps() * WarpSize(), 1, 1}};
+  return {{num_rows, 1, 1}, {config.num_warps * WarpSize(), 1, 1}};
 }
 
 Status EmitSoftMax(mlir::OpBuilder builder, absl::string_view libdevice_path,
                    const TritonFusionAnalysis& analysis,
                    const HloComputation* computation, mlir::triton::FuncOp fn,
-                   const AutotuneResult::TritonGemmKey& config, int) {
+                   const TritonGemmConfig& config, int) {
   const HloInstruction* root = computation->root_instruction();
   auto loc = mlir::NameLoc::get(builder.getStringAttr(root->name()));
   ImplicitLocOpBuilder b(loc, builder);
@@ -1542,40 +1581,40 @@
   CHECK_EQ(reduce->dimensions()[0], reduce_input_shape.rank() - 1);
 
   int row_len = reduce_input_shape.dimensions_minor(0);
-  int block_row = 1;
+  int block_size = 1;
 
-  // block_row must be a power of two.
-  while (block_row < row_len) {
-    block_row *= 2;
+  // block_size must be a power of two.
+  while (block_size < row_len) {
+    block_size *= 2;
   }
 
-  Value row_index = b.create<ma::ExtSIOp>(
+  Value pid = b.create<ma::ExtSIOp>(
       b.getI64Type(), b.create<mt::GetProgramIdOp>(mt::ProgramIDDim::X));
   Value row_stride = CreateConst(b, b.getI32Type(), row_len);
 
   absl::flat_hash_map<const HloInstruction*, Value> values_out;
   auto make_tensor_pointer = [&](Value base) {
     Value offset = b.create<ma::MulIOp>(
-        row_index, b.create<ma::ExtSIOp>(b.getI64Type(), row_stride));
+        pid, b.create<ma::ExtSIOp>(b.getI64Type(), row_stride));
     return b.create<mt::MakeTensorPtrOp>(
         /*base=*/AddPtr(b, base, offset),
         /*shape=*/ValueRange{CreateConst(b, b.getI64Type(), row_len)},
         /*strides=*/ValueRange{CreateConst(b, b.getI64Type(), 1)},
         /*offsets=*/ValueRange{CreateConst(b, b.getI32Type(), 0)},
-        /*tensorShape=*/std::vector<int32_t>{block_row},
+        /*tensorShape=*/std::vector<int32_t>{block_size},
         /*order=*/std::vector<int32_t>{0});
   };
 
   std::vector<int32_t> boundary_checks;
-  if (block_row != row_len) {
+  if (block_size != row_len) {
     boundary_checks.push_back(0);
   }
   values_out[computation->parameter_instruction(0)] = EmitParameterLoad(
       b, make_tensor_pointer(fn.getArgument(0)), boundary_checks);
   // Dimension 0 is the reduced one by construction and it's the only one
   // present in the tile shapes.
-  std::vector<DimProperties> tiled_dims = {
-      DimProperties(0, row_index, block_row, /*split_value=*/1)};
+  std::vector<DimProperties> tiled_dims = {DimProperties(
+      /*index=*/0, pid, block_size, /*split_value=*/1)};
   TF_ASSIGN_OR_RETURN(
       Value result,
       EmitScope(b, libdevice_path, &analysis,
@@ -1635,9 +1674,8 @@
 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> CreateTritonModule(
     const TritonFusionAnalysis& analysis, absl::string_view fn_name,
     const HloComputation* hlo_computation,
-    const se::DeviceDescription& device_info,
-    const AutotuneResult::TritonGemmKey& config, TritonIrEmitter ir_emitter,
-    mlir::MLIRContext& mlir_context) {
+    const se::DeviceDescription& device_info, const TritonGemmConfig& config,
+    TritonIrEmitter ir_emitter, mlir::MLIRContext& mlir_context) {
   mlir_context.loadDialect<mt::TritonDialect>();
   mlir::OpBuilder b(&mlir_context);
   auto loc = mlir::NameLoc::get(b.getStringAttr(hlo_computation->name()));
@@ -1688,9 +1726,9 @@
     const TritonFusionAnalysis& analysis, absl::string_view fn_name,
     const HloComputation* hlo_computation, absl::string_view fusion_kind,
     const se::CudaComputeCapability& cc,
-    const se::DeviceDescription& device_info,
-    const AutotuneResult::TritonGemmKey& config, llvm::Module* llvm_module,
-    TritonIrEmitter ir_emitter, mlir::MLIRContext& mlir_context) {
+    const se::DeviceDescription& device_info, const TritonGemmConfig& config,
+    llvm::Module* llvm_module, TritonIrEmitter ir_emitter,
+    mlir::MLIRContext& mlir_context) {
   if (fusion_kind == kTritonGemmFusionKind) {
     // This is a heuristic that serves as a proxy for register usage and code
     // size.
@@ -1717,9 +1755,9 @@
     // See go/tiling-heuristic for more details.
     constexpr int64_t kComplexityHeuristicLimit = 9000;
     int64_t complexity_heuristic_value =
-        (config.block_m() * config.block_n() +
-         (config.block_m() + config.block_n()) * config.block_k()) /
-        config.num_warps();
+        (config.block_m * config.block_n +
+         (config.block_m + config.block_n) * config.block_k) /
+        config.num_warps;
     VLOG(2) << "Complexity heuristic: " << complexity_heuristic_value;
     if (complexity_heuristic_value > kComplexityHeuristicLimit) {
       return ResourceExhausted("Tiling complexity heuristic exceeded: %d > %d",
@@ -1734,7 +1772,7 @@
                          config, ir_emitter, mlir_context));
 
   VLOG(3) << hlo_computation->ToString(HloPrintOptions::ShortParsable());
-  VLOG(2) << config.ShortDebugString();
+  VLOG(2) << config.ToString();
 
   // Compile Triton kernel to LLVM.
   std::optional<llvm::raw_fd_ostream> log_stream;
@@ -1780,7 +1818,7 @@
     }
   }
 
-  CreateTritonPipeline(pm, cc, config.num_warps(), config.num_stages());
+  CreateTritonPipeline(pm, cc, config.num_warps, config.num_stages);
   if (log_stream.has_value()) {
     pm.printAsTextualPipeline(log_stream.value());
     log_stream->write("\n\n", 2);
diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton.h b/third_party/xla/xla/service/gpu/ir_emitter_triton.h
index 9edf278..52d66dd 100644
--- a/third_party/xla/xla/service/gpu/ir_emitter_triton.h
+++ b/third_party/xla/xla/service/gpu/ir_emitter_triton.h
@@ -29,6 +29,7 @@
 #include "xla/service/gpu/gemm_rewriter_triton.h"
 #include "xla/service/gpu/hlo_traversal.h"
 #include "xla/service/gpu/launch_dimensions.h"
+#include "xla/service/gpu/matmul_utils.h"
 #include "xla/statusor.h"
 #include "xla/stream_executor/device_description.h"
 #include "triton/Dialect/Triton/IR/Dialect.h"
@@ -44,32 +45,27 @@
 LaunchDimensions GetMatMulLaunchDimensions(
     const TritonFusionAnalysis& analysis,
     absl::Span<const HloInstruction* const> roots,
-    const FusionBoundaryFn& fusion_boundary,
-    const AutotuneResult::TritonGemmKey& config);
+    const FusionBoundaryFn& fusion_boundary, const TritonGemmConfig& config);
 // Use tiling and execution parameters from 'config'.
 Status EmitMatMul(mlir::OpBuilder b, absl::string_view libdevice_path,
                   const TritonFusionAnalysis& analysis,
                   const HloComputation* computation, mlir::triton::FuncOp fn,
-                  const AutotuneResult::TritonGemmKey& config,
-                  int shmem_budget);
+                  const TritonGemmConfig& config, int shmem_budget);
 
 // Compute the launch dimensions for the given Triton SoftMax.
 LaunchDimensions GetSoftMaxLaunchDimensions(
     absl::Span<const HloInstruction* const> roots,
-    const FusionBoundaryFn& fusion_boundary,
-    const AutotuneResult::TritonGemmKey& config);
+    const FusionBoundaryFn& fusion_boundary, const TritonGemmConfig& config);
 // Generate Softmax in Triton IR inside 'fn'.
 // Use execution parameters from 'config'.
 Status EmitSoftMax(mlir::OpBuilder b, absl::string_view libdevice_path,
                    const TritonFusionAnalysis& analysis,
                    const HloComputation* computation, mlir::triton::FuncOp fn,
-                   const AutotuneResult::TritonGemmKey& config,
-                   int shmem_budget);
+                   const TritonGemmConfig& config, int shmem_budget);
 
 using TritonIrEmitter = std::function<Status(
     mlir::OpBuilder, absl::string_view, const TritonFusionAnalysis& analysis,
-    const HloComputation*, mlir::triton::FuncOp,
-    const AutotuneResult::TritonGemmKey&, int)>;
+    const HloComputation*, mlir::triton::FuncOp, const TritonGemmConfig&, int)>;
 
 // Generate Triton IR by running the provided generator and compile it into LLVM
 // IR.
@@ -78,18 +74,17 @@
     const TritonFusionAnalysis& analysis, absl::string_view fn_name,
     const HloComputation* hlo_computation, absl::string_view fusion_kind,
     const se::CudaComputeCapability& cc,
-    const se::DeviceDescription& device_info,
-    const AutotuneResult::TritonGemmKey& config, llvm::Module* llvm_module,
-    TritonIrEmitter ir_emitter, mlir::MLIRContext& mlir_context);
+    const se::DeviceDescription& device_info, const TritonGemmConfig& config,
+    llvm::Module* llvm_module, TritonIrEmitter ir_emitter,
+    mlir::MLIRContext& mlir_context);
 
 // Creates the initial Triton module for the given fusion. Visible for testing,
 // use TritonWrapper instead.
 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> CreateTritonModule(
     const TritonFusionAnalysis& analysis, absl::string_view fn_name,
     const HloComputation* hlo_computation,
-    const se::DeviceDescription& device_info,
-    const AutotuneResult::TritonGemmKey& config, TritonIrEmitter ir_emitter,
-    mlir::MLIRContext& mlir_context);
+    const se::DeviceDescription& device_info, const TritonGemmConfig& config,
+    TritonIrEmitter ir_emitter, mlir::MLIRContext& mlir_context);
 
 }  // namespace gpu
 }  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton_large_test.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton_large_test.cc
index 2fc8733..c5d98bb 100644
--- a/third_party/xla/xla/service/gpu/ir_emitter_triton_large_test.cc
+++ b/third_party/xla/xla/service/gpu/ir_emitter_triton_large_test.cc
@@ -41,9 +41,10 @@
 ENTRY e {
   arg0 = f16[65536,32800] parameter(0)
   arg1 = f16[32800,32] parameter(1)
-  ROOT custom-call = f16[65536,32] custom-call(arg0, arg1),
+  gemm = (f16[65536,32], s8[0]) custom-call(arg0, arg1),
     custom_call_target="__cublas$gemm",
     backend_config="{\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\"}"
+  ROOT get-tuple-element = f16[65536,32] get-tuple-element((f16[65536,32], s8[0]) gemm), index=0
 }
 )";
 
diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton_parametrized_test.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton_parametrized_test.cc
index 24c6869..ff027d0 100644
--- a/third_party/xla/xla/service/gpu/ir_emitter_triton_parametrized_test.cc
+++ b/third_party/xla/xla/service/gpu/ir_emitter_triton_parametrized_test.cc
@@ -210,13 +210,14 @@
   p1 = $0[33,68]{1,0} parameter(1)
   p0 = f32[15,33]{1,0} parameter(0)
   fusion = f32[33,68]{1,0} fusion(p1), kind=kLoop, calls=fused_computation
-  ROOT custom-call = f32[15,68]{1,0} custom-call(p0, fusion),
+  gemm = (f32[15,68]{1,0}, s8[0]{0}) custom-call(p0, fusion),
     custom_call_target="__cublas$$gemm",
     backend_config={"alpha_real":1,"beta":0,"dot_dimension_numbers":
       {"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"],
       "lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},
       "alpha_imag":0,"precision_config":
       {"operand_precision":["HIGHEST","HIGHEST"]},"epilogue":"DEFAULT"}
+   ROOT get-tuple-element = f32[15,68]{1,0} get-tuple-element((f32[15,68]{1,0}, s8[0]{0}) gemm), index=0
 })";
   const std::string hlo_ref = absl::Substitute(
       kHloRefTemplate, primitive_util::LowercasePrimitiveTypeName(data_type),
@@ -322,13 +323,14 @@
   p1 = $0[11,63]{1,0} parameter(1)
   p0 = f32[92,11]{1,0} parameter(0)
   fusion = f32[11,63]{1,0} fusion(p1, p2), kind=kLoop, calls=fused_computation
-  ROOT custom-call = f32[92,63]{1,0} custom-call(p0, fusion),
+  gemm = (f32[92,63]{1,0}, s8[0]{0}) custom-call(p0, fusion),
     custom_call_target="__cublas$$gemm",
     backend_config={"alpha_real":1,"beta":0,"dot_dimension_numbers":
       {"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"],
       "lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},
       "alpha_imag":0,"precision_config":
       {"operand_precision":["HIGHEST","HIGHEST"]},"epilogue":"DEFAULT"}
+  ROOT get-tuple-element = f32[92,63]{1,0} get-tuple-element((f32[92,63]{1,0}, s8[0]{0}) gemm), index=0
 })";
   const std::string hlo_ref = absl::Substitute(
       kHloRefTemplate, primitive_util::LowercasePrimitiveTypeName(data_type),
@@ -448,13 +450,14 @@
   p1 = $0[11,63]{1,0} parameter(1)
   p0 = f32[92,11]{1,0} parameter(0)
   fusion = f32[11,63]{1,0} fusion(p1, p2), kind=kLoop, calls=fused_computation
-  ROOT custom-call = f32[92,63]{1,0} custom-call(p0, fusion),
+  gemm = (f32[92,63]{1,0}, s8[0]{0}) custom-call(p0, fusion),
     custom_call_target="__cublas$$gemm",
     backend_config={"alpha_real":1,"beta":0,"dot_dimension_numbers":
       {"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"],
       "lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},
       "alpha_imag":0,"precision_config":
       {"operand_precision":["HIGHEST","HIGHEST"]},"epilogue":"DEFAULT"}
+  ROOT get-tuple-element = f32[92,63]{1,0} get-tuple-element((f32[92,63]{1,0}, s8[0]{0}) gemm), index=0
 })";
   const std::string hlo_ref = absl::Substitute(
       kHloRefTemplate, primitive_util::LowercasePrimitiveTypeName(data_type),
@@ -555,13 +558,14 @@
   p0 = $1[92,11]{1,0} parameter(0)
   fusion = $1[11,63]{1,0} fusion(p1, p2, p3), kind=kLoop,
     calls=fused_computation
-  ROOT custom-call = $1[92,63]{1,0} custom-call(p0, fusion),
+  gemm = ($1[92,63]{1,0}, s8[0]{0}) custom-call(p0, fusion),
     custom_call_target="__cublas$$gemm",
     backend_config={"alpha_real":1,"beta":0,"dot_dimension_numbers":
       {"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"],
       "lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},
       "alpha_imag":0,"precision_config":
       {"operand_precision":["HIGHEST","HIGHEST"]},"epilogue":"DEFAULT"}
+  ROOT get-tuple-element = $1[92,63]{1,0} get-tuple-element(($1[92,63]{1,0}, s8[0]{0}) gemm), index=0
 })";
   const std::string hlo_ref = absl::Substitute(
       kHloRefTemplate, primitive_util::LowercasePrimitiveTypeName(data_type1),
@@ -668,13 +672,14 @@
   p0 = f32[92,11]{1,0} parameter(0)
   fusion = f32[11,63]{1,0} fusion(p1), kind=kLoop,
     calls=fused_computation
-  ROOT custom-call = f32[92,63]{1,0} custom-call(p0, fusion),
+  gemm = (f32[92,63]{1,0}, s8[0]{0}) custom-call(p0, fusion),
     custom_call_target="__cublas$$gemm",
     backend_config={"alpha_real":1,"beta":0,"dot_dimension_numbers":
       {"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"],
       "lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},
       "alpha_imag":0,"precision_config":
       {"operand_precision":["HIGHEST","HIGHEST"]},"epilogue":"DEFAULT"}
+  ROOT get-tuple-element = f32[92,63]{1, 0} get-tuple-element((f32[92,63]{1, 0}, s8[0]{0}) gemm), index=0
 })";
   const std::string hlo_ref = absl::Substitute(
       kHloRefTemplate, primitive_util::LowercasePrimitiveTypeName(data_type));
diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc
index 63ba010..1489773 100644
--- a/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc
+++ b/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc
@@ -36,6 +36,7 @@
 #include "xla/service/gpu/gemm_rewriter_triton.h"
 #include "xla/service/gpu/gpu_device_info_for_tests.h"
 #include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/gpu/matmul_utils.h"
 #include "xla/service/gpu/tests/gpu_codegen_test.h"
 #include "xla/service/pattern_matcher.h"
 #include "xla/service/pattern_matcher_gmock.h"
@@ -87,13 +88,13 @@
 class TritonFilecheckTest : public TritonGemmTest {
  public:
   StatusOr<bool> CreateTritonIrAndFileCheck(
-      absl::string_view hlo_text, const AutotuneResult::TritonGemmKey& config,
+      absl::string_view hlo_text, const TritonGemmConfig& config,
       TritonIrEmitter emitter, absl::string_view triton_fusion_name,
       absl::string_view filecheck_pattern);
 };
 
 StatusOr<bool> TritonFilecheckTest::CreateTritonIrAndFileCheck(
-    absl::string_view hlo_text, const AutotuneResult::TritonGemmKey& config,
+    absl::string_view hlo_text, const TritonGemmConfig& config,
     TritonIrEmitter emitter, absl::string_view triton_fusion_name,
     absl::string_view filecheck_pattern) {
   TF_ASSIGN_OR_RETURN(std::unique_ptr<VerifiedHloModule> verified_module,
@@ -138,12 +139,7 @@
     calls=triton_gemm_r,
     backend_config={kind: "__triton_gemm", triton_gemm_config: {"block_m":16,"block_n":64,"block_k":32,"split_k":1,"num_stages":1,"num_warps":2}}
 })";
-  AutotuneResult::TritonGemmKey config;
-  config.set_split_k(1);
-  config.set_block_m(16);
-  config.set_block_k(32);
-  config.set_block_n(64);
-
+  TritonGemmConfig config(16, 64, 32, 1, 1, 1);
   ASSERT_THAT(CreateTritonIrAndFileCheck(kHloText, config, EmitMatMul,
                                          "triton_gemm_r", R"(
 CHECK:    tt.func @triton_fn(%[[LHS:.*]]: !tt.ptr<i8, 1> {tt.divisibility = 16 : i32}, %[[RHS:.*]]: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %[[OUT:.*]]: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}) {
@@ -162,7 +158,6 @@
 CHECK-DAG:  %[[GROUP_M:.*]] = arith.constant 8 : i32
 CHECK-DAG:  %[[WIDTH:.*]] = arith.constant 24 : i32
 CHECK:      %[[PID_NC:.*]] = tt.get_program_id x
-CHECK:      %[[PID_K:.*]] = tt.get_program_id z
 CHECK:      %[[GROUP_ID:.*]] = arith.divsi %[[PID_NC]], %[[WIDTH]]
 CHECK:      %[[FIRST_PID_M:.*]] = arith.muli %[[GROUP_ID]], %[[GROUP_M]]
 CHECK:      %[[MAX_M:.*]] = arith.subi %[[NUM_TILES_M]], %[[FIRST_PID_M]]
@@ -170,15 +165,14 @@
 CHECK:      %[[GROUP_SIZE:.*]] = arith.select %[[CMP]], %[[MAX_M]], %[[GROUP_M]]
 CHECK:      %[[PID_M:.*]] = arith.remsi %[[PID_NC]], %[[GROUP_SIZE]]
 CHECK:      %[[TILE_INDEX_M:.*]] = arith.addi %[[FIRST_PID_M]], %[[PID_M]] : i32
-CHECK:      %[[TILE_OFFSET_M:.*]] = arith.muli %[[TILE_INDEX_M]], %[[TILE_SIZE_M]]
 CHECK:      %[[TMP:.*]] = arith.remsi %[[PID_NC]], %[[WIDTH]] : i32
 CHECK:      %[[TILE_INDEX_N:.*]] = arith.divsi %[[TMP]], %[[GROUP_SIZE]] : i32
-CHECK:      %[[TILE_OFFSET_N:.*]] = arith.muli %[[TILE_INDEX_N]], %[[TILE_SIZE_N]]
-CHECK:      %[[TILE_OFFSET_K:.*]] = arith.muli %[[PID_K]], %[[TILE_SIZE_K]]
+CHECK:      %[[TILE_OFFSET_M_LHS:.*]] = arith.muli %[[TILE_INDEX_M]], %[[TILE_SIZE_M]]
 CHECK:      %[[LHS_PTR:.*]] = tt.make_tensor_ptr %[[LHS]]
-CHECK:      %[[LHS_TILE_PTR:.*]] = tt.advance %[[LHS_PTR]], [%[[TILE_OFFSET_M]], %[[TILE_OFFSET_K]]]
+CHECK:      %[[LHS_TILE_PTR:.*]] = tt.advance %[[LHS_PTR]], [%[[TILE_OFFSET_M_LHS]], %[[C0]]]
+CHECK:      %[[TILE_OFFSET_N_RHS:.*]] = arith.muli %[[TILE_INDEX_N]], %[[TILE_SIZE_N]]
 CHECK:      %[[RHS_PTR:.*]] = tt.make_tensor_ptr %[[RHS]]
-CHECK:      %[[RHS_TILE_PTR:.*]] = tt.advance %[[RHS_PTR]], [%[[TILE_OFFSET_K]], %[[TILE_OFFSET_N]]]
+CHECK:      %[[RHS_TILE_PTR:.*]] = tt.advance %[[RHS_PTR]], [%[[C0]], %[[TILE_OFFSET_N_RHS]]]
 CHECK:        %[[FOR:.*]]:3 = scf.for %[[BLOCK_K:.*]] = %[[C0]] to %[[SIZE_K]] step %[[TILE_SIZE_K]]
 CHECK-SAME:       iter_args(%[[LHS_ITER_PTR:.*]] = %[[LHS_TILE_PTR]], %[[RHS_ITER_PTR:.*]] = %[[RHS_TILE_PTR]], %[[ACC:.*]] = %[[ZERO_MN]])
 CHECK:        %[[LHS_TILE:.*]] = tt.load %[[LHS_ITER_PTR]] {boundaryCheck = array<i32: 1>
@@ -187,16 +181,13 @@
 CHECK:        %[[RHS_ITER_PTR_NEXT:.*]] = tt.advance %[[RHS_ITER_PTR]], [%[[TILE_SIZE_K]], %[[C0]]]
 CHECK:        %[[CONVERTED:.*]] = arith.sitofp %[[LHS_TILE]] : tensor<16x32xi8> to tensor<16x32xf32>
 CHECK:        %[[TILE_K_LIMIT:.*]] = arith.subi %[[SIZE_K]], %[[BLOCK_K]] : i32
-CHECK:        %[[K_OFFSET:.*]] = arith.muli %[[PID_K]], %[[TILE_SIZE_K]] : i32
-CHECK:        %[[K_OFFSET_SPLAT_K:.*]] = tt.splat %[[K_OFFSET]] : (i32) -> tensor<32xi32>
 CHECK:        %[[K_TILE_IOTA:.*]] = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
-CHECK:        %[[K_OFFSETS:.*]] = arith.addi %[[K_OFFSET_SPLAT_K]], %[[K_TILE_IOTA]] : tensor<32xi32>
-CHECK:        %[[K_OFFSETS_1K:.*]] = tt.expand_dims %[[K_OFFSETS]] {axis = 0 : i32} : (tensor<32xi32>) -> tensor<1x32xi32>
+CHECK:        %[[K_OFFSETS_1K:.*]] = tt.expand_dims %[[K_TILE_IOTA]] {axis = 0 : i32} : (tensor<32xi32>) -> tensor<1x32xi32>
 CHECK:        %[[TILE_K_LIMIT_1K:.*]] = tt.splat %[[TILE_K_LIMIT]] : (i32) -> tensor<1x32xi32>
 CHECK:        %[[LHS_INBOUNDS_1K:.*]] = arith.cmpi slt, %[[K_OFFSETS_1K]], %[[TILE_K_LIMIT_1K]] : tensor<1x32xi32>
 CHECK:        %[[LHS_INBOUNDS_MK:.*]] = tt.broadcast %[[LHS_INBOUNDS_1K]] : (tensor<1x32xi1>) -> tensor<16x32xi1>
 CHECK:        %[[LHS_MASKED:.*]] = arith.select %[[LHS_INBOUNDS_MK]], %[[CONVERTED]], %[[ZERO_MK]]
-CHECK:        %[[K_OFFSETS_K1:.*]] = tt.expand_dims %[[K_OFFSETS]] {axis = 1 : i32} : (tensor<32xi32>) -> tensor<32x1xi32>
+CHECK:        %[[K_OFFSETS_K1:.*]] = tt.expand_dims %[[K_TILE_IOTA]] {axis = 1 : i32} : (tensor<32xi32>) -> tensor<32x1xi32>
 CHECK:        %[[TILE_K_LIMIT_K1:.*]] = tt.splat %[[TILE_K_LIMIT]] : (i32) -> tensor<32x1xi32>
 CHECK:        %[[RHS_INBOUNDS_K1:.*]] = arith.cmpi slt, %[[K_OFFSETS_K1]], %[[TILE_K_LIMIT_K1]] : tensor<32x1xi32>
 CHECK:        %[[RHS_INBOUNDS_KN:.*]] = tt.broadcast %[[RHS_INBOUNDS_K1]] : (tensor<32x1xi1>) -> tensor<32x64xi1>
@@ -204,8 +195,10 @@
 CHECK:        %[[ACC_NEXT:.*]] = tt.dot %[[LHS_MASKED]], %[[RHS_MASKED]], %[[ACC]]
 CHECK:        scf.yield %[[LHS_ITER_PTR_NEXT]], %[[RHS_ITER_PTR_NEXT]], %[[ACC_NEXT]] : !tt.ptr<tensor<16x32xi8>, 1>, !tt.ptr<tensor<32x64xf32>, 1>, tensor<16x64xf32>
 CHECK:      }
+CHECK:      %[[TILE_OFFSET_M_OUT:.*]] = arith.muli %[[TILE_INDEX_M]], %[[TILE_SIZE_M]]
+CHECK:      %[[TILE_OFFSET_N_OUT:.*]] = arith.muli %[[TILE_INDEX_N]], %[[TILE_SIZE_N]]
 CHECK:      %[[OUT_PTR:.*]] = tt.make_tensor_ptr %[[OUT]], [%[[C80]], %[[SIZE_M]]], [%[[SIZE_M]], %[[C1]]], [%[[C0]], %[[C0]]] {order = array<i32: 1, 0>} : <tensor<16x64xf32>, 1>
-CHECK:      %[[OUT_OFFSET:.*]] = tt.advance %[[OUT_PTR]], [%[[TILE_OFFSET_M]], %[[TILE_OFFSET_N]]] : <tensor<16x64xf32>, 1>
+CHECK:      %[[OUT_OFFSET:.*]] = tt.advance %[[OUT_PTR]], [%[[TILE_OFFSET_M_OUT]], %[[TILE_OFFSET_N_OUT]]] : <tensor<16x64xf32>, 1>
 CHECK:      tt.store %[[OUT_OFFSET]], %[[FOR]]#2 {boundaryCheck = array<i32: 1>, cache = 1 : i32, evict = 1 : i32} : !tt.ptr<tensor<16x64xf32>, 1>, tensor<16x64xf32>
 CHECK:      tt.return
 CHECK:    }
@@ -334,13 +327,7 @@
   llvm::Module llvm_module("module", llvm_ctx);
   mlir::MLIRContext mlir_context;
 
-  AutotuneResult::TritonGemmKey config;
-  config.set_block_m(16);
-  config.set_block_n(32);
-  config.set_block_k(512);
-  config.set_split_k(1);
-  config.set_num_stages(4);
-  config.set_num_warps(8);
+  TritonGemmConfig config(16, 32, 512, 1, 4, 8);
   EXPECT_THAT(
       TritonWrapper(*TritonFusionAnalysis::Execute(*triton_dot_computation),
                     "test_fn", triton_dot_computation, kTritonGemmFusionKind,
@@ -350,10 +337,10 @@
       tsl::testing::StatusIs(tsl::error::RESOURCE_EXHAUSTED,
                              "Shared memory size limit exceeded."));
 
-  config.set_block_m(64);
-  config.set_block_n(128);
-  config.set_block_k(128);
-  config.set_num_stages(1);
+  config.block_m = 64;
+  config.block_n = 128;
+  config.block_k = 128;
+  config.num_stages = 1;
   TF_ASSERT_OK_AND_ASSIGN(
       const auto result,
       TritonWrapper(*TritonFusionAnalysis::Execute(*triton_dot_computation),
@@ -792,13 +779,7 @@
   mlir::MLIRContext mlir_context;
 
   // Fails if the tiling is too complex.
-  AutotuneResult::TritonGemmKey config;
-  config.set_block_m(512);
-  config.set_block_n(512);
-  config.set_block_k(32);
-  config.set_split_k(1);
-  config.set_num_stages(1);
-  config.set_num_warps(2);
+  TritonGemmConfig config(512, 512, 32, 1, 1, 2);
   EXPECT_THAT(
       TritonWrapper(*TritonFusionAnalysis::Execute(*triton_dot_computation),
                     "test_fn", triton_dot_computation, kTritonGemmFusionKind,
@@ -810,9 +791,9 @@
           "Tiling complexity heuristic exceeded: 147456 > 9000"));
 
   // Succeeds if the tiling is not too complex.
-  config.set_block_m(32);
-  config.set_block_n(32);
-  config.set_block_k(32);
+  config.block_m = 32;
+  config.block_n = 32;
+  config.block_k = 32;
   TF_CHECK_OK(
       TritonWrapper(*TritonFusionAnalysis::Execute(*triton_dot_computation),
                     "test_fn", triton_dot_computation, kTritonGemmFusionKind,
@@ -995,6 +976,15 @@
   }
 };
 
+class TritonGemmLevel2TestAny : public TritonGemmLevel2Test {
+ public:
+  DebugOptions GetDebugOptionsForTest() override {
+    DebugOptions debug_options = TritonGemmLevel2Test::GetDebugOptionsForTest();
+    debug_options.set_xla_gpu_triton_gemm_any(true);
+    return debug_options;
+  }
+};
+
 TEST_F(TritonGemmLevel2Test, BinaryOperationWithSmallInputsIsFused) {
   const std::string kHloText = R"(
 HloModule m
@@ -1271,6 +1261,198 @@
   EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/2e-3, /*arel=*/2e-3}));
 }
 
+TEST_F(TritonGemmLevel2TestAny, MinimumHandlesNaNsOnTheLeft) {
+  constexpr absl::string_view kHloText = R"(
+HloModule t
+
+ENTRY e {
+  p0 = f32[5,5] parameter(0)
+  neg1 = f32[] constant(-1)
+  neg1s = f32[5,5] broadcast(neg1), dimensions={}
+  nans = f32[5,5] sqrt(neg1s)
+  min = f32[5,5] minimum(nans, neg1s)
+  ROOT _ = f32[5,5] dot(p0, min),
+    lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})";
+
+  MatchOptimizedHlo(kHloText, R"(
+; CHECK: fusion
+; CHECK-SAME: kind=kCustom
+; CHECK-SAME: block_m
+)");
+
+  EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}));
+}
+
+TEST_F(TritonGemmLevel2TestAny, MinimumHandlesNaNsOnTheRight) {
+  constexpr absl::string_view kHloText = R"(
+HloModule t
+
+ENTRY e {
+  p0 = f32[5,5] parameter(0)
+  neg1 = f32[] constant(-1)
+  neg1s = f32[5,5] broadcast(neg1), dimensions={}
+  nans = f32[5,5] sqrt(neg1s)
+  min = f32[5,5] minimum(neg1s, nans)
+  ROOT _ = f32[5,5] dot(p0, min),
+    lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})";
+
+  MatchOptimizedHlo(kHloText, R"(
+; CHECK: fusion
+; CHECK-SAME: kind=kCustom
+; CHECK-SAME: block_m
+)");
+
+  EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}));
+}
+
+TEST_F(TritonGemmLevel2TestAny, MaximumHandlesNaNsOnTheLeft) {
+  constexpr absl::string_view kHloText = R"(
+HloModule t
+
+ENTRY e {
+  p0 = f32[5,5] parameter(0)
+  neg1 = f32[] constant(-1)
+  neg1s = f32[5,5] broadcast(neg1), dimensions={}
+  nans = f32[5,5] sqrt(neg1s)
+  max = f32[5,5] maximum(nans, neg1s)
+  ROOT _ = f32[5,5] dot(p0, max),
+    lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})";
+
+  MatchOptimizedHlo(kHloText, R"(
+; CHECK: fusion
+; CHECK-SAME: kind=kCustom
+; CHECK-SAME: block_m
+)");
+
+  EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}));
+}
+
+TEST_F(TritonGemmLevel2TestAny, MaximumHandlesNaNsOnTheRight) {
+  constexpr absl::string_view kHloText = R"(
+HloModule t
+
+ENTRY e {
+  p0 = f32[5,5] parameter(0)
+  neg1 = f32[] constant(-1)
+  neg1s = f32[5,5] broadcast(neg1), dimensions={}
+  nans = f32[5,5] sqrt(neg1s)
+  max = f32[5,5] maximum(neg1s, nans)
+  ROOT _ = f32[5,5] dot(p0, max),
+    lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})";
+
+  MatchOptimizedHlo(kHloText, R"(
+; CHECK: fusion
+; CHECK-SAME: kind=kCustom
+; CHECK-SAME: block_m
+)");
+
+  EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}));
+}
+
+TEST_F(TritonGemmLevel2TestAny, MinimumReturnsLHS) {
+  constexpr absl::string_view kHloText = R"(
+HloModule t
+
+ENTRY e {
+  p0 = f32[5,5] parameter(0)
+  zero = f32[] constant(0)
+  zeros = f32[5,5] broadcast(zero), dimensions={}
+  one = f32[] constant(1)
+  ones = f32[5,5] broadcast(one), dimensions={}
+  min = f32[5,5] minimum(zeros, ones)
+  ROOT _ = f32[5,5] dot(p0, min),
+  lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})";
+
+  MatchOptimizedHlo(kHloText, R"(
+; CHECK: fusion
+; CHECK-SAME: kind=kCustom
+; CHECK-SAME: block_m
+)");
+
+  EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3,
+                                                /*arel=*/1e-3}));
+}
+
+TEST_F(TritonGemmLevel2TestAny, MinimumReturnsRHS) {
+  constexpr absl::string_view kHloText = R"(
+HloModule t
+
+ENTRY e {
+  p0 = f32[5,5] parameter(0)
+  zero = f32[] constant(0)
+  zeros = f32[5,5] broadcast(zero), dimensions={}
+  one = f32[] constant(1)
+  ones = f32[5,5] broadcast(one), dimensions={}
+  min = f32[5,5] minimum(ones, zeros)
+  ROOT _ = f32[5,5] dot(p0, min),
+  lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})";
+
+  MatchOptimizedHlo(kHloText, R"(
+; CHECK: fusion
+; CHECK-SAME: kind=kCustom
+; CHECK-SAME: block_m
+)");
+
+  EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3,
+                                                /*arel=*/1e-3}));
+}
+
+TEST_F(TritonGemmLevel2TestAny, MaximumReturnsLHS) {
+  constexpr absl::string_view kHloText = R"(
+HloModule t
+
+ENTRY e {
+  p0 = f32[5,5] parameter(0)
+  zero = f32[] constant(0)
+  zeros = f32[5,5] broadcast(zero), dimensions={}
+  one = f32[] constant(1)
+  ones = f32[5,5] broadcast(one), dimensions={}
+  max = f32[5,5] maximum(ones, zeros)
+  ROOT _ = f32[5,5] dot(p0, max),
+  lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})";
+
+  MatchOptimizedHlo(kHloText, R"(
+; CHECK: fusion
+; CHECK-SAME: kind=kCustom
+; CHECK-SAME: block_m
+)");
+
+  EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3,
+                                                /*arel=*/1e-3}));
+}
+
+TEST_F(TritonGemmLevel2TestAny, MaximumReturnsRHS) {
+  constexpr absl::string_view kHloText = R"(
+HloModule t
+
+ENTRY e {
+  p0 = f32[5,5] parameter(0)
+  zero = f32[] constant(0)
+  zeros = f32[5,5] broadcast(zero), dimensions={}
+  one = f32[] constant(1)
+  ones = f32[5,5] broadcast(one), dimensions={}
+  max = f32[5,5] maximum(zeros, ones)
+  ROOT _ = f32[5,5] dot(p0, max),
+  lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})";
+
+  MatchOptimizedHlo(kHloText, R"(
+; CHECK: fusion
+; CHECK-SAME: kind=kCustom
+; CHECK-SAME: block_m
+)");
+
+  EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3,
+                                                /*arel=*/1e-3}));
+}
+
 TEST_F(TritonGemmTest, SineOutputIsNotFused) {
   const std::string kHloText = R"(
 HloModule m
@@ -1623,9 +1805,10 @@
 ENTRY e {
   arg0 = f16[5,7] parameter(0)
   arg1 = f16[7,33] parameter(1)
-  ROOT custom-call = f16[5,33] custom-call(arg0, arg1),
+  gemm = (f16[5,33], s8[0]{0}) custom-call(arg0, arg1),
     custom_call_target="__cublas$gemm",
     backend_config={"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}
+  ROOT get-tuple-element = f16[5,33]{1,0} get-tuple-element((f16[5,33]{1,0}, s8[0]{0}) gemm), index=0
 }
 )";
 
@@ -1659,9 +1842,10 @@
 ENTRY e {
   arg0 = f32[5,7] parameter(0)
   arg1 = f32[7,33] parameter(1)
-  ROOT custom-call = f32[5,33] custom-call(arg0, arg1),
+  gemm = (f32[5,33], s8[0]{0}) custom-call(arg0, arg1),
     custom_call_target="__cublas$gemm",
     backend_config={"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}
+  ROOT get-tuple-element = f32[5,33]{1,0} get-tuple-element((f32[5,33]{1,0}, s8[0]{0}) gemm), index=0
 }
 )";
 
@@ -1700,9 +1884,10 @@
 ENTRY e {
   arg0 = bf16[512,16]{1,0} parameter(0)
   arg1 = bf16[512,256]{1,0} parameter(1)
-  ROOT custom-call = bf16[16,256]{1,0} custom-call(arg0, arg1),
+  gemm = (bf16[16,256]{1,0}, s8[0]{0}) custom-call(arg0, arg1),
     custom_call_target="__cublas$gemm",
     backend_config={"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[0],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}
+  ROOT get-tuple-element = bf16[16,256]{1,0} get-tuple-element((bf16[16,256]{1,0}, s8[0]{0}) gemm), index=0
 }
 )";
 
@@ -1778,8 +1963,8 @@
       TritonWrapper(*TritonFusionAnalysis::Execute(*triton_dot_computation),
                     "test_fn", triton_dot_computation, kTritonGemmFusionKind,
                     GetCudaComputeCapability(), dev_info,
-                    config.triton_gemm_config(), &llvm_module, &EmitMatMul,
-                    mlir_context));
+                    TritonGemmConfig::FromProto(config.triton_gemm_config()),
+                    &llvm_module, &EmitMatMul, mlir_context));
   // The config is chosen so that the used memory size is slightly above the
   // 48 kB boundary of standard / optin shared memory so that any GPU that
   // has the optin one should be able to execute the test.
@@ -1817,9 +2002,10 @@
 ENTRY e {
   arg0 = f16[128,32]{1,0} parameter(0)
   arg1 = f16[64,32]{1,0} parameter(1)
-  ROOT custom-call = f16[128,64]{1,0} custom-call(arg0, arg1),
+  gemm = (f16[128,64]{1,0}, s8[0]{0}) custom-call(arg0, arg1),
     custom_call_target="__cublas$gemm",
     backend_config={"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[1],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}
+  ROOT get-tuple-element = f16[128,64]{1,0} get-tuple-element((f16[128,64]{1,0}, s8[0]{0}) gemm), index=0
 }
 )";
 
@@ -1853,9 +2039,10 @@
 ENTRY e {
   arg0 = f32[64,128]{1,0} parameter(0)
   arg1 = f32[1024,64]{1,0} parameter(1)
-  ROOT custom-call = f32[128,1024]{1,0} custom-call(arg0, arg1),
+  gemm = (f32[128,1024]{1,0}, s8[0]{0}) custom-call(arg0, arg1),
     custom_call_target="__cublas$gemm",
     backend_config={"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[0],"rhs_contracting_dimensions":[1],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}
+  ROOT get-tuple-element = f32[128,1024]{1,0} get-tuple-element((f32[128,1024]{1,0}, s8[0]{0}) gemm), index=0
 }
 )";
 
@@ -1899,9 +2086,10 @@
   p0 = s8[144,256]{1,0} parameter(0)
   fusion = bf16[144,256]{1,0} fusion(p0), kind=kInput, calls=fused_computation
   p1 = bf16[256,122]{1,0} parameter(1)
-  ROOT custom-call = bf16[144,122]{1,0} custom-call(fusion, p1),
+  gemm = (bf16[144,122]{1,0}, s8[0]{0}) custom-call(fusion, p1),
     custom_call_target="__cublas$gemm",
     backend_config={"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}
+  ROOT get-tuple-element = bf16[144,122]{1,0} get-tuple-element((bf16[144,122]{1,0}, s8[0]{0}) gemm), index=0
 }
 )";
 
@@ -2439,9 +2627,10 @@
   constant_2 = f32[] constant(321)
   parameter_0 = f32[92,11]{1,0} parameter(0)
   broadcast.2 = f32[11,63]{1,0} broadcast(constant_2), dimensions={}
-  ROOT custom-call = f32[63,92]{1,0} custom-call(broadcast.2, parameter_0),
+  gemm = (f32[63,92]{1,0}, s8[0]{0}) custom-call(broadcast.2, parameter_0),
     custom_call_target="__cublas$gemm",
     backend_config={"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["0"],"rhs_contracting_dimensions":["1"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}
+  ROOT get-tuple-element = f32[63,92]{1,0} get-tuple-element((f32[63,92]{1,0}, s8[0]{0}) gemm), index=0
 })";
 
   EXPECT_TRUE(RunAndCompareTwoModules(kHloTextRef, kHloTextTest,
@@ -2479,9 +2668,10 @@
   constant = f32[] constant(123)
   broadcast = f32[11,63]{1,0} broadcast(constant), dimensions={}
   broadcast.1 = f32[63,45]{1,0} broadcast(constant_1), dimensions={}
-  ROOT custom-call = f32[11,45]{1,0} custom-call(broadcast, broadcast.1),
+  gemm = (f32[11,45]{1,0}, s8[0]{0}) custom-call(broadcast, broadcast.1),
     custom_call_target="__cublas$gemm",
     backend_config={"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}
+  ROOT get-tuple-element = f32[11,45]{1,0} get-tuple-element((f32[11,45]{1,0}, s8[0]{0}) gemm), index=0
 })";
 
   EXPECT_TRUE(RunAndCompareTwoModules(kHloTextRef, kHloTextTest,
@@ -2584,9 +2774,10 @@
   tmp_1 = pred[3,32]{1,0} parameter(0)
   fusion.1 = f32[3,32]{1,0} fusion(tmp_7, tmp_5, tmp_1, tmp_2, tmp_3), kind=kLoop, calls=fused_computation.1
   fusion = f32[3,57]{1,0} fusion(tmp_18, tmp_14, tmp_15, tmp_16, tmp_12, /*index=5*/tmp_9, tmp_10), kind=kLoop, calls=fused_computation
-  ROOT custom-call = f32[32,57]{0,1} custom-call(fusion.1, fusion),
+  gemm = (f32[32,57]{0,1}, s8[0]{0}) custom-call(fusion.1, fusion),
     custom_call_target="__cublas$gemm",
     backend_config={"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["0"],"rhs_contracting_dimensions":["0"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}
+  ROOT get-tuple-element = f32[32,57]{0,1} get-tuple-element((f32[32,57]{0,1}, s8[0]{0}) gemm), index=0
 })";
 
   EXPECT_TRUE(RunAndCompareTwoModules(kHloTextRef, kHloTextTest,
@@ -2639,13 +2830,14 @@
   p1 = s32[11,63]{1,0} parameter(1)
   p0 = bf16[92,11]{1,0} parameter(0)
   fusion = bf16[11,63]{1,0} fusion(p1, p2), kind=kLoop, calls=fused_computation
-  ROOT custom-call = bf16[92,63]{1,0} custom-call(p0, fusion),
+  gemm = (bf16[92,63]{1,0}, s8[0]{0}) custom-call(p0, fusion),
     custom_call_target="__cublas$gemm",
     backend_config={"alpha_real":1,"beta":0,"dot_dimension_numbers":
       {"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"],
       "lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},
       "alpha_imag":0,"precision_config":
       {"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}
+  ROOT get-tuple-element = bf16[92,63]{1,0} get-tuple-element((bf16[92,63]{1,0}, s8[0]{0}) gemm), index=0
 })";
 
   EXPECT_TRUE(RunAndCompareTwoModules(kHloTextRef, kHloTextTest,
diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc
index 2727ebe..a9cd847 100644
--- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc
+++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc
@@ -33,6 +33,7 @@
 #include "absl/container/flat_hash_map.h"
 #include "absl/container/flat_hash_set.h"
 #include "absl/log/check.h"
+#include "absl/status/status.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/str_format.h"
 #include "absl/strings/str_join.h"
@@ -42,6 +43,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/IR/Argument.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/DerivedTypes.h"
@@ -70,6 +72,8 @@
 #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"  // from @llvm-project
 #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"  // from @llvm-project
 #include "mlir/Target/LLVMIR/Export.h"  // from @llvm-project
+#include "xla/ffi/api/c_api.h"
+#include "xla/ffi/ffi.h"
 #include "xla/hlo/ir/hlo_casting_utils.h"
 #include "xla/hlo/ir/hlo_computation.h"
 #include "xla/hlo/ir/hlo_instruction.h"
@@ -83,16 +87,20 @@
 #include "xla/mlir_hlo/transforms/gpu_passes.h"
 #include "xla/primitive_util.h"
 #include "xla/service/buffer_assignment.h"
+#include "xla/service/custom_call_status.h"
 #include "xla/service/custom_call_target_registry.h"
 #include "xla/service/gpu/backend_configs.pb.h"
 #include "xla/service/gpu/conditional_thunk.h"
 #include "xla/service/gpu/convolution_thunk.h"
 #include "xla/service/gpu/copy_thunk.h"
-#include "xla/service/gpu/fft_thunk.h"
 #include "xla/service/gpu/for_thunk.h"
 #include "xla/service/gpu/fused_mha_thunk.h"
+#include "xla/service/gpu/fusions/fusion_emitter.h"
 #include "xla/service/gpu/fusions/fusions.h"
+#include "xla/service/gpu/fusions/input_slices.h"
+#include "xla/service/gpu/fusions/loop.h"
 #include "xla/service/gpu/fusions/thunk_util.h"
+#include "xla/service/gpu/fusions/transpose.h"
 #include "xla/service/gpu/gemm_thunk.h"
 #include "xla/service/gpu/gpu_asm_opts_util.h"
 #include "xla/service/gpu/gpu_conv_runner.h"
@@ -116,6 +124,7 @@
 #include "xla/service/gpu/parallel_loop_emitter.h"
 #include "xla/service/gpu/replica_id_thunk.h"
 #include "xla/service/gpu/runtime3/custom_call_thunk.h"
+#include "xla/service/gpu/runtime3/fft_thunk.h"
 #include "xla/service/gpu/sequential_thunk.h"
 #include "xla/service/gpu/thunk.h"
 #include "xla/service/gpu/while_thunk.h"
@@ -147,7 +156,7 @@
 
 #if GOOGLE_CUDA || TF_HIPBLASLT
 #include "xla/service/gpu/cub_sort_thunk.h"
-#include "xla/service/gpu/cublas_lt_matmul_thunk.h"
+#include "xla/service/gpu/gpublas_lt_matmul_thunk.h"
 #include "xla/service/gpu/ir_emitter_triton.h"
 #endif  // GOOGLE_CUDA || TF_HIPBLASLT
 
@@ -338,10 +347,6 @@
 
 StatusOr<BufferAllocation::Slice> IrEmitterUnnested::GetAllocationSlice(
     mlir::Value v) {
-  if (ir_emitter_context_->emit_ir_from_hlo()) {
-    return InternalError(
-        "Getting buffer allocation for MLIR when emitting from HLO");
-  }
   return xla::gpu::GetAllocationSlice(v, ir_emitter_context_->allocations(),
                                       nullptr);
 }
@@ -371,14 +376,28 @@
       module.lookupSymbol(get_global.getName()));
   auto literal = global.getInitialValue()->dyn_cast<mlir::DenseElementsAttr>();
   TF_RET_CHECK(literal);
-  TF_ASSIGN_OR_RETURN(int element_bytes,
-                      GetElementTypeBytes(literal.getType().getElementType()));
   std::vector<uint8_t> content;
   TF_RETURN_IF_ERROR(CopyDenseElementsDataToXlaFormat(literal, &content));
-  ir_emitter_context_->emit_constant(
-      literal.getType().getNumElements(), element_bytes, global.getSymName(),
-      global->getAttrOfType<mlir::IntegerAttr>("lmhlo.alloc").getInt(), content,
-      &b_);
+  int num_elements, element_bytes;
+  if (literal.getType().getElementType().isInteger(4)) {
+    // Treat int4 constant as int8 constant with half the number of elements
+    TF_RET_CHECK(content.size() ==
+                 (literal.getType().getNumElements() + 1) / 2);
+    num_elements = content.size();
+    element_bytes = 1;
+  } else {
+    num_elements = literal.getType().getNumElements();
+    TF_ASSIGN_OR_RETURN(
+        element_bytes, GetElementTypeBytes(literal.getType().getElementType()));
+  }
+
+  int64_t arg_index =
+      global->getAttrOfType<mlir::IntegerAttr>("lmhlo.alloc").getInt();
+  int allocation_index = ir_emitter_context_->allocations()[arg_index]->index();
+
+  ir_emitter_context_->emit_constant(num_elements, element_bytes,
+                                     global.getSymName(), allocation_index,
+                                     content, &b_);
   return OkStatus();
 }
 
@@ -1030,32 +1049,6 @@
   return OkStatus();
 }
 
-Status IrEmitterUnnested::EmitCubDeviceRadixSort(mlir::Operation* op) {
-  auto radix_sort_op = mlir::cast<mlir::lmhlo_gpu::RadixSortOp>(op);
-  if (radix_sort_op.getInputs().size() != 1 &&
-      radix_sort_op.getInputs().size() != 2) {
-    return InternalError("Invalid number of operands for radix sort");
-  }
-
-  TF_ASSIGN_OR_RETURN(std::vector<BufferAllocation::Slice> operands,
-                      GetAllocationSlices(radix_sort_op.getInputs()));
-  TF_ASSIGN_OR_RETURN(std::vector<BufferAllocation::Slice> results,
-                      GetAllocationSlices(radix_sort_op.getOutput()));
-  TF_ASSIGN_OR_RETURN(BufferAllocation::Slice scratch,
-                      GetAllocationSlice(radix_sort_op.getScratch()));
-
-  auto thunk = std::make_unique<CubSortThunk>(
-      Thunk::ThunkInfo::WithProfileAnnotation(op),
-      GetShape(op->getOperand(0)).element_type(),
-      radix_sort_op.getInputs().size() == 2
-          ? std::optional(GetShape(op->getOperand(1)).element_type())
-          : std::nullopt,
-      operands, results, scratch, radix_sort_op.getDescending());
-
-  AddThunkToThunkSequence(std::move(thunk));
-  return OkStatus();
-}
-
 Status IrEmitterUnnested::EmitFusedMHAThunk(mlir::Operation* op) {
   using mlir::dyn_cast;
   using mlir::lmhlo_gpu::fusedMHAOp;
@@ -1342,6 +1335,33 @@
 }
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+
+Status IrEmitterUnnested::EmitCubDeviceRadixSort(mlir::Operation* op) {
+  auto radix_sort_op = mlir::cast<mlir::lmhlo_gpu::RadixSortOp>(op);
+  if (radix_sort_op.getInputs().size() != 1 &&
+      radix_sort_op.getInputs().size() != 2) {
+    return InternalError("Invalid number of operands for radix sort");
+  }
+
+  TF_ASSIGN_OR_RETURN(std::vector<BufferAllocation::Slice> operands,
+                      GetAllocationSlices(radix_sort_op.getInputs()));
+  TF_ASSIGN_OR_RETURN(std::vector<BufferAllocation::Slice> results,
+                      GetAllocationSlices(radix_sort_op.getOutput()));
+  TF_ASSIGN_OR_RETURN(BufferAllocation::Slice scratch,
+                      GetAllocationSlice(radix_sort_op.getScratch()));
+
+  auto thunk = std::make_unique<CubSortThunk>(
+      Thunk::ThunkInfo::WithProfileAnnotation(op),
+      GetShape(op->getOperand(0)).element_type(),
+      radix_sort_op.getInputs().size() == 2
+          ? std::optional(GetShape(op->getOperand(1)).element_type())
+          : std::nullopt,
+      operands, results, scratch, radix_sort_op.getDescending());
+
+  AddThunkToThunkSequence(std::move(thunk));
+  return OkStatus();
+}
+
 Status IrEmitterUnnested::EmitCholeskyThunk(mlir::Operation* op) {
   auto cholesky_op = mlir::cast<mlir::lmhlo_gpu::CholeskyOp>(op);
 
@@ -1445,76 +1465,146 @@
 }
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
+// Converts MLIR dictionary attribute attached to a custom call operation to a
+// custom call thunk attributes that are forwarded to the FFI handler.
+static StatusOr<CustomCallThunk::AttributesMap> BuildAttributesMap(
+    mlir::DictionaryAttr dict) {
+  CustomCallThunk::AttributesMap attributes;
+  for (auto& kv : dict) {
+    std::string_view name = kv.getName().strref();
+
+    auto integer = [&](mlir::IntegerAttr integer) {
+      switch (integer.getType().getIntOrFloatBitWidth()) {
+        case 32:
+          attributes[name] = static_cast<int32_t>(integer.getInt());
+          return OkStatus();
+        default:
+          return absl::InvalidArgumentError(absl::StrCat(
+              "Unsupported integer attribute bit width for attribute: ", name));
+      }
+    };
+
+    auto fp = [&](mlir::FloatAttr fp) {
+      switch (fp.getType().getIntOrFloatBitWidth()) {
+        case 32:
+          attributes[name] = static_cast<float>(fp.getValue().convertToFloat());
+          return OkStatus();
+        default:
+          return absl::InvalidArgumentError(absl::StrCat(
+              "Unsupported float attribute bit width for attribute: ", name));
+      }
+    };
+
+    auto str = [&](mlir::StringAttr str) {
+      attributes[name] = str.getValue().str();
+      return OkStatus();
+    };
+
+    TF_RETURN_IF_ERROR(
+        llvm::TypeSwitch<mlir::Attribute, Status>(kv.getValue())
+            .Case<mlir::IntegerAttr>(integer)
+            .Case<mlir::FloatAttr>(fp)
+            .Case<mlir::StringAttr>(str)
+            .Default([&](mlir::Attribute) {
+              return absl::InvalidArgumentError(absl::StrCat(
+                  "Unsupported attribute type for attribute: ", name));
+            }));
+  }
+  return attributes;
+}
+
 Status IrEmitterUnnested::EmitCustomCallThunk(mlir::Operation* op) {
   auto custom_call = mlir::cast<mlir::lmhlo::CustomCallOp>(op);
   const std::string call_target_name = custom_call.getCallTargetName().str();
 
-  void* call_target = CustomCallTargetRegistry::Global()->Lookup(
-      call_target_name, std::string(platform_name()));
-
-  // Typed custom calls only are supported by XLA runtime. It's ok to emit a
-  // thunk with an unresolved custom call target, as we'll never execute it.
-  bool is_typed_custom_call =
+  // Typed FFI custom calls is a replacement for legacy custom calls with
+  // a rich type safe API. It's under construction and not fully supported.
+  bool is_ffi_custom_call =
       custom_call.getApiVersion() ==
       mlir::mhlo::CustomCallApiVersion::API_VERSION_TYPED_FFI;
 
-  if (!call_target && !is_typed_custom_call) {
-    if (ir_emitter_context_->debug_options().xla_gpu_mock_custom_calls()) {
-      // Don't run anything on custom call.
+  void* call_target = CustomCallTargetRegistry::Global()->Lookup(
+      call_target_name, std::string(platform_name()));
+
+  StatusOr<XLA_FFI_Handler*> handler = ffi::FindHandler(call_target_name);
+
+  // At least one implementation should be available at run time.
+  bool found_custom_call = !is_ffi_custom_call && call_target != nullptr;
+  bool found_ffi_handler = is_ffi_custom_call && handler.ok();
+
+  if (!found_custom_call && !found_ffi_handler) {
+    auto& debug_options = ir_emitter_context_->debug_options();
+
+    // If true, then all custom calls that are not found in custom call or FFI
+    // registries will become no-op (we don't emit any thunks for them).
+    if (debug_options.xla_gpu_mock_custom_calls()) {
       return OkStatus();
     }
-    return Unimplemented(
-        "No registered implementation for custom call to \"%s\" for platform "
-        "\"%s\"",
-        call_target_name, platform_name());
+
+    // TODO(ezhulenev): Custom calls registered with an XLA runtime are not part
+    // of a legacy registry, or an FFI registry. For now we simply ignore them.
+    if (debug_options.xla_gpu_enable_xla_runtime_executable()) {
+      return OkStatus();
+    }
+
+    return absl::UnimplementedError(
+        absl::StrCat("No registered implementation for custom call to ",
+                     call_target_name, " for platform ", platform_name()));
   }
 
-  std::vector<CustomCallThunk::OptionalSlice> operands;
-  std::vector<CustomCallThunk::OptionalSlice> results;
-  if (custom_call.getTargetArgMapping()) {
-    auto values_to_slices_with_token_holes =
-        [&](mlir::ValueRange operands,
-            mlir::ArrayRef<int64_t> op_to_target_mapping, int64_t num_target)
-        -> StatusOr<std::vector<CustomCallThunk::OptionalSlice>> {
-      std::vector<CustomCallThunk::OptionalSlice> slices(num_target);
-      for (auto index_and_value_it :
-           llvm::zip(op_to_target_mapping, operands)) {
-        int64_t index = std::get<0>(index_and_value_it);
-        mlir::Value value = std::get<1>(index_and_value_it);
-        TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
-                            GetAllocationSlice(value));
-        slices[index] = slice;
-      }
-      return slices;
-    };
+  using Slices = std::vector<std::optional<CustomCallThunk::Slice>>;
 
-    mlir::lmhlo::CustomCallTargetArgMappingAttr target_mapping =
-        *custom_call.getTargetArgMapping();
-    TF_ASSIGN_OR_RETURN(operands, values_to_slices_with_token_holes(
-                                      custom_call.getArgs(),
-                                      target_mapping.getArgsToTargetArgs(),
-                                      target_mapping.getNumArgs()));
-    TF_ASSIGN_OR_RETURN(results, values_to_slices_with_token_holes(
-                                     custom_call.getOutput(),
-                                     target_mapping.getResultsToTargetResults(),
-                                     target_mapping.getNumResults()));
+  // Initialize slices and shapes from the value range.
+  auto init_from_values = [&](mlir::ValueRange values, Slices* slices) {
+    for (mlir::Value value : values) {
+      TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSlice(value));
+      slices->push_back(CustomCallThunk::Slice{slice, GetShape(value)});
+    }
+    return OkStatus();
+  };
+
+  // Initialize slices and shapes from the value range with token holes.
+  auto init_from_mapped_values = [&](mlir::ValueRange values,
+                                     absl::Span<const int64_t> target_mapping,
+                                     int64_t target_size, Slices* slices) {
+    slices->resize(target_size);
+    for (auto [index, value] : llvm::zip(target_mapping, values)) {
+      TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSlice(value));
+      (*slices)[index] = CustomCallThunk::Slice{slice, GetShape(value)};
+    }
+    return OkStatus();
+  };
+
+  Slices operands, results;
+
+  // If we have a target mapping, than the number of operands and results of a
+  // custom call handler can be more than a number of operands and results in
+  // the IR. These holes are coming from the HLO token operands and results.
+  if (auto target_mapping = custom_call.getTargetArgMapping()) {
+    auto arg_mapping = target_mapping->getArgsToTargetArgs();
+    auto res_mapping = target_mapping->getResultsToTargetResults();
+
+    TF_RETURN_IF_ERROR(
+        init_from_mapped_values(custom_call.getArgs(), arg_mapping,
+                                target_mapping->getNumArgs(), &operands));
+    TF_RETURN_IF_ERROR(
+        init_from_mapped_values(custom_call.getOutput(), res_mapping,
+                                target_mapping->getNumResults(), &results));
+
   } else {
-    auto values_to_slices = [&](mlir::ValueRange values)
-        -> StatusOr<std::vector<CustomCallThunk::OptionalSlice>> {
-      std::vector<CustomCallThunk::OptionalSlice> slices;
-      for (mlir::Value value : values) {
-        TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
-                            GetAllocationSlice(value));
-        slices.push_back(slice);
-      }
-      return slices;
-    };
-
-    TF_ASSIGN_OR_RETURN(operands, values_to_slices(custom_call.getArgs()));
-    TF_ASSIGN_OR_RETURN(results, values_to_slices(custom_call.getOutput()));
+    TF_RETURN_IF_ERROR(init_from_values(custom_call.getArgs(), &operands));
+    TF_RETURN_IF_ERROR(init_from_values(custom_call.getOutput(), &results));
   }
 
+  // For legacy custom calls we convert all API versions into the the latest
+  // status-returning one and pass backend config as an opaque string.
   CustomCallThunk::CustomCallTarget custom_call_target;
+  std::string opaque;
+
+  // For XLA FFI handlers we decode opaque backend config into attributes map
+  // at IR emission time, so that we do not need to parse MLIR at run time. For
+  // FFI handlers backend config must be a compatible MLIR dictionary.
+  CustomCallThunk::AttributesMap attributes;
 
   // For information about this calling convention, see
   // xla/g3doc/custom_call.md.
@@ -1542,29 +1632,55 @@
           reinterpret_cast<status_returning_call_type>(call_target);
       break;
     case mlir::mhlo::CustomCallApiVersion::API_VERSION_TYPED_FFI:
-      custom_call_target = [](CustomCallThunk::Stream, void**, const char*,
-                              size_t, XlaCustomCallStatus*) {
-        LOG(FATAL) << "Typed FFI custom call must be called by XLA runtime";
-      };
+      // We already checked `handler` above.
       break;
     default:
       return InternalError("Unknown custom-call API version enum value: %d",
                            custom_call.getApiVersion());
   }
 
-  // Thunks support only user-encoded string backend config.
-  std::string backend_config;
-  if (auto str = custom_call.getBackendConfig()
-                     .value_or(mlir::Attribute())
-                     .dyn_cast_or_null<mlir::StringAttr>()) {
-    backend_config = str.str();
+  auto backend_config =
+      custom_call.getBackendConfig().value_or(mlir::Attribute());
+
+  switch (custom_call.getApiVersion()) {
+    case mlir::mhlo::CustomCallApiVersion::API_VERSION_ORIGINAL:
+    case mlir::mhlo::CustomCallApiVersion::API_VERSION_STATUS_RETURNING:
+    case mlir::mhlo::CustomCallApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED:
+      if (auto str = backend_config.dyn_cast_or_null<mlir::StringAttr>()) {
+        opaque = str.str();
+        break;
+      }
+      return absl::InternalError(
+          "Unsupported backend config. Expected a string attribute");
+
+    case mlir::mhlo::CustomCallApiVersion::API_VERSION_TYPED_FFI:
+      if (auto dict = backend_config.dyn_cast_or_null<mlir::DictionaryAttr>()) {
+        TF_ASSIGN_OR_RETURN(attributes, BuildAttributesMap(dict));
+        break;
+      }
+      return absl::InternalError(
+          "Unsupported backend config. Expected a dictionary attribute");
+
+    default:
+      return InternalError("Unknown custom-call API version enum value: %d",
+                           custom_call.getApiVersion());
   }
 
-  auto thunk = std::make_unique<CustomCallThunk>(
-      Thunk::ThunkInfo::WithProfileAnnotation(op),
-      std::move(custom_call_target), std::move(operands), std::move(results),
-      backend_config);
-  AddThunkToThunkSequence(std::move(thunk));
+  auto ffi_thunk = [&] {
+    return std::make_unique<CustomCallThunk>(
+        Thunk::ThunkInfo::WithProfileAnnotation(op), *handler,
+        std::move(operands), std::move(results), std::move(attributes));
+  };
+
+  auto legacy_thunk = [&] {
+    return std::make_unique<CustomCallThunk>(
+        Thunk::ThunkInfo::WithProfileAnnotation(op),
+        std::move(custom_call_target), std::move(operands), std::move(results),
+        std::move(opaque));
+  };
+
+  AddThunkToThunkSequence(found_ffi_handler ? ffi_thunk() : legacy_thunk());
+
   return OkStatus();
 }
 
@@ -1749,7 +1865,7 @@
 #if GOOGLE_CUDA
 Status IrEmitterUnnested::EmitTritonFusion(
     const HloFusionAnalysis& hlo_fusion_analysis, mlir::Operation* op,
-    const AutotuneResult::TritonGemmKey& config,
+    const TritonGemmConfig& config,
     const absl::flat_hash_map<const mlir::Operation*, const HloInstruction*>&
         hlo_for_lmhlo) {
   // Note: In this method we can't use `BuildKernelThunk` as usual,
@@ -1804,9 +1920,8 @@
           hlo_fusion_analysis.fusion_boundary(), config);
     } else {  // Must be a MatMul
       CHECK_EQ(fusion_kind, kTritonGemmFusionKind);
-      TF_ASSIGN_OR_RETURN(
-          auto analysis,
-          TritonFusionAnalysis::Execute(*hlo_computation, config.split_k()));
+      TF_ASSIGN_OR_RETURN(auto analysis, TritonFusionAnalysis::Execute(
+                                             *hlo_computation, config.split_k));
       TF_ASSIGN_OR_RETURN(
           triton_wrapper_result,
           TritonWrapper(analysis, impl_fn_name, hlo_computation,
@@ -1851,6 +1966,63 @@
 
 #endif  // GOOGLE_CUDA
 
+// Check if the fusion instruction should be emitted as an in place dynamic
+// update slice or a memcpy fusion. The logic is copied from GetFusionEmitter.
+bool IsSpecializedLoopFusion(
+    mlir::Operation* op, absl::Span<const BufferAllocation* const> allocations,
+    HloFusionAnalysis& analysis) {
+  auto fusion_op = mlir::cast<mlir::lmhlo::FusionOp>(op);
+  if (!allocations.empty() && fusion_op != nullptr) {
+    bool is_single = IsSingleInstructionFusion(fusion_op);
+    if (!is_single &&
+        CanEmitFusedDynamicUpdateSliceInPlaceForGpu(fusion_op, allocations)) {
+      return true;
+    }
+    if (is_single && analysis.fusion_roots().size() == 1 &&
+        analysis.fusion_roots().front()->opcode() == HloOpcode::kCopy) {
+      mlir::Value operand = GetHloOperands(fusion_op).front();
+      mlir::Value output = GetHloOutputs(fusion_op).front();
+      Shape operand_shape = GetShape(operand);
+      Shape output_shape = GetShape(output);
+      if (LayoutUtil::Equal(operand_shape.layout(), output_shape.layout()) &&
+          GetAllocationSlice(operand, allocations).ok()) {
+        return true;
+      }
+    }
+  }
+  return false;
+}
+
+Status IrEmitterUnnested::EmitFusion(const HloFusionInstruction* instr,
+                                     HloFusionAnalysis& fusion_analysis) {
+  // TODO(anlunx): Support kReduction, kTriton, and kScatter.
+  std::unique_ptr<FusionInterface> emitter;
+  switch (fusion_analysis.GetEmitterFusionKind()) {
+    case HloFusionAnalysis::EmitterFusionKind::kInputSlices:
+      emitter = std::make_unique<InputSlicesFusion>(fusion_analysis);
+      break;
+    case HloFusionAnalysis::EmitterFusionKind::kLoop:
+      // TODO(anlunx): Support MemcpyFusion and InPlaceDymaicUpdateSlice.
+      emitter = std::make_unique<LoopFusion>(fusion_analysis);
+      break;
+    case HloFusionAnalysis::EmitterFusionKind::kTranspose:
+      emitter = std::make_unique<TransposeFusion>(fusion_analysis);
+      break;
+    default:
+      return FailedPrecondition(
+          "Fusion type not supported by the HLO emitter.");
+      break;
+  }
+
+  TF_ASSIGN_OR_RETURN(auto emission_result,
+                      emitter->Emit(*ir_emitter_context_, elemental_emitter_,
+                                    nullptr, *instr, kernel_reuse_cache_, &b_));
+  for (auto& thunk : emission_result.thunks) {
+    AddThunkToThunkSequence(std::move(thunk));
+  }
+  return OkStatus();
+}
+
 Status IrEmitterUnnested::EmitFusion(
     mlir::Operation* op,
     const absl::flat_hash_map<const mlir::Operation*, const HloInstruction*>&
@@ -1910,18 +2082,20 @@
           triton_config.set_num_stages(1);
           triton_config.set_num_warps(2);
         }
-        return EmitTritonFusion(fusion_analysis, fusion_op,
-                                backend_config.triton_gemm_config(),
-                                hlo_for_lmhlo);
+        return EmitTritonFusion(
+            fusion_analysis, fusion_op,
+            TritonGemmConfig::FromProto(backend_config.triton_gemm_config()),
+            hlo_for_lmhlo);
       }
       if (backend_config.kind() == kTritonSoftmaxFusionKind) {
         auto& triton_config = *backend_config.mutable_triton_gemm_config();
         triton_config.set_num_stages(1);
         triton_config.set_num_warps(
             DeriveNumWarpsFromTritonSoftmaxComputation(fused_computation));
-        return EmitTritonFusion(fusion_analysis, fusion_op,
-                                backend_config.triton_gemm_config(),
-                                hlo_for_lmhlo);
+        return EmitTritonFusion(
+            fusion_analysis, fusion_op,
+            TritonGemmConfig::FromProto(backend_config.triton_gemm_config()),
+            hlo_for_lmhlo);
       }
 #endif
       LOG(FATAL) << "Unsupported fusion kind: " << backend_config.kind();
@@ -1971,11 +2145,12 @@
 
   std::string name = GetIrNameFromLoc(select_and_scatter_op.getLoc());
 
+  const HloInstruction* init_value = select_and_scatter->operand(2);
   // IrEmitterUnnested implements kSelectAndScatter as a SequentialThunk
   // consisting of two thunks, an initializer KernelThunk that initializes
   // the output and another KernelThunk that accumulates the scattered
   // elements.
-  TF_RETURN_IF_ERROR(BuildInitializerThunk(op,
+  TF_RETURN_IF_ERROR(BuildInitializerThunk(op, select_and_scatter, init_value,
                                            select_and_scatter_op.getInitValue(),
                                            select_and_scatter_op.getOut()));
 
@@ -2882,16 +3057,16 @@
                                         launch_dimensions);
 }
 
-Status IrEmitterUnnested::BuildInitializerThunk(mlir::Operation* op,
-                                                mlir::Value init_value,
-                                                mlir::Value dest) {
+Status IrEmitterUnnested::BuildInitializerThunk(
+    mlir::Operation* op, const HloInstruction* instr,
+    const HloInstruction* init_value, mlir::Value init_value_mlir,
+    mlir::Value dest) {
   // initial value must be a scalar memref.
-  auto init_type = init_value.getType().dyn_cast<mlir::MemRefType>();
-  TF_RET_CHECK(init_type.getRank() == 0);
+  TF_RET_CHECK(init_value->shape().rank() == 0);
 
   TF_ASSIGN_OR_RETURN(std::optional<std::unique_ptr<Thunk>> constant_init_thunk,
                       BuildConstantInitializerThunk(*ir_emitter_context_, op,
-                                                    init_value, dest));
+                                                    instr, init_value, dest));
   if (constant_init_thunk) {
     AddThunkToThunkSequence(*std::move(constant_init_thunk));
     return OkStatus();
@@ -2905,8 +3080,8 @@
                       CalculateLaunchDimensions(
                           dest_shape, ir_emitter_context_->gpu_device_info()));
   TF_ASSIGN_OR_RETURN(auto ir_arrays,
-                      BuildKernelThunkForNonFusionOp(op, {init_value, dest},
-                                                     launch_dimensions));
+                      BuildKernelThunkForNonFusionOp(
+                          op, {init_value_mlir, dest}, launch_dimensions));
   auto& [inputs, outputs] = ir_arrays;
   auto init_array = inputs[0];
 
@@ -3086,9 +3261,6 @@
   if (mlir::isa<mlir::lmhlo_gpu::fusedMHABackwardOp>(op)) {
     return EmitFusedMHABackwardThunk(op);
   }
-  if (mlir::isa<mlir::lmhlo_gpu::RadixSortOp>(op)) {
-    return EmitCubDeviceRadixSort(op);
-  }
 #endif  // GOOGLE_CUDA
 
   if (mlir::isa<mlir::lmhlo_gpu::ConvForwardOp,
@@ -3101,6 +3273,9 @@
   }
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+  if (mlir::isa<mlir::lmhlo_gpu::RadixSortOp>(op)) {
+    return EmitCubDeviceRadixSort(op);
+  }
   if (mlir::isa<mlir::lmhlo_gpu::CholeskyOp>(op)) {
     if (ir_emitter_context_->emit_ir_from_hlo()) {
       return EmitCholeskyThunk(hlo_for_lmhlo.at(op));
@@ -3121,6 +3296,28 @@
   }
 
   if (mlir::isa<mlir::lmhlo::FusionOp>(op)) {
+    if (ir_emitter_context_->emit_ir_from_hlo()) {
+      const HloFusionInstruction* instr =
+          Cast<HloFusionInstruction>(hlo_for_lmhlo.at(op));
+      TF_ASSIGN_OR_RETURN(auto backend_config,
+                          instr->backend_config<FusionBackendConfig>());
+      const se::DeviceDescription& device_info =
+          ir_emitter_context_->gpu_device_info();
+      TF_ASSIGN_OR_RETURN(auto fusion_analysis,
+                          HloFusionAnalysis::Create(instr, &device_info));
+      HloFusionAnalysis::EmitterFusionKind kind =
+          fusion_analysis.GetEmitterFusionKind();
+      // TODO(anlunx): Add support for emitting kTriton, kScatter, kReduction,
+      // and specialized kLoops.
+      if (kind != HloFusionAnalysis::EmitterFusionKind::kTriton &&
+          kind != HloFusionAnalysis::EmitterFusionKind::kScatter &&
+          kind != HloFusionAnalysis::EmitterFusionKind::kReduction &&
+          !IsSpecializedLoopFusion(op, ir_emitter_context_->allocations(),
+                                   fusion_analysis)) {
+        return EmitFusion(instr, fusion_analysis);
+      }
+    }
+
     return EmitFusion(op, hlo_for_lmhlo);
   }
 
diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.h b/third_party/xla/xla/service/gpu/ir_emitter_unnested.h
index b7af4f7..2bdb58a 100644
--- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.h
+++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.h
@@ -141,14 +141,14 @@
   Status EmitConvolutionReorderThunk(mlir::Operation* op);
   Status EmitTritonFusion(
       const HloFusionAnalysis& hlo_fusion_analysis, mlir::Operation* op,
-      const AutotuneResult::TritonGemmKey& config,
+      const TritonGemmConfig& config,
       const absl::flat_hash_map<const mlir::Operation*, const HloInstruction*>&
           hlo_for_lmhlo);
   Status EmitFusedMHAThunk(mlir::Operation* op);
   Status EmitFusedMHABackwardThunk(mlir::Operation* op);
-  Status EmitCubDeviceRadixSort(mlir::Operation* op);
 #endif  // GOOGLE_CUDA
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+  Status EmitCubDeviceRadixSort(mlir::Operation* op);
   Status EmitCholeskyThunk(mlir::Operation* op);
   Status EmitCholeskyThunk(const HloInstruction* instr);
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
@@ -158,6 +158,8 @@
       mlir::Operation* op,
       const absl::flat_hash_map<const mlir::Operation*, const HloInstruction*>&
           hlo_for_lmhlo);
+  Status EmitFusion(const HloFusionInstruction* instr,
+                    HloFusionAnalysis& fusion_analysis);
   Status EmitSelectAndScatter(
       mlir::Operation* op,
       const absl::flat_hash_map<const mlir::Operation*, const HloInstruction*>&
@@ -396,8 +398,9 @@
                                  mlir::ValueRange needed_operands,
                                  const LaunchDimensions& launch_dimensions);
 
-  Status BuildInitializerThunk(mlir::Operation* op, mlir::Value init_value,
-                               mlir::Value dest);
+  Status BuildInitializerThunk(mlir::Operation* op, const HloInstruction* instr,
+                               const HloInstruction* init_value,
+                               mlir::Value init_value_mlir, mlir::Value dest);
 
   // Returns a WhileThunk that invokes thunk sequences for 'condition' and
   // 'body' sub-computations of while instruction 'hlo'.
diff --git a/third_party/xla/xla/service/gpu/kernel_arguments.cc b/third_party/xla/xla/service/gpu/kernel_arguments.cc
index a408e8b..5fcea6d 100644
--- a/third_party/xla/xla/service/gpu/kernel_arguments.cc
+++ b/third_party/xla/xla/service/gpu/kernel_arguments.cc
@@ -14,20 +14,29 @@
 ==============================================================================*/
 #include "xla/service/gpu/kernel_arguments.h"
 
-#include <utility>
 #include <optional>
+#include <utility>
+#include <vector>
 
+#include "absl/types/span.h"
 #include "llvm/ADT/STLExtras.h"
 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
-#include "xla/mlir_hlo/transforms/gpu_passes.h"
+#include "mlir/IR/Value.h"  // from @llvm-project
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/service/buffer_assignment.h"
 #include "xla/service/gpu/gpu_constants.h"
 #include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/shape_util.h"
+#include "xla/status.h"
+#include "xla/statusor.h"
+#include "tsl/platform/errors.h"
 
 namespace xla {
 namespace gpu {
 
 StatusOr<KernelArgument> KernelArgument::Create(
-    absl::Span<const BufferAllocation> allocations, mlir::Value value,
+    absl::Span<const BufferAllocation* const> allocations, mlir::Value value,
     bool is_written) {
   TF_ASSIGN_OR_RETURN(
       auto slice, xla::gpu::GetAllocationSlice(value, allocations, nullptr));
@@ -35,7 +44,7 @@
 }
 
 StatusOr<KernelArguments> KernelArguments::Create(
-    absl::Span<const BufferAllocation> allocations,
+    absl::Span<const BufferAllocation* const> allocations,
     mlir::lmhlo::FusionOp fusion) {
   auto operands = GetHloOperands(fusion);
   auto outputs = GetHloOutputs(fusion);
@@ -56,6 +65,32 @@
   return KernelArguments{std::move(kernel_arguments)};
 }
 
+StatusOr<KernelArguments> KernelArguments::Create(
+    const BufferAssignment& buffer_assignment,
+    const HloFusionInstruction* fusion) {
+  std::vector<KernelArgument> kernel_arguments;
+  for (const HloInstruction* operand : fusion->operands()) {
+    TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
+                        buffer_assignment.GetUniqueSlice(operand, {}));
+    kernel_arguments.emplace_back(KernelArgument(
+        /*value=*/nullptr, operand->shape(), slice, /*written=*/false));
+  }
+
+  TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
+      fusion->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
+        if (!subshape.IsArray()) {
+          return OkStatus();
+        }
+        TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
+                            buffer_assignment.GetUniqueSlice(fusion, index));
+        kernel_arguments.emplace_back(KernelArgument(
+            /*value=*/nullptr, subshape, slice, /*written=*/true));
+        return OkStatus();
+      }));
+
+  return KernelArguments{std::move(kernel_arguments)};
+}
+
 std::vector<KernelArgument> KernelArguments::ProcessArguments(
     std::vector<KernelArgument> kernel_arguments) {
   absl::flat_hash_set<BufferAllocation::Slice> buffers_written;
@@ -116,7 +151,7 @@
 }
 
 StatusOr<KernelArguments> KernelArguments::Create(
-    absl::Span<const BufferAllocation> allocations,
+    absl::Span<const BufferAllocation* const> allocations,
     mlir::Operation* non_fusion_op, mlir::ValueRange needed_operands) {
   std::vector<KernelArgument> kernel_arguments;
   kernel_arguments.reserve(needed_operands.size());
diff --git a/third_party/xla/xla/service/gpu/kernel_arguments.h b/third_party/xla/xla/service/gpu/kernel_arguments.h
index a795303..1292622 100644
--- a/third_party/xla/xla/service/gpu/kernel_arguments.h
+++ b/third_party/xla/xla/service/gpu/kernel_arguments.h
@@ -20,6 +20,7 @@
 #include <vector>
 
 #include "mlir/IR/Value.h"  // from @llvm-project
+#include "xla/hlo/ir/hlo_instructions.h"
 #include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h"
 #include "xla/service/buffer_assignment.h"
 #include "xla/shape.h"
@@ -33,7 +34,7 @@
 class KernelArgument {
  public:
   static StatusOr<KernelArgument> Create(
-      absl::Span<const BufferAllocation> allocations, mlir::Value value,
+      absl::Span<const BufferAllocation* const> allocations, mlir::Value value,
       bool is_written);
 
   mlir::Value value() const { return value_; }
@@ -67,11 +68,15 @@
 class KernelArguments {
  public:
   static StatusOr<KernelArguments> Create(
-      absl::Span<const BufferAllocation> allocations,
+      absl::Span<const BufferAllocation* const> allocations,
       mlir::lmhlo::FusionOp fusion);
 
   static StatusOr<KernelArguments> Create(
-      absl::Span<const BufferAllocation> allocations,
+      const BufferAssignment& buffer_assignment,
+      const HloFusionInstruction* fusion);
+
+  static StatusOr<KernelArguments> Create(
+      absl::Span<const BufferAllocation* const> allocations,
       mlir::Operation* non_fusion_op, mlir::ValueRange needed_operands);
 
   const std::vector<KernelArgument>& args() const { return args_; }
diff --git a/third_party/xla/xla/service/gpu/kernel_thunk.cc b/third_party/xla/xla/service/gpu/kernel_thunk.cc
index bbd1f12..1ddff6e 100644
--- a/third_party/xla/xla/service/gpu/kernel_thunk.cc
+++ b/third_party/xla/xla/service/gpu/kernel_thunk.cc
@@ -15,13 +15,19 @@
 
 #include "xla/service/gpu/kernel_thunk.h"
 
+#include <cstdint>
 #include <memory>
 #include <string>
 #include <utility>
+#include <variant>
 #include <vector>
 
-#include "xla/service/gpu/gpu_executable.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/service/gpu/kernel_arguments.h"
+#include "xla/service/gpu/launch_dimensions.h"
 #include "xla/service/gpu/stream_executor_util.h"
+#include "xla/service/gpu/thunk.h"
 #include "xla/stream_executor/device_memory.h"
 #include "xla/stream_executor/kernel.h"
 #include "xla/stream_executor/stream_executor.h"
@@ -42,11 +48,15 @@
 
 }  // namespace
 
-KernelThunk::KernelThunk(mlir::Operation* op, std::string kernel_name,
-                         absl::Span<const KernelArgument> kernel_arguments,
-                         LaunchDimensions launch_dimensions,
-                         int64_t shmem_bytes)
-    : Thunk(Kind::kKernel, Thunk::ThunkInfo::WithProfileAnnotation(op)),
+KernelThunk::KernelThunk(
+    std::variant<mlir::Operation*, const HloInstruction*> op,
+    std::string kernel_name, absl::Span<const KernelArgument> kernel_arguments,
+    LaunchDimensions launch_dimensions, int64_t shmem_bytes)
+    : Thunk(Kind::kKernel, std::holds_alternative<mlir::Operation*>(op)
+                               ? Thunk::ThunkInfo::WithProfileAnnotation(
+                                     std::get<mlir::Operation*>(op))
+                               : Thunk::ThunkInfo::WithProfileAnnotation(
+                                     std::get<const HloInstruction*>(op))),
       kernel_name_(std::move(kernel_name)),
       launch_dimensions_(std::move(launch_dimensions)),
       shmem_bytes_(shmem_bytes) {
@@ -59,6 +69,11 @@
     }
   }
 
+  if (std::holds_alternative<const HloInstruction*>(op)) {
+    // Skip populating MLIR values_ if emitting from HLO.
+    return;
+  }
+
   values_.reserve(kernel_arguments.size());
   for (const auto& kernel_argument : kernel_arguments) {
     if (!kernel_argument.first_with_same_slice().has_value()) {
diff --git a/third_party/xla/xla/service/gpu/kernel_thunk.h b/third_party/xla/xla/service/gpu/kernel_thunk.h
index 6fcf87a..c95372b 100644
--- a/third_party/xla/xla/service/gpu/kernel_thunk.h
+++ b/third_party/xla/xla/service/gpu/kernel_thunk.h
@@ -52,7 +52,8 @@
   // output of the computation. Also, the values must correspond to each arg
   // directly, not to their base allocation (e.g. they can be the result of an
   // `mlir::memref::ViewOp`).
-  KernelThunk(mlir::Operation* op, std::string kernel_name,
+  KernelThunk(std::variant<mlir::Operation*, const HloInstruction*> op,
+              std::string kernel_name,
               absl::Span<const KernelArgument> kernel_arguments,
               LaunchDimensions launch_dimensions, int64_t shmem_bytes);
   KernelThunk(const KernelThunk&) = delete;
diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
index 96970b0..20c2224 100644
--- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
+++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
@@ -114,7 +114,10 @@
                  << ". Defaulting to telling LLVM that we're compiling for sm_"
                  << sm_version;
   }
-  return absl::StrCat("sm_", sm_version);
+  // If the target is sm_90, hard code it to sm_90a so that all instructions
+  // can be used. We don't need the portability that sm_90 gives.
+  std::string_view extension = sm_version == 90 ? "a" : "";
+  return absl::StrCat("sm_", sm_version, extension);
 }
 
 // Convenience function for producing a name of a temporary compilation product
diff --git a/third_party/xla/xla/service/gpu/matmul_utils.cc b/third_party/xla/xla/service/gpu/matmul_utils.cc
index 3462dd2..8136ac0 100644
--- a/third_party/xla/xla/service/gpu/matmul_utils.cc
+++ b/third_party/xla/xla/service/gpu/matmul_utils.cc
@@ -18,6 +18,7 @@
 #include <algorithm>
 #include <cstdint>
 #include <optional>
+#include <string>
 #include <tuple>
 #include <type_traits>
 #include <utility>
@@ -25,12 +26,14 @@
 
 #include "absl/algorithm/container.h"
 #include "absl/log/check.h"
+#include "absl/strings/str_cat.h"
 #include "absl/types/span.h"
 #include "xla/hlo/ir/hlo_instruction.h"
 #include "xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h"
 #include "xla/primitive_util.h"
 #include "xla/shape.h"
 #include "xla/shape_util.h"
+#include "xla/status.h"
 #include "xla/status_macros.h"
 #include "xla/statusor.h"
 #include "xla/stream_executor/blas.h"
@@ -290,12 +293,13 @@
     absl::Span<const int64_t> rhs_batch_dims,
     absl::Span<const int64_t> rhs_contracting_dims, const Shape& output_shape,
     double alpha_real, double alpha_imag, double beta,
-    std::optional<int64_t> algorithm, int64_t compute_precision) {
+    std::optional<int64_t> algorithm, int64_t compute_precision, bool grad_x,
+    bool grad_y) {
   return GemmConfig::For(lhs_shape, lhs_batch_dims, lhs_contracting_dims,
                          rhs_shape, rhs_batch_dims, rhs_contracting_dims,
                          /*c_shape=*/output_shape, /*bias_shape_ptr=*/nullptr,
                          output_shape, alpha_real, alpha_imag, beta, algorithm,
-                         compute_precision);
+                         compute_precision, grad_x, grad_y);
 }
 
 /*static*/ StatusOr<GemmConfig> GemmConfig::For(
@@ -305,7 +309,7 @@
     absl::Span<const int64_t> rhs_contracting_dims, const Shape& c_shape,
     const Shape* bias_shape_ptr, const Shape& output_shape, double alpha_real,
     double alpha_imag, double beta, std::optional<int64_t> algorithm,
-    int64_t compute_precision) {
+    int64_t compute_precision, bool grad_x, bool grad_y) {
   absl::Span<const int64_t> lhs_col_dims = lhs_contracting_dims;
   TF_ASSIGN_OR_RETURN(
       std::vector<int64_t> lhs_row_dims,
@@ -406,16 +410,16 @@
                                output_shape.element_type()));
   }
 
-  return GemmConfig{
-      lhs_layout,
-      rhs_layout,
-      c_layout,
-      output_layout,
-      {alpha_real, alpha_imag},
-      beta,
-      compute_precision,
-      algorithm,
-  };
+  return GemmConfig{lhs_layout,
+                    rhs_layout,
+                    c_layout,
+                    output_layout,
+                    {alpha_real, alpha_imag},
+                    beta,
+                    compute_precision,
+                    algorithm,
+                    grad_x,
+                    grad_y};
 }
 
 /*static*/ StatusOr<GemmConfig> GemmConfig::For(const HloInstruction* gemm) {
@@ -433,12 +437,16 @@
   const Shape& output_shape =
       gemm->shape().IsTuple() ? gemm->shape().tuple_shapes(0) : gemm->shape();
 
+  auto attributes = gemm->frontend_attributes().map();
+  bool grad_x = (attributes["grad_x"] == "true");
+  bool grad_y = (attributes["grad_y"] == "true");
+
   return GemmConfig::For(
       lhs_shape, dot_dims.lhs_batch_dimensions(),
       dot_dims.lhs_contracting_dimensions(), rhs_shape,
       dot_dims.rhs_batch_dimensions(), dot_dims.rhs_contracting_dimensions(),
       output_shape, config.alpha_real(), config.alpha_imag(), config.beta(),
-      algorithm, se::blas::kDefaultComputePrecision);
+      algorithm, se::blas::kDefaultComputePrecision, grad_x, grad_y);
 }
 
 /*static*/ StatusOr<GemmConfig> GemmConfig::For(mlir::lmhlo_gpu::GEMMOp op) {
@@ -447,6 +455,13 @@
   std::optional<int64_t> algorithm;
   if (op.getAlgorithm()) algorithm = *op.getAlgorithm();
 
+  bool grad_x = false;
+  bool grad_y = false;
+  auto attr_grad_x = op.getGradX();
+  if (attr_grad_x) grad_x = attr_grad_x.value();
+  auto attr_grad_y = op.getGradY();
+  if (attr_grad_y) grad_y = attr_grad_y.value();
+
   int64_t compute_precision = 0;  // Default
   if (op.getPrecisionConfig().has_value()) {
     auto precision_config = op.getPrecisionConfig();
@@ -465,7 +480,8 @@
       dot_dims.getRhsBatchingDimensions(),
       dot_dims.getRhsContractingDimensions(), GetShape(op.getC()),
       op.getAlphaReal().convertToDouble(), op.getAlphaImag().convertToDouble(),
-      op.getBeta().convertToDouble(), algorithm, compute_precision);
+      op.getBeta().convertToDouble(), algorithm, compute_precision, grad_x,
+      grad_y);
 }
 
 namespace {
@@ -502,15 +518,14 @@
 }
 
 template <typename Scale, typename Input, typename Output>
-Status DoGemmWithAlgorithm(int64_t batch_size, int64_t m, int64_t n, int64_t k,
-                           const MatrixDescriptor& lhs,
-                           const MatrixDescriptor& rhs,
-                           const MatrixDescriptor& output, Scale alpha,
-                           Scale beta, se::Stream* stream,
-                           se::blas::AlgorithmType algorithm,
-                           se::blas::ComputePrecision compute_precision,
-                           const se::NumericOptions& numeric_options,
-                           se::blas::ProfileResult* profile_result) {
+Status DoGemmWithAlgorithm(
+    int64_t batch_size, int64_t m, int64_t n, int64_t k,
+    const MatrixDescriptor& lhs, const MatrixDescriptor& rhs,
+    const MatrixDescriptor& output, se::DeviceMemoryBase workspace, Scale alpha,
+    Scale beta, se::Stream* stream, se::blas::AlgorithmType algorithm,
+    se::blas::ComputePrecision compute_precision,
+    const se::NumericOptions& numeric_options,
+    se::blas::ProfileResult* profile_result, se::blas::CallContext context) {
   CHECK(output.transpose == se::blas::Transpose::kNoTranspose);
   PrimitiveType lhs_type = primitive_util::NativeToPrimitiveType<Input>();
   PrimitiveType output_type = primitive_util::NativeToPrimitiveType<Output>();
@@ -519,39 +534,49 @@
                                                       compute_precision));
   se::DeviceMemory<Output> output_data(output.data);
 
+  // Set a workspace for all Blas operations launched below.
+  se::blas::BlasSupport::ScopedWorkspace scoped_workspace(
+      stream->parent()->AsBlas(), &workspace);
+
   if (batch_size != 1) {
     return stream->ThenBlasGemmStridedBatchedWithAlgorithm(
         lhs.transpose, rhs.transpose, m, n, k, alpha, lhs.cast<Input>(),
         lhs.leading_dim_stride, lhs.batch_stride, rhs.cast<Input>(),
         rhs.leading_dim_stride, rhs.batch_stride, beta, &output_data,
         output.leading_dim_stride, output.batch_stride, batch_size,
-        computation_type, algorithm, numeric_options, profile_result);
+        computation_type, algorithm, numeric_options, profile_result, context);
   } else {
     return stream->ThenBlasGemmWithAlgorithm(
         lhs.transpose, rhs.transpose, m, n, k, alpha, lhs.cast<Input>(),
         lhs.leading_dim_stride, rhs.cast<Input>(), rhs.leading_dim_stride, beta,
         &output_data, output.leading_dim_stride, computation_type, algorithm,
-        numeric_options, profile_result);
+        numeric_options, profile_result, context);
   }
 }
 
 template <typename Scale, typename Input, typename Output>
 Status DoGemm(int64_t batch_size, int64_t m, int64_t n, int64_t k,
               const MatrixDescriptor& lhs, const MatrixDescriptor& rhs,
-              const MatrixDescriptor& output, Scale alpha, Scale beta,
-              se::Stream* stream,
+              const MatrixDescriptor& output, se::DeviceMemoryBase workspace,
+              Scale alpha, Scale beta, se::Stream* stream,
               std::optional<se::blas::AlgorithmType> algorithm,
               se::blas::ComputePrecision compute_precision,
               const se::NumericOptions& numeric_options,
-              se::blas::ProfileResult* profile_result) {
+              se::blas::ProfileResult* profile_result,
+              se::blas::CallContext context) {
   CHECK(output.transpose == se::blas::Transpose::kNoTranspose);
   se::DeviceMemory<Output> output_data(output.data);
 
+  // Set a workspace for all Blas operations launched below.
+  se::blas::BlasSupport::ScopedWorkspace scoped_workspace(
+      stream->parent()->AsBlas(), &workspace);
+
 #if GOOGLE_CUDA
   if (algorithm) {
     return DoGemmWithAlgorithm<Scale, Input, Output>(
-        batch_size, m, n, k, lhs, rhs, output, alpha, beta, stream, *algorithm,
-        compute_precision, numeric_options, profile_result);
+        batch_size, m, n, k, lhs, rhs, output, workspace, alpha, beta, stream,
+        *algorithm, compute_precision, numeric_options, profile_result,
+        context);
   }
 #endif
 
@@ -561,20 +586,21 @@
         lhs.leading_dim_stride, lhs.batch_stride, rhs.cast<Input>(),
         rhs.leading_dim_stride, rhs.batch_stride, beta, &output_data,
         output.leading_dim_stride, output.batch_stride, batch_size,
-        numeric_options);
+        numeric_options, context);
   }
 
   return stream->ThenBlasGemm(
       lhs.transpose, rhs.transpose, m, n, k, alpha, lhs.cast<Input>(),
       lhs.leading_dim_stride, rhs.cast<Input>(), rhs.leading_dim_stride, beta,
-      &output_data, output.leading_dim_stride, numeric_options);
+      &output_data, output.leading_dim_stride, numeric_options, context);
 }
 
 }  // namespace
 
 Status RunGemm(const GemmConfig& config, se::DeviceMemoryBase lhs_buffer,
                se::DeviceMemoryBase rhs_buffer,
-               se::DeviceMemoryBase output_buffer, bool deterministic_ops,
+               se::DeviceMemoryBase output_buffer,
+               se::DeviceMemoryBase workspace_buffer, bool deterministic_ops,
                se::Stream* stream,
                std::optional<se::blas::AlgorithmType> algorithm,
                se::blas::ProfileResult* profile_result) {
@@ -602,42 +628,64 @@
 
   if (!algorithm) algorithm = config.algorithm;
 
+  se::blas::CallContext context = se::blas::CallContext::kNone;
+  if (config.grad_x) {
+    context = must_swap_operands ? se::blas::CallContext::kBackpropInput2
+                                 : se::blas::CallContext::kBackpropInput1;
+  }
+  if (config.grad_y) {
+    context = must_swap_operands ? se::blas::CallContext::kBackpropInput1
+                                 : se::blas::CallContext::kBackpropInput2;
+  }
+
   std::tuple<PrimitiveType, PrimitiveType, PrimitiveType> operand_types{
       lhs_layout.dtype, rhs_layout.dtype, output_layout.dtype};
 
-#define TYPED_GEMM(SCALENTYPE, ATYPE, BTYPE, CTYPE)                         \
-  if (operand_types == std::make_tuple(ATYPE, BTYPE, CTYPE)) {              \
-    using NativeScaleType =                                                 \
-        primitive_util::PrimitiveTypeToNative<SCALENTYPE>::type;            \
-    using NativeAType = primitive_util::PrimitiveTypeToNative<ATYPE>::type; \
-    using NativeCType = primitive_util::PrimitiveTypeToNative<CTYPE>::type; \
-    return DoGemm<NativeScaleType, NativeAType, NativeCType>(               \
-        batch_size, m, n, k, lhs, rhs, output,                              \
-        static_cast<NativeScaleType>(config.alpha.real()),                  \
-        static_cast<NativeScaleType>(config.beta), stream, algorithm,       \
-        config.compute_precision, numeric_options, profile_result);         \
+  // Skip degenerate gemm with memzero. In general this is not safe, because it
+  // will suppress NaN propagation, however cuBLAS internally has exactly the
+  // same optimization for compatibility with NETLIB implementation, so we are
+  // not making things worse (and cuBLAS optimization is incompatible with CUDA
+  // graphs, so we are making sure we do not trigger it).
+  if (config.alpha.real() == 0.0 && config.alpha.imag() == 0.0 &&
+      config.beta == 0.0) {
+    stream->ThenMemZero(&output_buffer, output_buffer.size());
+    return tsl::OkStatus();
   }
 
-#define TYPED_GEMM_COMPLEX(SCALENTYPE, ATYPE, BTYPE, CTYPE)                 \
-  if (operand_types == std::make_tuple(ATYPE, BTYPE, CTYPE)) {              \
-    using NativeScaleType =                                                 \
-        primitive_util::PrimitiveTypeToNative<SCALENTYPE>::type;            \
-    using NativeAType = primitive_util::PrimitiveTypeToNative<ATYPE>::type; \
-    using NativeCType = primitive_util::PrimitiveTypeToNative<CTYPE>::type; \
-    return DoGemm<NativeScaleType, NativeAType, NativeCType>(               \
-        batch_size, m, n, k, lhs, rhs, output,                              \
-        static_cast<NativeScaleType>(config.alpha),                         \
-        static_cast<NativeScaleType>(config.beta), stream, algorithm,       \
-        config.compute_precision, numeric_options, profile_result);         \
+#define TYPED_GEMM(SCALENTYPE, ATYPE, BTYPE, CTYPE)                          \
+  if (operand_types == std::make_tuple(ATYPE, BTYPE, CTYPE)) {               \
+    using NativeScaleType =                                                  \
+        primitive_util::PrimitiveTypeToNative<SCALENTYPE>::type;             \
+    using NativeAType = primitive_util::PrimitiveTypeToNative<ATYPE>::type;  \
+    using NativeCType = primitive_util::PrimitiveTypeToNative<CTYPE>::type;  \
+    return DoGemm<NativeScaleType, NativeAType, NativeCType>(                \
+        batch_size, m, n, k, lhs, rhs, output, workspace_buffer,             \
+        static_cast<NativeScaleType>(config.alpha.real()),                   \
+        static_cast<NativeScaleType>(config.beta), stream, algorithm,        \
+        config.compute_precision, numeric_options, profile_result, context); \
+  }
+
+#define TYPED_GEMM_COMPLEX(SCALENTYPE, ATYPE, BTYPE, CTYPE)                  \
+  if (operand_types == std::make_tuple(ATYPE, BTYPE, CTYPE)) {               \
+    using NativeScaleType =                                                  \
+        primitive_util::PrimitiveTypeToNative<SCALENTYPE>::type;             \
+    using NativeAType = primitive_util::PrimitiveTypeToNative<ATYPE>::type;  \
+    using NativeCType = primitive_util::PrimitiveTypeToNative<CTYPE>::type;  \
+    return DoGemm<NativeScaleType, NativeAType, NativeCType>(                \
+        batch_size, m, n, k, lhs, rhs, output, workspace_buffer,             \
+        static_cast<NativeScaleType>(config.alpha),                          \
+        static_cast<NativeScaleType>(config.beta), stream, algorithm,        \
+        config.compute_precision, numeric_options, profile_result, context); \
   }
 
   if (output_layout.dtype == S32) {
     if (!algorithm) algorithm = se::blas::kDefaultGemmAlgo;
     return DoGemmWithAlgorithm<int32_t, int8_t, int32_t>(
-        batch_size, m, n, k, lhs, rhs, output,
+        batch_size, m, n, k, lhs, rhs, output, workspace_buffer,
         static_cast<int32_t>(config.alpha.real()),
         static_cast<int32_t>(config.beta), stream, *algorithm,
-        se::blas::kDefaultComputePrecision, numeric_options, profile_result);
+        se::blas::kDefaultComputePrecision, numeric_options, profile_result,
+        context);
   }
 
   TYPED_GEMM(F32, BF16, BF16, BF16)
@@ -657,7 +705,7 @@
       primitive_util::LowercasePrimitiveTypeName(lhs_layout.dtype),
       primitive_util::LowercasePrimitiveTypeName(rhs_layout.dtype),
       primitive_util::LowercasePrimitiveTypeName(output_layout.dtype));
-}
+}  // namespace gpu
 
 namespace gpublas_lt {
 
@@ -720,5 +768,35 @@
 
 }  // namespace gpublas_lt
 
+/*static*/ TritonGemmConfig TritonGemmConfig::FromProto(
+    const AutotuneResult::TritonGemmKey& proto) {
+  TritonGemmConfig config;
+  config.block_m = proto.block_m();
+  config.block_n = proto.block_n();
+  config.block_k = proto.block_k();
+  config.split_k = proto.split_k();
+  config.num_stages = proto.num_stages();
+  config.num_warps = proto.num_warps();
+  return config;
+}
+
+AutotuneResult::TritonGemmKey TritonGemmConfig::ToProto() const {
+  AutotuneResult::TritonGemmKey key;
+  key.set_block_m(block_m);
+  key.set_block_n(block_n);
+  key.set_block_k(block_k);
+  key.set_split_k(split_k);
+  key.set_num_stages(num_stages);
+  key.set_num_warps(num_warps);
+  return key;
+}
+
+std::string TritonGemmConfig::ToString() const {
+  return absl::StrCat("{block_m:", block_m, ",block_n:", block_n,
+                      ",block_k:", block_k, ",split_k:", split_k,
+                      ",num_stages:", num_stages, ",num_warps:", num_warps,
+                      "}");
+}
+
 }  // namespace gpu
 }  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/matmul_utils.h b/third_party/xla/xla/service/gpu/matmul_utils.h
index b6d5189..07ba566 100644
--- a/third_party/xla/xla/service/gpu/matmul_utils.h
+++ b/third_party/xla/xla/service/gpu/matmul_utils.h
@@ -18,10 +18,13 @@
 
 #include <cstdint>
 #include <optional>
+#include <string>
+#include <tuple>
 #include <utility>
 #include <vector>
 
 #include "absl/types/span.h"
+#include "xla/autotuning.pb.h"
 #include "xla/hlo/ir/hlo_instruction.h"
 #include "xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h"
 #include "xla/service/gpu/backend_configs.pb.h"
@@ -85,6 +88,14 @@
 };
 
 struct GemmConfig : public se::gpu::GemmConfig {
+  // For legacy Gemm operations XLA:GPU allocates its own workspace and passes
+  // it to all BLAS API calls.
+  //
+  // Size of the workspace based on NVIDIA recommendation:
+  // https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
+  static constexpr int64_t kHopperWorkspace = 32 * 1024 * 1024;  // 32 MiB
+  static constexpr int64_t kDefaultWorkspace = 4 * 1024 * 1024;  // 4 MiB
+
   static StatusOr<GemmConfig> For(const HloInstruction* gemm);
   static StatusOr<GemmConfig> For(mlir::lmhlo_gpu::GEMMOp op);
 
@@ -94,7 +105,8 @@
       absl::Span<const int64_t> rhs_batch_dims,
       absl::Span<const int64_t> rhs_contracting_dims, const Shape& output_shape,
       double alpha_real, double alpha_imag, double beta,
-      std::optional<int64_t> algorithm, int64_t compute_precision);
+      std::optional<int64_t> algorithm, int64_t compute_precision, bool grad_x,
+      bool grad_y);
 
   // As above with additional `c_shape` and `bias_shape_ptr` parameter, both
   // which are only necessarily for F8 gemms.
@@ -105,7 +117,7 @@
       absl::Span<const int64_t> rhs_contracting_dims, const Shape& c_shape,
       const Shape* bias_shape_ptr, const Shape& output_shape, double alpha_real,
       double alpha_imag, double beta, std::optional<int64_t> algorithm,
-      int64_t compute_precision);
+      int64_t compute_precision, bool grad_x, bool grad_y);
 
   template <typename CublasLtMatmulMaybeF8Op,
             typename = std::enable_if<
@@ -140,7 +152,8 @@
         op.getBias() == nullptr ? nullptr : &bias_shape, GetShape(op.getD()),
         op.getAlphaReal().convertToDouble(),
         op.getAlphaImag().convertToDouble(), op.getBeta().convertToDouble(),
-        op.getAlgorithm(), compute_precision);
+        op.getAlgorithm(), compute_precision, /*grad_x=*/false,
+        /*grad_y=*/false);
   }
 };
 
@@ -150,7 +163,8 @@
 // If `algorithm` is provided, it overrides the one specified in `config`.
 Status RunGemm(const GemmConfig& config, se::DeviceMemoryBase lhs_buffer,
                se::DeviceMemoryBase rhs_buffer,
-               se::DeviceMemoryBase output_buffer, bool deterministic_ops,
+               se::DeviceMemoryBase output_buffer,
+               se::DeviceMemoryBase workspace_buffer, bool deterministic_ops,
                se::Stream* stream,
                std::optional<se::blas::AlgorithmType> algorithm = std::nullopt,
                se::blas::ProfileResult* profile_result = nullptr);
@@ -165,6 +179,52 @@
 
 }  // namespace gpublas_lt
 
+// We should use this in code instead of AutotuneResult::TritonGemmKey.
+// This has some advantages, for example it can be used in hashmaps.
+struct TritonGemmConfig {
+  constexpr TritonGemmConfig() = default;
+  constexpr TritonGemmConfig(int block_m, int block_n, int block_k, int split_k,
+                             int num_stages, int num_warps)
+      : block_m(block_m),
+        block_n(block_n),
+        block_k(block_k),
+        split_k(split_k),
+        num_stages(num_stages),
+        num_warps(num_warps) {}
+
+  int block_m = 0;
+  int block_n = 0;
+  int block_k = 0;
+  int split_k = 0;
+  int num_stages = 0;
+  int num_warps = 0;
+
+ private:
+  auto ToTuple() const {
+    return std::make_tuple(block_m, block_n, block_k, split_k, num_stages,
+                           num_warps);
+  }
+
+ public:
+  static TritonGemmConfig FromProto(const AutotuneResult::TritonGemmKey& proto);
+  AutotuneResult::TritonGemmKey ToProto() const;
+
+  std::string ToString() const;
+
+  bool operator==(const TritonGemmConfig& other) const {
+    return ToTuple() == other.ToTuple();
+  }
+
+  bool operator<(const TritonGemmConfig& other) const {
+    return ToTuple() < other.ToTuple();
+  }
+
+  template <typename H>
+  friend H AbslHashValue(H h, const TritonGemmConfig& config) {
+    return H::combine(std::move(h), config.ToTuple());
+  }
+};
+
 }  // namespace gpu
 }  // namespace xla
 
diff --git a/third_party/xla/xla/service/gpu/model/analytical_latency_estimator.cc b/third_party/xla/xla/service/gpu/model/analytical_latency_estimator.cc
index 1c0a5b2..d13c96d 100644
--- a/third_party/xla/xla/service/gpu/model/analytical_latency_estimator.cc
+++ b/third_party/xla/xla/service/gpu/model/analytical_latency_estimator.cc
@@ -59,8 +59,9 @@
   }
 
   absl::Duration total_estimated_time =
-      GpuPerformanceModel::EstimateRunTimeForInstruction(instr,
-                                                         &*cost_analysis_)
+      GpuPerformanceModel::EstimateRunTimeForInstruction(
+          instr, &*cost_analysis_,
+          GpuPerformanceModelOptions::ForModule(instr->GetModule()))
           .exec_time;
   LatencyEstimator::TimeCost cost_in_us =
       absl::ToDoubleMicroseconds(total_estimated_time);
diff --git a/third_party/xla/xla/service/gpu/model/gpu_cost_model_stats_collection.cc b/third_party/xla/xla/service/gpu/model/gpu_cost_model_stats_collection.cc
index d5acb19..1447287 100644
--- a/third_party/xla/xla/service/gpu/model/gpu_cost_model_stats_collection.cc
+++ b/third_party/xla/xla/service/gpu/model/gpu_cost_model_stats_collection.cc
@@ -39,8 +39,9 @@
     for (auto* fusion_instr : computation->instructions()) {
       if (fusion_instr->opcode() != HloOpcode::kFusion) continue;
 
-      GpuPerformanceModel::RecordEstimatedRunTime(fusion_instr,
-                                                  &cost_analysis_);
+      GpuPerformanceModel::RecordEstimatedRunTime(
+          fusion_instr, &cost_analysis_,
+          GpuPerformanceModelOptions::ForModule(module));
     }
   }
   return false;
diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc
index b12170e..51579cd 100644
--- a/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc
+++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc
@@ -239,7 +239,8 @@
 
 /*static*/ EstimateRunTimeData
 GpuPerformanceModel::EstimateRunTimeForInstruction(
-    const HloInstruction* instr, const GpuHloCostAnalysis* cost_analysis) {
+    const HloInstruction* instr, const GpuHloCostAnalysis* cost_analysis,
+    const GpuPerformanceModelOptions& config) {
   const se::DeviceDescription* device_info = cost_analysis->device_info_;
 
   int64_t flops = cost_analysis->flop_count(*instr);
@@ -255,7 +256,7 @@
   absl::Duration compute_time = ComputeTime(*device_info, flops, num_threads);
   absl::Duration read_time = ProducerInputAccessTime(
       cost_analysis, *device_info, launch_dimensions.num_blocks(),
-      /*producer=*/instr, fusion_analysis);
+      /*producer=*/instr, fusion_analysis, config);
   absl::Duration write_time =
       absl::Seconds(1.0f * bytes_written / device_info->memory_bandwidth());
   absl::Duration exec_time = std::max(compute_time, read_time + write_time);
@@ -281,13 +282,14 @@
     const se::DeviceDescription& gpu_device_info, int64_t num_blocks,
     const HloInstruction* producer,
     std::optional<HloFusionAnalysis>& fusion_analysis,
+    const GpuPerformanceModelOptions& config,
     const HloInstruction* fused_consumer) {
   absl::Duration ret = absl::ZeroDuration();
   float producer_output_utilization = 1.f;
   ConstHloInstructionSet consumer_operands;
   bool consumer_transposes = false;
   if (fused_consumer) {
-    consumer_transposes = IsPhysicallyTransposing(*fused_consumer);
+    consumer_transposes = TransposesMinorDimension(fused_consumer);
     producer_output_utilization = cost_analysis->operand_utilization(
         *fused_consumer, fused_consumer->operand_index(producer));
     for (const HloInstruction* op : fused_consumer->operands()) {
@@ -295,7 +297,7 @@
     }
   }
 
-  bool producer_transposes = IsPhysicallyTransposing(*producer);
+  bool producer_transposes = TransposesMinorDimension(producer);
   for (int i = 0; i < producer->operand_count(); ++i) {
     // Information about data read taking into account utilization.
     // If `operand_utilization` is 0, `operand_bytes_accessed` should be also 0.
@@ -345,13 +347,20 @@
                       fusion_analysis->GetEmitterFusionKind() ==
                           HloFusionAnalysis::EmitterFusionKind::kTranspose) ||
                      (!producer_transposes && !consumer_transposes);
+    // Fusing two row reductions breaks coalescing.
+    coalesced &= ((fusion_analysis &&
+                   fusion_analysis->GetEmitterFusionKind() !=
+                       HloFusionAnalysis::EmitterFusionKind::kReduction) ||
+                  !fused_consumer || !IsInputFusibleReduction(*producer) ||
+                  !IsInputFusibleReduction(*fused_consumer));
     const auto& operand_shape = producer->operand(i)->shape();
 
     CHECK_LE(common_utilization, producer_output_utilization);
     float n_bytes_total = operand_bytes_accessed *
                           (producer_output_utilization - common_utilization);
     ret += ReadTime(gpu_device_info, num_blocks, /*n_bytes_net=*/n_bytes_net,
-                    n_bytes_total, operand_shape.element_type(), coalesced);
+                    n_bytes_total, operand_shape.element_type(),
+                    coalesced || !config.consider_coalescing);
   }
   return ret;
 }
@@ -369,6 +378,7 @@
 
 GpuPerformanceModel::RunTimes GpuPerformanceModel::EstimateRunTimes(
     const HloInstruction* producer, const GpuHloCostAnalysis* cost_analysis,
+    const GpuPerformanceModelOptions& config,
     std::vector<HloInstruction*> fused_consumers, bool multi_output) {
   VLOG(8) << "Producer: " << producer->name();
   if (producer->opcode() == HloOpcode::kFusion) {
@@ -378,7 +388,7 @@
   const se::DeviceDescription* device_info = cost_analysis->device_info_;
 
   EstimateRunTimeData producer_data =
-      EstimateRunTimeForInstruction(producer, cost_analysis);
+      EstimateRunTimeForInstruction(producer, cost_analysis, config);
 
   int64_t fused_consumer_count = fused_consumers.size();
   float total_producer_utilization = 0;
@@ -419,7 +429,7 @@
     // don't currently have an analysis that is able to detect these cases.
     absl::Duration input_access_time_by_this_consumer = ProducerInputAccessTime(
         cost_analysis, *device_info, launch_dimensions_fused.num_blocks(),
-        producer, analysis_fused, fused_consumer);
+        producer, analysis_fused, config, fused_consumer);
 
     exec_time_fused += std::max(compute_time_by_this_consumer,
                                 input_access_time_by_this_consumer);
@@ -431,7 +441,7 @@
     producer_output_read_time_unfused += ReadTime(
         *device_info, launch_dimensions_unfused.num_blocks(), n_bytes_net,
         n_bytes_total, fused_consumer->shape().element_type(),
-        /*coalesced=*/!IsPhysicallyTransposing(*fused_consumer));
+        /*coalesced=*/!TransposesMinorDimension(fused_consumer));
   }
 
   absl::Duration time_unfused =
@@ -458,12 +468,13 @@
 }
 
 void GpuPerformanceModel::RecordEstimatedRunTime(
-    HloInstruction* instruction, const GpuHloCostAnalysis* cost_analysis) {
+    HloInstruction* instruction, const GpuHloCostAnalysis* cost_analysis,
+    const GpuPerformanceModelOptions& config) {
   DCHECK(Cast<const HloFusionInstruction>(instruction)) << "expected fusion";
   DCHECK(cost_analysis != nullptr) << "expected cost analysis";
 
   EstimateRunTimeData data =
-      EstimateRunTimeForInstruction(instruction, cost_analysis);
+      EstimateRunTimeForInstruction(instruction, cost_analysis, config);
   double cycles = absl::ToDoubleNanoseconds(data.exec_time) *
                   cost_analysis->device_info_->clock_rate_ghz();
 
@@ -643,7 +654,7 @@
 
   if (HloDataflowAnalysis::IsAsynchronousOperationDone(instr.opcode())) {
     VLOG(8) << "Returning 0 cost for async done op " << instr.name();
-    return absl::Microseconds(0);
+    return absl::ZeroDuration();
   }
   switch (instr.opcode()) {
     case HloOpcode::kAllReduce:
diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model.h b/third_party/xla/xla/service/gpu/model/gpu_performance_model.h
index 2223b31..8bca1d8 100644
--- a/third_party/xla/xla/service/gpu/model/gpu_performance_model.h
+++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model.h
@@ -54,6 +54,27 @@
   absl::Duration exec_time;
 };
 
+struct GpuPerformanceModelOptions {
+  // Whether to attempt to model the effect of uncoalesced reads.
+  bool consider_coalescing = false;
+
+  static GpuPerformanceModelOptions Default() {
+    return GpuPerformanceModelOptions();
+  }
+
+  static GpuPerformanceModelOptions PriorityFusion() {
+    GpuPerformanceModelOptions config;
+    config.consider_coalescing = true;
+    return config;
+  }
+
+  static GpuPerformanceModelOptions ForModule(const HloModule* module) {
+    return module->config().debug_options().xla_gpu_enable_priority_fusion()
+               ? PriorityFusion()
+               : Default();
+  }
+};
+
 class GpuPerformanceModel {
  public:
   struct RunTimes {
@@ -62,16 +83,19 @@
   };
 
   static EstimateRunTimeData EstimateRunTimeForInstruction(
-      const HloInstruction* instr, const GpuHloCostAnalysis* cost_analysis);
+      const HloInstruction* instr, const GpuHloCostAnalysis* cost_analysis,
+      const GpuPerformanceModelOptions& config);
 
   static RunTimes EstimateRunTimes(
       const HloInstruction* producer, const GpuHloCostAnalysis* cost_analysis,
+      const GpuPerformanceModelOptions& config,
       std::vector<HloInstruction*> fused_consumers = {},
       bool multi_output = false);
 
   // Writes estimated execution time to FusionBackendConfig.reification_cost.
   static void RecordEstimatedRunTime(HloInstruction* instruction,
-                                     const GpuHloCostAnalysis* cost_analysis);
+                                     const GpuHloCostAnalysis* cost_analysis,
+                                     const GpuPerformanceModelOptions& config);
   static absl::Duration ComputeTime(
       const se::DeviceDescription& gpu_device_info, int64_t flops,
       int64_t num_threads);
@@ -81,6 +105,7 @@
       const se::DeviceDescription& gpu_device_info, int64_t num_blocks,
       const HloInstruction* producer,
       std::optional<HloFusionAnalysis>& fusion_analysis,
+      const GpuPerformanceModelOptions& config,
       const HloInstruction* fused_consumer = nullptr);
 };
 
diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc
index 2b282d6..1eed176 100644
--- a/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc
+++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc
@@ -75,8 +75,8 @@
   HloInstruction* root = module->entry_computation()->root_instruction();
   ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_));
 
-  GpuPerformanceModel::RunTimes t =
-      GpuPerformanceModel::EstimateRunTimes(root, &analysis_);
+  GpuPerformanceModel::RunTimes t = GpuPerformanceModel::EstimateRunTimes(
+      root, &analysis_, GpuPerformanceModelOptions::Default());
   // Dominated by the DRAM bandwidth.
   EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 57, 10);
 }
@@ -102,12 +102,13 @@
   HloInstruction* root = module->entry_computation()->root_instruction();
   ASSERT_IS_OK(root->Accept(&analysis_));
 
-  GpuPerformanceModel::RunTimes t =
-      GpuPerformanceModel::EstimateRunTimes(root, &analysis_);
+  GpuPerformanceModel::RunTimes t = GpuPerformanceModel::EstimateRunTimes(
+      root, &analysis_, GpuPerformanceModelOptions::Default());
   // Dominated by the kernel launch overhead.
   EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 5, 1);
 
-  GpuPerformanceModel::RecordEstimatedRunTime(root, &analysis_);
+  GpuPerformanceModel::RecordEstimatedRunTime(
+      root, &analysis_, GpuPerformanceModelOptions::Default());
   double recorded_cycles = root->backend_config<FusionBackendConfig>()
                                ->reification_cost()
                                .end_to_end_cycles();
@@ -135,12 +136,13 @@
   HloInstruction* root = module->entry_computation()->root_instruction();
   ASSERT_IS_OK(root->Accept(&analysis_));
 
-  GpuPerformanceModel::RunTimes t =
-      GpuPerformanceModel::EstimateRunTimes(root, &analysis_);
+  GpuPerformanceModel::RunTimes t = GpuPerformanceModel::EstimateRunTimes(
+      root, &analysis_, GpuPerformanceModelOptions::Default());
   // Dominated by the DRAM bandwidth.
   EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 175, 30);
 
-  GpuPerformanceModel::RecordEstimatedRunTime(root, &analysis_);
+  GpuPerformanceModel::RecordEstimatedRunTime(
+      root, &analysis_, GpuPerformanceModelOptions::Default());
   double recorded_cycles = root->backend_config<FusionBackendConfig>()
                                ->reification_cost()
                                .end_to_end_cycles();
@@ -170,8 +172,8 @@
   HloInstruction* root = module->entry_computation()->root_instruction();
   ASSERT_IS_OK(root->Accept(&analysis_));
 
-  GpuPerformanceModel::RunTimes t =
-      GpuPerformanceModel::EstimateRunTimes(root, &analysis_);
+  GpuPerformanceModel::RunTimes t = GpuPerformanceModel::EstimateRunTimes(
+      root, &analysis_, GpuPerformanceModelOptions::Default());
   // Parameter 0 read is accelerated by L1 cache even though the total data
   // volume is the same as in the test LargeReadWrite above.
   EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 118, 12);
@@ -200,8 +202,8 @@
   HloInstruction* root = module->entry_computation()->root_instruction();
   ASSERT_IS_OK(root->Accept(&analysis_));
 
-  GpuPerformanceModel::RunTimes t =
-      GpuPerformanceModel::EstimateRunTimes(root, &analysis_);
+  GpuPerformanceModel::RunTimes t = GpuPerformanceModel::EstimateRunTimes(
+      root, &analysis_, GpuPerformanceModelOptions::Default());
   // Parameter 0 read is accelerated by L2 cache (does not fit in L1).
   EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 123, 12);
 }
@@ -233,8 +235,8 @@
   HloInstruction* root = module->entry_computation()->root_instruction();
   ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_));
 
-  GpuPerformanceModel::RunTimes t =
-      GpuPerformanceModel::EstimateRunTimes(root, &analysis_);
+  GpuPerformanceModel::RunTimes t = GpuPerformanceModel::EstimateRunTimes(
+      root, &analysis_, GpuPerformanceModelOptions::Default());
   EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 5, 1);
 }
 
@@ -316,8 +318,8 @@
     std::vector<HloInstruction*> consumers{
         module->entry_computation()->GetInstructionWithName("reduce.2")};
 
-    return GpuPerformanceModel::EstimateRunTimes(producer, &analysis,
-                                                 consumers);
+    return GpuPerformanceModel::EstimateRunTimes(
+        producer, &analysis, GpuPerformanceModelOptions::Default(), consumers);
   };
 
   TF_ASSERT_OK_AND_ASSIGN(auto large_small_reduce_runtime,
@@ -357,13 +359,45 @@
       module->entry_computation()->GetInstructionWithName("transpose.1");
   std::vector<HloInstruction*> consumers{
       module->entry_computation()->GetInstructionWithName("reduce.1")};
-  GpuPerformanceModel::RunTimes t =
-      GpuPerformanceModel::EstimateRunTimes(producer, &analysis_, consumers);
+  GpuPerformanceModel::RunTimes t = GpuPerformanceModel::EstimateRunTimes(
+      producer, &analysis_, GpuPerformanceModelOptions::PriorityFusion(),
+      consumers);
 
   EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 105, 10);
   EXPECT_NEAR(absl::ToInt64Microseconds(t.time_fused), 1030, 10);
 }
 
+TEST_F(GpuPerformanceModelTest, FusingNonMinorTransposeIntoReduceIsFast) {
+  constexpr absl::string_view kHlo = R"(
+HloModule testmodule
+
+max {
+  p0 = f32[] parameter(0)
+  p1 = f32[] parameter(1)
+  ROOT max = f32[] maximum(p0, p1)
+}
+
+ENTRY fusion {
+  c = f32[] constant(-inf)
+  p0 = f32[1500,32,128]{1,2,0} parameter(0)
+  transpose.1 = f32[1500,128,32]{2,0,1} transpose(p0), dimensions={0,2,1}
+  ROOT reduce.1 = f32[1500,32] reduce(transpose.1, c), dimensions={1}, to_apply=max
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo));
+  ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_));
+
+  auto* producer =
+      module->entry_computation()->GetInstructionWithName("transpose.1");
+  std::vector<HloInstruction*> consumers{
+      module->entry_computation()->GetInstructionWithName("reduce.1")};
+  GpuPerformanceModel::RunTimes t = GpuPerformanceModel::EstimateRunTimes(
+      producer, &analysis_, GpuPerformanceModelOptions::Default(), consumers);
+
+  EXPECT_LT(t.time_fused, t.time_unfused);
+}
+
 TEST_F(GpuPerformanceModelTest, DusScalesWithUpdates) {
   constexpr absl::string_view kHlo = R"(
 HloModule testmodule
@@ -410,9 +444,11 @@
   ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_));
 
   GpuPerformanceModel::RunTimes t1 = GpuPerformanceModel::EstimateRunTimes(
-      module->entry_computation()->root_instruction()->operand(0), &analysis_);
+      module->entry_computation()->root_instruction()->operand(0), &analysis_,
+      GpuPerformanceModelOptions::Default());
   GpuPerformanceModel::RunTimes t2 = GpuPerformanceModel::EstimateRunTimes(
-      module->entry_computation()->root_instruction()->operand(1), &analysis_);
+      module->entry_computation()->root_instruction()->operand(1), &analysis_,
+      GpuPerformanceModelOptions::Default());
 
   // DUS scales with the size of the updates, so these two fusions should have
   // the same cost.
diff --git a/third_party/xla/xla/service/gpu/multi_output_fusion.cc b/third_party/xla/xla/service/gpu/multi_output_fusion.cc
index ce3b9ee..a88845b 100644
--- a/third_party/xla/xla/service/gpu/multi_output_fusion.cc
+++ b/third_party/xla/xla/service/gpu/multi_output_fusion.cc
@@ -217,7 +217,7 @@
       [&](const HloInstruction& producer,
           const HloInstruction& consumer) -> FusionDecision {
         GpuPerformanceModel::RunTimes t = GpuPerformanceModel::EstimateRunTimes(
-            &producer, cost_analysis,
+            &producer, cost_analysis, GpuPerformanceModelOptions::Default(),
             // `EstimateRunTimes`'s interface violates const correctness, so we
             // need the const cast here.
             {const_cast<HloInstruction*>(&consumer)},
diff --git a/third_party/xla/xla/service/gpu/nvptx_compiler.cc b/third_party/xla/xla/service/gpu/nvptx_compiler.cc
index 6f4646f..7f985c4 100644
--- a/third_party/xla/xla/service/gpu/nvptx_compiler.cc
+++ b/third_party/xla/xla/service/gpu/nvptx_compiler.cc
@@ -54,6 +54,7 @@
 #include "xla/service/gpu/gpu_asm_opts_util.h"
 #include "xla/service/gpu/gpu_conv_padding_legalization.h"
 #include "xla/service/gpu/gpu_conv_rewriter.h"
+#include "xla/service/gpu/gpu_sort_rewriter.h"
 #include "xla/service/gpu/ir_emission_utils.h"
 #include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
 #include "xla/service/gpu/metrics.h"
@@ -309,6 +310,13 @@
   return OkStatus();
 }
 
+Status NVPTXCompiler::AddCustomKernelReplacementPasses(
+    HloPassPipeline* pipeline, const DebugOptions& debug_options) {
+  if (debug_options.xla_gpu_enable_cub_radix_sort()) {
+    pipeline->AddPass<GpuSortRewriter>();
+  }
+  return OkStatus();
+}
 namespace {
 // Try to load ptx from files defined in the FLAGS. If successful, return true.
 bool MaybeLoadPtxFromFile(const HloModuleConfig module_config,
diff --git a/third_party/xla/xla/service/gpu/nvptx_compiler.h b/third_party/xla/xla/service/gpu/nvptx_compiler.h
index ef80167..b1c4c2b 100644
--- a/third_party/xla/xla/service/gpu/nvptx_compiler.h
+++ b/third_party/xla/xla/service/gpu/nvptx_compiler.h
@@ -62,6 +62,9 @@
       AutotuneConfig& autotune_config,
       tsl::thread::ThreadPool* thread_pool) override;
 
+  Status AddCustomKernelReplacementPasses(
+      HloPassPipeline* pipeline, const DebugOptions& debug_options) override;
+
   HloDataflowAnalysis::CanShareBuffer GetCanShareBuffer() const override;
 
   StatusOr<std::pair<std::string, std::vector<uint8_t>>> CompileTargetBinary(
diff --git a/third_party/xla/xla/service/gpu/priority_fusion.cc b/third_party/xla/xla/service/gpu/priority_fusion.cc
index 2f69db5..e45b8e7 100644
--- a/third_party/xla/xla/service/gpu/priority_fusion.cc
+++ b/third_party/xla/xla/service/gpu/priority_fusion.cc
@@ -140,7 +140,8 @@
                            HloInstruction* original_producer,
                            HloInstruction* original_consumer) override {
     if (fusion_process_dump_) {
-      auto* fusion_step = fusion_process_dump_->add_fusion_steps();
+      auto* fusion_step =
+          fusion_process_dump_->add_fusion_steps()->mutable_fusion();
 
       // Explicit std::string is needed for OSS proto implementation.
       fusion_step->set_fusion_name(std::string(fusion->name()));
@@ -238,12 +239,29 @@
     // Don't fuse if we can't fuse in all users.
     if (auto fusion_decision = CanFuseWithAllUsers(producer);
         !fusion_decision) {
+      if (fusion_process_dump_) {
+        auto* step = fusion_process_dump_->add_fusion_steps()
+                         ->mutable_producer_ineligible();
+        step->set_producer_name(std::string(producer->name()));
+        step->set_reason(fusion_decision.Explain());
+      }
       return std::numeric_limits<Priority>::min();
     }
 
     GpuPerformanceModel::RunTimes run_times =
-        GpuPerformanceModel::EstimateRunTimes(producer, &cost_analysis_,
-                                              producer->users());
+        GpuPerformanceModel::EstimateRunTimes(
+            producer, &cost_analysis_,
+            GpuPerformanceModelOptions::PriorityFusion(), producer->users());
+    if (fusion_process_dump_) {
+      auto* step =
+          fusion_process_dump_->add_fusion_steps()->mutable_update_priority();
+      step->set_producer_name(std::string(producer->name()));
+      for (auto* consumer : producer->users()) {
+        step->add_consumer_names(std::string(consumer->name()));
+      }
+      step->set_us_fused(absl::ToDoubleMicroseconds(run_times.time_fused));
+      step->set_us_unfused(absl::ToDoubleMicroseconds(run_times.time_unfused));
+    }
     return absl::ToInt64Nanoseconds(run_times.time_unfused -
                                     run_times.time_fused);
   }
diff --git a/third_party/xla/xla/service/gpu/priority_fusion_test.cc b/third_party/xla/xla/service/gpu/priority_fusion_test.cc
index e756c84..f78f127 100644
--- a/third_party/xla/xla/service/gpu/priority_fusion_test.cc
+++ b/third_party/xla/xla/service/gpu/priority_fusion_test.cc
@@ -526,5 +526,28 @@
               GmockMatch(m::Scatter(m::Parameter(), m::Add(), m::Add())));
 }
 
+TEST_F(PriorityFusionTest, DoNotFuseReduceIntoReduceEvenIfOccupancyIsHigh) {
+  constexpr absl::string_view kHlo = R"(
+    HloModule test_module
+
+    add {
+      lhs = f32[] parameter(0)
+      rhs = f32[] parameter(1)
+      ROOT add = f32[] add(lhs, rhs)
+    }
+
+    ENTRY main {
+      p0 = f32[4,3584,128,168]{3,2,1,0} parameter(0)
+      c = f32[] constant(0)
+      r1 = f32[4,3584,128]{2,1,0} reduce(p0, c), dimensions={3}, to_apply=add
+      ROOT r2 = f32[4,3584]{1,0} reduce(r1, c), dimensions={2}, to_apply=add
+    })";
+
+  RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"(
+CHECK: ROOT {{.*}} reduce(
+CHECK: ROOT {{.*}} reduce(
+  )");
+}
+
 }  // namespace gpu
 }  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/runtime/BUILD b/third_party/xla/xla/service/gpu/runtime/BUILD
index 38b5845..1a1cc8e 100644
--- a/third_party/xla/xla/service/gpu/runtime/BUILD
+++ b/third_party/xla/xla/service/gpu/runtime/BUILD
@@ -86,6 +86,7 @@
         "//xla/runtime:custom_call",
         "//xla/runtime:custom_call_registry",
         "//xla/runtime:executable",
+        "//xla/service:collective_ops_utils",
         "//xla/service:computation_placer_hdr",
         "//xla/service:executable",
         "//xla/service:global_device_id",
@@ -189,7 +190,9 @@
     name = "cub_sort",
     srcs = ["cub_sort.cc"],
     hdrs = ["cub_sort.h"],
-    local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
+    local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([
+        "TENSORFLOW_USE_ROCM=1",
+    ]),
     visibility = ["//visibility:public"],
     deps = [
         ":support",
@@ -200,7 +203,7 @@
         "//xla/service:executable",
         "//xla/stream_executor:device_memory",
         "@com_google_absl//absl/status",
-    ] + if_cuda_is_configured([
+    ] + if_gpu_is_configured([
         "//xla/service/gpu:cub_sort_thunk",
     ]),
 )
@@ -252,12 +255,12 @@
         ":conv",
         ":conv_reorder",
         ":cub_sort",
-        ":cublas_lt_matmul",
         ":custom_call",
         ":custom_call_registry",
         ":fft",
         ":fused_attention",
         ":gemm",
+        ":gpublas_lt_matmul",
         ":graph_launch",
         ":io_feed",
         ":kernel_launch",
@@ -298,7 +301,7 @@
         "//xla/runtime:custom_call_registry",
         "//xla/runtime:executable",
         "//xla/runtime:state",
-        "//xla/service/gpu:fft_thunk",
+        "//xla/service/gpu/runtime3:fft_thunk",
         "//xla/stream_executor:fft",
         "//xla/translate/mhlo_to_hlo:attribute_exporter",
     ],
@@ -454,6 +457,7 @@
         "//xla/stream_executor:device_memory",
         "@com_google_absl//absl/container:node_hash_map",
         "@com_google_absl//absl/status",
+        "@local_tsl//tsl/platform:errors",
     ] + if_cuda_is_configured([
         "//xla/service/gpu:gemm_algorithm_picker",
         "//xla/stream_executor/gpu:redzone_allocator",
@@ -477,6 +481,7 @@
         "//xla/runtime:custom_call_registry",
         "//xla/runtime:executable",
         "//xla/service:executable",
+        "//xla/service/gpu:buffer_allocations",
         "//xla/service/gpu:non_atomically_upgradeable_rw_lock",
         "//xla/stream_executor",
         "//xla/stream_executor/gpu:gpu_graph",
@@ -485,6 +490,7 @@
         "@com_google_absl//absl/log:check",
         "@com_google_absl//absl/status",
         "@com_google_absl//absl/synchronization",
+        "@com_google_absl//absl/types:span",
         "@local_tsl//tsl/profiler/lib:profiler_lock",
         "@local_tsl//tsl/profiler/lib:traceme",
         "@local_tsl//tsl/profiler/lib:traceme_encode",
@@ -566,9 +572,9 @@
 )
 
 cc_library(
-    name = "cublas_lt_matmul",
-    srcs = ["cublas_lt_matmul.cc"],
-    hdrs = ["cublas_lt_matmul.h"],
+    name = "gpublas_lt_matmul",
+    srcs = ["gpublas_lt_matmul.cc"],
+    hdrs = ["gpublas_lt_matmul.h"],
     local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM"]),
     visibility = ["//visibility:public"],
     deps = [
@@ -583,7 +589,6 @@
         "//xla/service:executable",
         "//xla/service/gpu:matmul_utils",
         "//xla/stream_executor",
-        "//xla/stream_executor/cuda:cublas_lt_header",
         "@local_tsl//tsl/platform:status",
     ] + if_rocm_is_configured([
         "@local_config_rocm//rocm:rocm_headers",
diff --git a/third_party/xla/xla/service/gpu/runtime/collectives.cc b/third_party/xla/xla/service/gpu/runtime/collectives.cc
index 7dabce9..00021ab 100644
--- a/third_party/xla/xla/service/gpu/runtime/collectives.cc
+++ b/third_party/xla/xla/service/gpu/runtime/collectives.cc
@@ -27,6 +27,7 @@
 #include "absl/strings/str_cat.h"
 #include "xla/runtime/custom_call.h"
 #include "xla/runtime/executable.h"
+#include "xla/service/collective_ops_utils.h"
 #include "xla/service/computation_placer.h"
 #include "xla/service/global_device_id.h"
 #include "xla/service/gpu/gpu_executable_run_options.h"
@@ -105,7 +106,7 @@
     const NcclExecuteParams& params, int64_t group_mode, int64_t op_id,
     absl::Span<const int64_t> replica_group_offsets,
     absl::Span<const int64_t> replica_group_values, int64_t stream_id,
-    bool enable_clique_optimization) {
+    bool enable_clique_optimization_flag) {
   // TODO(b/233930690): Pass the attribute below as a nested array.
   // Pass an array of arrays using two vectors; one specifying all the values
   // and another specifying the (ending) offsets of each array in the other
@@ -120,9 +121,12 @@
     replica_groups.push_back(replica_group);
   }
 
-  return LockNcclComm(params, replica_groups,
-                      static_cast<CollectiveOpGroupMode>(group_mode), op_id,
-                      stream_id, enable_clique_optimization);
+  // Always enable clique optimization for single host, which is indicated by
+  // the absence of nccl_unique_id_callback.
+  return LockNcclComm(
+      params, replica_groups, static_cast<CollectiveOpGroupMode>(group_mode),
+      op_id, stream_id,
+      enable_clique_optimization_flag || !params.nccl_unique_id_callback);
 }
 #endif  // XLA_ENABLE_XCCL
 
diff --git a/third_party/xla/xla/service/gpu/runtime/cub_sort.cc b/third_party/xla/xla/service/gpu/runtime/cub_sort.cc
index 32b7099..e6cbd13 100644
--- a/third_party/xla/xla/service/gpu/runtime/cub_sort.cc
+++ b/third_party/xla/xla/service/gpu/runtime/cub_sort.cc
@@ -26,7 +26,7 @@
 #include "xla/service/service_executable_run_options.h"
 #include "xla/stream_executor/device_memory.h"
 
-#ifdef GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 #include "xla/service/gpu/cub_sort_thunk.h"
 #endif
 
@@ -41,7 +41,7 @@
 absl::Status CubDeviceRadixSortKeysImpl(
     const ServiceExecutableRunOptions* run_options, FlatMemrefView input_view,
     FlatMemrefView output_view, FlatMemrefView scratch_view, bool descending) {
-#ifdef GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
   return RunCubSort(input_view.dtype, std::nullopt,
                     GetDeviceAddress(input_view), DeviceMemoryBase(),
                     GetDeviceAddress(output_view), DeviceMemoryBase(),
@@ -56,7 +56,7 @@
     FlatMemrefView input_keys_view, FlatMemrefView input_values_view,
     FlatMemrefView output_keys_view, FlatMemrefView output_values_view,
     FlatMemrefView scratch_view, bool descending) {
-#ifdef GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
   return RunCubSort(
       input_keys_view.dtype, input_values_view.dtype,
       GetDeviceAddress(input_keys_view), GetDeviceAddress(input_values_view),
diff --git a/third_party/xla/xla/service/gpu/runtime/executable.cc b/third_party/xla/xla/service/gpu/runtime/executable.cc
index f75fe65..799f837 100644
--- a/third_party/xla/xla/service/gpu/runtime/executable.cc
+++ b/third_party/xla/xla/service/gpu/runtime/executable.cc
@@ -33,12 +33,12 @@
 #include "xla/service/gpu/runtime/conv.h"
 #include "xla/service/gpu/runtime/conv_reorder.h"
 #include "xla/service/gpu/runtime/cub_sort.h"
-#include "xla/service/gpu/runtime/cublas_lt_matmul.h"
 #include "xla/service/gpu/runtime/custom_call.h"
 #include "xla/service/gpu/runtime/custom_call_registry.h"
 #include "xla/service/gpu/runtime/fft.h"
 #include "xla/service/gpu/runtime/fused_attention.h"
 #include "xla/service/gpu/runtime/gemm.h"
+#include "xla/service/gpu/runtime/gpublas_lt_matmul.h"
 #include "xla/service/gpu/runtime/graph_launch.h"
 #include "xla/service/gpu/runtime/io_feed.h"
 #include "xla/service/gpu/runtime/memcpy.h"
@@ -94,14 +94,13 @@
 #if GOOGLE_CUDA
   RegisterFusedAttentionCustomCalls(registry);
   RegisterFusedAttentionBackwardCustomCalls(registry);
-  RegisterCubSortCustomCalls(registry);
 #endif  // GOOGLE_CUDA
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
   // Graph launch kernels depend on Cuda Graph API.
   RegisterGraphLaunchCustomCalls(registry);
   RegisterConcurrentRegionCustomCalls(registry);
   RegisterStreamSynchronizationCustomCalls(registry);
-
+  RegisterCubSortCustomCalls(registry);
   RegisterXlaClassicCustomCalls(registry);
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 }
@@ -147,10 +146,12 @@
 
 GpuRuntimeExecutable::GpuRuntimeExecutable(
     std::string module_name, std::vector<int64_t> buffer_sizes,
+    std::vector<std::vector<int64_t>> allocation_indices,
     std::unique_ptr<JitExecutable> jit_executable, DebugOptions debug_options,
     ModulesState modules_state)
     : module_name_(std::move(module_name)),
       buffer_sizes_(std::move(buffer_sizes)),
+      allocation_indices_(std::move(allocation_indices)),
       executable_(std::move(jit_executable)),
       debug_options_(std::move(debug_options)),
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
@@ -162,10 +163,12 @@
 
 GpuRuntimeExecutable::GpuRuntimeExecutable(
     std::string module_name, std::vector<int64_t> buffer_sizes,
+    std::vector<std::vector<int64_t>> allocation_indices,
     std::unique_ptr<Executable> aot_executable, DebugOptions debug_options,
     ModulesState modules_state)
     : module_name_(std::move(module_name)),
       buffer_sizes_(std::move(buffer_sizes)),
+      allocation_indices_(std::move(allocation_indices)),
       executable_(std::move(aot_executable)),
       debug_options_(std::move(debug_options)),
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
@@ -232,6 +235,7 @@
 
   return std::unique_ptr<GpuRuntimeExecutable>(new GpuRuntimeExecutable(
       std::move(module_name), std::move(program->buffer_sizes),
+      std::move(program->allocation_indices),
       std::make_unique<JitExecutable>(std::move(*jit_executable)),
       std::move(program->debug_options), std::move(*modules_state)));
 }
@@ -241,10 +245,10 @@
 //===----------------------------------------------------------------------===//
 
 /*static*/ StatusOr<std::unique_ptr<GpuRuntimeExecutable>>
-GpuRuntimeExecutable::Create(std::string module_name,
-                             absl::Span<const int64_t> buffer_sizes,
-                             Executable executable,
-                             DebugOptions debug_options) {
+GpuRuntimeExecutable::Create(
+    std::string module_name, std::vector<int64_t> buffer_sizes,
+    std::vector<std::vector<int64_t>> allocation_indices, Executable executable,
+    DebugOptions debug_options) {
   // Instantiate state for all registered runtime modules.
   auto modules_state = ModulesState::Instantiate();
   if (!modules_state.ok())
@@ -252,8 +256,8 @@
                          modules_state.status().message());
 
   return std::unique_ptr<GpuRuntimeExecutable>(new GpuRuntimeExecutable(
-      std::move(module_name),
-      std::vector<int64_t>(buffer_sizes.begin(), buffer_sizes.end()),
+      std::move(module_name), std::move(buffer_sizes),
+      std::move(allocation_indices),
       std::make_unique<Executable>(std::move(executable)),
       std::move(debug_options), std::move(*modules_state)));
 }
@@ -389,8 +393,6 @@
       executor_graphs->snapshot();
   CapturedFunctionExecutionCount::Snapshot execution_count =
       captured_function_counts_(executor)->snapshot();
-  OrdinalToFallback::Snapshot ordinal_to_fallback =
-      ordinal_to_fallback_.snapshot();
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
   // Kernels in concurrent regions should be launched on borrowed stream, so
@@ -404,7 +406,7 @@
   FftPlans::Snapshot fft_plans = fft_plans_.snapshot();
 
 #if GOOGLE_CUDA || TF_HIPBLASLT
-  MatmulPlans::Snapshot matmul_plans = cublas_lt_matmul_plans_.snapshot();
+  MatmulPlans::Snapshot matmul_plans = gpublas_lt_matmul_plans_.snapshot();
 #endif
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
@@ -429,7 +431,7 @@
       &fused_attention_runners, &fused_attention_backward_runners,
 #endif  // GOOGLE_CUDA
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-      &graph_instances, &execution_count, &ordinal_to_fallback,
+      &graph_instances, &execution_count,
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
       &concurrent_region_status,
       // Null pointer will be interpreted as an absence of async collectives
@@ -448,25 +450,9 @@
   // Instantiate all CUDA graphs before executing the main function.
   if (debug_options_.xla_gpu_graph_num_runs_to_instantiate() < 0 &&
       !graph_instances_.InstantiatedAllGraphs(run_options, executable)) {
-    // To instantiate all Gpu graphs we have to pass a valid device pointer
-    // because some device operations in XLA (e.g. memcpy) query device
-    // information from a pointer. We have to find the largest allocation
-    // available, to guarantee that all memref slices are within bounds,
-    // otherwise we might get crashes from a Gpu driver.
-    void* device_ptr = temp_buffer.opaque();
-    size_t device_ptr_size = temp_buffer.size();
-
-    for (unsigned i = 0; i < buffer_allocations.size(); ++i) {
-      auto mem = buffer_allocations.GetDeviceAddress(i);
-      if (mem.size() > device_ptr_size) {
-        device_ptr = mem.opaque();
-        device_ptr_size = mem.size();
-      }
-    }
-
     if (auto instantiated = graph_instances_.InstantiateAllGraphs(
-            run_options, executable, user_data, device_ptr,
-            &ordinal_to_fallback,
+            run_options, executable, user_data, buffer_allocations,
+            buffer_sizes_, allocation_indices_,
             debug_options_.xla_gpu_graph_eviction_timeout_seconds());
         !instantiated.ok()) {
       return InternalError("Failed to instantiate GPU graphs: %s",
diff --git a/third_party/xla/xla/service/gpu/runtime/executable.h b/third_party/xla/xla/service/gpu/runtime/executable.h
index 5b4744f..d8f8018 100644
--- a/third_party/xla/xla/service/gpu/runtime/executable.h
+++ b/third_party/xla/xla/service/gpu/runtime/executable.h
@@ -16,6 +16,7 @@
 #ifndef XLA_SERVICE_GPU_RUNTIME_EXECUTABLE_H_
 #define XLA_SERVICE_GPU_RUNTIME_EXECUTABLE_H_
 
+#include <cstdint>
 #include <memory>
 #include <string>
 #include <string_view>
@@ -30,10 +31,10 @@
 #include "xla/service/gpu/non_atomically_upgradeable_rw_lock.h"
 #include "xla/service/gpu/runtime/collectives.h"
 #include "xla/service/gpu/runtime/conv.h"
-#include "xla/service/gpu/runtime/cublas_lt_matmul.h"
 #include "xla/service/gpu/runtime/fft.h"
 #include "xla/service/gpu/runtime/fused_attention.h"
 #include "xla/service/gpu/runtime/gemm.h"
+#include "xla/service/gpu/runtime/gpublas_lt_matmul.h"
 #include "xla/service/gpu/runtime/graph_launch.h"
 #include "xla/service/gpu/runtime/kernel_launch.h"
 #include "xla/service/service_executable_run_options.h"
@@ -64,15 +65,18 @@
 struct GpuRuntimeProgram {
   GpuRuntimeProgram(std::string entry_point, std::string module,
                     std::vector<int64_t> buffer_sizes,
+                    std::vector<std::vector<int64_t>> allocation_indices,
                     DebugOptions debug_options)
       : entry_point(std::move(entry_point)),
         module(std::move(module)),
         buffer_sizes(std::move(buffer_sizes)),
+        allocation_indices(std::move(allocation_indices)),
         debug_options(std::move(debug_options)) {}
 
   std::string entry_point;
   std::string module;
   std::vector<int64_t> buffer_sizes;
+  std::vector<std::vector<int64_t>> allocation_indices;
   DebugOptions debug_options;
 };
 
@@ -96,7 +100,8 @@
 
   // Creates GpuRuntimeExecutable from the AOT compiled binary.
   static StatusOr<std::unique_ptr<GpuRuntimeExecutable>> Create(
-      std::string module_name, absl::Span<const int64_t> buffer_sizes,
+      std::string module_name, std::vector<int64_t> buffer_sizes,
+      std::vector<std::vector<int64_t>> allocation_indices,
       runtime::Executable executable, DebugOptions debug_options);
 
   // Executes entry function with the given buffer arguments.
@@ -119,11 +124,13 @@
  private:
   GpuRuntimeExecutable(std::string module_name,
                        std::vector<int64_t> buffer_sizes,
+                       std::vector<std::vector<int64_t>> allocation_indices,
                        std::unique_ptr<runtime::JitExecutable> jit_executable,
                        DebugOptions debug_options, ModulesState modules_state);
 
   GpuRuntimeExecutable(std::string module_name,
                        std::vector<int64_t> buffer_sizes,
+                       std::vector<std::vector<int64_t>> allocation_indices,
                        std::unique_ptr<runtime::Executable> aot_executable,
                        DebugOptions debug_options, ModulesState modules_state);
 
@@ -135,6 +142,10 @@
 
   std::vector<int64_t> buffer_sizes_;
 
+  // `rt.allocation_index` attributes for all exported functions. Indexed by
+  // function ordinal.
+  std::vector<std::vector<int64_t>> allocation_indices_;
+
   // In JIT compilation mode `JitExecutable` is used. In AOT compilation mode
   // `Executable` is used.
   std::variant<std::unique_ptr<runtime::JitExecutable>,
@@ -168,14 +179,13 @@
   FftPlans fft_plans_;
 
 #if GOOGLE_CUDA || TF_HIPBLASLT  // Keep matmul execution plans.
-  MatmulPlans cublas_lt_matmul_plans_;
+  MatmulPlans gpublas_lt_matmul_plans_;
 #endif
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
   // Keep captured and instantiated GPU graphs instances.
   GraphInstances graph_instances_;
   CapturedFunctionExecutionCounts captured_function_counts_;
-  OrdinalToFallback ordinal_to_fallback_;
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
   // Keep an executable state for all registered runtime modules.
diff --git a/third_party/xla/xla/service/gpu/runtime/fft.cc b/third_party/xla/xla/service/gpu/runtime/fft.cc
index fda8f01..a668a8e 100644
--- a/third_party/xla/xla/service/gpu/runtime/fft.cc
+++ b/third_party/xla/xla/service/gpu/runtime/fft.cc
@@ -21,8 +21,8 @@
 #include "xla/runtime/custom_call.h"
 #include "xla/runtime/executable.h"
 #include "xla/runtime/state.h"
-#include "xla/service/gpu/fft_thunk.h"
 #include "xla/service/gpu/runtime/support.h"
+#include "xla/service/gpu/runtime3/fft_thunk.h"
 #include "xla/stream_executor/fft.h"
 
 namespace xla {
diff --git a/third_party/xla/xla/service/gpu/runtime/fft.h b/third_party/xla/xla/service/gpu/runtime/fft.h
index bc45e31..7a34e2d 100644
--- a/third_party/xla/xla/service/gpu/runtime/fft.h
+++ b/third_party/xla/xla/service/gpu/runtime/fft.h
@@ -20,7 +20,7 @@
 
 #include "xla/mlir/runtime/transforms/custom_call_encoding.h"
 #include "xla/runtime/custom_call_registry.h"
-#include "xla/service/gpu/fft_thunk.h"
+#include "xla/service/gpu/runtime3/fft_thunk.h"
 
 namespace xla {
 namespace gpu {
diff --git a/third_party/xla/xla/service/gpu/runtime/gemm.cc b/third_party/xla/xla/service/gpu/runtime/gemm.cc
index 29edea4..a57e43e 100644
--- a/third_party/xla/xla/service/gpu/runtime/gemm.cc
+++ b/third_party/xla/xla/service/gpu/runtime/gemm.cc
@@ -15,6 +15,7 @@
 
 #include "xla/service/gpu/runtime/gemm.h"
 
+#include <cstdint>
 #include <limits>
 #include <optional>
 #include <utility>
@@ -33,6 +34,7 @@
 #include "xla/stream_executor/blas.h"
 #include "xla/stream_executor/device_memory.h"
 #include "xla/xla.pb.h"
+#include "tsl/platform/errors.h"
 
 #if GOOGLE_CUDA
 #include "xla/service/gpu/gemm_algorithm_picker.h"
@@ -48,6 +50,7 @@
 
 #if GOOGLE_CUDA
 
+// TODO(ezhulenev): Delete run time auto tuning from XLA.
 Status DoRuntimeAutotuning(se::Stream* stream, GemmConfig& config,
                            se::DeviceMemoryBase lhs, se::DeviceMemoryBase rhs,
                            se::DeviceMemoryBase out, const Shape& output_shape,
@@ -92,8 +95,9 @@
             // we pass a non-null ProfileResult, DoGemmWithAlgorithm should
             // always return true, and the actual success-ness is returned in
             // ProfileResult::is_valid.
-            TF_RETURN_IF_ERROR(RunGemm(config, lhs, rhs, out, deterministic_ops,
-                                       stream, algorithm, &profile_result));
+            TF_RETURN_IF_ERROR(
+                RunGemm(config, lhs, rhs, out, se::DeviceMemoryBase(nullptr, 0),
+                        deterministic_ops, stream, algorithm, &profile_result));
             return std::move(profile_result);
           }));
 
@@ -111,13 +115,14 @@
                              NonAtomicallyUpgradeableRWLock* gpu_lock,
                              State<GemmConfig> state, StridedMemrefView lhs,
                              StridedMemrefView rhs, StridedMemrefView out,
-                             int64_t algorithm, double alpha_real,
-                             double alpha_imag, double beta,
+                             StridedMemrefView workspace, int64_t algorithm,
+                             double alpha_real, double alpha_imag, double beta,
                              DotDimensionNumbers dot_dims,
                              absl::Span<const int32_t> precision) {
   se::DeviceMemoryBase lhs_data = GetDeviceAddress(lhs);
   se::DeviceMemoryBase rhs_data = GetDeviceAddress(rhs);
   se::DeviceMemoryBase output_data = GetDeviceAddress(out);
+  se::DeviceMemoryBase workspace_data = GetDeviceAddress(workspace);
   const bool deterministic_ops = debug_options->xla_gpu_deterministic_ops();
 
   VLOG(3) << "Running GEMM";
@@ -152,7 +157,7 @@
 #endif
   }
 
-  return RunGemm(*gemm_config, lhs_data, rhs_data, output_data,
+  return RunGemm(*gemm_config, lhs_data, rhs_data, output_data, workspace_data,
                  deterministic_ops, stream);
 }
 
@@ -177,6 +182,7 @@
         .Arg<StridedMemrefView>()  // lhs
         .Arg<StridedMemrefView>()  // rhs
         .Arg<StridedMemrefView>()  // out
+        .Arg<StridedMemrefView>()  // workspace
         .Attr<int64_t>("algorithm")
         .Attr<double>("alpha_real")
         .Attr<double>("alpha_imag")
diff --git a/third_party/xla/xla/service/gpu/runtime/cublas_lt_matmul.cc b/third_party/xla/xla/service/gpu/runtime/gpublas_lt_matmul.cc
similarity index 99%
rename from third_party/xla/xla/service/gpu/runtime/cublas_lt_matmul.cc
rename to third_party/xla/xla/service/gpu/runtime/gpublas_lt_matmul.cc
index 0fa8351..c008bf5 100644
--- a/third_party/xla/xla/service/gpu/runtime/cublas_lt_matmul.cc
+++ b/third_party/xla/xla/service/gpu/runtime/gpublas_lt_matmul.cc
@@ -13,7 +13,7 @@
 limitations under the License.1
 ==============================================================================*/
 
-#include "xla/service/gpu/runtime/cublas_lt_matmul.h"
+#include "xla/service/gpu/runtime/gpublas_lt_matmul.h"
 
 #include <optional>
 #include <string>
diff --git a/third_party/xla/xla/service/gpu/runtime/cublas_lt_matmul.h b/third_party/xla/xla/service/gpu/runtime/gpublas_lt_matmul.h
similarity index 90%
rename from third_party/xla/xla/service/gpu/runtime/cublas_lt_matmul.h
rename to third_party/xla/xla/service/gpu/runtime/gpublas_lt_matmul.h
index e1c1f12..be85ea6 100644
--- a/third_party/xla/xla/service/gpu/runtime/cublas_lt_matmul.h
+++ b/third_party/xla/xla/service/gpu/runtime/gpublas_lt_matmul.h
@@ -13,8 +13,8 @@
 limitations under the License.
 ==============================================================================*/
 
-#ifndef XLA_SERVICE_GPU_RUNTIME_CUBLAS_LT_MATMUL_H_
-#define XLA_SERVICE_GPU_RUNTIME_CUBLAS_LT_MATMUL_H_
+#ifndef XLA_SERVICE_GPU_RUNTIME_GPUBLAS_LT_MATMUL_H_
+#define XLA_SERVICE_GPU_RUNTIME_GPUBLAS_LT_MATMUL_H_
 
 #include "xla/mlir/runtime/transforms/custom_call_encoding.h"
 #include "xla/runtime/custom_call_registry.h"
@@ -43,4 +43,4 @@
 }  // namespace gpu
 }  // namespace xla
 
-#endif  // XLA_SERVICE_GPU_RUNTIME_CUBLAS_LT_MATMUL_H_
+#endif  // XLA_SERVICE_GPU_RUNTIME_GPUBLAS_LT_MATMUL_H_
diff --git a/third_party/xla/xla/service/gpu/runtime/graph_launch.cc b/third_party/xla/xla/service/gpu/runtime/graph_launch.cc
index fa42e13..30b6e40 100644
--- a/third_party/xla/xla/service/gpu/runtime/graph_launch.cc
+++ b/third_party/xla/xla/service/gpu/runtime/graph_launch.cc
@@ -31,8 +31,10 @@
 #include "absl/log/log.h"
 #include "absl/status/status.h"
 #include "absl/synchronization/mutex.h"
+#include "absl/types/span.h"
 #include "xla/runtime/custom_call.h"
 #include "xla/runtime/executable.h"
+#include "xla/service/gpu/buffer_allocations.h"
 #include "xla/service/gpu/non_atomically_upgradeable_rw_lock.h"
 #include "xla/service/gpu/runtime/concurrent_region.h"
 #include "xla/service/gpu/runtime/conv.h"
@@ -270,7 +272,9 @@
 Status GraphInstances::InstantiateAllGraphs(
     const ServiceExecutableRunOptions* run_options,
     const Executable& executable, const CustomCall::UserData& user_data,
-    void* ptr, OrdinalToFallback::Snapshot* ordinal_to_fallback,
+    const BufferAllocations& buffer_allocations,
+    absl::Span<const int64_t> buffer_sizes,
+    absl::Span<const std::vector<int64_t>> allocation_indices,
     std::optional<uint64_t> eviction_timeout_seconds) {
   // We have only "main" function in the executable.
   if (executable.num_functions() == 1) return OkStatus();
@@ -303,9 +307,6 @@
                           "xla.gpu.graph.capture"))
       continue;
 
-    StatusOr<std::monostate*> fallback = ordinal_to_fallback->Get(ordinal);
-    if (fallback.ok()) continue;
-
     VLOG(3) << "Instantiate Gpu graph defined by capture function @"
             << executable.function_name(ordinal) << " (ordinal = " << ordinal
             << ")";
@@ -320,6 +321,12 @@
     assert(signature.num_results() == 0 && "unexpected number of results");
     Arguments<MemrefDesc> args(signature.num_operands());
 
+    // Mapping from graph capture argument to buffer allocation index.
+    absl::Span<const int64_t> capture_allocs = allocation_indices[ordinal];
+    if (capture_allocs.size() != signature.num_operands())
+      return absl::InternalError(
+          "Invalid number of allocation indices for a graph capture function");
+
     // Prepare arguments for the graph capture function.
     for (size_t j = 0; j < signature.num_operands(); ++j) {
       auto* memref = llvm::dyn_cast<MemrefType>(signature.operand(j));
@@ -336,8 +343,11 @@
       std::array<int64_t, 1> sizes = {memref->size(0)};
       std::array<int64_t, 1> strides = {1};
 
-      args.emplace_back<MemrefDesc>(memref->element_type(), ptr,
-                                    /*offset=*/0, sizes, strides);
+      int64_t allocation_index = capture_allocs[j];
+      args.emplace_back<MemrefDesc>(
+          memref->element_type(),
+          buffer_allocations.GetDeviceAddress(allocation_index).opaque(),
+          /*offset=*/0, sizes, strides);
     }
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
@@ -557,7 +567,6 @@
     StreamExecutorConvRunners::Snapshot* convs,
     StreamExecutorGraphInstances::Snapshot* instances,
     CapturedFunctionExecutionCount::Snapshot* counts,
-    OrdinalToFallback::Snapshot* ordinal_to_fallback,
     GemmConfigs::Snapshot* gemm_config, runtime::Executable* executable,
     NonAtomicallyUpgradeableRWLock* gpu_lock,
     ConcurrentRegionStatus* region_status, CustomCall::RemainingArgs fwd_args,
@@ -591,10 +600,7 @@
   // work around disable graph execution and run everything in op-by-op mode.
   bool is_profiling = tsl::profiler::ProfilerLock::HasActiveSession();
 
-  StatusOr<std::monostate*> fallback =
-      ordinal_to_fallback->Get(capture.ordinal);
-
-  if (count < num_runs_to_instantiate || is_profiling || fallback.ok()) {
+  if (count < num_runs_to_instantiate || is_profiling) {
     VLOG(3) << "Run gpu graph in op-by-op mode: ordinal = " << capture.ordinal;
     return RunGraphOpByOp(run_options, function_ref, fwd_args, user_data());
   }
@@ -656,43 +662,22 @@
   TF_ASSIGN_OR_RETURN(
       auto g, CaptureGraph(run_options, function_ref, args, user_data()));
 
-  se::gpu::OwnedGpuGraphExec::UpdateResult update_result;
-  {
-    // At this point we have to grab a writer lock, because we might potentially
-    // have concurrent execution of the cached graph instance.
-    absl::WriterMutexLock lock(instance->mutex.get());
+  // At this point we have to grab a writer lock, because we might potentially
+  // have concurrent execution of the cached graph instance.
+  absl::WriterMutexLock lock(instance->mutex.get());
 
-    // Update captured graph executable.
-    TF_ASSIGN_OR_RETURN(update_result, instance->exec.Update(std::move(g)));
-  }
+  // Update captured graph executable.
+  TF_RETURN_IF_ERROR(instance->exec.Update(std::move(g)));
 
-  switch (update_result) {
-    case se::gpu::OwnedGpuGraphExec::UpdateResult::kFallback: {
-      LOG(WARNING) << "Fallback to op-by-op mode because memset node breaks "
-                      "graph update";
-      // Deallocate instance.
-      TF_RETURN_IF_ERROR(instances->Erase(capture.ordinal));
-      // Set ordinal_to_fallback to prevent future instantiation of this graph.
-      TF_ASSIGN_OR_RETURN(
-          std::monostate * fallback,
-          ordinal_to_fallback->GetOrCreate(
-              capture.ordinal,
-              []() -> StatusOr<std::monostate> { return std::monostate{}; }));
-      DCHECK(fallback);
-      return RunGraphOpByOp(run_options, function_ref, fwd_args, user_data());
-    }
-    case se::gpu::OwnedGpuGraphExec::UpdateResult::kSuccess:
-      // Update captured pointer hash.
-      absl::WriterMutexLock lock(instance->mutex.get());
-      instance->ptr_hash = ptrs_hash;
+  // Update captured pointer hash.
+  instance->ptr_hash = ptrs_hash;
 
-      TraceMe trace([&] {
-        return TraceMeEncode("gpu.graph.launch_updated",
-                             {{"ordinal", capture.ordinal}});
-      });
+  TraceMe trace([&] {
+    return TraceMeEncode("gpu.graph.launch_updated",
+                         {{"ordinal", capture.ordinal}});
+  });
 
-      return instance->exec.Launch(run_options->stream());
-  }
+  return instance->exec.Launch(run_options->stream());
 
 #else  // #if !GOOGLE_CUDA && !TENSORFLOW_USE_ROCM
 
@@ -715,7 +700,6 @@
         .UserData<StreamExecutorConvRunners::Snapshot*>()
         .UserData<StreamExecutorGraphInstances::Snapshot*>()
         .UserData<CapturedFunctionExecutionCount::Snapshot*>()
-        .UserData<OrdinalToFallback::Snapshot*>()
         .UserData<GemmConfigs::Snapshot*>()
         .UserData<Executable*>()
         .UserData<NonAtomicallyUpgradeableRWLock*>()
diff --git a/third_party/xla/xla/service/gpu/runtime/graph_launch.h b/third_party/xla/xla/service/gpu/runtime/graph_launch.h
index e484154..6fe5145 100644
--- a/third_party/xla/xla/service/gpu/runtime/graph_launch.h
+++ b/third_party/xla/xla/service/gpu/runtime/graph_launch.h
@@ -17,6 +17,7 @@
 #define XLA_SERVICE_GPU_RUNTIME_GRAPH_LAUNCH_H_
 
 #include <atomic>
+#include <cstdint>
 #include <memory>
 #include <optional>
 #include <string>
@@ -24,8 +25,10 @@
 #include <variant>
 
 #include "absl/container/node_hash_map.h"
+#include "absl/types/span.h"
 #include "xla/runtime/custom_call_registry.h"
 #include "xla/runtime/executable.h"
+#include "xla/service/gpu/buffer_allocations.h"
 #include "xla/service/service_executable_run_options.h"
 #include "xla/stream_executor/stream_executor.h"
 
@@ -48,10 +51,6 @@
 class CapturedFunctionExecutionCount
     : public runtime::StateVector<std::unique_ptr<std::atomic<uint64_t>>> {};
 
-// Create the i-th value if the capture function with ordinal i causes graph
-// update failure.
-class OrdinalToFallback : public runtime::StateVector<std::monostate> {};
-
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 // A state vector that owns all instantiated GPU graphs. Graph capture function
@@ -110,8 +109,10 @@
   Status InstantiateAllGraphs(
       const ServiceExecutableRunOptions* run_options,
       const runtime::Executable& executable,
-      const runtime::CustomCall::UserData& user_data, void* ptr,
-      OrdinalToFallback::Snapshot* ordinal_to_fallback,
+      const runtime::CustomCall::UserData& user_data,
+      const BufferAllocations& buffer_allocations,
+      absl::Span<const int64_t> buffer_sizes,
+      absl::Span<const std::vector<int64_t>> allocation_indices,
       std::optional<uint64_t> eviction_timeout_seconds = std::nullopt);
 
   // Returns true if all Gpu graphs were already instantiated.
diff --git a/third_party/xla/xla/service/gpu/runtime/kernel_launch.cc b/third_party/xla/xla/service/gpu/runtime/kernel_launch.cc
index 6e3a7ec..78165c3c 100644
--- a/third_party/xla/xla/service/gpu/runtime/kernel_launch.cc
+++ b/third_party/xla/xla/service/gpu/runtime/kernel_launch.cc
@@ -80,22 +80,23 @@
   assert((*kernel)->name() == name && "unexpected loaded kernel");
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-  TF_ASSIGN_OR_RETURN(bool is_capturing, se::gpu::IsStreamCapturing(stream));
-#else
-  bool is_capturing = false;
-#endif
-
-  if (is_capturing) {
-    if (region_status->IsInConcurrentRegion()) {
-      VLOG(3) << "Launching " << (*kernel)->name()
-              << "in a concurrent region during GPU graph capture";
+  if (VLOG_IS_ON(3)) {
+    TF_ASSIGN_OR_RETURN(bool is_capturing, se::gpu::IsStreamCapturing(stream));
+    if (is_capturing) {
+      if (region_status->IsInConcurrentRegion()) {
+        LOG(INFO) << "Launching " << (*kernel)->name()
+                  << "in a concurrent region during GPU graph capture";
+      } else {
+        LOG(INFO) << "Launching " << (*kernel)->name()
+                  << "during GPU graph capture";
+      }
     } else {
-      VLOG(3) << "Launching " << (*kernel)->name()
-              << "during GPU graph capture";
+      LOG(INFO) << "Launching " << (*kernel)->name();
     }
-  } else {
-    VLOG(3) << "Launching " << (*kernel)->name();
   }
+#else
+  VLOG(3) << "Launching " << (*kernel)->name();
+#endif
 
   absl::InlinedVector<se::DeviceMemoryBase, 8> buffer_args(
       args_size_including_temp_buffer);
diff --git a/third_party/xla/xla/service/gpu/runtime/support.h b/third_party/xla/xla/service/gpu/runtime/support.h
index 57bce21..98cefce 100644
--- a/third_party/xla/xla/service/gpu/runtime/support.h
+++ b/third_party/xla/xla/service/gpu/runtime/support.h
@@ -72,6 +72,9 @@
     const runtime::StridedMemrefView& memref) {
   uint64_t size = primitive_util::ByteWidth(memref.dtype);
   for (auto dim : memref.sizes) size *= dim;
+  if (primitive_util::Is4BitType(memref.dtype)) {
+    size = (size + 1) / 2;
+  }
   return se::DeviceMemoryBase(memref.data, size);
 }
 
@@ -101,7 +104,8 @@
     absl::Span<const int64_t> lhs_contract, absl::Span<const int64_t> rhs_batch,
     absl::Span<const int64_t> rhs_contract, int64_t compute_precision,
     const std::optional<runtime::StridedMemrefView> c = std::nullopt,
-    const std::optional<runtime::StridedMemrefView>& bias = std::nullopt) {
+    const std::optional<runtime::StridedMemrefView>& bias = std::nullopt,
+    bool grad_x = false, bool grad_y = false) {
   Shape c_shape = ToShape(c.value_or(out));
   Shape bias_shape;
   Shape* bias_shape_ptr = nullptr;
@@ -112,7 +116,7 @@
   return GemmConfig::For(ToShape(lhs), lhs_batch, lhs_contract, ToShape(rhs),
                          rhs_batch, rhs_contract, c_shape, bias_shape_ptr,
                          ToShape(out), alpha_real, alpha_imag, beta, algorithm,
-                         compute_precision);
+                         compute_precision, grad_x, grad_y);
 }
 
 // adds Dot Dimension Attribute encodings for calls to Gemm and cuBLASLt
diff --git a/third_party/xla/xla/service/gpu/runtime3/BUILD b/third_party/xla/xla/service/gpu/runtime3/BUILD
index 0ff859f..efdb945 100644
--- a/third_party/xla/xla/service/gpu/runtime3/BUILD
+++ b/third_party/xla/xla/service/gpu/runtime3/BUILD
@@ -146,19 +146,51 @@
     ]),
     visibility = ["//visibility:public"],
     deps = [
+        "//xla:executable_run_options",
+        "//xla:shape_util",
+        "//xla:status",
         "//xla:util",
+        "//xla/ffi",
+        "//xla/ffi:call_frame",
+        "//xla/ffi/api:c_api",
         "//xla/service:buffer_assignment",
+        "//xla/service:custom_call_status",
         "//xla/service:custom_call_status_internal",
+        "//xla/service:executable",
         "//xla/service/gpu:buffer_allocations",
         "//xla/service/gpu:thunk",
+        "//xla/stream_executor:device_memory",
         "//xla/stream_executor/gpu:gpu_stream_header",
         "//xla/stream_executor/gpu:gpu_types_header",
+        "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/strings:str_format",
         "@local_tsl//tsl/platform:errors",
     ],
 )
 
 cc_library(
+    name = "fft_thunk",
+    srcs = ["fft_thunk.cc"],
+    hdrs = ["fft_thunk.h"],
+    visibility = ["//visibility:public"],
+    deps = [
+        "//xla:types",
+        "//xla:util",
+        "//xla:xla_data_proto_cc",
+        "//xla/hlo/ir:hlo",
+        "//xla/service:buffer_assignment",
+        "//xla/service/gpu:buffer_allocations",
+        "//xla/service/gpu:thunk",
+        "//xla/stream_executor",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/strings:str_format",
+        "@local_tsl//tsl/platform:logging",
+        "@local_tsl//tsl/platform:status",
+    ],
+)
+
+cc_library(
     name = "triangular_solve_thunk",
     srcs = if_gpu_is_configured(["triangular_solve_thunk.cc"]),
     hdrs = if_gpu_is_configured(["triangular_solve_thunk.h"]),
diff --git a/third_party/xla/xla/service/gpu/runtime3/custom_call_thunk.cc b/third_party/xla/xla/service/gpu/runtime3/custom_call_thunk.cc
index 3402467..b53a258 100644
--- a/third_party/xla/xla/service/gpu/runtime3/custom_call_thunk.cc
+++ b/third_party/xla/xla/service/gpu/runtime3/custom_call_thunk.cc
@@ -15,10 +15,24 @@
 
 #include "xla/service/gpu/runtime3/custom_call_thunk.h"
 
+#include <optional>
+#include <string>
+#include <utility>
+#include <vector>
+
 #include "absl/strings/str_format.h"
+#include "xla/executable_run_options.h"
+#include "xla/ffi/api/c_api.h"
+#include "xla/ffi/call_frame.h"
+#include "xla/ffi/ffi.h"
 #include "xla/service/buffer_assignment.h"
+#include "xla/service/custom_call_status.h"
+#include "xla/service/custom_call_status_internal.h"
+#include "xla/service/gpu/thunk.h"
+#include "xla/service/service_executable_run_options.h"
+#include "xla/status.h"
+#include "xla/stream_executor/device_memory.h"
 #include "xla/util.h"
-#include "tsl/platform/errors.h"
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 #include "xla/stream_executor/gpu/gpu_stream.h"
@@ -27,31 +41,47 @@
 namespace xla {
 namespace gpu {
 
+using xla::ffi::CallFrame;
+using xla::ffi::CallFrameBuilder;
+using xla::ffi::CallOptions;
+
 CustomCallThunk::CustomCallThunk(ThunkInfo thunk_info,
                                  CustomCallTarget call_target,
-                                 std::vector<OptionalSlice> operands,
-                                 std::vector<OptionalSlice> results,
+                                 std::vector<std::optional<Slice>> operands,
+                                 std::vector<std::optional<Slice>> results,
                                  const std::string& opaque)
     : Thunk(Thunk::kCustomCall, thunk_info),
-      call_target_(std::move(call_target)),
       operands_(std::move(operands)),
       results_(std::move(results)),
+      call_target_(std::move(call_target)),
       opaque_(opaque) {}
 
-Status CustomCallThunk::ExecuteOnStream(const ExecuteParams& params) {
+CustomCallThunk::CustomCallThunk(ThunkInfo thunk_info, XLA_FFI_Handler* handler,
+                                 std::vector<std::optional<Slice>> operands,
+                                 std::vector<std::optional<Slice>> results,
+                                 AttributesMap attributes)
+    : Thunk(Thunk::kCustomCall, thunk_info),
+      operands_(std::move(operands)),
+      results_(std::move(results)),
+      handler_(std::move(handler)),
+      attributes_(std::move(attributes)) {}
+
+Status CustomCallThunk::ExecuteCustomCall(const ExecuteParams& params) {
   // gpu_stream is CUstream or e.g. the equivalent type in ROCm.
   std::vector<void*> buffers;
   buffers.reserve(operands_.size() + results_.size());
-  for (const std::vector<OptionalSlice>& slices : {operands_, results_}) {
-    for (const OptionalSlice& slice : slices) {
-      if (slice) {
-        if (!slice->allocation())
-          return InternalError("custom call input missing buffer allocation");
-        buffers.push_back(
-            params.buffer_allocations->GetDeviceAddress(*slice).opaque());
-      } else {
+  for (auto& slices : {operands_, results_}) {
+    for (const std::optional<Slice>& slice : slices) {
+      if (!slice.has_value()) {
         buffers.push_back(nullptr);
+        continue;
       }
+
+      if (!slice->slice.allocation())
+        return InternalError("custom call input missing buffer allocation");
+
+      buffers.push_back(
+          params.buffer_allocations->GetDeviceAddress(slice->slice).opaque());
     }
   }
 
@@ -73,5 +103,44 @@
 #endif  //   GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 }
 
+Status CustomCallThunk::ExecuteFfiHandler(const ExecuteParams& params) {
+  // TODO(ezhulenev): This is not the most optimal approach, as we'll be doing
+  // a lot of extra allocation on every call. We have to keep attributes
+  // separate from arguments, as they do not change after thunk is constructed.
+  CallFrameBuilder builder;
+
+  for (auto& slices : {operands_, results_}) {
+    for (const std::optional<Slice>& slice : slices) {
+      // TODO(ezhulenev): Add a token argument type to XLA:FFI.
+      if (!slice.has_value()) {
+        return InternalError("FFI handlers do not support tokens (yet)!");
+      }
+
+      if (!slice->slice.allocation())
+        return InternalError("custom call input missing buffer allocation");
+
+      builder.AddBufferArg(
+          params.buffer_allocations->GetDeviceAddress(slice->slice),
+          slice->shape.element_type(), slice->shape.dimensions());
+    }
+  }
+
+  builder.AddAttributes(attributes_);
+  CallFrame call_frame = builder.Build();
+
+  // TODO(ezhulenev): Remove `ServiceExecutableRunOptions` from FFI handler
+  // execution context, as apparently it's not easily accessible from Thunk.
+  ExecutableRunOptions run_options;
+  run_options.set_stream(params.stream);
+  ServiceExecutableRunOptions service_run_options(run_options);
+
+  CallOptions options = {&service_run_options};
+  return Call(handler_, call_frame, options);
+}
+
+Status CustomCallThunk::ExecuteOnStream(const ExecuteParams& params) {
+  return handler_ ? ExecuteFfiHandler(params) : ExecuteCustomCall(params);
+}
+
 }  // namespace gpu
 }  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/runtime3/custom_call_thunk.h b/third_party/xla/xla/service/gpu/runtime3/custom_call_thunk.h
index c4724c3..f544485 100644
--- a/third_party/xla/xla/service/gpu/runtime3/custom_call_thunk.h
+++ b/third_party/xla/xla/service/gpu/runtime3/custom_call_thunk.h
@@ -16,9 +16,21 @@
 #ifndef XLA_SERVICE_GPU_RUNTIME3_CUSTOM_CALL_THUNK_H_
 #define XLA_SERVICE_GPU_RUNTIME3_CUSTOM_CALL_THUNK_H_
 
-#include "xla/service/custom_call_status_internal.h"
-#include "xla/service/gpu/buffer_allocations.h"
+#include <cstddef>
+#include <cstdint>
+#include <functional>
+#include <optional>
+#include <string>
+#include <variant>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "xla/ffi/api/c_api.h"
+#include "xla/service/buffer_assignment.h"
+#include "xla/service/custom_call_status.h"
 #include "xla/service/gpu/thunk.h"
+#include "xla/shape.h"
+#include "xla/status.h"
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 #include "xla/stream_executor/gpu/gpu_types.h"
@@ -40,8 +52,6 @@
 // compiler is allowed to create.
 class CustomCallThunk : public Thunk {
  public:
-  using OptionalSlice = ::std::optional<BufferAllocation::Slice>;
-
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
   using Stream = stream_executor::gpu::GpuStreamHandle;
 #else   //  GOOGLE_CUDA || TENSORFLOW_USE_ROCM
@@ -50,18 +60,46 @@
 
   using CustomCallTarget = std::function<void(Stream, void**, const char*,
                                               size_t, XlaCustomCallStatus*)>;
+
+  // We keep buffer allocation slice together with its shape to be able to fill
+  // FFI arguments with required details.
+  struct Slice {
+    BufferAllocation::Slice slice;
+    Shape shape;
+  };
+
+  using Attribute = std::variant<int32_t, float, std::string>;
+  using AttributesMap = absl::flat_hash_map<std::string, Attribute>;
+
   CustomCallThunk(ThunkInfo thunk_info, CustomCallTarget call_target,
-                  std::vector<OptionalSlice> operands,
-                  std::vector<OptionalSlice> results,
+                  std::vector<std::optional<Slice>> operands,
+                  std::vector<std::optional<Slice>> results,
                   const std::string& opaque);
 
+  CustomCallThunk(ThunkInfo thunk_info, XLA_FFI_Handler* handler,
+                  std::vector<std::optional<Slice>> operands,
+                  std::vector<std::optional<Slice>> results,
+                  AttributesMap attributes);
+
   Status ExecuteOnStream(const ExecuteParams& params) override;
 
  private:
-  const CustomCallTarget call_target_;
-  const std::vector<OptionalSlice> operands_;
-  const std::vector<OptionalSlice> results_;
-  const std::string opaque_;
+  Status ExecuteCustomCall(const ExecuteParams& params);
+  Status ExecuteFfiHandler(const ExecuteParams& params);
+
+  std::vector<std::optional<Slice>> operands_;
+  std::vector<std::optional<Slice>> results_;
+
+  // This is a legacy custom call API that is discouraged, and will be
+  // deprecated once XLA:FFI mechanism is ready.
+  CustomCallTarget call_target_;
+  std::string opaque_;
+
+  // XLA FFI provides a right type safe mechanism for registering external
+  // functions with XLA runtime. It's under construction, and still misses
+  // a lot of features. Long term it will replace legacy custom calls.
+  XLA_FFI_Handler* handler_ = nullptr;
+  AttributesMap attributes_;
 };
 
 }  // namespace gpu
diff --git a/third_party/xla/xla/service/gpu/fft_thunk.cc b/third_party/xla/xla/service/gpu/runtime3/fft_thunk.cc
similarity index 99%
rename from third_party/xla/xla/service/gpu/fft_thunk.cc
rename to third_party/xla/xla/service/gpu/runtime3/fft_thunk.cc
index 9198d43..711ae99 100644
--- a/third_party/xla/xla/service/gpu/fft_thunk.cc
+++ b/third_party/xla/xla/service/gpu/runtime3/fft_thunk.cc
@@ -13,7 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "xla/service/gpu/fft_thunk.h"
+#include "xla/service/gpu/runtime3/fft_thunk.h"
 
 #include <string>
 
diff --git a/third_party/xla/xla/service/gpu/fft_thunk.h b/third_party/xla/xla/service/gpu/runtime3/fft_thunk.h
similarity index 95%
rename from third_party/xla/xla/service/gpu/fft_thunk.h
rename to third_party/xla/xla/service/gpu/runtime3/fft_thunk.h
index 6b2224a..4e0de39 100644
--- a/third_party/xla/xla/service/gpu/fft_thunk.h
+++ b/third_party/xla/xla/service/gpu/runtime3/fft_thunk.h
@@ -13,8 +13,8 @@
 limitations under the License.
 ==============================================================================*/
 
-#ifndef XLA_SERVICE_GPU_FFT_THUNK_H_
-#define XLA_SERVICE_GPU_FFT_THUNK_H_
+#ifndef XLA_SERVICE_GPU_RUNTIME3_FFT_THUNK_H_
+#define XLA_SERVICE_GPU_RUNTIME3_FFT_THUNK_H_
 
 #include <optional>
 
@@ -97,4 +97,4 @@
 }  // namespace gpu
 }  // namespace xla
 
-#endif  // XLA_SERVICE_GPU_FFT_THUNK_H_
+#endif  // XLA_SERVICE_GPU_RUNTIME3_FFT_THUNK_H_
diff --git a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc
index 0a8b4f9..345b1b6 100644
--- a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc
+++ b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc
@@ -110,11 +110,11 @@
 
 StatusOr<HloInstruction*> MakeSplitKOperand(
     HloInstruction& dot, const TritonFusionAnalysis& analysis,
-    const AutotuneResult::TritonGemmKey& tiling,
-    const int64_t contracting_dim_idx, const int operand_number) {
+    const TritonGemmConfig& config, const int64_t contracting_dim_idx,
+    const int operand_number) {
   HloInstruction* operand = dot.mutable_operand(operand_number);
   const int64_t k = operand->shape().dimensions(contracting_dim_idx);
-  const bool need_padding = k % tiling.split_k() != 0;
+  const bool need_padding = k % config.split_k != 0;
 
   TritonFusionAnalysis::Scope scope = (operand_number == 0)
                                           ? TritonFusionAnalysis::Scope::LHS
@@ -136,10 +136,10 @@
           "Sliced contracting dimension is not supported yet.");
     }
     if (check_divisibility && !HasDivisibleSuffixAllowingSplit(
-                                  fragment.subfragments, tiling.split_k())) {
+                                  fragment.subfragments, config.split_k)) {
       return UncompilableMatmul("Contracting dimension is too fragmented.");
     }
-    if (tiling.split_k() > ceil(1.0 * fragment.count / tiling.block_k())) {
+    if (config.split_k > ceil(1.0 * fragment.count / config.block_k)) {
       return UncompilableMatmul(
           "Too small divisible part of the contracting dimension.");
     }
@@ -169,11 +169,14 @@
 
     PaddingConfig padding_config = MakeNoPaddingConfig(operand->shape().rank());
     padding_config.mutable_dimensions(contracting_dim_idx)
-        ->set_edge_padding_high(tiling.split_k() - k % tiling.split_k());
+        ->set_edge_padding_high(config.split_k - k % config.split_k);
 
-    TF_ASSIGN_OR_RETURN(operand, MakePadHlo(operand, zero, padding_config));
+    TF_ASSIGN_OR_RETURN(HloInstruction * pad,
+                        MakePadHlo(operand, zero, padding_config));
+    *pad->mutable_shape()->mutable_layout() = operand->shape().layout();
+    operand = pad;
   }
-  CHECK_GE(operand->shape().dimensions(contracting_dim_idx), tiling.split_k());
+  CHECK_GE(operand->shape().dimensions(contracting_dim_idx), config.split_k);
 
   // Add bitcast.
   const Shape& shape = operand->shape();
@@ -182,8 +185,8 @@
   for (int i = 0; i < shape.rank(); ++i) {
     const int64_t dimension_size = shape.dimensions(i);
     if (i == contracting_dim_idx) {
-      new_shape.add_dimensions(tiling.split_k());
-      new_shape.add_dimensions(dimension_size / tiling.split_k());
+      new_shape.add_dimensions(config.split_k);
+      new_shape.add_dimensions(dimension_size / config.split_k);
     } else {
       new_shape.add_dimensions(dimension_size);
     }
@@ -206,11 +209,12 @@
   return MakeBitcastHlo(operand, new_shape);
 }
 
-// Apply split K configuration from the tiling to the fused dot() computation:
-// bitcast the operands, change the output shape and the dot dimensions.
-Status MakeDotComputationSplitKBatch(
-    HloComputation* computation, const AutotuneResult::TritonGemmKey& tiling,
-    bool disable_reduced_precision_reduction) {
+// Apply split K configuration from the tiling config to the fused dot()
+// computation: bitcast the operands, change the output shape and the dot
+// dimensions.
+Status MakeDotComputationSplitKBatch(HloComputation* computation,
+                                     const TritonGemmConfig& config,
+                                     bool disable_reduced_precision_reduction) {
   HloInstruction* dot =
       hlo_query::GetFirstInstructionWithOpcode(*computation, HloOpcode::kDot);
   TF_ASSIGN_OR_RETURN(const auto analysis,
@@ -268,10 +272,10 @@
     if (current == dot) {
       TF_ASSIGN_OR_RETURN(
           HloInstruction * lhs,
-          MakeSplitKOperand(*dot, analysis, tiling, lhs_contracting_idx, 0));
+          MakeSplitKOperand(*dot, analysis, config, lhs_contracting_idx, 0));
       TF_ASSIGN_OR_RETURN(
           HloInstruction * rhs,
-          MakeSplitKOperand(*dot, analysis, tiling, rhs_contracting_idx, 1));
+          MakeSplitKOperand(*dot, analysis, config, rhs_contracting_idx, 1));
       if (lhs->operand(0)->opcode() == HloOpcode::kPad) {
         CHECK_EQ(rhs->operand(0)->opcode(), HloOpcode::kPad);
         did_pad = true;
@@ -290,9 +294,8 @@
       expanded->mutable_shape()->mutable_layout()->add_minor_to_major(0);
       dot->SetupDerivedInstruction(expanded);
     } else {
-      expanded = computation->AddInstruction(
-          current->CloneWithNewShape(ShapeUtil::PrependMajorDimension(
-              tiling.split_k(), current->shape())));
+      expanded = computation->AddInstruction(current->CloneWithNewShape(
+          ShapeUtil::PrependMajorDimension(config.split_k, current->shape())));
       if (expanded->opcode() == HloOpcode::kTranspose) {
         const auto* old_transpose = Cast<HloTransposeInstruction>(current);
         auto* new_transpose = Cast<HloTransposeInstruction>(expanded);
@@ -320,7 +323,7 @@
         TF_RETURN_IF_ERROR(expanded->ReplaceOperandWithDifferentShape(
             i, MakeBroadcastHlo(operand, broadcast_dimensions,
                                 ShapeUtil::PrependMajorDimension(
-                                    tiling.split_k(), operand->shape()))));
+                                    config.split_k, operand->shape()))));
       }
     }
   }
@@ -342,14 +345,14 @@
     // For the case without padding, we already checked this in
     // MakeSplitKOperand with the divisibility check.
     TF_RETURN_IF_ERROR(
-        TritonFusionAnalysis::Execute(*computation, tiling.split_k()).status());
+        TritonFusionAnalysis::Execute(*computation, config.split_k).status());
   }
 
   return OkStatus();
 }
 
 Status MakeDotSplitKBatch(HloInstruction* dot_fusion,
-                          const AutotuneResult::TritonGemmKey& tiling) {
+                          const TritonGemmConfig& config) {
   CHECK_EQ(dot_fusion->opcode(), HloOpcode::kFusion);
 
   if (dot_fusion->shape().IsTuple()) {
@@ -365,7 +368,7 @@
   const Layout output_layout = dot_fusion->shape().layout();
 
   TF_RETURN_IF_ERROR(MakeDotComputationSplitKBatch(
-      dot_fusion->fused_instructions_computation(), tiling,
+      dot_fusion->fused_instructions_computation(), config,
       disable_reduced_precision_reduction));
   const HloInstruction* root = dot_fusion->fused_expression_root();
 
diff --git a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.h b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.h
index 436247b..c74b4dc 100644
--- a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.h
+++ b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.h
@@ -20,6 +20,7 @@
 #include "absl/types/span.h"
 #include "xla/autotuning.pb.h"
 #include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/service/gpu/matmul_utils.h"
 #include "xla/status.h"
 
 namespace xla {
@@ -30,11 +31,11 @@
 bool HasDivisibleSuffixAllowingSplit(absl::Span<int64_t const> span,
                                      int64_t divisor);
 
-// Apply split K configuration from the tiling to the fusion instruction:
+// Apply split K configuration from the tiling config to the fusion instruction:
 // in addition to MakeDotComputationSplitKBatch on its computation add the
 // necessary reduction after it.
 Status MakeDotSplitKBatch(HloInstruction* dot_fusion,
-                          const AutotuneResult::TritonGemmKey& tiling);
+                          const TritonGemmConfig& config);
 
 }  // namespace gpu
 }  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter_test.cc b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter_test.cc
index a03627a..da131e5 100644
--- a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter_test.cc
+++ b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter_test.cc
@@ -29,6 +29,9 @@
 #include "xla/hlo/ir/hlo_opcode.h"
 #include "xla/layout.h"
 #include "xla/service/gpu/gemm_rewriter_triton.h"
+#include "xla/service/gpu/matmul_utils.h"
+#include "xla/service/hlo_verifier.h"
+#include "xla/service/layout_assignment.h"
 #include "xla/service/pattern_matcher.h"
 #include "xla/service/pattern_matcher_gmock.h"
 #include "xla/shape_util.h"
@@ -90,15 +93,9 @@
 })";
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
                           ParseAndReturnVerifiedModule(hlo_text));
-  AutotuneResult::TritonGemmKey key;
-  key.set_block_m(16);
-  key.set_block_n(16);
-  key.set_block_k(16);
-  key.set_split_k(4);
-  key.set_num_stages(1);
-  key.set_num_warps(4);
-  TF_EXPECT_OK(
-      MakeDotSplitKBatch(module->entry_computation()->root_instruction(), key));
+  TritonGemmConfig config(16, 16, 16, 4, 1, 4);
+  TF_EXPECT_OK(MakeDotSplitKBatch(
+      module->entry_computation()->root_instruction(), config));
   EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(),
             HloOpcode::kReduce);
 }
@@ -127,15 +124,9 @@
 })";
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
                           ParseAndReturnVerifiedModule(hlo_text));
-  AutotuneResult::TritonGemmKey key;
-  key.set_block_m(16);
-  key.set_block_n(16);
-  key.set_block_k(16);
-  key.set_split_k(4);
-  key.set_num_stages(1);
-  key.set_num_warps(4);
-  TF_EXPECT_OK(
-      MakeDotSplitKBatch(module->entry_computation()->root_instruction(), key));
+  TritonGemmConfig config(16, 16, 16, 4, 1, 4);
+  TF_EXPECT_OK(MakeDotSplitKBatch(
+      module->entry_computation()->root_instruction(), config));
   EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(),
             HloOpcode::kReduce);
 }
@@ -161,19 +152,13 @@
 })";
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
                           ParseAndReturnVerifiedModule(hlo_text));
-  AutotuneResult::TritonGemmKey key;
-  key.set_block_m(16);
-  key.set_block_n(16);
-  key.set_block_k(16);
-  key.set_split_k(4);
-  key.set_num_stages(1);
-  key.set_num_warps(4);
-  EXPECT_THAT(
-      MakeDotSplitKBatch(module->entry_computation()->root_instruction(), key),
-      tsl::testing::StatusIs(
-          tsl::error::CANCELLED,
-          absl::StrFormat(
-              "Operation non-distributive over addition after dot.")));
+  TritonGemmConfig config(16, 16, 16, 4, 1, 4);
+  EXPECT_THAT(MakeDotSplitKBatch(
+                  module->entry_computation()->root_instruction(), config),
+              tsl::testing::StatusIs(
+                  tsl::error::CANCELLED,
+                  absl::StrFormat(
+                      "Operation non-distributive over addition after dot.")));
 }
 
 TEST_F(SplitKTest, MakeSplitKWithNonDivisibleDimensionSize) {
@@ -198,15 +183,9 @@
 })";
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
                           ParseAndReturnVerifiedModule(kHloText));
-  AutotuneResult::TritonGemmKey key;
-  key.set_block_m(16);
-  key.set_block_n(16);
-  key.set_block_k(16);
-  key.set_split_k(2);
-  key.set_num_stages(1);
-  key.set_num_warps(2);
-  TF_EXPECT_OK(
-      MakeDotSplitKBatch(module->entry_computation()->root_instruction(), key));
+  TritonGemmConfig config(16, 16, 16, 2, 1, 2);
+  TF_EXPECT_OK(MakeDotSplitKBatch(
+      module->entry_computation()->root_instruction(), config));
 }
 
 TEST_F(SplitKTest, AvoidSplitKWithSlicedContractingDimension) {
@@ -227,19 +206,13 @@
 })";
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
                           ParseAndReturnVerifiedModule(hlo_text));
-  AutotuneResult::TritonGemmKey key;
-  key.set_block_m(16);
-  key.set_block_n(16);
-  key.set_block_k(16);
-  key.set_split_k(2);
-  key.set_num_stages(1);
-  key.set_num_warps(2);
-  EXPECT_THAT(
-      MakeDotSplitKBatch(module->entry_computation()->root_instruction(), key),
-      tsl::testing::StatusIs(
-          tsl::error::CANCELLED,
-          absl::StrFormat(
-              "Sliced contracting dimension is not supported yet.")));
+  TritonGemmConfig config(16, 16, 16, 2, 1, 2);
+  EXPECT_THAT(MakeDotSplitKBatch(
+                  module->entry_computation()->root_instruction(), config),
+              tsl::testing::StatusIs(
+                  tsl::error::CANCELLED,
+                  absl::StrFormat(
+                      "Sliced contracting dimension is not supported yet.")));
 }
 
 TEST_F(SplitKTest, MakeSplitKWithNonStandardOutputLayout) {
@@ -265,16 +238,10 @@
 })";
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
                           ParseAndReturnVerifiedModule(kHloText));
-  AutotuneResult::TritonGemmKey key;
-  key.set_block_m(16);
-  key.set_block_n(16);
-  key.set_block_k(16);
-  key.set_split_k(4);
-  key.set_num_stages(1);
-  key.set_num_warps(4);
+  TritonGemmConfig config(16, 16, 16, 4, 1, 4);
 
-  TF_EXPECT_OK(
-      MakeDotSplitKBatch(module->entry_computation()->root_instruction(), key));
+  TF_EXPECT_OK(MakeDotSplitKBatch(
+      module->entry_computation()->root_instruction(), config));
 
   EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(),
             HloOpcode::kReduce);
@@ -306,15 +273,9 @@
 })";
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
                           ParseAndReturnVerifiedModule(hlo_text));
-  AutotuneResult::TritonGemmKey key;
-  key.set_block_m(32);
-  key.set_block_n(64);
-  key.set_block_k(64);
-  key.set_split_k(8);
-  key.set_num_stages(1);
-  key.set_num_warps(4);
-  TF_EXPECT_OK(
-      MakeDotSplitKBatch(module->entry_computation()->root_instruction(), key));
+  TritonGemmConfig config(32, 64, 64, 8, 1, 4);
+  TF_EXPECT_OK(MakeDotSplitKBatch(
+      module->entry_computation()->root_instruction(), config));
   EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(),
             HloOpcode::kReduce);
 }
@@ -342,15 +303,9 @@
 })";
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
                           ParseAndReturnVerifiedModule(kHloText));
-  AutotuneResult::TritonGemmKey key;
-  key.set_block_m(16);
-  key.set_block_n(16);
-  key.set_block_k(16);
-  key.set_split_k(4);
-  key.set_num_stages(1);
-  key.set_num_warps(4);
-  TF_EXPECT_OK(
-      MakeDotSplitKBatch(module->entry_computation()->root_instruction(), key));
+  TritonGemmConfig config(16, 16, 16, 4, 1, 4);
+  TF_EXPECT_OK(MakeDotSplitKBatch(
+      module->entry_computation()->root_instruction(), config));
 }
 
 TEST_F(SplitKTest, SupportsIndivisibleSimpleSplitK4) {
@@ -373,15 +328,41 @@
 })";
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
                           ParseAndReturnVerifiedModule(kHloText));
-  AutotuneResult::TritonGemmKey key;
-  key.set_block_m(16);
-  key.set_block_n(16);
-  key.set_block_k(16);
-  key.set_split_k(4);
-  key.set_num_stages(1);
-  key.set_num_warps(4);
-  TF_EXPECT_OK(
-      MakeDotSplitKBatch(module->entry_computation()->root_instruction(), key));
+  TritonGemmConfig config(16, 16, 16, 4, 1, 4);
+  TF_EXPECT_OK(MakeDotSplitKBatch(
+      module->entry_computation()->root_instruction(), config));
+}
+
+TEST_F(SplitKTest, SupportsIndivisibleWithCustomLayout) {
+  constexpr absl::string_view kHloText = R"(
+HloModule t
+
+triton_gemm_dot {
+  parameter_0 = s8[480,129]{0,1} parameter(0)
+  convert_0 = bf16[480,129]{0,1} convert(parameter_0)
+  parameter_1 = bf16[16,129]{0,1} parameter(1)
+  ROOT dot.0 = bf16[480,16]{1,0} dot(convert_0, parameter_1),
+    lhs_contracting_dims={1}, rhs_contracting_dims={1}
+}
+
+ENTRY e {
+  p0 = s8[480,129]{0,1} parameter(0)
+  p1 = bf16[16,129]{0,1} parameter(1)
+  ROOT fusion = bf16[480,16]{1,0} fusion(p0, p1),
+    kind=kCustom, calls=triton_gemm_dot, backend_config="__triton_gemm"
+})";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(kHloText));
+
+  constexpr TritonGemmConfig kConfig(16, 16, 16, 4, 1, 4);
+  TF_EXPECT_OK(MakeDotSplitKBatch(
+      module->entry_computation()->root_instruction(), kConfig));
+
+  TF_EXPECT_OK(HloVerifier(/*layout_sensitive=*/true,
+                           /*allow_mixed_precision=*/true,
+                           LayoutAssignment::InstructionCanChangeLayout)
+                   .Run(module.get())
+                   .status());
 }
 
 TEST_F(SplitKTest, SupportsIndivisibleSimpleSplitK16) {
@@ -404,15 +385,9 @@
 })";
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
                           ParseAndReturnVerifiedModule(kHloText));
-  AutotuneResult::TritonGemmKey key;
-  key.set_block_m(16);
-  key.set_block_n(16);
-  key.set_block_k(16);
-  key.set_split_k(16);
-  key.set_num_stages(1);
-  key.set_num_warps(4);
-  TF_EXPECT_OK(
-      MakeDotSplitKBatch(module->entry_computation()->root_instruction(), key));
+  TritonGemmConfig config(16, 16, 16, 16, 1, 4);
+  TF_EXPECT_OK(MakeDotSplitKBatch(
+      module->entry_computation()->root_instruction(), config));
 }
 
 TEST_F(SplitKTest, SupportsIndivisibleWithTranspose) {
@@ -436,15 +411,9 @@
 })";
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
                           ParseAndReturnVerifiedModule(kHloText));
-  AutotuneResult::TritonGemmKey key;
-  key.set_block_m(16);
-  key.set_block_n(16);
-  key.set_block_k(16);
-  key.set_split_k(16);
-  key.set_num_stages(1);
-  key.set_num_warps(4);
-  TF_EXPECT_OK(
-      MakeDotSplitKBatch(module->entry_computation()->root_instruction(), key));
+  TritonGemmConfig config(16, 16, 16, 16, 1, 4);
+  TF_EXPECT_OK(MakeDotSplitKBatch(
+      module->entry_computation()->root_instruction(), config));
 }
 
 TEST_F(SplitKTest, SupportIndivisibleWithBroadcast) {
@@ -468,15 +437,9 @@
 })";
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
                           ParseAndReturnVerifiedModule(kHloText));
-  AutotuneResult::TritonGemmKey key;
-  key.set_block_m(16);
-  key.set_block_n(16);
-  key.set_block_k(16);
-  key.set_split_k(16);
-  key.set_num_stages(1);
-  key.set_num_warps(4);
-  TF_EXPECT_OK(
-      MakeDotSplitKBatch(module->entry_computation()->root_instruction(), key));
+  TritonGemmConfig config(16, 16, 16, 16, 1, 4);
+  TF_EXPECT_OK(MakeDotSplitKBatch(
+      module->entry_computation()->root_instruction(), config));
 }
 
 TEST_F(SplitKTest, SupportsIndivisibleWithBitcast) {
@@ -500,15 +463,9 @@
 })";
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
                           ParseAndReturnVerifiedModule(kHloText));
-  AutotuneResult::TritonGemmKey key;
-  key.set_block_m(16);
-  key.set_block_n(16);
-  key.set_block_k(16);
-  key.set_split_k(16);
-  key.set_num_stages(1);
-  key.set_num_warps(4);
-  TF_EXPECT_OK(
-      MakeDotSplitKBatch(module->entry_computation()->root_instruction(), key));
+  TritonGemmConfig config(16, 16, 16, 16, 1, 4);
+  TF_EXPECT_OK(MakeDotSplitKBatch(
+      module->entry_computation()->root_instruction(), config));
 }
 
 TEST_F(SplitKTest, SkipSmallK) {
@@ -534,18 +491,12 @@
 })";
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
                           ParseAndReturnVerifiedModule(hlo_text));
-  AutotuneResult::TritonGemmKey key;
-  key.set_block_m(16);
-  key.set_block_n(16);
-  key.set_block_k(128);
-  key.set_split_k(4);
-  key.set_num_stages(1);
-  key.set_num_warps(4);
-  EXPECT_THAT(
-      MakeDotSplitKBatch(module->entry_computation()->root_instruction(), key),
-      tsl::testing::StatusIs(
-          tsl::error::CANCELLED,
-          "Too small divisible part of the contracting dimension."));
+  TritonGemmConfig config(16, 16, 128, 4, 1, 4);
+  EXPECT_THAT(MakeDotSplitKBatch(
+                  module->entry_computation()->root_instruction(), config),
+              tsl::testing::StatusIs(
+                  tsl::error::CANCELLED,
+                  "Too small divisible part of the contracting dimension."));
 }
 
 TEST_F(SplitKTest, FragmentedKSupported) {
@@ -570,24 +521,19 @@
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
                           ParseAndReturnVerifiedModule(hlo_text));
 
-  AutotuneResult::TritonGemmKey key;
-  key.set_block_m(32);
-  key.set_block_n(32);
-  key.set_block_k(16);
-  key.set_num_stages(1);
-  key.set_num_warps(4);
-
+  TritonGemmConfig config(32, 32, 16, 1, 1, 4);
   // 5 divides the contracting dimension, but not its major subdimensions.
-  key.set_split_k(5);
+  config.split_k = 5;
   EXPECT_THAT(
-      MakeDotSplitKBatch(module->entry_computation()->root_instruction(), key),
+      MakeDotSplitKBatch(module->entry_computation()->root_instruction(),
+                         config),
       tsl::testing::StatusIs(tsl::error::CANCELLED,
                              "Contracting dimension is too fragmented."));
 
   // 8 fits the constraints.
-  key.set_split_k(8);
-  TF_EXPECT_OK(
-      MakeDotSplitKBatch(module->entry_computation()->root_instruction(), key));
+  config.split_k = 8;
+  TF_EXPECT_OK(MakeDotSplitKBatch(
+      module->entry_computation()->root_instruction(), config));
   const HloInstruction* root = module->entry_computation()->root_instruction();
   EXPECT_EQ(root->opcode(), HloOpcode::kReduce);
   const HloComputation* dot_computation = module->entry_computation()
@@ -597,7 +543,7 @@
   const HloInstruction* p0 = dot_computation->parameter_instruction(0);
   TF_ASSERT_OK_AND_ASSIGN(
       const auto analysis,
-      TritonFusionAnalysis::Execute(*dot_computation, key.split_k()));
+      TritonFusionAnalysis::Execute(*dot_computation, config.split_k));
   EXPECT_EQ(dot_computation->root_instruction()->shape(),
             ShapeUtil::MakeShapeWithDescendingLayout(F16, {8, 7, 5}));
   EXPECT_THAT(
@@ -628,16 +574,11 @@
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
                           ParseAndReturnVerifiedModule(hlo_text));
 
-  AutotuneResult::TritonGemmKey key;
-  key.set_block_m(16);
-  key.set_block_n(16);
-  key.set_block_k(16);
-  key.set_num_stages(1);
-  key.set_num_warps(4);
-  key.set_split_k(4);
+  TritonGemmConfig config(16, 16, 16, 4, 1, 4);
   // Because HasDivisibleSuffixAllowingSplit({128, 3}, 4) == false.
   EXPECT_THAT(
-      MakeDotSplitKBatch(module->entry_computation()->root_instruction(), key),
+      MakeDotSplitKBatch(module->entry_computation()->root_instruction(),
+                         config),
       tsl::testing::StatusIs(tsl::error::CANCELLED,
                              "Contracting dimension is too fragmented."));
 }
@@ -661,15 +602,9 @@
 })";
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
                           ParseAndReturnVerifiedModule(kHloText));
-  AutotuneResult::TritonGemmKey key;
-  key.set_block_m(16);
-  key.set_block_n(16);
-  key.set_block_k(16);
-  key.set_split_k(2);
-  key.set_num_stages(1);
-  key.set_num_warps(4);
-  TF_EXPECT_OK(
-      MakeDotSplitKBatch(module->entry_computation()->root_instruction(), key));
+  TritonGemmConfig config(16, 16, 16, 2, 1, 4);
+  TF_EXPECT_OK(MakeDotSplitKBatch(
+      module->entry_computation()->root_instruction(), config));
   EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(),
             HloOpcode::kReduce);
   const HloComputation* dot_computation = module->entry_computation()
@@ -716,15 +651,9 @@
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
                           ParseAndReturnVerifiedModule(kHloText));
 
-  AutotuneResult::TritonGemmKey key;
-  key.set_block_m(16);
-  key.set_block_n(16);
-  key.set_block_k(16);
-  key.set_split_k(4);
-  key.set_num_stages(1);
-  key.set_num_warps(4);
-  TF_EXPECT_OK(
-      MakeDotSplitKBatch(module->entry_computation()->root_instruction(), key));
+  TritonGemmConfig config(16, 16, 16, 4, 1, 4);
+  TF_EXPECT_OK(MakeDotSplitKBatch(
+      module->entry_computation()->root_instruction(), config));
 
   EXPECT_THAT(module->entry_computation()->root_instruction(),
               GmockMatch(m::Convert(m::Reduce(m::Fusion(), m::Constant()))));
@@ -754,15 +683,9 @@
 })";
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
                           ParseAndReturnVerifiedModule(hlo_text));
-  AutotuneResult::TritonGemmKey key;
-  key.set_block_m(16);
-  key.set_block_n(16);
-  key.set_block_k(16);
-  key.set_split_k(4);
-  key.set_num_stages(1);
-  key.set_num_warps(4);
-  TF_EXPECT_OK(
-      MakeDotSplitKBatch(module->entry_computation()->root_instruction(), key));
+  TritonGemmConfig config(16, 16, 16, 4, 1, 4);
+  TF_EXPECT_OK(MakeDotSplitKBatch(
+      module->entry_computation()->root_instruction(), config));
   EXPECT_THAT(module->entry_computation()->root_instruction(),
               GmockMatch(m::Convert(m::Reduce(m::Fusion(), m::Constant()))));
 }
@@ -786,15 +709,9 @@
 })";
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
                           ParseAndReturnVerifiedModule(hlo_text));
-  AutotuneResult::TritonGemmKey key;
-  key.set_block_m(16);
-  key.set_block_n(128);
-  key.set_block_k(32);
-  key.set_split_k(8);
-  key.set_num_stages(1);
-  key.set_num_warps(4);
-  TF_EXPECT_OK(
-      MakeDotSplitKBatch(module->entry_computation()->root_instruction(), key));
+  TritonGemmConfig config(16, 128, 32, 8, 1, 4);
+  TF_EXPECT_OK(MakeDotSplitKBatch(
+      module->entry_computation()->root_instruction(), config));
   const auto* transpose =
       Cast<HloTransposeInstruction>(module->entry_computation()
                                         ->root_instruction()
diff --git a/third_party/xla/xla/service/gpu/tests/BUILD b/third_party/xla/xla/service/gpu/tests/BUILD
index 04d886e..ad4aebe 100644
--- a/third_party/xla/xla/service/gpu/tests/BUILD
+++ b/third_party/xla/xla/service/gpu/tests/BUILD
@@ -146,6 +146,7 @@
     deps = [
         ":gpu_codegen_test",
         "//xla:statusor",
+        "//xla:test",
         "//xla:xla_proto_cc",
         "//xla/hlo/ir:hlo",
         "//xla/service:gpu_plugin",
@@ -423,11 +424,9 @@
     tags = tf_cuda_tests_tags(),
     deps = [
         ":gpu_codegen_test",
+        "//xla:shape_util",
         "//xla/hlo/ir:hlo",
-        "//xla/service:hlo_module_config",
-        "//xla/service:hlo_parser",
-        "//xla/tests:hlo_test_base",
-        "@local_tsl//tsl/platform:test",
+        "@com_google_googletest//:gtest",
         "@local_tsl//tsl/platform:test_main",
     ],
 )
@@ -699,6 +698,9 @@
 xla_cc_test(
     name = "sorting_test",
     srcs = ["sorting_test.cc"],
+    local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([
+        "TENSORFLOW_USE_ROCM=1",
+    ]),
     tags = tf_cuda_tests_tags(),
     deps = [
         ":gpu_codegen_test",
@@ -985,3 +987,14 @@
     shard_count = 1,
     deps = [":simple_optimization_test"],
 )
+
+xla_cc_test(
+    name = "gpu_int4_test",
+    srcs = ["gpu_int4_test.cc"],
+    tags = tf_cuda_tests_tags(),
+    deps = [
+        ":gpu_codegen_test",
+        "@local_tsl//tsl/platform:test",
+        "@local_tsl//tsl/platform:test_main",
+    ],
+)
diff --git a/third_party/xla/xla/service/gpu/tests/gemm_broadcast_folding_rewrite_test.cc b/third_party/xla/xla/service/gpu/tests/gemm_broadcast_folding_rewrite_test.cc
index 8aa6775..66c7b6b 100644
--- a/third_party/xla/xla/service/gpu/tests/gemm_broadcast_folding_rewrite_test.cc
+++ b/third_party/xla/xla/service/gpu/tests/gemm_broadcast_folding_rewrite_test.cc
@@ -53,7 +53,7 @@
 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[3,2,2], y: f32[2,2]) -> f32[3,2,2] {
 ; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[3,2,2]{2,1,0} parameter(0)
 ; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
-; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[3,2,2]{2,1,0} custom-call([[P0]], [[P1]]),
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
 ; CHECK:           custom_call_target="__cublas${{(lt\$matmul|gemm)}}",
 ; CHECK:           backend_config={
 ; CHECK-DAG:         "alpha_real":1
@@ -92,7 +92,7 @@
 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[3,2,2]) -> f32[3,2,2] {
 ; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
 ; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,2,2]{2,1,0} parameter(1)
-; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[3,2,2]{2,1,0} custom-call([[P0]], [[P1]]),
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
 ; CHECK    :       custom_call_target="__cublas${{(lt\$matmul|gemm)}}",
 ; CHECK    :       backend_config={
 ; CHECK-DAG:         "alpha_real":1
diff --git a/third_party/xla/xla/service/gpu/tests/gemm_rewrite_test.cc b/third_party/xla/xla/service/gpu/tests/gemm_rewrite_test.cc
index dec11a3..c742d2e 100644
--- a/third_party/xla/xla/service/gpu/tests/gemm_rewrite_test.cc
+++ b/third_party/xla/xla/service/gpu/tests/gemm_rewrite_test.cc
@@ -29,6 +29,7 @@
 #include "xla/service/pattern_matcher.h"
 #include "xla/service/pattern_matcher_gmock.h"
 #include "xla/statusor.h"
+#include "xla/test.h"
 #include "xla/tests/filecheck.h"
 #include "xla/xla.pb.h"
 #include "tsl/lib/core/status_test_util.h"
@@ -239,7 +240,7 @@
 ; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4]) -> f32[2,4] {
 ; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
 ; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]]),
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
 ; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
 ; CHECK:           backend_config={
 ; CHECK-DAG:         "alpha_real":1
@@ -277,7 +278,7 @@
 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,3], y: f32[3,4]) -> f32[2,4] {
 ; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
 ; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]]),
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
 ; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
 ; CHECK:           backend_config={
 ; CHECK-DAG:         "alpha_real":1
@@ -319,7 +320,7 @@
 ; CHECK-DAG:     [[P1:%[^ ]+]] = f32[3,4,5]{2,1,0} parameter(1)
 ; CHECK-DAG:     [[BITCAST0:%[^ ]+]] = f32[2,12]{0,1} bitcast([[P0]])
 ; CHECK-DAG:     [[BITCAST1:%[^ ]+]] = f32[12,5]{1,0} bitcast([[P1]])
-; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[2,5]{1,0} custom-call([[BITCAST0]], [[BITCAST1]]),
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[BITCAST0]], [[BITCAST1]]),
 ; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
 ; CHECK:           backend_config={
 ; CHECK-DAG:         "alpha_real":1
@@ -358,7 +359,7 @@
 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[3,2], y: f32[3,4]) -> f32[2,4] {
 ; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[3,2]{1,0} parameter(0)
 ; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]]),
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
 ; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
 ; CHECK:           backend_config={
 ; CHECK-DAG:         "alpha_real":1
@@ -397,7 +398,7 @@
 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[5,3,2], y: f32[5,3,4]) -> f32[5,2,4] {
 ; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[5,3,2]{2,1,0} parameter(0)
 ; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1)
-; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[5,2,4]{2,1,0} custom-call([[P0]], [[P1]]),
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
 ; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
 ; CHECK:           backend_config={
 ; CHECK-DAG:         "alpha_real":1
@@ -436,7 +437,7 @@
 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,5,3], y: f32[5,3,4]) -> f32[5,2,4] {
 ; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,5,3]{2,1,0} parameter(0)
 ; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1)
-; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[5,2,4]{2,1,0} custom-call([[P0]], [[P1]]),
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
 ; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
 ; CHECK:           backend_config={
 ; CHECK-DAG:         "alpha_real":1
@@ -476,7 +477,7 @@
 ; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[3,2,5]{2,1,0} parameter(0)
 ; CHECK-DAG:     [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1)
 ; CHECK-DAG:     [[FUSION:%[^ ]+]] = f32[5,2,3]{2,1,0} transpose([[P0]])
-; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[5,2,4]{2,1,0} custom-call([[FUSION]], [[P1]]),
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[FUSION]], [[P1]]),
 ; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
 ; CHECK:           backend_config={
 ; CHECK-DAG:         "alpha_real":1
@@ -518,7 +519,7 @@
 ; CHECK:    [[BC0:%[^ ]+]] = f32[80000,3,2]{2,1,0} bitcast([[P0]])
 ; CHECK:    [[P1:%[^ ]+]] = f32[20000,4,3,4]{3,2,1,0} parameter(1)
 ; CHECK:    [[BC1:%[^ ]+]] = f32[80000,3,4]{2,1,0} bitcast([[P1]])
-; CHECK:    [[OUT:%[^ ]+]] = f32[80000,2,4]{2,1,0} custom-call([[BC0]], [[BC1]]),
+; CHECK:    [[GEMM:%[^ ]+]] = (f32[80000,2,4]{2,1,0}, s8[{{[0-9]+}}]{0}) custom-call([[BC0]], [[BC1]]),
 ; CHECK:           custom_call_target="__cublas$gemm",
 ; CHECK:           backend_config={
 ; CHECK-DAG:         "alpha_real":1
@@ -534,6 +535,7 @@
 ; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
 ; CHECK-DAG:         }
 ; CHECK:           }
+; CHECK:   [[OUT:%[^ ]+]] = f32[80000,2,4]{2,1,0} get-tuple-element([[GEMM]]), index=0
 ; CHECK:   ROOT {{[^ ]+}} = f32[20000,4,2,4]{3,2,1,0} bitcast([[OUT]])
 )");
 }
@@ -557,7 +559,7 @@
 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,3], y: f32[3,4]) -> f32[4,2] {
 ; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
 ; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
-; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[4,2]{1,0} custom-call([[P1]], [[P0]]),
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P1]], [[P0]]),
 ; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
 ; CHECK:           backend_config={
 ; CHECK-DAG:         "alpha_real":1
@@ -596,7 +598,7 @@
 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[5,2,3], y: f32[5,3,4]) -> f32[2,5,4] {
 ; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[5,2,3]{2,1,0} parameter(0)
 ; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1)
-; CHECK-NEXT:    [[GEMM:%[^ ]+]] = f32[5,2,4]{2,0,1} custom-call([[P0]], [[P1]]),
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
 ; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
 ; CHECK:           backend_config={
 ; CHECK-DAG:         "alpha_real":1
@@ -613,7 +615,7 @@
 ; CHECK-DAG:         }
 ; CHECK-DAG:         "epilogue":"DEFAULT"
 ; CHECK:           }
-; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[2,5,4]{2,1,0} bitcast([[GEMM]])
+; CHECK:         ROOT [[OUT:%[^ ]+]] = f32[2,5,4]{2,1,0} bitcast
 )");
 }
 
@@ -636,7 +638,7 @@
 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[5,2,3], y: f32[5,3,4]) -> f32[2,4,5] {
 ; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[5,2,3]{2,1,0} parameter(0)
 ; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1)
-; CHECK-NEXT:    [[GEMM:%[^ ]+]] = f32[5,2,4]{2,1,0} custom-call([[P0]], [[P1]]),
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
 ; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
 ; CHECK:           backend_config={
 ; CHECK-DAG:         "alpha_real":1
@@ -653,7 +655,7 @@
 ; CHECK-DAG:         }
 ; CHECK-DAG:         "epilogue":"DEFAULT"
 ; CHECK:           }
-; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[2,4,5]{2,1,0} [[OP:[^ ]+]]([[GEMM]])
+; CHECK:         ROOT [[OUT:%[^ ]+]] = f32[2,4,5]{2,1,0} [[OP:[^ ]+]]
 )");
 }
 
@@ -678,7 +680,7 @@
 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2]) -> f32[2,2] {
 ; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
 ; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
-; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[2,2]{1,0} custom-call([[P0]], [[P1]]),
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
 ; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
 ; CHECK:           backend_config={
 ; CHECK-DAG:         "alpha_real":3
@@ -719,7 +721,7 @@
 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: c64[2,2], y: c64[2,2]) -> c64[2,2] {
 ; CHECK-NEXT:    [[P0:%[^ ]+]] = c64[2,2]{1,0} parameter(0)
 ; CHECK-NEXT:    [[P1:%[^ ]+]] = c64[2,2]{1,0} parameter(1)
-; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = c64[2,2]{1,0} custom-call([[P0]], [[P1]]),
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
 ; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
 ; CHECK:           backend_config={
 ; CHECK-DAG:         "alpha_real":3
@@ -758,7 +760,7 @@
   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
   MatchOptimizedHlo(hlo_text,
                     R"(
-; CHECK:    {{[^ ]+}} = f32[2,2]{1,0} custom-call({{[^,]+}}, {{[^)]+}}),
+; CHECK:    {{[^ ]+}} = {{.*}} custom-call({{[^,]+}}, {{[^)]+}}),
 ; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
 ; CHECK:           backend_config={
 ; CHECK-DAG:         "alpha_real":1
@@ -798,7 +800,7 @@
 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2]) -> f32[2,2] {
 ; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
 ; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
-; CHECK-NEXT:    [[OUT:%[^ ]+]] = f32[2,2]{1,0} custom-call([[P0]], [[P1]]),
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0]], [[P1]]),
 ; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
 ; CHECK:           backend_config={
 ; CHECK-DAG:         "alpha_real":1
@@ -833,13 +835,13 @@
   if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::AMPERE)) {
     MatchOptimizedHlo(hlo_text,
                       R"(
-; CHECK: bf16[16,8]{1,0} custom-call(bf16[16,8]{1,0} {{.*}}, bf16[8,8]{1,0} {{.*}}), custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>"
+; CHECK: {{.*}} custom-call(bf16[16,8]{1,0} {{.*}}, bf16[8,8]{1,0} {{.*}}), custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>"
   )",
                       /*print_operand_shape=*/true);
   } else {
     MatchOptimizedHlo(hlo_text,
                       R"(
-; CHECK: bf16[12,8]{1,0} custom-call(bf16[12,4]{1,0} [[P0:%[^ ]+]], bf16[4,8]{1,0} [[P1:%[^ ]+]]), custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>"
+; CHECK: {{.*}} custom-call(bf16[12,4]{1,0} [[P0:%[^ ]+]], bf16[4,8]{1,0} [[P1:%[^ ]+]]), custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>"
   )",
                       /*print_operand_shape=*/true);
   }
@@ -861,13 +863,19 @@
   if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::AMPERE)) {
     MatchOptimizedHlo(hlo_text,
                       R"(
-    ; CHECK: bf16[3,8,8]{2,1,0} custom-call(bf16[3,8,8]{2,1,0} {{.*}}, bf16[3,8,8]{2,1,0} {{.*}}), custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>"
+    ; CHECK: {{.*}} custom-call(bf16[3,8,8]{2,1,0} {{.*}}, bf16[3,8,8]{2,1,0} {{.*}}), custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>"
+    )",
+                      /*print_operand_shape=*/true);
+  } else if (GetParam()) {
+    MatchOptimizedHlo(hlo_text,
+                      R"(
+    ; CHECK: ROOT [[OUT:%[^ ]+]] = bf16[3,4,2]{2,1,0} custom-call(bf16[3,3,4]{2,1,0} [[A:%[^ ]+]], bf16[3,3,2]{2,1,0} [[B:%[^ ]+]]), custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>"
     )",
                       /*print_operand_shape=*/true);
   } else {
     MatchOptimizedHlo(hlo_text,
                       R"(
-    ; CHECK: ROOT [[OUT:%[^ ]+]] = bf16[3,4,2]{2,1,0} custom-call(bf16[3,3,4]{2,1,0} [[A:%[^ ]+]], bf16[3,3,2]{2,1,0} [[B:%[^ ]+]]), custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>"
+    ; CHECK: {{.*}} custom-call(bf16[3,3,4]{2,1,0} [[A:%[^ ]+]], bf16[3,3,2]{2,1,0} [[B:%[^ ]+]]), custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>"
     )",
                       /*print_operand_shape=*/true);
   }
@@ -888,13 +896,13 @@
   if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) {
     MatchOptimizedHlo(hlo_text,
                       R"(
-; CHECK: s32[12,8]{1,0} custom-call(s8[12,4]{1,0} [[A:%[^ ]+]], s8[4,8]{0,1} [[B:%[^ ]+]]), custom_call_target="__cublas$gemm"
+; CHECK: {{.*}} custom-call(s8[12,4]{1,0} [[A:%[^ ]+]], s8[4,8]{0,1} [[B:%[^ ]+]]), custom_call_target="__cublas$gemm"
   )",
                       /*print_operand_shape=*/true);
   } else {
     MatchOptimizedHlo(hlo_text,
                       R"(
-; CHECK: s32[12,8]{1,0} dot(s32[12,4]{1,0} [[A:%[^ ]+]], s32[4,8]{1,0} [[B:%[^ ]+]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+; CHECK: {{.*}} dot(s32[12,4]{1,0} [[A:%[^ ]+]], s32[4,8]{1,0} [[B:%[^ ]+]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
 
   )",
                       /*print_operand_shape=*/true);
@@ -918,8 +926,18 @@
   if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) {
     MatchOptimizedHlo(hlo_text,
                       R"(
-; CHECK %custom-call = s32[8,4]{1,0} custom-call(s8[8,4]{1,0} %fusion.1, s8[4,4]{0,1} %bitcast.13), custom_call_target="__cublas$gemm", backend_config={"selected_algorithm":"0","alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}
-; CHECK: ROOT %bitcast.1 = s32[1,8,4]{2,1,0} bitcast(s32[8,4]{1,0} %custom-call)
+; CHECK: [[GEMM:%[^ ]+]] = (s32[8,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call(s8[8,4]{1,0} %fusion.1, s8[4,4]{0,1} %bitcast.13), custom_call_target="__cublas$gemm",
+; CHECK:   backend_config={
+; CHECK-DAG:   "selected_algorithm":"0"
+; CHECK-DAG:   "alpha_real":1
+; CHECK-DAG:   "alpha_imag":0
+; CHECK-DAG:   "beta":0
+; CHECK-DAG:   "dot_dimension_numbers":{"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]}
+; CHECK-DAG:   "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}
+; CHECK-DAG:   "epilogue":"DEFAULT"
+; CHECK:   }
+; CHECK: [[RES:%[^ ]+]] = s32[8,4]{1,0} get-tuple-element((s32[8,4]{1,0}, s8[{{[0-9]+}}]{0}) [[GEMM]]), index=0
+; CHECK: ROOT [[OUT:%[^ ]+]] = s32[1,8,4]{2,1,0} bitcast(s32[8,4]{1,0} [[RES]])
   )",
                       /*print_operand_shape=*/true);
   }
@@ -943,7 +961,7 @@
   if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) {
     MatchOptimizedHlo(hlo_text,
                       R"(
-; CHECK: s32[12,8]{1,0} custom-call(s8[12,4]{1,0} [[A:%[^ ]+]], s8[4,8]{0,1} [[B:%[^ ]+]]),
+; CHECK: {{.*}} custom-call(s8[12,4]{1,0} [[A:%[^ ]+]], s8[4,8]{0,1} [[B:%[^ ]+]]),
 ; CHECK:           custom_call_target="__cublas$gemm",
 ; CHECK:           backend_config={
 ; CHECK-DAG:       "alpha_real":1
@@ -953,7 +971,7 @@
   } else {
     MatchOptimizedHlo(hlo_text,
                       R"(
-; CHECK: s32[12,8]{1,0} dot(s32[12,4]{1,0} [[A:%[^ ]+]], s32[4,8]{1,0} [[B:%[^ ]+]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+; CHECK: {{.*}} dot(s32[12,4]{1,0} [[A:%[^ ]+]], s32[4,8]{1,0} [[B:%[^ ]+]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
 
   )",
                       /*print_operand_shape=*/true);
@@ -977,7 +995,7 @@
   if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) {
     MatchOptimizedHlo(hlo_text,
                       R"(
-; CHECK: s32[12,8]{1,0} custom-call(s8[12,4]{1,0} [[A:%[^ ]+]], s8[4,8]{0,1} [[B:%[^ ]+]]),
+; CHECK: {{.*}} custom-call(s8[12,4]{1,0} [[A:%[^ ]+]], s8[4,8]{0,1} [[B:%[^ ]+]]),
 ; CHECK:           custom_call_target="__cublas$gemm",
 ; CHECK:           backend_config={
 ; CHECK-DAG:       "alpha_real":1
@@ -988,7 +1006,7 @@
   } else {
     MatchOptimizedHlo(hlo_text,
                       R"(
-; CHECK: s32[12,8]{1,0} dot(s32[12,4]{1,0} [[A:%[^ ]+]], s32[4,8]{1,0} [[B:%[^ ]+]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+; CHECK: {{.*}} dot(s32[12,4]{1,0} [[A:%[^ ]+]], s32[4,8]{1,0} [[B:%[^ ]+]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
 
   )",
                       /*print_operand_shape=*/true);
@@ -1010,13 +1028,13 @@
   if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) {
     MatchOptimizedHlo(hlo_text,
                       R"(
-; CHECK: s32[16,12]{1,0} custom-call(s8[16,4]{1,0} [[A:%[^ ]+]], s8[4,12]{0,1} [[B:%[^ ]+]]), custom_call_target="__cublas$gemm"
+; CHECK: {{.*}} custom-call(s8[16,4]{1,0} [[A:%[^ ]+]], s8[4,12]{0,1} [[B:%[^ ]+]]), custom_call_target="__cublas$gemm"
   )",
                       /*print_operand_shape=*/true);
   } else {
     MatchOptimizedHlo(hlo_text,
                       R"(
-; CHECK: s32[13,9]{1,0} dot(s32[13,4]{1,0} [[A:%[^ ]+]], s32[4,9]{1,0} [[B:%[^ ]+]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+; CHECK: {{.*}} dot(s32[13,4]{1,0} [[A:%[^ ]+]], s32[4,9]{1,0} [[B:%[^ ]+]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}
 
   )",
                       /*print_operand_shape=*/true);
@@ -1109,8 +1127,9 @@
 
   // This is a type combination which is not supported by cublasLt, expect
   // GemmRewriter to choose legacy cublas.
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::CustomCall({"__cublas$gemm"})));
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(m::GetTupleElement(m::CustomCall({"__cublas$gemm"}), 0)));
 }
 
 TEST_P(ParameterizedGemmRewriteTest, UpcastingC64ToC128) {
@@ -1132,8 +1151,9 @@
 
   // This is a type combination which is not supported by cublasLt, expect
   // GemmRewriter to choose legacy cublas.
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::CustomCall({"__cublas$gemm"})));
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(m::GetTupleElement(m::CustomCall({"__cublas$gemm"}), 0)));
 }
 
 TEST_P(ParameterizedGemmRewriteTest, UpcastingF16ToF32) {
@@ -1153,8 +1173,14 @@
   TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
   EXPECT_TRUE(changed);
 
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::CustomCall({CustomCallTarget()})));
+  if (GetParam()) {
+    EXPECT_THAT(module->entry_computation()->root_instruction(),
+                GmockMatch(m::CustomCall({CustomCallTarget()})));
+  } else {
+    EXPECT_THAT(
+        module->entry_computation()->root_instruction(),
+        GmockMatch(m::GetTupleElement(m::CustomCall({CustomCallTarget()}), 0)));
+  }
 }
 
 TEST_P(ParameterizedGemmRewriteTest, UpcastingF16ToF64) {
@@ -1176,8 +1202,9 @@
 
   // This is a type combination which is not supported by cublasLt, expect
   // GemmRewriter to choose legacy cublas.
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::CustomCall({"__cublas$gemm"})));
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(m::GetTupleElement(m::CustomCall({"__cublas$gemm"}), 0)));
 }
 
 TEST_P(ParameterizedGemmRewriteTest, UpcastingF32ToF64) {
@@ -1199,8 +1226,9 @@
 
   // This is a type combination which is not supported by cublasLt, expect
   // GemmRewriter to choose legacy cublas.
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::CustomCall({"__cublas$gemm"})));
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      GmockMatch(m::GetTupleElement(m::CustomCall({"__cublas$gemm"}), 0)));
 }
 
 TEST_P(ParameterizedGemmRewriteTest, DoNotUpconvertOutput) {
@@ -1226,8 +1254,14 @@
 
   // input fp16 and output fp32 combination is supported by legacy cublas and
   // cublasLt, expect GemmRewriter to fuse the convert into gemm.
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::Convert(m::CustomCall({CustomCallTarget()}))));
+  if (GetParam()) {
+    EXPECT_THAT(module->entry_computation()->root_instruction(),
+                GmockMatch(m::Convert(m::CustomCall({CustomCallTarget()}))));
+  } else {
+    EXPECT_THAT(module->entry_computation()->root_instruction(),
+                GmockMatch(m::Convert(m::GetTupleElement(
+                    m::CustomCall({CustomCallTarget()}), 0))));
+  }
 }
 
 TEST_P(ParameterizedGemmRewriteTest, UnsupportedMixTypeGemm) {
@@ -1253,8 +1287,14 @@
 
   // u8 is not supported by legacy cublas and cublasLt, expect
   // GemmRewriter to not fuse the convert into gemm.
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::Convert(m::CustomCall({CustomCallTarget()}))));
+  if (GetParam()) {
+    EXPECT_THAT(module->entry_computation()->root_instruction(),
+                GmockMatch(m::Convert(m::CustomCall({CustomCallTarget()}))));
+  } else {
+    EXPECT_THAT(module->entry_computation()->root_instruction(),
+                GmockMatch(m::Convert(m::GetTupleElement(
+                    m::CustomCall({CustomCallTarget()}), 0))));
+  }
 }
 
 TEST_P(ParameterizedGemmRewriteTest, CheckIsGemmAliasedBeforeFusion) {
@@ -1283,8 +1323,14 @@
   // input fp16 and output fp32 combination is supported by legacy cublas and
   // cublasLt, but gemm output is already aliased with one of the input expect
   // GemmRewriter to not fuse the convert into gemm.
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              GmockMatch(m::Convert(m::CustomCall({CustomCallTarget()}))));
+  if (GetParam()) {
+    EXPECT_THAT(module->entry_computation()->root_instruction(),
+                GmockMatch(m::Convert(m::CustomCall({CustomCallTarget()}))));
+  } else {
+    EXPECT_THAT(module->entry_computation()->root_instruction(),
+                GmockMatch(m::Convert(m::GetTupleElement(
+                    m::CustomCall({CustomCallTarget()}), 0))));
+  }
 }
 
 INSTANTIATE_TEST_SUITE_P(CublasTestsBothLegacyAndLt,
@@ -1333,9 +1379,11 @@
 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2], param_2: f32[2,2]) -> f32[2,2] {
 ; CHECK-DAG:     [[X:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
 ; CHECK-DAG:     [[Y:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
-; CHECK:         ROOT [[OUT:%[^ ]+]] = f32[2,2]{1,0} custom-call([[X]], [[Y]], {{[^,)]+}}),
+; CHECK:         [[O:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[X]], [[Y]], {{[^,)]+}}),
 ; CHECK:           custom_call_target="__cublas$gemm",
-; CHECK:           output_to_operand_aliasing={{{{}: \(2, {}\)}}},
+; CHECK:           output_to_operand_aliasing={
+; CHECK-SAME:        {0}: (2, {})
+; CHECK-SAME:      }
 ; CHECK:           backend_config={
 ; CHECK-DAG:         "alpha_real":3
 ; CHECK-DAG:         "alpha_imag":0
@@ -1351,6 +1399,7 @@
 ; CHECK-DAG:         }
 ; CHECK-DAG:         "epilogue":"DEFAULT"
 ; CHECK:           }
+; CHECK:         ROOT [[OUT:%[^ ]+]] = f32[2,2]{1,0} get-tuple-element([[O]]), index=0
 )");
 }
 
@@ -1377,7 +1426,7 @@
 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2], bias: f32[2,2]) -> f32[2,2] {
 ; CHECK-DAG:     [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
 ; CHECK-DAG:     [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
-; CHECK-NEXT:    [[GEMM:%[^ ]+]] = f32[2,2]{1,0} custom-call([[P0]], [[P1]]),
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
 ; CHECK:           custom_call_target="__cublas$gemm",
 ; CHECK:           backend_config={
 ; CHECK-DAG:         "alpha_real":3
@@ -1416,7 +1465,7 @@
 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2], bias: f32[2,2]) -> f32[2,2] {
 ; CHECK-DAG:     [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
 ; CHECK-DAG:     [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
-; CHECK-NEXT:    [[GEMM:%[^ ]+]] = f32[2,2]{1,0} custom-call([[P0]], [[P1]]),
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
 ; CHECK:           custom_call_target="__cublas$gemm",
 ; CHECK:           backend_config={
 ; CHECK-DAG:         "alpha_real":1
@@ -1459,9 +1508,11 @@
 ; CHECK-DAG:     [[P2:%[^ ]+]] = (f32[2,2]{1,0}, f32[3,3]{1,0}) parameter(2)
 ; CHECK-DAG:     [[BIAS:%[^ ]+]] = f32[2,2]{1,0} get-tuple-element([[P2]]), index=0
 ; CHECK-DAG:     [[BIAS_COPY:%[^ ]+]] = f32[2,2]{1,0} copy([[BIAS]])
-; CHECK-NEXT:    [[GEMM:%[^ ]+]] = f32[2,2]{1,0} custom-call([[P0]], [[P1]], [[BIAS_COPY]]),
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], [[BIAS_COPY]]),
 ; CHECK:           custom_call_target="__cublas$gemm",
-; CHECK:           output_to_operand_aliasing={{{{}: \(2, {}\)}}},
+; CHECK:           output_to_operand_aliasing={
+; CHECK-SAME:        {0}: (2, {})
+; CHECK-SAME:      }
 ; CHECK:           backend_config={
 ; CHECK-DAG:         "alpha_real":1
 ; CHECK-DAG:         "alpha_imag":0
@@ -1504,9 +1555,11 @@
 ; CHECK-DAG:     [[X:%[^ ]+]] = f32[2,2]{1,0} parameter(0)
 ; CHECK-DAG:     [[Y:%[^ ]+]] = f32[2,2]{1,0} parameter(1)
 ; CHECK-DAG:     [[BIAS:%[^ ]+]] = f32[2,2]{1,0} parameter(2)
-; CHECK:         ROOT [[OUT:%[^ ]+]] = f32[2,2]{1,0} custom-call([[X]], [[Y]], [[BIAS]]),
+; CHECK:         [[GEMM:%[^ ]+]] = (f32[2,2]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[X]], [[Y]], [[BIAS]]),
 ; CHECK:           custom_call_target="__cublas$gemm",
-; CHECK:           output_to_operand_aliasing={{{{}: \(2, {}\)}}},
+; CHECK:           output_to_operand_aliasing={
+; CHECK-SAME:        {0}: (2, {})
+; CHECK-SAME:      }
 ; CHECK:           backend_config={
 ; CHECK-DAG:         "alpha_real":3
 ; CHECK-DAG:         "alpha_imag":0
@@ -1546,7 +1599,7 @@
 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[1024,1024], y: f32[1024,1024], bias: f32[1024,1024]) -> f32[1024,1024] {
 ; CHECK-DAG:     [[P0:%[^ ]+]] = f32[1024,1024]{1,0} parameter(0)
 ; CHECK-DAG:     [[P1:%[^ ]+]] = f32[1024,1024]{1,0} parameter(1)
-; CHECK-NEXT:    [[GEMM:%[^ ]+]] = f32[1024,1024]{1,0} custom-call([[P0]], [[P1]]),
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = (f32[1024,1024]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
 ; CHECK:           custom_call_target="__cublas$gemm",
 ; CHECK:           backend_config={
 ; CHECK-DAG:         "alpha_real":1
@@ -1591,9 +1644,11 @@
 ; CHECK-LABEL: ENTRY %BF16GemmWithBias (x: bf16[8,8], y: bf16[8,8], param_2: bf16[8,8]) -> bf16[8,8] {
 ; CHECK-DAG:    [[X:%[^ ]+]] = bf16[8,8]{1,0} parameter(0)
 ; CHECK-DAG:    [[Y:%[^ ]+]] = bf16[8,8]{1,0} parameter(1)
-; CHECK:        ROOT [[GEMM:%[^ ]+]] = bf16[8,8]{1,0} custom-call([[X]], [[Y]], {{[^,)]+}}),
+; CHECK:        [[GEMM:%[^ ]+]] = (bf16[8,8]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[X]], [[Y]], {{[^,)]+}}),
 ; CHECK:           custom_call_target="__cublas$gemm",
-; CHECK:           output_to_operand_aliasing={{{{}: \(2, {}\)}}},
+; CHECK:           output_to_operand_aliasing={
+; CHECK-SAME:        {0}: (2, {})
+; CHECK-SAME:      }
 ; CHECK:           backend_config={
 ; CHECK-DAG:         "alpha_real":1
 ; CHECK-DAG:         "alpha_imag":0
@@ -1638,9 +1693,11 @@
 ; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], param_2: f32[2,4]) -> f32[2,4] {
 ; CHECK-DAG:     [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0)
 ; CHECK-DAG:     [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
-; CHECK:         ROOT [[GEMM:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]], {{[^,)]+}}),
+; CHECK:         [[GEMM:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]], {{[^,)]+}}),
 ; CHECK:           custom_call_target="__cublas$gemm",
-; CHECK:           output_to_operand_aliasing={{{{}: \(2, {}\)}}},
+; CHECK:           output_to_operand_aliasing={
+; CHECK-SAME:        {0}: (2, {})
+; CHECK-SAME:      }
 ; CHECK:           backend_config={
 ; CHECK-DAG:         "alpha_real":1
 ; CHECK-DAG:         "alpha_imag":0
@@ -1683,7 +1740,7 @@
 ; CHECK-DAG:     [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1)
 ; CHECK-DAG:     [[P2:%[^ ]+]] = f32[2,3]{1,0} parameter(2)
 ; CHECK-DAG:     [[P3:%[^ ]+]] = f32[3,4]{1,0} parameter(3)
-; CHECK-NEXT:    [[FIRST_GEMM:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]]),
+; CHECK-NEXT:    [[FIRST_GEMM:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1]]),
 ; CHECK:           custom_call_target="__cublas$gemm",
 ; CHECK:           backend_config={
 ; CHECK-DAG:         "alpha_real":1
@@ -1700,9 +1757,12 @@
 ; CHECK-DAG:         }
 ; CHECK-DAG:         "epilogue":"DEFAULT"
 ; CHECK:           }
-; CHECK-NEXT:    ROOT [[SECOND_GEMM:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P2]], [[P3]], [[FIRST_GEMM]]),
+; CHECK:         [[FIRST_GEMM_OUT:%[^ ]+]] = f32[2,4]{1,0} get-tuple-element([[FIRST_GEMM]]), index=0
+; CHECK-NEXT:    [[SECOND_GEMM:%[^ ]+]] = (f32[2,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P2]], [[P3]], [[FIRST_GEMM_OUT]]),
 ; CHECK:           custom_call_target="__cublas$gemm",
-; CHECK:           output_to_operand_aliasing={{{{}: \(2, {}\)}}},
+; CHECK:           output_to_operand_aliasing={
+; CHECK-SAME:        {0}: (2, {})
+; CHECK-SAME:      }
 ; CHECK:           backend_config={
 ; CHECK-DAG:         "alpha_real":1
 ; CHECK-DAG:         "alpha_imag":0
@@ -1753,8 +1813,10 @@
     TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
                             GetOptimizedModule(hlo_text));
     EXPECT_THAT(optimized_module->entry_computation()->root_instruction(),
-                GmockMatch(m::CustomCall(m::Parameter(0), m::Parameter(1),
-                                         m::Negate(m::Parameter(2)))));
+                GmockMatch(m::GetTupleElement(
+                    m::CustomCall(m::Parameter(0), m::Parameter(1),
+                                  m::Negate(m::Parameter(2))),
+                    0)));
   }
 }
 
@@ -1788,8 +1850,10 @@
     TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
                             GetOptimizedModule(hlo_text));
     EXPECT_THAT(optimized_module->entry_computation()->root_instruction(),
-                GmockMatch(m::CustomCall(m::Parameter(0), m::Parameter(1),
-                                         m::Negate(m::Parameter(2)))));
+                GmockMatch(m::GetTupleElement(
+                    m::CustomCall(m::Parameter(0), m::Parameter(1),
+                                  m::Negate(m::Parameter(2))),
+                    0)));
   }
 }
 #endif
@@ -1818,7 +1882,9 @@
       optimized_module->entry_computation()->root_instruction(),
       GmockMatch(m::Fusion(
           m::Parameter(2),
-          m::CustomCall({"__cublas$gemm"}, m::Parameter(0), m::Parameter(1)))));
+          m::GetTupleElement(m::CustomCall({"__cublas$gemm"}, m::Parameter(0),
+                                           m::Parameter(1)),
+                             0))));
 }
 
 // Test batch gemm matrix bias add fusion with mix type that is not supported
@@ -1847,7 +1913,9 @@
       optimized_module->entry_computation()->root_instruction(),
       GmockMatch(m::Fusion(
           m::Parameter(2),
-          m::CustomCall({"__cublas$gemm"}, m::Parameter(0), m::Parameter(1)))));
+          m::GetTupleElement(m::CustomCall({"__cublas$gemm"}, m::Parameter(0),
+                                           m::Parameter(1)),
+                             0))));
 }
 
 TEST_F(LegacyCublasGemmRewriteTest, MergeBitcastAndAdd) {
@@ -1872,8 +1940,11 @@
       module->entry_computation()->root_instruction(),
       GmockMatch(
           m::Bitcast(
-              m::CustomCall({"__cublas$gemm"}, m::Parameter(0), m::Parameter(1),
-                            m::Bitcast(m::Parameter(2)).WithShape(F32, {2, 2})))
+              m::GetTupleElement(
+                  m::CustomCall(
+                      {"__cublas$gemm"}, m::Parameter(0), m::Parameter(1),
+                      m::Bitcast(m::Parameter(2)).WithShape(F32, {2, 2})),
+                  0))
               .WithShape(F32, {4})));
 }
 
@@ -1919,11 +1990,18 @@
   EXPECT_THAT(
       module->entry_computation()->root_instruction(),
       GmockMatch(m::Tuple(
-          m::CustomCall(m::Parameter(0), m::Parameter(1),
-                        m::Negate(m::Parameter(2))),
-          m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
-          m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
-          m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()))));
+          m::GetTupleElement(m::CustomCall(m::Parameter(0), m::Parameter(1),
+                                           m::Negate(m::Parameter(2))),
+                             0),
+          m::GetTupleElement(
+              m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
+              0),
+          m::GetTupleElement(
+              m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
+              0),
+          m::GetTupleElement(
+              m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
+              0))));
 }
 
 #if GOOGLE_CUDA
@@ -4719,7 +4797,7 @@
   MatchOptimizedHlo(hlo_text,
                     R"(
 ; CHECK-LABEL: ENTRY %PreAdaTest (x: f8e4m3fn[16,32], y: f8e4m3fn[32,16]) -> f8e4m3fn[16,16] {
-; CHECK:    {{.*}} = f16[16,16]{1,0} custom-call({{.*}}, {{.*}})
+; CHECK:    {{.*}} = {{.*}} custom-call({{.*}}, {{.*}})
 ; CHECK-DAG:  custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>"
           )");
 }
@@ -4747,7 +4825,7 @@
 ; CHECK-NEXT:    [[P0_CONVERT:%[^ ]+]] = f16[16,16]{1,0} convert([[P0]])
 ; CHECK-NEXT:    [[P1:%[^ ]+]] = f8e5m2[16,16]{1,0} parameter(1)
 ; CHECK-NEXT:    [[P1_CONVERT:%[^ ]+]] = f16[16,16]{1,0} convert([[P1]])
-; CHECK-NEXT:    [[DOT:%[^ ]+]] = f16[16,16]{1,0} custom-call([[P0_CONVERT]], [[P1_CONVERT]]),
+; CHECK-NEXT:    [[DOT:%[^ ]+]] = {{.*}} custom-call([[P0_CONVERT]], [[P1_CONVERT]]),
 ; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
 ; CHECK:           backend_config={
 ; CHECK-DAG:         "alpha_real":1
@@ -4764,7 +4842,7 @@
 ; CHECK-DAG:         }
 ; CHECK-DAG:         "epilogue":"DEFAULT"
 ; CHECK:           }
-; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f8e5m2[16,16]{1,0} convert([[DOT]])
+; CHECK:         ROOT [[OUT:%[^ ]+]] = f8e5m2[16,16]{1,0} convert
       )",
                                                 replacements_));
 }
@@ -5848,11 +5926,11 @@
 ; CHECK-DAG:         "alpha_real":1
 ; CHECK-DAG:         "alpha_imag":0
 ; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{ 
+; CHECK-DAG:         "dot_dimension_numbers":{
 ; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
 ; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
 ; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[] 
+; CHECK-DAG:           "rhs_batch_dimensions":[]
 ; CHECK-DAG:         }
 ; CHECK-DAG:         "precision_config":{
 ; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
@@ -5910,11 +5988,11 @@
 ; CHECK-DAG:         "alpha_real":1
 ; CHECK-DAG:         "alpha_imag":0
 ; CHECK-DAG:         "beta":0
-; CHECK-DAG:         "dot_dimension_numbers":{ 
+; CHECK-DAG:         "dot_dimension_numbers":{
 ; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
 ; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
 ; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[] 
+; CHECK-DAG:           "rhs_batch_dimensions":[]
 ; CHECK-DAG:         }
 ; CHECK-DAG:         "precision_config":{
 ; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
@@ -6605,11 +6683,11 @@
 ; CHECK-DAG:         "alpha_real":1
 ; CHECK-DAG:         "alpha_imag":0
 ; CHECK-DAG:         "beta":1
-; CHECK-DAG:         "dot_dimension_numbers":{ 
+; CHECK-DAG:         "dot_dimension_numbers":{
 ; CHECK-DAG:           "lhs_contracting_dimensions":["1"]
 ; CHECK-DAG:           "rhs_contracting_dimensions":["1"]
 ; CHECK-DAG:           "lhs_batch_dimensions":[]
-; CHECK-DAG:           "rhs_batch_dimensions":[] 
+; CHECK-DAG:           "rhs_batch_dimensions":[]
 ; CHECK-DAG:         }
 ; CHECK-DAG:         "precision_config":{
 ; CHECK-DAG:           "operand_precision":["DEFAULT","DEFAULT"]
@@ -6862,6 +6940,38 @@
       )");
 }
 
+TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDPrecisionF8) {
+#if CUDA_VERSION < 12000
+  GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
+#endif  // CUDA_VERSION < 12000
+  const char* hlo_template = R"(
+    HloModule test
+
+    ENTRY test {
+      x = f8e4m3fn[1600,3200] parameter(0)
+      y = f8e4m3fn[3200,1600] parameter(1)
+      x_f32 = f32[1600,3200] convert(x)
+      y_f32 = f32[3200,1600] convert(y)
+      x_scale = f32[] parameter(2)
+      y_scale = f32[] parameter(3)
+      x_scale_bcast = f32[1600,3200] broadcast(x_scale), dimensions={}
+      y_scale_bcast = f32[3200,1600] broadcast(y_scale), dimensions={}
+      x_unscaled = f32[1600,3200] multiply(x_f32, x_scale_bcast)
+      y_unscaled = f32[3200,1600] multiply(y_f32, y_scale_bcast)
+      ROOT out = f32[1600,1600] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={<<precision>>,<<precision>>}
+          }
+)";
+
+  absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
+  replacements["<<precision>>"] = "default";
+  const auto hlo_text_default = absl::StrReplaceAll(hlo_template, replacements);
+  EXPECT_TRUE(RunAndCompare(hlo_text_default, ErrorSpec{1e-3, 1e-3}));
+
+  replacements["<<precision>>"] = "highest";
+  const auto hlo_text_highest = absl::StrReplaceAll(hlo_template, replacements);
+  EXPECT_TRUE(RunAndCompare(hlo_text_highest, ErrorSpec{1e-4, 1e-4}));
+}
+
 TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8Parameterized) {
 #if CUDA_VERSION < 12000
   GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
@@ -7070,7 +7180,7 @@
 ; CHECK-NEXT:    [[P3:%[^ ]+]] = f32[] parameter(3)
 ; CHECK-NEXT:    [[P3_B:%[^ ]+]] = f32[32,16]{1,0} broadcast([[P3]]), dimensions={}
 ; CHECK-NEXT:    [[P1_UNSCALED:%[^ ]+]] = f32[32,16]{1,0} multiply([[P1_CV]], [[P3_B]])
-; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[16,16]{1,0} custom-call([[P0_UNSCALED]], [[P1_UNSCALED]]),
+; CHECK-NEXT:    [[GEMM:%[^ ]+]] = {{.*}} custom-call([[P0_UNSCALED]], [[P1_UNSCALED]]),
 ; CHECK:           custom_call_target="<<CUBLAS_CUSTOM_CALL_TARGET_PLACEHOLDER>>",
 ; CHECK:           backend_config={
 ; CHECK-DAG:         "alpha_real":1
@@ -7132,7 +7242,7 @@
         static_cast<GpuExecutable*>(executable.get());
     absl::Span<const BufferAllocation> allocations =
         gpu_executable->GetAllocations();
-    CHECK_EQ(allocations.size(), expected_number_of_allocations);
+    ASSERT_EQ(allocations.size(), expected_number_of_allocations);
   }
 };
 
@@ -7151,7 +7261,7 @@
 )";
 
   // Bias should be fused into the multiplication.
-  CheckNumberOfAllocations(hlo_text, 3);
+  CheckNumberOfAllocations(hlo_text, 4);
   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
 }
 
diff --git a/third_party/xla/xla/service/gpu/tests/gpu_all_gather_optimizer_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_all_gather_optimizer_test.cc
index 0b097c9..bd6bb00 100644
--- a/third_party/xla/xla/service/gpu/tests/gpu_all_gather_optimizer_test.cc
+++ b/third_party/xla/xla/service/gpu/tests/gpu_all_gather_optimizer_test.cc
@@ -208,6 +208,25 @@
   EXPECT_EQ(CollectiveCount<HloOpcode::kReduceScatter>(module), 1);
 }
 
+TEST_F(GpuAllGatherOptimizerTest, DifferentOperandShapes) {
+  absl::string_view hlo_string = R"(
+HloModule TestModule
+
+ENTRY main {
+param.1 = bf16[8,64,128]{2,1,0} parameter(0)
+param.2 = bf16[8,128,64]{2,1,0} parameter(1)
+all-gather.1 = bf16[8,128,128]{2,1,0} all-gather(param.1), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true
+all-gather.2 = bf16[8,128,128]{2,1,0} all-gather(param.2), channel_id=5, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={2}, use_global_device_ids=true
+add.1 = bf16[8,128,128]{2,1,0} add(all-gather.1, all-gather.2)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
+                                               /*num_replicas=*/8,
+                                               /*num_partitions=*/1,
+                                               /*expect_change=*/false));
+}
+
 }  // namespace
 }  // namespace gpu
 }  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/tests/gpu_dyn_shape_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_dyn_shape_test.cc
index 47116ef..73fb7e8 100644
--- a/third_party/xla/xla/service/gpu/tests/gpu_dyn_shape_test.cc
+++ b/third_party/xla/xla/service/gpu/tests/gpu_dyn_shape_test.cc
@@ -14,9 +14,12 @@
 ==============================================================================*/
 #include <utility>
 
+#include <gtest/gtest.h>
 #include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
 #include "xla/service/gpu/tests/gpu_codegen_test.h"
-#include "xla/service/hlo_module_config.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
 
 namespace xla {
 namespace gpu {
@@ -38,8 +41,8 @@
   CompileAndVerifyIr(std::move(hlo_module),
                      R"(
 ; CHECK-LABEL: is_thread_0-true
-; CHECK-LABEL: custom_call.in_dyn_bounds-true
-; CHECK-LABEL: custom_call.in_bounds-true
+; CHECK-LABEL: x_padded.in_dyn_bounds-true
+; CHECK-LABEL: x_padded.in_bounds-true
 ; CHECK: %[[dyn_dim_size:.*]] = load i32, ptr
 ; CHECK: %[[dyn_element_total:.*]] = mul i32 1, %[[dyn_dim_size:.*]]
 ; CHECK: %[[linear_index:.*]] = add nuw nsw i32
diff --git a/third_party/xla/xla/service/gpu/tests/gpu_int4_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_int4_test.cc
new file mode 100644
index 0000000..26355c5
--- /dev/null
+++ b/third_party/xla/xla/service/gpu/tests/gpu_int4_test.cc
@@ -0,0 +1,91 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <optional>
+#include <string>
+#include <utility>
+
+#include "xla/service/gpu/tests/gpu_codegen_test.h"
+#include "tsl/platform/test.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+class GpuInt4Test : public GpuCodegenTest {};
+
+TEST_F(GpuInt4Test, TestInt4ParameterSize) {
+  const std::string hlo_text = R"(
+  HloModule Reshape
+  ENTRY main {
+    x = s4[4] parameter(0)
+    ROOT y = s8[4] convert(x)
+  })";
+  auto hlo_module =
+      ParseAndReturnVerifiedModule(hlo_text, GetModuleConfigForTest()).value();
+
+  // The input should be 2 bytes and the output should be 4 bytes
+  auto expected_ir = R"(
+; CHECK: define void {{.*}} dereferenceable(2){{.*}} dereferenceable(4)
+)";
+  CompileAndVerifyIr(std::move(hlo_module),
+                     MakePlatformSpecificLlvm(expected_ir),
+                     /*match_optimized_ir=*/true);
+  EXPECT_TRUE(RunAndCompare(hlo_text, /*error=*/std::nullopt));
+}
+
+TEST_F(GpuInt4Test, TestInt4OutputSize) {
+  const std::string hlo_text = R"(
+  HloModule Reshape
+  ENTRY main {
+    x = s8[4] parameter(0)
+    ROOT y = s4[4] convert(x)
+  })";
+  auto hlo_module =
+      ParseAndReturnVerifiedModule(hlo_text, GetModuleConfigForTest()).value();
+
+  // The input should be 4 bytes and the output should be 2 bytes
+  auto expected_ir = R"(
+; CHECK: define void {{.*}} dereferenceable(4){{.*}} dereferenceable(2)
+)";
+  CompileAndVerifyIr(std::move(hlo_module),
+                     MakePlatformSpecificLlvm(expected_ir),
+                     /*match_optimized_ir=*/true);
+  EXPECT_TRUE(RunAndCompare(hlo_text, /*error=*/std::nullopt));
+}
+
+TEST_F(GpuInt4Test, TestConstantSize) {
+  const std::string hlo_text = R"(
+  HloModule Reshape
+  ENTRY main {
+    x = s4[4] constant({1, 2, 3, 4})
+    ROOT y = s8[4] convert(x)
+  })";
+  auto hlo_module =
+      ParseAndReturnVerifiedModule(hlo_text, GetModuleConfigForTest()).value();
+
+  // The constant should be 2 bytes and the output should be 4 bytes
+  auto expected_ir = R"(
+; CHECK: define void {{.*}} dereferenceable(2){{.*}} dereferenceable(4)
+)";
+  CompileAndVerifyIr(std::move(hlo_module),
+                     MakePlatformSpecificLlvm(expected_ir),
+                     /*match_optimized_ir=*/true);
+  EXPECT_TRUE(RunAndCompare(hlo_text, /*error=*/std::nullopt));
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla
diff --git a/third_party/xla/xla/service/gpu/tests/sorting_test.cc b/third_party/xla/xla/service/gpu/tests/sorting_test.cc
index 4c9990f..2d6cca3 100644
--- a/third_party/xla/xla/service/gpu/tests/sorting_test.cc
+++ b/third_party/xla/xla/service/gpu/tests/sorting_test.cc
@@ -69,7 +69,7 @@
 }
 
 // Size of the radix sort tests.
-static constexpr int kRadixSortTestSize = 100;
+static constexpr int kRadixSortTestSize = 100000;
 
 template <typename T>
 bool CheckOrder(T lhs, T rhs, bool asc, int pos) {
@@ -106,9 +106,9 @@
 
 ENTRY %main {
   %input = $0[$1] parameter(0)
-  %sort = ($0[$1], u8[1000]) custom-call(%input),
+  %sort = ($0[$1], u8[$2]) custom-call(%input),
       custom_call_target="__cub$$DeviceRadixSort",
-      backend_config="{\"descending\": $2}"
+      backend_config="{\"descending\": $3}"
   ROOT %gte = get-tuple-element(%sort), index=0
 }
 )";
@@ -117,7 +117,9 @@
   std::string hlo = absl::Substitute(
       kHloTemplate,
       GetTypeName(std::get<0>(GetParam())->shape().element_type()),
-      kRadixSortTestSize, ascending ? "false" : "true");
+      kRadixSortTestSize,
+      kRadixSortTestSize * 10,  // added scratch buffer size
+      ascending ? "false" : "true");
 
   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo));
   std::vector<Literal*> literals = {std::get<0>(GetParam()).get()};
@@ -143,9 +145,9 @@
 ENTRY %main {
   %keys = $0[$2] parameter(0)
   %values = $1[$2] convert(%keys)
-  ROOT %sort = ($0[$2], $1[$2], u8[1000]) custom-call(%keys, %values),
+  ROOT %sort = ($0[$2], $1[$2], u8[$3]) custom-call(%keys, %values),
       custom_call_target="__cub$$DeviceRadixSort",
-      backend_config="{\"descending\": $3}"
+      backend_config="{\"descending\": $4}"
 }
 )";
 
@@ -154,6 +156,7 @@
       kHloTemplate,
       GetTypeName(std::get<0>(GetParam())->shape().element_type()),
       GetTypeName(std::get<1>(GetParam())), kRadixSortTestSize,
+      kRadixSortTestSize * 20,  // added scratch buffer size
       ascending ? "false" : "true");
 
   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo));
diff --git a/third_party/xla/xla/service/gpu/triton_autotuner.cc b/third_party/xla/xla/service/gpu/triton_autotuner.cc
index 1c5c585..8e91a17 100644
--- a/third_party/xla/xla/service/gpu/triton_autotuner.cc
+++ b/third_party/xla/xla/service/gpu/triton_autotuner.cc
@@ -59,6 +59,7 @@
 #include "xla/service/gpu/gpu_fusible.h"
 #include "xla/service/gpu/instruction_fusion.h"
 #include "xla/service/gpu/ir_emission_utils.h"
+#include "xla/service/gpu/matmul_utils.h"
 #include "xla/service/gpu/split_k_gemm_rewriter.h"
 #include "xla/service/gpu/stream_executor_util.h"
 #include "xla/service/hlo_module_config.h"
@@ -74,6 +75,7 @@
 #include "xla/stream_executor/stream.h"
 #include "xla/util.h"
 #include "xla/xla.pb.h"
+#include "tsl/lib/core/bits.h"
 #include "tsl/platform/blocking_counter.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/status.h"
@@ -95,21 +97,8 @@
 
 namespace {
 
-// Constructs an autotuning key for a gemm performed in Triton.
-static AutotuneResult::TritonGemmKey GemmKey(int64_t block_m, int64_t block_n,
-                                             int64_t block_k, int64_t split_k,
-                                             int64_t num_stages,
-                                             int64_t num_warps) {
-  AutotuneResult::TritonGemmKey key;
-  key.set_block_m(block_m);
-  key.set_block_n(block_n);
-  key.set_block_k(block_k);
-  key.set_split_k(split_k);
-  key.set_num_stages(num_stages);
-  key.set_num_warps(num_warps);
-  return key;
-}
-
+// Currently supported minimum tile size.
+constexpr int kMinTileSize = 16;
 // Not a hard limit, just an assumption that should stay valid.
 constexpr int kMaxTileSize = 512;
 
@@ -157,10 +146,10 @@
 
     // This cannot be the "else" branch of the previous "if".
     if (backend_config.has_triton_gemm_config()) {
-      const AutotuneResult::TritonGemmKey& tiling =
-          backend_config.triton_gemm_config();
-      if (tiling.split_k() > 1) {
-        TF_RETURN_IF_ERROR(MakeDotSplitKBatch(hlo, tiling));
+      const TritonGemmConfig config =
+          TritonGemmConfig::FromProto(backend_config.triton_gemm_config());
+      if (config.split_k > 1) {
+        TF_RETURN_IF_ERROR(MakeDotSplitKBatch(hlo, config));
       }
     }
 
@@ -174,11 +163,11 @@
 
 // This contains all alternative Triton GEMM configs related to one fusion.
 struct GemmConfigSet {
-  std::vector<AutotuneResult::TritonGemmKey> configs;
+  std::vector<TritonGemmConfig> configs;
 };
 
 struct ExecutableCandidate {
-  AutotuneResult::TritonGemmKey config;
+  TritonGemmConfig config;
   // Not nullptr.
   std::unique_ptr<Executable> executable;
 };
@@ -250,15 +239,49 @@
   absl::flat_hash_set<AutotuneCacheKey> handled_fusions_;
 };
 
+struct TileSizeLimit {
+  int64_t block_m = 0;
+  int64_t block_n = 0;
+  int64_t block_k = 0;
+};
+
+TileSizeLimit GetUpperLimit(const HloDotInstruction& dot) {
+  // This is not a sharp upper limit, the actual m value can be much smaller
+  // based on how much of the m dimension is physically contiguous.
+  // TODO(tdanyluk): Get the exact m value by running a TritonFusionAnalysis.
+  const int64_t m = dot.operand(0)->shape().dimensions(
+      NonContractingDimensionIndex(dot, /*operand_number=*/0));
+  // Theoretically the same is true as for m, but that is not possible in
+  // practice with the current implementation.
+  const int64_t n = dot.operand(1)->shape().dimensions(
+      NonContractingDimensionIndex(dot, /*operand_number=*/1));
+  // This is before doing the split-k transform.
+  const int64_t k = dot.operand(0)->shape().dimensions(
+      ContractingDimensionIndex(dot, /*operand_number=*/0));
+  const int64_t block_m_limit =
+      std::max<int64_t>(tsl::NextPowerOfTwoS64(m), kMinTileSize);
+  const int64_t block_n_limit =
+      std::max<int64_t>(tsl::NextPowerOfTwoS64(n), kMinTileSize);
+  const int64_t block_k_limit =
+      std::max<int64_t>(tsl::NextPowerOfTwoS64(k), kMinTileSize);
+  return {block_m_limit, block_n_limit, block_k_limit};
+}
+
+int64_t GetSplitKLimit(int64_t block_k, int64_t block_k_limit) {
+  return std::max<int64_t>(block_k_limit / block_k, 1);
+}
+
 // Search space for exhaustive matmul autotuning.
 constexpr std::array<int, 6> BLOCK_SIZES = {16, 32, 64, 128, 256, 512};
 constexpr std::array<int, 4> NUM_STAGES = {1, 2, 3, 4};
 constexpr std::array<int, 4> NUM_WARPS = {2, 4, 8, 16};
 constexpr std::array<int, 5> SPLIT_K = {1, 2, 4, 8, 16};
 
-std::vector<AutotuneResult::TritonGemmKey> GetExhaustiveMatmulAutotuneConfigs(
+std::vector<TritonGemmConfig> GetExhaustiveMatmulAutotuneConfigs(
+    const HloDotInstruction& dot,
     const se::CudaComputeCapability compute_capability, const int max_split_k) {
-  std::vector<AutotuneResult::TritonGemmKey> configs;
+  const TileSizeLimit limit = GetUpperLimit(dot);
+  std::vector<TritonGemmConfig> configs;
   bool mma_layout_v2 =
       compute_capability.IsAtLeast(se::CudaComputeCapability::AMPERE);
   for (int num_warps : NUM_WARPS) {
@@ -268,18 +291,27 @@
         continue;
       }
       for (int block_m : BLOCK_SIZES) {
+        if (block_m > limit.block_m) {
+          continue;
+        }
         for (int block_n : BLOCK_SIZES) {
           // Exclude configs not supported by MMA layout v2.
-          if (mma_layout_v2 && (block_m * block_n / 256) % num_warps != 0) {
+          if (block_n > limit.block_n ||
+              (mma_layout_v2 && (block_m * block_n / 256) % num_warps != 0)) {
             continue;
           }
           for (int block_k : BLOCK_SIZES) {
+            if (block_k > limit.block_k) {
+              continue;
+            }
             for (int split_k : SPLIT_K) {
-              if (split_k > max_split_k) {
+              if (split_k >
+                  std::min<int64_t>(max_split_k,
+                                    GetSplitKLimit(block_k, limit.block_k))) {
                 continue;
               }
-              auto config = GemmKey(block_m, block_n, block_k, split_k,
-                                    num_stages, num_warps);
+              auto config = TritonGemmConfig(block_m, block_n, block_k, split_k,
+                                             num_stages, num_warps);
               configs.push_back(std::move(config));
             }
           }
@@ -290,60 +322,90 @@
   return configs;
 }
 
-std::vector<AutotuneResult::TritonGemmKey> GetFixedMatmulAutotuneConfigs(
+std::vector<TritonGemmConfig> GetFixedMatmulAutotuneConfigs(
     const se::CudaComputeCapability compute_capability, const int max_split_k) {
-  std::vector<AutotuneResult::TritonGemmKey> configs = {
-      GemmKey(32, 32, 256, 1, 1, 4), GemmKey(64, 32, 32, 16, 1, 4),
-      GemmKey(32, 64, 64, 4, 1, 4),  GemmKey(128, 128, 64, 4, 1, 4),
-      GemmKey(16, 16, 256, 1, 1, 4), GemmKey(16, 128, 32, 16, 1, 4),
-      GemmKey(16, 64, 128, 1, 1, 4), GemmKey(16, 128, 32, 8, 1, 4),
-      GemmKey(16, 16, 512, 1, 1, 4), GemmKey(32, 16, 512, 1, 1, 4),
-      GemmKey(64, 32, 64, 1, 2, 8)};
+  // Shorter name for better formatting.
+  using Config = TritonGemmConfig;
+  std::vector<Config> configs = {
+      Config(32, 32, 256, 1, 1, 4), Config(64, 32, 32, 16, 1, 4),
+      Config(32, 64, 64, 4, 1, 4),  Config(128, 128, 64, 4, 1, 4),
+      Config(16, 16, 256, 1, 1, 4), Config(16, 128, 32, 16, 1, 4),
+      Config(16, 64, 128, 1, 1, 4), Config(16, 128, 32, 8, 1, 4),
+      Config(16, 16, 512, 1, 1, 4), Config(32, 16, 512, 1, 1, 4),
+      Config(64, 32, 64, 1, 2, 8)};
   if (compute_capability.IsAtLeast(se::CudaComputeCapability::AMPERE)) {
     absl::c_copy(
-        std::vector<AutotuneResult::TritonGemmKey>{
-            GemmKey(128, 256, 32, 1, 3, 8),  GemmKey(256, 128, 32, 1, 3, 8),
-            GemmKey(256, 64, 32, 1, 4, 4),   GemmKey(64, 256, 32, 1, 4, 4),
-            GemmKey(128, 64, 32, 1, 4, 4),   GemmKey(64, 128, 32, 1, 4, 4),
-            GemmKey(256, 128, 128, 1, 3, 8), GemmKey(256, 64, 128, 1, 4, 4),
-            GemmKey(64, 256, 128, 1, 4, 4),  GemmKey(128, 128, 128, 1, 4, 4),
-            GemmKey(128, 64, 64, 1, 4, 4),   GemmKey(64, 128, 64, 1, 4, 4),
-            GemmKey(128, 32, 64, 1, 4, 4),   GemmKey(64, 32, 64, 1, 4, 4),
-            GemmKey(32, 128, 32, 1, 4, 4),   GemmKey(128, 128, 32, 1, 4, 4),
-            GemmKey(16, 16, 256, 1, 3, 4),   GemmKey(128, 128, 64, 2, 1, 8),
-            GemmKey(64, 64, 64, 1, 2, 4),    GemmKey(16, 64, 256, 8, 1, 4),
-            GemmKey(256, 256, 128, 1, 3, 8)},
+        std::vector<Config>{
+            Config(128, 256, 32, 1, 3, 8),  Config(256, 128, 32, 1, 3, 8),
+            Config(256, 64, 32, 1, 4, 4),   Config(64, 256, 32, 1, 4, 4),
+            Config(128, 64, 32, 1, 4, 4),   Config(64, 128, 32, 1, 4, 4),
+            Config(256, 128, 128, 1, 3, 8), Config(256, 64, 128, 1, 4, 4),
+            Config(64, 256, 128, 1, 4, 4),  Config(128, 128, 128, 1, 4, 4),
+            Config(128, 64, 64, 1, 4, 4),   Config(64, 128, 64, 1, 4, 4),
+            Config(128, 32, 64, 1, 4, 4),   Config(64, 32, 64, 1, 4, 4),
+            Config(32, 128, 32, 1, 4, 4),   Config(128, 128, 32, 1, 4, 4),
+            Config(16, 16, 256, 1, 3, 4),   Config(128, 128, 64, 2, 1, 8),
+            Config(64, 64, 64, 1, 2, 4),    Config(16, 64, 256, 8, 1, 4),
+            Config(256, 256, 128, 1, 3, 8)},
         std::back_inserter(configs));
   }
   if (compute_capability.IsAtLeast(se::CudaComputeCapability::HOPPER)) {
     configs.erase(
         std::remove_if(configs.begin(), configs.end(),
-                       [](const AutotuneResult::TritonGemmKey& config) {
-                         return (config.block_m() * config.block_n() / 256) %
-                                    config.num_warps() !=
+                       [](const Config& config) {
+                         return (config.block_m * config.block_n / 256) %
+                                    config.num_warps !=
                                 0;
                        }),
         configs.end());
   }
-  configs.erase(
-      std::remove_if(configs.begin(), configs.end(),
-                     [&](const AutotuneResult::TritonGemmKey& config) {
-                       return config.split_k() > max_split_k;
-                     }),
-      configs.end());
+  configs.erase(std::remove_if(configs.begin(), configs.end(),
+                               [&](const Config& config) {
+                                 return config.split_k > max_split_k;
+                               }),
+                configs.end());
+  return configs;
+}
+
+// This prefers to take the parameter by moving it.
+std::vector<TritonGemmConfig> ReduceTileSizes(
+    const HloDotInstruction& dot, std::vector<TritonGemmConfig> configs) {
+  const TileSizeLimit limit = GetUpperLimit(dot);
+  // Decrease the block sizes and split_k if they are unnecessarily big.
+  for (TritonGemmConfig& config : configs) {
+    config.block_m = std::min<int64_t>(config.block_m, limit.block_m);
+    config.block_n = std::min<int64_t>(config.block_n, limit.block_n);
+    config.block_k = std::min<int64_t>(config.block_k, limit.block_k);
+    config.split_k = std::min<int64_t>(
+        config.split_k, GetSplitKLimit(config.block_k, limit.block_k));
+  }
+
+  // Remove duplicates.
+  absl::flat_hash_set<TritonGemmConfig> configs_so_far;
+  configs.erase(std::remove_if(configs.begin(), configs.end(),
+                               [&](const TritonGemmConfig& config) {
+                                 return !configs_so_far.insert(config).second;
+                               }),
+                configs.end());
+  CHECK(!configs.empty());
   return configs;
 }
 
 int GetLogEveryN() { return VLOG_IS_ON(3) ? 100 : 1000; }
 
 StatusOr<std::unique_ptr<HloModule>> TritonGemmAutotuneExtractor(
-    const AutotuneResult::TritonGemmKey& key,
+    const TritonGemmConfig& config,
     const se::DeviceDescription& gpu_device_info,
-    const HloFusionInstruction* fusion, DebugOptions debug_opts) {
+    const HloFusionInstruction* fusion, DebugOptions debug_opts,
+    bool allow_filtering_kernels_spilling_registers) {
   std::unique_ptr<HloModule> new_module =
       AutotunerUtil::ExtractInstructionIntoNewModule(*fusion);
   // Reduce memory usage during compilation by disabling GPU runtime.
   debug_opts.set_xla_gpu_enable_xla_runtime_executable(false);
+  if (!allow_filtering_kernels_spilling_registers) {
+    debug_opts.set_xla_gpu_filter_kernels_spilling_registers_on_autotuning(
+        false);
+  }
   new_module->mutable_config().set_debug_options(debug_opts);
 
   HloComputation* entry_computation = new_module->entry_computation();
@@ -351,11 +413,11 @@
 
   TF_ASSIGN_OR_RETURN(auto backend_config,
                       cloned_dot_fusion->backend_config<FusionBackendConfig>());
-  *backend_config.mutable_triton_gemm_config() = key;
+  *backend_config.mutable_triton_gemm_config() = config.ToProto();
   TF_RETURN_IF_ERROR(cloned_dot_fusion->set_backend_config(backend_config));
 
-  if (key.split_k() > 1) {
-    TF_RETURN_IF_ERROR(MakeDotSplitKBatch(cloned_dot_fusion, key));
+  if (config.split_k > 1) {
+    TF_RETURN_IF_ERROR(MakeDotSplitKBatch(cloned_dot_fusion, config));
     GpuFloatSupport bf16_support(BF16);
     FloatNormalization float_normalization(&bf16_support);
     TF_RETURN_IF_ERROR(float_normalization.Run(new_module.get()).status());
@@ -401,6 +463,11 @@
   return new_module;
 }
 
+bool ShouldAllowFilteringKernelsSpillingRegisters(
+    const GemmConfigSet& gemm_config_set) {
+  return gemm_config_set.configs.size() > 1;
+}
+
 StatusOr<absl::flat_hash_map<const HloFusionInstruction*, ExecutableSet>>
 CompileMany(const AutotuneConfig& config, AutotunerCompileUtil& util,
             tsl::thread::ThreadPool* thread_pool,
@@ -441,11 +508,11 @@
 
   // Returns true on success.
   auto compile =
-      [&](const HloFusionInstruction* fusion,
-          const AutotuneResult::TritonGemmKey& conf) -> StatusOr<bool> {
-    CHECK(conf.block_m() <= kMaxTileSize);
-    CHECK(conf.block_n() <= kMaxTileSize);
-    CHECK(conf.block_k() <= kMaxTileSize);
+      [&](const HloFusionInstruction* fusion, const TritonGemmConfig& conf,
+          bool allow_filtering_kernels_spilling_registers) -> StatusOr<bool> {
+    CHECK_LE(conf.block_m, kMaxTileSize);
+    CHECK_LE(conf.block_n, kMaxTileSize);
+    CHECK_LE(conf.block_k, kMaxTileSize);
     // TODO(b/296884861): Reenable GPU runtime, when it will have much smaller
     // memory overhead (regarding the size of the executables).
     // We can also remove the force_disable_gpu_runtime argument at that
@@ -453,7 +520,8 @@
     TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
                         util.Compile([&](const DebugOptions& opts) {
                           return TritonGemmAutotuneExtractor(
-                              conf, gpu_device_info, fusion, opts);
+                              conf, gpu_device_info, fusion, opts,
+                              allow_filtering_kernels_spilling_registers);
                         }));
 
     if (executable != nullptr) {
@@ -506,13 +574,14 @@
       const HloFusionInstruction* fusion = key_value.first;
       const GemmConfigSet& gemm_config_set = key_value.second;
 
-      for (const AutotuneResult::TritonGemmKey& conf :
-           gemm_config_set.configs) {
+      for (const TritonGemmConfig& conf : gemm_config_set.configs) {
         thread_pool->Schedule([&, fusion] {
-          StatusOr<bool> has_executable = compile(fusion, conf);
+          StatusOr<bool> has_executable = compile(
+              fusion, conf,
+              ShouldAllowFilteringKernelsSpillingRegisters(gemm_config_set));
           TF_CHECK_OK(has_executable.status())
               << "Failure occured when compiling fusion " << fusion->name()
-              << " with config '" << conf.ShortDebugString()
+              << " with config '" << conf.ToString()
               << "'\nFused HLO computation:\n"
               << fusion->fused_instructions_computation()->ToString();
           log(has_executable.value());
@@ -543,9 +612,12 @@
       const HloFusionInstruction* fusion = key_value.first;
       const GemmConfigSet& gemm_config_set = key_value.second;
 
-      for (const AutotuneResult::TritonGemmKey& gemm_config :
-           gemm_config_set.configs) {
-        TF_ASSIGN_OR_RETURN(bool has_executable, compile(fusion, gemm_config));
+      for (const TritonGemmConfig& gemm_config : gemm_config_set.configs) {
+        TF_ASSIGN_OR_RETURN(
+            bool has_executable,
+            compile(
+                fusion, gemm_config,
+                ShouldAllowFilteringKernelsSpillingRegisters(gemm_config_set)));
         log(has_executable);
       }
 
@@ -636,10 +708,10 @@
   VLOG(2) << "Running " << executable_count << " configs for " << fusion->name()
           << ".";
   for (const ExecutableCandidate& candidate : executable_set.candidates) {
-    VLOG(5) << "Trying triton tiling: " << candidate.config.ShortDebugString();
+    VLOG(5) << "Trying triton tiling: " << candidate.config.ToString();
 
     AutotuneResult res;
-    *res.mutable_triton() = candidate.config;
+    *res.mutable_triton() = candidate.config.ToProto();
 
     TF_ASSIGN_OR_RETURN(std::optional<ProfilingOutput> profiling_output,
                         util.ProfileExecutable(candidate.executable.get(),
@@ -659,7 +731,7 @@
     if (profiling_output->duration >= absl::Seconds(1)) {
       LOG(WARNING) << "Slow kernel for " << fusion->name()
                    << " took: " << profiling_output->duration
-                   << ". config: " << candidate.config.ShortDebugString();
+                   << ". config: " << candidate.config.ToString();
     }
     *res.mutable_run_time() =
         tsl::proto_utils::ToDurationProto(profiling_output->duration);
@@ -727,13 +799,14 @@
                             AutotunerCompileUtil& util,
                             const AutotuneResult result,
                             const HloFusionInstruction* fusion, int fusion_id) {
-  TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
-                      util.ExtractModule([&](const DebugOptions& debug_opts) {
-                        return TritonGemmAutotuneExtractor(
-                            result.triton(),
-                            config.GetExecutor()->GetDeviceDescription(),
-                            fusion, debug_opts);
-                      }));
+  TF_ASSIGN_OR_RETURN(
+      std::unique_ptr<HloModule> module,
+      util.ExtractModule([&](const DebugOptions& debug_opts) {
+        return TritonGemmAutotuneExtractor(
+            TritonGemmConfig::FromProto(result.triton()),
+            config.GetExecutor()->GetDeviceDescription(), fusion, debug_opts,
+            /*allow_filtering_kernels_spilling_registers=*/true);
+      }));
   module->set_name(std::string(fusion->name()));
   // Using the original module for its debug info and name in the first
   // parameter. It's better to include the name of both the original module
@@ -760,6 +833,17 @@
       executable_sets,
       CompileMany(config, util, thread_pool, debug_opts, gemm_config_sets));
 
+  // Sort the candidates to make their execution order well-defined for each
+  // fusion.
+  for (auto& key_value : executable_sets) {
+    ExecutableSet& executable_set = key_value.second;
+    std::vector<ExecutableCandidate>& candidates = executable_set.candidates;
+    absl::c_sort(candidates, [](const ExecutableCandidate& a,
+                                const ExecutableCandidate& b) {
+      return a.config < b.config;
+    });
+  }
+
   for (const auto& key_value : executable_sets) {
     const HloFusionInstruction* fusion = key_value.first;
     const ExecutableSet& executable_set = key_value.second;
@@ -788,7 +872,7 @@
 
 }  // anonymous namespace
 
-std::vector<AutotuneResult::TritonGemmKey> GetPossibleMatmulAutotuneConfigs(
+std::vector<TritonGemmConfig> GetPossibleMatmulAutotuneConfigs(
     const HloDotInstruction& dot,
     const se::CudaComputeCapability compute_capability,
     const DebugOptions& debug_options, bool exhaustive_tiling_search) {
@@ -796,7 +880,7 @@
   constexpr int kMinGemmElements = 32 * 32;
   if (ShapeUtil::ElementsIn(dot.operand(0)->shape()) <= kMinGemmElements &&
       ShapeUtil::ElementsIn(dot.operand(1)->shape()) <= kMinGemmElements) {
-    return {GemmKey(32, 32, 32, 1, 1, 4)};
+    return ReduceTileSizes(dot, {TritonGemmConfig(32, 32, 32, 1, 1, 4)});
   }
   // Split-K optimization enables more even utilization of a GPU in cases
   // where tiling just the non-contracting dimensions of a GEMM does not create
@@ -815,9 +899,10 @@
                                       ShapeUtil::ElementsIn(dot.shape()))
           : 1;
   return exhaustive_tiling_search
-             ? GetExhaustiveMatmulAutotuneConfigs(compute_capability,
+             ? GetExhaustiveMatmulAutotuneConfigs(dot, compute_capability,
                                                   max_split_k)
-             : GetFixedMatmulAutotuneConfigs(compute_capability, max_split_k);
+             : ReduceTileSizes(dot, GetFixedMatmulAutotuneConfigs(
+                                        compute_capability, max_split_k));
 }
 
 StatusOr<bool> TritonAutotuner::Run(
diff --git a/third_party/xla/xla/service/gpu/triton_autotuner.h b/third_party/xla/xla/service/gpu/triton_autotuner.h
index bff2f17..9e1e07f 100644
--- a/third_party/xla/xla/service/gpu/triton_autotuner.h
+++ b/third_party/xla/xla/service/gpu/triton_autotuner.h
@@ -24,6 +24,7 @@
 #include "xla/hlo/ir/hlo_instructions.h"
 #include "xla/hlo/ir/hlo_module.h"
 #include "xla/service/gpu/autotuner_util.h"
+#include "xla/service/gpu/matmul_utils.h"
 #include "xla/service/hlo_pass_interface.h"
 #include "xla/statusor.h"
 #include "xla/stream_executor/device_description.h"
@@ -54,7 +55,7 @@
 
 // TODO(b/266210099): have a way to generate/load these dynamically.
 // Returns a list of possible tilings for a GEMM performed in Triton.
-std::vector<AutotuneResult::TritonGemmKey> GetPossibleMatmulAutotuneConfigs(
+std::vector<TritonGemmConfig> GetPossibleMatmulAutotuneConfigs(
     const HloDotInstruction& dot, se::CudaComputeCapability compute_capability,
     const DebugOptions& debug_options, bool exhaustive_tiling_search = false);
 
diff --git a/third_party/xla/xla/service/gpu/triton_autotuner_test.cc b/third_party/xla/xla/service/gpu/triton_autotuner_test.cc
index e6221f1..83a0012 100644
--- a/third_party/xla/xla/service/gpu/triton_autotuner_test.cc
+++ b/third_party/xla/xla/service/gpu/triton_autotuner_test.cc
@@ -37,6 +37,7 @@
 #include "xla/service/gpu/autotuner_util.h"
 #include "xla/service/gpu/backend_configs.pb.h"
 #include "xla/service/gpu/gemm_rewriter_triton.h"
+#include "xla/service/gpu/matmul_utils.h"
 #include "xla/service/hlo_module_config.h"
 #include "xla/service/hlo_pass_pipeline.h"
 #include "xla/service/pattern_matcher.h"
@@ -247,15 +248,14 @@
                                                   .value();
   const se::CudaComputeCapability compute_capability{
       se::CudaComputeCapability::VOLTA, /*minor=*/0};
-  const std::vector<AutotuneResult::TritonGemmKey> configs =
+  const std::vector<TritonGemmConfig> configs =
       GetPossibleMatmulAutotuneConfigs(
           *Cast<HloDotInstruction>(
               module->entry_computation()->root_instruction()),
           compute_capability, GetDebugOptionsForTest());
-  EXPECT_FALSE(std::any_of(configs.begin(), configs.end(),
-                           [](const AutotuneResult::TritonGemmKey& key) {
-                             return key.num_stages() > 2;
-                           }));
+  EXPECT_FALSE(std::any_of(
+      configs.begin(), configs.end(),
+      [](const TritonGemmConfig& config) { return config.num_stages > 2; }));
 }
 
 TEST_F(TritonAutotunerTest, AmpereUsesMoreThanTwoStages) {
@@ -269,15 +269,14 @@
                                                   .value();
   const se::CudaComputeCapability compute_capability{
       se::CudaComputeCapability::AMPERE, /*minor=*/0};
-  const std::vector<AutotuneResult::TritonGemmKey> configs =
+  const std::vector<TritonGemmConfig> configs =
       GetPossibleMatmulAutotuneConfigs(
           *Cast<HloDotInstruction>(
               module->entry_computation()->root_instruction()),
           compute_capability, GetDebugOptionsForTest());
-  EXPECT_TRUE(std::any_of(configs.begin(), configs.end(),
-                          [](const AutotuneResult::TritonGemmKey& key) {
-                            return key.num_stages() > 2;
-                          }));
+  EXPECT_TRUE(std::any_of(
+      configs.begin(), configs.end(),
+      [](const TritonGemmConfig& config) { return config.num_stages > 2; }));
 }
 
 TEST_F(TritonAutotunerTest, SmallOutputCanUseLargeSplitK) {
@@ -291,15 +290,14 @@
                                                   .value();
   const se::CudaComputeCapability compute_capability{
       se::CudaComputeCapability::AMPERE, /*minor=*/0};
-  const std::vector<AutotuneResult::TritonGemmKey> configs =
+  const std::vector<TritonGemmConfig> configs =
       GetPossibleMatmulAutotuneConfigs(
           *Cast<HloDotInstruction>(
               module->entry_computation()->root_instruction()),
           compute_capability, GetDebugOptionsForTest());
-  EXPECT_TRUE(std::any_of(configs.begin(), configs.end(),
-                          [](const AutotuneResult::TritonGemmKey& key) {
-                            return key.split_k() >= 16;
-                          }));
+  EXPECT_TRUE(std::any_of(
+      configs.begin(), configs.end(),
+      [](const TritonGemmConfig& config) { return config.split_k >= 16; }));
 }
 
 TEST_F(TritonAutotunerTest, LargeOutputDoesNotUseLargeSplitK) {
@@ -313,15 +311,14 @@
                                                   .value();
   const se::CudaComputeCapability compute_capability{
       se::CudaComputeCapability::AMPERE, /*minor=*/0};
-  const std::vector<AutotuneResult::TritonGemmKey> configs =
+  const std::vector<TritonGemmConfig> configs =
       GetPossibleMatmulAutotuneConfigs(
           *Cast<HloDotInstruction>(
               module->entry_computation()->root_instruction()),
           compute_capability, GetDebugOptionsForTest());
-  EXPECT_FALSE(std::any_of(configs.begin(), configs.end(),
-                           [](const AutotuneResult::TritonGemmKey& key) {
-                             return key.split_k() > 1;
-                           }));
+  EXPECT_FALSE(std::any_of(
+      configs.begin(), configs.end(),
+      [](const TritonGemmConfig& config) { return config.split_k > 1; }));
 }
 
 TEST_F(TritonAutotunerTest, Int8FusedGemm) {
@@ -677,15 +674,14 @@
                                                   .value();
   const se::CudaComputeCapability compute_capability{
       se::CudaComputeCapability::AMPERE, /*minor=*/0};
-  const std::vector<AutotuneResult::TritonGemmKey> configs =
+  const std::vector<TritonGemmConfig> configs =
       GetPossibleMatmulAutotuneConfigs(
           *Cast<HloDotInstruction>(
               module->entry_computation()->root_instruction()),
           compute_capability, GetDebugOptionsForTest());
-  EXPECT_TRUE(std::all_of(configs.begin(), configs.end(),
-                          [](const AutotuneResult::TritonGemmKey& key) {
-                            return key.split_k() == 1;
-                          }));
+  EXPECT_TRUE(std::all_of(
+      configs.begin(), configs.end(),
+      [](const TritonGemmConfig& config) { return config.split_k == 1; }));
 }
 
 }  // namespace
diff --git a/third_party/xla/xla/service/heap_simulator.cc b/third_party/xla/xla/service/heap_simulator.cc
index 516c1e5..1496df4 100644
--- a/third_party/xla/xla/service/heap_simulator.cc
+++ b/third_party/xla/xla/service/heap_simulator.cc
@@ -44,6 +44,7 @@
 #include "xla/hlo/utils/hlo_live_range.h"
 #include "xla/map_util.h"
 #include "xla/service/memory_space_assignment/repacking.h"
+#include "xla/service/time_utils.h"
 #include "xla/status.h"
 #include "xla/util.h"
 
@@ -927,14 +928,27 @@
 
 template <typename BufferType>
 void GlobalDecreasingSizeBestFitHeap<BufferType>::SlicedBufferInterval::
-    UpdateSliceStartTimes(const std::vector<int64_t>& start_times) {
-  CHECK_EQ(start_times.size(), num_slices());
+    UpdateExclusiveSliceStartTimes(
+        const std::vector<int64_t>& exclusive_start_times) {
+  std::vector<int64_t> inclusive_start_times = exclusive_start_times;
+  absl::c_for_each(inclusive_start_times,
+                   [](int64_t& t) { t = ExclusiveToInclusiveStartTime(t); });
+  UpdateInclusiveSliceStartTimes(inclusive_start_times);
+}
+
+template <typename BufferType>
+void GlobalDecreasingSizeBestFitHeap<BufferType>::SlicedBufferInterval::
+    UpdateInclusiveSliceStartTimes(
+        const std::vector<int64_t>& inclusive_start_times) {
+  CHECK_EQ(inclusive_start_times.size(), num_slices());
   CHECK(mutable_full_buffer_interval_ != nullptr);
-  mutable_full_buffer_interval_->start = start_times.front();
+  mutable_full_buffer_interval_->start = inclusive_start_times.front();
   for (size_t slice_time = 0; slice_time < num_slices(); ++slice_time) {
-    make_free_chunks_intervals_[slice_time].start = start_times[slice_time];
+    make_free_chunks_intervals_[slice_time].start =
+        inclusive_start_times[slice_time];
     if (slice_time != num_slices() - 1) {
-      make_free_chunks_intervals_[slice_time].end = start_times[slice_time + 1];
+      make_free_chunks_intervals_[slice_time].end =
+          ExclusiveToInclusiveEndTime(inclusive_start_times[slice_time + 1]);
     } else {
       make_free_chunks_intervals_[slice_time].end = full_buffer_interval_.end;
     }
diff --git a/third_party/xla/xla/service/heap_simulator.h b/third_party/xla/xla/service/heap_simulator.h
index 14f80a1..b6b78d7 100644
--- a/third_party/xla/xla/service/heap_simulator.h
+++ b/third_party/xla/xla/service/heap_simulator.h
@@ -482,10 +482,13 @@
     //
     // REQUIRES:
     // - The SlicedBufferInterval was constructed using CreateMutableInterval.
-    // - start_times.size() == NumSlices()
-    // - start_times should be set such that it is permissible for any slice
-    //   size to map to any start time.
-    void UpdateSliceStartTimes(const std::vector<int64_t>& start_times);
+    // - *_start_times.size() == NumSlices()
+    // - *_start_times should be set such that it is permissible for any
+    //   slice size to map to any start time.
+    void UpdateExclusiveSliceStartTimes(
+        const std::vector<int64_t>& exclusive_start_times);
+    void UpdateInclusiveSliceStartTimes(
+        const std::vector<int64_t>& inclusive_start_times);
 
     // Updates the free time for all the slices.
     //
diff --git a/third_party/xla/xla/service/heap_simulator_test.cc b/third_party/xla/xla/service/heap_simulator_test.cc
index 5a4fa1d..8ee842c 100644
--- a/third_party/xla/xla/service/heap_simulator_test.cc
+++ b/third_party/xla/xla/service/heap_simulator_test.cc
@@ -1575,7 +1575,7 @@
 
     // // Slice B.
     sliced_buffer_b.Slice({5, 5});
-    sliced_buffer_b.UpdateSliceStartTimes({25, 30});
+    sliced_buffer_b.UpdateInclusiveSliceStartTimes({25, 30});
 
     // Place and commit B (and C transitively via colocation). B should be
     // placed at an offset that accommodates C; however, it should not have the
@@ -2084,20 +2084,21 @@
   EXPECT_THAT(mutable_sliced_buffer_interval_->SliceSizesSortedByOffset(),
               ::testing::ElementsAre(4, 5, 5, 6));
 
-  mutable_sliced_buffer_interval_->UpdateSliceStartTimes({100, 125, 150, 175});
+  mutable_sliced_buffer_interval_->UpdateInclusiveSliceStartTimes(
+      {100, 125, 150, 175});
 
   EXPECT_EQ(BufferIntervalToTuple(
                 mutable_sliced_buffer_interval_->IntervalForMakeFreeChunks(0)),
             BufferIntervalToTuple(
-                {p0_value_.get(), 4, 100, 125, ColocationTy(), true}));
+                {p0_value_.get(), 4, 100, 124, ColocationTy(), true}));
   EXPECT_EQ(BufferIntervalToTuple(
                 mutable_sliced_buffer_interval_->IntervalForMakeFreeChunks(1)),
             BufferIntervalToTuple(
-                {p0_value_.get(), 4, 125, 150, ColocationTy(), true}));
+                {p0_value_.get(), 4, 125, 149, ColocationTy(), true}));
   EXPECT_EQ(BufferIntervalToTuple(
                 mutable_sliced_buffer_interval_->IntervalForMakeFreeChunks(2)),
             BufferIntervalToTuple(
-                {p0_value_.get(), 4, 150, 175, ColocationTy(), true}));
+                {p0_value_.get(), 4, 150, 174, ColocationTy(), true}));
   EXPECT_EQ(BufferIntervalToTuple(
                 mutable_sliced_buffer_interval_->IntervalForMakeFreeChunks(3)),
             BufferIntervalToTuple({p0_value_.get(), 20, 175, 200,
@@ -2107,6 +2108,30 @@
             BufferIntervalToTuple({p0_value_.get(), 20, 100, 200,
                                    ColocationTy({p1_value_.get()}), true}));
 
+  mutable_sliced_buffer_interval_->UpdateExclusiveSliceStartTimes(
+      {100, 125, 150, 175});
+
+  EXPECT_EQ(BufferIntervalToTuple(
+                mutable_sliced_buffer_interval_->IntervalForMakeFreeChunks(0)),
+            BufferIntervalToTuple(
+                {p0_value_.get(), 4, 101, 125, ColocationTy(), true}));
+  EXPECT_EQ(BufferIntervalToTuple(
+                mutable_sliced_buffer_interval_->IntervalForMakeFreeChunks(1)),
+            BufferIntervalToTuple(
+                {p0_value_.get(), 4, 126, 150, ColocationTy(), true}));
+  EXPECT_EQ(BufferIntervalToTuple(
+                mutable_sliced_buffer_interval_->IntervalForMakeFreeChunks(2)),
+            BufferIntervalToTuple(
+                {p0_value_.get(), 4, 151, 175, ColocationTy(), true}));
+  EXPECT_EQ(BufferIntervalToTuple(
+                mutable_sliced_buffer_interval_->IntervalForMakeFreeChunks(3)),
+            BufferIntervalToTuple({p0_value_.get(), 20, 176, 200,
+                                   ColocationTy({p1_value_.get()}), true}));
+  EXPECT_EQ(BufferIntervalToTuple(
+                mutable_sliced_buffer_interval_->full_buffer_interval()),
+            BufferIntervalToTuple({p0_value_.get(), 20, 101, 200,
+                                   ColocationTy({p1_value_.get()}), true}));
+
   mutable_sliced_buffer_interval_->UpdateEndTime(300);
 
   // Only the BufferInterval for the last slice time should have changed end
@@ -2115,11 +2140,11 @@
             175);
   EXPECT_EQ(BufferIntervalToTuple(
                 mutable_sliced_buffer_interval_->IntervalForMakeFreeChunks(3)),
-            BufferIntervalToTuple({p0_value_.get(), 20, 175, 300,
+            BufferIntervalToTuple({p0_value_.get(), 20, 176, 300,
                                    ColocationTy({p1_value_.get()}), true}));
   EXPECT_EQ(BufferIntervalToTuple(
                 mutable_sliced_buffer_interval_->full_buffer_interval()),
-            BufferIntervalToTuple({p0_value_.get(), 20, 100, 300,
+            BufferIntervalToTuple({p0_value_.get(), 20, 101, 300,
                                    ColocationTy({p1_value_.get()}), true}));
 }
 
diff --git a/third_party/xla/xla/service/hlo_computation_deduplicator.cc b/third_party/xla/xla/service/hlo_computation_deduplicator.cc
index 87bf540..fc1f497 100644
--- a/third_party/xla/xla/service/hlo_computation_deduplicator.cc
+++ b/third_party/xla/xla/service/hlo_computation_deduplicator.cc
@@ -84,7 +84,7 @@
     // with large number of instructions or large-size constants due to increase
     // in time taken to stringify.
     if (comp->IsEntryComputation() || comp->instruction_count() > 128 ||
-        ContainsLargeConstants(comp)) {
+        ContainsLargeConstants(comp) || comp->IsCollectiveCalledComputation()) {
       continue;
     }
     std::string comp_str = comp->ToString(options);
diff --git a/third_party/xla/xla/service/hlo_computation_deduplicator_test.cc b/third_party/xla/xla/service/hlo_computation_deduplicator_test.cc
index 723de91..3dab680 100644
--- a/third_party/xla/xla/service/hlo_computation_deduplicator_test.cc
+++ b/third_party/xla/xla/service/hlo_computation_deduplicator_test.cc
@@ -589,5 +589,34 @@
   std::vector<HloComputation *> computations = module->MakeComputationSorted();
   EXPECT_EQ(computations.size(), (total_regions + 1));
 }
+
+TEST_F(HloComputationDeduplicatorTest, DontDeduplicateReduceAllReduce) {
+  // Note: this test is hypothetical and just to check dedup.
+  const std::string_view text = R"(
+  HloModule TestModule
+
+  add.1 {
+    Arg_0 = s32[] parameter(0)
+    Arg_1 = s32[] parameter(1)
+    ROOT add.2 = s32[] add(Arg_0, Arg_1)
+  }
+  add.2 {
+    Arg_0 = s32[] parameter(0)
+    Arg_1 = s32[] parameter(1)
+    ROOT add.2 = s32[] add(Arg_0, Arg_1)
+  }
+
+  ENTRY main {
+    Arg_0.1 = s32[10] parameter(0)
+    constant.3 = s32[] constant(0)
+    rd1 = s32[] reduce(Arg_0.1, constant.3), dimensions={0}, to_apply=add.1
+    Arg_1.1 = s32[] parameter(1)
+    rd2 = s32[] all-reduce(Arg_1.1), to_apply=add.2
+    ROOT multiply.14 = s32[] multiply(rd1, rd2)
+  }
+  )";
+  auto computation_names = RunDeduplicatePass(text, /*expect_true=*/false);
+  EXPECT_EQ(computation_names.size(), 3);
+}
 }  //  namespace
 }  //  namespace xla
diff --git a/third_party/xla/xla/service/hlo_creation_utils.cc b/third_party/xla/xla/service/hlo_creation_utils.cc
index 54e8f9c..34e188e 100644
--- a/third_party/xla/xla/service/hlo_creation_utils.cc
+++ b/third_party/xla/xla/service/hlo_creation_utils.cc
@@ -16,16 +16,19 @@
 #include "xla/service/hlo_creation_utils.h"
 
 #include <algorithm>
+#include <cstdint>
 #include <iterator>
 #include <memory>
 #include <numeric>
 #include <optional>
 #include <string>
-#include <utility>
 #include <vector>
 
 #include "absl/algorithm/container.h"
+#include "absl/log/check.h"
 #include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
 #include "xla/client/lib/comparators.h"
 #include "xla/client/xla_builder.h"
 #include "xla/client/xla_computation.h"
@@ -33,11 +36,18 @@
 #include "xla/hlo/ir/hlo_clone_context.h"
 #include "xla/hlo/ir/hlo_instruction.h"
 #include "xla/hlo/ir/hlo_module.h"
-#include "xla/literal.h"
+#include "xla/hlo/ir/hlo_opcode.h"
 #include "xla/literal_util.h"
+#include "xla/primitive_util.h"
 #include "xla/service/hlo_module_config.h"
 #include "xla/service/shape_inference.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/status_macros.h"
+#include "xla/statusor.h"
 #include "xla/util.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
 
 namespace xla {
 using absl::StrCat;
@@ -319,6 +329,12 @@
     return hlo;
   }
   Shape shape = ShapeUtil::ChangeElementType(hlo->shape(), type);
+  if (primitive_util::Is4BitType(shape.element_type())) {
+    shape.mutable_layout()->set_element_size_in_bits(4);
+  } else {
+    shape.mutable_layout()->set_element_size_in_bits(0);
+  }
+
   hlo = hlo->parent()->AddInstruction(HloInstruction::CreateConvert(shape, hlo),
                                       metadata);
   CHECK_EQ(hlo->shape().element_type(), type);
@@ -598,8 +614,10 @@
   const Shape& operand_shape = operand->shape();
   CHECK_GE(operand_shape.dimensions_size(), n);
   int64_t new_shape_leading_bound = 1;
+  bool new_shape_leading_is_dynamic = false;
   for (int64_t i = 0; i < n; i++) {
     new_shape_leading_bound *= operand_shape.dimensions(i);
+    new_shape_leading_is_dynamic |= operand_shape.is_dynamic_dimension(i);
   }
 
   std::vector<int64_t> new_shape_dims;
@@ -610,8 +628,15 @@
             operand_shape.dimensions().end(),
             std::back_inserter(new_shape_dims));
 
-  Shape output_shape =
-      ShapeUtil::MakeShape(operand_shape.element_type(), new_shape_dims);
+  std::vector<bool> new_shape_dynamic_dims;
+  new_shape_dynamic_dims.reserve(operand_shape.dimensions_size() - n + 1);
+  new_shape_dynamic_dims.push_back(new_shape_leading_is_dynamic);
+  std::copy(operand_shape.dynamic_dimensions().begin() + n,
+            operand_shape.dynamic_dimensions().end(),
+            std::back_inserter(new_shape_dynamic_dims));
+
+  Shape output_shape = ShapeUtil::MakeShape(
+      operand_shape.element_type(), new_shape_dims, new_shape_dynamic_dims);
 
   return MakeReshapeHlo(output_shape, operand);
 }
diff --git a/third_party/xla/xla/service/hlo_dataflow_analysis.cc b/third_party/xla/xla/service/hlo_dataflow_analysis.cc
index 1f72360..6c9f40e 100644
--- a/third_party/xla/xla/service/hlo_dataflow_analysis.cc
+++ b/third_party/xla/xla/service/hlo_dataflow_analysis.cc
@@ -24,11 +24,14 @@
 #include <vector>
 
 #include "absl/algorithm/container.h"
+#include "absl/base/attributes.h"
 #include "absl/container/flat_hash_map.h"
 #include "absl/container/flat_hash_set.h"
 #include "absl/container/inlined_vector.h"
 #include "absl/functional/function_ref.h"
+#include "absl/memory/memory.h"
 #include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
 #include "xla/hlo/ir/hlo_casting_utils.h"
 #include "xla/hlo/ir/hlo_computation.h"
 #include "xla/hlo/ir/hlo_instruction.h"
@@ -86,11 +89,12 @@
 using absl::StrAppend;
 using absl::StrCat;
 
-HloDataflowAnalysis::HloDataflowAnalysis(const HloModule& module, bool ssa_form,
-                                         bool bitcast_defines_value,
-                                         const CanShareBuffer& can_share_buffer,
-                                         const ForwardsValue& forwards_value)
+HloDataflowAnalysis::HloDataflowAnalysis(
+    const HloModule& module, bool ssa_form, bool bitcast_defines_value,
+    const CanShareBuffer& can_share_buffer, const ForwardsValue& forwards_value,
+    absl::flat_hash_set<absl::string_view> execution_threads)
     : module_(module),
+      execution_threads_(std::move(execution_threads)),
       ssa_form_(ssa_form),
       bitcast_defines_value_(bitcast_defines_value),
       call_graph_(CallGraph::Build(&module)),
@@ -366,6 +370,10 @@
       StrCat("HloDataflowAnalysis, module ", module_.name(), "\n");
   StrAppend(&out, "  Instruction value sets:\n");
   for (const HloComputation* computation : module_.computations()) {
+    if (!HloInstruction::IsThreadIncluded(computation->execution_thread(),
+                                          execution_threads_)) {
+      continue;
+    }
     for (const HloInstruction* instruction : computation->instructions()) {
       StrAppend(&out, "Instruction: \n  ", instruction->name(), ":\n");
       if (instruction->shape().IsTuple()) {
@@ -572,20 +580,6 @@
   return false;
 }
 
-bool HloDataflowAnalysis::UpdateSetDimensionSizeValueSet(
-    HloInstruction* set_dimension_size) {
-  CHECK_EQ(set_dimension_size->opcode(), HloOpcode::kSetDimensionSize);
-  const InstructionValueSet& operand_set =
-      GetInstructionValueSet(set_dimension_size->operand(0));
-  InstructionValueSet& set_dimension_size_set =
-      GetInstructionValueSet(set_dimension_size);
-  if (operand_set != set_dimension_size_set) {
-    set_dimension_size_set = operand_set;
-    return true;
-  }
-  return false;
-}
-
 bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) {
   CHECK_EQ(send->opcode(), HloOpcode::kSend);
   bool changed = false;
@@ -632,6 +626,10 @@
           }
         });
   }
+  if (!HloInstruction::IsThreadIncluded(async_start->async_execution_thread(),
+                                        execution_threads_)) {
+    return changed;
+  }
   // AsyncStart forwards the async wrapped computation root values to element
   // {1} of its output.
   HloInstruction* root =
@@ -661,7 +659,10 @@
   CHECK_EQ(async_update->shape(), async_update->operand(0)->shape());
   bool changed = false;
   HloInstruction* root =
-      async_update->async_wrapped_computation()->root_instruction();
+      HloInstruction::IsThreadIncluded(async_update->async_execution_thread(),
+                                       execution_threads_)
+          ? async_update->async_wrapped_computation()->root_instruction()
+          : nullptr;
   // AsyncUpdate forwards all of the operand values to corresponding elements of
   // its output.
   ShapeUtil::ForEachSubshape(
@@ -680,13 +681,16 @@
             value_set = operand_value_set;
             changed = true;
           }
-        } else {
+        } else if (root != nullptr) {
           // If this subshape is an output (index {1}), we need to create the
           // union with the async wrapped computation root.
           ShapeIndex root_index(index.begin() + 1, index.end());
           const HloValueSet& root_value_set = GetValueSet(root, root_index);
           changed |=
               value_set.AssignUnionOf({&operand_value_set, &root_value_set});
+        } else if (value_set != operand_value_set) {
+          value_set = operand_value_set;
+          changed = true;
         }
       });
   return changed;
@@ -696,7 +700,10 @@
   CHECK_EQ(async_done->opcode(), HloOpcode::kAsyncDone);
   bool changed = false;
   HloInstruction* root =
-      async_done->async_wrapped_computation()->root_instruction();
+      HloInstruction::IsThreadIncluded(async_done->async_execution_thread(),
+                                       execution_threads_)
+          ? async_done->async_wrapped_computation()->root_instruction()
+          : nullptr;
   // AsyncDone creates a union of the operand values at {1} and the async
   // wrapped computation root to element {} of its output.
   ShapeUtil::ForEachSubshape(
@@ -710,9 +717,14 @@
 
         ShapeIndex output_index(index.begin() + 1, index.end());
         HloValueSet& value_set = GetValueSet(async_done, output_index);
-        const HloValueSet& root_value_set = GetValueSet(root, output_index);
-        changed |=
-            value_set.AssignUnionOf({&operand_value_set, &root_value_set});
+        if (root != nullptr) {
+          const HloValueSet& root_value_set = GetValueSet(root, output_index);
+          changed |=
+              value_set.AssignUnionOf({&operand_value_set, &root_value_set});
+        } else if (value_set != operand_value_set) {
+          value_set = operand_value_set;
+          changed = true;
+        }
       });
   return changed;
 }
@@ -1179,10 +1191,6 @@
       changed = UpdateBitcastValueSet(instruction);
       break;
     }
-    case HloOpcode::kSetDimensionSize: {
-      changed = UpdateSetDimensionSizeValueSet(instruction);
-      break;
-    }
     case HloOpcode::kDomain: {
       changed = UpdateDomainValueSet(instruction);
       break;
@@ -1287,6 +1295,10 @@
 
   auto comps = module_.MakeComputationPostOrder();
   for (HloComputation* computation : comps) {
+    if (!HloInstruction::IsThreadIncluded(computation->execution_thread(),
+                                          execution_threads_)) {
+      continue;
+    }
     for (HloInstruction* instruction :
          computation->MakeInstructionPostOrder()) {
       add_to_worklist(instruction);
@@ -1296,12 +1308,6 @@
 
   while (!worklist.empty()) {
     HloInstruction* instruction = worklist.top().second;
-    auto add_to_worklist = [&](HloInstruction* todo) {
-      if (workset.insert(todo).second) {
-        VLOG(1) << "  Adding todo : " << todo->name();
-        worklist.emplace(priority_map[todo], todo);
-      }
-    };
     worklist.pop();
 
     workset.erase(workset.find(instruction));
@@ -1341,18 +1347,25 @@
         }
       } else if (user->opcode() == HloOpcode::kAsyncUpdate ||
                  user->opcode() == HloOpcode::kAsyncDone) {
-        // For async update and async done, we cannot distinguish which
-        // parameter needs to be updated so add all to the worklist.
-        for (int64_t parameter_number = 0;
-             parameter_number <
-             user->async_wrapped_computation()->num_parameters();
-             ++parameter_number) {
-          add_to_worklist(
-              user->async_wrapped_computation()->parameter_instruction(
-                  parameter_number));
+        if (HloInstruction::IsThreadIncluded(user->async_execution_thread(),
+                                             execution_threads_)) {
+          // For async update and async done, we cannot distinguish which
+          // parameter needs to be updated so add all to the worklist.
+          for (int64_t parameter_number = 0;
+               parameter_number <
+               user->async_wrapped_computation()->num_parameters();
+               ++parameter_number) {
+            add_to_worklist(
+                user->async_wrapped_computation()->parameter_instruction(
+                    parameter_number));
+          }
         }
       } else {
         for (HloComputation* called_computation : user->called_computations()) {
+          if (!HloInstruction::IsThreadIncluded(
+                  called_computation->execution_thread(), execution_threads_)) {
+            continue;
+          }
           const CallGraphNode& call_graph_node =
               call_graph_->GetNode(called_computation);
           if (call_graph_node.context() == CallContext::kControlFlow) {
@@ -1399,6 +1412,10 @@
 
 Status HloDataflowAnalysis::InitializeInstructionValueSets() {
   for (const HloComputation* computation : module_.MakeComputationSorted()) {
+    if (!HloInstruction::IsThreadIncluded(computation->execution_thread(),
+                                          execution_threads_)) {
+      continue;
+    }
     const CallGraphNode& call_graph_node = call_graph_->GetNode(computation);
     for (HloInstruction* instruction :
          computation->MakeInstructionPostOrder()) {
@@ -1445,7 +1462,6 @@
             define_all_values();
           }
           break;
-        case HloOpcode::kSetDimensionSize:
         case HloOpcode::kAddDependency:
         case HloOpcode::kWhile:
         case HloOpcode::kCall:
@@ -1483,16 +1499,22 @@
           // values flow from their operands.
           define_value_at(/*index=*/{});
           break;
-        case HloOpcode::kAsyncStart:
+        case HloOpcode::kAsyncStart: {
           // AsyncStart produces a tuple of {{aliased operands}, {destination},
           // contexts}. It defines all of the tuple-shaped values and the
           // contexts.
+          // If the thread is excluded, then we don't track the contained
+          // dataflow, and define the destination values too.
+          bool thread_included = HloInstruction::IsThreadIncluded(
+              instruction->async_execution_thread(), execution_threads_);
           define_all_values([&](const ShapeIndex& index) {
             return ShapeUtil::GetSubshape(instruction->shape(), index)
                        .IsTuple() ||
-                   index.front() > 1;
+                   (!thread_included && index.front() == 1) ||
+                   (index.front() > 1);
           });
           break;
+        }
         case HloOpcode::kAsyncUpdate:
           // AsyncUpdate produces a tuple of {{aliased operands}, {destination},
           // contexts} where all of the array-typed values alias with the
@@ -1607,6 +1629,10 @@
   XLA_VLOG_LINES(1, phi_graph_.ToString());
 
   for (const HloComputation* computation : module_.computations()) {
+    if (!HloInstruction::IsThreadIncluded(computation->execution_thread(),
+                                          execution_threads_)) {
+      continue;
+    }
     for (HloInstruction* instruction : computation->instructions()) {
       InstructionValueSet& instruction_value_set =
           GetInstructionValueSet(instruction);
@@ -1636,14 +1662,14 @@
 /* static */
 StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
     const HloModule& module, bool ssa_form, bool bitcast_defines_value,
-    const CanShareBuffer& can_share_buffer,
-    const ForwardsValue& forwards_value) {
+    const CanShareBuffer& can_share_buffer, const ForwardsValue& forwards_value,
+    absl::flat_hash_set<absl::string_view> execution_threads) {
   VLOG(1) << "HloDataflowAnalysis::Run on module " << module.name();
   XLA_VLOG_LINES(2, module.ToString());
 
-  auto dataflow_analysis = absl::WrapUnique(
-      new HloDataflowAnalysis(module, ssa_form, bitcast_defines_value,
-                              can_share_buffer, forwards_value));
+  auto dataflow_analysis = absl::WrapUnique(new HloDataflowAnalysis(
+      module, ssa_form, bitcast_defines_value, can_share_buffer, forwards_value,
+      execution_threads));
 
   TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets());
   dataflow_analysis->Propagate();
@@ -1659,6 +1685,10 @@
   std::vector<std::vector<HloPosition>> value_positions(
       dataflow_analysis->next_value_id_);
   for (const HloComputation* computation : module.computations()) {
+    if (!HloInstruction::IsThreadIncluded(computation->execution_thread(),
+                                          execution_threads)) {
+      continue;
+    }
     for (HloInstruction* instruction : computation->instructions()) {
       for (const auto& pair :
            dataflow_analysis->GetInstructionValueSet(instruction)) {
@@ -1708,6 +1738,10 @@
   // For each value in each value set, verify that the value set's position
   // appears in the value's positions().
   for (const auto& computation : module_.computations()) {
+    if (!HloInstruction::IsThreadIncluded(computation->execution_thread(),
+                                          execution_threads_)) {
+      continue;
+    }
     for (const auto& instruction : computation->instructions()) {
       for (const auto& pair : GetInstructionValueSet(instruction)) {
         const ShapeIndex& index = pair.first;
@@ -1958,6 +1992,14 @@
       }
     }
     return in_place_pairs;
+  } else if (instruction->opcode() == HloOpcode::kSetDimensionSize) {
+    int64_t dimension = instruction->dimension();
+    std::vector<std::pair<HloOperandIndex, ShapeIndex>> in_place_pairs;
+    if (instruction->shape().is_dynamic_dimension(dimension) ==
+        instruction->shape().is_dynamic_dimension(dimension)) {
+      in_place_pairs.push_back({HloOperandIndex{0, {}}, {}});
+    }
+    return in_place_pairs;
   }
 
   return {};
@@ -2074,7 +2116,8 @@
 
   if (user->opcode() == HloOpcode::kDynamicUpdateSlice ||
       user->opcode() == HloOpcode::kScatter ||
-      user->opcode() == HloOpcode::kTriangularSolve) {
+      user->opcode() == HloOpcode::kTriangularSolve ||
+      user->opcode() == HloOpcode::kSetDimensionSize) {
     // We eliminated other users in HloOrdering::LiveRangeStrictlyBefore
     // so here we just need to check that the use is at the right operand index.
     const auto operand_indices = user->OperandIndices(operand);
diff --git a/third_party/xla/xla/service/hlo_dataflow_analysis.h b/third_party/xla/xla/service/hlo_dataflow_analysis.h
index d0cbb63..99453fd 100644
--- a/third_party/xla/xla/service/hlo_dataflow_analysis.h
+++ b/third_party/xla/xla/service/hlo_dataflow_analysis.h
@@ -20,23 +20,27 @@
 #ifndef XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_
 #define XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_
 
+#include <cstdint>
 #include <functional>
-#include <iterator>
 #include <memory>
+#include <optional>
 #include <string>
+#include <utility>
 #include <vector>
 
 #include "absl/container/flat_hash_map.h"
 #include "absl/container/flat_hash_set.h"
+#include "absl/strings/string_view.h"
 #include "absl/types/span.h"
 #include "xla/hlo/ir/hlo_instruction.h"
 #include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/ir/hlo_opcode.h"
 #include "xla/service/call_graph.h"
 #include "xla/service/hlo_phi_graph.h"
+#include "xla/service/hlo_value.h"
 #include "xla/shape_util.h"
 #include "xla/status.h"
 #include "xla/statusor.h"
-#include "xla/types.h"
 #include "xla/xla_data.pb.h"
 
 namespace xla {
@@ -110,7 +114,8 @@
       const HloModule& module, bool ssa_form = false,
       bool bitcast_defines_value = false,
       const CanShareBuffer& can_share_buffer = nullptr,
-      const ForwardsValue& forwards_value = nullptr);
+      const ForwardsValue& forwards_value = nullptr,
+      absl::flat_hash_set<absl::string_view> execution_threads = {});
 
   // Returns true if 'instruction' defines an HLO value at the given shape index
   // of its output.
@@ -227,9 +232,10 @@
   static bool AreTransitiveUsesElementwiseOrTuple(const HloInstruction* inst);
 
   HloDataflowAnalysis(const HloModule& module, bool ssa_form,
-                      bool bitcast_defines_value = false,
-                      const CanShareBuffer& can_share_buffer = nullptr,
-                      const ForwardsValue& forwards_value = nullptr);
+                      bool bitcast_defines_value,
+                      const CanShareBuffer& can_share_buffer,
+                      const ForwardsValue& forwards_value,
+                      absl::flat_hash_set<absl::string_view> execution_threads);
 
   // 1. During value propagation (Propagate function), always create phi
   // values once it see multiple inputs merging at the same point. It then
@@ -293,7 +299,6 @@
   bool UpdateOptimizationBarrierValueSet(HloInstruction* barrier);
   bool UpdateRecvDoneValueSet(HloInstruction* recv_done);
   bool UpdateSendValueSet(HloInstruction* send);
-  bool UpdateSetDimensionSizeValueSet(HloInstruction* set_dimension_size);
   bool UpdateTupleValueSet(HloInstruction* tuple);
   bool UpdateWhileValueSet(HloInstruction* xla_while);
   bool UpdateAddDependencyValueSet(HloInstruction* add_dependency);
@@ -327,6 +332,7 @@
       const InstructionValueSet* prev_value_set = nullptr);
 
   const HloModule& module_;
+  const absl::flat_hash_set<absl::string_view> execution_threads_;
   const bool ssa_form_;
   const bool bitcast_defines_value_;
 
diff --git a/third_party/xla/xla/service/hlo_dataflow_analysis_test.cc b/third_party/xla/xla/service/hlo_dataflow_analysis_test.cc
index 2a01eee..c5f67ab 100644
--- a/third_party/xla/xla/service/hlo_dataflow_analysis_test.cc
+++ b/third_party/xla/xla/service/hlo_dataflow_analysis_test.cc
@@ -15,27 +15,36 @@
 
 #include "xla/service/hlo_dataflow_analysis.h"
 
+#include <cstdint>
+#include <initializer_list>
+#include <memory>
 #include <string>
+#include <utility>
+#include <vector>
 
+#include <gtest/gtest.h>
+#include "absl/log/check.h"
+#include "absl/strings/str_cat.h"
+#include "xla/comparison_util.h"
 #include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
 #include "xla/hlo/ir/hlo_opcode.h"
-#include "xla/hlo/utils/hlo_matchers.h"
-#include "xla/literal.h"
+#include "xla/hlo/ir/hlo_schedule.h"
+#include "xla/literal_util.h"
 #include "xla/service/async_op_canonicalizer.h"
 #include "xla/service/flatten_call_graph.h"
 #include "xla/service/hlo_creation_utils.h"
 #include "xla/service/hlo_dce.h"
-#include "xla/service/hlo_graph_dumper.h"
 #include "xla/service/hlo_ordering.h"
-#include "xla/service/instruction_fusion.h"
+#include "xla/service/hlo_value.h"
+#include "xla/shape.h"
 #include "xla/shape_util.h"
-#include "xla/status_macros.h"
+#include "xla/status.h"
 #include "xla/test.h"
-#include "xla/test_helpers.h"
 #include "xla/tests/hlo_test_base.h"
 #include "xla/xla_data.pb.h"
 #include "tsl/lib/core/status_test_util.h"
-#include "tsl/platform/logging.h"
+#include "tsl/platform/statusor.h"
 #include "tsl/platform/test.h"
 
 namespace xla {
@@ -1239,7 +1248,7 @@
               UnorderedElementsAre(&analysis.GetValueDefinedAt(param)));
 }
 
-TEST_P(HloDataflowAnalysisTest, SetDimensionSizeForwardsValue) {
+TEST_P(HloDataflowAnalysisTest, SetDimensionSizeCreatesValue) {
   auto builder = HloComputation::Builder(TestName());
   auto param = builder.AddInstruction(
       HloInstruction::CreateParameter(0, vector_shape_, "param"));
@@ -1254,11 +1263,11 @@
   bool ssa_form = GetParam();
   {
     const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
-    EXPECT_EQ(analysis.values().size(), 2);
+    EXPECT_EQ(analysis.values().size(), 3);
 
     EXPECT_TRUE(analysis.ValueIsDefinedAt(param));
-    EXPECT_FALSE(analysis.ValueIsDefinedAt(sds));
-    EXPECT_TRUE(analysis.GetValueDefinedAt(param).live_out_of_module());
+    EXPECT_TRUE(analysis.ValueIsDefinedAt(sds));
+    EXPECT_TRUE(analysis.GetValueDefinedAt(sds).live_out_of_module());
   }
 }
 
diff --git a/third_party/xla/xla/service/hlo_dce.cc b/third_party/xla/xla/service/hlo_dce.cc
index 5325873..ed24416 100644
--- a/third_party/xla/xla/service/hlo_dce.cc
+++ b/third_party/xla/xla/service/hlo_dce.cc
@@ -15,11 +15,12 @@
 
 #include "xla/service/hlo_dce.h"
 
-#include <memory>
-#include <utility>
 #include <vector>
 
 #include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/strings/string_view.h"
 #include "xla/hlo/ir/hlo_casting_utils.h"
 #include "xla/hlo/ir/hlo_computation.h"
 #include "xla/hlo/ir/hlo_instruction.h"
@@ -27,12 +28,11 @@
 #include "xla/hlo/ir/hlo_module.h"
 #include "xla/hlo/ir/hlo_opcode.h"
 #include "xla/status.h"
-#include "xla/status_macros.h"
 #include "xla/statusor.h"
-#include "xla/types.h"
 #include "xla/util.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/logging.h"
+#include "tsl/platform/statusor.h"
 
 namespace xla {
 
@@ -127,9 +127,7 @@
   return module->RemoveEmbeddedComputation(computation);
 }
 
-StatusOr<bool> HloDCE::RecursivelyRemoveDeadComputations(
-    HloModule* module,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+StatusOr<bool> HloDCE::RecursivelyRemoveDeadComputations(HloModule* module) {
   // Tracks whether any dead code is eliminated by this pass.
   bool module_contains_dead_code = false;
 
@@ -153,8 +151,7 @@
 
   // Find dead computations.
   absl::flat_hash_set<HloComputation*> dead_computations;
-  for (auto* computation :
-       module->MakeComputationPostOrder(execution_threads)) {
+  for (auto* computation : module->MakeComputationPostOrder()) {
     // Finds all "top-level" dead computations not called by any instructions.
     // contains(comp) = true and live_computation_call_count[comp] = 0 also
     // implies that the computation is dead, but is nested in other dead
@@ -189,9 +186,8 @@
   // Now DCE HloComputations.  Keep doing passes through the module until no
   // more computations can be eliminated. The function removes all
   // subcomputations that can be proved to have no remaining live callers.
-  TF_ASSIGN_OR_RETURN(
-      bool module_contains_dead_code,
-      RecursivelyRemoveDeadComputations(module, execution_threads));
+  TF_ASSIGN_OR_RETURN(bool module_contains_dead_code,
+                      RecursivelyRemoveDeadComputations(module));
   changed |= module_contains_dead_code;
 
   VLOG(2) << "After dce:";
diff --git a/third_party/xla/xla/service/hlo_dce.h b/third_party/xla/xla/service/hlo_dce.h
index 9625371..5d9d6ea 100644
--- a/third_party/xla/xla/service/hlo_dce.h
+++ b/third_party/xla/xla/service/hlo_dce.h
@@ -17,10 +17,13 @@
 #define XLA_SERVICE_HLO_DCE_H_
 
 #include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/string_view.h"
 #include "xla/hlo/ir/hlo_computation.h"
 #include "xla/hlo/ir/hlo_instruction.h"
 #include "xla/hlo/ir/hlo_module.h"
 #include "xla/service/hlo_pass_interface.h"
+#include "xla/status.h"
 #include "xla/statusor.h"
 
 namespace xla {
@@ -57,9 +60,7 @@
  private:
   // Finds all computations that are not called by any instruction and removes
   // them from the module. Returns whether any dead code was removed.
-  StatusOr<bool> RecursivelyRemoveDeadComputations(
-      HloModule* module,
-      const absl::flat_hash_set<absl::string_view>& execution_threads);
+  StatusOr<bool> RecursivelyRemoveDeadComputations(HloModule* module);
 
   // Given a dead computation, decrements the ref count of all its called
   // computations and checks if any of the subcomputations become dead after the
diff --git a/third_party/xla/xla/service/hlo_parser.cc b/third_party/xla/xla/service/hlo_parser.cc
index 2743f5b..3178700 100644
--- a/third_party/xla/xla/service/hlo_parser.cc
+++ b/third_party/xla/xla/service/hlo_parser.cc
@@ -1955,7 +1955,10 @@
                                      to_apply.value(), is_stable.value()));
     }
     case HloOpcode::kTuple: {
-      if ((!preset_operands && !ParseOperands(&operands, builder)) ||
+      if ((!preset_operands &&
+           !(shape.has_value()
+                 ? ParseOperands(&operands, builder, shape->tuple_shapes_size())
+                 : ParseOperands(&operands, builder))) ||
           !ParseAttributes(attrs, allow_attributes)) {
         return nullptr;
       }
diff --git a/third_party/xla/xla/service/hlo_parser.h b/third_party/xla/xla/service/hlo_parser.h
index b3a4633..7ac65a6 100644
--- a/third_party/xla/xla/service/hlo_parser.h
+++ b/third_party/xla/xla/service/hlo_parser.h
@@ -29,9 +29,6 @@
 
 namespace xla {
 
-// For details about the syntax accepted by this parser, see
-// g3doc/hlo_parser.md.
-
 // Given a string in the HloModule::ToString() format, parses the string and
 // creates a HloModule with the given config.
 // Note: Tests derived from HloTestBase should use
diff --git a/third_party/xla/xla/service/hlo_parser_test.cc b/third_party/xla/xla/service/hlo_parser_test.cc
index 48bdeae..4baf3be 100644
--- a/third_party/xla/xla/service/hlo_parser_test.cc
+++ b/third_party/xla/xla/service/hlo_parser_test.cc
@@ -4404,6 +4404,19 @@
       ShapeUtil::MakeScalarShape(S32)));
 }
 
+TEST_F(HloParserTest, TupleTypo) {
+  constexpr char text[] = R"(HloModule TupleTypoTest
+ENTRY TupleTypo {
+  pow = s32[] constant(42)
+  ROOT v = (s32[]) tuple(power)
+}
+)";
+  auto result = ParseAndReturnVerifiedModule(text);
+  EXPECT_THAT(result.status(),
+              tsl::testing::StatusIs(tsl::error::INVALID_ARGUMENT,
+                                     HasSubstr("instruction does not exist")));
+}
+
 TEST_F(HloParserTest, InferDotShape) {
   constexpr char text[] = R"(HloModule InferDotShapeTest
 ENTRY InferDotShape {
diff --git a/third_party/xla/xla/service/hlo_value_semantics_analysis.cc b/third_party/xla/xla/service/hlo_value_semantics_analysis.cc
index b625290..dfb029a 100644
--- a/third_party/xla/xla/service/hlo_value_semantics_analysis.cc
+++ b/third_party/xla/xla/service/hlo_value_semantics_analysis.cc
@@ -410,8 +410,16 @@
   auto depth_iter = einsum_depth_map_.find(conditional);
   CHECK(depth_iter != einsum_depth_map_.end());
   const ShapeTree<int> depth_tree = depth_iter->second;
-  return HandleCalledComputation(*conditional->called_computations()[0],
-                                 depth_tree, conditional->operands());
+  // Conditionals have one more operand than the number of branches. The first
+  // operand is the pred.
+  TF_RETURN_IF_ERROR(
+      SetInstructionDepth(conditional->operands()[0], depth_tree));
+  for (int i = 0; i < conditional->branch_count(); ++i) {
+    TF_RETURN_IF_ERROR(
+        HandleCalledComputation(*conditional->called_computations()[i],
+                                depth_tree, {conditional->operands()[i + 1]}));
+  }
+  return OkStatus();
 }
 
 Status EinsumDepthAnalysis::HandleCalledComputation(
@@ -615,6 +623,7 @@
     value_semantics_[parameter] = std::move(semantics_shape_tree);
   }
 }
+
 Status HloValueSemanticsAnalysis::RunOnComputation(
     const HloComputation& computation,
     absl::Span<const HloInstruction* const> operands) {
@@ -1154,8 +1163,9 @@
 Status HloValueSemanticsPropagation::HandleConditional(
     HloInstruction* conditional) {
   for (int i = 0; i < conditional->called_computations().size(); ++i) {
-    TF_RETURN_IF_ERROR(analysis_->RunOnComputation(
-        *conditional->called_computations()[i], conditional->operands()));
+    TF_RETURN_IF_ERROR(
+        analysis_->RunOnComputation(*conditional->called_computations()[i],
+                                    {conditional->operands()[i + 1]}));
   }
   HloComputation* computation = conditional->called_computations()[0];
   const ShapeTree<const HloValueSemantics*>& root_semantics =
diff --git a/third_party/xla/xla/service/hlo_value_semantics_analysis_test.cc b/third_party/xla/xla/service/hlo_value_semantics_analysis_test.cc
index 41c934f..fd1704f 100644
--- a/third_party/xla/xla/service/hlo_value_semantics_analysis_test.cc
+++ b/third_party/xla/xla/service/hlo_value_semantics_analysis_test.cc
@@ -581,5 +581,44 @@
   EXPECT_EQ(GetInstructionDepth(einsum_depth_map, computation, "dot.85"), 0);
 }
 
+TEST_F(EinsumDepthAnalysisTest, HandleConditional) {
+  const char* const hlo_string = R"(
+    HloModule Module
+
+    branch0 {
+      tparam = f32[4] parameter(0)
+      ROOT tgte1 = f32[4] ceil(tparam)
+    }
+
+    branch1 {
+      fparam = f32[4] parameter(0)
+      %async-start = ((f32[4]), f32[4], s32[]) custom-call-start(f32[4] fparam), async_execution_thread="parallel_thread", custom_call_target="foo"
+      ROOT %async-done = f32[4] custom-call-done(((f32[4]), f32[4], s32[]) %async-start), async_execution_thread="parallel_thread", custom_call_target="foo"
+    }
+
+    branch2 {
+      sparam = f32[4] parameter(0)
+      ROOT sgte1 = f32[4] ceil(sparam)
+    }
+
+    ENTRY entry {
+      p0 = f32[4] parameter(0)
+      b0 = s32[] parameter(1)
+      ROOT conditional = f32[4] conditional(b0, p0, p0, p0),
+        branch_computations={branch0, branch1, branch2}
+    }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  TF_ASSERT_OK_AND_ASSIGN(
+      std::unique_ptr<EinsumDepthAnalysis> einsum_depth_analysis,
+      EinsumDepthAnalysis::Run(*module->entry_computation()));
+  const EinsumDepthMap& einsum_depth_map =
+      einsum_depth_analysis->GetEinsumDepthMap();
+  HloComputation* computation = module->GetComputationWithName("entry");
+  EXPECT_EQ(GetInstructionDepth(einsum_depth_map, computation, "conditional"),
+            0);
+}
+
 }  // namespace
 }  // namespace xla
diff --git a/third_party/xla/xla/service/hlo_verifier.cc b/third_party/xla/xla/service/hlo_verifier.cc
index fd34bc4..d980024 100644
--- a/third_party/xla/xla/service/hlo_verifier.cc
+++ b/third_party/xla/xla/service/hlo_verifier.cc
@@ -2761,7 +2761,12 @@
             operand_shape.rank() == result_shape.rank() &&
             operand_shape.has_layout()) {
           const Layout& operand_layout = operand_shape.layout();
-          TF_RET_CHECK(LayoutUtil::Equal(result_layout, operand_layout))
+          Layout::Equal equal_predicate = Layout::Equal();
+          if (instruction->opcode() == HloOpcode::kConvert) {
+            // Convert instructions can change element_size_in_bits
+            equal_predicate.IgnoreElementSize();
+          }
+          TF_RET_CHECK(equal_predicate(result_layout, operand_layout))
               << "Instruction shouldn't change layouts "
               << instruction->ToString() << " From " << result_shape << " To "
               << operand_shape;
@@ -2876,14 +2881,18 @@
       TF_RETURN_IF_ERROR(module->schedule().Verify());
     }
 
-    TF_RETURN_IF_ERROR(module->input_output_alias_config().Verify(
-        *module, [this](const Shape& shape) -> int64_t {
-          if (target_metadata_->GetVerifierOpts().IsLayoutSensitive()) {
-            return target_metadata_->GetVerifierOpts().ShapeSize(shape);
-          } else {
-            return 0;
-          }
-        }));
+    if (HloInstruction::IsThreadIncluded(
+            module->entry_computation()->execution_thread(),
+            execution_threads)) {
+      TF_RETURN_IF_ERROR(module->input_output_alias_config().Verify(
+          *module, [this](const Shape& shape) -> int64_t {
+            if (target_metadata_->GetVerifierOpts().IsLayoutSensitive()) {
+              return target_metadata_->GetVerifierOpts().ShapeSize(shape);
+            } else {
+              return 0;
+            }
+          }));
+    }
 
     TF_RETURN_IF_ERROR(module->buffer_donor_config().Verify(*module));
     TF_RETURN_IF_ERROR(VerifyLayoutConstrainedAllReduce(*module));
diff --git a/third_party/xla/xla/service/hlo_verifier_test.cc b/third_party/xla/xla/service/hlo_verifier_test.cc
index 8a4cbe5..a26c153 100644
--- a/third_party/xla/xla/service/hlo_verifier_test.cc
+++ b/third_party/xla/xla/service/hlo_verifier_test.cc
@@ -1620,7 +1620,7 @@
   ENTRY entry {
     p0 = f32[2,3] parameter(0)
     p1 = u32[] parameter(1)
-    tuple = (f32[2,3], f32[2,3]) tuple(p0, p0, p1, p1)
+    tuple = (f32[2,3], f32[2,3], u32[], u32[]) tuple(p0, p0, p1, p1)
     ROOT done = f32[2,3] all-reduce-done(tuple)
   }
   )";
diff --git a/third_party/xla/xla/service/layout_normalization.cc b/third_party/xla/xla/service/layout_normalization.cc
index 773763e..a4c002a 100644
--- a/third_party/xla/xla/service/layout_normalization.cc
+++ b/third_party/xla/xla/service/layout_normalization.cc
@@ -26,6 +26,7 @@
 #include "xla/hlo/ir/hlo_casting_utils.h"
 #include "xla/hlo/ir/hlo_instruction.h"
 #include "xla/hlo/ir/hlo_module.h"
+#include "xla/layout_util.h"
 #include "xla/permutation_util.h"
 #include "xla/service/hlo_creation_utils.h"
 #include "xla/service/shape_inference.h"
@@ -67,6 +68,12 @@
     const Shape& shape = hlo->shape();
     Shape normalized_shape = Normalize(shape);
     *literal.mutable_shape_do_not_use() = normalized_shape;
+    // Ensure element_size_in_bits of literal is 0, because literals do not
+    // support packed values.
+    literal.mutable_shape_do_not_use()
+        ->mutable_layout()
+        ->set_element_size_in_bits(0);
+
     HloInstruction* bc_to_orig = MakeBitcastHlo(hlo, shape);
     *hlo->mutable_shape() = normalized_shape;
     TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWithDifferentShape(bc_to_orig));
@@ -257,7 +264,8 @@
     auto operand_shape = operand->shape();
 
     // Precondition: elementwise unary leaves layout intact.
-    TF_RET_CHECK(s.layout() == operand_shape.layout())
+    TF_RET_CHECK(
+        Layout::Equal().IgnoreElementSize()(s.layout(), operand_shape.layout()))
         << "Unexpected non-layout preserving elementwise unary: "
         << hlo->ToString();
     TF_ASSIGN_OR_RETURN(auto normalized_input, GetNormalizedInput(operand));
@@ -631,8 +639,9 @@
         << "Unexpected HLO input: " << hlo->ToString();
     auto input = hlo->mutable_operand(0);
     auto input_shape = input->shape();
-    TF_RET_CHECK(input_shape.layout() ==
-                 LayoutUtil::GetDefaultLayoutForShape(input_shape));
+    TF_RET_CHECK(Layout::Equal().IgnoreElementSize()(
+        input_shape.layout(),
+        LayoutUtil::GetDefaultLayoutForShape(input_shape)));
     return input;
   }
 
diff --git a/third_party/xla/xla/service/llvm_ir/BUILD b/third_party/xla/xla/service/llvm_ir/BUILD
index f48d8be..7c266d0 100644
--- a/third_party/xla/xla/service/llvm_ir/BUILD
+++ b/third_party/xla/xla/service/llvm_ir/BUILD
@@ -281,6 +281,7 @@
     hdrs = ["buffer_assignment_util.h"],
     visibility = ["//visibility:public"],
     deps = [
+        "//xla/hlo/ir:hlo",
         "//xla/service:buffer_assignment",
         "@com_google_absl//absl/strings",
     ],
diff --git a/third_party/xla/xla/service/llvm_ir/buffer_assignment_util.cc b/third_party/xla/xla/service/llvm_ir/buffer_assignment_util.cc
index e20bb3a..f92b373 100644
--- a/third_party/xla/xla/service/llvm_ir/buffer_assignment_util.cc
+++ b/third_party/xla/xla/service/llvm_ir/buffer_assignment_util.cc
@@ -18,10 +18,11 @@
 #include <algorithm>
 
 #include "absl/strings/str_cat.h"
+#include "xla/hlo/ir/hlo_instruction.h"
 
 namespace xla {
 namespace llvm_ir {
-static const HloInstruction& InstrForConstantBufferAllocation(
+const HloInstruction& InstrForConstantBufferAllocation(
     const BufferAllocation& allocation) {
   CHECK(allocation.is_constant());
   HloInstruction* const_instr = nullptr;
diff --git a/third_party/xla/xla/service/llvm_ir/buffer_assignment_util.h b/third_party/xla/xla/service/llvm_ir/buffer_assignment_util.h
index d8b80a5..90805fe 100644
--- a/third_party/xla/xla/service/llvm_ir/buffer_assignment_util.h
+++ b/third_party/xla/xla/service/llvm_ir/buffer_assignment_util.h
@@ -38,6 +38,10 @@
 // Returns the Literal corresponding to `allocation`, which must be a constant
 // allocation.
 const Literal& LiteralForConstantAllocation(const BufferAllocation& allocation);
+// Returns the constant HloInstruction corresponding to `allocation`, which must
+// be a constant allocation.
+const HloInstruction& InstrForConstantBufferAllocation(
+    const BufferAllocation& allocation);
 }  // namespace llvm_ir
 }  // namespace xla
 
diff --git a/third_party/xla/xla/service/llvm_ir/llvm_util.cc b/third_party/xla/xla/service/llvm_ir/llvm_util.cc
index 676c9c3..b3f6e7a 100644
--- a/third_party/xla/xla/service/llvm_ir/llvm_util.cc
+++ b/third_party/xla/xla/service/llvm_ir/llvm_util.cc
@@ -23,6 +23,7 @@
 #include <optional>
 #include <string>
 #include <utility>
+#include <vector>
 
 #include "absl/base/casts.h"
 #include "absl/strings/str_cat.h"
@@ -58,6 +59,7 @@
 #include "mlir/IR/Value.h"  // from @llvm-project
 #include "xla/layout_util.h"
 #include "xla/literal.h"
+#include "xla/primitive_util.h"
 #include "xla/service/cpu/cpu_options.h"
 #include "xla/service/dump.h"
 #include "xla/service/hlo_module_config.h"
@@ -141,10 +143,22 @@
     auto cmp = b->CreateFCmpUGE(lhs_value, rhs_value);
     return b->CreateSelect(cmp, lhs_value, rhs_value, name.data());
   } else {
-    auto cmp_ge = b->CreateFCmpOGE(lhs_value, rhs_value);
+    // logic: isNaN(lhs) || (!isNan(rhs) && lhs >= rhs) ? lhs : rhs
+    // See also: IEEE Std 754-2008 5.11.
+    //
+    // This also works, but we wanted to make it similar to minimum.
+    // logic: isNaN(lhs) || lhs >= rhs ? lhs : rhs
+    //
+    // b->CreateMaximum() doesn't work on GPU before SM80.
+    //
+    // A test with a strange LLVM version breaks if we use OGT here, so we use
+    // OGE.
     auto lhs_is_nan = b->CreateFCmpUNE(lhs_value, lhs_value);
-    auto sel_lhs = b->CreateOr(cmp_ge, lhs_is_nan);
-    return b->CreateSelect(sel_lhs, lhs_value, rhs_value, name.data());
+    auto rhs_is_not_nan = b->CreateFCmpOEQ(rhs_value, rhs_value);
+    auto lhs_is_ge = b->CreateFCmpOGE(lhs_value, rhs_value);
+    return b->CreateSelect(
+        b->CreateOr(lhs_is_nan, b->CreateAnd(rhs_is_not_nan, lhs_is_ge)),
+        lhs_value, rhs_value, name.data());
   }
 }
 
@@ -155,10 +169,23 @@
     auto cmp = b->CreateFCmpULE(lhs_value, rhs_value);
     return b->CreateSelect(cmp, lhs_value, rhs_value, name.data());
   } else {
-    auto cmp_le = b->CreateFCmpOLE(lhs_value, rhs_value);
+    // logic: isNaN(lhs) || (!isNan(rhs) && lhs <= rhs) ? lhs : rhs
+    // See also: IEEE Std 754-2008 5.11.
+    //
+    // This should also work, but the tests show that it doesn't work for
+    // minimum(x, NaN) on GPU:
+    // logic: isNaN(lhs) || lhs <= rhs ? lhs : rhs
+    //
+    // b->CreateMaximum() doesn't work on GPU before SM80.
+    //
+    // A test with a strange LLVM version breaks if we use OLT here, so we use
+    // OLE.
     auto lhs_is_nan = b->CreateFCmpUNE(lhs_value, lhs_value);
-    auto sel_lhs = b->CreateOr(cmp_le, lhs_is_nan);
-    return b->CreateSelect(sel_lhs, lhs_value, rhs_value, name.data());
+    auto rhs_is_not_nan = b->CreateFCmpOEQ(rhs_value, rhs_value);
+    auto lhs_is_le = b->CreateFCmpOLE(lhs_value, rhs_value);
+    return b->CreateSelect(
+        b->CreateOr(lhs_is_nan, b->CreateAnd(rhs_is_not_nan, lhs_is_le)),
+        lhs_value, rhs_value, name.data());
   }
 }
 
@@ -190,7 +217,7 @@
                                   llvm::Module* module) {
   switch (element_type) {
     case PRED:
-    // Int8 is used as there is no LLVM S4/U4 dtype
+    // i8 is used for S4/U4 as arrays of i4 values are not packed
     case S4:
     case U4:
     case S8:
@@ -311,10 +338,18 @@
 llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,
                                            llvm::Module* module) {
   const char* data = static_cast<const char*>(literal.untyped_data());
+  int64_t size_bytes = literal.size_bytes();
   CHECK_EQ(module->getDataLayout().isLittleEndian(), tsl::port::kLittleEndian);
-  return llvm::ConstantDataArray::getString(
-      module->getContext(), llvm::StringRef(data, literal.size_bytes()),
-      /*AddNull=*/false);
+  std::vector<char> packed_data;
+  if (primitive_util::Is4BitType(literal.shape().element_type())) {
+    packed_data.resize((size_bytes + 1) / 2);
+    PackInt4(absl::MakeSpan(data, size_bytes), absl::MakeSpan(packed_data));
+    data = packed_data.data();
+    size_bytes = packed_data.size();
+  }
+  return llvm::ConstantDataArray::getString(module->getContext(),
+                                            llvm::StringRef(data, size_bytes),
+                                            /*AddNull=*/false);
 }
 
 llvm::GlobalVariable* AllocateSharedMemoryTile(llvm::Module* module,
diff --git a/third_party/xla/xla/service/memory_space_assignment/BUILD b/third_party/xla/xla/service/memory_space_assignment/BUILD
index 5951248..e16247c 100644
--- a/third_party/xla/xla/service/memory_space_assignment/BUILD
+++ b/third_party/xla/xla/service/memory_space_assignment/BUILD
@@ -48,10 +48,12 @@
         "//xla:util",
         "//xla/hlo/ir:hlo",
         "//xla/hlo/utils:hlo_live_range",
+        "//xla/service:buffer_value",
         "//xla/service:heap_simulator",
         "//xla/service:hlo_cost_analysis",
         "//xla/service:hlo_proto_cc",
         "//xla/service:hlo_value",
+        "//xla/service:time_utils",
         "//xla/service:tuple_util",
         "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/container:btree",
diff --git a/third_party/xla/xla/service/memory_space_assignment/best_fit_repacker.cc b/third_party/xla/xla/service/memory_space_assignment/best_fit_repacker.cc
index 1215116..aa129e8 100644
--- a/third_party/xla/xla/service/memory_space_assignment/best_fit_repacker.cc
+++ b/third_party/xla/xla/service/memory_space_assignment/best_fit_repacker.cc
@@ -153,13 +153,13 @@
 std::vector<const AllocationBlock*> SortAllocationBlocks(const T& container) {
   std::vector<const AllocationBlock*> result;
   result.insert(result.end(), container.begin(), container.end());
-  absl::c_sort(result,
-               [](const AllocationBlock* lhs, const AllocationBlock* rhs) {
-                 return std::make_tuple(lhs->start_time, lhs->end_time,
-                                        lhs->initial_offset, lhs->size) <
-                        std::make_tuple(rhs->start_time, rhs->end_time,
-                                        rhs->initial_offset, rhs->size);
-               });
+  absl::c_sort(
+      result, [](const AllocationBlock* lhs, const AllocationBlock* rhs) {
+        return std::make_tuple(lhs->inclusive_start_time, lhs->end_time,
+                               lhs->initial_offset, lhs->size) <
+               std::make_tuple(rhs->inclusive_start_time, rhs->end_time,
+                               rhs->initial_offset, rhs->size);
+      });
 
   return result;
 }
@@ -198,13 +198,14 @@
             allocation_block);
         need_allocation = false;
       }
-      full_buffer_interval_map_.insert(std::make_pair(
-          allocation_block, BufferInterval{allocation_block,
-                                           allocation_block->size,
-                                           allocation_block->start_time,
-                                           allocation_block->end_time,
-                                           {},
-                                           need_allocation}));
+      full_buffer_interval_map_.insert(
+          std::make_pair(allocation_block,
+                         BufferInterval{allocation_block,
+                                        allocation_block->size,
+                                        allocation_block->inclusive_start_time,
+                                        allocation_block->end_time,
+                                        {},
+                                        need_allocation}));
     }
 
     // Now that full_buffer_interval_map_ has full colocation specifications,
@@ -229,8 +230,8 @@
         CHECK(!original_slice_data.slices_sorted_by_offset.empty());
 
         sliced_buffer_interval.Slice(original_slice_data.SizesSortedByOffset());
-        sliced_buffer_interval.UpdateSliceStartTimes(
-            original_slice_data.SortedStartTimes());
+        sliced_buffer_interval.UpdateInclusiveSliceStartTimes(
+            original_slice_data.SortedInclusiveStartTimes());
       }
 
       // We use buffer_intervals_ to store the minimum buffer interval for
@@ -338,11 +339,11 @@
       repacked_slice_data->slices_sorted_by_offset.reserve(chunks.size());
 
       // Chunks and start times are sorted in start time order.
-      std::vector<int64_t> sorted_start_times =
-          original_slice_data.SortedStartTimes();
+      std::vector<int64_t> sorted_inclusive_start_times =
+          original_slice_data.SortedInclusiveStartTimes();
       for (int i = 0; i < chunks.size(); ++i) {
         const Chunk& chunk = chunks[i];
-        int64_t start_time = sorted_start_times[i];
+        int64_t start_time = sorted_inclusive_start_times[i];
         result_.heap_size = result_.UpdatedHeapSize(chunk);
         VLOG(2) << "Adding sliced chunk " << chunk.ToString() << " at ["
                 << start_time << ", " << allocation_block->end_time << "]";
@@ -361,9 +362,9 @@
       new_offset = chunks.front().offset;
       result_.heap_size = result_.UpdatedHeapSize(chunks.front());
       VLOG(2) << "Adding unsliced chunk " << chunks.front().ToString()
-              << " at [" << allocation_block->start_time << ", "
+              << " at [" << allocation_block->inclusive_start_time << ", "
               << allocation_block->end_time << ")";
-      interval_tree_.Add(allocation_block->start_time,
+      interval_tree_.Add(allocation_block->inclusive_start_time,
                          allocation_block->end_time, chunks.front());
     }
 
@@ -522,13 +523,13 @@
               block->repacked_slice_data->slices_sorted_by_offset[i];
           timed_chunks.push_back(
               TimedChunk{absl::StrCat(((int64_t)block), "_slice_", i), block,
-                         slice.start_time, block->end_time,
+                         slice.inclusive_start_time, block->end_time,
                          Chunk::FromOffsetSize(slice.offset, slice.size)});
         }
       } else {
         timed_chunks.push_back(
-            TimedChunk{absl::StrCat(((int64_t)block)), block, block->start_time,
-                       block->end_time,
+            TimedChunk{absl::StrCat(((int64_t)block)), block,
+                       block->inclusive_start_time, block->end_time,
                        Chunk::FromOffsetSize(block->offset, block->size)});
       }
     }
diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.cc b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.cc
index bf051d3..a665217 100644
--- a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.cc
+++ b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.cc
@@ -22,6 +22,7 @@
 #include <functional>
 #include <iterator>
 #include <limits>
+#include <list>
 #include <memory>
 #include <optional>
 #include <ostream>
@@ -41,6 +42,7 @@
 #include "absl/strings/str_split.h"
 #include "absl/types/span.h"
 #include "xla/debug_options_flags.h"
+#include "xla/hlo/ir/hlo_computation.h"
 #include "xla/hlo/ir/hlo_instruction.h"
 #include "xla/hlo/ir/hlo_opcode.h"
 #include "xla/hlo/utils/hlo_live_range.h"
@@ -49,6 +51,7 @@
 #include "xla/service/memory_space_assignment/repacking.h"
 #include "xla/service/memory_space_assignment/tuning_utils.h"
 #include "xla/service/memory_space_assignment/utils.h"
+#include "xla/service/time_utils.h"
 #include "xla/service/tuple_util.h"
 #include "xla/shape.h"
 #include "xla/shape_util.h"
@@ -240,6 +243,11 @@
          });
 }
 
+struct CrossProgramPrefetchBufferSortValues {
+  int64_t latest_use = 0;
+  int64_t use_size = 0;
+};
+
 std::vector<MemorySpaceAssignment::BufferInterval>
 FindCrossProgramPrefetchCandidates(const HloAliasAnalysis& alias_analysis,
                                    const HloLiveRange& hlo_live_range,
@@ -260,37 +268,21 @@
     }
   }
 
-  // The BufferIntervalCompare function used to sort buffers implements the
-  // greater-than operator so that the most beneficial buffers are allocated
-  // first. The size_compare function below hence uses the greater-than operator
-  // to pick the largest buffer.
-  auto size_compare = [](const auto& x, const auto& y) {
-    if (x.size == y.size) {
-      // When both buffers are of same size, we prefer the one that is used to
-      // produce larger tensors in its consumer instructions.
-      auto get_use_size =
-          [](const MemorySpaceAssignment::BufferInterval& bi) -> int64_t {
-        int64_t use_size = 0;
-        for (const auto& use : bi.buffer->GetUses()) {
-          use_size += ShapeUtil::ElementsInRecursive(use.instruction->shape());
-        }
-        return use_size;
-      };
-      return get_use_size(x) > get_use_size(y);
-    }
-    return x.size > y.size;
-  };
-  auto& compare = options.default_cross_program_prefetch_heuristic &&
-                          options.buffer_interval_compare
-                      ? *options.buffer_interval_compare
-                      : size_compare;
+  DefaultCrossProgramPrefetchBufferIntervalComparator default_comparator(
+      hlo_live_range);
+  MemorySpaceAssignment::BufferIntervalComparator* comparator =
+      (options.default_cross_program_prefetch_heuristic &&
+               options.buffer_interval_comparator
+           ? options.buffer_interval_comparator
+           : &default_comparator);
+  absl::c_sort(candidates, comparator->GetComparisonFunctor());
 
-  absl::c_sort(candidates, compare);
-
-  VLOG(3) << "Cross-program prefetch candidates: " << candidates.size();
+  VLOG(3) << "Cross-program prefetch candidates: " << candidates.size()
+          << ". Sorting criteria: " << comparator->DescribeComparisonCriteria();
   for (auto& candidate : candidates) {
-    VLOG(3) << "Cross-program prefetch candidate picked: "
-            << candidate.buffer->ToString();
+    VLOG(3) << "Cross-program prefetch candidate. Sorting criteria: "
+            << comparator->CriteriaToString(candidate)
+            << ". Candidate: " << candidate.buffer->ToString();
   }
   return candidates;
 }
@@ -1710,8 +1702,9 @@
       hlo_live_range_(hlo_live_range),
       peak_memory_usage_(hlo_live_range.schedule_end_time() + 1) {
   // Override buffer interval compare if provided.
-  if (options.buffer_interval_compare) {
-    buffer_interval_compare_ = *options.buffer_interval_compare;
+  if (options.buffer_interval_comparator) {
+    buffer_interval_compare_ =
+        options.buffer_interval_comparator->GetComparisonFunctor();
   }
 
   call_graph_ = CallGraph::Build(&alias_analysis_.dataflow_analysis().module());
@@ -3195,8 +3188,8 @@
       std::make_unique<MemorySpaceAssignment::CopyAllocation>(
           *value->allocations.back(),
           MemorySpaceAssignment::MemorySpace::kAlternate, std::nullopt,
-          ((*copy_start_time - 1) + loop_size_) % loop_size_,
-          last_use_idx_sentinel, first_use_idx));
+          ((*copy_start_time - 1) + loop_size_) % loop_size_, first_use_idx,
+          last_use_idx_sentinel));
   AddAllLoopPositionsAndUses(*value, /*allocate_next_iteration_uses=*/true);
 
   // Account for the additional memory used by early forcing the already
@@ -3648,6 +3641,15 @@
   // Calculate the memory pressure for the buffers that can be assigned in the
   // alternate memory.
   memory_pressure_ = 0;
+  VLOG(5) << [&]() {
+    std::string s("Sorted BufferInterval order.");
+    if (options_.buffer_interval_comparator) {
+      absl::StrAppend(
+          &s, " Pre-autotuning sort criteria: ",
+          options_.buffer_interval_comparator->DescribeComparisonCriteria());
+    }
+    return s;
+  }();
   for (auto& interval : sorted_buffer_intervals) {
     if (!interval.need_allocation ||
         !MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory(
@@ -3655,6 +3657,16 @@
         interval.size > available_heap_size()) {
       continue;
     }
+    VLOG(5) << [&]() {
+      std::string s("SortedBufferInterval.");
+      if (options_.buffer_interval_comparator) {
+        absl::StrAppend(
+            &s, " Criteria: ",
+            options_.buffer_interval_comparator->CriteriaToString(interval));
+      }
+      absl::StrAppend(&s, " Buffer: ", interval.buffer->ToShortString());
+      return s;
+    }();
     memory_pressure_ += interval.size;
   }
   VLOG(1) << "Memory pressure = " << memory_pressure_;
@@ -4424,7 +4436,7 @@
         // Rarely, (e.g., when conditional true and false parameters are the
         // same), definition time can be the time of the conditional and use
         // time is the parameter use, which is less.
-        request.start_time = std::min(definition_time, use_time);
+        request.inclusive_start_time = std::min(definition_time, use_time);
         request.end_time = use_time;
         request.latest_prefetch_time = latest_prefetch_time;
         request.size = allocation_value.size();
@@ -4556,44 +4568,45 @@
 }
 
 void AsynchronousCopyOrdering::AddCopy(const AsynchronousCopy& copy) {
-  auto it = ranges_.find({copy.start_time, copy.end_time});
+  auto it = ranges_.find({copy.exclusive_start_time, copy.end_time});
   if (it != ranges_.end()) {
-    CHECK_EQ(it->first.start_time, copy.start_time);
+    CHECK_EQ(it->first.exclusive_start_time, copy.exclusive_start_time);
     CHECK(it->second.insert(copy).second);
   } else {
-    ranges_[{copy.start_time, copy.end_time}] = {copy};
+    ranges_[{copy.exclusive_start_time, copy.end_time}] = {copy};
   }
 }
 
 void AsynchronousCopyOrdering::RemoveCopy(const AsynchronousCopy& copy) {
-  auto copy_it = ranges_.find({copy.start_time, copy.end_time});
+  auto copy_it = ranges_.find({copy.exclusive_start_time, copy.end_time});
   CHECK(copy_it != ranges_.end());
-  CHECK_EQ(copy_it->first.start_time, copy.start_time);
+  CHECK_EQ(copy_it->first.exclusive_start_time, copy.exclusive_start_time);
   CHECK_EQ(copy_it->second.erase(copy), 1);
   if (copy_it->second.empty()) {
     ranges_.erase(copy_it);
   }
 }
 
-bool AsynchronousCopyOrdering::ViolatesOrdering(int64_t start_time,
+bool AsynchronousCopyOrdering::ViolatesOrdering(int64_t exclusive_start_time,
                                                 int64_t end_time) const {
   // We allow identical start and end times. It is enough to check for just the
   // start time in case we find a match in ranges_ because the found value will
   // either be identical to {start_time, estimated_end_time} (and this doesn't
   // violate) or its start_time will be smaller and estimated_end_time will be
   // larger (this violates).
-  auto copy_it = ranges_.find({start_time, end_time});
-  if (copy_it != ranges_.end() && copy_it->first.start_time != start_time) {
-    VLOG(4) << "Violates ordering: (" << start_time << ", " << end_time
-            << ") and (" << copy_it->first.start_time << ", "
-            << copy_it->first.end_time << ")";
+  auto copy_it = ranges_.find({exclusive_start_time, end_time});
+  if (copy_it != ranges_.end() &&
+      copy_it->first.exclusive_start_time != exclusive_start_time) {
+    VLOG(4) << "Violates ordering: (" << exclusive_start_time << ", "
+            << end_time << ") and (" << copy_it->first.exclusive_start_time
+            << ", " << copy_it->first.end_time << ")";
     return true;
   }
   return false;
 }
 
 bool AsynchronousCopyResource::ConsumeResource(
-    int64_t start_time, int64_t end_time, float resource,
+    int64_t exclusive_start_time, int64_t end_time, float resource,
     absl::flat_hash_map<int64_t, float>* delay_change_map,
     float resource_to_free) {
   std::list<AsynchronousCopy>::iterator current_copy = async_copies_.end();
@@ -4605,10 +4618,16 @@
     // resource is modified below. We save its initial value for logging below.
     const float amount_requested = resource;
 
-    VLOG(3) << "Consume resource: start time = " << start_time
-            << ", end time = " << end_time << ", resource = " << resource
-            << ", delay = " << delay_[start_time + 1]
+    VLOG(3) << "Consume resource: start time_exclusive = "
+            << exclusive_start_time << ", end time = " << end_time
+            << ", resource = " << resource << ", delay = "
+            << delay_[ExclusiveToInclusiveStartTime(exclusive_start_time)]
             << ", free = " << resource_to_free;
+    VLOG(5) << "Available resources: "
+            << VectorToString(
+                   GetCurrentResources(), /*include_indices=*/true,
+                   ExclusiveToInclusiveStartTime(exclusive_start_time),
+                   end_time);
 
     // Nothing to do if we're not adding or removing any resources.
     if (resource == 0.0 && resource_to_free == 0.0) {
@@ -4619,7 +4638,7 @@
     // this copy would have to be delayed because of an earlier copy that wasn't
     // finished when this copy starts.
     if (current_copy == async_copies_.end()) {
-      resource += delay_[start_time + 1];
+      resource += delay_[ExclusiveToInclusiveStartTime(exclusive_start_time)];
     }
 
     // Find the copy that is right after this one. If there are leftover
@@ -4629,7 +4648,8 @@
     if (current_copy != async_copies_.end()) {
       next_copy = std::next(current_copy);
     } else {
-      auto async_copy_time_it = async_copy_time_map_.upper_bound(start_time);
+      auto async_copy_time_it =
+          async_copy_time_map_.upper_bound(exclusive_start_time);
       if (async_copy_time_it != async_copy_time_map_.end()) {
         next_copy = async_copy_time_it->second;
       }
@@ -4640,13 +4660,14 @@
     // earlier in time).
     std::optional<float> delay_for_next_copy = std::nullopt;
     float resource_freed = 0.0;
-    for (int64_t time = start_time + 1; time < end_time && resource != 0;
-         ++time) {
+    for (int64_t time = ExclusiveToInclusiveStartTime(exclusive_start_time);
+         time < end_time && resource != 0; ++time) {
       // Iterate over the logical times that this copy spans. Note that the
       // start and end time ranges are exclusive.
       float used_resource = std::min(resource, initial_resources_[time]);
       if (next_copy != async_copies_.end() &&
-          next_copy->start_time == time - 1) {
+          next_copy->exclusive_start_time ==
+              InclusiveToExclusiveStartTime(time)) {
         // This is the time where the next copy begins. If the resource is
         // non-zero at this point, the copy didn't finish by the time the next
         // copy started, so the next copy would need to be pushed later in time.
@@ -4672,9 +4693,6 @@
 
     // If resource isn't satisfied by the end, we didn't have enough resources.
     if (resource > 0) {
-      VLOG(5) << "Available resources: "
-              << VectorToString(GetCurrentResources(), /*include_indices=*/true,
-                                start_time + 1, end_time);
       VLOG(3) << "Doesn't have enough resource; requested resource = "
               << amount_requested << "; leftover resources = " << resource;
       return false;
@@ -4686,7 +4704,7 @@
     // If this copy overlapped with another one, we run for another iteration
     // with the next copy  with the amount of resource that needs to be added or
     // removed.
-    start_time = next_copy->start_time;
+    exclusive_start_time = next_copy->exclusive_start_time;
     end_time = next_copy->end_time;
     resource = *delay_for_next_copy + next_copy->resource;
     current_copy = next_copy;
@@ -4694,11 +4712,13 @@
 }
 
 void AsynchronousCopyResource::AddCopy(const AsynchronousCopy& copy) {
-  CHECK(ConsumeResource(copy.start_time, copy.end_time, copy.resource));
+  CHECK(
+      ConsumeResource(copy.exclusive_start_time, copy.end_time, copy.resource));
 
   // Find the iterator for the copy that would be right after this copy and put
   // this copy right before it in async_copies_.
-  auto async_copy_time_it = async_copy_time_map_.upper_bound(copy.start_time);
+  auto async_copy_time_it =
+      async_copy_time_map_.upper_bound(copy.exclusive_start_time);
   auto insertion_it = (async_copy_time_it == async_copy_time_map_.end())
                           ? async_copies_.end()
                           : async_copy_time_it->second;
@@ -4709,9 +4729,9 @@
   // start index. If there are multiple asynchronous copies that have the same
   // start time, the memory space assignment algorithm schedules them in the
   // same order that AddCopy was called.
-  if (async_copy_time_map_.find(copy.start_time) ==
+  if (async_copy_time_map_.find(copy.exclusive_start_time) ==
       async_copy_time_map_.end()) {
-    async_copy_time_map_[copy.start_time] = inserted_it;
+    async_copy_time_map_[copy.exclusive_start_time] = inserted_it;
   }
 }
 
@@ -4725,7 +4745,8 @@
   // remove the copies until we find the copy we actually want to remove. After
   // we remove the copy that we actually want to remove, we add back the
   // temporarily removed copies one by one in the same order.
-  auto async_copy_time_it = async_copy_time_map_.upper_bound(copy.start_time);
+  auto async_copy_time_it =
+      async_copy_time_map_.upper_bound(copy.exclusive_start_time);
   auto copy_it = (async_copy_time_it == async_copy_time_map_.end())
                      ? async_copies_.end()
                      : async_copy_time_it->second;
@@ -4736,10 +4757,10 @@
   auto prev_copy_it = copy_it;
   for (; *copy_it != copy; copy_it = prev_copy_it) {
     CHECK(copy_it != async_copies_.begin());
-    CHECK_EQ(copy_it->start_time, copy.start_time);
+    CHECK_EQ(copy_it->exclusive_start_time, copy.exclusive_start_time);
     copies_to_add_back.push_front(*copy_it);
     VLOG(4) << "RemoveCopy found a copy to temporarily remove and add back: "
-            << copy_it->start_time << " " << copy_it->end_time << " "
+            << copy_it->exclusive_start_time << " " << copy_it->end_time << " "
             << copy_it->resource;
     prev_copy_it = std::prev(copy_it);
     RemoveCopy(copy_it);
@@ -4756,19 +4777,21 @@
     std::list<AsynchronousCopy>::iterator& copy_it) {
   // This method works only for the latest copy for the given start time.
   CHECK(std::next(copy_it) == async_copies_.end() ||
-        std::next(copy_it)->start_time > copy_it->start_time);
-  CHECK(ConsumeResource(copy_it->start_time, copy_it->end_time, /*resource=*/0,
+        std::next(copy_it)->exclusive_start_time >
+            copy_it->exclusive_start_time);
+  CHECK(ConsumeResource(copy_it->exclusive_start_time, copy_it->end_time,
+                        /*resource=*/0,
                         /*delay_change_map=*/nullptr,
                         /*resource_to_free=*/copy_it->resource));
   // If the copy to be removed is the value pointed by async_copy_time_map_, we
   // make the next copy with the same start time to be pointed by
   // async_copy_time_map_. If there are no such copies, we remove the key for
   // this copy start time.
-  int64_t start_time = copy_it->start_time;
-  auto async_copy_time_it = async_copy_time_map_.find(start_time);
+  int64_t exclusive_start_time = copy_it->exclusive_start_time;
+  auto async_copy_time_it = async_copy_time_map_.find(exclusive_start_time);
   if (copy_it == async_copy_time_it->second) {
     if (std::next(copy_it) != async_copies_.end() &&
-        std::next(copy_it)->start_time == start_time) {
+        std::next(copy_it)->exclusive_start_time == exclusive_start_time) {
       async_copy_time_it->second = std::next(copy_it);
     } else {
       async_copy_time_map_.erase(async_copy_time_it);
@@ -4777,11 +4800,12 @@
   async_copies_.erase(copy_it);
 }
 
-bool AsynchronousCopyResource::HasEnoughResource(int64_t start_time,
+bool AsynchronousCopyResource::HasEnoughResource(int64_t exclusive_start_time,
                                                  int64_t end_time,
                                                  float resource) {
   absl::flat_hash_map<int64_t, float> delay_changes;
-  bool result = ConsumeResource(start_time, end_time, resource, &delay_changes);
+  bool result =
+      ConsumeResource(exclusive_start_time, end_time, resource, &delay_changes);
   for (const auto& change_pair : delay_changes) {
     delay_[change_pair.first] = change_pair.second;
   }
@@ -4792,8 +4816,8 @@
     const std::vector<ResourceSpec>& specs) {
   absl::flat_hash_map<int64_t, float> delay_changes;
   bool result = absl::c_all_of(specs, [&](const ResourceSpec& spec) {
-    return ConsumeResource(spec.start_time, spec.end_time, spec.resource,
-                           &delay_changes);
+    return ConsumeResource(spec.exclusive_start_time, spec.end_time,
+                           spec.resource, &delay_changes);
   });
   for (const auto& change_pair : delay_changes) {
     delay_[change_pair.first] = change_pair.second;
@@ -4835,11 +4859,12 @@
     if (copy.destination != memory_space_filter) {
       continue;
     }
-    int64_t overlap_start = std::max(start_time, copy.start_time);
+    int64_t overlap_start = std::max(start_time, copy.exclusive_start_time);
     int64_t overlap_end = std::min(end_time, copy.end_time);
     if (overlap_start < overlap_end) {
       lines.push_back(absl::StrCat(
-          "copy(id: ", copy.id, ", start: ", copy.start_time,
+          "copy(id: ", copy.id,
+          ", exclusive_start: ", copy.exclusive_start_time,
           ", end: ", copy.end_time, ", resource: ", copy.resource, ")"));
     }
     for (int i = overlap_start; i < overlap_end; ++i) {
@@ -4962,20 +4987,21 @@
       options_.prefetch_interval_picker->LatestPrefetchStartTime(
           buffer->defining_position().shape(), last_use_time,
           end_of_program_prefetch_end_time, nullptr);
-  int64_t end_of_program_prefetch_start_time =
+  int64_t end_of_program_inclusive_prefetch_start_time =
       options_.prefetch_interval_picker->PreferredPrefetchStartTime(
           buffer->defining_position().shape(), last_use_time,
           end_of_program_prefetch_latest_start_time,
           end_of_program_prefetch_end_time);
   VLOG(2) << "last use time = " << last_use_time
-          << ", end-of-program prefetch start time = "
-          << end_of_program_prefetch_start_time;
+          << ", end-of-program inclusive prefetch start time = "
+          << end_of_program_inclusive_prefetch_start_time;
   float total_execution_time =
       options_.prefetch_interval_picker->GetLogicalIntervalElapsed(
           0, instruction_schedule.size());
   float buffer_occupied_time =
       options_.prefetch_interval_picker->GetLogicalIntervalElapsed(
-          end_of_program_prefetch_start_time, end_of_program_prefetch_end_time);
+          end_of_program_inclusive_prefetch_start_time,
+          end_of_program_prefetch_end_time);
   if (options_.cost_analysis) {
     buffer_occupied_time = std::max(buffer_occupied_time,
                                     options_.cost_analysis->GetAsyncCopyElapsed(
@@ -4995,14 +5021,17 @@
       (options_.enable_cross_program_prefetch_freeing &&
        memory_pressure_ > options_.max_size_in_bytes &&
        buffer_occupied_ratio < kCrossProgramPrefetchOccupyFreeingLimit &&
-       end_of_program_prefetch_start_time > last_use_time &&
-       end_of_program_prefetch_start_time < end_of_program_prefetch_end_time);
+       end_of_program_inclusive_prefetch_start_time > last_use_time &&
+       end_of_program_inclusive_prefetch_start_time <
+           end_of_program_prefetch_end_time);
   int64_t cross_program_prefetch_end_time =
       free_buffer ? last_use_time : prefetch_candidate.end;
 
   AddAsyncCopy(*allocations.back(), MemorySpace::kAlternate, chunk_candidate,
-               prefetch_candidate.start, cross_program_prefetch_end_time,
-               latest_prefetch_time, &allocations, /*aliased_offset=*/nullptr,
+               /*exclusive_start_time=*/
+               InclusiveToExclusiveStartTime(prefetch_candidate.start),
+               cross_program_prefetch_end_time, latest_prefetch_time,
+               &allocations, /*aliased_offset=*/nullptr,
                /*resource=*/0.0, cross_program_prefetch_index);
 
   absl::c_for_each(uses, [&](auto& use) { allocations.back()->AddUse(use); });
@@ -5013,7 +5042,9 @@
     VLOG(2) << "Adding an end-of-program prefetch for freed "
                "cross-program-prefetched buffer.";
     AddAsyncCopy(*allocations.front(), MemorySpace::kAlternate, chunk_candidate,
-                 end_of_program_prefetch_start_time,
+                 /*exclusive_start_time=*/
+                 InclusiveToExclusiveStartTime(
+                     end_of_program_inclusive_prefetch_start_time),
                  end_of_program_prefetch_end_time,
                  end_of_program_prefetch_end_time, &allocations,
                  cross_program_prefetch_offset,
@@ -5454,12 +5485,14 @@
     MemorySpaceAssignmentRepacker::SlicedAllocationData original_slice_data;
     for (const SliceDetail* slice_detail : slice_details_sorted_by_offset) {
       CHECK_EQ(slice_detail->copy_start_after_time,
-               slice_detail->slice_decision.start_time);
+               slice_detail->slice_decision.exclusive_start_time);
       original_slice_data.slices_sorted_by_offset.push_back(
           MemorySpaceAssignmentRepacker::Slice{
               slice_detail->slice_decision.chunk.size,
               slice_detail->slice_decision.chunk.offset,
-              slice_detail->slice_decision.start_time});
+              /*inclusive_start_time=*/
+              ExclusiveToInclusiveStartTime(
+                  slice_detail->slice_decision.exclusive_start_time)});
     }
 
     allocation_block.original_slice_data = std::move(original_slice_data);
@@ -5489,7 +5522,7 @@
   block.initial_offset = repacked_offset;
   block.offset = -1;
   interval_tree_.Add(
-      block.start_time, block.end_time,
+      block.inclusive_start_time, block.end_time,
       HeapSimulator::Chunk::FromOffsetSize(repacked_offset, block.size));
 
   VLOG(3) << "Repacking move. offset: " << original_offset << " -> "
@@ -5530,8 +5563,10 @@
   // we don't need to worry about modifying the chunks here.
   for (const SliceDetail& slice_detail :
        allocation->slice_details_sorted_by_start_time()) {
-    interval_tree_.Add(slice_detail.copy_start_after_time, block.end_time,
-                       slice_detail.slice_decision.chunk);
+    interval_tree_.Add(
+        /*start=*/
+        ExclusiveToInclusiveStartTime(slice_detail.copy_start_after_time),
+        block.end_time, slice_detail.slice_decision.chunk);
   }
 
   VLOG(3) << "Repacking move. offset: " << original_offset << " -> "
@@ -5574,18 +5609,22 @@
     }
     interval_tree_.Remove(interval.start, interval.end, chunk);
   }
-  for (const auto& interval : pending_async_copies_) {
-    if (interval.destination == MemorySpace::kAlternate) {
-      prefetch_interval_tree_.Remove(interval.start_time, interval.end_time,
-                                     kDummyChunk);
-      prefetch_async_copy_resource_.RemoveCopy(interval);
+  for (const AsynchronousCopy& async_copy : pending_async_copies_) {
+    if (async_copy.destination == MemorySpace::kAlternate) {
+      prefetch_interval_tree_.Remove(
+          /*start=*/
+          ExclusiveToInclusiveStartTime(async_copy.exclusive_start_time),
+          async_copy.end_time, kDummyChunk);
+      prefetch_async_copy_resource_.RemoveCopy(async_copy);
       if (options_.enforce_prefetch_fifo_order) {
-        async_copy_ordering_.RemoveCopy(interval);
+        async_copy_ordering_.RemoveCopy(async_copy);
       }
     } else {
-      eviction_interval_tree_.Remove(interval.start_time, interval.end_time,
-                                     kDummyChunk);
-      eviction_async_copy_resource_.RemoveCopy(interval);
+      eviction_interval_tree_.Remove(
+          /*start=*/
+          ExclusiveToInclusiveStartTime(async_copy.exclusive_start_time),
+          async_copy.end_time, kDummyChunk);
+      eviction_async_copy_resource_.RemoveCopy(async_copy);
     }
   }
   for (const auto& value_and_required_assignment :
@@ -5674,43 +5713,44 @@
 void AlternateMemoryBestFitHeap::AddToPendingChunks(
     const BufferInterval& buffer_interval, const Chunk& chunk_candidate) {
   VLOG(3) << "Committing chunk: " << buffer_interval.start << "-"
-          << buffer_interval.end << " : [" << chunk_candidate.offset << ", "
-          << chunk_candidate.size << "]";
+          << buffer_interval.end << " : " << chunk_candidate.ToString();
   pending_chunks_.emplace_back(buffer_interval, chunk_candidate);
   for (int i = buffer_interval.start; i <= buffer_interval.end; ++i) {
     peak_memory_usage_[i] += chunk_candidate.size;
     CHECK_LE(peak_memory_usage_[i], options_.max_size_in_bytes)
         << "Peak memory usage at " << i
         << " exceeds the max size of alternate memory. "
-        << buffer_interval.start << "-" << buffer_interval.end << " : ["
-        << chunk_candidate.offset << ", " << chunk_candidate.size << "]";
+        << buffer_interval.start << "-" << buffer_interval.end << " : "
+        << chunk_candidate.ToString();
   }
   CommitChunk(buffer_interval, chunk_candidate);
 }
 
 std::optional<int>
-AlternateMemoryBestFitHeap::FindEarliestTimeToSatisfyPeakMemory(
-    int start_time, int end_time, int64_t size) const {
-  int earliest_time;
-  for (earliest_time = end_time;
-       earliest_time >= start_time &&
-       peak_memory_usage_[earliest_time] + size <= options_.max_size_in_bytes;
-       --earliest_time) {
+AlternateMemoryBestFitHeap::FindEarliestExclusiveTimeToSatisfyPeakMemory(
+    int exclusive_start_time, int end_time, int64_t size) const {
+  std::optional<int> earliest_time_exclusive = std::nullopt;
+  for (int time_inclusive = ExclusiveToInclusiveEndTime(end_time);
+       time_inclusive > exclusive_start_time; --time_inclusive) {
+    if (peak_memory_usage_[time_inclusive] + size <=
+        options_.max_size_in_bytes) {
+      earliest_time_exclusive = InclusiveToExclusiveStartTime(time_inclusive);
+    } else {
+      break;
+    }
   }
-  if (earliest_time == end_time) {
-    return std::nullopt;
-  }
-  return earliest_time + 1;
+
+  return earliest_time_exclusive;
 }
 
 AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::AllocateSegment(
     const AllocationRequest& request) {
   auto allocation_sequence =
       request.allocation_value->mutable_allocation_sequence();
-  // start_time == end_time is a special case where the value is consumed
-  // multiple times by the same instruction. We can just find the previous
-  // allocation and use that allocation.
-  if (request.start_time == request.end_time) {
+  // inclusive_start_time == end_time is a special case where the value is
+  // consumed multiple times by the same instruction. We can just find the
+  // previous allocation and use that allocation.
+  if (request.inclusive_start_time == request.end_time) {
     MemorySpaceAssignment::Allocation* allocation =
         GetLiveAllocationAt(*allocation_sequence, request.end_time);
     CHECK_NE(allocation, nullptr);
@@ -5721,14 +5761,14 @@
   const HloPosition& defining_position =
       request.allocation_value->defining_position();
   VLOG(2) << "Finding allocation for "
-          << request.allocation_value->ToShortString() << " ("
-          << request.start_time << ", " << request.end_time
+          << request.allocation_value->ToShortString() << " ["
+          << request.inclusive_start_time << ", " << request.end_time
           << ") latest prefetch = " << request.latest_prefetch_time
           << " last use = " << request.allocation_value->uses().back().time
           << " use = " << request.use->hlo_use.ToString()
           << ". Size = " << request.size
           << ", def pos = " << defining_position.ToString();
-  CHECK_LE(request.start_time, request.end_time);
+  CHECK_LE(request.inclusive_start_time, request.end_time);
   if (VLOG_IS_ON(3) && options_.cost_analysis) {
     const HloPosition& defining_position =
         request.allocation_value->defining_position();
@@ -5753,7 +5793,7 @@
   // memory, we cannot prefetch it because if we did, it would be in alternate
   // memory instead.
   auto required_assignment_at_start = RequiredMemoryAssignmentAt(
-      request.allocation_value->value(), request.start_time);
+      request.allocation_value->value(), request.inclusive_start_time);
   std::optional<MemorySpace> required_memory_space_at_start;
   if (required_assignment_at_start) {
     required_memory_space_at_start = required_assignment_at_start->memory_space;
@@ -5787,7 +5827,7 @@
             return allocation->memory_space() == required_memory_space_at_start;
           });
       if (prev_allocation_it != allocation_sequence->rend()) {
-        (*prev_allocation_it)->set_end_time(request.start_time);
+        (*prev_allocation_it)->set_end_time(request.inclusive_start_time);
         needs_required_allocation = false;
       }
     }
@@ -5801,7 +5841,8 @@
       allocation_sequence->push_back(
           std::make_unique<MemorySpaceAssignment::Allocation>(
               defining_position, required_assignment_at_start->memory_space,
-              aliased_chunk, request.start_time, request.start_time,
+              aliased_chunk, request.inclusive_start_time,
+              request.inclusive_start_time,
               /*is_scoped_allocation=*/false));
       if (required_assignment_at_start->memory_space ==
           MemorySpace::kAlternate) {
@@ -5848,7 +5889,8 @@
     allocation_sequence->push_back(
         std::make_unique<MemorySpaceAssignment::Allocation>(
             defining_position, MemorySpace::kDefault,
-            /*chunk=*/std::nullopt, request.start_time, request.end_time,
+            /*chunk=*/std::nullopt, request.inclusive_start_time,
+            request.end_time,
             /*is_scoped_allocation=*/false));
     prev_allocation_in_default_mem_it = allocation_sequence->rbegin();
   }
@@ -5950,8 +5992,9 @@
 
 void AlternateMemoryBestFitHeap::AddAsyncCopy(
     MemorySpaceAssignment::Allocation& prev_allocation,
-    MemorySpace memory_space, std::optional<Chunk> chunk, int64_t start_time,
-    int64_t end_time, int64_t copy_done_schedule_before_time,
+    MemorySpace memory_space, std::optional<Chunk> chunk,
+    int64_t exclusive_start_time, int64_t end_time,
+    int64_t copy_done_schedule_before_time,
     MemorySpaceAssignment::AllocationSequence* allocations,
     AliasedOffset* aliased_offset, float resource,
     std::optional<int> cross_program_prefetch_index) {
@@ -5959,32 +6002,37 @@
           << (memory_space == MemorySpaceAssignment::MemorySpace::kDefault
                   ? "default"
                   : "alternate")
-          << " memory between " << start_time << " and "
-          << copy_done_schedule_before_time << " keeping until " << end_time
+          << " memory in (" << exclusive_start_time << ", "
+          << copy_done_schedule_before_time << "), keeping until " << end_time
           << ", estimated copy resource is " << resource;
-  CHECK_LT(start_time, copy_done_schedule_before_time);
+  CHECK_LT(exclusive_start_time, copy_done_schedule_before_time);
 
   allocations->push_back(
       std::make_unique<MemorySpaceAssignment::CopyAllocation>(
-          prev_allocation, memory_space, chunk, start_time, end_time,
-          copy_done_schedule_before_time, cross_program_prefetch_index));
+          prev_allocation, memory_space, chunk, exclusive_start_time,
+          copy_done_schedule_before_time, end_time,
+          cross_program_prefetch_index));
 
   // Register the additional async copy with the interval tree to keep track of
   // the limit at any given time.
-  pending_async_copies_.push_back({start_time, copy_done_schedule_before_time,
-                                   resource, memory_space,
-                                   next_async_copy_id_++});
+  pending_async_copies_.push_back({exclusive_start_time,
+                                   copy_done_schedule_before_time, resource,
+                                   memory_space, next_async_copy_id_++});
   if (memory_space == MemorySpaceAssignment::MemorySpace::kAlternate) {
-    prefetch_interval_tree_.Add(start_time, copy_done_schedule_before_time,
-                                kDummyChunk);
+    prefetch_interval_tree_.Add(
+        /*start=*/
+        ExclusiveToInclusiveStartTime(exclusive_start_time),
+        copy_done_schedule_before_time, kDummyChunk);
     prefetch_async_copy_resource_.AddCopy(pending_async_copies_.back());
     if (options_.enforce_prefetch_fifo_order) {
       async_copy_ordering_.AddCopy(pending_async_copies_.back());
     }
     CreateOrAddToAliasedOffset(*allocations->back(), aliased_offset);
   } else {
-    eviction_interval_tree_.Add(start_time, copy_done_schedule_before_time,
-                                kDummyChunk);
+    eviction_interval_tree_.Add(
+        /*start=*/
+        ExclusiveToInclusiveStartTime(exclusive_start_time),
+        copy_done_schedule_before_time, kDummyChunk);
     eviction_async_copy_resource_.AddCopy(pending_async_copies_.back());
   }
 }
@@ -6005,7 +6053,7 @@
 
   for (const auto& slice_decision : slice_decisions) {
     std::vector<std::string> details;
-    details.push_back(absl::StrCat(slice_decision.start_time));
+    details.push_back(absl::StrCat(slice_decision.exclusive_start_time));
     details.push_back(absl::StrCat(prefetch_end));
     details.push_back(absl::StrCat(allocation_end));
     details.push_back(absl::StrCat(slice_decision.copy_resource_consumed));
@@ -6033,27 +6081,27 @@
           << SliceTimesAndCopyResourcesToString(
                  slice_decisions_sorted_by_start_time, prefetch_end_time,
                  allocation_end_time);
-  CHECK(absl::c_all_of(slice_decisions_sorted_by_start_time,
-                       [&](const auto& slice_decision) {
-                         return slice_decision.start_time < prefetch_end_time;
-                       }));
+  CHECK(absl::c_all_of(
+      slice_decisions_sorted_by_start_time, [&](const auto& slice_decision) {
+        return slice_decision.exclusive_start_time < prefetch_end_time;
+      }));
 
   allocations->push_back(
       std::make_unique<MemorySpaceAssignment::SlicedCopyAllocation>(
           prev_allocation, MemorySpaceAssignment::MemorySpace::kAlternate,
-          slice_decisions_sorted_by_start_time, allocation_end_time,
-          prefetch_end_time, options_.update_layout_fn));
+          slice_decisions_sorted_by_start_time, prefetch_end_time,
+          allocation_end_time, options_.update_layout_fn));
 
   // Register the additional async copy with the interval tree to keep track of
   // the limit at any given time.
   for (const auto& slice_decision : slice_decisions_sorted_by_start_time) {
     pending_async_copies_.push_back(
-        {slice_decision.start_time, prefetch_end_time,
+        {slice_decision.exclusive_start_time, prefetch_end_time,
          slice_decision.copy_resource_consumed,
          MemorySpaceAssignment::MemorySpace::kAlternate,
          next_async_copy_id_++});
-    prefetch_interval_tree_.Add(slice_decision.start_time, prefetch_end_time,
-                                kDummyChunk);
+    prefetch_interval_tree_.Add(slice_decision.exclusive_start_time,
+                                prefetch_end_time, kDummyChunk);
     prefetch_async_copy_resource_.AddCopy(pending_async_copies_.back());
     if (options_.enforce_prefetch_fifo_order) {
       async_copy_ordering_.AddCopy(pending_async_copies_.back());
@@ -6063,7 +6111,7 @@
 }
 
 bool AlternateMemoryBestFitHeap::ViolatesMaximumOutstandingAsyncCopies(
-    int64_t start_time, int64_t end_time, bool is_prefetch,
+    int64_t inclusive_start_time, int64_t end_time, bool is_prefetch,
     int64_t extra_async_copy_limit, int64_t num_additional_copies) const {
   if (options_.max_outstanding_prefetches < 0 && is_prefetch) {
     return false;
@@ -6075,14 +6123,16 @@
   // Count the prefetches/evictions in the interval tree for the given interval.
   if (is_prefetch) {
     int64_t num_prefetches =
-        prefetch_interval_tree_.ChunksOverlappingInTime(start_time, end_time)
+        prefetch_interval_tree_
+            .ChunksOverlappingInTime(inclusive_start_time, end_time)
             .size() +
         num_additional_copies;
     return num_prefetches >=
            options_.max_outstanding_prefetches + extra_async_copy_limit;
   } else {
     int64_t num_evictions =
-        eviction_interval_tree_.ChunksOverlappingInTime(start_time, end_time)
+        eviction_interval_tree_
+            .ChunksOverlappingInTime(inclusive_start_time, end_time)
             .size() +
         num_additional_copies;
     return num_evictions >=
@@ -6120,7 +6170,8 @@
   // duration checks.
   if (!request.prefer_no_copy_alternate_mem_allocation &&
       !options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
-          defining_position.shape(), request.start_time, request.end_time)) {
+          defining_position.shape(), request.inclusive_start_time,
+          request.end_time)) {
     VLOG(3) << "Live range is too long.";
     return Result::kFailLiveRangeTooLong;
   }
@@ -6129,7 +6180,7 @@
   alternate_mem_interval.buffer = request.allocation_value->value();
   alternate_mem_interval.size = request.size;
   alternate_mem_interval.end = request.end_time;
-  alternate_mem_interval.start = request.start_time;
+  alternate_mem_interval.start = request.inclusive_start_time;
 
   // Prefer the offset that was previously used for the previous allocation.
   AliasedOffset* preferred_offset = nullptr;
@@ -6182,7 +6233,9 @@
             << ", heap_size = " << result_.UpdatedHeapSize(*chunk_candidate)
             << ", prefetch picker = "
             << options_.prefetch_interval_picker->ToNoCopyDebugString(
-                   defining_position.shape(), request.start_time,
+                   defining_position.shape(),
+                   /*start_time=*/
+                   InclusiveToExclusiveStartTime(request.inclusive_start_time),
                    request.end_time);
     AddToPendingChunks(alternate_mem_interval, *chunk_candidate);
 
@@ -6196,7 +6249,7 @@
       request.allocation_value->mutable_allocation_sequence()->push_back(
           std::make_unique<MemorySpaceAssignment::Allocation>(
               defining_position, MemorySpace::kAlternate, chunk_candidate,
-              request.start_time, request.end_time,
+              request.inclusive_start_time, request.end_time,
               /*is_scoped_allocation=*/false));
       CreateOrAddToAliasedOffset(
           *request.allocation_value->allocation_sequence()->back(),
@@ -6218,14 +6271,27 @@
   CHECK_GT(request.allocation_value->allocation_sequence()->size(), 0);
   MemorySpaceAssignment::Allocation* prev_allocation =
       request.allocation_value->allocation_sequence()->back().get();
-  int64_t eviction_start_time = prev_allocation->start_time();
+  // We do not ever expect an Evict() to be immediately proceeded by a prefetch.
+  // If that case ever occurs, the eviction_exclusive_start_time below will be
+  // calculated incorrectly, as it will need to come after the prefetch finishes
+  // coping data.
+  CHECK(!prev_allocation->is_copy_like_allocation())
+      << "Evict has been given copy-like previous allocation.\nEvict "
+         "candidate:\n"
+      << request.allocation_value->ToString() << "\nPrevious allocation:\n"
+      << prev_allocation->ToString();
+
+  // The previous allocation's inclusive start time is the eviction's exclusive
+  // start time to ensure that the value is created before we start copying
+  // back to default memory.
+  int64_t eviction_exclusive_start_time = prev_allocation->start_time();
   int64_t eviction_end_time = prev_allocation->end_time();
-  CHECK(eviction_start_time <= eviction_end_time);
+  CHECK(eviction_exclusive_start_time <= eviction_end_time);
 
   int64_t preferred_eviction_end_time =
       std::max(options_.prefetch_interval_picker->PreferredEvictionEndTime(
                    request.allocation_value->defining_position().shape(),
-                   eviction_start_time, request.end_time),
+                   eviction_exclusive_start_time, request.end_time),
                eviction_end_time);
   // Evictions must complete by the time of this use.
   preferred_eviction_end_time =
@@ -6239,8 +6305,8 @@
   eviction_mem_interval.start = eviction_end_time + 1;
   eviction_mem_interval.end = preferred_eviction_end_time;
   int64_t preferred_offset = prev_allocation->chunk().offset;
-  VLOG(3) << "Eviction (" << eviction_start_time << ", " << eviction_end_time
-          << ") preferred end time = " << eviction_mem_interval.end;
+  VLOG(3) << "Considering eviction after" << eviction_exclusive_start_time
+          << ", with preferred end time = " << eviction_mem_interval.end;
 
   for (; eviction_mem_interval.end > eviction_end_time;
        --eviction_mem_interval.end) {
@@ -6254,7 +6320,7 @@
   eviction_end_time = eviction_mem_interval.end;
 
   VLOG(3) << "Evicting buffer at " << prev_allocation->chunk().offset << " ("
-          << eviction_start_time << ", " << eviction_end_time << ")";
+          << eviction_exclusive_start_time << ", " << eviction_end_time << ")";
 
   float eviction_resource =
       options_.cost_analysis
@@ -6262,10 +6328,11 @@
                 request.allocation_value->defining_position().shape())
           : 0.1;
 
-  bool eviction_interval_too_short = (eviction_start_time == eviction_end_time);
+  bool eviction_interval_too_short =
+      (eviction_exclusive_start_time == eviction_end_time);
   bool eviction_violates_resource =
       !eviction_async_copy_resource_.HasEnoughResource(
-          eviction_start_time, eviction_end_time, eviction_resource);
+          eviction_exclusive_start_time, eviction_end_time, eviction_resource);
   if (eviction_violates_resource) {
     // If we're in the last retry, set resource to 0.
     if (options_.prefetch_interval_picker->retry_number() ==
@@ -6275,19 +6342,22 @@
     }
     eviction_violates_resource =
         !eviction_async_copy_resource_.HasEnoughResource(
-            eviction_start_time, eviction_end_time, eviction_resource);
+            eviction_exclusive_start_time, eviction_end_time,
+            eviction_resource);
   }
   bool eviction_violates_outstanding_copies =
-      ViolatesMaximumOutstandingAsyncCopies(eviction_start_time,
-                                            eviction_end_time,
-                                            /*is_prefetch=*/false);
+      ViolatesMaximumOutstandingAsyncCopies(
+          /*inclusive_start_time=*/ExclusiveToInclusiveStartTime(
+              eviction_exclusive_start_time),
+          eviction_end_time,
+          /*is_prefetch=*/false);
 
   // See if this interval would violate the asynchronous copy limit.
   if (!eviction_interval_too_short && !eviction_violates_outstanding_copies &&
       !eviction_violates_resource) {
     prev_allocation->set_end_time(eviction_end_time);
     AddAsyncCopy(*prev_allocation, MemorySpace::kDefault,
-                 /*chunk=*/std::nullopt, eviction_start_time,
+                 /*chunk=*/std::nullopt, eviction_exclusive_start_time,
                  prev_allocation->end_time(), eviction_end_time,
                  request.allocation_value->mutable_allocation_sequence(),
                  /*aliased_offset=*/nullptr, eviction_resource);
@@ -6297,8 +6367,9 @@
     } else if (eviction_violates_resource) {
       VLOG(3) << "This violates resource.";
     } else {
-      VLOG(3) << "Eviction interval is too short (" << eviction_start_time
-              << ", " << eviction_end_time << ").";
+      VLOG(3) << "Eviction interval is too short ("
+              << eviction_exclusive_start_time << ", " << eviction_end_time
+              << ").";
     }
     // If the original interval violated the limit, try sub-intervals within
     // this interval.
@@ -6309,17 +6380,16 @@
       // kept in the default memory.
       VLOG(3) << "Bailing: Could not evict " << request.use->hlo_use.ToString()
               << " because we hit the limit of maximum asynchronous copies "
-              << "between "
+              << "between ("
               << hlo_live_range_.flattened_instruction_sequence()
-                     .instructions()[eviction_start_time]
-              << " and "
+                     .instructions()[eviction_exclusive_start_time]
+              << ", "
               << hlo_live_range_.flattened_instruction_sequence()
-                     .instructions()[eviction_end_time];
-      // return false;
+                     .instructions()[eviction_end_time]
+              << ")";
       return Result::kFailOutOfAsyncCopies;
     }
   }
-  // return true;
   return Result::kSuccess;
 }
 
@@ -6340,9 +6410,9 @@
   slice_strings.reserve(slice_decisions.size());
 
   for (const auto& slice_decision : slice_decisions) {
-    slice_strings.push_back(absl::StrCat("(", slice_decision.start_time, ", ",
-                                         slice_decision.chunk.offset, ", ",
-                                         slice_decision.chunk.size, ")"));
+    slice_strings.push_back(absl::StrCat(
+        "(", slice_decision.exclusive_start_time, ", ",
+        slice_decision.chunk.offset, ", ", slice_decision.chunk.size, ")"));
   }
 
   return absl::StrCat(
@@ -6410,10 +6480,12 @@
   Result result = Result::kSuccess;
   while (!options_.prefetch_interval_picker->Done()) {
     // Get the prefetch start time from the interval picker.
-    context.prefetch_start_time = options_.prefetch_interval_picker->Next();
-    CHECK_LT(context.prefetch_start_time, context.prefetch_end_time);
-    if (context.out_of_mem_start.has_value() &&
-        context.prefetch_start_time <= *context.out_of_mem_start) {
+    context.exclusive_prefetch_start_time =
+        options_.prefetch_interval_picker->Next();
+    CHECK_LT(context.exclusive_prefetch_start_time, context.prefetch_end_time);
+    if (context.exclusive_out_of_mem_start.has_value() &&
+        context.exclusive_prefetch_start_time <=
+            *context.exclusive_out_of_mem_start) {
       VLOG(4) << "This would OOM (cached).";
       return Result::kFailOutOfMemory;
     }
@@ -6471,8 +6543,10 @@
     return Result::kSuccess;
   }
   if (context.unsliced_solution) {
-    VLOG(3) << "Move the buffer to alternate memory at time "
-            << context.unsliced_solution_intervals.full.start << ". Offset = "
+    VLOG(3) << "Move the buffer to alternate memory after time "
+            << InclusiveToExclusiveStartTime(
+                   context.unsliced_solution_intervals.full.start)
+            << ". Offset = "
             << context.unsliced_solution->chunk_candidate.offset
             << ", size = " << context.unsliced_solution->chunk_candidate.size
             << ", heap_size = "
@@ -6485,7 +6559,7 @@
     AddAsyncCopy(
         *context.prev_allocation_in_default_mem, MemorySpace::kAlternate,
         context.unsliced_solution->chunk_candidate,
-        context.unsliced_solution_intervals.full.start,
+        context.unsliced_solution_intervals.full.start - 1,
         context.request->end_time, context.prefetch_end_time,
         context.request->allocation_value->mutable_allocation_sequence(),
         context.request->preferred_offset,
@@ -6592,24 +6666,25 @@
 AlternateMemoryBestFitHeap::Result
 AlternateMemoryBestFitHeap::InitializePrefetchIntervalPicker(
     PrefetchContext& context) {
-  int64_t earliest_prefetch_time =
+  int64_t earliest_exclusive_prefetch_time =
       context.prev_allocation_in_default_mem->earliest_available_time();
   if (context.request->earliest_prefetch_time) {
-    earliest_prefetch_time = std::max(earliest_prefetch_time,
-                                      *context.request->earliest_prefetch_time);
+    earliest_exclusive_prefetch_time =
+        std::max(earliest_exclusive_prefetch_time,
+                 *context.request->earliest_prefetch_time);
   }
   context.prefetch_end_time =
-      FindPrefetchEndTime(*context.request, earliest_prefetch_time);
+      FindPrefetchEndTime(*context.request, earliest_exclusive_prefetch_time);
 
   // As a compile time optimization, use the peak memory usage to filter out
   // allocation times that would push us to OOM.
-  std::optional<int> earliest_non_oom_prefetch_time =
-      FindEarliestTimeToSatisfyPeakMemory(earliest_prefetch_time,
-                                          context.prefetch_end_time,
-                                          context.request->size);
-  if (!earliest_non_oom_prefetch_time) {
-    VLOG(3) << "Any prefetch in range (" << earliest_prefetch_time << ", "
-            << context.prefetch_end_time << ") for size "
+  std::optional<int> earliest_exclusive_non_oom_prefetch_time =
+      FindEarliestExclusiveTimeToSatisfyPeakMemory(
+          earliest_exclusive_prefetch_time, context.prefetch_end_time,
+          context.request->size);
+  if (!earliest_exclusive_non_oom_prefetch_time) {
+    VLOG(3) << "Any prefetch in range (" << earliest_exclusive_prefetch_time
+            << ", " << context.prefetch_end_time << ") for size "
             << context.request->size << " would go out of memory.";
     return Result::kFailOutOfMemory;
   }
@@ -6619,20 +6694,21 @@
     // buffer will fit, but we may be able to start slices before that time. So,
     // we leave earliest_prefetch_time at its initial value.
     VLOG(4) << "After peak memory check, prefetch range is ("
-            << *earliest_non_oom_prefetch_time << ", "
+            << *earliest_exclusive_non_oom_prefetch_time << ", "
             << context.prefetch_end_time
             << "). Original earliest prefetch time is "
-            << earliest_prefetch_time;
-    earliest_prefetch_time = *earliest_non_oom_prefetch_time;
+            << earliest_exclusive_prefetch_time;
+    earliest_exclusive_prefetch_time =
+        *earliest_exclusive_non_oom_prefetch_time;
   }
   std::optional<int64_t> preferred_prefetch_time =
       context.request->preferred_prefetch_time;
   if (preferred_prefetch_time) {
     preferred_prefetch_time =
-        std::max(*preferred_prefetch_time, earliest_prefetch_time);
+        std::max(*preferred_prefetch_time, earliest_exclusive_prefetch_time);
   }
   options_.prefetch_interval_picker->Begin(
-      context.request->use->hlo_use, earliest_prefetch_time,
+      context.request->use->hlo_use, earliest_exclusive_prefetch_time,
       context.prefetch_end_time, preferred_prefetch_time);
   VLOG(3) << "Trying prefetch picker = "
           << options_.prefetch_interval_picker->ToDebugString();
@@ -6648,9 +6724,9 @@
            ? context.sliced_solution_intervals.sliced.get()
            : context.unsliced_solution_intervals.sliced.get());
 
-  // Note, UpdateSliceStartTimes() will correctly update start times for both
-  // sliced and unsliced solutions.
-  interval->UpdateSliceStartTimes(
+  // Note, UpdateInclusiveSliceStartTimes() will correctly update start times
+  // for both sliced and unsliced solutions.
+  interval->UpdateExclusiveSliceStartTimes(
       std::vector<int64_t>(interval->num_slices(),
                            options_.prefetch_interval_picker->latest_time()));
   std::vector<Chunk> chunk_candidates = FindBestChunkCandidates(
@@ -6729,7 +6805,8 @@
         absl::StrJoin(specs, ", ",
                       [](std::string* out,
                          const AsynchronousCopyResource::ResourceSpec& spec) {
-                        absl::StrAppend(out, "{start: ", spec.start_time,
+                        absl::StrAppend(out, "{exclusive start: ",
+                                        spec.exclusive_start_time,
                                         ", end: ", spec.end_time,
                                         ", resource: ", spec.resource, "}");
                       }),
@@ -6787,20 +6864,23 @@
   }
 
   // Update the prefetch start time in our working solution.
-  std::vector<int64_t> slice_start_times = PickSliceStartTimes(
-      sliced_buffer_interval->num_slices(), context.prefetch_start_time,
-      context.prefetch_end_time);
-  CHECK_EQ(sliced_buffer_interval->num_slices(), slice_start_times.size());
-  sliced_buffer_interval->UpdateSliceStartTimes(slice_start_times);
+  std::vector<int64_t> exclusive_slice_start_times = PickSliceStartTimes(
+      sliced_buffer_interval->num_slices(),
+      context.exclusive_prefetch_start_time, context.prefetch_end_time);
+  CHECK_EQ(sliced_buffer_interval->num_slices(),
+           exclusive_slice_start_times.size());
+  sliced_buffer_interval->UpdateExclusiveSliceStartTimes(
+      exclusive_slice_start_times);
   VLOG(4) << AlternateMemoryAllocationAttemptToString(for_sliced_solution,
                                                       context);
 
   // Check if all slices have the same start time. If so, we might as well
   // resort to a full copy.
   if (for_sliced_solution &&
-      absl::c_all_of(slice_start_times, [&](int64_t slice_start_time) {
-        return slice_start_time == slice_start_times.front();
-      })) {
+      absl::c_all_of(
+          exclusive_slice_start_times, [&](int64_t slice_start_time) {
+            return slice_start_time == exclusive_slice_start_times.front();
+          })) {
     return Result::kAllSlicesHaveTheSameStartTime;
   }
 
@@ -6811,7 +6891,7 @@
   // resources here.
   if (context.request->preferred_prefetch_time) {
     copy_resource_per_slice_sorted_by_start_time =
-        std::vector<float>(slice_start_times.size(), 0.0);
+        std::vector<float>(exclusive_slice_start_times.size(), 0.0);
   } else if (for_sliced_solution) {
     // In a sliced setting, we don't yet know when each slice will be
     // prefetched. Given the proposed slice times, the most conservative copy
@@ -6832,7 +6912,8 @@
   CHECK_EQ(sliced_buffer_interval->num_slices(),
            copy_resource_per_slice_sorted_by_start_time.size());
 
-  if (!DoWeHaveEnoughCopyResource(slice_start_times, context.prefetch_end_time,
+  if (!DoWeHaveEnoughCopyResource(exclusive_slice_start_times,
+                                  context.prefetch_end_time,
                                   copy_resource_per_slice_sorted_by_start_time,
                                   prefetch_async_copy_resource_)) {
     return Result::kFailViolatesAsyncCopyResource;
@@ -6841,20 +6922,20 @@
   // Check if the copies we would add for the prefetch would violate copy
   // ordering.
   if (options_.enforce_prefetch_fifo_order &&
-      std::any_of(slice_start_times.begin(), slice_start_times.end(),
-                  [&](int64_t slice_start_time) {
-                    return async_copy_ordering_.ViolatesOrdering(
-                        slice_start_time, context.prefetch_end_time);
-                  })) {
+      absl::c_any_of(exclusive_slice_start_times,
+                     [&](int64_t slice_start_time) {
+                       return async_copy_ordering_.ViolatesOrdering(
+                           slice_start_time, context.prefetch_end_time);
+                     })) {
     VLOG(4) << "This would violate asynchronous copy ordering.";
     return Result::kFailViolatesAsyncCopyResource;
   }
 
   // Check if the copies we would add for the prefetch violate the maximum
   // number of outstanding async copies.
-  for (int i = 0; i < slice_start_times.size(); ++i) {
+  for (int i = 0; i < exclusive_slice_start_times.size(); ++i) {
     if (ViolatesMaximumOutstandingAsyncCopies(
-            slice_start_times[i], context.prefetch_end_time,
+            exclusive_slice_start_times[i], context.prefetch_end_time,
             /*is_prefetch=*/true, context.extra_async_copy_limit, i)) {
       VLOG(4) << "This would violate the outstanding async copy limit.";
       return Result::kFailOutOfAsyncCopies;
@@ -6893,14 +6974,14 @@
           CopyResourceForShape(options_, proposal.slice_shape);
       slice_decisions_sorted_by_start_time.push_back(
           MemorySpaceAssignment::SliceDecision{
-              chunk_candidates[slice_time], slice_start_times[slice_time],
-              proposal,
+              chunk_candidates[slice_time],
+              exclusive_slice_start_times[slice_time], proposal,
               copy_resource_per_slice_sorted_by_start_time[slice_time]});
     }
 
     // Check that we have enough copy resources for all the slice decisions.
     if (!DoWeHaveEnoughCopyResource(
-            slice_start_times, context.prefetch_end_time,
+            exclusive_slice_start_times, context.prefetch_end_time,
             copy_resource_per_slice_sorted_by_start_time,
             prefetch_async_copy_resource_)) {
       return Result::kFailViolatesAsyncCopyResource;
@@ -6921,7 +7002,8 @@
     BufferInterval final_buffer_interval{
         context.request->allocation_value->value(),
         /*size=*/final_chunk.size,
-        /*start=*/slice_start_times.back(),
+        /*start=*/
+        ExclusiveToInclusiveStartTime(exclusive_slice_start_times.back()),
         /*end=*/context.request->end_time,
         /*colocations=*/
         sliced_buffer_interval->full_buffer_interval().colocations,
@@ -6929,20 +7011,23 @@
     for (int64_t slice_time = 0;
          slice_time < sliced_buffer_interval->num_slices(); ++slice_time) {
       const Chunk& chunk = chunk_candidates[slice_time];
-      int64_t start_time = slice_start_times[slice_time];
-      if (start_time == slice_start_times.back()) {
+      int64_t inclusive_start_time = ExclusiveToInclusiveStartTime(
+          exclusive_slice_start_times[slice_time]);
+      if (inclusive_start_time ==
+          ExclusiveToInclusiveStartTime(exclusive_slice_start_times.back())) {
         // This and the following chunks will be merged into the final chunk.
         // Note, it's possible for more than one slice to start at the same
         // time.
         break;
       }
-      CHECK_LE(start_time, slice_start_times.back() - 1);
+      CHECK_LT(inclusive_start_time, ExclusiveToInclusiveStartTime(
+                                         exclusive_slice_start_times.back()));
       slices_for_pending_chunks.push_back(std::make_pair(
           BufferInterval{
               context.request->allocation_value->value(),
               /*size=*/chunk.size,
-              /*start=*/start_time,
-              /*end=*/slice_start_times.back() - 1,
+              /*start=*/inclusive_start_time,
+              /*end=*/exclusive_slice_start_times.back(),
               // We only use the final_buffer_interval for colocations because
               // slices start at different offsets, and the colocation
               // infrastructure expects all colocated buffers to start at the
@@ -6980,12 +7065,14 @@
   // Thus, if we are considering a sliced prefetch for the current request,
   // we can only update out_of_mem_start when we check with slices.
   if (for_sliced_solution || !context.slice_proposal_collection) {
-    CHECK_GT(slice_start_times.size(), 0);
-    context.out_of_mem_start =
-        std::max(context.out_of_mem_start ? *context.out_of_mem_start : -1,
-                 slice_start_times.front());
+    CHECK_GT(exclusive_slice_start_times.size(), 0);
+    context.exclusive_out_of_mem_start = std::max(
+        context.exclusive_out_of_mem_start ? *context.exclusive_out_of_mem_start
+                                           : -1,
+        exclusive_slice_start_times.front());
   }
 
+  VLOG(4) << "Out of memory.";
   return Result::kFailOutOfMemory;
 }
 
@@ -7229,23 +7316,6 @@
   return stats;
 }
 
-/*static*/ MemorySpaceAssignment::BufferIntervalCompare
-MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare(
-    const MemorySpaceAssignmentCostAnalysis& cost_analysis,
-    MemorySpaceAssignmentCostAnalysis::Cache* cache) {
-  return [&cost_analysis, cache](const BufferInterval& x,
-                                 const BufferInterval& y) {
-    float x_memory_boundedness = cost_analysis.GetMemoryBoundedness(x, cache);
-    float y_memory_boundedness = cost_analysis.GetMemoryBoundedness(y, cache);
-    if (x_memory_boundedness != y_memory_boundedness) {
-      return x_memory_boundedness > y_memory_boundedness;
-    }
-    // Tie-break if the memory boundedness is the same.
-    return GlobalDecreasingSizeBestFitHeap<
-        HloValue>::GetSpatialBufferIntervalCompare()(x, y);
-  };
-}
-
 /*static*/ StatusOr<std::unique_ptr<PresetAssignments>>
 MemorySpaceAssignment::Run(HloModule* module,
                            const HloLiveRange& hlo_live_range,
@@ -7513,10 +7583,11 @@
 }
 
 std::string MemorySpaceAssignment::SliceDecision::ToString() const {
-  return absl::StrCat(
-      "{ chunk: ", chunk.ToString(), ", start_time: ", start_time,
-      ", sizing: ", sizing.ToString(),
-      ", copy_resource_consumed: ", copy_resource_consumed, " }");
+  return absl::StrCat("{ chunk: ", chunk.ToString(),
+                      ", (exclusive) start_time: ", exclusive_start_time,
+                      ", sizing: ", sizing.ToString(),
+                      ", copy_resource_consumed: ", copy_resource_consumed,
+                      " }");
 }
 
 namespace {
@@ -7524,9 +7595,9 @@
 std::tuple<const MemorySpaceAssignment::Chunk&, int64_t,
            const MemorySpaceAssignment::SliceProposal&, float>
 SliceDecisionToTuple(const MemorySpaceAssignment::SliceDecision& decision) {
-  return std::make_tuple(std::ref(decision.chunk), decision.start_time,
-                         std::ref(decision.sizing),
-                         decision.copy_resource_consumed);
+  return std::make_tuple(
+      std::ref(decision.chunk), decision.exclusive_start_time,
+      std::ref(decision.sizing), decision.copy_resource_consumed);
 }
 
 }  // namespace
@@ -7652,38 +7723,43 @@
 }
 
 // Helper function to compute the start time for a SlicedCopyAllocation.
-int64_t GetSlicedCopyAllocationStartTime(
+int64_t GetSlicedCopyAllocationExclusiveStartTime(
     const std::vector<MemorySpaceAssignment::SliceDecision>&
-        slice_decisions_sorted_by_start_time) {
-  if (slice_decisions_sorted_by_start_time.empty()) {
+        slice_decisions_sorted_by_exclusive_start_time) {
+  if (slice_decisions_sorted_by_exclusive_start_time.empty()) {
     return -1;
   }
 
-  return slice_decisions_sorted_by_start_time.front().start_time;
+  return slice_decisions_sorted_by_exclusive_start_time.front()
+      .exclusive_start_time;
 }
 
 }  // namespace
 
 MemorySpaceAssignment::SlicedCopyAllocation::SlicedCopyAllocation(
     const Allocation& prev_allocation, MemorySpace memory_space,
-    std::vector<SliceDecision> slice_decisions_sorted_by_start_time,
-    int64_t end_time, int64_t copy_done_schedule_before_time,
+    std::vector<SliceDecision> slice_decisions_sorted_by_exclusive_start_time,
+    int64_t copy_done_schedule_before_time, int64_t end_time,
     absl::FunctionRef<void(Shape*)> update_layout_fn)
     : Allocation(
           /*defining_position=*/{nullptr, {}}, memory_space,
-          GetSlicedCopyAllocationChunk(slice_decisions_sorted_by_start_time),
-          GetSlicedCopyAllocationStartTime(
-              slice_decisions_sorted_by_start_time),
+          GetSlicedCopyAllocationChunk(
+              slice_decisions_sorted_by_exclusive_start_time),
+          // Allocation uses an inclusive start time
+          ExclusiveToInclusiveStartTime(
+              GetSlicedCopyAllocationExclusiveStartTime(
+                  slice_decisions_sorted_by_exclusive_start_time)),
           end_time,
           /*is_scoped_allocation=*/false),
       original_shape_to_slice_(prev_allocation.defining_position().shape()),
       prev_allocation_(prev_allocation),
       update_layout_fn_(update_layout_fn) {
-  CHECK_GE(slice_decisions_sorted_by_start_time.size(), 2);
+  CHECK_GE(slice_decisions_sorted_by_exclusive_start_time.size(), 2);
   slice_details_sorted_by_start_time_.reserve(
-      slice_decisions_sorted_by_start_time.size());
-  for (SliceDecision& decision : slice_decisions_sorted_by_start_time) {
-    int64_t copy_done_schedule_after_time = decision.start_time;
+      slice_decisions_sorted_by_exclusive_start_time.size());
+  for (SliceDecision& decision :
+       slice_decisions_sorted_by_exclusive_start_time) {
+    int64_t copy_done_schedule_after_time = decision.exclusive_start_time;
     slice_details_sorted_by_start_time_.push_back(SliceDetail{
         std::move(decision),
         copy_done_schedule_after_time,
@@ -7845,8 +7921,10 @@
     const MemorySpaceAssignmentRepacker::Slice& repacked_slice_data =
         data.slices_sorted_by_offset[i];
     chunk = Chunk::FromOffsetSize(repacked_slice_data.offset, chunk.size);
-    slice_detail->copy_start_after_time = repacked_slice_data.start_time;
-    slice_detail->slice_decision.start_time = repacked_slice_data.start_time;
+    slice_detail->copy_start_after_time =
+        repacked_slice_data.inclusive_start_time - 1;
+    slice_detail->slice_decision.exclusive_start_time =
+        InclusiveToExclusiveStartTime(repacked_slice_data.inclusive_start_time);
   }
 
   absl::c_sort(slice_details_sorted_by_start_time_,
@@ -7919,6 +7997,21 @@
                       original_allocation_.ToString());
 }
 
+MemorySpaceAssignment::CopyAllocation::CopyAllocation(
+    Allocation& prev_allocation, MemorySpace memory_space,
+    std::optional<Chunk> chunk, int64_t copy_start_schedule_after_time,
+    int64_t copy_done_schedule_before_time, int64_t end_time,
+    std::optional<int64_t> cross_program_prefetch_index)
+    : Allocation(/*defining_position=*/{nullptr, {}}, memory_space, chunk,
+                 // Allocation uses an inclusive start time
+                 ExclusiveToInclusiveStartTime(copy_start_schedule_after_time),
+                 end_time,
+                 /*is_scoped_allocation=*/false),
+      prev_allocation_(prev_allocation),
+      copy_start_schedule_after_(copy_start_schedule_after_time),
+      copy_done_schedule_before_(copy_done_schedule_before_time),
+      cross_program_prefetch_index_(cross_program_prefetch_index) {}
+
 Status MemorySpaceAssignment::CopyAllocation::Process() {
   // Copy allocations need to insert asynchronous copy nodes.
   Shape shape = defining_position().shape();
@@ -8491,8 +8584,15 @@
 
         // Accessing flattened_instructions_ here without checking if it is
         // nullptr is safe because this method is called before SimplifyGraph.
-        while (async_copy_step->defining_position().instruction->parent() !=
-               flattened_instructions_[copy_start_schedule_after]->parent()) {
+        while (
+            async_copy_step->defining_position().instruction->parent() !=
+            flattened_instructions_[
+                // We can't use -1 to index into flatten_instructions_. However,
+                // if we want to place the copy as first instruction, i.e.,
+                // after the -1 scheduling position, its parent will be the same
+                // as the first instruction, i.e., the one at the 0th position.
+                std::max<int64_t>(0, copy_start_schedule_after)]
+                ->parent()) {
           VLOG(4) << "Delaying CopyStart (" << copy_start_schedule_after
                   << " to " << (copy_start_schedule_after + 1) << ") for "
                   << start_phase->instruction->ToString()
@@ -8815,5 +8915,107 @@
 
   return OkStatus();
 }
+
+DefaultCrossProgramPrefetchBufferIntervalComparator::
+    DefaultCrossProgramPrefetchBufferIntervalComparator(
+        const HloLiveRange& hlo_live_range)
+    : MemorySpaceAssignment::BufferIntervalComparator(),
+      hlo_live_range_(hlo_live_range) {}
+
+std::string DefaultCrossProgramPrefetchBufferIntervalComparator::
+    DescribeComparisonCriteria() const {
+  return "[ -size, -cumulative use size, latest use, instruction id]";
+}
+
+std::string
+DefaultCrossProgramPrefetchBufferIntervalComparator::CriteriaToString(
+    const BufferInterval& buffer_interval) {
+  return absl::StrCat("[ ", absl::StrJoin(GetTuple(buffer_interval), ", "),
+                      " ]");
+}
+
+bool DefaultCrossProgramPrefetchBufferIntervalComparator::LessThan(
+    const BufferInterval& lhs, const BufferInterval& rhs) {
+  return GetTuple(lhs) < GetTuple(rhs);
+}
+
+DefaultCrossProgramPrefetchBufferIntervalComparator::ComparisonTuple
+DefaultCrossProgramPrefetchBufferIntervalComparator::GetTuple(
+    const BufferInterval& buffer_interval) {
+  auto sort_data_it = additional_sort_data_.find(buffer_interval.buffer);
+  if (sort_data_it == additional_sort_data_.end()) {
+    AdditionalSortData sort_data;
+    absl::c_for_each(buffer_interval.buffer->GetUses(), [&](const HloUse& use) {
+      auto it = hlo_live_range_.instruction_schedule().find(use.instruction);
+      if (it == hlo_live_range_.instruction_schedule().end()) {
+        return;
+      }
+      sort_data.latest_use = std::max(sort_data.latest_use, it->second);
+      sort_data.cumulative_use_size +=
+          ShapeUtil::ElementsInRecursive(use.instruction->shape());
+    });
+    sort_data_it = additional_sort_data_
+                       .insert(std::make_pair(buffer_interval.buffer,
+                                              std::move(sort_data)))
+                       .first;
+  }
+
+  return std::make_tuple(
+      -1 * buffer_interval.size, -1 * sort_data_it->second.cumulative_use_size,
+      sort_data_it->second.latest_use, buffer_interval.buffer->id());
+}
+
+MemoryBoundednessBufferIntervalComparator::
+    MemoryBoundednessBufferIntervalComparator(
+        const MemorySpaceAssignmentCostAnalysis& cost_analysis,
+        MemorySpaceAssignmentCostAnalysis::Cache* cost_analysis_cache)
+    : MemorySpaceAssignment::BufferIntervalComparator(),
+      cost_analysis_(cost_analysis),
+      cost_analysis_cache_(cost_analysis_cache) {}
+
+std::string
+MemoryBoundednessBufferIntervalComparator::DescribeComparisonCriteria() const {
+  return "[ -memory boundedness, -size, -buffer duration, latest use time, "
+         "(inclusive) start time, instruction id ]";
+}
+
+std::string MemoryBoundednessBufferIntervalComparator::CriteriaToString(
+    const BufferInterval& buffer_interval) {
+  return absl::StrCat("[ ", absl::StrJoin(GetTuple(buffer_interval), ", "),
+                      " ]");
+}
+
+bool MemoryBoundednessBufferIntervalComparator::LessThan(
+    const BufferInterval& lhs, const BufferInterval& rhs) {
+  return GetTuple(lhs) < GetTuple(rhs);
+}
+
+MemoryBoundednessBufferIntervalComparator::ComparisonTuple
+MemoryBoundednessBufferIntervalComparator::GetTuple(
+    const BufferInterval& buffer_interval) {
+  auto latest_use_it = buffer_to_latest_use_.find(buffer_interval.buffer);
+  if (latest_use_it == buffer_to_latest_use_.end()) {
+    int64_t latest_use_time = 0;
+    for (const HloUse& use : buffer_interval.buffer->GetUses()) {
+      auto it = cost_analysis_.hlo_live_range().instruction_schedule().find(
+          use.instruction);
+      if (it != cost_analysis_.hlo_live_range().instruction_schedule().end()) {
+        latest_use_time = std::max(latest_use_time, it->second);
+      }
+    }
+    latest_use_it =
+        buffer_to_latest_use_
+            .insert(std::make_pair(buffer_interval.buffer, latest_use_time))
+            .first;
+  }
+
+  return std::make_tuple(-1.0 * cost_analysis_.GetMemoryBoundedness(
+                                    buffer_interval, cost_analysis_cache_),
+                         -1 * buffer_interval.size,
+                         buffer_interval.start - buffer_interval.end,
+                         latest_use_it->second, buffer_interval.start,
+                         buffer_interval.buffer->id());
+}
+
 }  // namespace memory_space_assignment
 }  // namespace xla
diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.h b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.h
index 384c5bc..f518e9f 100644
--- a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.h
+++ b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.h
@@ -38,6 +38,8 @@
 #include "absl/functional/function_ref.h"
 #include "absl/types/span.h"
 #include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/utils/hlo_live_range.h"
+#include "xla/service/buffer_value.h"
 #include "xla/service/heap_simulator.h"
 #include "xla/service/hlo.pb.h"
 #include "xla/service/hlo_cost_analysis.h"
@@ -48,8 +50,8 @@
 #include "xla/statusor.h"
 
 namespace xla {
-
 namespace memory_space_assignment {
+
 // Forward Declaration of Options.
 class Options;
 
@@ -574,6 +576,39 @@
       const absl::flat_hash_set<ShapeIndex>& /*outputs_in_alternate_memory*/)>;
   using UpdateLayoutFunction = std::function<void(Shape*)>;
 
+  // The BufferInterval sorting interface that MemorySpaceAssignment expects.
+  class BufferIntervalComparator {
+   public:
+    using BufferInterval = MemorySpaceAssignment::BufferInterval;
+
+    virtual ~BufferIntervalComparator() = default;
+
+    // A logging string explaining the sorting criteria. E.g., [ -size, offset ]
+    // indicates we sort (desc) size, then (asc) offset.
+    virtual std::string DescribeComparisonCriteria() const = 0;
+
+    // A logging string containing the values used to sort buffer_interval.
+    // E.g., we might return [ -1024, 100 ], if the criteria is [ -size,
+    // offset ].
+    virtual std::string CriteriaToString(
+        const BufferInterval& buffer_interval) = 0;
+
+    // comparator.LessThan(lhs, rhs) will be used for BufferIntervalCompare.
+    virtual bool LessThan(const BufferInterval& lhs,
+                          const BufferInterval& rhs) = 0;
+
+    // Used to create a functor that can be passed to a method like std::sort.
+    // E.g., absl::c_sort(v, comparator.GetComparisonFunctor());
+    BufferIntervalCompare GetComparisonFunctor() {
+      return [this](const BufferInterval& lhs, const BufferInterval& rhs) {
+        return LessThan(lhs, rhs);
+      };
+    }
+
+   protected:
+    BufferIntervalComparator() = default;
+  };
+
   // MemorySpaceAssignment uses a notion of a slow and large default memory
   // space and a fast and small alternate memory space.
   enum class MemorySpace { kDefault, kAlternate };
@@ -717,22 +752,18 @@
   };
 
   // This class represents an allocation as a result of an asynchronous copy.
-  // Note: CopyStart instructions are inserted after `start_time` or later,
-  // while CopyDone instructions are inserted before
-  // `copy_done_schedule_before_time` or earlier.
+  // Note: CopyStart instructions are inserted after
+  // `copy_start_schedule_after`, while CopyDone instructions are inserted
+  // before `copy_done_schedule_before_time`.
   class CopyAllocation : public Allocation {
    public:
+    // TODO(b/307342076): Reorder scheduling times to be
+    // copy_start_schedule_after_time, copy_done_schedule_before_time, end_time
     CopyAllocation(
         Allocation& prev_allocation, MemorySpace memory_space,
-        std::optional<Chunk> chunk, int64_t start_time, int64_t end_time,
-        int64_t copy_done_schedule_before_time,
-        std::optional<int64_t> cross_program_prefetch_index = std::nullopt)
-        : Allocation(/*defining_position=*/{nullptr, {}}, memory_space, chunk,
-                     start_time, end_time, /*is_scoped_allocation=*/false),
-          prev_allocation_(prev_allocation),
-          copy_start_schedule_after_(start_time),
-          copy_done_schedule_before_(copy_done_schedule_before_time),
-          cross_program_prefetch_index_(cross_program_prefetch_index) {}
+        std::optional<Chunk> chunk, int64_t copy_start_schedule_after_time,
+        int64_t copy_done_schedule_before_time, int64_t end_time,
+        std::optional<int64_t> cross_program_prefetch_index = std::nullopt);
 
     bool is_copy_allocation() const override { return true; }
 
@@ -853,7 +884,7 @@
     bool operator==(const SliceDecision& other) const;
 
     Chunk chunk;
-    int64_t start_time;
+    int64_t exclusive_start_time;
     SliceProposal sizing;
     float copy_resource_consumed;
   };
@@ -914,7 +945,7 @@
     SlicedCopyAllocation(
         const Allocation& prev_allocation, MemorySpace memory_space,
         std::vector<SliceDecision> slice_decisions_sorted_by_start_time,
-        int64_t end_time, int64_t copy_done_schedule_before_time,
+        int64_t copy_done_schedule_before_time, int64_t end_time,
         absl::FunctionRef<void(Shape*)> update_layout_fn);
 
     bool is_sliced_copy_allocation() const override { return true; }
@@ -1191,10 +1222,6 @@
   // Calculates asynchronous copy statistics.
   StatusOr<AsyncCopyStats> CalculateAsyncCopyStats() const;
 
-  static BufferIntervalCompare GetMemoryBoundednessBufferIntervalCompare(
-      const MemorySpaceAssignmentCostAnalysis& cost_analysis,
-      MemorySpaceAssignmentCostAnalysis::Cache* cache = nullptr);
-
   // Verify that the memory space assignment is free of overlapping buffers and
   // export heap simulator trace to be used by buffer_assignment.
   Status VerifyAndExportHeapSimulatorTrace();
@@ -1283,6 +1310,69 @@
   absl::flat_hash_map<int64_t, std::vector<HloInstruction*>> schedule_before_;
 };
 
+// A BufferIntervalComparator that utilizes MemoryBoundedness as its primary
+// sorting criteria.
+//
+// This comparator caches HloValues -> latest use time.
+class MemoryBoundednessBufferIntervalComparator
+    : public MemorySpaceAssignment::BufferIntervalComparator {
+ public:
+  MemoryBoundednessBufferIntervalComparator(
+      const MemorySpaceAssignmentCostAnalysis& cost_analysis,
+      MemorySpaceAssignmentCostAnalysis::Cache* cost_analysis_cache);
+
+  ~MemoryBoundednessBufferIntervalComparator() override = default;
+
+  std::string DescribeComparisonCriteria() const override;
+  std::string CriteriaToString(const BufferInterval& buffer_interval) override;
+  bool LessThan(const BufferInterval& lhs, const BufferInterval& rhs) override;
+
+ private:
+  // See the value returned by DescribeComparisonCriteria() for the meaning of
+  // each tuple element.
+  using ComparisonTuple =
+      std::tuple<float, int64_t, int64_t, int64_t, int64_t, BufferValue::Id>;
+
+  ComparisonTuple GetTuple(const BufferInterval& buffer_interval);
+
+  absl::flat_hash_map<const HloValue*, int64_t> buffer_to_latest_use_;
+  const MemorySpaceAssignmentCostAnalysis& cost_analysis_;
+  MemorySpaceAssignmentCostAnalysis::Cache* cost_analysis_cache_;
+};
+
+// The default BufferIntervalComparator used for cross-program prefetching.
+//
+// This class caches HloValue -> {latest use, cumulative use size }.
+class DefaultCrossProgramPrefetchBufferIntervalComparator
+    : public MemorySpaceAssignment::BufferIntervalComparator {
+ public:
+  explicit DefaultCrossProgramPrefetchBufferIntervalComparator(
+      const HloLiveRange& hlo_live_range);
+
+  ~DefaultCrossProgramPrefetchBufferIntervalComparator() override = default;
+
+  std::string DescribeComparisonCriteria() const override;
+  std::string CriteriaToString(const BufferInterval& buffer_interval) override;
+  bool LessThan(const BufferInterval& lhs, const BufferInterval& rhs) override;
+
+ private:
+  // See the value returned by DescribeComparisonCriteria() for the meaning of
+  // each tuple element.
+  using ComparisonTuple =
+      std::tuple<int64_t, int64_t, int64_t, BufferValue::Id>;
+
+  struct AdditionalSortData {
+    int64_t latest_use = 0;
+    int64_t cumulative_use_size = 0;
+  };
+
+  ComparisonTuple GetTuple(const BufferInterval& buffer_interval);
+
+  absl::flat_hash_map<const HloValue*, AdditionalSortData>
+      additional_sort_data_;
+  const HloLiveRange& hlo_live_range_;
+};
+
 // Filters prefetches by matching against multiple filters and overrides the
 // preferred prefetch time for matching prefetches by the provided override
 // strategy.
@@ -1376,10 +1466,10 @@
   // Memory alignment of the alternate memory space.
   int64_t alignment_in_bytes = 1;
 
-  // If provided, we sort the buffers using this comparison function
-  // otherwise, we use GlobalDecreasingSizeBestFitHeap::kSpatial.
-  std::optional<MemorySpaceAssignment::BufferIntervalCompare>
-      buffer_interval_compare = std::nullopt;
+  // If provided, we sort the buffers using this comparator. Otherwise, we use
+  // GlobalDecreasingSizeBestFitHeap::kSpatial.
+  MemorySpaceAssignment::BufferIntervalComparator* buffer_interval_comparator =
+      nullptr;
 
   // This object determines how early and how late prefetches can occur.
   PrefetchIntervalPicker* prefetch_interval_picker = nullptr;
@@ -1564,7 +1654,7 @@
 // time (time that copy done is scheduled), the resource this copy would use,
 // its destination memory space, and a unique ID.
 struct AsynchronousCopy {
-  int64_t start_time;
+  int64_t exclusive_start_time;
   int64_t end_time;
   float resource;
   MemorySpaceAssignment::MemorySpace destination;
@@ -1573,7 +1663,8 @@
   std::tuple<int64_t, int64_t, float, MemorySpaceAssignment::MemorySpace,
              int64_t>
   AsTuple() const {
-    return std::make_tuple(start_time, end_time, resource, destination, id);
+    return std::make_tuple(exclusive_start_time, end_time, resource,
+                           destination, id);
   }
 };
 
@@ -1610,13 +1701,13 @@
   // The new asynchronous copy would violate the ordering guarantee because the
   // copy start is after an already committed asynchronous copy while its copy
   // done is before the committed copy.
-  bool ViolatesOrdering(int64_t start_time, int64_t end_time) const;
+  bool ViolatesOrdering(int64_t exclusive_start_time, int64_t end_time) const;
 
  private:
   // We use this data structure for keys into the map that has a custom
   // comparator for the ordering guarantees.
   struct Interval {
-    int64_t start_time;
+    int64_t exclusive_start_time;
     int64_t end_time;
 
     // We allow multiple prefetches that have one or both of the same start and
@@ -1625,8 +1716,10 @@
     // intervals that evaluate to be equal are those with the same start and end
     // times or those with intervals that violate the FIFO order.
     bool operator<(const Interval& other) const {
-      return (start_time < other.start_time && end_time <= other.end_time) ||
-             (start_time <= other.start_time && end_time < other.end_time);
+      return (exclusive_start_time < other.exclusive_start_time &&
+              end_time <= other.end_time) ||
+             (exclusive_start_time <= other.exclusive_start_time &&
+              end_time < other.end_time);
     }
   };
   // Stores asynchronous copies in a tree set respecting the pipelining order.
@@ -1642,7 +1735,7 @@
  public:
   // A specification of needed asynchronous copy resources.
   struct ResourceSpec {
-    int64_t start_time;
+    int64_t exclusive_start_time;
     int64_t end_time;
     float resource;
   };
@@ -1664,7 +1757,8 @@
 
   // Returns true if a copy with the given start and end times and resource can
   // be satisfied.
-  bool HasEnoughResource(int64_t start_time, int64_t end_time, float resource);
+  bool HasEnoughResource(int64_t exclusive_start_time, int64_t end_time,
+                         float resource);
 
   // Returns true if a set of copy specifications can be satisfied in the
   // order specified.
@@ -1693,7 +1787,7 @@
   // for any change to delay_[i], {i, delay_[i]} will be added to
   // delay_change_map, allowing callers to undo any modifications.
   bool ConsumeResource(
-      int64_t start_time, int64_t end_time, float resource,
+      int64_t exclusive_start_time, int64_t end_time, float resource,
       absl::flat_hash_map<int64_t, float>* delay_change_map = nullptr,
       float resource_to_free = 0.0);
 
@@ -2053,7 +2147,7 @@
   // If earliest_prefetch_time is set, prefetches cannot start before this
   // value.
   struct AllocationRequest {
-    int64_t start_time;
+    int64_t inclusive_start_time;
     int64_t end_time;
     int64_t latest_prefetch_time;
     int64_t size;
@@ -2203,14 +2297,14 @@
 
     // Intermediate calculations common to both the sliced and unsliced
     // solutions.
-    int64_t prefetch_start_time = -1;
+    int64_t exclusive_prefetch_start_time = -1;
     int64_t prefetch_end_time = -1;
     const Shape* full_shape;
     int64_t extra_async_copy_limit = 0;
     // As a compilation time optimization, store the prefetch start time where
     // we have first seen out of memory. There is no point of exploring prefetch
     // start times earlier than this point.
-    std::optional<int64_t> out_of_mem_start = std::nullopt;
+    std::optional<int64_t> exclusive_out_of_mem_start = std::nullopt;
 
     // Data structures used to compute and store the sliced solution.
     std::optional<MemorySpaceAssignment::SliceProposalCollection>
@@ -2487,7 +2581,7 @@
   // copies. An extra  async copy limit can be provided to increase the limit of
   // asynchronous copies for this instance.
   bool ViolatesMaximumOutstandingAsyncCopies(
-      int64_t start_time, int64_t end_time, bool is_prefetch,
+      int64_t inclusive_start_time, int64_t end_time, bool is_prefetch,
       int64_t extra_async_copy_limit = 0,
       int64_t num_additional_copies = 1) const;
 
@@ -2512,8 +2606,9 @@
   // Adds an asynchronous copy to allocations.
   void AddAsyncCopy(
       MemorySpaceAssignment::Allocation& prev_allocation,
-      MemorySpace memory_space, std::optional<Chunk> chunk, int64_t start_time,
-      int64_t end_time, int64_t copy_done_schedule_before_time,
+      MemorySpace memory_space, std::optional<Chunk> chunk,
+      int64_t exclusive_start_time, int64_t end_time,
+      int64_t copy_done_schedule_before_time,
       MemorySpaceAssignment::AllocationSequence* allocations,
       AliasedOffset* aliased_offset, float resource,
       std::optional<int> cross_program_prefetch_index = std::nullopt);
@@ -2563,12 +2658,11 @@
     return options_.max_size_in_bytes - reserved_in_bytes_;
   }
 
-  // Returns the earliest time in the [start_time, end_time] range that a new
-  // allocation with the given size would fit in the alternate memory. If it
-  // doesn't fit, it returns nullopt.
-  std::optional<int> FindEarliestTimeToSatisfyPeakMemory(int start_time,
-                                                         int end_time,
-                                                         int64_t size) const;
+  // Returns the earliest time in the (exclusive_start_time, end_time) range
+  // that a new allocation with the given size would fit in the alternate
+  // memory. If it doesn't fit, it returns nullopt.
+  std::optional<int> FindEarliestExclusiveTimeToSatisfyPeakMemory(
+      int exclusive_start_time, int end_time, int64_t size) const;
 
   // Creates and returns a RepackAllocationBlock.
   static RepackAllocationBlock MakeRepackAllocationBlock(
@@ -2576,7 +2670,7 @@
       int64_t initial_offset, int64_t id,
       MemorySpaceAssignment::Allocation* allocation) {
     RepackAllocationBlock allocation_block;
-    allocation_block.start_time = start_time;
+    allocation_block.inclusive_start_time = start_time;
     allocation_block.end_time = end_time;
     allocation_block.size = size;
     allocation_block.offset = -1;
diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc
index 381470c..dc20c25 100644
--- a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc
+++ b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc
@@ -110,6 +110,32 @@
   return ShapeSize(value.shape());
 }
 
+class TestBufferIntervalComparator
+    : public MemorySpaceAssignment::BufferIntervalComparator {
+ public:
+  explicit TestBufferIntervalComparator(
+      GlobalDecreasingSizeBestFitHeap<HloValue>::BufferIntervalCompare
+          compare_method)
+      : MemorySpaceAssignment::BufferIntervalComparator(),
+        compare_method_(compare_method) {}
+
+  ~TestBufferIntervalComparator() override = default;
+
+  std::string DescribeComparisonCriteria() const override {
+    return "internal to test";
+  }
+  std::string CriteriaToString(const BufferInterval& buffer_interval) override {
+    return "internal to test";
+  }
+  bool LessThan(const BufferInterval& lhs, const BufferInterval& rhs) override {
+    return compare_method_(lhs, rhs);
+  }
+
+ private:
+  GlobalDecreasingSizeBestFitHeap<HloValue>::BufferIntervalCompare
+      compare_method_;
+};
+
 class MemorySpaceAssignmentTestBase : public HloTestBase {
  protected:
   // We use the following two memory space values to describe the default (slow
@@ -183,10 +209,14 @@
             /*preferred_overlap_to_async_copy_ratio=*/1.5,
             /*max_overlap_to_mem_size_async_copy_ratio=*/10.0,
             /*mem_size_bytes=*/memory_space_options.max_size_in_bytes));
+    memory_space_assignment::MemoryBoundednessBufferIntervalComparator
+        comparator(*cost_analysis, &cache_);
     return AssignMemorySpace(
         module, memory_space_options,
-        MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare(
-            *cost_analysis, &cache_),
+        [&comparator](const MemorySpaceAssignment::BufferInterval& lhs,
+                      const MemorySpaceAssignment::BufferInterval& rhs) {
+          return comparator.LessThan(lhs, rhs);
+        },
         &prefetch_interval_picker);
   }
 
@@ -243,7 +273,12 @@
     if (options_override) {
       options = *options_override;
     }
-    options.buffer_interval_compare = buffer_interval_compare;
+    std::unique_ptr<TestBufferIntervalComparator> test_comparator;
+    if (buffer_interval_compare.has_value()) {
+      test_comparator = std::make_unique<TestBufferIntervalComparator>(
+          *buffer_interval_compare);
+      options.buffer_interval_comparator = test_comparator.get();
+    }
     options.prefetch_interval_picker = prefetch_interval_picker;
     options.size_fn = size_fn;
     if (options.is_allowed_in_alternate_mem_fn == nullptr) {
@@ -6389,11 +6424,13 @@
         absl::StrAppend(&colocations_str, colocation->id, ", ");
         colocations.insert(colocation->id);
       }
-      VLOG(1) << "Alloc id: " << block->id << " time: [" << block->start_time
-              << ", " << block->end_time << "] size: " << block->size
+      VLOG(1) << "Alloc id: " << block->id << " time: ["
+              << block->inclusive_start_time << ", " << block->end_time
+              << "] size: " << block->size
               << " init offset: " << block->initial_offset << " colocations: {"
               << colocations_str << "}";
-      auto it = repack_map_.find({block->start_time, block->initial_offset});
+      auto it = repack_map_.find(
+          {block->inclusive_start_time, block->initial_offset});
       if (it != repack_map_.end()) {
         modified = true;
         block->offset = it->second;
@@ -6877,7 +6914,7 @@
              allocations) {
         for (MemorySpaceAssignmentRepacker::AllocationBlock* block :
              allocations) {
-          if (block->start_time == block->end_time) {
+          if (block->inclusive_start_time == block->end_time) {
             EXPECT_GT(block->colocations.size(), 0);
           }
         }
@@ -8844,6 +8881,129 @@
                          op::Fusion()));
 }
 
+// Test description:
+// - Setup: Make sure p1 can not be prefetched to alternate memory until after
+//   instruction c. We do this by causing p0 to be prefetched to alternate
+//   memory for use in c. Since p0 is larger than 1/2 of alternate memory, we
+//   will not be able to prefetch p1 until after p0 is unallocated.
+// - Test: prefetch p1, after p0 is unallocated from alternate memory (after
+//   instruction c).
+TEST_P(MemorySpaceAssignmentTest, CopyResourceIntegration) {
+  std::string_view hlo_string = R"(
+HloModule module, is_scheduled=true
+
+ENTRY main {
+  p0 = s32[8,8] parameter(0)
+  p1 = s32[8,8] parameter(1)
+  p2 = s32[] parameter(2)
+  a = negate(p2)
+  b = negate(a)
+  c = add(p0, p0)
+  d = negate(b)
+  e = negate(d)
+  f = add(p1, p1)
+
+  ROOT result = tuple(e,c,f)
+}
+  )";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+
+  Options options = DefaultMemorySpaceOptions();
+  options.max_size_in_bytes = 300;
+
+  // Setup cost analysis so it takes 2 instructions to prefetch anything.
+  HloCostAnalysis hlo_cost_analysis(ShapeSize);
+  TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis,
+                          FakeMemorySpaceAssignmentCostAnalysis::Create(
+                              hlo_cost_analysis, *module, options));
+  cost_analysis->SetOverrideForGetInstructionElapsed(
+      [](const HloInstruction& instruction) -> float { return 10.0; });
+  cost_analysis->SetOverrideForGetAsyncCopyElapsed(
+      [](const Shape& shape) -> float { return 20.0; });
+  options.cost_analysis = cost_analysis.get();
+  CostAnalysisPrefetchIntervalPicker prefetch_interval_picker(
+      CostAnalysisPrefetchIntervalPicker(
+          *cost_analysis, /*min_overlap_to_async_copy_ratio=*/0.8,
+          /*preferred_overlap_to_async_copy_ratio=*/1.5,
+          /*max_overlap_to_mem_size_async_copy_ratio=*/10.0,
+          /*mem_size_bytes=*/options.max_size_in_bytes));
+
+  // p0 has the highest priority, followed by p1, followed by everything else.
+  MemorySpaceAssignment::BufferIntervalCompare compare =
+      [](const MemorySpaceAssignment::BufferInterval& lhs,
+         const MemorySpaceAssignment::BufferInterval& rhs) -> bool {
+    auto lookup = [](const MemorySpaceAssignment::BufferInterval& x) {
+      // An arbitrary value that is greater than that for p0 and p1.
+      int priority = 100;
+      if (x.buffer->instruction()->name() == "p0") {
+        priority = 0;
+      } else if (x.buffer->instruction()->name() == "p1") {
+        priority = 1;
+      }
+      return std::make_tuple(priority, x.buffer->instruction()->name());
+    };
+
+    return lookup(lhs) < lookup(rhs);
+  };
+
+  // Run test.
+  AssignMemorySpace(module.get(), options, compare, &prefetch_interval_picker);
+
+  // - Make sure the setup occurred, i.e., that p0 is prefetched to alternate
+  //   memory for use by c.
+  // - Make sure p1 is prefetched.
+  ASSERT_THAT(
+      module->entry_computation()->root_instruction(),
+      op::Tuple(_,
+                // p0 is prefetched to alternate memory for use by c.
+                op::Add(op::AsyncCopy(kAlternateMemorySpace,
+                                      kDefaultMemorySpace, op::Parameter(0)),
+                        op::AsyncCopy(kAlternateMemorySpace,
+                                      kDefaultMemorySpace, op::Parameter(0))),
+                // p1 is prefetched to alternate memory for use by f.
+                op::Add(op::AsyncCopy(kAlternateMemorySpace,
+                                      kDefaultMemorySpace, op::Parameter(1)),
+                        op::AsyncCopy(kAlternateMemorySpace,
+                                      kDefaultMemorySpace, op::Parameter(1)))));
+
+  // Check the schedule
+  const std::vector<HloInstruction*>& schedule =
+      module->schedule().sequence(module->entry_computation()).instructions();
+  auto find_schedule_index = [&schedule](std::string_view name) -> int {
+    for (int i = 0; i < schedule.size(); ++i) {
+      if (schedule[i]->name() == name) {
+        return i;
+      }
+    }
+    LOG(FATAL) << "Unable to find index of instruction with name " << name;
+  };
+  int c_index = find_schedule_index("c");
+  int p1_copy_start = find_schedule_index(module->entry_computation()
+                                              ->root_instruction()  // result
+                                              ->operand(2)          // f
+                                              ->operand(0)          // copy done
+                                              ->operand(0)  // copy start
+                                              ->name());
+  int d_index = find_schedule_index("d");
+  int e_index = find_schedule_index("e");
+  int p1_copy_end = find_schedule_index(module->entry_computation()
+                                            ->root_instruction()  // result
+                                            ->operand(2)          // f
+                                            ->operand(0)          // copy done
+                                            ->name());
+  int f_index = find_schedule_index("f");
+  // We expect to start copying p1 after c.
+  EXPECT_EQ(p1_copy_start, c_index + 1);
+  // d and e should follow come between p1's copy start and end.
+  EXPECT_EQ(d_index, p1_copy_start + 1);
+  EXPECT_EQ(e_index, d_index + 1);
+  EXPECT_EQ(p1_copy_end, e_index + 1);
+  // f should immediately follow the end of p1's copy.
+  EXPECT_EQ(f_index, p1_copy_end + 1);
+}
+
 using CostAnalysisPrefetchIntervalPickerTest = HloTestBase;
 
 TEST_F(CostAnalysisPrefetchIntervalPickerTest, PrefetchIntervalOrder) {
@@ -9624,9 +9784,9 @@
       TF_RETURN_IF_ERROR(Initialize(module, alternate_memory_size));
     }
     MemorySpaceAssignmentCostAnalysis::Cache cache;
-    options_.buffer_interval_compare =
-        MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare(
-            *cost_analysis_, &cache);
+    memory_space_assignment::MemoryBoundednessBufferIntervalComparator
+        comparator(*cost_analysis_, &cache);
+    options_.buffer_interval_comparator = &comparator;
     CostAnalysisPrefetchIntervalPicker prefetch_interval_picker(
         CostAnalysisPrefetchIntervalPicker(
             *cost_analysis_, /*min_overlap_to_async_copy_ratio=*/0.8,
@@ -11752,29 +11912,28 @@
                          ShapeSize(f32_16_16)}),
       })));
 
-  // Force MSA to prefer prefetching (in order) p0, p1, p2, p3, p4, and then
+  // Force MSA to prefer prefetching (in order) p1, p2, p3, p4, and then
   // anything else.
   MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare =
-      [](const MemorySpaceAssignment::BufferInterval& a,
-         const MemorySpaceAssignment::BufferInterval& b) {
-        auto get_priority = [](const HloInstruction* instruction) {
-          if (instruction->name() == "p1") {
-            return 1;
+      [](const MemorySpaceAssignment::BufferInterval& lhs,
+         const MemorySpaceAssignment::BufferInterval& rhs) {
+        auto lookup = [](const MemorySpaceAssignment::BufferInterval& x) {
+          // An arbitrary value that is greater than that for p1, p2, p3, and
+          // p4.
+          int priority = 100;
+          if (x.buffer->instruction()->name() == "p1") {
+            priority = 1;
+          } else if (x.buffer->instruction()->name() == "p2") {
+            priority = 2;
+          } else if (x.buffer->instruction()->name() == "p3") {
+            priority = 3;
+          } else if (x.buffer->instruction()->name() == "p4") {
+            priority = 4;
           }
-          if (instruction->name() == "p2") {
-            return 2;
-          }
-          if (instruction->name() == "p3") {
-            return 3;
-          }
-          if (instruction->name() == "p4") {
-            return 4;
-          }
-          return 100;
+          return std::make_tuple(priority, x.buffer->instruction()->name());
         };
 
-        return get_priority(a.buffer->defining_instruction()) <
-               get_priority(b.buffer->defining_instruction());
+        return lookup(lhs) < lookup(rhs);
       };
 
   // Configure MSA.
@@ -11806,9 +11965,9 @@
         for (MockRepacker::AllocationBlock* block : allocations) {
           VLOG(1) << "Allocation block: " << block->ToString();
 
-          // Move "p2" from offset 1024 -> 2048.
-          if (block->start_time == 2 && block->initial_offset == 1024 &&
-              block->size == 2048) {
+          if (block->inclusive_start_time == 3 &&
+              block->initial_offset == 1024 && block->size == 2048) {
+            // Move "p2" from offset 1024 -> 2048.
             found_p2 = true;
             block->offset = 2048;
             // We expect p2 to be sliced. Check that it has slicing information
@@ -11816,23 +11975,26 @@
             EXPECT_TRUE(block->original_slice_data.has_value());
             if (block->original_slice_data.has_value()) {
               SlicedAllocationData expected(
-                  {{Slice{1024, 1024, 2}, Slice{1024, 2048, 6}}});
+                  {{Slice{1024, 1024, /*inclusive_start_time=*/3},
+                    Slice{1024, 2048, /*inclusive_start_time=*/7}}});
               EXPECT_EQ(*block->original_slice_data, expected)
                   << "\nExpected: " << expected.ToString()
                   << "\nGot: " << block->original_slice_data->ToString();
               // Set the first slice for p2 to be place at the larger offset.
               block->repacked_slice_data = SlicedAllocationData(
-                  {{Slice{1024, 2048, 6}, Slice{1024, 3072, 2}}});
+                  {{Slice{1024, 2048, /*inclusive_start_time=*/7},
+                    Slice{1024, 3072, /*inclusive_start_time=*/3}}});
             }
-          }
-          // Move "p3" from offset 3072 -> 1024.
-          if (block->start_time == 3 && block->initial_offset == 3072 &&
-              block->size == 1024) {
+          } else if (block->inclusive_start_time == 4 &&
+                     block->initial_offset == 3072 && block->size == 1024) {
+            // Move "p3" from offset 3072 -> 1024.
             found_p3 = true;
             block->offset = 1024;
             // We do not expect p3 to be sliced. Thus, it should not have
             // slicing information in its AllocationBlock.
             EXPECT_FALSE(block->original_slice_data.has_value());
+          } else {
+            block->offset = block->initial_offset;
           }
         }
 
diff --git a/third_party/xla/xla/service/memory_space_assignment/repacking.h b/third_party/xla/xla/service/memory_space_assignment/repacking.h
index 0379fc7..1556bd3 100644
--- a/third_party/xla/xla/service/memory_space_assignment/repacking.h
+++ b/third_party/xla/xla/service/memory_space_assignment/repacking.h
@@ -43,15 +43,16 @@
   struct Slice {
     int64_t size;
     int64_t offset;
-    int64_t start_time;
+    int64_t inclusive_start_time;
 
     std::string ToString() const {
       return absl::StrCat("{ size: ", size, ", offset: ", offset,
-                          ", start_time: ", start_time, " }");
+                          ", inclusive_start_time: ", inclusive_start_time,
+                          " }");
     }
 
     std::tuple<int64_t, int64_t, int64_t> ToTuple() const {
-      return std::make_tuple(size, offset, start_time);
+      return std::make_tuple(size, offset, inclusive_start_time);
     }
 
     bool operator==(const Slice& rhs) const {
@@ -73,15 +74,15 @@
       return sizes_sorted_by_offset;
     }
 
-    std::vector<int64_t> SortedStartTimes() const {
-      std::vector<int64_t> sorted_start_times;
-      sorted_start_times.reserve(slices_sorted_by_offset.size());
-      absl::c_for_each(slices_sorted_by_offset,
-                       [&sorted_start_times](const Slice& slice) {
-                         sorted_start_times.push_back(slice.start_time);
-                       });
-      absl::c_sort(sorted_start_times);
-      return sorted_start_times;
+    std::vector<int64_t> SortedInclusiveStartTimes() const {
+      std::vector<int64_t> sorted_inclusive_start_times;
+      sorted_inclusive_start_times.reserve(slices_sorted_by_offset.size());
+      absl::c_for_each(slices_sorted_by_offset, [&sorted_inclusive_start_times](
+                                                    const Slice& slice) {
+        sorted_inclusive_start_times.push_back(slice.inclusive_start_time);
+      });
+      absl::c_sort(sorted_inclusive_start_times);
+      return sorted_inclusive_start_times;
     }
 
     std::string ToString() const {
@@ -113,7 +114,7 @@
   // the information in the original_slice_data field to achieve an even more
   // efficient repacking.
   struct AllocationBlock {
-    int64_t start_time;
+    int64_t inclusive_start_time;
     int64_t end_time;
     int64_t size;
     int64_t offset;
@@ -137,8 +138,8 @@
         repacked_slicing_str = absl::StrCat("; repacked_slice_data: ",
                                             repacked_slice_data->ToString());
       }
-      return absl::StrCat("[", start_time, ", ", end_time, "]; size: ", size,
-                          "; offset: ", offset,
+      return absl::StrCat("[", inclusive_start_time, ", ", end_time,
+                          "]; size: ", size, "; offset: ", offset,
                           "; initial offset: ", initial_offset,
                           "; # colocations: ", colocations.size(),
                           original_slicing_str, repacked_slicing_str);
diff --git a/third_party/xla/xla/service/reduce_scatter_decomposer.cc b/third_party/xla/xla/service/reduce_scatter_decomposer.cc
index 1fb197b..5936663 100644
--- a/third_party/xla/xla/service/reduce_scatter_decomposer.cc
+++ b/third_party/xla/xla/service/reduce_scatter_decomposer.cc
@@ -55,11 +55,15 @@
       }
 
       // Create an all-reduce
+      HloComputation *apply_clone = module->AddComputationAndUnifyNamesAndIds(
+          rs->to_apply()->Clone(), /*is_entry=*/false);
       HloInstruction *ar =
           computation->AddInstruction(HloInstruction::CreateAllReduce(
-              rs->operand(0)->shape(), rs->operands(), rs->to_apply(),
+              rs->operand(0)->shape(), rs->operands(), apply_clone,
               rs->replica_groups(), rs->constrain_layout(), channel_id,
               rs->use_global_device_ids()));
+      apply_clone->SetCollectiveCallInstruction(ar);
+
       // Create start indices for a dynamic slice to decompose the all-reduce
       // results.
       TF_ASSIGN_OR_RETURN(
diff --git a/third_party/xla/xla/service/service.cc b/third_party/xla/xla/service/service.cc
index 8a25b21..3aa2f0c 100644
--- a/third_party/xla/xla/service/service.cc
+++ b/third_party/xla/xla/service/service.cc
@@ -951,8 +951,19 @@
     if (!LayoutUtil::HasLayout(return_shape)) {
       return InvalidArgument("shape_with_layout must have layout if present.");
     }
+    if (return_shape.has_layout() &&
+        return_shape.layout().element_size_in_bits() != 0) {
+      return InvalidArgument(
+          "shape_with_layout cannot have layout's element_size_in_bits field "
+          "set");
+    }
   } else {
     return_shape = Shape(shaped_buffer->on_device_shape());
+    if (return_shape.has_layout() &&
+        return_shape.layout().element_size_in_bits() != 0) {
+      // Literals do not support element_size_in_bits
+      return_shape.mutable_layout()->set_element_size_in_bits(0);
+    }
   }
 
   TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(
diff --git a/third_party/xla/xla/service/shape_inference.cc b/third_party/xla/xla/service/shape_inference.cc
index fbab0ea..dcb5d7e 100644
--- a/third_party/xla/xla/service/shape_inference.cc
+++ b/third_party/xla/xla/service/shape_inference.cc
@@ -17,28 +17,39 @@
 
 #include <algorithm>
 #include <cstddef>
+#include <cstdint>
 #include <iterator>
+#include <limits>
 #include <numeric>
+#include <optional>
 #include <set>
 #include <string>
+#include <utility>
+#include <vector>
 
 #include "absl/algorithm/container.h"
 #include "absl/container/flat_hash_set.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/log/check.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/str_format.h"
 #include "absl/strings/str_join.h"
 #include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_opcode.h"
 #include "xla/permutation_util.h"
 #include "xla/primitive_util.h"
+#include "xla/shape.h"
 #include "xla/shape_util.h"
+#include "xla/status.h"
 #include "xla/status_macros.h"
-#include "xla/types.h"
+#include "xla/statusor.h"
 #include "xla/util.h"
 #include "xla/window_util.h"
 #include "xla/xla_data.pb.h"
 #include "tsl/platform/errors.h"
 #include "tsl/platform/logging.h"
-#include "tsl/platform/protobuf.h"
+#include "tsl/platform/status.h"
 #include "tsl/platform/statusor.h"
 
 namespace xla {
@@ -1332,8 +1343,9 @@
   }
 
   const int64_t feature_count = operand_shape.dimensions(feature_index);
-  Shape output_shape_for_mean_and_var =
-      ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
+  bool dynamic_feature = operand_shape.is_dynamic_dimension(feature_index);
+  Shape output_shape_for_mean_and_var = ShapeUtil::MakeShape(
+      operand_shape.element_type(), {feature_count}, {dynamic_feature});
 
   if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) {
     return InvalidArgument(
@@ -2731,7 +2743,18 @@
     VLOG(2) << StrFormat("slice_sizes[%d] = %d", dim, slice_dim_size);
   }
 
-  return ShapeUtil::MakeShape(operand_shape.element_type(), slice_sizes);
+  Shape result =
+      ShapeUtil::MakeShape(operand_shape.element_type(), slice_sizes);
+
+  for (int64_t dimension = 0; dimension < operand_shape.rank(); ++dimension) {
+    if (operand_shape.is_dynamic_dimension(dimension) &&
+        slice_sizes[dimension] > 1 &&
+        slice_sizes[dimension] == operand_shape.dimensions(dimension)) {
+      result.set_dynamic_dimension(dimension, true);
+    }
+  }
+
+  return result;
 }
 
 /* static */ StatusOr<Shape> ShapeInference::InferDynamicUpdateSliceShape(
@@ -3336,8 +3359,15 @@
         ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(pred));
   }
 
-  return ShapeUtil::ChangeElementType(
+  Shape result = ShapeUtil::ChangeElementType(
       pred, ShapeUtil::HigherPrecisionElementType(on_true, on_false));
+  for (int64_t dimension = 0; dimension < pred.rank(); ++dimension) {
+    result.set_dynamic_dimension(dimension,
+                                 pred.is_dynamic_dimension(dimension) ||
+                                     on_true.is_dynamic_dimension(dimension) ||
+                                     on_false.is_dynamic_dimension(dimension));
+  }
+  return std::move(result);
 }
 
 /* static */ StatusOr<Shape> ShapeInference::InferCallShape(
diff --git a/third_party/xla/xla/service/sharding_propagation.cc b/third_party/xla/xla/service/sharding_propagation.cc
index cbf7e2b..ad89c05 100644
--- a/third_party/xla/xla/service/sharding_propagation.cc
+++ b/third_party/xla/xla/service/sharding_propagation.cc
@@ -33,6 +33,7 @@
 #include "absl/log/check.h"
 #include "absl/strings/str_join.h"
 #include "absl/types/span.h"
+#include "xla/array.h"
 #include "xla/hlo/ir/hlo_casting_utils.h"
 #include "xla/hlo/ir/hlo_computation.h"
 #include "xla/hlo/ir/hlo_instruction.h"
@@ -536,121 +537,6 @@
   return sharding;
 }
 
-bool InferDotShardingFromOperands(
-    HloInstruction* instruction, const CallGraph& call_graph,
-    const dot_as_convolution_util::DotConvolutionDimsInfo& dnums,
-    bool may_combine_partial_sharding, bool is_spmd) {
-  auto from_operand = [&](int64_t operand_index) {
-    auto operand = instruction->operand(operand_index);
-    const HloSharding& operand_sharding = operand->sharding();
-    if (operand_sharding.IsTileMaximal()) {
-      return operand_sharding;
-    }
-    std::vector<int64_t> contracting_dims;
-    contracting_dims.reserve(dnums.contracting_dims.size());
-    for (const auto& dim : dnums.contracting_dims) {
-      contracting_dims.push_back(operand_index == 0 ? dim.lhs : dim.rhs);
-    }
-    // It's possible that some size-1 spatial dims of convolutions are parsed as
-    // non-contracting dims. We might have tiled dimensions on them.
-    for (const auto& dim : operand_index == 0
-                               ? dnums.rhs_non_contracting_dims
-                               : dnums.lhs_non_contracting_dims) {
-      int64_t d = operand_index == 0 ? dim.lhs : dim.rhs;
-      if (d >= 0) {
-        contracting_dims.push_back(d);
-      }
-    }
-    auto replicate_contracting_dims =
-        hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
-            operand_sharding, contracting_dims);
-    std::vector<int64_t> out_dims_to_op_perm(instruction->shape().rank(), -1);
-    std::vector<int64_t> op_dims_to_output_perm(operand->shape().rank(), -1);
-    for (const auto& dim : dnums.batch_dims) {
-      out_dims_to_op_perm[dim.output] = operand_index == 0 ? dim.lhs : dim.rhs;
-      op_dims_to_output_perm[operand_index == 0 ? dim.lhs : dim.rhs] =
-          dim.output;
-    }
-    for (const auto& dim : operand_index == 0
-                               ? dnums.lhs_non_contracting_dims
-                               : dnums.rhs_non_contracting_dims) {
-      out_dims_to_op_perm[dim.output] = operand_index == 0 ? dim.lhs : dim.rhs;
-      op_dims_to_output_perm[operand_index == 0 ? dim.lhs : dim.rhs] =
-          dim.output;
-    }
-    return *hlo_sharding_util::TransposeShardingWithCollapsedDims(
-        replicate_contracting_dims, op_dims_to_output_perm,
-        out_dims_to_op_perm);
-  };
-  std::optional<HloSharding> improved_operand_0;
-  std::optional<HloSharding> improved_operand_1;
-  if (IsSpatiallyPartitioned(instruction->operand(0))) {
-    improved_operand_0 = ReturnImprovedSharding(
-        from_operand(0), instruction, may_combine_partial_sharding,
-        /*allow_aggressive_resharding=*/false);
-  }
-  if (IsSpatiallyPartitioned(instruction->operand(1))) {
-    improved_operand_1 = ReturnImprovedSharding(
-        from_operand(1), instruction, may_combine_partial_sharding,
-        /*allow_aggressive_resharding=*/false);
-  }
-  // If not improved sharding found then do not set any sharding.
-  if (!improved_operand_0.has_value() && !improved_operand_1.has_value()) {
-    return false;
-  }
-  // Sharding found from operand 0 but not operand 1. Set sharding from operand
-  // 0
-  if (improved_operand_0.has_value() && !improved_operand_1.has_value()) {
-    instruction->set_sharding(*improved_operand_0);
-    return true;
-  }
-  // Sharding found from operand 1 but not operand 0. Set sharding from operand
-  // 1
-  if (!improved_operand_0.has_value() && improved_operand_1.has_value()) {
-    instruction->set_sharding(*improved_operand_1);
-    return true;
-  }
-  CHECK(improved_operand_0.has_value() && improved_operand_1.has_value());
-  std::optional<HloSharding> lookahead_sharding =
-      LookaheadUserSharding(instruction, is_spmd, call_graph);
-  std::array<HloSharding, 2> sharding_priority = {*improved_operand_0,
-                                                  *improved_operand_1};
-  bool priority_defined_with_lookahead = false;
-  // Found sharding from lookahead.
-  if (lookahead_sharding.has_value()) {
-    const bool operand_0_is_lookahead_subtiling =
-        hlo_sharding_util::IsSubTilingOrEqualSharding(
-            instruction->shape(), *lookahead_sharding, *improved_operand_0);
-    const bool operand_1_is_lookahead_subtiling =
-        hlo_sharding_util::IsSubTilingOrEqualSharding(
-            instruction->shape(), *lookahead_sharding, *improved_operand_1);
-    // If the sharding from operand 0 is a subtiling of the user, but not the
-    // one from operand 1 prioritize that sharding.
-    if (operand_0_is_lookahead_subtiling && !operand_1_is_lookahead_subtiling) {
-      priority_defined_with_lookahead = true;
-    }
-    // If the sharding from operand 1 is a subtiling of the user, but not the
-    // one from operand 0 prioritize that sharding.
-    if (!operand_0_is_lookahead_subtiling && operand_1_is_lookahead_subtiling) {
-      instruction->set_sharding(*improved_operand_1);
-      std::swap(sharding_priority[0], sharding_priority[1]);
-      priority_defined_with_lookahead = true;
-    }
-  }
-  // If lookahead didn't define a priority then use size.
-  if (!priority_defined_with_lookahead &&
-      ShapeUtil::ByteSizeOf(instruction->operand(0)->shape()) <
-          ShapeUtil::ByteSizeOf(instruction->operand(1)->shape())) {
-    std::swap(sharding_priority[0], sharding_priority[1]);
-  }
-  // Set primary sharding to the instruction and then try to improve it with
-  // the secondary sharding.
-  instruction->set_sharding(sharding_priority[0]);
-  MaybeImproveInstructionSharding(sharding_priority[1], instruction,
-                                  may_combine_partial_sharding);
-  return true;
-}
-
 // Infer output sharding on index parallel dimensions for gather/scatter from
 // gather operand/indices or scatter operands/indices/updates.
 HloSharding InferParallelShardingFromOperand(
@@ -793,92 +679,6 @@
   return changed;
 }
 
-// Convolution handling for InferShardingFromOperands().
-bool InferConvolutionShardingFromOperands(HloInstruction* instruction,
-                                          const CallGraph& call_graph,
-                                          int64_t aggressiveness,
-                                          bool may_combine_partial_sharding,
-                                          bool is_spmd) {
-  auto get_partitions_for_dims =
-      [&](const HloInstruction* inst,
-          absl::Span<
-              const dot_as_convolution_util::DotConvolutionDimsInfo::DimNums>
-              dims,
-          int lhs_or_rhs) {
-        int64_t partitions = 1;
-        if (!inst->has_sharding()) {
-          return partitions;
-        }
-        const auto& sharding = inst->sharding();
-        if (sharding.IsTileMaximal()) {
-          return partitions;
-        }
-        for (const auto& dim : dims) {
-          if (lhs_or_rhs == 0) {
-            partitions *= sharding.tile_assignment().dim(dim.lhs);
-          } else {
-            CHECK_EQ(lhs_or_rhs, 1);
-            partitions *= sharding.tile_assignment().dim(dim.rhs);
-          }
-        }
-        return partitions;
-      };
-  auto dot_dims =
-      dot_as_convolution_util::ParseConvolutionDimsInfo(instruction);
-  const int64_t lhs_conv_spatial_partitions = get_partitions_for_dims(
-      instruction->operand(0), dot_dims.conv_spatial_dims, 0);
-  const int64_t rhs_conv_spatial_partitions = get_partitions_for_dims(
-      instruction->operand(1), dot_dims.conv_spatial_dims, 1);
-  if (dot_dims.conv_spatial_dims.empty() ||
-      (lhs_conv_spatial_partitions == 1 && rhs_conv_spatial_partitions == 1 &&
-       instruction->batch_group_count() == 1 &&
-       instruction->feature_group_count() == 1)) {
-    return InferDotShardingFromOperands(instruction, call_graph, dot_dims,
-                                        may_combine_partial_sharding, is_spmd);
-  }
-  const auto& dnums = instruction->convolution_dimension_numbers();
-  const HloInstruction* lhs = instruction->operand(0);
-  auto get_tiled_sharding_based_on_lhs = [&] {
-    CHECK(!lhs->sharding().IsTileMaximal());
-    std::vector<int64_t> output_to_lhs_indices(instruction->shape().rank());
-    output_to_lhs_indices[dnums.output_batch_dimension()] =
-        dnums.input_batch_dimension();
-    output_to_lhs_indices[dnums.output_feature_dimension()] =
-        dnums.input_feature_dimension();
-    for (int64_t i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
-      output_to_lhs_indices[dnums.output_spatial_dimensions(i)] =
-          dnums.input_spatial_dimensions(i);
-    }
-    return hlo_sharding_util::TransposeSharding(lhs->sharding(),
-                                                output_to_lhs_indices);
-  };
-  if (!IsSpatiallyPartitioned(lhs)) {
-    return false;
-  }
-  if (lhs->sharding().IsTileMaximal()) {
-    return MaybeImproveInstructionSharding(lhs->sharding(), instruction,
-                                           may_combine_partial_sharding);
-  }
-
-  if (IsConvolutionKernelSmall(instruction)) {
-    // If the kernel is small compared to the input then we can generate an
-    // output what is sharded the same way as the input.
-    const auto& tile_assignment = lhs->sharding().tile_assignment();
-    if (tile_assignment.dim(dnums.input_feature_dimension()) > 1) {
-      return false;
-    }
-    return MaybeImproveInstructionSharding(get_tiled_sharding_based_on_lhs(),
-                                           instruction,
-                                           may_combine_partial_sharding);
-  }
-  // If the kernel is large (e.g backward convolution) then we only support
-  // replicated output.
-  return MaybeImproveInstructionSharding(
-      hlo_sharding_util::ReplicateAllDataDims(lhs->sharding(),
-                                              instruction->shape().rank()),
-      instruction, may_combine_partial_sharding);
-}
-
 bool CanPropagateThroughAtAggressiveLevel(const HloInstruction& inst,
                                           int64_t aggressiveness) {
   // At minimum aggressiveness, only allow pass-through ops.
@@ -1419,6 +1219,207 @@
 
 }  // namespace
 
+bool InferDotShardingFromOperands(
+    HloInstruction* instruction, const CallGraph& call_graph,
+    const dot_as_convolution_util::DotConvolutionDimsInfo& dnums,
+    bool may_combine_partial_sharding, bool is_spmd) {
+  auto from_operand = [&](int64_t operand_index) {
+    auto operand = instruction->operand(operand_index);
+    const HloSharding& operand_sharding = operand->sharding();
+    if (operand_sharding.IsTileMaximal()) {
+      return operand_sharding;
+    }
+    std::vector<int64_t> contracting_dims;
+    contracting_dims.reserve(dnums.contracting_dims.size());
+    for (const auto& dim : dnums.contracting_dims) {
+      contracting_dims.push_back(operand_index == 0 ? dim.lhs : dim.rhs);
+    }
+    // It's possible that some size-1 spatial dims of convolutions are parsed as
+    // non-contracting dims. We might have tiled dimensions on them.
+    for (const auto& dim : operand_index == 0
+                               ? dnums.rhs_non_contracting_dims
+                               : dnums.lhs_non_contracting_dims) {
+      int64_t d = operand_index == 0 ? dim.lhs : dim.rhs;
+      if (d >= 0) {
+        contracting_dims.push_back(d);
+      }
+    }
+    auto replicate_contracting_dims =
+        hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
+            operand_sharding, contracting_dims);
+    std::vector<int64_t> out_dims_to_op_perm(instruction->shape().rank(), -1);
+    std::vector<int64_t> op_dims_to_output_perm(operand->shape().rank(), -1);
+    for (const auto& dim : dnums.batch_dims) {
+      out_dims_to_op_perm[dim.output] = operand_index == 0 ? dim.lhs : dim.rhs;
+      op_dims_to_output_perm[operand_index == 0 ? dim.lhs : dim.rhs] =
+          dim.output;
+    }
+    for (const auto& dim : operand_index == 0
+                               ? dnums.lhs_non_contracting_dims
+                               : dnums.rhs_non_contracting_dims) {
+      out_dims_to_op_perm[dim.output] = operand_index == 0 ? dim.lhs : dim.rhs;
+      op_dims_to_output_perm[operand_index == 0 ? dim.lhs : dim.rhs] =
+          dim.output;
+    }
+    return *hlo_sharding_util::TransposeShardingWithCollapsedDims(
+        replicate_contracting_dims, op_dims_to_output_perm,
+        out_dims_to_op_perm);
+  };
+  std::optional<HloSharding> improved_operand_0;
+  std::optional<HloSharding> improved_operand_1;
+  if (IsSpatiallyPartitioned(instruction->operand(0))) {
+    improved_operand_0 = ReturnImprovedSharding(
+        from_operand(0), instruction, may_combine_partial_sharding,
+        /*allow_aggressive_resharding=*/false);
+  }
+  if (IsSpatiallyPartitioned(instruction->operand(1))) {
+    improved_operand_1 = ReturnImprovedSharding(
+        from_operand(1), instruction, may_combine_partial_sharding,
+        /*allow_aggressive_resharding=*/false);
+  }
+  // If not improved sharding found then do not set any sharding.
+  if (!improved_operand_0.has_value() && !improved_operand_1.has_value()) {
+    return false;
+  }
+  // Sharding found from operand 0 but not operand 1. Set sharding from operand
+  // 0
+  if (improved_operand_0.has_value() && !improved_operand_1.has_value()) {
+    instruction->set_sharding(*improved_operand_0);
+    return true;
+  }
+  // Sharding found from operand 1 but not operand 0. Set sharding from operand
+  // 1
+  if (!improved_operand_0.has_value() && improved_operand_1.has_value()) {
+    instruction->set_sharding(*improved_operand_1);
+    return true;
+  }
+  CHECK(improved_operand_0.has_value() && improved_operand_1.has_value());
+  std::optional<HloSharding> lookahead_sharding =
+      LookaheadUserSharding(instruction, is_spmd, call_graph);
+  std::array<HloSharding, 2> sharding_priority = {*improved_operand_0,
+                                                  *improved_operand_1};
+  bool priority_defined_with_lookahead = false;
+  // Found sharding from lookahead.
+  if (lookahead_sharding.has_value()) {
+    const bool operand_0_is_lookahead_subtiling =
+        hlo_sharding_util::IsSubTilingOrEqualSharding(
+            instruction->shape(), *lookahead_sharding, *improved_operand_0);
+    const bool operand_1_is_lookahead_subtiling =
+        hlo_sharding_util::IsSubTilingOrEqualSharding(
+            instruction->shape(), *lookahead_sharding, *improved_operand_1);
+    // If the sharding from operand 0 is a subtiling of the user, but not the
+    // one from operand 1 prioritize that sharding.
+    if (operand_0_is_lookahead_subtiling && !operand_1_is_lookahead_subtiling) {
+      priority_defined_with_lookahead = true;
+    }
+    // If the sharding from operand 1 is a subtiling of the user, but not the
+    // one from operand 0 prioritize that sharding.
+    if (!operand_0_is_lookahead_subtiling && operand_1_is_lookahead_subtiling) {
+      instruction->set_sharding(*improved_operand_1);
+      std::swap(sharding_priority[0], sharding_priority[1]);
+      priority_defined_with_lookahead = true;
+    }
+  }
+  // If lookahead didn't define a priority then use size.
+  if (!priority_defined_with_lookahead &&
+      ShapeUtil::ByteSizeOf(instruction->operand(0)->shape()) <
+          ShapeUtil::ByteSizeOf(instruction->operand(1)->shape())) {
+    std::swap(sharding_priority[0], sharding_priority[1]);
+  }
+  // Set primary sharding to the instruction and then try to improve it with
+  // the secondary sharding.
+  instruction->set_sharding(sharding_priority[0]);
+  MaybeImproveInstructionSharding(sharding_priority[1], instruction,
+                                  may_combine_partial_sharding);
+  return true;
+}
+
+// Convolution handling for InferShardingFromOperands().
+bool InferConvolutionShardingFromOperands(HloInstruction* instruction,
+                                          const CallGraph& call_graph,
+                                          int64_t aggressiveness,
+                                          bool may_combine_partial_sharding,
+                                          bool is_spmd) {
+  auto get_partitions_for_dims =
+      [&](const HloInstruction* inst,
+          absl::Span<
+              const dot_as_convolution_util::DotConvolutionDimsInfo::DimNums>
+              dims,
+          int lhs_or_rhs) {
+        int64_t partitions = 1;
+        if (!inst->has_sharding()) {
+          return partitions;
+        }
+        const auto& sharding = inst->sharding();
+        if (sharding.IsTileMaximal()) {
+          return partitions;
+        }
+        for (const auto& dim : dims) {
+          if (lhs_or_rhs == 0) {
+            partitions *= sharding.tile_assignment().dim(dim.lhs);
+          } else {
+            CHECK_EQ(lhs_or_rhs, 1);
+            partitions *= sharding.tile_assignment().dim(dim.rhs);
+          }
+        }
+        return partitions;
+      };
+  auto dot_dims =
+      dot_as_convolution_util::ParseConvolutionDimsInfo(instruction);
+  const int64_t lhs_conv_spatial_partitions = get_partitions_for_dims(
+      instruction->operand(0), dot_dims.conv_spatial_dims, 0);
+  const int64_t rhs_conv_spatial_partitions = get_partitions_for_dims(
+      instruction->operand(1), dot_dims.conv_spatial_dims, 1);
+  if (dot_dims.conv_spatial_dims.empty() ||
+      (lhs_conv_spatial_partitions == 1 && rhs_conv_spatial_partitions == 1 &&
+       instruction->batch_group_count() == 1 &&
+       instruction->feature_group_count() == 1)) {
+    return InferDotShardingFromOperands(instruction, call_graph, dot_dims,
+                                        may_combine_partial_sharding, is_spmd);
+  }
+  const auto& dnums = instruction->convolution_dimension_numbers();
+  const HloInstruction* lhs = instruction->operand(0);
+  auto get_tiled_sharding_based_on_lhs = [&] {
+    CHECK(!lhs->sharding().IsTileMaximal());
+    std::vector<int64_t> output_to_lhs_indices(instruction->shape().rank());
+    output_to_lhs_indices[dnums.output_batch_dimension()] =
+        dnums.input_batch_dimension();
+    output_to_lhs_indices[dnums.output_feature_dimension()] =
+        dnums.input_feature_dimension();
+    for (int64_t i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
+      output_to_lhs_indices[dnums.output_spatial_dimensions(i)] =
+          dnums.input_spatial_dimensions(i);
+    }
+    return hlo_sharding_util::TransposeSharding(lhs->sharding(),
+                                                output_to_lhs_indices);
+  };
+  if (!IsSpatiallyPartitioned(lhs)) {
+    return false;
+  }
+  if (lhs->sharding().IsTileMaximal()) {
+    return MaybeImproveInstructionSharding(lhs->sharding(), instruction,
+                                           may_combine_partial_sharding);
+  }
+
+  if (IsConvolutionKernelSmall(instruction)) {
+    // If the kernel is small compared to the input then we can generate an
+    // output what is sharded the same way as the input.
+    const auto& tile_assignment = lhs->sharding().tile_assignment();
+    if (tile_assignment.dim(dnums.input_feature_dimension()) > 1) {
+      return false;
+    }
+    return MaybeImproveInstructionSharding(get_tiled_sharding_based_on_lhs(),
+                                           instruction,
+                                           may_combine_partial_sharding);
+  }
+  // If the kernel is large (e.g backward convolution) then we only support
+  // replicated output.
+  return MaybeImproveInstructionSharding(
+      hlo_sharding_util::ReplicateAllDataDims(lhs->sharding(),
+                                              instruction->shape().rank()),
+      instruction, may_combine_partial_sharding);
+}
+
 std::optional<HloSharding> InferBroadcastOperandSharding(
     const HloInstruction& instruction, bool is_spmd) {
   if (instruction.sharding().IsReplicated() ||
@@ -2745,8 +2746,7 @@
     return false;
   }
   // Propagate manual sharding.
-  if (!instruction->has_sharding() ||
-      instruction->sharding().IsTileMaximal()) {
+  if (!instruction->has_sharding() || instruction->sharding().IsTileMaximal()) {
     for (const HloInstruction* user : instruction->users()) {
       if (!user->has_sharding() || user->IsCustomCall("SPMDFullToShardShape"))
         continue;
diff --git a/third_party/xla/xla/service/sharding_propagation.h b/third_party/xla/xla/service/sharding_propagation.h
index 8fe65ce..2cdf11a 100644
--- a/third_party/xla/xla/service/sharding_propagation.h
+++ b/third_party/xla/xla/service/sharding_propagation.h
@@ -26,11 +26,27 @@
 #include "xla/hlo/ir/hlo_module.h"
 #include "xla/service/call_graph.h"
 #include "xla/service/custom_call_sharding_helper.h"
+#include "xla/service/dot_as_convolution_util.h"
 #include "xla/service/hlo_pass_interface.h"
 #include "xla/statusor.h"
 
 namespace xla {
 
+// Infers the shardings for a dot HLO op from the shardings on its operands,
+// which are expected to have sharding annotations.
+bool InferDotShardingFromOperands(
+    HloInstruction* instruction, const CallGraph& call_graph,
+    const dot_as_convolution_util::DotConvolutionDimsInfo& dnums,
+    bool may_combine_partial_sharding, bool is_spmd);
+
+// Infers the shardings for a convolution HLO op from the shardings on its
+// operands, which are expected to have sharding annotations.
+bool InferConvolutionShardingFromOperands(HloInstruction* instruction,
+                                          const CallGraph& call_graph,
+                                          int64_t aggressiveness,
+                                          bool may_combine_partial_sharding,
+                                          bool is_spmd);
+
 // Remove Sharding custom-call instruction by folding the sharding attribute
 // to its operand. If the operand already has a different sharding, insert a
 // copy node for reshard. Depending on whether propagating the spmd sharding to
diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner.cc b/third_party/xla/xla/service/spmd/spmd_partitioner.cc
index 429181d..98438aa 100644
--- a/third_party/xla/xla/service/spmd/spmd_partitioner.cc
+++ b/third_party/xla/xla/service/spmd/spmd_partitioner.cc
@@ -48,7 +48,6 @@
 #include "xla/service/hlo_cse.h"
 #include "xla/service/hlo_dce.h"
 #include "xla/service/hlo_pass_pipeline.h"
-#include "xla/service/pattern_matcher.h"
 #include "xla/service/shape_inference.h"
 #include "xla/service/spmd/custom_call_handler.h"
 #include "xla/service/spmd/spmd_partitioner_util.h"
@@ -4730,10 +4729,16 @@
           for (int64_t i = 0; i < num_replicas; ++i) {
             groups[i].add_replica_ids(i);
           }
-          return b->AddInstruction(HloInstruction::CreateAllReduce(
-              operand->shape(), {operand}, reduction, groups,
-              /*constrain_layout=*/false, channel_id,
-              /*use_global_device_ids=*/false));
+          HloComputation* reduction_clone =
+              reduction->parent()->AddComputationAndUnifyNamesAndIds(
+                  reduction->Clone(), false);
+          HloInstruction* all_reduce =
+              b->AddInstruction(HloInstruction::CreateAllReduce(
+                  operand->shape(), {operand}, reduction_clone, groups,
+                  /*constrain_layout=*/false, channel_id,
+                  /*use_global_device_ids=*/false));
+          reduction_clone->SetCollectiveCallInstruction(all_reduce);
+          return all_reduce;
         }
 
         std::vector<ReplicaGroup> device_groups;
@@ -4746,10 +4751,16 @@
             }
           }
         }
-        return b->AddInstruction(HloInstruction::CreateAllReduce(
-            operand->shape(), {operand}, reduction, device_groups,
-            /*constrain_layout=*/false, channel_id,
-            /*use_global_device_ids=*/true));
+        HloComputation* reduction_clone =
+            reduction->parent()->AddComputationAndUnifyNamesAndIds(
+                reduction->Clone(), false);
+        HloInstruction* all_reduce =
+            b->AddInstruction(HloInstruction::CreateAllReduce(
+                operand->shape(), {operand}, reduction_clone, device_groups,
+                /*constrain_layout=*/false, channel_id,
+                /*use_global_device_ids=*/true));
+        reduction_clone->SetCollectiveCallInstruction(all_reduce);
+        return all_reduce;
       },
       [num_partitions](SpmdBuilder* b, HloInstruction* operand,
                        std::vector<std::pair<int64_t, int64_t>>& src_dst_pairs,
diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner.h b/third_party/xla/xla/service/spmd/spmd_partitioner.h
index 26d7a5e..de9063d 100644
--- a/third_party/xla/xla/service/spmd/spmd_partitioner.h
+++ b/third_party/xla/xla/service/spmd/spmd_partitioner.h
@@ -605,6 +605,7 @@
       std::vector<std::vector<int64_t>>& groups);
 
   const CallGraph& call_graph() { return call_graph_; }
+  int64_t num_partitions() const { return num_partitions_; }
 
   // Information about a loop created for windowed dot-general. Used when
   // DoCodeMotionForWindowedDotGeneralLoops() executes after the visitor
diff --git a/third_party/xla/xla/service/sub_byte_normalization.cc b/third_party/xla/xla/service/sub_byte_normalization.cc
index 66d3a56..60a8c79 100644
--- a/third_party/xla/xla/service/sub_byte_normalization.cc
+++ b/third_party/xla/xla/service/sub_byte_normalization.cc
@@ -15,32 +15,73 @@
 
 #include "xla/service/sub_byte_normalization.h"
 
+#include <cstdint>
+
 #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
 #include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/layout.h"
+#include "xla/primitive_util.h"
+#include "xla/shape.h"
+#include "xla/shape_layout.h"
 #include "tsl/platform/errors.h"
+#include "tsl/platform/status.h"
 
 namespace xla {
 
 namespace {
 
-bool RemoveInt4SizeFromShape(Shape* shape) {
+// Updates the layout by setting element_size_in_bits to the appropriate value.
+// Returns true if the layout was changed.
+bool UpdateLayout(Layout* layout, PrimitiveType type,
+                  SubByteNormalization::Mode mode) {
+  auto set_element_size = [layout](int64_t element_size) {
+    if (layout->element_size_in_bits() != element_size) {
+      layout->set_element_size_in_bits(element_size);
+      return true;
+    }
+    return false;
+  };
+
+  switch (mode) {
+    case SubByteNormalization::REMOVE_ELEMENT_SIZE:
+      return set_element_size(0);
+    case SubByteNormalization::SET_ELEMENT_SIZE:
+      if (primitive_util::Is4BitType(type)) {
+        return set_element_size(4);
+      } else {
+        return set_element_size(0);
+      }
+  }
+}
+
+// Updates the shape by setting set_element_size_in_bits on the shape's layout.
+// Returns true if a layout was changed.
+bool UpdateShape(Shape* shape, SubByteNormalization::Mode mode) {
   if (shape->IsTuple()) {
     bool changed = false;
     for (int idx = 0; idx < shape->tuple_shapes_size(); ++idx) {
-      changed |= RemoveInt4SizeFromShape(shape->mutable_tuple_shapes(idx));
+      changed |= UpdateShape(shape->mutable_tuple_shapes(idx), mode);
     }
     return changed;
   }
-  if (shape->IsArray()) {
-    const int64_t element_size_in_bits = shape->layout().element_size_in_bits();
-    if (element_size_in_bits != 0 && element_size_in_bits < 8) {
-      shape->mutable_layout()->set_element_size_in_bits(0);
-      return true;
-    }
+  if (shape->IsArray() && shape->has_layout()) {
+    return UpdateLayout(shape->mutable_layout(), shape->element_type(), mode);
   }
   return false;
 }
 
+// Sets element_size_in_bits on a ShapeLayout's layout. Returns true if the
+// layout was changed.
+bool ProcessInputOrOutputLayout(ShapeLayout* shape_layout,
+                                SubByteNormalization::Mode mode) {
+  Shape shape = shape_layout->shape();
+  bool changed = UpdateShape(&shape, mode);
+  if (changed) {
+    TF_CHECK_OK(shape_layout->CopyLayoutFromShape(shape));
+  }
+  return changed;
+}
+
 }  // namespace
 
 StatusOr<bool> SubByteNormalization::Run(
@@ -49,7 +90,7 @@
   bool changed = false;
   FunctionVisitor visitor([&](HloInstruction* hlo) -> Status {
     auto* shape = hlo->mutable_shape();
-    changed |= RemoveInt4SizeFromShape(shape);
+    changed |= UpdateShape(shape, mode_);
     return OkStatus();
   });
   for (HloComputation* computation :
@@ -60,16 +101,10 @@
   for (int param_no = 0; param_no < computation_layout->parameter_count();
        ++param_no) {
     auto* shape_layout = computation_layout->mutable_parameter_layout(param_no);
-    if (shape_layout->LayoutIsSet() && shape_layout->shape().IsArray()) {
-      Layout layout = shape_layout->layout();
-      const int64_t element_size_in_bits = layout.element_size_in_bits();
-      if (element_size_in_bits != 0 && element_size_in_bits < 8) {
-        layout.set_element_size_in_bits(0);
-        shape_layout->ResetLayout(layout);
-        changed = true;
-      }
-    }
+    changed |= ProcessInputOrOutputLayout(shape_layout, mode_);
   }
+  auto* output_layout = computation_layout->mutable_result_layout();
+  changed |= ProcessInputOrOutputLayout(output_layout, mode_);
   if (changed) {
     XLA_VLOG_LINES(2, "SubByteNormalization::Run() modified hlo_module:\n" +
                           module->ToString());
diff --git a/third_party/xla/xla/service/sub_byte_normalization.h b/third_party/xla/xla/service/sub_byte_normalization.h
index 010fd2f..39d4b60 100644
--- a/third_party/xla/xla/service/sub_byte_normalization.h
+++ b/third_party/xla/xla/service/sub_byte_normalization.h
@@ -24,21 +24,41 @@
 
 namespace xla {
 
-// A pass that unconditionally removes the sub-byte element_size_in_bits
-// annotation for platforms that doesn't support nibble-packed types. After this
-// pass, a sub-byte type is treated as int8 for space occupation and arithmetic
-// operations. This pass is used in HloEvaluation and testing only.
+// A pass that can modify the sub-byte element_size_in_bits annotation on
+// layouts. Depending on the constructor argument, it either removes the
+// element_size_in_bits annotation for platforms that doesn't support
+// nibble-packed types, or it sets element_size_in_bits to 4 for 4-bit values.
 class SubByteNormalization : public HloModulePass {
  public:
-  SubByteNormalization() = default;
+  enum Mode {
+    // Remove element_size_in_bits on all layouts. Useful for platforms which
+    // do not support nibble-packed types.
+    REMOVE_ELEMENT_SIZE,
+    // Set element_size_in_bits to 4 for layouts of int4 types (S4, U4), and to
+    // 0 for all other layouts. Useful for platforms which support nibble-packed
+    // types.
+    SET_ELEMENT_SIZE,
+  };
+
+  explicit SubByteNormalization(Mode mode) : mode_(mode) {}
 
   ~SubByteNormalization() override = default;
 
-  absl::string_view name() const override { return "int4-size-removal"; }
+  absl::string_view name() const override {
+    switch (mode_) {
+      case REMOVE_ELEMENT_SIZE:
+        return "int4-size-removal";
+      case SET_ELEMENT_SIZE:
+        return "int4-size-setter";
+    }
+  }
   using HloPassInterface::Run;
   StatusOr<bool> Run(
       HloModule* module,
       const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+  Mode mode_;
 };
 
 }  // namespace xla
diff --git a/third_party/xla/xla/service/symbol_repository.h b/third_party/xla/xla/service/symbol_repository.h
new file mode 100644
index 0000000..dd88a8b
--- /dev/null
+++ b/third_party/xla/xla/service/symbol_repository.h
@@ -0,0 +1,121 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_SYMBOL_REPOSITORY_H_
+#define XLA_SERVICE_SYMBOL_REPOSITORY_H_
+
+// Functionality to do lookups in HLO repositories. See export_hlo.h for
+// uploads.
+
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "absl/base/thread_annotations.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/log/log.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "absl/synchronization/mutex.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/compiler.h"
+#include "xla/xla.pb.h"
+
+namespace xla {
+
+// Different backends that repositories might store symbols for. This enum could
+// change to a string in the future if required, but ideally repositories only
+// care about the class of hardware, not the specific make/model and so an enum
+// is fine.
+enum class BackendType {
+  kCpu,
+  kGpu,
+  kTpu,
+};
+
+// Dummy struct for individual backends to add their data to.
+struct BackendSpecificData {
+  virtual ~BackendSpecificData() = default;
+};
+
+// A module and some collected metadata that allow for pure compilation of an
+// HLO module. Implementations may want to subclass to add additional
+// functionality or data.
+struct HloModuleAndMetadata {
+  virtual ~HloModuleAndMetadata() = default;
+
+  std::unique_ptr<HloModule> hlo_module;
+  std::unique_ptr<Compiler::TargetConfig> target_config;
+  // Use static_cast to cast this to a concrete type.
+  std::unique_ptr<BackendSpecificData> backend_specific_data;
+};
+
+// Looks up HLO in a repository. The only non-dummy implementation is
+// Google-internal as of 2023-10.
+class SymbolRepository {
+ public:
+  virtual ~SymbolRepository() = default;
+  virtual absl::StatusOr<std::unique_ptr<HloModuleAndMetadata>> Lookup(
+      absl::string_view symbol_reference, BackendType backend) const = 0;
+};
+
+// Registry for SymbolRepository implementations.
+class SymbolRepositoryRegistry {
+ public:
+  void Register(const std::string& name,
+                std::unique_ptr<SymbolRepository> repo) {
+    absl::MutexLock lock(&mu_);
+    VLOG(1) << "Registering SymbolRepository " << name;
+    repo_[name] = std::move(repo);
+  }
+
+  SymbolRepository* repo(absl::string_view name) {
+    absl::MutexLock lock(&mu_);
+    const auto it = repo_.find(name);
+    if (it == repo_.end()) {
+      return nullptr;
+    }
+
+    return it->second.get();
+  }
+
+ private:
+  absl::Mutex mu_;
+  absl::flat_hash_map<std::string, std::unique_ptr<SymbolRepository>> repo_
+      ABSL_GUARDED_BY(mu_);
+};
+
+inline SymbolRepositoryRegistry& GetGlobalSymbolRepositoryRegistry() {
+  static auto* const registry = new SymbolRepositoryRegistry;
+  return *registry;
+}
+
+// Entry points start here.
+
+inline StatusOr<std::unique_ptr<HloModuleAndMetadata>> LookupSymbolInRepository(
+    absl::string_view repository, absl::string_view symbol_reference,
+    BackendType backend) {
+  if (SymbolRepository* repo =
+          GetGlobalSymbolRepositoryRegistry().repo(repository);
+      repo != nullptr) {
+    return repo->Lookup(symbol_reference, backend);
+  }
+
+  return nullptr;
+}
+
+}  // namespace xla
+
+#endif  // XLA_SERVICE_SYMBOL_REPOSITORY_H_
diff --git a/third_party/xla/xla/service/time_utils.cc b/third_party/xla/xla/service/time_utils.cc
new file mode 100644
index 0000000..227193f
--- /dev/null
+++ b/third_party/xla/xla/service/time_utils.cc
@@ -0,0 +1,38 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/time_utils.h"
+
+#include <cstdint>
+
+namespace xla {
+
+int64_t ExclusiveToInclusiveStartTime(int64_t exclusive_time) {
+  return exclusive_time + 1;
+}
+
+int64_t InclusiveToExclusiveStartTime(int64_t inclusive_time) {
+  return inclusive_time - 1;
+}
+
+int64_t ExclusiveToInclusiveEndTime(int64_t exclusive_time) {
+  return exclusive_time - 1;
+}
+
+int64_t InclusiveToExclusiveEndTime(int64_t inclusive_time) {
+  return inclusive_time + 1;
+}
+
+}  // namespace xla
diff --git a/third_party/xla/xla/service/time_utils.h b/third_party/xla/xla/service/time_utils.h
new file mode 100644
index 0000000..c3ea709
--- /dev/null
+++ b/third_party/xla/xla/service/time_utils.h
@@ -0,0 +1,31 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_TIME_UTILS_H_
+#define XLA_SERVICE_TIME_UTILS_H_
+
+#include <cstdint>
+
+namespace xla {
+
+// Convert between inclusive/exclusive start/end times.
+int64_t ExclusiveToInclusiveStartTime(int64_t exclusive_time);
+int64_t InclusiveToExclusiveStartTime(int64_t inclusive_time);
+int64_t ExclusiveToInclusiveEndTime(int64_t exclusive_time);
+int64_t InclusiveToExclusiveEndTime(int64_t inclusive_time);
+
+}  // namespace xla
+
+#endif  // XLA_SERVICE_TIME_UTILS_H_
diff --git a/third_party/xla/xla/service/tuple_util.cc b/third_party/xla/xla/service/tuple_util.cc
index c0fdf02..01f523c 100644
--- a/third_party/xla/xla/service/tuple_util.cc
+++ b/third_party/xla/xla/service/tuple_util.cc
@@ -15,13 +15,33 @@
 
 #include "xla/service/tuple_util.h"
 
+#include <cstdint>
+#include <string>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/string_view.h"
 #include "absl/types/span.h"
 #include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/service/hlo_value.h"
+#include "xla/shape.h"
+#include "xla/shape_tree.h"
+#include "xla/shape_util.h"
+#include "xla/statusor.h"
+#include "tsl/platform/statusor.h"
 
 namespace xla {
 
 /*static*/ HloInstruction* TupleUtil::ExtractPrefix(HloInstruction* input_tuple,
-                                                    int64_t elements) {
+                                                    int64_t elements,
+                                                    absl::string_view name) {
   CHECK(input_tuple->shape().IsTuple());
 
   HloComputation* computation = input_tuple->parent();
@@ -30,13 +50,18 @@
   std::vector<HloInstruction*> tuple_elements;
   tuple_elements.reserve(elements);
   for (int i = 0; i < elements; i++) {
-    tuple_elements.push_back(
-        computation->AddInstruction(HloInstruction::CreateGetTupleElement(
-            input_shape.tuple_shapes(i), input_tuple, i)));
+    std::string element_name;
+    if (!name.empty()) {
+      element_name = absl::StrCat(name, ".element.", i);
+    }
+    tuple_elements.push_back(computation->AddInstruction(
+        HloInstruction::CreateGetTupleElement(input_shape.tuple_shapes(i),
+                                              input_tuple, i),
+        element_name));
   }
 
   return computation->AddInstruction(
-      HloInstruction::CreateTuple(tuple_elements));
+      HloInstruction::CreateTuple(tuple_elements), name);
 }
 
 /*static*/ HloInstruction* TupleUtil::AppendSuffix(
@@ -171,4 +196,54 @@
   return instruction;
 }
 
+ShapeTree<HloInstruction*> TupleUtil::DisassembleTupleInstruction(
+    HloInstruction* tuple) {
+  const Shape& shape = tuple->shape();
+  ShapeTree<HloInstruction*> result(shape);
+  result.ForEachMutableElement([&](ShapeIndexView index,
+                                   HloInstruction** element) {
+    if (index.empty()) {
+      *element = tuple;
+    } else {
+      ShapeIndexView parent_index = index.subspan(0, index.size() - 1);
+      HloInstruction* parent = result.element(parent_index);
+      std::string name = absl::StrCat(tuple->name(), ".disassembled.",
+                                      absl::StrJoin(index, "."));
+      *element = tuple->parent()->AddInstruction(
+          HloInstruction::CreateGetTupleElement(parent, index.back()), name);
+    }
+  });
+  return result;
+}
+
+HloInstruction* TupleUtil::AssembleTupleInstruction(
+    HloComputation* computation, ShapeTree<HloInstruction*> elements,
+    absl::string_view name) {
+  elements.ForEachMutableElementPostOrder(
+      [&](const ShapeIndex& index, HloInstruction** element) {
+        const Shape& subshape = ShapeUtil::GetSubshape(elements.shape(), index);
+        if (subshape.IsTuple()) {
+          absl::InlinedVector<HloInstruction*, 2> children;
+          ShapeIndex child_index = index;
+          for (int i = 0; i < subshape.tuple_shapes_size(); ++i) {
+            child_index.push_back(i);
+            children.push_back(elements.element(child_index));
+            child_index.pop_back();
+          }
+          std::string new_name;
+          if (!name.empty()) {
+            if (index.empty()) {
+              new_name = std::string(name);
+            } else {
+              new_name =
+                  absl::StrCat(name, ".assembled.", absl::StrJoin(index, "."));
+            }
+          }
+          *element = computation->AddInstruction(
+              HloInstruction::CreateTuple(children), new_name);
+        }
+      });
+  return elements.element({});
+}
+
 }  // namespace xla
diff --git a/third_party/xla/xla/service/tuple_util.h b/third_party/xla/xla/service/tuple_util.h
index 59f241f..5fd434f 100644
--- a/third_party/xla/xla/service/tuple_util.h
+++ b/third_party/xla/xla/service/tuple_util.h
@@ -16,8 +16,15 @@
 #ifndef XLA_SERVICE_TUPLE_UTIL_H_
 #define XLA_SERVICE_TUPLE_UTIL_H_
 
+#include <cstdint>
+
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
 #include "xla/hlo/ir/hlo_instruction.h"
 #include "xla/service/hlo_value.h"
+#include "xla/shape_tree.h"
+#include "xla/shape_util.h"
+#include "xla/statusor.h"
 
 namespace xla {
 class TupleUtil {
@@ -29,7 +36,8 @@
   // The instructions are generated into the computation containing
   // `input_tuple`.
   static HloInstruction* ExtractPrefix(HloInstruction* input_tuple,
-                                       int64_t elements);
+                                       int64_t elements,
+                                       absl::string_view name = "");
 
   // Generates HLO instructions to create a tuple that consists of the values in
   // `trailing_values` appended to `input_tuple` (which must be of tuple shape).
@@ -59,6 +67,24 @@
   // Recursively create kGetTupleElement instructions if the defining position
   // shape is not an array. Returns the new instruction that has array shape.
   static HloInstruction* AddGetTupleElements(const HloPosition& position);
+
+  // Returns a ShapeTree where each index is a GetTupleElement instruction for
+  // that subshape of the tuple.  The root index is the original argument.
+  // The new instructions are added to the parent computation of the argument.
+  // This function is similar to `xla::DisassembleTuple` except it operates
+  // directly on `HloInstruction*`.
+  static ShapeTree<HloInstruction*> DisassembleTupleInstruction(
+      HloInstruction* tuple);
+
+  // Assembles a tuple from a ShapeTree that contains the leaves of the tuple.
+  // Non-leaf elements of the ShapeTree are ignored.  DisassembleTuple and
+  // AssembleTuple are essentially inverse operations.
+  // The new instructions are added to the given computation.
+  // This function is similar to `xla::AssembleTuple` except it operates
+  // directly on `HloInstruction*`.
+  static HloInstruction* AssembleTupleInstruction(
+      HloComputation* computation, ShapeTree<HloInstruction*> elements,
+      absl::string_view name = "");
 };
 }  // namespace xla
 
diff --git a/third_party/xla/xla/service/while_util.cc b/third_party/xla/xla/service/while_util.cc
index e11e11a..e3c7a61 100644
--- a/third_party/xla/xla/service/while_util.cc
+++ b/third_party/xla/xla/service/while_util.cc
@@ -15,31 +15,51 @@
 
 #include "xla/service/while_util.h"
 
+#include <cstdint>
+#include <iterator>
 #include <memory>
+#include <tuple>
+#include <utility>
+#include <vector>
 
 #include "absl/algorithm/container.h"
 #include "absl/container/flat_hash_map.h"
 #include "absl/container/inlined_vector.h"
+#include "absl/functional/function_ref.h"
+#include "absl/log/check.h"
 #include "absl/strings/str_cat.h"
+#include "absl/types/span.h"
+#include "xla/comparison_util.h"
 #include "xla/hlo/ir/hlo_computation.h"
 #include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/layout_util.h"
 #include "xla/literal_util.h"
+#include "xla/service/call_inliner.h"
 #include "xla/service/hlo_creation_utils.h"
 #include "xla/service/tuple_util.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/statusor.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
 
 namespace xla {
 
 using absl::StrCat;
 
-static StatusOr<HloComputation*> WidenWhileCondition(
-    HloComputation* narrow_condition, const Shape& wide_shape) {
+static StatusOr<std::pair<HloComputation*, CallInliner::InlinedInstructionMap>>
+WidenWhileCondition(HloComputation* narrow_condition, const Shape& wide_shape) {
   const Shape& narrow_shape =
       narrow_condition->parameter_instruction(0)->shape();
 
   HloComputation* wide_while_cond = [&]() {
     HloComputation::Builder builder(StrCat("wide.", narrow_condition->name()));
-    builder.AddInstruction(
-        HloInstruction::CreateParameter(0, wide_shape, "wide_param"));
+    builder.AddInstruction(HloInstruction::CreateParameter(
+        0, wide_shape,
+        absl::StrCat("wide.",
+                     narrow_condition->parameter_instruction(0)->name())));
 
     // This is needed so that the root instruction is shaped as a PRED[] -- we
     // need to get this right to begin with since we can't mutate the type of
@@ -50,17 +70,20 @@
     return narrow_condition->parent()->AddEmbeddedComputation(builder.Build());
   }();
 
-  HloInstruction* truncated_parameter =
-      TupleUtil::ExtractPrefix(wide_while_cond->parameter_instruction(0),
-                               narrow_shape.tuple_shapes_size());
+  HloInstruction* truncated_parameter = TupleUtil::ExtractPrefix(
+      wide_while_cond->parameter_instruction(0),
+      narrow_shape.tuple_shapes_size(),
+      absl::StrCat("renarrowed.",
+                   wide_while_cond->parameter_instruction(0)->name()));
   HloInstruction* call_narrow_cond = wide_while_cond->AddInstruction(
       HloInstruction::CreateCall(ShapeUtil::MakeShape(PRED, {}),
                                  {truncated_parameter}, narrow_condition));
 
   wide_while_cond->set_root_instruction(call_narrow_cond);
 
-  TF_RETURN_IF_ERROR(CallInliner::Inline(call_narrow_cond).status());
-  return wide_while_cond;
+  TF_ASSIGN_OR_RETURN(auto inlined_instructions_map,
+                      CallInliner::Inline(call_narrow_cond));
+  return {{wide_while_cond, std::move(inlined_instructions_map)}};
 }
 
 static StatusOr<std::pair<HloComputation*, CallInliner::InlinedInstructionMap>>
@@ -69,14 +92,17 @@
 
   HloComputation* wide_while_body = [&]() {
     HloComputation::Builder builder(StrCat("wide.", narrow_body->name()));
-    builder.AddInstruction(
-        HloInstruction::CreateParameter(0, wide_shape, "wide_param"));
+    builder.AddInstruction(HloInstruction::CreateParameter(
+        0, wide_shape,
+        absl::StrCat("wide.", narrow_body->parameter_instruction(0)->name())));
     return narrow_body->parent()->AddEmbeddedComputation(builder.Build());
   }();
 
   HloInstruction* wide_parameter = wide_while_body->parameter_instruction(0);
   HloInstruction* truncated_parameter = TupleUtil::ExtractPrefix(
-      wide_parameter, narrow_shape.tuple_shapes_size());
+      wide_parameter, narrow_shape.tuple_shapes_size(),
+      absl::StrCat("renarrowed.",
+                   wide_while_body->parameter_instruction(0)->name()));
   HloInstruction* call_narrow_body =
       wide_while_body->AddInstruction(HloInstruction::CreateCall(
           narrow_shape, {truncated_parameter}, narrow_body));
@@ -84,9 +110,11 @@
   std::vector<HloInstruction*> live_through_values;
   for (int i = narrow_shape.tuple_shapes_size();
        i < wide_shape.tuple_shapes_size(); i++) {
-    live_through_values.push_back(
-        wide_while_body->AddInstruction(HloInstruction::CreateGetTupleElement(
-            wide_shape.tuple_shapes(i), wide_parameter, i)));
+    live_through_values.push_back(wide_while_body->AddInstruction(
+        HloInstruction::CreateGetTupleElement(wide_shape.tuple_shapes(i),
+                                              wide_parameter, i),
+        absl::StrCat(wide_while_body->name(), ".through.",
+                     i - narrow_shape.tuple_shapes_size())));
   }
 
   wide_while_body->set_root_instruction(
@@ -109,8 +137,10 @@
     *new_while_shape.add_tuple_shapes() = instruction->shape();
   }
 
+  HloComputation* new_while_condition;
+  CallInliner::InlinedInstructionMap inlined_condition_instructions_map;
   TF_ASSIGN_OR_RETURN(
-      HloComputation * new_while_condition,
+      std::tie(new_while_condition, inlined_condition_instructions_map),
       WidenWhileCondition(while_instr->while_condition(), new_while_shape));
 
   HloComputation* new_while_body;
@@ -138,10 +168,12 @@
   std::vector<HloInstruction*> live_in_instructions;
   for (int64_t i = elements_in_old_while_shape;
        i < new_while_shape.tuple_shapes_size(); i++) {
-    live_in_instructions.push_back(
-        new_while_body->AddInstruction(HloInstruction::CreateGetTupleElement(
+    live_in_instructions.push_back(new_while_body->AddInstruction(
+        HloInstruction::CreateGetTupleElement(
             instructions[i - elements_in_old_while_shape]->shape(),
-            while_body_param, i)));
+            while_body_param, i),
+        absl::StrCat(new_while_body->name(), ".in.",
+                     i - elements_in_old_while_shape)));
   }
 
   WhileUtil::MakeInstructionsLiveInResult result;
@@ -150,6 +182,8 @@
   result.replacement_instr = replacement_instr;
   result.while_body_live_in_values = std::move(live_in_instructions);
   result.while_body_instruction_map = std::move(inlined_instructions_map);
+  result.while_condition_instruction_map =
+      std::move(inlined_condition_instructions_map);
 
   return std::move(result);
 }
diff --git a/third_party/xla/xla/service/while_util.h b/third_party/xla/xla/service/while_util.h
index 97ad67a..ac6aac8 100644
--- a/third_party/xla/xla/service/while_util.h
+++ b/third_party/xla/xla/service/while_util.h
@@ -16,11 +16,17 @@
 #ifndef XLA_SERVICE_WHILE_UTIL_H_
 #define XLA_SERVICE_WHILE_UTIL_H_
 
+#include <cstdint>
+#include <memory>
+#include <vector>
+
 #include "absl/container/flat_hash_map.h"
 #include "absl/container/inlined_vector.h"
 #include "absl/functional/function_ref.h"
+#include "absl/types/span.h"
 #include "xla/hlo/ir/hlo_instruction.h"
 #include "xla/service/call_inliner.h"
+#include "xla/statusor.h"
 
 namespace xla {
 class WhileUtil {
@@ -42,6 +48,11 @@
     // to the corresponding instructions in the body for the newly created while
     // operation.
     CallInliner::InlinedInstructionMap while_body_instruction_map;
+
+    // `while_body_instruction_map` maps instructions in the original while body
+    // to the corresponding instructions in the body for the newly created while
+    // operation.
+    CallInliner::InlinedInstructionMap while_condition_instruction_map;
   };
 
   // Replaces `while_instr` with a new while instruction that is equivalent to
diff --git a/third_party/xla/xla/service/xla_aot_compile_test_autotune_results.prototxt b/third_party/xla/xla/service/xla_aot_compile_test_autotune_results.prototxt
index 1859aa9..f6b15b1 100644
--- a/third_party/xla/xla/service/xla_aot_compile_test_autotune_results.prototxt
+++ b/third_party/xla/xla/service/xla_aot_compile_test_autotune_results.prototxt
@@ -15,7 +15,7 @@
 version: 2
 results {
   device: "sm_6.0 with 17071734784B RAM, 56 cores, 1480500KHz clock, 715000KHz mem clock, 4194304B L2$"
-  hlo: "f32[3,3]{1,0} custom-call(f32[3,3]{1,0}, f32[3,3]{1,0}), custom_call_target=\"__cublas$gemm\", backend_config={\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\"}"
+  hlo: "(f32[3,3]{1,0}, s8[72]{0}) custom-call(f32[3,3]{1,0}, f32[3,3]{1,0}), custom_call_target=\"__cublas$gemm\", backend_config={\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\",\"lhs_stride\":\"9\",\"rhs_stride\":\"9\",\"grad_x\":false,\"grad_y\":false}"
   result {
     gemm {
       algorithm: 13
diff --git a/third_party/xla/xla/service/xla_compile_main.cc b/third_party/xla/xla/service/xla_compile_main.cc
index 9dac2ee..ec73062 100644
--- a/third_party/xla/xla/service/xla_compile_main.cc
+++ b/third_party/xla/xla/service/xla_compile_main.cc
@@ -15,34 +15,50 @@
 
 #include <iostream>
 #include <memory>
+#include <optional>
 #include <string>
 #include <string_view>
 #include <utility>
 #include <vector>
 
+#include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"  // from @llvm-project
 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
+#include "mlir/IR/BuiltinOps.h"  // from @llvm-project
 #include "mlir/IR/DialectRegistry.h"  // from @llvm-project
+#include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/Parser/Parser.h"  // from @llvm-project
 #include "stablehlo/dialect/Register.h"  // from @stablehlo
 #include "xla/autotune_results.pb.h"
 #include "xla/debug_options_flags.h"
+#include "xla/hlo/ir/hlo_module_group.h"
 #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
 #include "xla/pjrt/mlir_to_hlo.h"
 #include "xla/service/compiler.h"
 #include "xla/service/cpu/cpu_compiler.h"
 #include "xla/service/cpu/cpu_executable.h"
+#include "xla/service/executable.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/service/symbol_repository.h"
 #include "xla/statusor.h"
+#include "xla/stream_executor/device_memory_allocator.h"
+#include "xla/stream_executor/stream_executor_pimpl.h"
 #include "xla/tools/hlo_module_loader.h"
+#include "xla/util.h"
 #include "tsl/platform/env.h"
+#include "tsl/platform/errors.h"
 #include "tsl/platform/init_main.h"
 #include "tsl/platform/path.h"
 #include "tsl/platform/protobuf.h"
 #include "tsl/util/command_line_flags.h"
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#include "xla/service/gpu/autotuner_util.h"
 #include "xla/service/gpu/executable.pb.h"
 #include "xla/service/gpu/gpu_compiler.h"
+#include "xla/service/gpu/gpu_symbol_repository.h"
+#include "xla/stream_executor/gpu/gpu_init.h"
 #endif
 #if GOOGLE_CUDA
 #include "xla/service/gpu/nvptx_compiler.h"
@@ -55,12 +71,24 @@
 
 const char kUsageHeader[] =
     "xla_compile performs ahead-of-time compilation of an MHLO, StableHLO or "
-    "HLO "
-    "module,\nresulting in an AotCompilationResult compiled for CPU or GPU.\n"
+    "HLO module,\nresulting in an AotCompilationResult compiled for CPU or GPU."
+    "\n"
     "A typical invocation looks like this:\n"
     "\n"
     "   $ xla_compile --module_file=mymodule.mlir --output_file=output "
     "--platform=cpu"
+    "\n"
+    "For GPU, either the attached GPU or a simulated one may be used. To use "
+    "a simulated device, set --gpu_target_config to a textproto file "
+    "containing a GpuTargetConfigProto forthe device you wish to simulate. To "
+    "use the attached GPU, do not set this flag. When compiling with the "
+    "attached device, --output_file will contain a text-format HLO module "
+    "instead of an AotCompilationResult."
+    "\n"
+    "HLO may also be looked up in a symbol repository (see symbol_repository.h"
+    ") by passing --symbol_repository to a linked-in symbol repository "
+    "implementation and setting --symbol_reference to a reference of a symbol "
+    "understood by that repository."
     "\n";
 
 StatusOr<std::string> AotCompileCpuExecutable(
@@ -76,31 +104,60 @@
 }
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-StatusOr<std::string> AotCompileGpuExecutable(
+StatusOr<std::string> CompileGpuExecutable(
     std::unique_ptr<HloModule> hlo_module,
-    const Compiler::TargetConfig& target_config) {
+    const std::optional<Compiler::TargetConfig> target_config) {
+  const bool aot = target_config.has_value();
+
 #if GOOGLE_CUDA
   auto gpu_compiler = gpu::NVPTXCompiler();
 #elif TENSORFLOW_USE_ROCM
   auto gpu_compiler = gpu::AMDGPUCompiler();
 #endif
   Compiler::CompileOptions compile_options;
-  compile_options.target_config = target_config;
+
+  stream_executor::StreamExecutor* stream_executor = nullptr;
+  std::unique_ptr<stream_executor::StreamExecutorMemoryAllocator> allocator;
+  if (aot) {
+    compile_options.target_config = *target_config;
+  } else {
+    TF_RETURN_IF_ERROR(stream_executor::ValidateGPUMachineManager());
+    TF_ASSIGN_OR_RETURN(
+        stream_executor,
+        stream_executor::GPUMachineManager()->ExecutorForDevice(0));
+    allocator =
+        std::make_unique<stream_executor::StreamExecutorMemoryAllocator>(
+            stream_executor);
+    compile_options.device_allocator = allocator.get();
+  }
+
   TF_ASSIGN_OR_RETURN(
       std::unique_ptr<HloModule> module_after_opt,
-      gpu_compiler.RunHloPasses(std::move(hlo_module),
-                                /*stream_exec=*/nullptr, compile_options));
+      gpu_compiler.RunHloPasses(std::move(hlo_module), stream_executor,
+                                compile_options));
 
-  auto module_group =
-      std::make_unique<HloModuleGroup>(std::move(module_after_opt));
-  AotCompilationOptions aot_options(gpu_compiler.PlatformId());
-  aot_options.set_target_config(target_config);
+  if (aot) {
+    auto module_group =
+        std::make_unique<HloModuleGroup>(std::move(module_after_opt));
+
+    AotCompilationOptions aot_options(gpu_compiler.PlatformId());
+    aot_options.set_target_config(*target_config);
+
+    TF_ASSIGN_OR_RETURN(
+        std::vector<std::unique_ptr<AotCompilationResult>> aot_results,
+        gpu_compiler.CompileAheadOfTime(std::move(module_group), aot_options));
+    TF_ASSIGN_OR_RETURN(std::string result,
+                        aot_results[0]->SerializeAsString());
+    return result;
+  }
+
   TF_ASSIGN_OR_RETURN(
-      std::vector<std::unique_ptr<AotCompilationResult>> aot_results,
-      gpu_compiler.CompileAheadOfTime(std::move(module_group), aot_options));
-  TF_ASSIGN_OR_RETURN(std::string result, aot_results[0]->SerializeAsString());
-  return result;
+      std::unique_ptr<Executable> executable,
+      gpu_compiler.RunBackend(std::move(module_after_opt), stream_executor,
+                              compile_options));
+  return executable->module().ToString();
 }
+
 #endif
 
 xla::StatusOr<std::unique_ptr<HloModule>> LoadModule(
@@ -141,13 +198,32 @@
   return HloModule::CreateFromProto(hlo_module_proto, config);
 }
 
-xla::Status XlaCompileMain(const std::string& module_path,
-                           const std::string& output_path,
-                           const std::string& platform,
-                           const std::string& gpu_target_config_path,
-                           const std::string& autotune_results_path) {
-  TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> hlo_module,
-                      LoadModule(module_path));
+Status XlaCompileMain(
+    const std::string& module_path, const std::string& output_path,
+    const std::string& platform, const std::string& gpu_target_config_path,
+    const std::string& autotune_results_path, const std::string& symbol_repo,
+    const std::string& symbol_id, const bool use_attached_device) {
+  std::unique_ptr<HloModule> hlo_module;
+  std::unique_ptr<Compiler::TargetConfig> target_config;
+  if (!symbol_id.empty()) {
+    TF_ASSIGN_OR_RETURN(
+        std::unique_ptr<HloModuleAndMetadata> mod,
+        LookupSymbolInRepository(symbol_repo, symbol_id, BackendType::kGpu));
+    if (mod == nullptr) {
+      return absl::NotFoundError(
+          absl::StrCat("Could not find ", symbol_id, " in ", symbol_repo));
+    }
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+    if (auto* data = static_cast<gpu::GpuBackendSpecificData*>(
+            mod->backend_specific_data.get());
+        data != nullptr) {
+      target_config = std::move(mod->target_config);
+    }
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+    hlo_module = std::move(mod->hlo_module);
+  } else {
+    TF_ASSIGN_OR_RETURN(hlo_module, LoadModule(module_path));
+  }
 
   // Run AOT compilation.
   std::string result;
@@ -155,26 +231,33 @@
     TF_ASSIGN_OR_RETURN(result, AotCompileCpuExecutable(std::move(hlo_module)));
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
   } else if (platform == "gpu") {
-    // Parse GpuTargetConfig.
-    std::string gpu_target_config_string;
-    TF_RETURN_IF_ERROR(tsl::ReadFileToString(tsl::Env::Default(),
-                                             gpu_target_config_path,
-                                             &gpu_target_config_string));
-    stream_executor::GpuTargetConfigProto gpu_target_config_proto;
+    if (!gpu_target_config_path.empty()) {
+      // Parse GpuTargetConfig.
+      std::string gpu_target_config_string;
+      TF_RETURN_IF_ERROR(tsl::ReadFileToString(tsl::Env::Default(),
+                                               gpu_target_config_path,
+                                               &gpu_target_config_string));
+      stream_executor::GpuTargetConfigProto gpu_target_config_proto;
 
-    if (!tsl::protobuf::TextFormat::ParseFromString(gpu_target_config_string,
-                                                    &gpu_target_config_proto)) {
-      return FailedPrecondition("Failed to parse GpuTargetConfigProto");
+      if (!tsl::protobuf::TextFormat::ParseFromString(
+              gpu_target_config_string, &gpu_target_config_proto)) {
+        return FailedPrecondition("Failed to parse GpuTargetConfigProto");
+      }
+
+      target_config =
+          std::make_unique<Compiler::TargetConfig>(gpu_target_config_proto);
+
+      if (!autotune_results_path.empty()) {
+        TF_RETURN_IF_ERROR(gpu::AutotunerUtil::LoadAutotuneResultsFromFile(
+            autotune_results_path));
+      }
     }
 
-    Compiler::TargetConfig gpu_target_config(gpu_target_config_proto);
-
-    if (!autotune_results_path.empty()) {
-      TF_RETURN_IF_ERROR(gpu::AutotunerUtil::LoadAutotuneResultsFromFile(
-          autotune_results_path));
-    }
-    TF_ASSIGN_OR_RETURN(result, AotCompileGpuExecutable(std::move(hlo_module),
-                                                        gpu_target_config));
+    std::optional<Compiler::TargetConfig> cfg =
+        (use_attached_device) ? std::nullopt
+                              : std::make_optional(*std::move(target_config));
+    TF_ASSIGN_OR_RETURN(result,
+                        CompileGpuExecutable(std::move(hlo_module), cfg));
 #endif
   } else {
     return Unimplemented("platform %s not supported", platform);
@@ -189,13 +272,16 @@
 }  // end namespace xla
 
 // Read the input file containing the MHLO module, and write a Serialized
-// AotCompilationResult to the output file.
+// AotCompilationResult or Executable to the output file.
 int main(int argc, char* argv[]) {
   std::string module_path;
   std::string output_path;
   std::string platform;
   std::string gpu_target_config_path;
   std::string autotune_results_path;
+  std::string symbol_repository;
+  std::string symbol_id;
+  bool use_attached_device = false;
   std::vector<tsl::Flag> flag_list = {
       tsl::Flag("module_file", &module_path,
                 "The path to the HLO, MHLO or StableHLO file"),
@@ -203,11 +289,24 @@
       tsl::Flag("platform", &platform,
                 "The platform on which the built executable runs"),
       tsl::Flag("gpu_target_config", &gpu_target_config_path,
-                "The path to serialized GpuTargetConfig, required when"
-                " compiling for GPU"),
+                "The path to a text-format GpuTargetConfig. If not provided, "
+                "an attached GPU will be used."),
       tsl::Flag("autotune_results", &autotune_results_path,
                 "The path to AutotuneResults, optional when compiling for"
-                " GPU")};
+                " GPU"),
+      tsl::Flag("symbol_repo", &symbol_repository,
+                "Which SymbolRepository to look up --symbol_reference in. If "
+                "the repository contains a GpuTargetConfig, "
+                "--gpu_target_config will take precedence if it is also set."),
+      tsl::Flag("symbol_reference", &symbol_id,
+                "Symbol ID to look up in a SymbolRepository. Overrides "
+                "--module_file."),
+      tsl::Flag("use_attached_device", &use_attached_device,
+                "Whether to use the attached GPU or not. Overrides the "
+                "AOT-vs-device-backed inference based on the presence of "
+                "--gpu_target_config, which is relevant when a GpuTargetConfig "
+                "can be found in the symbol repository."),
+  };
 
   tsl::string usage = xla::xla_compile::kUsageHeader;
   usage += tsl::Flags::Usage(argv[0], flag_list);
@@ -223,7 +322,7 @@
 
   xla::Status result = xla::xla_compile::XlaCompileMain(
       module_path, output_path, platform, gpu_target_config_path,
-      autotune_results_path);
+      autotune_results_path, symbol_repository, symbol_id, use_attached_device);
   if (!result.ok()) {
     LOG(ERROR) << "Compilation failed: " << result;
     return 1;
diff --git a/third_party/xla/xla/shape_util.h b/third_party/xla/xla/shape_util.h
index 3519f91..760fc42 100644
--- a/third_party/xla/xla/shape_util.h
+++ b/third_party/xla/xla/shape_util.h
@@ -33,6 +33,7 @@
 
 #include "absl/container/inlined_vector.h"
 #include "absl/functional/function_ref.h"
+#include "absl/log/check.h"
 #include "absl/types/span.h"
 #include "xla/layout.h"
 #include "xla/layout_util.h"
@@ -565,6 +566,57 @@
     return ForEachMutableSubshapeWithStatusHelper(shape, fn, &index);
   }
 
+  // Calls the given visitor function for each subshape of the given shape.
+  // Subshapes are visited in DFS post-order starting with the entire shape
+  // (index {}).
+  //
+  // The visitor function must have the signature
+  //
+  //   void fn(const Shape& subshape, const ShapeIndex& index), or
+  //   void fn(Shape* subshape, const ShapeIndex& index) (mutable version)
+  template <typename Fn>
+  static void ForEachSubshapePostOrder(const Shape& shape, Fn&& fn) {
+    ForEachSubshapePostOrderWithStatus(shape, [&](const Shape& subshape,
+                                                  const ShapeIndex& index) {
+      fn(subshape, index);
+      return OkStatus();
+    }).IgnoreError();
+  }
+  template <typename Fn>
+  static void ForEachMutableSubshapePostOrder(Shape* shape, Fn&& fn) {
+    ForEachMutableSubshapePostOrderWithStatus(
+        shape,
+        [&](Shape* subshape, const ShapeIndex& index) {
+          fn(subshape, index);
+          return OkStatus();
+        })
+        .IgnoreError();
+  }
+
+  // Variants of ForEach(Mutable)SubshapePostOrder which propagate Status from
+  // the visitor function.
+  //
+  // Visitor function must have the signature
+  //
+  //   Status fn(const Shape& subshape, const ShapeIndex& index), or
+  //   Status fn(Shape* subshape, const ShapeIndex& index) (mutable version)
+  //
+  template <typename Fn>
+  static Status ForEachSubshapePostOrderWithStatus(const Shape& shape,
+                                                   Fn&& fn) {
+    return ForEachMutableSubshapePostOrderWithStatus(
+        const_cast<Shape*>(&shape),
+        [&](Shape* subshape, const ShapeIndex& index) -> Status {
+          return fn(*const_cast<const Shape*>(subshape), index);
+        });
+  }
+  template <typename Fn>
+  static Status ForEachMutableSubshapePostOrderWithStatus(Shape* shape,
+                                                          Fn&& fn) {
+    ShapeIndex index;
+    return ForEachMutableSubshapePostOrderWithStatusHelper(shape, fn, &index);
+  }
+
   // Returns true if `shape` (which must be an array) with degenerate dimensions
   // (dimensions with bound 1).
   static bool HasDegenerateDimensions(const Shape& shape);
@@ -933,6 +985,23 @@
     return OkStatus();
   }
 
+  // Helper for ForEachSubshapePost which visits the subshapes of the given
+  // shape in DFS post-order.
+  template <typename Fn>
+  static Status ForEachMutableSubshapePostOrderWithStatusHelper(
+      Shape* shape, Fn&& fn, ShapeIndex* index) {
+    if (shape->IsTuple()) {
+      for (int64_t i = 0; i < ShapeUtil::TupleElementCount(*shape); ++i) {
+        index->push_back(i);
+        TF_RETURN_IF_ERROR(ForEachMutableSubshapePostOrderWithStatusHelper(
+            shape->mutable_tuple_shapes(i), fn, index));
+        index->pop_back();
+      }
+    }
+    TF_RETURN_IF_ERROR(fn(shape, *index));
+    return OkStatus();
+  }
+
   // Keeps track of the iteration state for the ForEach...Internal routines
   struct ForEachState {
     ForEachState(const Shape& s, absl::Span<const int64_t> b,
diff --git a/third_party/xla/xla/stream_executor/BUILD b/third_party/xla/xla/stream_executor/BUILD
index 42a3409..7204b26 100644
--- a/third_party/xla/xla/stream_executor/BUILD
+++ b/third_party/xla/xla/stream_executor/BUILD
@@ -1,9 +1,3 @@
-# GPU executor library for data-parallel kernel launches and cross-platform
-# HPC-library APIs.
-#
-# Throughout this file, all targets are built with the standard crosstool and
-# do not link against restricted binary blobs.
-
 load("//xla:xla.bzl", "xla_cc_test")
 load("//xla/stream_executor:build_defs.bzl", "stream_executor_friends", "stream_executor_internal")
 load("@local_tsl//tsl:tsl.bzl", "set_external_visibility", "transitive_hdrs")
@@ -71,7 +65,6 @@
         "event.h",
         "executor_cache.h",
         "kernel.h",
-        "kernel_cache_config.h",
         "kernel_spec.h",
         "launch_dim.h",
         "module_spec.h",
@@ -105,7 +98,6 @@
 # making `stream_executor_headers` self-contained, which means that you can include any of the
 # public API headers and don't worry about adding dependencies).
 STREAM_EXECUTOR_DEPENDENCIES = [
-    ":allocator_stats",
     ":device_description_proto_cc",
     ":host_or_device_scalar",
     ":multi_platform_manager",
@@ -119,7 +111,6 @@
     "@com_google_absl//absl/status",
     "@com_google_absl//absl/strings",
     "@com_google_absl//absl/synchronization",
-    "@com_google_absl//absl/types:optional",
     "@com_google_absl//absl/types:span",
     "//xla/stream_executor/platform",
     "@local_tsl//tsl/framework:device_id",
@@ -156,10 +147,6 @@
 # `stream_executor` and `stream_executor_headers` targets). This is mostly a historical artifact of
 # an era when StreamExecutor was a part of Tensorflow.
 
-# TODO(ezhulenev): Consider merging some (all?) of these libraries into StreamExecutor target, e.g.
-# does it really make sense to have a separate `device_memory` library which is not usable without
-# StreamExecutor.
-
 tf_proto_library(
     name = "device_description_proto",
     srcs = ["device_description.proto"],
@@ -193,6 +180,7 @@
     deps = ["//xla/stream_executor/platform"],
 )
 
+# TODO(ezhulenev): Merge this target into `stream_executor`.
 cc_library(
     name = "device_memory_allocator",
     hdrs = ["device_memory_allocator.h"],
@@ -247,6 +235,25 @@
 )
 
 cc_library(
+    name = "multi_platform_manager",
+    srcs = ["multi_platform_manager.cc"],
+    hdrs = ["multi_platform_manager.h"],
+    visibility = ["//visibility:public"],
+    deps = [
+        ":platform",
+        "//xla/stream_executor/platform",
+        "@com_google_absl//absl/base:core_headers",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/strings:str_format",
+        "@com_google_absl//absl/synchronization",
+        "@local_tsl//tsl/platform:errors",
+        "@local_tsl//tsl/platform:status",
+        "@local_tsl//tsl/platform:statusor",
+    ],
+)
+
+cc_library(
     name = "numeric_options",
     hdrs = ["numeric_options.h"],
     visibility = ["//visibility:public"],
@@ -269,6 +276,80 @@
 )
 
 #===--------------------------------------------------------------------------------------------===#
+# StreamExecutor plugins
+#===--------------------------------------------------------------------------------------------===#
+
+# TODO(ezhulenev): Today all StreamExecutor plugins are linked by default into the `stream_executor`
+# target and leak into "core" APIs. We should decouple all plugins into optional dependencies, and
+# make sure that they are not exposed via "core" APIs (se::Stream, se::StreamExecutor, etc.).
+
+cc_library(
+    name = "blas",
+    srcs = ["blas.cc"],
+    hdrs = ["blas.h"],
+    visibility = ["//visibility:public"],
+    deps = [
+        ":stream_executor_headers",
+        "//xla/stream_executor/platform",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:span",
+        "@local_tsl//tsl/platform:statusor",
+        "@local_tsl//tsl/protobuf:dnn_proto_cc",
+    ],
+)
+
+cc_library(
+    name = "dnn",
+    srcs = ["dnn.cc"],
+    hdrs = ["dnn.h"],
+    visibility = ["//visibility:public"],
+    deps = [
+        ":device_description_proto_cc",
+        ":device_memory",
+        ":numeric_options",
+        ":stream_executor_headers",
+        "//xla/stream_executor/platform",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:btree",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/strings:str_format",
+        "@com_google_absl//absl/types:span",
+        "@local_tsl//tsl/lib/strings:proto_serialization",
+        "@local_tsl//tsl/platform:logging",
+        "@local_tsl//tsl/platform:status",
+        "@local_tsl//tsl/platform:statusor",
+        "@local_tsl//tsl/protobuf:dnn_proto_cc",
+    ] + if_static(["@com_google_protobuf//:protobuf"]),
+)
+
+cc_library(
+    name = "fft",
+    hdrs = ["fft.h"],
+    visibility = ["//visibility:public"],
+    deps = [
+        "//xla/stream_executor/platform",
+    ],
+)
+
+cc_library(
+    name = "lazy_op_runner",
+    hdrs = ["lazy_op_runner.h"],
+    visibility = ["//visibility:public"],
+    deps = [
+        ":dnn",
+        ":stream_executor_headers",
+        "@com_google_absl//absl/base",
+    ],
+)
+
+# TODO(ezhulenev): This should be removed.
+exports_files(
+    ["lazy_op_runner.h"],
+    visibility = ["//visibility:public"],
+)
+
+#===--------------------------------------------------------------------------------------------===#
 # StreamExecutor platform-dependent interfaces
 #===--------------------------------------------------------------------------------------------===#
 
@@ -336,17 +417,23 @@
 # StreamExecutor private implementation (has private visibility)
 #===--------------------------------------------------------------------------------------------===#
 
-# TODO(ezhulenev): We need a clear separation between StreamExecutor "core" (event, stream, etc.),
-# and plugins (FFT, Blas, etc.). We should not be mixing headers and implementation of core
-# libraries with plugins. Today `stream_executor` exports all plugin headers, and we should remove
-# this, however it requires more work to break dependency from "core" to plugin implementation.
-
 # Targets that implement StreamExecutor APIs are private, and should not be used outside of
 # `stream_executor` package. Clients should depend on `stream_executor` (headers and
 # implementation) or `stream_executor_headers` (only headers, if there is a reason not to link
 # implementation) if they want to use StreamExecutor.
 
 cc_library(
+    name = "allocator_stats",
+    srcs = ["allocator_stats.cc"],
+    hdrs = ["allocator_stats.h"],
+    visibility = ["//visibility:public"],
+    deps = [
+        "//xla/stream_executor/platform",
+        "@com_google_absl//absl/strings:str_format",
+    ],
+)
+
+cc_library(
     name = "command_buffer",
     srcs = ["command_buffer.cc"],
     hdrs = ["command_buffer.h"],
@@ -450,27 +537,6 @@
 )
 
 cc_library(
-    name = "stream",
-    srcs = ["stream.cc"],
-    hdrs = ["stream.h"],
-    visibility = ["//visibility:public"],
-    deps = [
-        ":blas",
-        ":dnn",
-        ":stream_executor_headers",
-        ":stream_executor_pimpl",
-        "//xla/stream_executor/platform",
-        "@com_google_absl//absl/base:core_headers",
-        "@com_google_absl//absl/functional:any_invocable",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/synchronization",
-        "@eigen_archive//:eigen3",
-        "@local_tsl//tsl/platform:logging",
-        "@local_tsl//tsl/platform:stacktrace",
-    ],
-)
-
-cc_library(
     name = "temporary_device_memory",
     srcs = ["temporary_device_memory.cc"],
     hdrs = ["temporary_device_memory.h"],
@@ -511,49 +577,15 @@
     deps = [":stream_executor_headers"],
 )
 
-cc_library(
-    name = "allocator_stats",
-    srcs = ["allocator_stats.cc"],
-    hdrs = ["allocator_stats.h"],
-    visibility = ["//visibility:public"],
-    deps = [
-        "//xla/stream_executor/platform",
-        "@com_google_absl//absl/strings:str_format",
-        "@com_google_absl//absl/types:optional",
-    ],
-)
-
-cc_library(
-    name = "fft",
-    hdrs = ["fft.h"],
-    visibility = ["//visibility:public"],
-    deps = [
-        "//xla/stream_executor/platform",
-    ],
-)
-
-cc_library(
-    name = "multi_platform_manager",
-    srcs = ["multi_platform_manager.cc"],
-    hdrs = ["multi_platform_manager.h"],
-    visibility = ["//visibility:public"],
-    deps = [
-        ":platform",
-        "//xla/stream_executor/platform",
-        "@com_google_absl//absl/base:core_headers",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/strings:str_format",
-        "@com_google_absl//absl/synchronization",
-        "@local_tsl//tsl/platform:errors",
-        "@local_tsl//tsl/platform:status",
-        "@local_tsl//tsl/platform:statusor",
-    ],
-)
-
+# TODO(ezhulenev): This should be merged into regular `stream_executor` target and `stream.cc` moved
+# into its own target, however today we have problems with backward references when we try to link
+# everything together. See: https://lld.llvm.org/ELF/warn_backrefs.html.
 cc_library(
     name = "stream_executor_pimpl",
-    srcs = ["stream_executor_pimpl.cc"],
+    srcs = [
+        "stream.cc",
+        "stream_executor_pimpl.cc",
+    ],
     hdrs = ["stream_executor_pimpl.h"],
     visibility = ["//visibility:public"],
     deps = [
@@ -582,7 +614,6 @@
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/strings:str_format",
         "@com_google_absl//absl/synchronization",
-        "@com_google_absl//absl/types:optional",
         "@com_google_absl//absl/types:span",
         "@eigen_archive//:eigen3",
         "@local_tsl//tsl/platform:env",
@@ -596,47 +627,11 @@
     ],
 )
 
-cc_library(
-    name = "blas",
-    srcs = ["blas.cc"],
-    hdrs = ["blas.h"],
-    visibility = ["//visibility:public"],
-    deps = [
-        ":stream_executor_headers",
-        "//xla/stream_executor/platform",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/types:span",
-        "@local_tsl//tsl/platform:statusor",
-        "@local_tsl//tsl/protobuf:dnn_proto_cc",
-    ],
-)
-
-cc_library(
-    name = "dnn",
-    srcs = ["dnn.cc"],
-    hdrs = ["dnn.h"],
-    visibility = ["//visibility:public"],
-    deps = [
-        ":device_description_proto_cc",
-        ":device_memory",
-        ":numeric_options",
-        ":stream_executor_headers",
-        "//xla/stream_executor/platform",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:btree",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/strings:str_format",
-        "@com_google_absl//absl/types:optional",
-        "@com_google_absl//absl/types:span",
-        "@local_tsl//tsl/lib/strings:proto_serialization",
-        "@local_tsl//tsl/platform:logging",
-        "@local_tsl//tsl/platform:status",
-        "@local_tsl//tsl/platform:statusor",
-        "@local_tsl//tsl/protobuf:dnn_proto_cc",
-    ] + if_static(["@com_google_protobuf//:protobuf"]),
-)
-
+# We have a separate `stream_executor_impl` target because in open source we are building multiple
+# shared libraries and then link them together (this is an implementation detail of Tensorflow
+# framework), and we take extra care not to define symbols in multiple objects. Otherwise we can
+# end up with static singletons declared in multiple objects, ODR violations, and many other bad
+# things that lead to nearly impossible to debug run time crashes.
 cc_library(
     name = "stream_executor_impl",
     visibility = ["//visibility:public"],
@@ -653,7 +648,6 @@
         ":multi_platform_manager",
         ":platform",
         ":scratch_allocator",
-        ":stream",
         ":stream_executor_headers",
         ":stream_executor_pimpl",
         ":temporary_device_memory",
@@ -662,6 +656,10 @@
     ],
 )
 
+#===--------------------------------------------------------------------------------------------===#
+# StreamExecutor tests
+#===--------------------------------------------------------------------------------------------===#
+
 xla_cc_test(
     name = "stream_test",
     size = "small",
@@ -685,21 +683,9 @@
     ],
 )
 
-cc_library(
-    name = "lazy_op_runner",
-    hdrs = ["lazy_op_runner.h"],
-    visibility = ["//visibility:public"],
-    deps = [
-        ":dnn",
-        ":stream_executor_headers",
-        "@com_google_absl//absl/base",
-    ],
-)
-
-exports_files(
-    ["lazy_op_runner.h"],
-    visibility = ["//visibility:public"],
-)
+#===--------------------------------------------------------------------------------------------===#
+# Aliases for StreamExecutor platforms
+#===--------------------------------------------------------------------------------------------===#
 
 alias(
     name = "cuda_platform",
@@ -713,7 +699,7 @@
     visibility = ["//visibility:public"],
 )
 
-# TODO(se-owner): document or remove this.
+# TODO(ezhulenev): This should be removed.
 cc_library(
     name = "stream_executor_bundle",
     visibility = ["//visibility:public"],
diff --git a/third_party/xla/xla/stream_executor/allocator_stats.h b/third_party/xla/xla/stream_executor/allocator_stats.h
index d073d08..1e15d0c 100644
--- a/third_party/xla/xla/stream_executor/allocator_stats.h
+++ b/third_party/xla/xla/stream_executor/allocator_stats.h
@@ -18,7 +18,6 @@
 
 #include <string>
 
-#include "absl/types/optional.h"
 #include "xla/stream_executor/platform/port.h"
 
 namespace stream_executor {
diff --git a/third_party/xla/xla/stream_executor/blas.cc b/third_party/xla/xla/stream_executor/blas.cc
index efcda8d..594acf4 100644
--- a/third_party/xla/xla/stream_executor/blas.cc
+++ b/third_party/xla/xla/stream_executor/blas.cc
@@ -18,10 +18,32 @@
 #include <cstdint>
 
 #include "absl/strings/str_cat.h"
+#include "xla/stream_executor/device_memory.h"
 
 namespace stream_executor {
 namespace blas {
 
+// TODO(ezhulenev): We need a scoped thread local map-like container to make
+// sure that we can have multiple BlasSupport instances that do not overwrite
+// each others workspaces. For not it's ok as we know that this can't happen.
+static thread_local DeviceMemoryBase* workspace_thread_local = nullptr;
+
+BlasSupport::ScopedWorkspace::ScopedWorkspace(BlasSupport* blas,
+                                              DeviceMemoryBase* workspace)
+    : blas_(blas) {
+  blas->SetWorkspace(workspace);
+}
+
+BlasSupport::ScopedWorkspace::~ScopedWorkspace() { blas_->ResetWorkspace(); }
+
+DeviceMemoryBase* BlasSupport::GetWorkspace() { return workspace_thread_local; }
+
+void BlasSupport::SetWorkspace(DeviceMemoryBase* workspace) {
+  workspace_thread_local = workspace;
+}
+
+void BlasSupport::ResetWorkspace() { workspace_thread_local = nullptr; }
+
 std::string TransposeString(Transpose t) {
   switch (t) {
     case Transpose::kNoTranspose:
diff --git a/third_party/xla/xla/stream_executor/blas.h b/third_party/xla/xla/stream_executor/blas.h
index a62ec9a..0cdb0dd 100644
--- a/third_party/xla/xla/stream_executor/blas.h
+++ b/third_party/xla/xla/stream_executor/blas.h
@@ -121,6 +121,16 @@
   kTF32AsF32,  // Allow downcast to TF32 precision.
 };
 
+// Call context information for GEMM API calls
+// This is extra information that can optionally be passed down to the blas
+// library, so that it can pick the efficient imlpementation based on context
+enum class CallContext {
+  kNone = 0,            // No information
+  kForward = 1,         // call happens in "forward" pass
+  kBackpropInput1 = 2,  // call happens in "backprop" pass for the first input
+  kBackpropInput2 = 4,  // call happens in "backprop" pass for the second input
+};
+
 // Converts a ComputationType to a string.
 std::string ComputationTypeString(ComputationType ty);
 
@@ -323,7 +333,8 @@
                                  const DeviceMemoryBase &a, int lda,
                                  const DeviceMemoryBase &b, int ldb,
                                  const void *beta, DeviceMemoryBase *c, int ldc,
-                                 const NumericOptions &numeric_options) = 0;
+                                 const NumericOptions &numeric_options,
+                                 blas::CallContext context) = 0;
 
   // Gets a list of supported algorithms for DoBlasGemmWithAlgorithm.
   virtual bool GetBlasGemmAlgorithms(
@@ -348,8 +359,7 @@
       DeviceMemoryBase *c, DataType type_c, int ldc,
       ComputationType computation_type, AlgorithmType algorithm,
       const NumericOptions &numeric_options,
-      ProfileResult *output_profile_result) = 0;
-
+      ProfileResult *output_profile_result, blas::CallContext context) = 0;
   virtual tsl::Status DoBlasGemmStridedBatchedWithAlgorithm(
       Stream *stream, blas::Transpose transa, blas::Transpose transb,
       uint64_t m, uint64_t n, uint64 k, const void *alpha,
@@ -358,7 +368,7 @@
       const void *beta, DeviceMemoryBase *c, DataType type_c, int ldc,
       int64_t stride_c, int batch_count, ComputationType computation_type,
       AlgorithmType algorithm, const NumericOptions &numeric_options,
-      ProfileResult *output_profile_result) = 0;
+      ProfileResult *output_profile_result, blas::CallContext context) = 0;
 
   // Computes a batch of matrix-matrix product with general matrices.
   // This is a batched version of DoBlasGemm.
@@ -372,35 +382,30 @@
                                  float beta, DeviceMemorySlice<Eigen::half> c,
                                  int ldc, int batch_count,
                                  const NumericOptions &numeric_options,
-                                 ScratchAllocator *scratch_allocator) = 0;
-  virtual bool DoBlasGemmBatched(Stream *stream, blas::Transpose transa,
-                                 blas::Transpose transb, uint64_t m, uint64_t n,
-                                 uint64 k, float alpha,
-                                 DeviceMemorySlice<Eigen::bfloat16> a, int lda,
-                                 DeviceMemorySlice<Eigen::bfloat16> b, int ldb,
-                                 float beta,
-                                 DeviceMemorySlice<Eigen::bfloat16> c, int ldc,
-                                 int batch_count,
-                                 const NumericOptions &numeric_options,
-                                 ScratchAllocator *scratch_allocator) = 0;
-  virtual bool DoBlasGemmBatched(Stream *stream, blas::Transpose transa,
-                                 blas::Transpose transb, uint64_t m, uint64_t n,
-                                 uint64 k, float alpha,
-                                 DeviceMemorySlice<float> a, int lda,
-                                 DeviceMemorySlice<float> b, int ldb,
-                                 float beta, DeviceMemorySlice<float> c,
-                                 int ldc, int batch_count,
-                                 const NumericOptions &numeric_options,
-                                 ScratchAllocator *scratch_allocator) = 0;
-  virtual bool DoBlasGemmBatched(Stream *stream, blas::Transpose transa,
-                                 blas::Transpose transb, uint64_t m, uint64_t n,
-                                 uint64 k, double alpha,
-                                 DeviceMemorySlice<double> a, int lda,
-                                 DeviceMemorySlice<double> b, int ldb,
-                                 double beta, DeviceMemorySlice<double> c,
-                                 int ldc, int batch_count,
-                                 const NumericOptions &numeric_options,
-                                 ScratchAllocator *scratch_allocator) = 0;
+                                 ScratchAllocator *scratch_allocator,
+                                 blas::CallContext context) = 0;
+  virtual bool DoBlasGemmBatched(
+      Stream *stream, blas::Transpose transa, blas::Transpose transb,
+      uint64_t m, uint64_t n, uint64 k, float alpha,
+      DeviceMemorySlice<Eigen::bfloat16> a, int lda,
+      DeviceMemorySlice<Eigen::bfloat16> b, int ldb, float beta,
+      DeviceMemorySlice<Eigen::bfloat16> c, int ldc, int batch_count,
+      const NumericOptions &numeric_options,
+      ScratchAllocator *scratch_allocator, blas::CallContext context) = 0;
+  virtual bool DoBlasGemmBatched(
+      Stream *stream, blas::Transpose transa, blas::Transpose transb,
+      uint64_t m, uint64_t n, uint64 k, float alpha, DeviceMemorySlice<float> a,
+      int lda, DeviceMemorySlice<float> b, int ldb, float beta,
+      DeviceMemorySlice<float> c, int ldc, int batch_count,
+      const NumericOptions &numeric_options,
+      ScratchAllocator *scratch_allocator, blas::CallContext context) = 0;
+  virtual bool DoBlasGemmBatched(
+      Stream *stream, blas::Transpose transa, blas::Transpose transb,
+      uint64_t m, uint64_t n, uint64 k, double alpha,
+      DeviceMemorySlice<double> a, int lda, DeviceMemorySlice<double> b,
+      int ldb, double beta, DeviceMemorySlice<double> c, int ldc,
+      int batch_count, const NumericOptions &numeric_options,
+      ScratchAllocator *scratch_allocator, blas::CallContext context) = 0;
   virtual bool DoBlasGemmBatched(
       Stream *stream, blas::Transpose transa, blas::Transpose transb,
       uint64_t m, uint64_t n, uint64 k, std::complex<float> alpha,
@@ -408,7 +413,7 @@
       DeviceMemorySlice<std::complex<float>> b, int ldb,
       std::complex<float> beta, DeviceMemorySlice<std::complex<float>> c,
       int ldc, int batch_count, const NumericOptions &numeric_options,
-      ScratchAllocator *scratch_allocator) = 0;
+      ScratchAllocator *scratch_allocator, blas::CallContext context) = 0;
   virtual bool DoBlasGemmBatched(
       Stream *stream, blas::Transpose transa, blas::Transpose transb,
       uint64_t m, uint64_t n, uint64 k, std::complex<double> alpha,
@@ -416,8 +421,7 @@
       DeviceMemorySlice<std::complex<double>> b, int ldb,
       std::complex<double> beta, DeviceMemorySlice<std::complex<double>> c,
       int ldc, int batch_count, const NumericOptions &numeric_options,
-      ScratchAllocator *scratch_allocator) = 0;
-
+      ScratchAllocator *scratch_allocator, blas::CallContext context) = 0;
   // Batched gemm with strides instead of pointer arrays.
   virtual tsl::Status DoBlasGemmStridedBatched(
       Stream *stream, blas::Transpose transa, blas::Transpose transb,
@@ -425,7 +429,7 @@
       const DeviceMemoryBase &a, int lda, int64_t stride_a,
       const DeviceMemoryBase &b, int ldb, int64_t stride_b, const void *beta,
       DeviceMemoryBase *c, int ldc, int64_t stride_c, int batch_count,
-      const NumericOptions &numeric_options) = 0;
+      const NumericOptions &numeric_options, blas::CallContext context) = 0;
 
   // Solves a triangular matrix equation.
   //
@@ -490,12 +494,44 @@
                                  DeviceMemory<std::complex<double> *> *bs,
                                  int ldb, int batch_count) = 0;
 
+  // TODO(ezhulenev): We should never pass ScratchAllocator to any of the APIs
+  // in this file, because it makes them incompatible with command buffers (CUDA
+  // graphs). We should pass workspace memory explicitly to all APIs. However
+  // this is a giant change, so currently we work around it by setting a thread
+  // local workspace and rely on `ScopedBlasWorkspace` RAII helper to reset it.
+  //
+  // APIs that get ScratchAllocator ignore this workspace, and continue
+  // allocating scratch memory on demand.
+  class ScopedWorkspace {
+   public:
+    ScopedWorkspace(BlasSupport *blas, DeviceMemoryBase *workspace);
+    ~ScopedWorkspace();
+
+   private:
+    BlasSupport *blas_;
+  };
+
   virtual tsl::Status GetVersion(std::string *version) = 0;
 
  protected:
+  DeviceMemoryBase *GetWorkspace();
+
   BlasSupport() {}
 
  private:
+  // Workspace memory pointer is thread local, once it is set all Blas
+  // operations issued from a caller thread might use it if it has large enough
+  // size. It's a user responsibility to make sure that workspace will outlive
+  // all issued BLAS operations.
+  //
+  // TODO(ezhulenev): This is a giant footgun! We have to remove it and use
+  // explicit workspace memory argument for all BLAS operations.
+  void SetWorkspace(DeviceMemoryBase *workspace);
+
+  // Resets user-defined workspace memory, so that Blas operations can use their
+  // own memory pool for allocating workspace.
+  void ResetWorkspace();
+
   BlasSupport(const BlasSupport &) = delete;
   void operator=(const BlasSupport &) = delete;
 };
@@ -576,7 +612,8 @@
       uint64_t m, uint64 n, uint64 k, blas::DataType dtype, const void *alpha, \
       const DeviceMemoryBase &a, int lda, const DeviceMemoryBase &b, int ldb,  \
       const void *beta, DeviceMemoryBase *c, int ldc,                          \
-      const NumericOptions &numeric_options) override;                         \
+      const NumericOptions &numeric_options, blas::CallContext context)        \
+      override;                                                                \
   bool GetBlasGemmAlgorithms(Stream *stream,                                   \
                              std::vector<blas::AlgorithmType> *out_algorithms) \
       override;                                                                \
@@ -588,7 +625,8 @@
       const void *beta, DeviceMemoryBase *c, blas::DataType type_c, int ldc,   \
       blas::ComputationType computation_type, blas::AlgorithmType algorithm,   \
       const NumericOptions &numeric_options,                                   \
-      blas::ProfileResult *output_profile_result) override;                    \
+      blas::ProfileResult *output_profile_result, blas::CallContext context)   \
+      override;                                                                \
   bool DoBlasGemmBatched(                                                      \
       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
       uint64_t m, uint64 n, uint64 k, float alpha,                             \
@@ -596,7 +634,8 @@
       DeviceMemorySlice<Eigen::half> b, int ldb, float beta,                   \
       DeviceMemorySlice<Eigen::half> c, int ldc, int batch_count,              \
       const NumericOptions &numeric_options,                                   \
-      ScratchAllocator *scratch_allocator) override;                           \
+      ScratchAllocator *scratch_allocator, blas::CallContext context)          \
+      override;                                                                \
   bool DoBlasGemmBatched(                                                      \
       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
       uint64_t m, uint64 n, uint64 k, float alpha,                             \
@@ -604,21 +643,24 @@
       DeviceMemorySlice<Eigen::bfloat16> b, int ldb, float beta,               \
       DeviceMemorySlice<Eigen::bfloat16> c, int ldc, int batch_count,          \
       const NumericOptions &numeric_options,                                   \
-      ScratchAllocator *scratch_allocator) override;                           \
+      ScratchAllocator *scratch_allocator, blas::CallContext context)          \
+      override;                                                                \
   bool DoBlasGemmBatched(                                                      \
       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
       uint64_t m, uint64 n, uint64 k, float alpha, DeviceMemorySlice<float> a, \
       int lda, DeviceMemorySlice<float> b, int ldb, float beta,                \
       DeviceMemorySlice<float> c, int ldc, int batch_count,                    \
       const NumericOptions &numeric_options,                                   \
-      ScratchAllocator *scratch_allocator) override;                           \
+      ScratchAllocator *scratch_allocator, blas::CallContext context)          \
+      override;                                                                \
   bool DoBlasGemmBatched(                                                      \
       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
       uint64_t m, uint64 n, uint64 k, double alpha,                            \
       DeviceMemorySlice<double> a, int lda, DeviceMemorySlice<double> b,       \
       int ldb, double beta, DeviceMemorySlice<double> c, int ldc,              \
       int batch_count, const NumericOptions &numeric_options,                  \
-      ScratchAllocator *scratch_allocator) override;                           \
+      ScratchAllocator *scratch_allocator, blas::CallContext context)          \
+      override;                                                                \
   bool DoBlasGemmBatched(                                                      \
       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
       uint64_t m, uint64 n, uint64 k, std::complex<float> alpha,               \
@@ -626,7 +668,8 @@
       DeviceMemorySlice<std::complex<float>> b, int ldb,                       \
       std::complex<float> beta, DeviceMemorySlice<std::complex<float>> c,      \
       int ldc, int batch_count, const NumericOptions &numeric_options,         \
-      ScratchAllocator *scratch_allocator) override;                           \
+      ScratchAllocator *scratch_allocator, blas::CallContext context)          \
+      override;                                                                \
   bool DoBlasGemmBatched(                                                      \
       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
       uint64_t m, uint64 n, uint64 k, std::complex<double> alpha,              \
@@ -634,14 +677,16 @@
       DeviceMemorySlice<std::complex<double>> b, int ldb,                      \
       std::complex<double> beta, DeviceMemorySlice<std::complex<double>> c,    \
       int ldc, int batch_count, const NumericOptions &numeric_options,         \
-      ScratchAllocator *scratch_allocator) override;                           \
+      ScratchAllocator *scratch_allocator, blas::CallContext context)          \
+      override;                                                                \
   tsl::Status DoBlasGemmStridedBatched(                                        \
       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
       uint64_t m, uint64 n, uint64 k, blas::DataType dtype, const void *alpha, \
       const DeviceMemoryBase &a, int lda, int64_t stride_a,                    \
       const DeviceMemoryBase &b, int ldb, int64_t stride_b, const void *beta,  \
       DeviceMemoryBase *c, int ldc, int64_t stride_c, int batch_count,         \
-      const NumericOptions &numeric_options) override;                         \
+      const NumericOptions &numeric_options, blas::CallContext context)        \
+      override;                                                                \
   tsl::Status DoBlasGemmStridedBatchedWithAlgorithm(                           \
       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
       uint64_t m, uint64 n, uint64 k, const void *alpha,                       \
@@ -651,7 +696,8 @@
       blas::DataType type_c, int ldc, int64_t stride_c, int batch_count,       \
       blas::ComputationType computation_type, blas::AlgorithmType algorithm,   \
       const NumericOptions &numeric_options,                                   \
-      blas::ProfileResult *output_profile_result) override;                    \
+      blas::ProfileResult *output_profile_result, blas::CallContext context)   \
+      override;                                                                \
   bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
                   blas::Transpose transa, blas::Diagonal diag, uint64_t m,     \
                   uint64_t n, float alpha, const DeviceMemory<float> &a,       \
diff --git a/third_party/xla/xla/stream_executor/command_buffer.cc b/third_party/xla/xla/stream_executor/command_buffer.cc
index e753b16..305872e 100644
--- a/third_party/xla/xla/stream_executor/command_buffer.cc
+++ b/third_party/xla/xla/stream_executor/command_buffer.cc
@@ -31,6 +31,10 @@
 
 namespace stream_executor {
 
+CommandBuffer::~CommandBuffer() = default;
+CommandBuffer::CommandBuffer(CommandBuffer&&) = default;
+CommandBuffer& CommandBuffer::operator=(CommandBuffer&&) = default;
+
 /*static*/ tsl::StatusOr<CommandBuffer> CommandBuffer::Create(
     StreamExecutor* executor, Mode mode) {
   TF_ASSIGN_OR_RETURN(
diff --git a/third_party/xla/xla/stream_executor/command_buffer.h b/third_party/xla/xla/stream_executor/command_buffer.h
index 15200b7..44791ae 100644
--- a/third_party/xla/xla/stream_executor/command_buffer.h
+++ b/third_party/xla/xla/stream_executor/command_buffer.h
@@ -48,6 +48,10 @@
 // device.
 class CommandBuffer {
  public:
+  ~CommandBuffer();
+  CommandBuffer(CommandBuffer&&);
+  CommandBuffer& operator=(CommandBuffer&&);
+
   // Command buffer state:
   //
   //   (1) kCreate:    a new command buffer under construction
@@ -136,9 +140,6 @@
     return implementation_.get();
   }
 
-  CommandBuffer(CommandBuffer&&) = default;
-  CommandBuffer& operator=(CommandBuffer&&) = default;
-
  private:
   CommandBuffer(
       StreamExecutor* executor,
diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD
index b8be5cc..e049313 100644
--- a/third_party/xla/xla/stream_executor/cuda/BUILD
+++ b/third_party/xla/xla/stream_executor/cuda/BUILD
@@ -263,6 +263,7 @@
         ":cuda_stream",
         "@com_google_absl//absl/base:core_headers",
         "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/status",
         "@com_google_absl//absl/strings:str_format",
         "@com_google_absl//absl/synchronization",
         "@eigen_archive//:eigen3",
@@ -365,18 +366,19 @@
         "@com_google_absl//absl/base:core_headers",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/memory",
-        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/status",
         "@com_google_absl//absl/strings:str_format",
+        "@com_google_absl//absl/strings",
         "@com_google_absl//absl/synchronization",
         "@com_google_absl//absl/types:span",
         "@cudnn_frontend_archive//:cudnn_frontend",
         "@eigen_archive//:eigen3",
         "@local_config_cuda//cuda:cuda_headers",
         "@local_config_cuda//cuda:cudnn_header",
-        "//xla/stream_executor",
         "//xla/stream_executor:dnn",
         "//xla/stream_executor:plugin_registry",
         "//xla/stream_executor:stream_executor_headers",
+        "//xla/stream_executor",
         "//xla/stream_executor/gpu:gpu_executor_header",
         "//xla/stream_executor/gpu:gpu_timer_header",
         "//xla/stream_executor/platform",
diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_blas.cc b/third_party/xla/xla/stream_executor/cuda/cuda_blas.cc
index a790aca..843b347 100644
--- a/third_party/xla/xla/stream_executor/cuda/cuda_blas.cc
+++ b/third_party/xla/xla/stream_executor/cuda/cuda_blas.cc
@@ -18,8 +18,11 @@
 #include <complex>
 #include <cstdint>
 
+#include "absl/status/status.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/str_format.h"
+#include "absl/strings/string_view.h"
+#include "absl/synchronization/mutex.h"
 #include "Eigen/Core"  // from @eigen_archive
 #include "third_party/gpus/cuda/include/cublas_v2.h"
 #include "third_party/gpus/cuda/include/cuda.h"
@@ -179,6 +182,8 @@
     "not built with support for the GPU in your machine.";
 
 bool CUDABlas::Init() {
+  absl::MutexLock lock(&mu_);
+
   gpu::ScopedActivateExecutorContext sac{parent_};
   cublasStatus_t ret = cublasCreate(&blas_);
   if (ret != CUBLAS_STATUS_SUCCESS) {
@@ -222,6 +227,7 @@
   CHECK(AsGpuStreamValue(stream) != nullptr);
   CHECK(blas_ != nullptr);
   gpu::ScopedActivateExecutorContext sac{parent_};
+
   cublasStatus_t ret = cublasSetStream(blas_, AsGpuStreamValue(stream));
   if (ret != CUBLAS_STATUS_SUCCESS) {
     LOG(ERROR) << "failed to set stream for cuBLAS calls: " << ToString(ret);
@@ -358,6 +364,18 @@
     return tsl::errors::Internal("Failed setting stream");
   }
 
+  // Set workspace to a user-owned buffer, otherwise cuBlas will use its own
+  // memory pool, and it's not compatible with CUDA graphs.
+  if (auto *workspace = GetWorkspace();
+      workspace && workspace->opaque() && workspace->size() > 0) {
+    cublasStatus_t ret =
+        cublasSetWorkspace(blas_, workspace->opaque(), workspace->size());
+    if (ret != CUBLAS_STATUS_SUCCESS) {
+      return absl::InternalError(
+          absl::StrCat("Failed setting cuBlas workspace: ", ToString(ret)));
+    }
+  }
+
   ScopedCublasMathMode math_mode{blas_};
 #if CUBLAS_VER_MAJOR >= 11
   if (math_type == CUBLAS_TF32_TENSOR_OP_MATH &&
@@ -583,7 +601,8 @@
                                  const void *alpha, const DeviceMemoryBase &a,
                                  int lda, const DeviceMemoryBase &b, int ldb,
                                  const void *beta, DeviceMemoryBase *c, int ldc,
-                                 const NumericOptions &numeric_options) {
+                                 const NumericOptions &numeric_options,
+                                 blas::CallContext context) {
   cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
 
 #if CUDA_VERSION < 11000
@@ -778,7 +797,7 @@
     blas::DataType type_b, int ldb, const void *beta, DeviceMemoryBase *c,
     blas::DataType type_c, int ldc, blas::ComputationType computation_type,
     blas::AlgorithmType algorithm, const NumericOptions &numeric_options,
-    blas::ProfileResult *output_profile_result) {
+    blas::ProfileResult *output_profile_result, blas::CallContext context) {
   TF_ASSIGN_OR_RETURN(
       cublasMath_t math_type,
       GetMathTypeForGemmEx(stream, algorithm, type_a, type_b, numeric_options));
@@ -812,7 +831,7 @@
     DeviceMemoryBase *c, blas::DataType type_c, int ldc, int64_t stride_c,
     int batch_count, blas::ComputationType computation_type,
     blas::AlgorithmType algorithm, const NumericOptions &numeric_options,
-    blas::ProfileResult *output_profile_result) {
+    blas::ProfileResult *output_profile_result, blas::CallContext context) {
   TF_ASSIGN_OR_RETURN(
       cublasMath_t math_type,
       GetMathTypeForGemmEx(stream, algorithm, type_a, type_b, numeric_options));
@@ -1106,7 +1125,8 @@
       DeviceMemory<T> *c_matrix = c_ptrs_to_wrappers[b];
       TF_RETURN_IF_ERROR(DoBlasGemm(
           stream, transa, transb, m, n, k, blas::ToDataType<T>::value, &alpha,
-          a_matrix, lda, b_matrix, ldb, &beta, c_matrix, ldc, numeric_options));
+          a_matrix, lda, b_matrix, ldb, &beta, c_matrix, ldc, numeric_options,
+          blas::CallContext::kNone));
     }
     return ::tsl::OkStatus();
   }
@@ -1117,8 +1137,8 @@
     uint64_t n, uint64 k, float alpha, DeviceMemorySlice<Eigen::half> a_array,
     int lda, DeviceMemorySlice<Eigen::half> b_array, int ldb, float beta,
     DeviceMemorySlice<Eigen::half> c_array, int ldc, int batch_count,
-    const NumericOptions &numeric_options,
-    ScratchAllocator *scratch_allocator) {
+    const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator,
+    blas::CallContext context) {
   // Note: The func passed here (cublasSgemmBatched) is not actually called,
   // due to special handling of fp16 inside DoBlasGemmBatchedInternal.
   tsl::Status status = DoBlasGemmBatchedInternal(
@@ -1137,8 +1157,8 @@
     DeviceMemorySlice<Eigen::bfloat16> a_array, int lda,
     DeviceMemorySlice<Eigen::bfloat16> b_array, int ldb, float beta,
     DeviceMemorySlice<Eigen::bfloat16> c_array, int ldc, int batch_count,
-    const NumericOptions &numeric_options,
-    ScratchAllocator *scratch_allocator) {
+    const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator,
+    blas::CallContext context) {
   // Note: The func passed here (cublasSgemmBatched) is not actually called,
   // due to special handling of bf16 inside DoBlasGemmBatchedInternal.
   tsl::Status status = DoBlasGemmBatchedInternal(
@@ -1151,15 +1171,13 @@
   return status.ok();
 }
 
-bool CUDABlas::DoBlasGemmBatched(Stream *stream, blas::Transpose transa,
-                                 blas::Transpose transb, uint64_t m, uint64_t n,
-                                 uint64 k, float alpha,
-                                 DeviceMemorySlice<float> a_array, int lda,
-                                 DeviceMemorySlice<float> b_array, int ldb,
-                                 float beta, DeviceMemorySlice<float> c_array,
-                                 int ldc, int batch_count,
-                                 const NumericOptions &numeric_options,
-                                 ScratchAllocator *scratch_allocator) {
+bool CUDABlas::DoBlasGemmBatched(
+    Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
+    uint64_t n, uint64 k, float alpha, DeviceMemorySlice<float> a_array,
+    int lda, DeviceMemorySlice<float> b_array, int ldb, float beta,
+    DeviceMemorySlice<float> c_array, int ldc, int batch_count,
+    const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator,
+    blas::CallContext context) {
   tsl::Status status = DoBlasGemmBatchedInternal(
       cublasSgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
       b_array, ldb, beta, c_array, ldc, batch_count, numeric_options,
@@ -1170,18 +1188,17 @@
   return status.ok();
 }
 
-bool CUDABlas::DoBlasGemmBatched(Stream *stream, blas::Transpose transa,
-                                 blas::Transpose transb, uint64_t m, uint64_t n,
-                                 uint64 k, double alpha,
-                                 DeviceMemorySlice<double> a_array, int lda,
-                                 DeviceMemorySlice<double> b_array, int ldb,
-                                 double beta, DeviceMemorySlice<double> c_array,
-                                 int ldc, int batch_count,
-                                 const NumericOptions &numeric_options,
-                                 ScratchAllocator *scratch_allocator) {
+bool CUDABlas::DoBlasGemmBatched(
+    Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
+    uint64_t n, uint64 k, double alpha, DeviceMemorySlice<double> a_array,
+    int lda, DeviceMemorySlice<double> b_array, int ldb, double beta,
+    DeviceMemorySlice<double> c_array, int ldc, int batch_count,
+    const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator,
+    blas::CallContext context) {
   tsl::Status status = DoBlasGemmBatchedInternal(
       cublasDgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
       b_array, ldb, beta, c_array, ldc, batch_count, numeric_options,
+
       scratch_allocator);
   if (!status.ok()) {
     LOG(ERROR) << status;
@@ -1196,10 +1213,11 @@
     DeviceMemorySlice<std::complex<float>> b_array, int ldb,
     std::complex<float> beta, DeviceMemorySlice<std::complex<float>> c_array,
     int ldc, int batch_count, const NumericOptions &numeric_options,
-    ScratchAllocator *scratch_allocator) {
+    ScratchAllocator *scratch_allocator, blas::CallContext context) {
   tsl::Status status = DoBlasGemmBatchedInternal(
       cublasCgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
       b_array, ldb, beta, c_array, ldc, batch_count, numeric_options,
+
       scratch_allocator);
   if (!status.ok()) {
     LOG(ERROR) << status;
@@ -1214,7 +1232,7 @@
     DeviceMemorySlice<std::complex<double>> b_array, int ldb,
     std::complex<double> beta, DeviceMemorySlice<std::complex<double>> c_array,
     int ldc, int batch_count, const NumericOptions &numeric_options,
-    ScratchAllocator *scratch_allocator) {
+    ScratchAllocator *scratch_allocator, blas::CallContext context) {
   tsl::Status status = DoBlasGemmBatchedInternal(
       cublasZgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
       b_array, ldb, beta, c_array, ldc, batch_count, numeric_options,
@@ -1231,7 +1249,7 @@
     const DeviceMemoryBase &a, int lda, int64_t stride_a,
     const DeviceMemoryBase &b, int ldb, int64_t stride_b, const void *beta,
     DeviceMemoryBase *c, int ldc, int64_t stride_c, int batch_count,
-    const NumericOptions &numeric_options) {
+    const NumericOptions &numeric_options, blas::CallContext context) {
   cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
 #if CUDA_VERSION < 11000
   if (dtype == dnn::kHalf) {
diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_blas.h b/third_party/xla/xla/stream_executor/cuda/cuda_blas.h
index 151a63e..3dc54f9 100644
--- a/third_party/xla/xla/stream_executor/cuda/cuda_blas.h
+++ b/third_party/xla/xla/stream_executor/cuda/cuda_blas.h
@@ -25,6 +25,7 @@
 #include "third_party/gpus/cuda/include/cublas_v2.h"
 #include "xla/stream_executor/blas.h"
 #include "xla/stream_executor/cuda/cuda_blas_lt.h"
+#include "xla/stream_executor/device_memory.h"
 #include "xla/stream_executor/platform/port.h"
 #include "xla/stream_executor/plugin_registry.h"
 
diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc
index 8edf649..a8bc2b0 100644
--- a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc
+++ b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc
@@ -191,7 +191,8 @@
 /*static*/ tsl::StatusOr<BlasLt::MatmulDesc> BlasLt::MatmulDesc::Create(
     blas::ComputationType compute_type, blas::DataType scale_type,
     blas::Transpose trans_a, blas::Transpose trans_b,
-    gpu::BlasLt::Epilogue epilogue, PointerMode pointer_mode) {
+    gpu::BlasLt::Epilogue epilogue, bool enable_fast_accum,
+    PointerMode pointer_mode) {
   VLOG(2) << "MatmulDesc::Create: compute_type: " << (int)compute_type
           << " scale:" << (int)scale_type << " trans a/b: " << (int)trans_a
           << "," << (int)trans_b << " epilogue:" << (int)epilogue
@@ -210,6 +211,9 @@
                              AsCublasOperation(trans_b)));
   TF_ASSIGN_OR_RETURN(cublasLtEpilogue_t epi, AsCublasLtEpilogue(epilogue));
   TF_RETURN_IF_ERROR(SetAttr(cu_desc, CUBLASLT_MATMUL_DESC_EPILOGUE, epi));
+  // TODO(b/259609697): Set the CUBLASLT_MATMUL_DESC_FAST_ACCUM attribute if
+  // enable_fast_accum is true, once Flax/Praxis properly pass a PrecisionConfig
+  // of HIGH or HIGHEST on the backwards pass.
   return std::move(desc);
 }
 
@@ -315,11 +319,17 @@
                                           cfg.compute_precision));
   }
 
+  // FP8 matmuls have a fast accumulation mode that is less precise than the
+  // default accumulation mode. Use the fast accumulation mode if the compute
+  // precision is DEFAULT.
+  bool enable_fast_accum = (xla::primitive_util::IsF8Type(lhs_layout.dtype) ||
+                            xla::primitive_util::IsF8Type(rhs_layout.dtype)) &&
+                           cfg.compute_precision == 0;
   TF_ASSIGN_OR_RETURN(
       auto op_desc,
       MatmulDesc::Create(*compute_type,
                          gpu::GetScaleType(output_dtype, *compute_type),
-                         trans_a, trans_b, epilogue));
+                         trans_a, trans_b, epilogue, enable_fast_accum));
 
   TF_ASSIGN_OR_RETURN(auto a_desc, MatrixLayout::Create(lhs_layout));
   TF_ASSIGN_OR_RETURN(auto b_desc, MatrixLayout::Create(rhs_layout));
diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h
index 7ca5f64..7a758f4 100644
--- a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h
+++ b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h
@@ -70,7 +70,7 @@
         blas::ComputationType compute_type, blas::DataType scale_type,
         blas::Transpose trans_a = blas::Transpose::kNoTranspose,
         blas::Transpose trans_b = blas::Transpose::kNoTranspose,
-        Epilogue epilogue = Epilogue::kDefault,
+        Epilogue epilogue = Epilogue::kDefault, bool enable_fast_accum = false,
         PointerMode pointer_mode = PointerMode::kHost);
 
     cublasComputeType_t compute_type() const;
diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc
index b3c820f..32be286 100644
--- a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc
+++ b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc
@@ -33,6 +33,7 @@
 #include "absl/base/thread_annotations.h"
 #include "absl/container/flat_hash_map.h"
 #include "absl/memory/memory.h"
+#include "absl/status/status.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/str_format.h"
 #include "Eigen/Core"  // from @eigen_archive
@@ -3624,12 +3625,16 @@
   MASK_ID,
   ZERO_VAL_ID,
   ONE_VAL_ID,
+  NEG_INFINITY_ID,
   ALPHA_SCALE_ID,
   DROPOUT_SCALE_ID,
+  SCALE_PROB_ID,
   Q_SEQLEN_ID,
   K_SEQLEN_ID,
   D_OFFSET_ID,
   D_SEED_ID,
+  S_SUM_ID,
+  d_Q_accum_ID,
   VIRTUAL_ID = 34857
 };
 
@@ -3688,10 +3693,10 @@
 }
 
 // Returns a cudnn tensor that's the output of the mask op
-tsl::StatusOr<cudnn_frontend::Tensor> CreateCudnnMaskTensor(
+tsl::StatusOr<cudnn_frontend::Tensor> CreateCudnnMaskFwdTensor(
     std::vector<cudnn_frontend::Operation>& ops, absl::Span<const int64_t> dims,
     absl::Span<const int64_t> strides, dnn::DataType dtype,
-    std::shared_ptr<cudnn_frontend::Tensor> input_tensor) {
+    cudnn_frontend::Tensor& input_tensor) {
   std::vector<int64_t> mask_dim(dims.size(), 1);
   std::vector<int64_t> mask_stride(strides.size(), 1);
 
@@ -3703,7 +3708,7 @@
   // Create the mask output tensor
   TF_ASSIGN_OR_RETURN(
       auto mask_out_tensor,
-      CreateCudnnTensor(dims, strides, CudnnfMHAUid::VIRTUAL_ID + 300,
+      CreateCudnnTensor(dims, strides, CudnnfMHAUid::VIRTUAL_ID + 400,
                         dnn::DataType::kFloat, 1, -1,
                         /*is_virtual=*/true));
 
@@ -3715,7 +3720,7 @@
   // Create the mask op.
   auto mask_op = cudnn_frontend::OperationBuilder(
                      CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
-                     .setxDesc((*input_tensor))
+                     .setxDesc(input_tensor)
                      .setbDesc(mask_tensor)
                      .setyDesc(mask_out_tensor)
                      .setpwDesc(mask_desc)
@@ -3734,7 +3739,7 @@
 tsl::StatusOr<cudnn_frontend::Tensor> CreateCudnnScaleTensor(
     std::vector<cudnn_frontend::Operation>& ops, absl::Span<const int64_t> dims,
     absl::Span<const int64_t> strides, dnn::DataType dtype,
-    std::shared_ptr<cudnn_frontend::Tensor> input_tensor) {
+    cudnn_frontend::Tensor& input_tensor) {
   std::vector<int64_t> scale_dims(dims.size(), 1);
   std::vector<int64_t> scale_strides(strides.size(), 1);
 
@@ -3747,12 +3752,13 @@
           /*is_value*/ true));
   TF_ASSIGN_OR_RETURN(auto scale_desc,
                       CreatePwDesc(dnn::DataType::kFloat, CUDNN_POINTWISE_MUL));
-  TF_ASSIGN_OR_RETURN(auto tensor_alpha_scale_out,
-                      CreateCudnnTensor(dims, strides, VIRTUAL_ID + 600, dtype,
-                                        1, -1, /* is_virtual */ true));
+  TF_ASSIGN_OR_RETURN(
+      auto tensor_alpha_scale_out,
+      CreateCudnnTensor(dims, strides, CudnnfMHAUid::VIRTUAL_ID + 200, dtype, 1,
+                        -1, /* is_virtual */ true));
 
   TF_ASSIGN_OR_RETURN(auto scale_op,
-                      CreateBinaryPwOp((*input_tensor), tensor_alpha_scale,
+                      CreateBinaryPwOp(input_tensor, tensor_alpha_scale,
                                        tensor_alpha_scale_out, scale_desc));
   // Add scale to op list
   ops.push_back(std::move(scale_op));
@@ -3764,7 +3770,7 @@
 tsl::StatusOr<cudnn_frontend::Tensor> CreateCudnnBiasTensor(
     std::vector<cudnn_frontend::Operation>& ops, absl::Span<const int64_t> dims,
     absl::Span<const int64_t> strides, dnn::DataType dtype,
-    std::shared_ptr<cudnn_frontend::Tensor> input_tensor, bool use_mask) {
+    cudnn_frontend::Tensor& input_tensor, bool use_mask) {
   // Create the bias tensor.
   TF_ASSIGN_OR_RETURN(
       auto bias_tensor,
@@ -3774,7 +3780,7 @@
   dnn::DataType bias_out_type = use_mask ? dtype : dnn::DataType::kFloat;
   TF_ASSIGN_OR_RETURN(
       auto bias_out_tensor,
-      CreateCudnnTensor(dims, strides, CudnnfMHAUid::VIRTUAL_ID + 200,
+      CreateCudnnTensor(dims, strides, CudnnfMHAUid::VIRTUAL_ID + 300,
                         bias_out_type, 1, -1,
                         /*is_virtual=*/true));
 
@@ -3786,15 +3792,13 @@
   // Create the bias op.
   auto bias_op = cudnn_frontend::OperationBuilder(
                      CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
-                     .setxDesc((*input_tensor))
+                     .setxDesc(input_tensor)
                      .setbDesc(bias_tensor)
                      .setyDesc(bias_out_tensor)
                      .setpwDesc(bias_desc)
                      .build();
 
   RETURN_MSG_IF_CUDNN_ERROR(bias_op);
-
-  RETURN_MSG_IF_CUDNN_ERROR(bias_out_tensor);
   // Add bias to op list
   ops.push_back(std::move(bias_op));
 
@@ -3805,8 +3809,7 @@
 tsl::StatusOr<cudnn_frontend::Tensor> CreateCudnnSoftmaxFwdTensor(
     std::vector<cudnn_frontend::Operation>& ops, absl::Span<const int64_t> dims,
     absl::Span<const int64_t> strides, dnn::DataType dtype,
-    std::shared_ptr<cudnn_frontend::Tensor> input_tensor,
-    bool is_virtual = false) {
+    cudnn_frontend::Tensor& input_tensor, bool is_virtual = false) {
   // softmax's typical computation is:
   // exp(input - reduce_max(input)) / reduce_sum(exp(input - reduce_max(input)))
   // We need to create each op and add it to the op list sequentially.
@@ -3823,25 +3826,26 @@
   }
 
   // Softmax output should be float
-  cudnnDataType_t softmax_output_type = CUDNN_DATA_FLOAT;
+  dnn::DataType softmax_output_type = dnn::DataType::kFloat;
 
   // Create output tensor of the first max reduction.
   TF_ASSIGN_OR_RETURN(
       auto max_reduction_output_tensor,
       CreateCudnnTensor(reduction_output_dim, reduction_output_stride,
-                        CudnnfMHAUid::VIRTUAL_ID + 400, dnn::DataType::kFloat,
+                        CudnnfMHAUid::VIRTUAL_ID + 500, dnn::DataType::kFloat,
                         1, -1, /*is_virtual=*/true));
 
   // Create the reduction descriptor
-  auto max_reduction_desc = cudnn_frontend::ReductionDescBuilder()
-                                .setComputeType(softmax_output_type)
-                                .setReductionOp(CUDNN_REDUCE_TENSOR_MAX)
-                                .build();
+  auto max_reduction_desc =
+      cudnn_frontend::ReductionDescBuilder()
+          .setComputeType(ToCudnnDataType(softmax_output_type))
+          .setReductionOp(CUDNN_REDUCE_TENSOR_MAX)
+          .build();
 
   // Create a reduction max node.
   auto max_reduction_op = cudnn_frontend::OperationBuilder(
                               CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)
-                              .setxDesc((*input_tensor))
+                              .setxDesc(input_tensor)
                               .setyDesc(max_reduction_output_tensor)
                               .setreductionDesc(max_reduction_desc)
                               .build();
@@ -3850,56 +3854,45 @@
   // Create output tensor of the subtraction op.
   TF_ASSIGN_OR_RETURN(
       auto subtract_output_tensor,
-      CreateCudnnTensor(dims, strides, CudnnfMHAUid::VIRTUAL_ID + 401,
+      CreateCudnnTensor(dims, strides, CudnnfMHAUid::VIRTUAL_ID + 501,
                         dnn::DataType::kFloat, 1, -1,
                         /*is_virtual=*/true));
   // Create the subtraction descriptor
-  auto subtract_desc = cudnn_frontend::PointWiseDescBuilder()
-                           .setMode(CUDNN_POINTWISE_SUB)
-                           .setComputeType(softmax_output_type)
-                           .build();
+  TF_ASSIGN_OR_RETURN(auto subtract_desc,
+                      CreatePwDesc(softmax_output_type, CUDNN_POINTWISE_SUB));
 
   // Create a subtraction node.
-  auto subtract_op = cudnn_frontend::OperationBuilder(
-                         CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
-                         .setxDesc((*input_tensor))
-                         .setbDesc(max_reduction_output_tensor)
-                         .setyDesc(subtract_output_tensor)
-                         .setpwDesc(subtract_desc)
-                         .build();
-  RETURN_MSG_IF_CUDNN_ERROR(subtract_op);
+  TF_ASSIGN_OR_RETURN(
+      auto subtract_op,
+      CreateBinaryPwOp(input_tensor, max_reduction_output_tensor,
+                       subtract_output_tensor, subtract_desc));
   // Create output tensor of the exp op.
   TF_ASSIGN_OR_RETURN(
       auto exp_output_tensor,
-      CreateCudnnTensor(dims, strides, CudnnfMHAUid::VIRTUAL_ID + 402,
+      CreateCudnnTensor(dims, strides, CudnnfMHAUid::VIRTUAL_ID + 502,
                         dnn::DataType::kFloat, 1, -1,
                         /*is_virtual=*/true));
   // Create the exponetial descriptor
-  auto exp_desc = cudnn_frontend::PointWiseDescBuilder()
-                      .setMode(CUDNN_POINTWISE_EXP)
-                      .setComputeType(softmax_output_type)
-                      .build();
+  TF_ASSIGN_OR_RETURN(auto exp_desc,
+                      CreatePwDesc(softmax_output_type, CUDNN_POINTWISE_EXP));
 
   // Create a exponetial node.
-  auto exp_op = cudnn_frontend::OperationBuilder(
-                    CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
-                    .setxDesc(subtract_output_tensor)
-                    .setyDesc(exp_output_tensor)
-                    .setpwDesc(exp_desc)
-                    .build();
-  RETURN_MSG_IF_CUDNN_ERROR(exp_op);
+  TF_ASSIGN_OR_RETURN(
+      auto exp_op,
+      CreateUnaryPwOp(subtract_output_tensor, exp_output_tensor, exp_desc));
 
   // Create output tensor of the sum reduction.
   TF_ASSIGN_OR_RETURN(
       auto sum_reduction_output_tensor,
       CreateCudnnTensor(reduction_output_dim, reduction_output_stride,
-                        CudnnfMHAUid::VIRTUAL_ID + 403, dnn::DataType::kFloat,
+                        CudnnfMHAUid::VIRTUAL_ID + 503, dnn::DataType::kFloat,
                         1, -1, /*is_virtual=*/true));
   // Create the reduction descriptor
-  auto sum_reduction_desc = cudnn_frontend::ReductionDescBuilder()
-                                .setComputeType(softmax_output_type)
-                                .setReductionOp(CUDNN_REDUCE_TENSOR_ADD)
-                                .build();
+  auto sum_reduction_desc =
+      cudnn_frontend::ReductionDescBuilder()
+          .setComputeType(ToCudnnDataType(softmax_output_type))
+          .setReductionOp(CUDNN_REDUCE_TENSOR_ADD)
+          .build();
 
   // Create a reduction sum node.
   auto sum_reduction_op = cudnn_frontend::OperationBuilder(
@@ -3911,7 +3904,7 @@
   RETURN_MSG_IF_CUDNN_ERROR(sum_reduction_op);
 
   // Create output tensor of the divide op.
-  auto uid = is_virtual ? CudnnfMHAUid::VIRTUAL_ID + 404 : CudnnfMHAUid::P_ID;
+  auto uid = is_virtual ? CudnnfMHAUid::VIRTUAL_ID + 504 : CudnnfMHAUid::P_ID;
   TF_ASSIGN_OR_RETURN(
       auto divide_output_tensor,
       CreateCudnnTensor(
@@ -3919,22 +3912,14 @@
           /*is_virtual*/ is_virtual,
           /*cudnn_tensor_order_type*/ CUDNN_TENSOR_REORDERING_F16x16));
   // Create the divide descriptor
-  auto divide_desc = cudnn_frontend::PointWiseDescBuilder()
-                         .setMode(CUDNN_POINTWISE_DIV)
-                         .setComputeType(softmax_output_type)
-                         .build();
+  TF_ASSIGN_OR_RETURN(auto divide_desc,
+                      CreatePwDesc(softmax_output_type, CUDNN_POINTWISE_DIV));
 
   // Create a divide node.
-  auto divide_op = cudnn_frontend::OperationBuilder(
-                       CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
-                       .setxDesc(exp_output_tensor)
-                       .setbDesc(sum_reduction_output_tensor)
-                       .setyDesc(divide_output_tensor)
-                       .setpwDesc(divide_desc)
-                       .build();
-  RETURN_MSG_IF_CUDNN_ERROR(divide_op);
-
-  RETURN_MSG_IF_CUDNN_ERROR(divide_output_tensor);
+  TF_ASSIGN_OR_RETURN(
+      auto divide_op,
+      CreateBinaryPwOp(exp_output_tensor, sum_reduction_output_tensor,
+                       divide_output_tensor, divide_desc));
 
   // Add max reduction to op list
   ops.push_back(std::move(max_reduction_op));
@@ -3951,11 +3936,11 @@
 }
 
 // Returns a cudnn tensor that's the output of the dropout op
-tsl::StatusOr<cudnn_frontend::Tensor> CreateCudnnDropoutTensor(
+tsl::StatusOr<cudnn_frontend::Tensor> CreateCudnnDropoutFwdTensor(
     std::vector<cudnn_frontend::Operation>& ops, absl::Span<const int64_t> dims,
     absl::Span<const int64_t> strides, dnn::DataType dtype,
-    std::shared_ptr<cudnn_frontend::Tensor> input_tensor, double dropout_rate,
-    int64_t seed, bool is_virtual = false) {
+    cudnn_frontend::Tensor& input_tensor, double dropout_rate, int64_t seed,
+    bool is_virtual = false) {
   // Create scale tensor
   std::vector<int64_t> scale_dims(dims.size(), 1);
   std::vector<int64_t> scale_strides(strides.size(), 1);
@@ -3963,11 +3948,11 @@
   // Create tensor for dropout's mask.
   TF_ASSIGN_OR_RETURN(
       auto mask_tensor,
-      CreateCudnnTensor(dims, strides, CudnnfMHAUid::VIRTUAL_ID + 500,
+      CreateCudnnTensor(dims, strides, CudnnfMHAUid::VIRTUAL_ID + 600,
                         dnn::DataType::kFloat, 1, -1,
                         /*is_virtual*/ true));
   // Create output tensor of dropout node
-  auto uid = is_virtual ? CudnnfMHAUid::VIRTUAL_ID + 501 : CudnnfMHAUid::P_ID;
+  auto uid = is_virtual ? CudnnfMHAUid::VIRTUAL_ID + 601 : CudnnfMHAUid::P_ID;
   TF_ASSIGN_OR_RETURN(
       auto dropout_out_tensor,
       CreateCudnnTensor(
@@ -4010,19 +3995,13 @@
   RETURN_MSG_IF_CUDNN_ERROR(rng_op);
 
   // Create the masking node desc after mask tensor
-  auto masking_desc = cudnn_frontend::PointWiseDescBuilder()
-                          .setMode(CUDNN_POINTWISE_MUL)
-                          .setComputeType(CUDNN_DATA_FLOAT)
-                          .build();
+  TF_ASSIGN_OR_RETURN(auto masking_desc,
+                      CreatePwDesc(dnn::DataType::kFloat, CUDNN_POINTWISE_MUL));
 
-  auto masking_op = cudnn_frontend::OperationBuilder(
-                        CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
-                        .setxDesc((*input_tensor))
-                        .setbDesc(mask_tensor)
-                        .setyDesc(dropout_out_tensor)
-                        .setpwDesc(masking_desc)
-                        .build();
-  RETURN_MSG_IF_CUDNN_ERROR(masking_op);
+  // Create the scaling op
+  TF_ASSIGN_OR_RETURN(auto masking_op,
+                      CreateBinaryPwOp(input_tensor, mask_tensor,
+                                       dropout_out_tensor, masking_desc));
 
   TF_ASSIGN_OR_RETURN(
       auto dropout_scale_tensor,
@@ -4036,24 +4015,16 @@
   // Create output of scale node
   TF_ASSIGN_OR_RETURN(
       auto dropout_scale_out_tensor,
-      CreateCudnnTensor(dims, strides, CudnnfMHAUid::VIRTUAL_ID + 502, dtype, 1,
+      CreateCudnnTensor(dims, strides, CudnnfMHAUid::VIRTUAL_ID + 602, dtype, 1,
                         -1, /*is_virtual*/ true));
   // Create the scaling desc
-  auto scale_desc = cudnn_frontend::PointWiseDescBuilder()
-                        .setMode(CUDNN_POINTWISE_MUL)
-                        .setComputeType(CUDNN_DATA_FLOAT)
-                        .build();
-  // Create the scaling op
-  auto scale_op = cudnn_frontend::OperationBuilder(
-                      CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
-                      .setxDesc(dropout_out_tensor)
-                      .setbDesc(dropout_scale_tensor)
-                      .setyDesc(dropout_scale_out_tensor)
-                      .setpwDesc(scale_desc)
-                      .build();
-  RETURN_MSG_IF_CUDNN_ERROR(scale_op);
+  TF_ASSIGN_OR_RETURN(auto scale_desc,
+                      CreatePwDesc(dnn::DataType::kFloat, CUDNN_POINTWISE_MUL));
 
-  RETURN_MSG_IF_CUDNN_ERROR(dropout_scale_out_tensor);
+  // Create the scaling op
+  TF_ASSIGN_OR_RETURN(auto scale_op,
+                      CreateBinaryPwOp(dropout_out_tensor, dropout_scale_tensor,
+                                       dropout_scale_out_tensor, scale_desc));
   // Add rng op to op list
   ops.push_back(std::move(rng_op));
   // Add masking op to op list
@@ -5172,10 +5143,7 @@
       auto tensor_k,
       CreateCudnnTensor(bmm1_rhs_dims, bmm1_rhs_strides, CudnnfMHAUid::K_ID,
                         bmm1_rhs_descriptor.type(), 1, -1));
-  VLOG(4) << "\nTensor_k: " << tensor_k.describe();
 
-  std::shared_ptr<cudnn_frontend::Tensor> bmm2_input_tensor =
-      std::make_shared<cudnn_frontend::Tensor>(std::move(tensor_k));
   std::vector<int64_t> intermediate_bmm2_lhs_dims =
       intermediate_bmm2_lhs_descriptor.GetCudnnCompatibleDimensions(true);
   std::vector<int64_t> intermediate_bmm2_lhs_strides =
@@ -5205,7 +5173,8 @@
   TF_ASSIGN_OR_RETURN(
       auto alpha_scale_out,
       CreateCudnnScaleTensor(intermediate_ops, bmm1_rhs_dims, bmm1_rhs_strides,
-                             bmm1_rhs_descriptor.type(), bmm2_input_tensor));
+                             bmm1_rhs_descriptor.type(), tensor_k));
+
   auto bmm1_desc = cudnn_frontend::MatMulDescBuilder()
                        .setComputeType(CUDNN_DATA_FLOAT)
                        .build();
@@ -5222,8 +5191,7 @@
   VLOG(4) << "\nTensor_s: " << tensor_s.describe()
           << "\nBMM1_op: " << bmm1_op.describe();
 
-  bmm2_input_tensor =
-      std::make_shared<cudnn_frontend::Tensor>(std::move(tensor_s));
+  cudnn_frontend::Tensor bmm2_input_tensor = std::move(tensor_s);
   intermediate_ops.push_back(std::move(bmm1_op));
 
   if (is_s_virtual) {
@@ -5236,19 +5204,17 @@
                                 intermediate_bmm2_lhs_strides,
                                 (*bias_descriptor).type(), bmm2_input_tensor,
                                 use_mask));
-      bmm2_input_tensor =
-          std::make_shared<cudnn_frontend::Tensor>(std::move(bias_out));
+      bmm2_input_tensor = std::move(bias_out);
     }
     if (use_mask) {
       // Create mask op and tensor
       TF_ASSIGN_OR_RETURN(
           auto mask_out,
-          CreateCudnnMaskTensor(intermediate_ops, intermediate_bmm2_lhs_dims,
-                                intermediate_bmm2_lhs_strides,
-                                intermediate_bmm2_lhs_descriptor.type(),
-                                bmm2_input_tensor));
-      bmm2_input_tensor =
-          std::make_shared<cudnn_frontend::Tensor>(std::move(mask_out));
+          CreateCudnnMaskFwdTensor(intermediate_ops, intermediate_bmm2_lhs_dims,
+                                   intermediate_bmm2_lhs_strides,
+                                   intermediate_bmm2_lhs_descriptor.type(),
+                                   bmm2_input_tensor));
+      bmm2_input_tensor = std::move(mask_out);
     }
     if (kind == dnn::FusedMHAKind::BMM1_OUTPUT_FLOAT || use_bias ||
         use_dropout || use_mask) {
@@ -5262,22 +5228,20 @@
                               intermediate_bmm2_lhs_descriptor.type(),
                               /*input_tensor*/ bmm2_input_tensor,
                               /*is_virtual*/ !should_output_softmax));
-      bmm2_input_tensor =
-          std::make_shared<cudnn_frontend::Tensor>(std::move(softmax_fwd_out));
+      bmm2_input_tensor = std::move(softmax_fwd_out);
     }
 
     if (use_dropout) {
       // Create dropout tensor
       bool dropout_virtual = (activation_descriptor == std::nullopt);
       TF_ASSIGN_OR_RETURN(auto dropout_out,
-                          CreateCudnnDropoutTensor(
+                          CreateCudnnDropoutFwdTensor(
                               intermediate_ops, intermediate_bmm2_lhs_dims,
                               intermediate_bmm2_lhs_strides,
                               intermediate_bmm2_lhs_descriptor.type(),
                               /*input_tensor*/ bmm2_input_tensor, *dropout_rate,
                               *seed, /*is_virtual*/ dropout_virtual));
-      bmm2_input_tensor =
-          std::make_shared<cudnn_frontend::Tensor>(std::move(dropout_out));
+      bmm2_input_tensor = std::move(dropout_out);
     }
   }
   std::vector<int64_t> bmm2_rhs_dims =
@@ -5311,7 +5275,7 @@
   RETURN_MSG_IF_CUDNN_ERROR(bmm2_desc);
   auto bmm2_op = cudnn_frontend::OperationBuilder(
                      CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
-                     .setaMatDesc((*bmm2_input_tensor))
+                     .setaMatDesc(bmm2_input_tensor)
                      .setbMatDesc(tensor_v)
                      .setcMatDesc(tensor_o)
                      .setmatmulDesc(bmm2_desc)
@@ -5728,13 +5692,12 @@
     const dnn::MatmulTensorDescriptor& bmm2_grad_gemm1_lhs_descriptor,
     const dnn::MatmulTensorDescriptor& bmm2_grad_gemm2_rhs_descriptor,
     const dnn::MatmulTensorDescriptor& d_output_descriptor,
-    const dnn::TensorDescriptor& d_s_descriptor,
     const dnn::TensorDescriptor& d_bmm1_lhs_descriptor,
     const dnn::TensorDescriptor& d_bmm1_rhs_descriptor,
     const dnn::TensorDescriptor& d_bmm2_rhs_descriptor, dnn::FusedMHAKind kind,
     std::optional<double> dropout_rate, std::optional<int64_t> seed,
-    CudnnHandle& cudnn, double scale, bool use_dropout = false,
-    bool use_mask = false, bool use_bias = false) {
+    CudnnHandle& cudnn, double scale, std::vector<int64_t>& intermediate_shape,
+    bool use_dropout = false, bool use_mask = false, bool use_bias = false) {
   if (VLOG_IS_ON(4)) {
     VLOG(4) << "\n bmm1_grad_gemm1_rhs(q): "
             << bmm1_grad_gemm1_rhs_descriptor.ToString()
@@ -5791,11 +5754,15 @@
       auto tensor_k,
       CreateCudnnTensor(k_dims, k_strides, CudnnfMHAUid::K_ID, dtype, 1, -1));
 
+  // P^T is lhs of bmm2grad1 dV = dot(P^T, dO) so we set is_lhs = false here to
+  // get correct P dim and stride
   std::vector<int64_t> p_dims =
-      bmm2_grad_gemm1_lhs_descriptor.GetCudnnCompatibleDimensions(true);
+      bmm2_grad_gemm1_lhs_descriptor.GetCudnnCompatibleDimensions(false);
   std::vector<int64_t> p_strides =
-      bmm2_grad_gemm1_lhs_descriptor.GetCudnnCompatibleStrides(true);
+      bmm2_grad_gemm1_lhs_descriptor.GetCudnnCompatibleStrides(false);
 
+  // used for calculate offset increment
+  intermediate_shape = p_dims;
   VLOG(2) << "\n cuDNN compatible bmm2_grad_gemm1_lhs_dims: "
           << absl::StrJoin(p_dims, ",")
           << "\n cuDNN compatible bmm2_grad_gemm1_lhs_strides: "
@@ -5879,10 +5846,11 @@
   std::swap(p_transpose_dims[rank - 1], p_transpose_dims[rank - 2]);
   std::swap(p_transpose_strides[rank - 1], p_transpose_strides[rank - 2]);
 
-  TF_ASSIGN_OR_RETURN(auto tensor_p_transpose,
-                      CreateCudnnTensor(p_transpose_dims, p_transpose_strides,
-                                        VIRTUAL_ID + 300, dtype, 1, -1,
-                                        /* is_virtual */ true));
+  TF_ASSIGN_OR_RETURN(
+      auto tensor_p_transpose,
+      CreateCudnnTensor(p_transpose_dims, p_transpose_strides,
+                        CudnnfMHAUid::VIRTUAL_ID + 300, dtype, 1, -1,
+                        /* is_virtual */ true));
 
   auto reshape_op = cudnn_frontend::OperationBuilder(
                         CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR)
@@ -5902,10 +5870,10 @@
           /*is_value*/ true));
 
   // Create output of scale
-  TF_ASSIGN_OR_RETURN(
-      auto tensor_p_transpose_scale,
-      CreateCudnnTensor(p_transpose_dims, p_transpose_strides, VIRTUAL_ID + 301,
-                        dtype, 1, -1, /*is_virtual*/ true));
+  TF_ASSIGN_OR_RETURN(auto tensor_p_transpose_scale,
+                      CreateCudnnTensor(p_transpose_dims, p_transpose_strides,
+                                        CudnnfMHAUid::VIRTUAL_ID + 301, dtype,
+                                        1, -1, /*is_virtual*/ true));
   // Create the scaling desc
   TF_ASSIGN_OR_RETURN(auto scale_desc,
                       CreatePwDesc(dnn::DataType::kFloat, CUDNN_POINTWISE_MUL));
@@ -5916,10 +5884,10 @@
                                        tensor_p_transpose_scale, scale_desc));
   // create abs operation here to clear the sign bit
   // sign bit is used to store the mask for dropout
-  TF_ASSIGN_OR_RETURN(
-      auto tensor_p_transpose_scale_abs,
-      CreateCudnnTensor(p_transpose_dims, p_transpose_strides, VIRTUAL_ID + 302,
-                        dtype, 1, -1, /*is_virtual*/ true));
+  TF_ASSIGN_OR_RETURN(auto tensor_p_transpose_scale_abs,
+                      CreateCudnnTensor(p_transpose_dims, p_transpose_strides,
+                                        CudnnfMHAUid::VIRTUAL_ID + 302, dtype,
+                                        1, -1, /*is_virtual*/ true));
 
   TF_ASSIGN_OR_RETURN(auto abs_desc,
                       CreatePwDesc(dnn::DataType::kFloat, CUDNN_POINTWISE_ABS));
@@ -5937,6 +5905,7 @@
                                   .setComputeType(CUDNN_DATA_FLOAT)
                                   .build();
   RETURN_MSG_IF_CUDNN_ERROR(bmm2_grad_gemm1_desc);
+
   auto bmm2_grad_gemm1_op = cudnn_frontend::OperationBuilder(
                                 CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
                                 .setaMatDesc(tensor_p_transpose_scale_abs)
@@ -5953,7 +5922,7 @@
   // matmul to calculate dp
   TF_ASSIGN_OR_RETURN(
       auto tensor_dp,
-      CreateCudnnTensor(p_dims, p_strides, VIRTUAL_ID + 303,
+      CreateCudnnTensor(p_dims, p_strides, CudnnfMHAUid::VIRTUAL_ID + 303,
                         dnn::DataType::kFloat, 1,
                         -1,  // FMHA TODO TYPE: why it is float here?
                         /* is_virtual */ true));
@@ -5978,7 +5947,8 @@
   // mask out the sign bit here
   TF_ASSIGN_OR_RETURN(
       auto tensor_p_abs,
-      CreateCudnnTensor(p_dims, p_strides, VIRTUAL_ID + 304, dtype, 1, -1,
+      CreateCudnnTensor(p_dims, p_strides, CudnnfMHAUid::VIRTUAL_ID + 304,
+                        dtype, 1, -1,
                         /* is_virtual */ true));
 
   TF_ASSIGN_OR_RETURN(auto p_abs_desc,
@@ -6005,18 +5975,17 @@
       CreateCudnnMaskBwdTensor(intermediate_ops, p_dims, p_strides, dtype,
                                tensor_ds, use_mask));
 
-#if (CUDNN_VERSION >= 8901 && TF_ENABLE_CUDNN_FRONTEND)
   // bias backward
   if (use_bias) {
-    // bias backward
+#if (CUDNN_VERSION >= 8901 && TF_ENABLE_CUDNN_FRONTEND)
     TF_ASSIGN_OR_RETURN(
         auto tensor_dbias,
         CreateCudnnBiasBwdTensor(intermediate_ops, p_dims, p_strides, dtype,
                                  tensor_ds_mask));
-  }
 #else
-  return absl::InternalError("Bias backward op requires cudnn >= 8.9.1");
+    return absl::InternalError("Bias backward op requires cudnn >= 8.9.1");
 #endif
+  }
 
   // calculate dq
   auto bmm1_grad_gemm2_desc = cudnn_frontend::MatMulDescBuilder()
@@ -6037,10 +6006,11 @@
   intermediate_ops.push_back(std::move(bmm1_grad_gemm2_op));
 
   // calculate dk
-  TF_ASSIGN_OR_RETURN(auto tensor_ds_mask_reshape,
-                      CreateCudnnTensor(p_transpose_dims, p_transpose_strides,
-                                        VIRTUAL_ID + 305, dtype, 1, -1,
-                                        /* is_virtual */ true));
+  TF_ASSIGN_OR_RETURN(
+      auto tensor_ds_mask_reshape,
+      CreateCudnnTensor(p_transpose_dims, p_transpose_strides,
+                        CudnnfMHAUid::VIRTUAL_ID + 305, dtype, 1, -1,
+                        /* is_virtual */ true));
 
   auto reshape_2_op = cudnn_frontend::OperationBuilder(
                           CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR)
@@ -6089,6 +6059,1293 @@
   return std::make_unique<cudnn_frontend::OperationGraph>(std::move(op_graph));
 }
 
+// Returns a cudnn tensor that's the output of the bias addition op
+tsl::StatusOr<cudnn_frontend::Tensor> CreateCudnnFlashAttentionBiasFwdTensor(
+    std::vector<cudnn_frontend::Operation>& ops, absl::Span<const int64_t> dims,
+    absl::Span<const int64_t> strides, dnn::DataType dtype,
+    cudnn_frontend::Tensor& input_tensor) {
+  // Create the bias tensor.
+  TF_ASSIGN_OR_RETURN(
+      auto bias_tensor,
+      CreateCudnnTensor(dims, strides, CudnnfMHAUid::BIAS_ID, dtype, 1, -1));
+
+  // Create the bias output tensor
+  TF_ASSIGN_OR_RETURN(
+      auto bias_out_tensor,
+      CreateCudnnTensor(dims, strides, CudnnfMHAUid::VIRTUAL_ID + 300,
+                        dnn::DataType::kFloat, 1, -1, /*is_virtual=*/true));
+
+  // Define the bias descriptor
+  auto bias_desc = cudnn_frontend::PointWiseDescBuilder()
+                       .setMode(CUDNN_POINTWISE_ADD)
+                       .setComputeType(CUDNN_DATA_FLOAT)
+                       .build();
+  // Create the bias op.
+  auto bias_op = cudnn_frontend::OperationBuilder(
+                     CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
+                     .setxDesc(input_tensor)
+                     .setbDesc(bias_tensor)
+                     .setyDesc(bias_out_tensor)
+                     .setpwDesc(bias_desc)
+                     .build();
+
+  RETURN_MSG_IF_CUDNN_ERROR(bias_op);
+  // Add bias to op list
+  ops.push_back(std::move(bias_op));
+
+  return bias_out_tensor;
+}
+
+tsl::StatusOr<cudnn_frontend::Tensor> CreateCudnnFlashAttentionCausalMaskTensor(
+    std::vector<cudnn_frontend::Operation>& ops, absl::Span<const int64_t> dims,
+    absl::Span<const int64_t> strides, dnn::DataType dtype,
+    cudnn_frontend::Tensor& input_tensor) {
+  std::vector<int64_t> mask_dim(dims.size(), 1);
+  std::vector<int64_t> mask_stride(strides.size(), 1);
+
+  // Create the masked out value tensor.
+  TF_ASSIGN_OR_RETURN(
+      auto masked_val_tensor,
+      CreateCudnnTensor(
+          mask_dim, mask_stride, CudnnfMHAUid::NEG_INFINITY_ID,
+          dnn::DataType::kFloat, 1, -1,
+          /*is_virtual*/ false,
+          /*cudnn_tensor_order_type*/ CUDNN_TENSOR_REORDERING_NONE,
+          /*is_value*/ true));
+
+  // Create the row index tensor
+  TF_ASSIGN_OR_RETURN(
+      auto row_index_tensor,
+      CreateCudnnTensor(dims, strides, CudnnfMHAUid::VIRTUAL_ID + 401,
+                        dnn::DataType::kFloat, 1, -1,
+                        /*is_virtual=*/true));
+
+  // Create the column index tensor
+  TF_ASSIGN_OR_RETURN(
+      auto column_index_tensor,
+      CreateCudnnTensor(dims, strides, CudnnfMHAUid::VIRTUAL_ID + 402,
+                        dnn::DataType::kFloat, 1, -1,
+                        /*is_virtual=*/true));
+
+  // Create the causal mask tensor
+  auto causal_mask_tensor = cudnn_frontend::TensorBuilder()
+                                .setDim(dims.size(), dims.data())
+                                .setStride(strides.size(), strides.data())
+                                .setId(CudnnfMHAUid::VIRTUAL_ID + 403)
+                                .setAlignment(16)
+                                .setDataType(CUDNN_DATA_BOOLEAN)
+                                .setVectorCountAndDimension(1, -1)
+                                .setVirtual(true)
+                                .setReorderType(CUDNN_TENSOR_REORDERING_NONE)
+                                .setByValue(false)
+                                .build();
+
+  // Create the mask output tensor
+  TF_ASSIGN_OR_RETURN(
+      auto mask_out_tensor,
+      CreateCudnnTensor(dims, strides, CudnnfMHAUid::VIRTUAL_ID + 400,
+                        dnn::DataType::kFloat, 1, -1,
+                        /*is_virtual=*/true));
+
+  auto gen_index_row_desc = cudnn_frontend::PointWiseDescBuilder()
+                                .setMode(CUDNN_POINTWISE_GEN_INDEX)
+                                .setAxis(2)
+                                .setComputeType(CUDNN_DATA_FLOAT)
+                                .build();
+  RETURN_MSG_IF_CUDNN_ERROR(gen_index_row_desc);
+
+  TF_ASSIGN_OR_RETURN(
+      auto gen_index_row_op,
+      CreateUnaryPwOp(input_tensor, row_index_tensor, gen_index_row_desc));
+
+  auto gen_index_column_desc = cudnn_frontend::PointWiseDescBuilder()
+                                   .setMode(CUDNN_POINTWISE_GEN_INDEX)
+                                   .setAxis(3)
+                                   .setComputeType(CUDNN_DATA_FLOAT)
+                                   .build();
+  RETURN_MSG_IF_CUDNN_ERROR(gen_index_column_desc);
+
+  TF_ASSIGN_OR_RETURN(auto gen_index_column_op,
+                      CreateUnaryPwOp(input_tensor, column_index_tensor,
+                                      gen_index_column_desc));
+
+  auto row_greater_than_column_desc = cudnn_frontend::PointWiseDescBuilder()
+                                          .setMode(CUDNN_POINTWISE_CMP_GE)
+                                          .setComputeType(CUDNN_DATA_BOOLEAN)
+                                          .build();
+  RETURN_MSG_IF_CUDNN_ERROR(row_greater_than_column_desc);
+
+  TF_ASSIGN_OR_RETURN(
+      auto row_greater_than_column_op,
+      CreateBinaryPwOp(row_index_tensor, column_index_tensor,
+                       causal_mask_tensor, row_greater_than_column_desc));
+
+  TF_ASSIGN_OR_RETURN(
+      auto mask_desc,
+      CreatePwDesc(dnn::DataType::kFloat, CUDNN_POINTWISE_BINARY_SELECT));
+
+  // Create the mask op.
+  TF_ASSIGN_OR_RETURN(
+      auto mask_op,
+      CreateTernaryPwOp(input_tensor, masked_val_tensor, causal_mask_tensor,
+                        mask_out_tensor, mask_desc));
+
+  // Add mask to op list
+  ops.push_back(std::move(gen_index_row_op));
+  ops.push_back(std::move(gen_index_column_op));
+  ops.push_back(std::move(row_greater_than_column_op));
+  ops.push_back(std::move(mask_op));
+
+  return mask_out_tensor;
+}
+
+tsl::StatusOr<cudnn_frontend::Tensor> CreateCudnnFlashAttentionSoftmaxFwdTensor(
+    std::vector<cudnn_frontend::Operation>& ops, absl::Span<const int64_t> dims,
+    absl::Span<const int64_t> strides, dnn::DataType dtype,
+    cudnn_frontend::Tensor& input_tensor, bool is_virtual = false) {
+  // softmax's typical computation is:
+  // exp(input - reduce_max(input)) / reduce_sum(exp(input - reduce_max(input)))
+  // We need to create each op and add it to the op list sequentially.
+
+  // Copy all dims except the last dim since it's reduced to 1.
+  std::vector<int64_t> reduction_output_dim(dims.begin(), dims.end() - 1);
+  reduction_output_dim.push_back(1);
+
+  // Divide every stride by the last dim value.
+  std::vector<int64_t> reduction_output_stride;
+  int64_t reduced_dim_len = dims.back();
+  for (auto stride : strides) {
+    reduction_output_stride.push_back(stride / reduced_dim_len);
+  }
+
+  // Create output tensor of the first max reduction.
+  TF_ASSIGN_OR_RETURN(
+      auto max_reduction_output_tensor,
+      CreateCudnnTensor(reduction_output_dim, reduction_output_stride,
+                        CudnnfMHAUid::VIRTUAL_ID + 500, dnn::DataType::kFloat,
+                        1, -1, /*is_virtual=*/true));
+
+  // Create the reduction descriptor
+  auto max_reduction_desc =
+      cudnn_frontend::ReductionDescBuilder()
+          .setComputeType(ToCudnnDataType(dnn::DataType::kFloat))
+          .setReductionOp(CUDNN_REDUCE_TENSOR_MAX)
+          .build();
+  RETURN_MSG_IF_CUDNN_ERROR(max_reduction_desc);
+  // Create a reduction max node.
+  auto max_reduction_op = cudnn_frontend::OperationBuilder(
+                              CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)
+                              .setxDesc(input_tensor)
+                              .setyDesc(max_reduction_output_tensor)
+                              .setreductionDesc(max_reduction_desc)
+                              .build();
+  RETURN_MSG_IF_CUDNN_ERROR(max_reduction_op);
+
+  // Create output tensor of the subtraction op.
+  TF_ASSIGN_OR_RETURN(
+      auto subtract_output_tensor,
+      CreateCudnnTensor(dims, strides, CudnnfMHAUid::VIRTUAL_ID + 501,
+                        dnn::DataType::kFloat, 1, -1,
+                        /*is_virtual=*/true));
+  // Create the subtraction descriptor
+  TF_ASSIGN_OR_RETURN(auto subtract_desc,
+                      CreatePwDesc(dnn::DataType::kFloat, CUDNN_POINTWISE_SUB));
+
+  // Create a subtraction node.
+  TF_ASSIGN_OR_RETURN(
+      auto subtract_op,
+      CreateBinaryPwOp(input_tensor, max_reduction_output_tensor,
+                       subtract_output_tensor, subtract_desc));
+  // Create output tensor of the exp op.
+  TF_ASSIGN_OR_RETURN(
+      auto exp_output_tensor,
+      CreateCudnnTensor(dims, strides, CudnnfMHAUid::VIRTUAL_ID + 502,
+                        dnn::DataType::kFloat, 1, -1,
+                        /*is_virtual=*/true));
+  // Create the exponetial descriptor
+  TF_ASSIGN_OR_RETURN(auto exp_desc,
+                      CreatePwDesc(dnn::DataType::kFloat, CUDNN_POINTWISE_EXP));
+
+  // Create a exponetial node.
+  TF_ASSIGN_OR_RETURN(
+      auto exp_op,
+      CreateUnaryPwOp(subtract_output_tensor, exp_output_tensor, exp_desc));
+
+  // Create output tensor of the sum reduction.
+  TF_ASSIGN_OR_RETURN(
+      auto sum_reduction_output_tensor,
+      CreateCudnnTensor(reduction_output_dim, reduction_output_stride,
+                        CudnnfMHAUid::VIRTUAL_ID + 503, dnn::DataType::kFloat,
+                        1, -1, /*is_virtual=*/true));
+  // Create the reduction descriptor
+  auto sum_reduction_desc =
+      cudnn_frontend::ReductionDescBuilder()
+          .setComputeType(ToCudnnDataType(dnn::DataType::kFloat))
+          .setReductionOp(CUDNN_REDUCE_TENSOR_ADD)
+          .build();
+  RETURN_MSG_IF_CUDNN_ERROR(sum_reduction_desc);
+  // Create a reduction sum node.
+  auto sum_reduction_op = cudnn_frontend::OperationBuilder(
+                              CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)
+                              .setxDesc(exp_output_tensor)
+                              .setyDesc(sum_reduction_output_tensor)
+                              .setreductionDesc(sum_reduction_desc)
+                              .build();
+  RETURN_MSG_IF_CUDNN_ERROR(sum_reduction_op);
+
+  // Create output tensor of the log op.
+  TF_ASSIGN_OR_RETURN(
+      auto log_tensor,
+      CreateCudnnTensor(reduction_output_dim, reduction_output_stride,
+                        CudnnfMHAUid::VIRTUAL_ID + 504, dnn::DataType::kFloat,
+                        1, -1,
+                        /*is_virtual*/ true));
+
+  // Create the log descriptor
+  TF_ASSIGN_OR_RETURN(auto log_desc,
+                      CreatePwDesc(dnn::DataType::kFloat, CUDNN_POINTWISE_LOG));
+
+  // Create a log node.
+  TF_ASSIGN_OR_RETURN(auto log_op, CreateUnaryPwOp(sum_reduction_output_tensor,
+                                                   log_tensor, log_desc));
+
+  // Create output tensor of the add op.
+  auto ID = is_virtual ? CudnnfMHAUid::VIRTUAL_ID + 505 : CudnnfMHAUid::P_ID;
+  TF_ASSIGN_OR_RETURN(
+      auto softmax_stats_tensor,
+      CreateCudnnTensor(reduction_output_dim, reduction_output_stride, ID,
+                        dnn::DataType::kFloat, 1, -1,
+                        /*is_virtual*/ is_virtual));
+
+  // Create the add descriptor
+  TF_ASSIGN_OR_RETURN(auto add_desc,
+                      CreatePwDesc(dnn::DataType::kFloat, CUDNN_POINTWISE_ADD));
+
+  // Create a add node.
+  TF_ASSIGN_OR_RETURN(auto add_op,
+                      CreateBinaryPwOp(max_reduction_output_tensor, log_tensor,
+                                       softmax_stats_tensor, add_desc));
+
+  // Create output tensor of the divide op.
+  TF_ASSIGN_OR_RETURN(
+      auto divide_output_tensor,
+      CreateCudnnTensor(
+          dims, strides, CudnnfMHAUid::VIRTUAL_ID + 506, dnn::DataType::kFloat,
+          1, -1,
+          /*is_virtual*/ true,
+          /*cudnn_tensor_order_type*/ CUDNN_TENSOR_REORDERING_F16x16));
+  // Create the divide descriptor
+  TF_ASSIGN_OR_RETURN(auto divide_desc,
+                      CreatePwDesc(dnn::DataType::kFloat, CUDNN_POINTWISE_DIV));
+
+  // Create a divide node.
+  TF_ASSIGN_OR_RETURN(
+      auto divide_op,
+      CreateBinaryPwOp(exp_output_tensor, sum_reduction_output_tensor,
+                       divide_output_tensor, divide_desc));
+
+  // Add max reduction to op list
+  ops.push_back(std::move(max_reduction_op));
+  // Add subtract to op list
+  ops.push_back(std::move(subtract_op));
+  // Add exponetial to op list
+  ops.push_back(std::move(exp_op));
+  // Add sum reduction to op list
+  ops.push_back(std::move(sum_reduction_op));
+  // Add Log to op list
+  ops.push_back(std::move(log_op));
+  // Add Add to op list
+  ops.push_back(std::move(add_op));
+  // Add divide to op list
+  ops.push_back(std::move(divide_op));
+  return divide_output_tensor;
+}
+
+tsl::StatusOr<cudnn_frontend::Tensor> CreateCudnnFlashAttentionDropoutFwdTensor(
+    std::vector<cudnn_frontend::Operation>& ops, absl::Span<const int64_t> dims,
+    absl::Span<const int64_t> strides, dnn::DataType dtype,
+    cudnn_frontend::Tensor& input_tensor, double dropout_rate) {
+  // Create scale tensor
+  std::vector<int64_t> scale_dims(dims.size(), 1);
+  std::vector<int64_t> scale_strides(strides.size(), 1);
+
+  // Create tensor for dropout's mask.
+  TF_ASSIGN_OR_RETURN(
+      auto mask_tensor,
+      CreateCudnnTensor(dims, strides, CudnnfMHAUid::VIRTUAL_ID + 600,
+                        dnn::DataType::kFloat, 1, -1,
+                        /*is_virtual*/ true));
+  // Create output tensor of dropout node
+  // it is different from regular attention, the dropout output is always
+  // virtual we compute mask in the bwd instead of storing the mask
+  TF_ASSIGN_OR_RETURN(
+      auto dropout_out_tensor,
+      CreateCudnnTensor(
+          dims, strides, CudnnfMHAUid::VIRTUAL_ID + 601, dtype, 1, -1,
+          /*is_virtual*/ true,
+          /*cudnn_tensor_order_type*/
+          cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_F16x16));
+
+  // Create offset tensor of dropout node
+  TF_ASSIGN_OR_RETURN(
+      auto dropout_offset_tensor,
+      CreateCudnnTensor(
+          scale_dims, scale_strides, CudnnfMHAUid::D_OFFSET_ID,
+          dnn::DataType::kInt64, 1, -1, /*is_virtual*/ false,
+          /*cudnn_tensor_order_type*/
+          cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_NONE,
+          /*is_value*/ CUDNN_VERSION < 8903 ? false : true));
+
+  // Create seed tensor of dropout node
+  TF_ASSIGN_OR_RETURN(
+      auto dropout_seed_tensor,
+      CreateCudnnTensor(
+          scale_dims, scale_strides, CudnnfMHAUid::D_SEED_ID,
+          dnn::DataType::kInt64, 1, -1, /*is_virtual*/ false,
+          /*cudnn_tensor_order_type*/
+          cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_NONE,
+          /*is_value*/ CUDNN_VERSION < 8903 ? false : true));
+
+  // Create description for rng node
+  auto rng_desc = cudnn_frontend::RngDescBuilder()
+                      .setRngDistribution(CUDNN_RNG_DISTRIBUTION_BERNOULLI)
+                      .setBernoulliDistProbability(1.0 - dropout_rate)
+                      .build();
+
+  // Create the rng Node.
+  auto rng_op =
+      cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR)
+          .setyDesc(mask_tensor)
+          .setSeedDesc(dropout_seed_tensor)
+          .setOffsetDesc(dropout_offset_tensor)
+          .setRngDesc(rng_desc)
+          .build();
+  RETURN_MSG_IF_CUDNN_ERROR(rng_op);
+
+  // Create the masking node desc after mask tensor
+  TF_ASSIGN_OR_RETURN(auto masking_desc,
+                      CreatePwDesc(dnn::DataType::kFloat, CUDNN_POINTWISE_MUL));
+
+  // Create the scaling op
+  TF_ASSIGN_OR_RETURN(auto masking_op,
+                      CreateBinaryPwOp(input_tensor, mask_tensor,
+                                       dropout_out_tensor, masking_desc));
+
+  TF_ASSIGN_OR_RETURN(
+      auto dropout_scale_tensor,
+      CreateCudnnTensor(
+          scale_dims, scale_strides, CudnnfMHAUid::DROPOUT_SCALE_ID,
+          dnn::DataType::kFloat, 1, -1,
+          /*is_virtual*/ false,
+          /*cudnn_tensor_order_type*/ CUDNN_TENSOR_REORDERING_NONE,
+          /*is_value*/ true));
+
+  // Create output of scale node
+  TF_ASSIGN_OR_RETURN(
+      auto dropout_scale_out_tensor,
+      CreateCudnnTensor(dims, strides, CudnnfMHAUid::VIRTUAL_ID + 602, dtype, 1,
+                        -1, /*is_virtual*/ true));
+  // Create the scaling desc
+  TF_ASSIGN_OR_RETURN(auto scale_desc,
+                      CreatePwDesc(dnn::DataType::kFloat, CUDNN_POINTWISE_MUL));
+
+  // Create the scaling op
+  TF_ASSIGN_OR_RETURN(auto scale_op,
+                      CreateBinaryPwOp(dropout_out_tensor, dropout_scale_tensor,
+                                       dropout_scale_out_tensor, scale_desc));
+  // Add rng op to op list
+  ops.push_back(std::move(rng_op));
+  // Add masking op to op list
+  ops.push_back(std::move(masking_op));
+  // Add scaling op to op list
+  ops.push_back(std::move(scale_op));
+
+  return dropout_scale_out_tensor;
+}
+
+tsl::StatusOr<std::unique_ptr<cudnn_frontend::OperationGraph>>
+GetCudnnFlashAttentionOperationGraph(
+    const dnn::MatmulTensorDescriptor& bmm1_lhs_descriptor,
+    const dnn::MatmulTensorDescriptor& bmm1_rhs_descriptor,
+    const dnn::MatmulTensorDescriptor& bmm2_rhs_descriptor,
+    const dnn::MatmulTensorDescriptor& intermediate_bmm2_lhs_descriptor,
+    const dnn::TensorDescriptor& output_descriptor,
+    std::optional<dnn::TensorDescriptor> mask_descriptor,
+    std::optional<dnn::TensorDescriptor> bias_descriptor,
+    std::optional<dnn::TensorDescriptor> activation_descriptor,
+    dnn::FusedMHAKind kind, std::optional<double> dropout_rate,
+    std::optional<int64_t> seed, CudnnHandle& cudnn, double scale,
+    std::vector<int64_t>& intermediate_shape, bool use_dropout = false,
+    bool use_mask = false, bool use_bias = false,
+    bool use_causal_mask = false) {
+  if (VLOG_IS_ON(4)) {
+    VLOG(4) << "\n bmm1_lhs(q): " << bmm1_lhs_descriptor.ToString()
+            << "\n bmm1_rhs(k): " << bmm1_rhs_descriptor.ToString()
+            << "\n bmm2_lhs(s): " << intermediate_bmm2_lhs_descriptor.ToString()
+            << "\n bmm2_rhs(v): " << bmm2_rhs_descriptor.ToString()
+            << "\n out(o): " << output_descriptor.ToString();
+    if (activation_descriptor) {
+      VLOG(4) << "\n activation(s): " << (*activation_descriptor).ToString();
+    }
+  }
+
+  // cnn_infer needs to be preloaded for fMHA as well. Reusing the function
+  // created for convolution for fMHA.
+  PreloadCudnnSubLibsHelper(dnn::ConvolutionKind::FORWARD);
+
+  std::vector<cudnn_frontend::Operation const*> ops;
+  std::vector<cudnn_frontend::Operation> intermediate_ops;
+
+  // Batched Matmul: bmm1_lhs: tensor_q, bmm1_rhs:tensor_k; output: tensor_s
+  // (virtual)
+  // Batched Matmul: bmm2_lhs: tensor_s, bmm2_rhs:tensor_v; output: tensor_o
+  std::vector<int64_t> bmm1_lhs_dims =
+      bmm1_lhs_descriptor.GetCudnnCompatibleDimensions(true);
+  std::vector<int64_t> bmm1_lhs_strides =
+      bmm1_lhs_descriptor.GetCudnnCompatibleStrides(true);
+
+  VLOG(2) << "\n cuDNN compatible bmm1_lhs_dims: "
+          << absl::StrJoin(bmm1_lhs_dims, ",")
+          << "\n cuDNN compatible bmm1_lhs_strides: "
+          << absl::StrJoin(bmm1_lhs_strides, ",");
+
+  TF_ASSIGN_OR_RETURN(
+      auto tensor_q,
+      CreateCudnnTensor(bmm1_lhs_dims, bmm1_lhs_strides, CudnnfMHAUid::Q_ID,
+                        bmm1_lhs_descriptor.type(), 1, -1));
+
+  std::vector<int64_t> bmm1_rhs_dims =
+      bmm1_rhs_descriptor.GetCudnnCompatibleDimensions(false);
+  std::vector<int64_t> bmm1_rhs_strides =
+      bmm1_rhs_descriptor.GetCudnnCompatibleStrides(false);
+
+  VLOG(2) << "\n cuDNN compatible bmm1_rhs_dims: "
+          << absl::StrJoin(bmm1_rhs_dims, ",")
+          << "\n cuDNN compatible bmm1_rhs_strides: "
+          << absl::StrJoin(bmm1_rhs_strides, ",");
+
+  TF_ASSIGN_OR_RETURN(
+      auto tensor_k,
+      CreateCudnnTensor(bmm1_rhs_dims, bmm1_rhs_strides, CudnnfMHAUid::K_ID,
+                        bmm1_rhs_descriptor.type(), 1, -1));
+
+  std::vector<int64_t> intermediate_bmm2_lhs_dims =
+      intermediate_bmm2_lhs_descriptor.GetCudnnCompatibleDimensions(true);
+  std::vector<int64_t> intermediate_bmm2_lhs_strides =
+      intermediate_bmm2_lhs_descriptor.GetCudnnCompatibleStrides(true);
+
+  VLOG(2) << "\n cuDNN compatible intermediate_bmm2_lhs_dims: "
+          << absl::StrJoin(intermediate_bmm2_lhs_dims, ",")
+          << "\n cuDNN compatible intermediate_bmm2_lhs_strides: "
+          << absl::StrJoin(intermediate_bmm2_lhs_strides, ",");
+  intermediate_shape = intermediate_bmm2_lhs_dims;
+  bool has_activation = activation_descriptor != std::nullopt;
+
+  TF_ASSIGN_OR_RETURN(auto tensor_s,
+                      CreateCudnnTensor(intermediate_bmm2_lhs_dims,
+                                        intermediate_bmm2_lhs_strides,
+                                        CudnnfMHAUid::VIRTUAL_ID + 100,
+                                        dnn::DataType::kFloat, 1, -1,
+                                        /*is_virtual=*/true));
+
+  auto bmm1_desc = cudnn_frontend::MatMulDescBuilder()
+                       .setComputeType(CUDNN_DATA_FLOAT)
+                       .build();
+  RETURN_MSG_IF_CUDNN_ERROR(bmm1_desc);
+  auto bmm1_op = cudnn_frontend::OperationBuilder(
+                     CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
+                     .setaMatDesc(tensor_q)
+                     .setbMatDesc(tensor_k)
+                     .setcMatDesc(tensor_s)
+                     .setmatmulDesc(bmm1_desc)
+                     .build();
+  RETURN_MSG_IF_CUDNN_ERROR(bmm1_op);
+  intermediate_ops.push_back(std::move(bmm1_op));
+
+  // Create scale op and tensor
+  TF_ASSIGN_OR_RETURN(
+      auto alpha_scale_out,
+      CreateCudnnScaleTensor(intermediate_ops, intermediate_bmm2_lhs_dims,
+                             intermediate_bmm2_lhs_strides,
+                             dnn::DataType::kFloat, tensor_s));
+
+  auto bmm2_input_tensor = std::move(alpha_scale_out);
+
+  if (use_bias) {
+    // Create bias op and tensor
+    TF_ASSIGN_OR_RETURN(auto bias_out,
+                        CreateCudnnFlashAttentionBiasFwdTensor(
+                            intermediate_ops, intermediate_bmm2_lhs_dims,
+                            intermediate_bmm2_lhs_strides,
+                            (*bias_descriptor).type(), bmm2_input_tensor));
+    bmm2_input_tensor = std::move(bias_out);
+  }
+
+  if (use_causal_mask) {
+    // Create mask op and tensor
+    TF_ASSIGN_OR_RETURN(
+        auto mask_out,
+        CreateCudnnFlashAttentionCausalMaskTensor(
+            intermediate_ops, intermediate_bmm2_lhs_dims,
+            intermediate_bmm2_lhs_strides,
+            intermediate_bmm2_lhs_descriptor.type(), bmm2_input_tensor));
+    bmm2_input_tensor = std::move(mask_out);
+  }
+
+  // Create Softmax tensor
+  // The output is always a virtual for inference mode.
+  // The output is always non virtual for training mode. cuz we recompute
+  // dropout in bwd.;
+  bool should_output_softmax = has_activation;
+  TF_ASSIGN_OR_RETURN(auto softmax_fwd_out,
+                      CreateCudnnFlashAttentionSoftmaxFwdTensor(
+                          intermediate_ops, intermediate_bmm2_lhs_dims,
+                          intermediate_bmm2_lhs_strides,
+                          intermediate_bmm2_lhs_descriptor.type(),
+                          /*input_tensor*/ bmm2_input_tensor,
+                          /*is_virtual*/ !should_output_softmax));
+  bmm2_input_tensor = std::move(softmax_fwd_out);
+
+  // Create dropout tensor
+  // dropout is always virtual in inference or training for flash attention
+  TF_ASSIGN_OR_RETURN(auto dropout_out,
+                      CreateCudnnFlashAttentionDropoutFwdTensor(
+                          intermediate_ops, intermediate_bmm2_lhs_dims,
+                          intermediate_bmm2_lhs_strides,
+                          intermediate_bmm2_lhs_descriptor.type(),
+                          /*input_tensor*/ softmax_fwd_out, *dropout_rate));
+  bmm2_input_tensor = std::move(dropout_out);
+
+  std::vector<int64_t> bmm2_rhs_dims =
+      bmm2_rhs_descriptor.GetCudnnCompatibleDimensions(false);
+  std::vector<int64_t> bmm2_rhs_strides =
+      bmm2_rhs_descriptor.GetCudnnCompatibleStrides(false);
+
+  VLOG(2) << "\n cuDNN compatible bmm2_rhs_dims: "
+          << absl::StrJoin(bmm2_rhs_dims, ",")
+          << "\n cuDNN compatible bmm2_rhs_strides: "
+          << absl::StrJoin(bmm2_rhs_strides, ",");
+
+  TF_ASSIGN_OR_RETURN(
+      auto tensor_v,
+      CreateCudnnTensor(bmm2_rhs_dims, bmm2_rhs_strides, CudnnfMHAUid::V_ID,
+                        bmm2_rhs_descriptor.type(), 1, -1));
+
+  std::vector<int64_t> output_dims = output_descriptor.dimensions();
+  std::vector<int64_t> output_strides = output_descriptor.GetLogicalStrides();
+
+  VLOG(2) << "\n Out Dims: " << absl::StrJoin(output_dims, ",")
+          << "\n Out Strides: " << absl::StrJoin(output_strides, ",");
+
+  TF_ASSIGN_OR_RETURN(
+      auto tensor_o,
+      CreateCudnnTensor(output_dims, output_strides, CudnnfMHAUid::O_ID,
+                        output_descriptor.type(), 1, -1));
+  auto bmm2_desc = cudnn_frontend::MatMulDescBuilder()
+                       .setComputeType(CUDNN_DATA_FLOAT)
+                       .build();
+  RETURN_MSG_IF_CUDNN_ERROR(bmm2_desc);
+  auto bmm2_op = cudnn_frontend::OperationBuilder(
+                     CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
+                     .setaMatDesc(bmm2_input_tensor)
+                     .setbMatDesc(tensor_v)
+                     .setcMatDesc(tensor_o)
+                     .setmatmulDesc(bmm2_desc)
+                     .build();
+  RETURN_MSG_IF_CUDNN_ERROR(bmm2_op);
+  // Create an Operation Graph. In this case it is gemm-gemm
+  intermediate_ops.push_back(std::move(bmm2_op));
+  ops.reserve(intermediate_ops.size());
+  for (auto& intermediate_op : intermediate_ops) {
+    ops.emplace_back(&intermediate_op);
+  }
+
+  auto op_graph = cudnn_frontend::OperationGraphBuilder()
+                      .setHandle(cudnn.handle())
+                      .setOperationGraph(ops.size(), ops.data())
+                      .build();
+  RETURN_MSG_IF_CUDNN_ERROR(op_graph);
+  VLOG(4) << "\nTensor_q: " << tensor_q.describe()
+          << "\nTensor_k: " << tensor_k.describe()
+          << "\nTensor_s: " << tensor_s.describe()
+          << "\nTensor_v: " << tensor_v.describe()
+          << "\nTensor_o: " << tensor_o.describe()
+          << "\nBMM1: " << bmm1_desc.describe()
+          << "\nBMM2: " << bmm2_desc.describe()
+          << "\nOpGraph: " << op_graph.describe();
+  return std::make_unique<cudnn_frontend::OperationGraph>(std::move(op_graph));
+}
+
+tsl::StatusOr<cudnn_frontend::Tensor> CreateCudnnFlashAttentionDropoutBwdTensor(
+    std::vector<cudnn_frontend::Operation>& ops, absl::Span<const int64_t> dims,
+    absl::Span<const int64_t> strides, dnn::DataType dtype,
+    cudnn_frontend::Tensor& input_tensor, cudnn_frontend::Tensor& mask_tensor,
+    double dropout_rate) {
+  // Create scale tensor
+  std::vector<int64_t> scale_dims(dims.size(), 1);
+  std::vector<int64_t> scale_strides(strides.size(), 1);
+
+  // Create output tensor of dropout node
+  // it is different from regular attention, the dropout output is always
+  // virtual we compute mask in the bwd instead of storing the mask
+
+  TF_ASSIGN_OR_RETURN(
+      auto dropout_out_tensor,
+      CreateCudnnTensor(
+          dims, strides, CudnnfMHAUid::VIRTUAL_ID + 601, dtype, 1, -1,
+          /*is_virtual*/ true,
+          /*cudnn_tensor_order_type*/
+          cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_F16x16));
+
+  // flash attention TODO: set byValue to true if host pointer is supported
+  // Create offset tensor of dropout node
+  TF_ASSIGN_OR_RETURN(
+      auto dropout_offset_tensor,
+      CreateCudnnTensor(
+          scale_dims, scale_strides, CudnnfMHAUid::D_OFFSET_ID,
+          dnn::DataType::kInt64, 1, -1, /*is_virtual*/ false,
+          /*cudnn_tensor_order_type*/
+          cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_NONE,
+          /*is_value*/ CUDNN_VERSION < 8903 ? false : true));
+
+  // Create seed tensor of dropout node
+  TF_ASSIGN_OR_RETURN(
+      auto dropout_seed_tensor,
+      CreateCudnnTensor(
+          scale_dims, scale_strides, CudnnfMHAUid::D_SEED_ID,
+          dnn::DataType::kInt64, 1, -1, /*is_virtual*/ false,
+          /*cudnn_tensor_order_type*/
+          cudnnBackendTensorReordering_t::CUDNN_TENSOR_REORDERING_NONE,
+          /*is_value*/ CUDNN_VERSION < 8903 ? false : true));
+
+  // Create description for rng node
+  auto rng_desc = cudnn_frontend::RngDescBuilder()
+                      .setRngDistribution(CUDNN_RNG_DISTRIBUTION_BERNOULLI)
+                      .setBernoulliDistProbability(1.0 - dropout_rate)
+                      .build();
+
+  // Create the rng Node.
+  auto rng_op =
+      cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR)
+          .setyDesc(mask_tensor)
+          .setSeedDesc(dropout_seed_tensor)
+          .setOffsetDesc(dropout_offset_tensor)
+          .setRngDesc(rng_desc)
+          .build();
+  RETURN_MSG_IF_CUDNN_ERROR(rng_op);
+
+  // Create the masking node desc after mask tensor
+  TF_ASSIGN_OR_RETURN(auto masking_desc,
+                      CreatePwDesc(dnn::DataType::kFloat, CUDNN_POINTWISE_MUL));
+
+  // Create the scaling op
+  TF_ASSIGN_OR_RETURN(auto masking_op,
+                      CreateBinaryPwOp(input_tensor, mask_tensor,
+                                       dropout_out_tensor, masking_desc));
+
+  TF_ASSIGN_OR_RETURN(
+      auto dropout_scale_tensor,
+      CreateCudnnTensor(
+          scale_dims, scale_strides, CudnnfMHAUid::DROPOUT_SCALE_ID,
+          dnn::DataType::kFloat, 1, -1,
+          /*is_virtual*/ false,
+          /*cudnn_tensor_order_type*/ CUDNN_TENSOR_REORDERING_NONE,
+          /*is_value*/ true));
+
+  // Create output of scale node
+  TF_ASSIGN_OR_RETURN(
+      auto dropout_scale_out_tensor,
+      CreateCudnnTensor(dims, strides, CudnnfMHAUid::VIRTUAL_ID + 602, dtype, 1,
+                        -1, /*is_virtual*/ true));
+  // Create the scaling desc
+  TF_ASSIGN_OR_RETURN(auto scale_desc,
+                      CreatePwDesc(dnn::DataType::kFloat, CUDNN_POINTWISE_MUL));
+
+  // Create the scaling op
+  TF_ASSIGN_OR_RETURN(auto scale_op,
+                      CreateBinaryPwOp(dropout_out_tensor, dropout_scale_tensor,
+                                       dropout_scale_out_tensor, scale_desc));
+  // Add rng op to op list
+  ops.push_back(std::move(rng_op));
+  // Add masking op to op list
+  ops.push_back(std::move(masking_op));
+  // Add scaling op to op list
+  ops.push_back(std::move(scale_op));
+
+  return dropout_scale_out_tensor;
+}
+
+tsl::StatusOr<std::unique_ptr<cudnn_frontend::OperationGraph>>
+GetCudnnFlashAttentionBackwardOperationGraph(
+    const dnn::MatmulTensorDescriptor& bmm1_grad_gemm1_rhs_descriptor,
+    const dnn::MatmulTensorDescriptor& bmm1_grad_gemm2_rhs_descriptor,
+    const dnn::MatmulTensorDescriptor& bmm2_grad_gemm1_lhs_descriptor,
+    const dnn::MatmulTensorDescriptor& bmm2_grad_gemm2_rhs_descriptor,
+    const dnn::MatmulTensorDescriptor& d_output_descriptor,
+    const dnn::TensorDescriptor& d_bmm1_lhs_descriptor,
+    const dnn::TensorDescriptor& d_bmm1_rhs_descriptor,
+    const dnn::TensorDescriptor& d_bmm2_rhs_descriptor, dnn::FusedMHAKind kind,
+    std::optional<double> dropout_rate, std::optional<int64_t> seed,
+    CudnnHandle& cudnn, double scale, std::vector<int64_t>& intermediate_shape,
+    bool use_dropout = false, bool use_mask = false, bool use_bias = false,
+    bool use_causal_mask = false) {
+  if (VLOG_IS_ON(4)) {
+    VLOG(4) << "\n bmm1_grad_gemm1_rhs(q): "
+            << bmm1_grad_gemm1_rhs_descriptor.ToString()
+            << "\n bmm1_grad_gemm2_rhs(k): "
+            << bmm1_grad_gemm2_rhs_descriptor.ToString()
+            << "\n bmm2_grad_gemm1_lhs(p): "
+            << bmm2_grad_gemm1_lhs_descriptor.ToString()
+            << "\n bmm2_grad_gemm2_rhs(v^t): "
+            << bmm2_grad_gemm2_rhs_descriptor.ToString()
+            << "\n d_output(do): " << d_output_descriptor.ToString()
+            << "\n d_bmm1_lhs(dq): " << d_bmm1_lhs_descriptor.ToString()
+            << "\n d_bmm1_rhs(dk): " << d_bmm1_rhs_descriptor.ToString()
+            << "\n d_bmm2_rhs(dv): " << d_bmm2_rhs_descriptor.ToString();
+  }
+  // cnn_infer needs to be preloaded for fMHA as well. Reusing the function
+  // created for convolution for fMHA.
+  PreloadCudnnSubLibsHelper(dnn::ConvolutionKind::FORWARD);
+
+  std::vector<cudnn_frontend::Operation const*> ops;
+  std::vector<cudnn_frontend::Operation> intermediate_ops;
+
+  // fp16 or bf16 is required
+  auto dtype = bmm1_grad_gemm1_rhs_descriptor.type();
+  // create input tensor Q
+  std::vector<int64_t> q_dims =
+      bmm1_grad_gemm1_rhs_descriptor.GetCudnnCompatibleDimensions(false);
+  std::vector<int64_t> q_strides =
+      bmm1_grad_gemm1_rhs_descriptor.GetCudnnCompatibleStrides(false);
+
+  // used for create scale tensor or zero tensor
+  std::vector<int64_t> scale_dims(q_dims.size(), 1);
+  std::vector<int64_t> scale_strides(q_strides.size(), 1);
+
+  VLOG(2) << "\n cuDNN compatible bmm1_grad_gemm1_rhs_dims: "
+          << absl::StrJoin(q_dims, ",")
+          << "\n cuDNN compatible bmm1_grad_gemm1_rhs_strides: "
+          << absl::StrJoin(q_strides, ",");
+
+  TF_ASSIGN_OR_RETURN(
+      auto tensor_q,
+      CreateCudnnTensor(q_dims, q_strides, CudnnfMHAUid::Q_ID, dtype, 1, -1));
+
+  // create input tensor K^T
+  std::vector<int64_t> k_transpose_dims =
+      bmm1_grad_gemm2_rhs_descriptor.GetCudnnCompatibleDimensions(true);
+  std::vector<int64_t> k_transpose_strides =
+      bmm1_grad_gemm2_rhs_descriptor.GetCudnnCompatibleStrides(true);
+
+  VLOG(2) << "\n cuDNN compatible bmm1_grad_gemm2_rhs_dims: "
+          << absl::StrJoin(k_transpose_dims, ",")
+          << "\n cuDNN compatible bmm1_grad_gemm2_rhs_strides: "
+          << absl::StrJoin(k_transpose_strides, ",");
+
+  TF_ASSIGN_OR_RETURN(auto tensor_kt,
+                      CreateCudnnTensor(k_transpose_dims, k_transpose_strides,
+                                        CudnnfMHAUid::K_ID, dtype, 1, -1));
+
+  // P^T is lhs of bmm2grad1 dV = dot(P^T, dO) so we set is_lhs = false here to
+  // get correct P dim and stride
+  std::vector<int64_t> p_dims =
+      bmm2_grad_gemm1_lhs_descriptor.GetCudnnCompatibleDimensions(false);
+  std::vector<int64_t> p_strides =
+      bmm2_grad_gemm1_lhs_descriptor.GetCudnnCompatibleStrides(false);
+
+  // used for calculate offset increment
+  intermediate_shape = p_dims;
+  VLOG(2) << "\n cuDNN compatible bmm2_grad_gemm1_lhs_dims: "
+          << absl::StrJoin(p_dims, ",")
+          << "\n cuDNN compatible bmm2_grad_gemm1_lhs_strides: "
+          << absl::StrJoin(p_strides, ",");
+
+  // create input tensor V^T
+  std::vector<int64_t> v_transpose_dims =
+      bmm2_grad_gemm2_rhs_descriptor.GetCudnnCompatibleDimensions(false);
+  std::vector<int64_t> v_transpose_strides =
+      bmm2_grad_gemm2_rhs_descriptor.GetCudnnCompatibleStrides(false);
+
+  VLOG(2) << "\n cuDNN compatible bmm2_grad_gemm2_rhs_dims: "
+          << absl::StrJoin(v_transpose_dims, ",")
+          << "\n cuDNN compatible bmm2_grad_gemm2_rhs_strides: "
+          << absl::StrJoin(v_transpose_strides, ",");
+
+  TF_ASSIGN_OR_RETURN(auto tensor_vt,
+                      CreateCudnnTensor(v_transpose_dims, v_transpose_strides,
+                                        CudnnfMHAUid::V_ID, dtype, 1, -1));
+
+  // create input tensor dO
+  // FLASH ATTENTION TODO: be really careful here about dim
+  std::vector<int64_t> do_dims =
+      d_output_descriptor.GetCudnnCompatibleDimensions(false);
+  std::vector<int64_t> do_strides =
+      d_output_descriptor.GetCudnnCompatibleStrides(false);
+
+  VLOG(2) << "\n cuDNN compatible d_output_dims: "
+          << absl::StrJoin(do_dims, ",")
+          << "\n cuDNN compatible d_output_strides: "
+          << absl::StrJoin(do_strides, ",");
+
+  TF_ASSIGN_OR_RETURN(auto tensor_do,
+                      CreateCudnnTensor(do_dims, do_strides,
+                                        CudnnfMHAUid::dO_ID, dtype, 1, -1));
+  TF_ASSIGN_OR_RETURN(
+      auto tensor_o,
+      CreateCudnnTensor(do_dims, do_strides, CudnnfMHAUid::O_ID, dtype, 1, -1));
+
+  // create output tensor dQ
+  std::vector<int64_t> dq_dims = d_bmm1_lhs_descriptor.dimensions();
+  std::vector<int64_t> dq_strides = d_bmm1_lhs_descriptor.GetLogicalStrides();
+
+  VLOG(2) << "\n cuDNN compatible d_bmm1_lhs_dims: "
+          << absl::StrJoin(dq_dims, ",")
+          << "\n cuDNN compatible d_bmm1_lhs_strides: "
+          << absl::StrJoin(dq_strides, ",");
+
+  TF_ASSIGN_OR_RETURN(auto tensor_dq,
+                      CreateCudnnTensor(dq_dims, dq_strides,
+                                        CudnnfMHAUid::dQ_ID, dtype, 1, -1));
+
+  // create output tensor dK
+  std::vector<int64_t> dk_dims = d_bmm1_rhs_descriptor.dimensions();
+  std::vector<int64_t> dk_strides = d_bmm1_rhs_descriptor.GetLogicalStrides();
+
+  VLOG(2) << "\n cuDNN compatible d_bmm1_rhs_dims: "
+          << absl::StrJoin(dk_dims, ",")
+          << "\n cuDNN compatible d_bmm1_rhs_strides: "
+          << absl::StrJoin(dk_strides, ",");
+
+  TF_ASSIGN_OR_RETURN(auto tensor_dk,
+                      CreateCudnnTensor(dk_dims, dk_strides,
+                                        CudnnfMHAUid::dK_ID, dtype, 1, -1));
+
+  // create output tensor dV
+  std::vector<int64_t> dv_dims = d_bmm2_rhs_descriptor.dimensions();
+  std::vector<int64_t> dv_strides = d_bmm2_rhs_descriptor.GetLogicalStrides();
+
+  VLOG(2) << "\n cuDNN compatible d_bmm2_rhs_dims: "
+          << absl::StrJoin(dv_dims, ",")
+          << "\n cuDNN compatible d_bmm2_rhs_strides: "
+          << absl::StrJoin(dv_strides, ",");
+
+  TF_ASSIGN_OR_RETURN(auto tensor_dv,
+                      CreateCudnnTensor(dv_dims, dv_strides,
+                                        CudnnfMHAUid::dV_ID, dtype, 1, -1));
+
+  // Begin backward graph creation
+  // dO * O
+  TF_ASSIGN_OR_RETURN(
+      auto tensor_dot_product,
+      CreateCudnnTensor(do_dims, do_strides, CudnnfMHAUid::VIRTUAL_ID + 100,
+                        dnn::DataType::kFloat, 1, -1, /*is_virtual*/ true));
+
+  TF_ASSIGN_OR_RETURN(auto mul_desc,
+                      CreatePwDesc(dnn::DataType::kFloat, CUDNN_POINTWISE_MUL));
+
+  TF_ASSIGN_OR_RETURN(
+      auto mul_op,
+      CreateBinaryPwOp(tensor_do, tensor_o, tensor_dot_product, mul_desc));
+
+  intermediate_ops.push_back(std::move(mul_op));
+
+  // reduction(dO * O)
+  std::vector<int64_t> do_reduction_dims(do_dims.begin(), do_dims.end() - 1);
+  do_reduction_dims.push_back(1);
+
+  // Divide every stride by the last dim value.
+  std::vector<int64_t> do_reduction_strides;
+  do_reduction_strides.reserve(do_strides.size());
+  int64_t reduced_dim_len = do_dims.back();
+  for (auto stride : do_strides) {
+    do_reduction_strides.push_back(stride / reduced_dim_len);
+  }
+
+  TF_ASSIGN_OR_RETURN(
+      auto tensor_dot_product_reduction,
+      CreateCudnnTensor(do_reduction_dims, do_reduction_strides,
+                        CudnnfMHAUid::VIRTUAL_ID + 101, dnn::DataType::kFloat,
+                        1, -1, /*is_virtual*/ true));
+
+  auto reduction_add_desc = cudnn_frontend::ReductionDescBuilder()
+                                .setComputeType(CUDNN_DATA_FLOAT)
+                                .setReductionOp(CUDNN_REDUCE_TENSOR_ADD)
+                                .build();
+  RETURN_MSG_IF_CUDNN_ERROR(reduction_add_desc);
+  auto reduction_add_op = cudnn_frontend::OperationBuilder(
+                              CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)
+                              .setxDesc(tensor_dot_product)
+                              .setyDesc(tensor_dot_product_reduction)
+                              .setreductionDesc(reduction_add_desc)
+                              .build();
+  RETURN_MSG_IF_CUDNN_ERROR(reduction_add_op);
+  intermediate_ops.push_back(std::move(reduction_add_op));
+
+  // reduction(dO * O) * scale prob -> softmax_sum
+  TF_ASSIGN_OR_RETURN(
+      auto tensor_scale_prob,
+      CreateCudnnTensor(
+          scale_dims, scale_strides, CudnnfMHAUid::SCALE_PROB_ID,
+          dnn::DataType::kFloat, 1, -1,
+          /*is_virtual*/ false,
+          /*cudnn_tensor_order_type*/ CUDNN_TENSOR_REORDERING_NONE,
+          /*is_value*/ true));
+
+  TF_ASSIGN_OR_RETURN(
+      auto tensor_softmax_sum,
+      CreateCudnnTensor(do_reduction_dims, do_reduction_strides,
+                        CudnnfMHAUid::S_SUM_ID, dnn::DataType::kFloat, 1, -1,
+                        /*is_virtual*/ false));
+
+  TF_ASSIGN_OR_RETURN(
+      auto mul_0_op,
+      CreateBinaryPwOp(tensor_dot_product_reduction, tensor_scale_prob,
+                       tensor_softmax_sum, mul_desc));
+  intermediate_ops.push_back(std::move(mul_0_op));
+
+  // Q @ K.T -> P
+  TF_ASSIGN_OR_RETURN(
+      auto tensor_p,
+      CreateCudnnTensor(p_dims, p_strides, CudnnfMHAUid::VIRTUAL_ID + 102,
+                        dnn::DataType::kFloat, 1, -1,
+                        /*is_virtual*/ true));
+
+  auto bmm1_desc = cudnn_frontend::MatMulDescBuilder()
+                       .setComputeType(CUDNN_DATA_FLOAT)
+                       .build();
+  RETURN_MSG_IF_CUDNN_ERROR(bmm1_desc);
+  auto bmm1_op = cudnn_frontend::OperationBuilder(
+                     CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
+                     .setaMatDesc(tensor_q)
+                     .setbMatDesc(tensor_kt)
+                     .setcMatDesc(tensor_p)
+                     .setmatmulDesc(bmm1_desc)
+                     .build();
+  RETURN_MSG_IF_CUDNN_ERROR(bmm1_op);
+  intermediate_ops.push_back(std::move(bmm1_op));
+
+  // P * alpha_scale -> p_after_alpha_scale
+  TF_ASSIGN_OR_RETURN(
+      auto tensor_alpha_scale,
+      CreateCudnnTensor(
+          scale_dims, scale_strides, CudnnfMHAUid::ALPHA_SCALE_ID,
+          dnn::DataType::kFloat, 1, -1,
+          /*is_virtual*/ false,
+          /*cudnn_tensor_order_type*/ CUDNN_TENSOR_REORDERING_NONE,
+          /*is_value*/ true));
+
+  TF_ASSIGN_OR_RETURN(
+      auto tensor_p_after_alpha_scale,
+      CreateCudnnTensor(p_dims, p_strides, CudnnfMHAUid::VIRTUAL_ID + 103,
+                        dnn::DataType::kFloat, 1, -1,
+                        /*is_virtual*/ true));
+  TF_ASSIGN_OR_RETURN(auto mul_1_op,
+                      CreateBinaryPwOp(tensor_p, tensor_alpha_scale,
+                                       tensor_p_after_alpha_scale, mul_desc));
+  intermediate_ops.push_back(std::move(mul_1_op));
+
+  if (use_bias) {
+    // bias -> p_after_bias
+    TF_ASSIGN_OR_RETURN(auto tensor_p_after_bias,
+                        CreateCudnnFlashAttentionBiasFwdTensor(
+                            intermediate_ops, p_dims, p_strides, dtype,
+                            tensor_p_after_alpha_scale));
+    tensor_p_after_alpha_scale = std::move(tensor_p_after_bias);
+  }
+  if (use_causal_mask) {
+    // Causal masking -> p_after_mask
+    TF_ASSIGN_OR_RETURN(auto tensor_p_after_causal_mask,
+                        CreateCudnnFlashAttentionCausalMaskTensor(
+                            intermediate_ops, p_dims, p_strides, dtype,
+                            tensor_p_after_alpha_scale));
+    tensor_p_after_alpha_scale = std::move(tensor_p_after_causal_mask);
+  }
+  auto tensor_p_after_bias_or_mask = std::move(tensor_p_after_alpha_scale);
+  // p_after_mask - softmax_stats -> p_after_sub
+  TF_ASSIGN_OR_RETURN(
+      auto tensor_p_after_sub,
+      CreateCudnnTensor(p_dims, p_strides, CudnnfMHAUid::VIRTUAL_ID + 104,
+                        dnn::DataType::kFloat, 1, -1,
+                        /*is_virtual*/ true));
+  TF_ASSIGN_OR_RETURN(
+      auto tensor_softmax_stats,
+      CreateCudnnTensor(do_reduction_dims, do_reduction_strides,
+                        CudnnfMHAUid::P_ID, dnn::DataType::kFloat, 1, -1));
+
+  TF_ASSIGN_OR_RETURN(auto sub_desc,
+                      CreatePwDesc(dnn::DataType::kFloat, CUDNN_POINTWISE_SUB));
+  TF_ASSIGN_OR_RETURN(
+      auto sub_0_op,
+      CreateBinaryPwOp(tensor_p_after_bias_or_mask, tensor_softmax_stats,
+                       tensor_p_after_sub, sub_desc));
+  intermediate_ops.push_back(std::move(sub_0_op));
+
+  // e^(p_after_sub) -> p_after_softmax
+  TF_ASSIGN_OR_RETURN(
+      auto tensor_p_after_softmax,
+      CreateCudnnTensor(p_dims, p_strides, CudnnfMHAUid::VIRTUAL_ID + 105,
+                        dnn::DataType::kFloat, 1, -1,
+                        /*is_virtual*/ true));
+
+  TF_ASSIGN_OR_RETURN(auto exp_0_desc,
+                      CreatePwDesc(dnn::DataType::kFloat, CUDNN_POINTWISE_EXP));
+  TF_ASSIGN_OR_RETURN(
+      auto exp_0_op,
+      CreateUnaryPwOp(tensor_p_after_sub, tensor_p_after_softmax, exp_0_desc));
+  intermediate_ops.push_back(std::move(exp_0_op));
+
+  // Dropout -> p_after_scale_dropout
+  // Create tensor for dropout's mask
+  TF_ASSIGN_OR_RETURN(
+      auto tensor_dropout_mask,
+      CreateCudnnTensor(p_dims, p_strides, CudnnfMHAUid::VIRTUAL_ID + 106,
+                        dnn::DataType::kFloat, 1, -1,
+                        /*is_virtual*/ true));
+  TF_ASSIGN_OR_RETURN(
+      auto tensor_p_after_scale_dropout,
+      CreateCudnnFlashAttentionDropoutBwdTensor(
+          intermediate_ops, p_dims, p_strides, dtype, tensor_p_after_softmax,
+          tensor_dropout_mask, *dropout_rate));
+
+  // after_scale_dropout -> s_transpose
+  auto p_transpose_dims = p_dims;
+  auto p_transpose_strides = p_strides;
+  auto p_rank = p_transpose_dims.size();
+  std::swap(p_transpose_dims[p_rank - 1], p_transpose_dims[p_rank - 2]);
+  std::swap(p_transpose_strides[p_rank - 1], p_transpose_strides[p_rank - 2]);
+  TF_ASSIGN_OR_RETURN(
+      auto tensor_s_transpose,
+      CreateCudnnTensor(p_transpose_dims, p_transpose_strides,
+                        CudnnfMHAUid::VIRTUAL_ID + 107, dtype, 1, -1,
+                        /*is_virtual*/ true));
+  auto reshape_op = cudnn_frontend::OperationBuilder(
+                        CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR)
+                        .setxDesc(tensor_p_after_scale_dropout)
+                        .setyDesc(tensor_s_transpose)
+                        .build();
+  RETURN_MSG_IF_CUDNN_ERROR(reshape_op);
+  intermediate_ops.push_back(std::move(reshape_op));
+
+  // s_transpose @ dO -> dV
+  auto bmm2_grad_gemm1_desc = cudnn_frontend::MatMulDescBuilder()
+                                  .setComputeType(CUDNN_DATA_FLOAT)
+                                  .build();
+  RETURN_MSG_IF_CUDNN_ERROR(bmm2_grad_gemm1_desc);
+  auto bmm2_grad_gemm1_op = cudnn_frontend::OperationBuilder(
+                                CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
+                                .setaMatDesc(tensor_s_transpose)
+                                .setbMatDesc(tensor_do)
+                                .setcMatDesc(tensor_dv)
+                                .setmatmulDesc(bmm2_grad_gemm1_desc)
+                                .build();
+  RETURN_MSG_IF_CUDNN_ERROR(bmm2_grad_gemm1_op);
+  intermediate_ops.push_back(std::move(bmm2_grad_gemm1_op));
+
+  // dO @ V^t -> dS
+  TF_ASSIGN_OR_RETURN(
+      auto tensor_ds,
+      CreateCudnnTensor(p_dims, p_strides, CudnnfMHAUid::VIRTUAL_ID + 108,
+                        dnn::DataType::kFloat, 1, -1,
+                        /*is_virtual*/ true));
+
+  auto bmm2_grad_gemm2_desc = cudnn_frontend::MatMulDescBuilder()
+                                  .setComputeType(CUDNN_DATA_FLOAT)
+                                  .build();
+  RETURN_MSG_IF_CUDNN_ERROR(bmm2_grad_gemm2_desc);
+  auto bmm2_grad_gemm2_op = cudnn_frontend::OperationBuilder(
+                                CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
+                                .setaMatDesc(tensor_do)
+                                .setbMatDesc(tensor_vt)
+                                .setcMatDesc(tensor_ds)
+                                .setmatmulDesc(bmm2_grad_gemm2_desc)
+                                .build();
+  RETURN_MSG_IF_CUDNN_ERROR(bmm2_grad_gemm2_op);
+  intermediate_ops.push_back(std::move(bmm2_grad_gemm2_op));
+
+  // dS * dropout -> dS_after_dropout
+  TF_ASSIGN_OR_RETURN(
+      auto tensor_ds_after_dropout,
+      CreateCudnnTensor(p_dims, p_strides, CudnnfMHAUid::VIRTUAL_ID + 109,
+                        dnn::DataType::kFloat, 1, -1,
+                        /*is_virtual*/ true));
+
+  TF_ASSIGN_OR_RETURN(auto mul_2_op,
+                      CreateBinaryPwOp(tensor_ds, tensor_dropout_mask,
+                                       tensor_ds_after_dropout, mul_desc));
+  intermediate_ops.push_back(std::move(mul_2_op));
+
+  // dS_after_dropout - softmax_sum -> dS_after_sub
+  TF_ASSIGN_OR_RETURN(
+      auto tensor_ds_after_sub,
+      CreateCudnnTensor(p_dims, p_strides, CudnnfMHAUid::VIRTUAL_ID + 110,
+                        dnn::DataType::kFloat, 1, -1,
+                        /*is_virtual*/ true));
+
+  TF_ASSIGN_OR_RETURN(
+      auto sub_1_op,
+      CreateBinaryPwOp(tensor_ds_after_dropout, tensor_softmax_sum,
+                       tensor_ds_after_sub, sub_desc));
+  intermediate_ops.push_back(std::move(sub_1_op));
+
+  // dS_after_sub * p_after_softmax -> dP
+  TF_ASSIGN_OR_RETURN(
+      auto tensor_dp,
+      CreateCudnnTensor(p_dims, p_strides, CudnnfMHAUid::VIRTUAL_ID + 111,
+                        dnn::DataType::kFloat, 1, -1,
+                        /*is_virtual*/ true));
+
+  TF_ASSIGN_OR_RETURN(auto mul_3_op, CreateBinaryPwOp(tensor_ds_after_sub,
+                                                      tensor_p_after_softmax,
+                                                      tensor_dp, mul_desc));
+  intermediate_ops.push_back(std::move(mul_3_op));
+
+  // dP * dropout_scale -> dP_after_dropout_scale
+  // flash attention TODO: make sure the data type is correct here
+  TF_ASSIGN_OR_RETURN(
+      auto tensor_dp_after_dropout_scale,
+      CreateCudnnTensor(p_dims, p_strides, CudnnfMHAUid::VIRTUAL_ID + 112,
+                        dnn::DataType::kFloat, 1, -1,
+                        /*is_virtual*/ true));
+
+  TF_ASSIGN_OR_RETURN(
+      auto tensor_dropout_scale,
+      CreateCudnnTensor(
+          scale_dims, scale_strides, CudnnfMHAUid::DROPOUT_SCALE_ID,
+          dnn::DataType::kFloat, 1, -1,
+          /*is_virtual*/ false,
+          /*cudnn_tensor_order_type*/ CUDNN_TENSOR_REORDERING_NONE,
+          /*is_value*/ true));
+
+  TF_ASSIGN_OR_RETURN(
+      auto mul_4_op, CreateBinaryPwOp(tensor_dp, tensor_dropout_scale,
+                                      tensor_dp_after_dropout_scale, mul_desc));
+  intermediate_ops.push_back(std::move(mul_4_op));
+
+  // dP_after_dropout_scale * alpha_scale -> dP_scaled
+  TF_ASSIGN_OR_RETURN(
+      auto tensor_dp_scaled,
+      CreateCudnnTensor(p_dims, p_strides, CudnnfMHAUid::VIRTUAL_ID + 113,
+                        dnn::DataType::kFloat, 1, -1,
+                        /*is_virtual*/ true));
+  TF_ASSIGN_OR_RETURN(
+      auto mul_5_op,
+      CreateBinaryPwOp(tensor_dp_after_dropout_scale, tensor_alpha_scale,
+                       tensor_dp_scaled, mul_desc));
+  intermediate_ops.push_back(std::move(mul_5_op));
+
+  // K^T -> K
+  auto k_dims = k_transpose_dims;
+  auto k_strides = k_transpose_strides;
+  auto k_rank = k_dims.size();
+  std::swap(k_dims[k_rank - 1], k_dims[k_rank - 2]);
+  std::swap(k_strides[k_rank - 1], k_strides[k_rank - 2]);
+
+  TF_ASSIGN_OR_RETURN(
+      auto tensor_k,
+      CreateCudnnTensor(k_dims, k_strides, CudnnfMHAUid::VIRTUAL_ID + 114,
+                        dtype, 1, -1, /*is_virtual*/ true));
+  auto reshape_1_op = cudnn_frontend::OperationBuilder(
+                          CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR)
+                          .setxDesc(tensor_kt)
+                          .setyDesc(tensor_k)
+                          .build();
+  RETURN_MSG_IF_CUDNN_ERROR(reshape_1_op);
+  intermediate_ops.push_back(std::move(reshape_1_op));
+
+  // dP_scaled @ K -> d_Q_accum
+  auto tensor_d_Q_accum =
+      cudnn_frontend::TensorBuilder()
+          .setDim(dq_dims.size(), dq_dims.data())
+          .setStride(dq_strides.size(), dq_strides.data())
+          .setId(CudnnfMHAUid::d_Q_accum_ID)
+          .setAlignment(16)
+          .setDataType(ToCudnnDataType(dnn::DataType::kFloat))
+          .setVectorCountAndDimension(1, -1)
+          .setVirtual(false)
+          .setReorderType(CUDNN_TENSOR_REORDERING_F16x16)
+          .setByValue(false)
+          .build();
+  RETURN_MSG_IF_CUDNN_ERROR(tensor_d_Q_accum);
+
+  auto bmm1_grad_gemm1_desc = cudnn_frontend::MatMulDescBuilder()
+                                  .setComputeType(CUDNN_DATA_FLOAT)
+                                  .build();
+  RETURN_MSG_IF_CUDNN_ERROR(bmm1_grad_gemm1_desc);
+  auto bmm1_grad_gemm1_op = cudnn_frontend::OperationBuilder(
+                                CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
+                                .setaMatDesc(tensor_dp_scaled)
+                                .setbMatDesc(tensor_k)
+                                .setcMatDesc(tensor_d_Q_accum)
+                                .setmatmulDesc(bmm1_grad_gemm1_desc)
+                                .build();
+  RETURN_MSG_IF_CUDNN_ERROR(bmm1_grad_gemm1_op);
+  intermediate_ops.push_back(std::move(bmm1_grad_gemm1_op));
+
+  // dP_scaled.T @ Q -> dK
+  TF_ASSIGN_OR_RETURN(
+      auto tensor_dp_scaled_transpose,
+      CreateCudnnTensor(p_transpose_dims, p_transpose_strides,
+                        CudnnfMHAUid::VIRTUAL_ID + 115, dnn::DataType::kFloat,
+                        1, -1, /*is_virtual*/ true));
+  auto reshape_2_op = cudnn_frontend::OperationBuilder(
+                          CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR)
+                          .setxDesc(tensor_dp_scaled)
+                          .setyDesc(tensor_dp_scaled_transpose)
+                          .build();
+  RETURN_MSG_IF_CUDNN_ERROR(reshape_2_op);
+  intermediate_ops.push_back(std::move(reshape_2_op));
+
+  auto bmm1_grad_gemm2_desc = cudnn_frontend::MatMulDescBuilder()
+                                  .setComputeType(CUDNN_DATA_FLOAT)
+                                  .build();
+  RETURN_MSG_IF_CUDNN_ERROR(bmm1_grad_gemm2_desc);
+  auto bmm1_grad_gemm2_op = cudnn_frontend::OperationBuilder(
+                                CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
+                                .setaMatDesc(tensor_dp_scaled_transpose)
+                                .setbMatDesc(tensor_q)
+                                .setcMatDesc(tensor_dk)
+                                .setmatmulDesc(bmm1_grad_gemm2_desc)
+                                .build();
+  RETURN_MSG_IF_CUDNN_ERROR(bmm1_grad_gemm2_op);
+  intermediate_ops.push_back(std::move(bmm1_grad_gemm2_op));
+
+  // d_Q_accum @ identity -> dQ
+  TF_ASSIGN_OR_RETURN(
+      auto identity_desc,
+      CreatePwDesc(dnn::DataType::kFloat, CUDNN_POINTWISE_IDENTITY));
+  TF_ASSIGN_OR_RETURN(
+      auto identity_op,
+      CreateUnaryPwOp(tensor_d_Q_accum, tensor_dq, identity_desc));
+  intermediate_ops.push_back(std::move(identity_op));
+
+  ops.reserve(intermediate_ops.size());
+  for (auto& intermediate_op : intermediate_ops) {
+    ops.emplace_back(&intermediate_op);
+  }
+
+  auto op_graph = cudnn_frontend::OperationGraphBuilder()
+                      .setHandle(cudnn.handle())
+                      .setOperationGraph(ops.size(), ops.data())
+                      .build();
+  RETURN_MSG_IF_CUDNN_ERROR(op_graph);
+
+  VLOG(4) << "\nTensor_q: " << tensor_q.describe()
+          << "\nTensor_kt: " << tensor_kt.describe()
+          << "\nTensor_p: " << tensor_p.describe()
+          << "\nTensor_vt: " << tensor_vt.describe()
+          << "\nTensor_do: " << tensor_do.describe()
+          << "\nTensor_o: " << tensor_o.describe()
+          << "\nTensor_dq: " << tensor_dq.describe()
+          << "\nTensor_dk: " << tensor_dk.describe()
+          << "\nTensor_dv: " << tensor_dv.describe()
+          << "\nBMM2_grad_gemm1: " << bmm2_grad_gemm1_desc.describe()
+          << "\nBMM2_grad_gemm2: " << bmm2_grad_gemm2_desc.describe()
+          << "\nBMM1_grad_gemm1: " << bmm1_grad_gemm1_desc.describe()
+          << "\nBMM1_grad_gemm2: " << bmm1_grad_gemm2_desc.describe()
+          << "\nOpGraph: " << op_graph.describe();
+  return std::make_unique<cudnn_frontend::OperationGraph>(std::move(op_graph));
+}
+
 #endif  // CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND
 
 }  // namespace
@@ -6721,10 +7978,10 @@
     auto cudnn = cudnn_->GetHandle(parent_, stream);
 
     size_t workspace_size = plan_.getWorkspaceSize();
+
     RETURN_MSG_IF_CUDNN_ERROR(plan_);
     bool should_add_scalars =
         !scalar_input_uids_.empty() && !scalar_input_values_.empty();
-    RETURN_MSG_IF_CUDNN_ERROR(plan_);
 
     std::vector<int64_t> data_uids_vec = {data_uids_.cbegin(),
                                           data_uids_.cend()};
@@ -6754,9 +8011,8 @@
       data_ptrs_vec.pop_back();
     }
 
-    if (sizeof...(Args) == 7 || sizeof...(Args) == 11) {
-      // is fused attention fwd and bwd
-      // remove empty buffers from the list
+    if (sizeof...(Args) == 7 || sizeof...(Args) == 15) {
+      // is attention fwd or bwd
       data_ptrs_vec.erase(
           std::remove(data_ptrs_vec.begin(), data_ptrs_vec.end(), nullptr),
           data_ptrs_vec.end());
@@ -6777,15 +8033,22 @@
       initial_offset_ += offset_increment_;
       data_uids_vec.push_back(CudnnfMHAUid::D_SEED_ID);
       data_uids_vec.push_back(CudnnfMHAUid::D_OFFSET_ID);
-      data_ptrs_vec.push_back((void*)(&rng_seed_));
-      data_ptrs_vec.push_back((void*)(&initial_offset_));
+      if (is_flash_attention_ && CUDNN_VERSION < 8903) {
+        // flash attention for cuDNN < 8.9.3 only supports dev pointer for seed
+        // and offset
+        data_ptrs_vec.push_back(scratch_memory.opaque());
+        data_ptrs_vec.push_back(static_cast<void*>(
+            static_cast<int64_t*>(scratch_memory.opaque()) + 1));
+      } else {
+        data_ptrs_vec.push_back((void*)(&rng_seed_));
+        data_ptrs_vec.push_back((void*)(&initial_offset_));
+      }
 #else
       return absl::UnimplementedError(
           "Cudnn dropout offset and seed are only supported with Cudnn >= "
           "8.8.");
 #endif  // CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND
     }
-
     auto variantPack =
         cudnn_frontend::VariantPackBuilder()
             .setWorkspacePointer(scratch_memory.opaque())
@@ -6793,7 +8056,6 @@
             .setUids(data_uids_vec.size(), data_uids_vec.data())
             .build();
     RETURN_MSG_IF_CUDNN_ERROR(variantPack);
-
     VLOG(4) << "\nDo cudnn execution plan with plan tag: " << plan_.getTag()
             << "\nWorkspace size in bytes: " << workspace_size
             << "\nVariantPack: " << variantPack.describe();
@@ -6803,6 +8065,16 @@
         std::optional<GpuTimer> timer,
         GpuTimer::CreateIfNeeded(AsGpuStream(stream), is_profiling));
 
+    if (sizeof...(Args) == 15) {
+      // is training
+      if (is_flash_attention_) {
+        // should memset dq_accum because it is being atomic added
+        std::vector<DeviceMemoryBase> dev_mem{inputs...};
+        DeviceMemoryBase* dev_dq_accum = &(dev_mem[10]);
+        stream->ThenMemZero(dev_dq_accum, dev_dq_accum->size());
+      }
+    }
+
     cudnnStatus_t status = cudnnBackendExecute(
         cudnn.handle(), plan_.get_raw_desc(), variantPack.get_raw_desc());
     RETURN_IF_CUDNN_ERROR(status);
@@ -6837,7 +8109,8 @@
              {},
              {},
              0,
-             0}};
+             0,
+             false}};
   }
 
   static tsl::StatusOr<CudnnExecutionPlanRunner> Create(
@@ -6846,12 +8119,13 @@
       bool need_side_input, bool has_activation_output,
       std::vector<int64_t> scalar_input_uids,
       std::vector<ScalingParam> scalar_input_values, int64_t dropout_rng_seed,
-      int64_t dropout_rng_offset) {
+      int64_t dropout_rng_offset, bool is_flash_attention) {
     auto workspace_size = static_cast<uint64_t>(plan.getWorkspaceSize());
     RETURN_MSG_IF_CUDNN_ERROR(plan);
     return {{parent, cudnn, std::move(plan), workspace_size, uids,
              need_side_input, has_activation_output, scalar_input_uids,
-             scalar_input_values, dropout_rng_seed, dropout_rng_offset}};
+             scalar_input_values, dropout_rng_seed, dropout_rng_offset,
+             is_flash_attention}};
   }
 
  private:
@@ -6862,7 +8136,8 @@
                            bool has_activation_output,
                            std::vector<int64_t> scalar_input_uids,
                            std::vector<ScalingParam> scalar_input_values,
-                           int64_t dropout_rng_seed, int64_t dropout_rng_offset)
+                           int64_t dropout_rng_seed, int64_t dropout_rng_offset,
+                           bool is_flash_attention)
       : parent_(parent),
         cudnn_(cudnn),
         plan_(std::move(plan)),
@@ -6873,7 +8148,8 @@
         scalar_input_uids_(scalar_input_uids),
         scalar_input_values_(scalar_input_values),
         offset_increment_(dropout_rng_offset),
-        rng_seed_(dropout_rng_seed) {}
+        rng_seed_(dropout_rng_seed),
+        is_flash_attention_(is_flash_attention) {}
   GpuExecutor* parent_;
   CudnnAccess* cudnn_;
   cudnn_frontend::ExecutionPlan plan_;
@@ -6887,6 +8163,7 @@
   mutable int64_t initial_offset_ = 0;
   int64_t offset_increment_ = 0;
   int64_t rng_seed_;
+  bool is_flash_attention_;
 };
 #endif  // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
 
@@ -7740,20 +9017,31 @@
     std::optional<dnn::TensorDescriptor> activation_descriptor,
     std::optional<dnn::TensorDescriptor> mask_descriptor,
     std::optional<dnn::TensorDescriptor> bias_descriptor, double scale,
-    std::optional<double> dropout_rate, std::optional<int64_t> seed) {
+    std::optional<double> dropout_rate, std::optional<int64_t> seed,
+    bool is_flash_attention, bool is_causal_mask) {
 #if (CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND)
   auto cudnn = cudnn_->GetHandle(parent_, stream);
   bool use_dropout = dropout_rate && *dropout_rate > 0.0;
   std::vector<int64_t> intermediate_shape;
   TF_ASSIGN_OR_RETURN(
       auto op_graph,
-      GetCudnnFusedMHAOperationGraph(
-          bmm1_lhs_descriptor, bmm1_rhs_descriptor, bmm2_rhs_descriptor,
-          intermediate_bmm2_lhs_descriptor, output_descriptor, mask_descriptor,
-          bias_descriptor, activation_descriptor, kind, dropout_rate, seed,
-          cudnn, scale, intermediate_shape, use_dropout,
-          /*use_mask*/ mask_descriptor != std::nullopt,
-          /*use_bias*/ bias_descriptor != std::nullopt));
+      is_flash_attention
+          ? GetCudnnFlashAttentionOperationGraph(
+                bmm1_lhs_descriptor, bmm1_rhs_descriptor, bmm2_rhs_descriptor,
+                intermediate_bmm2_lhs_descriptor, output_descriptor,
+                mask_descriptor, bias_descriptor, activation_descriptor, kind,
+                dropout_rate, seed, cudnn, scale, intermediate_shape,
+                use_dropout,
+                /*use_mask*/ mask_descriptor != std::nullopt,
+                /*use_bias*/ bias_descriptor != std::nullopt, is_causal_mask)
+          : GetCudnnFusedMHAOperationGraph(
+                bmm1_lhs_descriptor, bmm1_rhs_descriptor, bmm2_rhs_descriptor,
+                intermediate_bmm2_lhs_descriptor, output_descriptor,
+                mask_descriptor, bias_descriptor, activation_descriptor, kind,
+                dropout_rate, seed, cudnn, scale, intermediate_shape,
+                use_dropout,
+                /*use_mask*/ mask_descriptor != std::nullopt,
+                /*use_bias*/ bias_descriptor != std::nullopt));
 
   TF_ASSIGN_OR_RETURN(auto execution_plan,
                       GetExecPlanFromHeuristics(std::move(*op_graph), cudnn));
@@ -7771,19 +9059,46 @@
     u_ids.push_back(CudnnfMHAUid::P_ID);
   }
 
-  ScalingParam alpha_scale(scale, bmm1_lhs_descriptor.type());
-  std::vector<ScalingParam> scalar_input_values = {alpha_scale};
-  std::vector<int64_t> scalar_input_uids = {CudnnfMHAUid::ALPHA_SCALE_ID};
+  std::vector<ScalingParam> scalar_input_values;
+  std::vector<int64_t> scalar_input_uids;
+
+  int64_t dropout_rng_seed = seed == std::nullopt ? 0 : *seed;
   int64_t dropout_rng_offset = 0;
 
-  if (use_dropout) {
+  if (is_flash_attention) {
+    ScalingParam alpha_scale(scale, dnn::DataType::kFloat);
+    scalar_input_values = {alpha_scale};
+    scalar_input_uids = {CudnnfMHAUid::ALPHA_SCALE_ID};
     scalar_input_uids.push_back(CudnnfMHAUid::DROPOUT_SCALE_ID);
-    double dropout_scale_value = (1.0 / (1.0 - *dropout_rate));
-    ScalingParam dropout_scale(dropout_scale_value, bmm1_lhs_descriptor.type());
+    // before 8.9.3 it should be half/bf16, after 8.9.3, it could be any type,
+    // use fp32 here
+    double dropout_scale_value =
+        use_dropout ? (1.0f / (1.0f - *dropout_rate)) : 1.0f;
+    ScalingParam dropout_scale(dropout_scale_value, dnn::DataType::kFloat);
     scalar_input_values.push_back(dropout_scale);
     dropout_rng_offset = GetDropoutRngOffset(intermediate_shape);
+
+    if (bias_descriptor == std::nullopt) {
+      // push negative infinity here
+      scalar_input_uids.push_back(CudnnfMHAUid::NEG_INFINITY_ID);
+      double negative_infinity_value = -std::numeric_limits<float>::infinity();
+      ScalingParam negative_infinity(negative_infinity_value,
+                                     dnn::DataType::kFloat);
+      scalar_input_values.push_back(negative_infinity);
+    }
+  } else {
+    ScalingParam alpha_scale(scale, bmm1_lhs_descriptor.type());
+    scalar_input_values = {alpha_scale};
+    scalar_input_uids = {CudnnfMHAUid::ALPHA_SCALE_ID};
+    if (use_dropout) {
+      scalar_input_uids.push_back(CudnnfMHAUid::DROPOUT_SCALE_ID);
+      double dropout_scale_value = 1.0f / (1.0f - *dropout_rate);
+      ScalingParam dropout_scale(dropout_scale_value,
+                                 bmm1_lhs_descriptor.type());
+      scalar_input_values.push_back(dropout_scale);
+      dropout_rng_offset = GetDropoutRngOffset(intermediate_shape);
+    }
   }
-  int64_t dropout_rng_seed = seed == std::nullopt ? 0 : *seed;
 
   TF_ASSIGN_OR_RETURN(
       auto runner,
@@ -7792,8 +9107,7 @@
           /*need_side_input*/ true,
           /*has_activation_output*/ (activation_descriptor != std::nullopt),
           scalar_input_uids, scalar_input_values, dropout_rng_seed,
-          dropout_rng_offset));
-
+          dropout_rng_offset, is_flash_attention));
   return {std::make_unique<CudnnExecutionPlanRunner<dnn::FusedMHASignature>>(
       std::move(runner))};
 #else
@@ -7814,74 +9128,128 @@
     const dnn::TensorDescriptor& d_bmm1_lhs_descriptor,
     const dnn::TensorDescriptor& d_bmm1_rhs_descriptor,
     const dnn::TensorDescriptor& d_bmm2_rhs_descriptor,
-    const dnn::TensorDescriptor& d_s_descriptor,
+    std::optional<dnn::TensorDescriptor> d_s_descriptor,
     std::optional<dnn::TensorDescriptor> mask_descriptor,
-    std::optional<dnn::TensorDescriptor> d_bias_descriptor, double scale,
-    std::optional<double> dropout_rate, std::optional<int64_t> seed) {
-#if (CUDNN_VERSION >= 8901 && TF_ENABLE_CUDNN_FRONTEND)
+    std::optional<dnn::TensorDescriptor> d_bias_descriptor,
+    std::optional<dnn::TensorDescriptor> fwd_output_descriptor,
+    std::optional<dnn::TensorDescriptor> bias_descriptor, double scale,
+    std::optional<double> dropout_rate, std::optional<int64_t> seed,
+    bool is_flash_attention, bool is_causal_mask) {
+#if (CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND)
   auto cudnn = cudnn_->GetHandle(parent_, stream);
 
   bool use_dropout = dropout_rate && *dropout_rate > 0.0;
+  std::vector<int64_t> intermediate_shape;
   TF_ASSIGN_OR_RETURN(
       auto op_graph,
-      GetCudnnFusedMHABackwardOperationGraph(
-          bmm1_grad_gemm1_rhs_descriptor, bmm1_grad_gemm2_rhs_descriptor,
-          bmm2_grad_gemm1_lhs_descriptor, bmm2_grad_gemm2_rhs_descriptor,
-          d_output_descriptor, d_s_descriptor, d_bmm1_lhs_descriptor,
-          d_bmm1_rhs_descriptor, d_bmm2_rhs_descriptor, kind, dropout_rate,
-          seed, cudnn, scale, use_dropout,
-          /*use_mask*/ mask_descriptor != std::nullopt,
-          /*use_bias*/ d_bias_descriptor != std::nullopt));
+      is_flash_attention
+          ? GetCudnnFlashAttentionBackwardOperationGraph(
+                bmm1_grad_gemm1_rhs_descriptor, bmm1_grad_gemm2_rhs_descriptor,
+                bmm2_grad_gemm1_lhs_descriptor, bmm2_grad_gemm2_rhs_descriptor,
+                d_output_descriptor, d_bmm1_lhs_descriptor,
+                d_bmm1_rhs_descriptor, d_bmm2_rhs_descriptor, kind,
+                dropout_rate, seed, cudnn, scale, intermediate_shape,
+                use_dropout,
+                /*use_mask*/ mask_descriptor != std::nullopt,
+                /*use_bias*/ bias_descriptor != std::nullopt, is_causal_mask)
+          : GetCudnnFusedMHABackwardOperationGraph(
+                bmm1_grad_gemm1_rhs_descriptor, bmm1_grad_gemm2_rhs_descriptor,
+                bmm2_grad_gemm1_lhs_descriptor, bmm2_grad_gemm2_rhs_descriptor,
+                d_output_descriptor, d_bmm1_lhs_descriptor,
+                d_bmm1_rhs_descriptor, d_bmm2_rhs_descriptor, kind,
+                dropout_rate, seed, cudnn, scale, intermediate_shape,
+                use_dropout,
+                /*use_mask*/ mask_descriptor != std::nullopt,
+                /*use_bias*/ d_bias_descriptor != std::nullopt));
+  // The function  GetExecPlanFromHeuristics uses
+  // cudnn_frontend::cudnnException which is currently not recommended for
+  // use by Google. Hence commenting out the call.
+  // TODO - Create a status wrapper to wrap the exception to avoid it being
+  // exposed to the runtime. TF_ASSIGN_OR_RETURN(auto execution_plan,
+  //                       GetExecPlanFromHeuristics(std::move(*op_graph),
+  //                       cudnn));
 
   TF_ASSIGN_OR_RETURN(auto execution_plan,
                       GetExecPlanFromHeuristics(std::move(*op_graph), cudnn));
 
-  std::vector<int64_t> scalar_uids = {CudnnfMHAUid::ALPHA_SCALE_ID,
-                                      CudnnfMHAUid::ZERO_VAL_ID,
-                                      CudnnfMHAUid::ONE_VAL_ID};
-  ScalingParam alpha_scale(scale, dnn::DataType::kFloat);
-  double zero_value = 0.0f;
-  ScalingParam zero(zero_value, dnn::DataType::kFloat);
-  double one_value = 1.0f;
-  ScalingParam one(one_value, dnn::DataType::kFloat);
-  std::vector<ScalingParam> scalar_values = {alpha_scale, zero, one};
-
-  // TODO cudnn doesn't support no dropout, so setting dropout rate to 0
-  // here to mimic no dropout. Change this when cudnn graph is more
-  // flexible.
-  scalar_uids.push_back(CudnnfMHAUid::DROPOUT_SCALE_ID);
-  double dropout_scale_value =
-      use_dropout ? (1.0 / (1.0 - *dropout_rate)) : 1.0;
-  ScalingParam dropout_scale(dropout_scale_value, dnn::DataType::kFloat);
-  scalar_values.push_back(dropout_scale);
   int64_t dropout_rng_seed = seed == std::nullopt ? 0 : *seed;
+  int64_t dropout_rng_offset = 0;
+  std::vector<int64_t> scalar_uids;
+  std::vector<ScalingParam> scalar_values;
+  std::vector<int64_t> uids;
 
-  std::vector<int64_t> uids = {
-      CudnnfMHAUid::Q_ID,  CudnnfMHAUid::K_ID,  CudnnfMHAUid::P_ID,
-      CudnnfMHAUid::V_ID,  CudnnfMHAUid::dO_ID, CudnnfMHAUid::dQ_ID,
-      CudnnfMHAUid::dK_ID, CudnnfMHAUid::dV_ID, CudnnfMHAUid::dS_ID};
-  if (mask_descriptor != std::nullopt) {
-    uids.push_back(CudnnfMHAUid::MASK_ID);
-  }
-  if (d_bias_descriptor != std::nullopt) {
-    uids.push_back(CudnnfMHAUid::dBIAS_ID);
-  }
+  if (is_flash_attention) {
+    scalar_uids = {CudnnfMHAUid::ALPHA_SCALE_ID, CudnnfMHAUid::DROPOUT_SCALE_ID,
+                   CudnnfMHAUid::SCALE_PROB_ID};
+    // alpha scale
+    ScalingParam alpha_scale(scale, dnn::DataType::kFloat);
+    // dropout scale
+    double dropout_scale_value =
+        use_dropout ? (1.0f / (1.0f - *dropout_rate)) : 1.0f;
+    ScalingParam dropout_scale(dropout_scale_value, dnn::DataType::kFloat);
+    // scale prob
+    double scale_prob_value = 1.0 - *dropout_rate;
+    ScalingParam scale_prob(scale_prob_value, dnn::DataType::kFloat);
+    scalar_values = {alpha_scale, dropout_scale, scale_prob};
+    // push dropout seed and offset here
+    dropout_rng_offset = GetDropoutRngOffset(intermediate_shape);
+    uids = {
+        CudnnfMHAUid::Q_ID,         CudnnfMHAUid::K_ID,  CudnnfMHAUid::P_ID,
+        CudnnfMHAUid::V_ID,         CudnnfMHAUid::dO_ID, CudnnfMHAUid::dQ_ID,
+        CudnnfMHAUid::dK_ID,        CudnnfMHAUid::dV_ID, CudnnfMHAUid::S_SUM_ID,
+        CudnnfMHAUid::d_Q_accum_ID, CudnnfMHAUid::O_ID};
+    if (bias_descriptor != std::nullopt) {
+      uids.push_back(CudnnfMHAUid::BIAS_ID);
+    } else {
+      // is causal mask
+      // negative infinity
+      double negative_infinity_value = -std::numeric_limits<float>::infinity();
+      ScalingParam negative_infinity(negative_infinity_value,
+                                     dnn::DataType::kFloat);
+      scalar_values.push_back(negative_infinity);
+      scalar_uids.push_back(CudnnfMHAUid::NEG_INFINITY_ID);
+    }
+  } else {
+    // TODO cudnn doesn't support no dropout, so setting dropout rate to 0 here
+    // to mimic no dropout. Change this when cudnn graph is more flexible.
+    scalar_uids = {CudnnfMHAUid::ALPHA_SCALE_ID, CudnnfMHAUid::ZERO_VAL_ID,
+                   CudnnfMHAUid::ONE_VAL_ID, CudnnfMHAUid::DROPOUT_SCALE_ID};
+    ScalingParam alpha_scale(scale, dnn::DataType::kFloat);
+    double zero_value = 0.0f;
+    ScalingParam zero(zero_value, dnn::DataType::kFloat);
+    double one_value = 1.0f;
+    ScalingParam one(one_value, dnn::DataType::kFloat);
+    double dropout_scale_value =
+        use_dropout ? (1.0 / (1.0 - *dropout_rate)) : 1.0;
+    ScalingParam dropout_scale(dropout_scale_value, dnn::DataType::kFloat);
+    scalar_values = {alpha_scale, zero, one, dropout_scale};
 
+    uids = {CudnnfMHAUid::Q_ID,  CudnnfMHAUid::K_ID,  CudnnfMHAUid::P_ID,
+            CudnnfMHAUid::V_ID,  CudnnfMHAUid::dO_ID, CudnnfMHAUid::dQ_ID,
+            CudnnfMHAUid::dK_ID, CudnnfMHAUid::dV_ID, CudnnfMHAUid::dS_ID};
+    if (mask_descriptor != std::nullopt) {
+      uids.push_back(CudnnfMHAUid::MASK_ID);
+    }
+    if (d_bias_descriptor != std::nullopt) {
+      uids.push_back(CudnnfMHAUid::dBIAS_ID);
+    }
+  }
   TF_ASSIGN_OR_RETURN(
       auto runner,
       CudnnExecutionPlanRunner<dnn::FusedMHABackwardSignature>::Create(
           parent_, cudnn_.get(), std::move(execution_plan), uids,
           /*need_side_input*/ true, /*has_activation_output*/ false,
           scalar_uids, scalar_values, dropout_rng_seed,
-          /*dropout_rng_offset*/ 0));
+          /*dropout_rng_offset*/ dropout_rng_offset,
+          /*is_flash_attention*/ is_flash_attention));
   return {std::make_unique<
       CudnnExecutionPlanRunner<dnn::FusedMHABackwardSignature>>(
       std::move(runner))};
 #else
   return absl::UnimplementedError(
-      "Cudnn execution plans with mask input in bwd are only supported with "
-      "Cudnn >= 8.9.1");
-#endif  // CUDNN_VERSION >= 8901 && TF_ENABLE_CUDNN_FRONTEND
+      "Cudnn execution plans with dbias calculation in bwd are only "
+      "supported with Cudnn >= 8.8.");
+#endif  // CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND
 }
 
 bool CudnnSupport::GetRnnAlgorithms(
@@ -8648,7 +10016,9 @@
     if (!stream
              ->ThenBlasGemm(blas::Transpose::kNoTranspose,
                             blas::Transpose::kNoTranspose, m, n, k, weights, m,
-                            input_data, k, output_data, m, NumericOptions{})
+                            input_data, k, output_data, m, NumericOptions{},
+                            blas::CallContext::kNone)
+
              .ok()) {
       return false;
     }
@@ -8728,10 +10098,10 @@
       return ptrs;
     };
 
-    stream->ThenBlasGemmBatched(blas::Transpose::kNoTranspose,
-                                blas::Transpose::kNoTranspose, m, n, k, alpha,
-                                toPtrs(a), lda, toPtrs(b), ldb, beta, toPtrs(c),
-                                ldc, batch_count, NumericOptions{});
+    stream->ThenBlasGemmBatched(
+        blas::Transpose::kNoTranspose, blas::Transpose::kNoTranspose, m, n, k,
+        alpha, toPtrs(a), lda, toPtrs(b), ldb, beta, toPtrs(c), ldc,
+        batch_count, NumericOptions{}, blas::CallContext::kNone);
   }
 
   return stream->ok();
diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h
index 1157093..f662ecf 100644
--- a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h
+++ b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h
@@ -307,7 +307,8 @@
       std::optional<dnn::TensorDescriptor> activation_descriptor,
       std::optional<dnn::TensorDescriptor> mask_descriptor,
       std::optional<dnn::TensorDescriptor> bias_descriptor, double scale,
-      std::optional<double> dropout_rate, std::optional<int64_t> seed) override;
+      std::optional<double> dropout_rate, std::optional<int64_t> seed,
+      bool is_flash_attention, bool is_causal_mask) override;
 
   tsl::StatusOr<std::unique_ptr<const dnn::FusedMHABackwardRunner>>
   FusedMHABackwardRunnerFromDesc(
@@ -321,10 +322,13 @@
       const dnn::TensorDescriptor& d_bmm1_lhs_descriptor,
       const dnn::TensorDescriptor& d_bmm1_rhs_descriptor,
       const dnn::TensorDescriptor& d_bmm2_rhs_descriptor,
-      const dnn::TensorDescriptor& d_s_descriptor,
+      std::optional<dnn::TensorDescriptor> d_s_descriptor,
       std::optional<dnn::TensorDescriptor> mask_descriptor,
-      std::optional<dnn::TensorDescriptor> d_bias_descriptor, double scale,
-      std::optional<double> dropout_rate, std::optional<int64_t> seed) override;
+      std::optional<dnn::TensorDescriptor> d_bias_descriptor,
+      std::optional<dnn::TensorDescriptor> fwd_output_descriptor,
+      std::optional<dnn::TensorDescriptor> bias_descriptor, double scale,
+      std::optional<double> dropout_rate, std::optional<int64_t> seed,
+      bool is_flash_attention, bool is_causal_mask);
   bool GetRnnAlgorithms(
       std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
 
@@ -449,32 +453,6 @@
       std::optional<const DeviceMemory<float>> bias_input,
       std::optional<DeviceMemory<float>> bias_output) override;
 
-  bool DoConvolveQuantized(
-      Stream* stream, const dnn::BatchDescriptor& input_descriptor,
-      const DeviceMemory<float>& input_data,
-      const dnn::FilterDescriptor& filter_descriptor,
-      const DeviceMemory<int8_t>& filter_coefficients,
-      const DeviceMemory<float>& coefficient_scales,
-      const dnn::ConvolutionDescriptor& convolution_descriptor,
-      const dnn::BatchDescriptor& output_descriptor,
-      DeviceMemory<float>* output_data) override {
-    LOG(ERROR) << "DoConvolveQuantized not supported by cuDNN";
-    return false;
-  }
-
-  bool DoConvolveQuantized(
-      Stream* stream, const dnn::BatchDescriptor& input_descriptor,
-      const DeviceMemory<float>& input_data,
-      const dnn::FilterDescriptor& filter_descriptor,
-      const DeviceMemory<int16>& filter_coefficients,
-      const DeviceMemory<float>& coefficient_scales,
-      const dnn::ConvolutionDescriptor& convolution_descriptor,
-      const dnn::BatchDescriptor& output_descriptor,
-      DeviceMemory<float>* output_data) override {
-    LOG(ERROR) << "DoConvolveQuantized not supported by cuDNN";
-    return false;
-  }
-
   bool DoSeparableConvolve(
       Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
       const DeviceMemory<float>& input_data,
diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc
index 2a23f48..bf16b27 100644
--- a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc
+++ b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc
@@ -37,7 +37,6 @@
 #include "absl/synchronization/notification.h"
 #include "third_party/gpus/cuda/include/cuda.h"
 #include "third_party/gpus/cuda/include/cuda_runtime_api.h"
-#include "third_party/gpus/cuda/include/driver_types.h"
 #include "xla/stream_executor/cuda/cuda_diagnostics.h"
 #include "xla/stream_executor/gpu/gpu_driver.h"
 #include "xla/stream_executor/gpu/gpu_types.h"
@@ -415,11 +414,10 @@
   if (context == nullptr) {
     return;
   }
-  CUcontext former_context = CurrentContext();
-  CUresult res = cuCtxSetCurrent(context->context());
+  CUresult res = cuCtxPushCurrent(context->context());
   CUdevice device;
   cuCtxGetDevice(&device);
-  cuCtxSetCurrent(former_context);
+  cuCtxPopCurrent(nullptr);
 
   res = cuDevicePrimaryCtxRelease(device);
 
@@ -1592,21 +1590,15 @@
   ScopedActivateContext activation(context);
   CUresult result;
 
-  // Check if the stream is doing graph capture.
-  cudaStreamCaptureStatus stream_capture_status;
-  cudaError_t err =
-      cudaStreamGetCaptureInfo(stream, &stream_capture_status, /*pId=*/nullptr);
-  if (err != cudaSuccess) {
-    LOG(ERROR) << "Failed to get stream capture info: "
-               << cudaGetErrorString(err);
+  // In graph capture mode we never have operations that access peer memory, so
+  // we can always make a call to cuMemcpyDtoDAsync.
+  tsl::StatusOr<bool> is_capturing = StreamIsCapturing(stream);
+  if (!is_capturing.ok()) {
+    LOG(ERROR) << is_capturing.status().message();
     return false;
   }
 
-  // In graph capture mode we never have operations that access peer memory, so
-  // we can always make a call to cuMemcpyDtoDAsync.
-  bool is_capturing = stream_capture_status == cudaStreamCaptureStatusActive;
-
-  if ((gpu_dst == 0 || gpu_src == 0) || is_capturing) {
+  if ((gpu_dst == 0 || gpu_src == 0) || (*is_capturing)) {
     // CreatedContexts::GetAnyContext() doesn't works when ptr == 0.
     // This happens when the size is 0.
     result = cuMemcpyDtoDAsync(gpu_dst, gpu_src, size, stream);
diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc
index 534ef26..6b5bbab 100644
--- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc
+++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc
@@ -37,7 +37,6 @@
 #include "xla/stream_executor/gpu/gpu_kernel.h"
 #include "xla/stream_executor/gpu/gpu_timer.h"
 #include "xla/stream_executor/gpu/gpu_types.h"
-#include "xla/stream_executor/kernel_cache_config.h"
 #include "xla/stream_executor/platform.h"
 #include "xla/stream_executor/plugin_registry.h"
 #include "xla/stream_executor/stream.h"
diff --git a/third_party/xla/xla/stream_executor/device_description.h b/third_party/xla/xla/stream_executor/device_description.h
index 84653c0..193abd2 100644
--- a/third_party/xla/xla/stream_executor/device_description.h
+++ b/third_party/xla/xla/stream_executor/device_description.h
@@ -68,6 +68,18 @@
     return !(*this < CudaComputeCapability{other_major, other_minor});
   }
 
+  bool IsAtLeastVolta() const {
+    return major >= CudaComputeCapabilities::VOLTA;
+  }
+
+  bool IsAtLeastAmpere() const {
+    return major >= CudaComputeCapabilities::AMPERE;
+  }
+
+  bool IsAtLeastHopper() const {
+    return major >= CudaComputeCapabilities::HOPPER;
+  }
+
   bool operator<(const CudaComputeCapability &other) const {
     return ToPair() < other.ToPair();
   }
diff --git a/third_party/xla/xla/stream_executor/dnn.cc b/third_party/xla/xla/stream_executor/dnn.cc
index 4262e15..0b2f6d7 100644
--- a/third_party/xla/xla/stream_executor/dnn.cc
+++ b/third_party/xla/xla/stream_executor/dnn.cc
@@ -229,7 +229,8 @@
     std::optional<dnn::TensorDescriptor> activation_descriptor,
     std::optional<dnn::TensorDescriptor> mask_descriptor,
     std::optional<dnn::TensorDescriptor> bias_descriptor, double scale,
-    std::optional<double> dropout_rate, std::optional<int64_t> seed) {
+    std::optional<double> dropout_rate, std::optional<int64_t> seed,
+    bool is_flash_attention, bool is_causal_mask) {
   return absl::UnimplementedError("FusedMHARunnerFromDesc not implemented.");
 }
 
@@ -245,10 +246,13 @@
     const TensorDescriptor& d_bmm1_lhs_descriptor,
     const TensorDescriptor& d_bmm1_rhs_descriptor,
     const TensorDescriptor& d_bmm2_rhs_descriptor,
-    const TensorDescriptor& d_s_descriptor,
+    std::optional<dnn::TensorDescriptor> d_s_descriptor,
     std::optional<dnn::TensorDescriptor> mask_descriptor,
-    std::optional<dnn::TensorDescriptor> d_bias_descriptor, double scale,
-    std::optional<double> dropout_rate, std::optional<int64_t> seed) {
+    std::optional<dnn::TensorDescriptor> d_bias_descriptor,
+    std::optional<dnn::TensorDescriptor> fwd_output_descriptor,
+    std::optional<dnn::TensorDescriptor> bias_descriptor, double scale,
+    std::optional<double> dropout_rate, std::optional<int64_t> seed,
+    bool is_flash_attention, bool is_causal_mask) {
   return absl::UnimplementedError(
       "FusedMHABackwardRunnerFromDesc not implemented.");
 }
diff --git a/third_party/xla/xla/stream_executor/dnn.h b/third_party/xla/xla/stream_executor/dnn.h
index 3561a1f..8e92fd5 100644
--- a/third_party/xla/xla/stream_executor/dnn.h
+++ b/third_party/xla/xla/stream_executor/dnn.h
@@ -35,7 +35,6 @@
 #include <vector>
 
 #include "google/protobuf/wrappers.pb.h"
-#include "absl/types/optional.h"
 #include "absl/types/span.h"
 #include "xla/stream_executor/data_type.h"
 #include "xla/stream_executor/device_description.h"
@@ -147,15 +146,6 @@
   kRnnBidirectional = 1,
 };
 
-// Relevant to DepthToSpace and SpaceToDepth. This is the write layout when
-// performing depth to space and the read layout when performing space to depth.
-// It's specified with most-major dimension first and most-minor dimension last.
-// In DepthToSpace, the D*M^2 values are read in and then, for DepthHeightWidth,
-// written out to the output patch, by varying first width, then height, then
-// depth. In C array format, it looks like [depth][height][width]. See
-// DepthToSpace comment for more information.
-enum class DepthToSpaceLayout { DepthHeightWidth };
-
 class TensorDescriptor {
  public:
   TensorDescriptor() = default;
@@ -1005,8 +995,11 @@
     DeviceMemoryBase /* d_output_data */,
     DeviceMemoryBase /* d_BMM1_inputA_data */,
     DeviceMemoryBase /* d_BMM1_inputB_data */,
-    DeviceMemoryBase /* d_BMM2_inputB_data */, DeviceMemoryBase /* d_s_data */,
-    DeviceMemoryBase /* mask_data */, DeviceMemoryBase /* d_bias_data */);
+    DeviceMemoryBase /* d_BMM2_inputB_data */, DeviceMemoryBase /* d_S_data */,
+    DeviceMemoryBase /* softmax_sum_data */,
+    DeviceMemoryBase /* d_Q_accum_data */, DeviceMemoryBase /* mask_data */,
+    DeviceMemoryBase /* d_bias_data */, DeviceMemoryBase /* fwd_output_data */,
+    DeviceMemoryBase /* bias_data */);
 using FusedMHABackwardRunner = OpRunner<FusedMHABackwardSignature>;
 
 // Describes the configuration for the algorithms that will used.
@@ -1667,7 +1660,8 @@
       std::optional<dnn::TensorDescriptor> activation_descriptor,
       std::optional<dnn::TensorDescriptor> mask_descriptor,
       std::optional<dnn::TensorDescriptor> bias_descriptor, double scale,
-      std::optional<double> dropout_rate, std::optional<int64_t> seed);
+      std::optional<double> dropout_rate, std::optional<int64_t> seed,
+      bool is_flash_attention, bool is_causal_mask);
 
   virtual tsl::StatusOr<std::unique_ptr<const dnn::FusedMHABackwardRunner>>
   FusedMHABackwardRunnerFromDesc(
@@ -1681,10 +1675,13 @@
       const TensorDescriptor& d_bmm1_lhs_descriptor,
       const TensorDescriptor& d_bmm1_rhs_descriptor,
       const TensorDescriptor& d_bmm2_rhs_descriptor,
-      const TensorDescriptor& d_s_descriptor,
+      std::optional<dnn::TensorDescriptor> d_s_descriptor,
       std::optional<dnn::TensorDescriptor> mask_descriptor,
-      std::optional<dnn::TensorDescriptor> d_bias_descriptor, double scale,
-      std::optional<double> dropout_rate, std::optional<int64_t> seed);
+      std::optional<dnn::TensorDescriptor> d_bias_descriptor,
+      std::optional<dnn::TensorDescriptor> fwd_output_descriptor,
+      std::optional<dnn::TensorDescriptor> bias_descriptor, double scale,
+      std::optional<double> dropout_rate, std::optional<int64_t> seed,
+      bool is_flash_attention, bool is_causal_mask);
 
   virtual bool GetMIOpenConvolveAlgorithms(
       dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
@@ -1700,32 +1697,6 @@
   // Returns a list of supported rnn algorithms.
   virtual bool GetRnnAlgorithms(std::vector<AlgorithmDesc>* out_algorithms);
 
-  // Version of DoConvolve that uses pre-quantized 8 bit coefficients.
-  // coefficient_scales specifies the scaling of each column of coefficients:
-  // original float coefficient[row * num_columns + column] =
-  //     quantized coefficient[row * num_columns + column] *
-  //     coefficient_scales[column].
-  virtual bool DoConvolveQuantized(
-      Stream* stream, const dnn::BatchDescriptor& input_descriptor,
-      const DeviceMemory<float>& input_data,
-      const dnn::FilterDescriptor& filter_descriptor,
-      const DeviceMemory<int8_t>& filter_coefficients,
-      const DeviceMemory<float>& coefficient_scales,
-      const dnn::ConvolutionDescriptor& convolution_descriptor,
-      const dnn::BatchDescriptor& output_descriptor,
-      DeviceMemory<float>* output_data) = 0;
-
-  // Same as DoConvolveQuantized above, but int8 filter coefficients.
-  virtual bool DoConvolveQuantized(
-      Stream* stream, const dnn::BatchDescriptor& input_descriptor,
-      const DeviceMemory<float>& input_data,
-      const dnn::FilterDescriptor& filter_descriptor,
-      const DeviceMemory<int16>& filter_coefficients,
-      const DeviceMemory<float>& coefficient_scales,
-      const dnn::ConvolutionDescriptor& convolution_descriptor,
-      const dnn::BatchDescriptor& output_descriptor,
-      DeviceMemory<float>* output_data) = 0;
-
   // Variation of the above with the weight matrix split into two matrices.
   // first_weights: Coefficients of the first matrix.
   // second_weights: Coefficients of the second matrix.
@@ -1969,35 +1940,6 @@
       absl::Span<const DeviceMemory<float>* const> input_data,
       DeviceMemory<float>* output_data) = 0;
 
-  // Depth to space takes an X by Y image with depth D*M^2 and changes it to an
-  // MX x MY image with depth D. Each input location (x,y) with depth D*M^2 in
-  // the input image is changed to an MxM contiguous area in the output image,
-  // with the values being laid out in the raster order by DepthToSpaceLayout,
-  // and will have a new depth of D.
-  //
-  // Example.
-  // M=2, Din =8, Xin=2, Yin=2. Xout=4, Yout=4,  Dout=2
-  // DepthHeightWidth layout
-  // Values within a 'cell' are at different depths and same x & y.
-  // Input:
-  // abcdefgh  ijklmnop
-  // qrstuvwx  yz012345
-  // Output:
-  // ae bf im jn
-  // cg dh ko lp
-  // qu rv y2 z3
-  // sw tx 04 15
-  //
-  // sqrt_depth_reduction: 'M' in the comment above
-  virtual bool DoDepthToSpace(Stream* stream,
-                              const dnn::BatchDescriptor& input_dimensions,
-                              const DeviceMemory<float>& input_data,
-                              const DepthToSpaceLayout& depth_to_space_layout,
-                              const int& sqrt_depth_reduction,
-                              DeviceMemory<float>* output_data) {
-    return false;
-  }
-
   // Computes the specified operation (e.g. addition or multiplication)
   // between corresponding elements in the inputs and stores the result in the
   // output element.
diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD
index f133714..cfb6171 100644
--- a/third_party/xla/xla/stream_executor/gpu/BUILD
+++ b/third_party/xla/xla/stream_executor/gpu/BUILD
@@ -100,6 +100,7 @@
         "@com_google_absl//absl/log:check",
         "@com_google_absl//absl/status",
         "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:span",
         "@local_tsl//tsl/platform:env",
         "@local_tsl//tsl/platform:errors",
         "@local_tsl//tsl/platform:status",
diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.cc b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.cc
index 5f07de3..8818808 100644
--- a/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.cc
+++ b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.cc
@@ -27,30 +27,36 @@
 #include "tsl/platform/tensor_float_32_utils.h"
 #endif
 
-namespace stream_executor::gpu {
+namespace stream_executor {
 
-tsl::StatusOr<blas::DataType> AsBlasDataType(xla::PrimitiveType dtype) {
+namespace gpu {
+
+using blas::ComputationType;
+using blas::DataType;
+using xla::PrimitiveType;
+
+tsl::StatusOr<DataType> AsBlasDataType(PrimitiveType dtype) {
   switch (dtype) {
-    case xla::PrimitiveType::F8E5M2:
-      return blas::DataType::kF8E5M2;
-    case xla::PrimitiveType::F8E4M3FN:
-      return blas::DataType::kF8E4M3FN;
-    case xla::PrimitiveType::S8:
-      return blas::DataType::kInt8;
-    case xla::PrimitiveType::F16:
-      return blas::DataType::kHalf;
-    case xla::PrimitiveType::BF16:
-      return blas::DataType::kBF16;
-    case xla::PrimitiveType::F32:
-      return blas::DataType::kFloat;
-    case xla::PrimitiveType::S32:
-      return blas::DataType::kInt32;
-    case xla::PrimitiveType::F64:
-      return blas::DataType::kDouble;
-    case xla::PrimitiveType::C64:
-      return blas::DataType::kComplexFloat;
-    case xla::PrimitiveType::C128:
-      return blas::DataType::kComplexDouble;
+    case PrimitiveType::F8E5M2:
+      return DataType::kF8E5M2;
+    case PrimitiveType::F8E4M3FN:
+      return DataType::kF8E4M3FN;
+    case PrimitiveType::S8:
+      return DataType::kInt8;
+    case PrimitiveType::F16:
+      return DataType::kHalf;
+    case PrimitiveType::BF16:
+      return DataType::kBF16;
+    case PrimitiveType::F32:
+      return DataType::kFloat;
+    case PrimitiveType::S32:
+      return DataType::kInt32;
+    case PrimitiveType::F64:
+      return DataType::kDouble;
+    case PrimitiveType::C64:
+      return DataType::kComplexFloat;
+    case PrimitiveType::C128:
+      return DataType::kComplexDouble;
     default:
       return xla::InternalError(
           "AsBlasDataType: unsupported type: %s",
@@ -58,59 +64,59 @@
   }
 }
 
-tsl::StatusOr<xla::PrimitiveType> AsXlaPrimitiveType(blas::DataType dtype) {
+tsl::StatusOr<PrimitiveType> AsXlaPrimitiveType(DataType dtype) {
   switch (dtype) {
-    case blas::DataType::kF8E5M2:
-      return xla::PrimitiveType::F8E5M2;
-    case blas::DataType::kF8E4M3FN:
-      return xla::PrimitiveType::F8E4M3FN;
-    case blas::DataType::kInt8:
-      return xla::PrimitiveType::S8;
-    case blas::DataType::kHalf:
-      return xla::PrimitiveType::F16;
-    case blas::DataType::kBF16:
-      return xla::PrimitiveType::BF16;
-    case blas::DataType::kFloat:
-      return xla::PrimitiveType::F32;
-    case blas::DataType::kInt32:
-      return xla::PrimitiveType::S32;
-    case blas::DataType::kDouble:
-      return xla::PrimitiveType::F64;
-    case blas::DataType::kComplexFloat:
-      return xla::PrimitiveType::C64;
-    case blas::DataType::kComplexDouble:
-      return xla::PrimitiveType::C128;
+    case DataType::kF8E5M2:
+      return PrimitiveType::F8E5M2;
+    case DataType::kF8E4M3FN:
+      return PrimitiveType::F8E4M3FN;
+    case DataType::kInt8:
+      return PrimitiveType::S8;
+    case DataType::kHalf:
+      return PrimitiveType::F16;
+    case DataType::kBF16:
+      return PrimitiveType::BF16;
+    case DataType::kFloat:
+      return PrimitiveType::F32;
+    case DataType::kInt32:
+      return PrimitiveType::S32;
+    case DataType::kDouble:
+      return PrimitiveType::F64;
+    case DataType::kComplexFloat:
+      return PrimitiveType::C64;
+    case DataType::kComplexDouble:
+      return PrimitiveType::C128;
     default:
       return xla::InternalError("AsXlaPrimitiveType: unsupported dtype");
   }
 }
 
-tsl::StatusOr<blas::ComputationType> GetBlasComputationType(
-    xla::PrimitiveType lhs_dtype, xla::PrimitiveType output_dtype,
+tsl::StatusOr<ComputationType> GetBlasComputationType(
+    PrimitiveType lhs_dtype, PrimitiveType output_dtype,
     int64_t compute_precision) {
   switch (output_dtype) {
-    case xla::PrimitiveType::F8E5M2:    // fall-through
-    case xla::PrimitiveType::F8E4M3FN:  // fall-through
-    case xla::PrimitiveType::F16:       // fall-through
-    case xla::PrimitiveType::BF16:
+    case PrimitiveType::F8E5M2:    // fall-through
+    case PrimitiveType::F8E4M3FN:  // fall-through
+    case PrimitiveType::F16:       // fall-through
+    case PrimitiveType::BF16:
       // Accumulate in f32 precision.
-      return blas::ComputationType::kF32;
-    case xla::PrimitiveType::F32:  // fall-through
-    case xla::PrimitiveType::C64:
+      return ComputationType::kF32;
+    case PrimitiveType::F32:  // fall-through
+    case PrimitiveType::C64:
 #if GOOGLE_CUDA
       if (tsl::tensor_float_32_execution_enabled() && compute_precision <= 1 &&
           lhs_dtype == output_dtype) {
         // CublasLt requires compute type to be F32 for F8 matmul.
         // TF32 should only be chosen for FP32 or C64 gemm
-        return blas::ComputationType::kTF32AsF32;
+        return ComputationType::kTF32AsF32;
       }
 #endif
-      return blas::ComputationType::kF32;
-    case xla::PrimitiveType::F64:  // fall-through
-    case xla::PrimitiveType::C128:
-      return blas::ComputationType::kF64;
-    case xla::PrimitiveType::S32:
-      return blas::ComputationType::kI32;
+      return ComputationType::kF32;
+    case PrimitiveType::F64:  // fall-through
+    case PrimitiveType::C128:
+      return ComputationType::kF64;
+    case PrimitiveType::S32:
+      return ComputationType::kI32;
     default:
       return xla::InternalError("GetBlasComputationType: unsupported type");
   }
@@ -149,12 +155,13 @@
   return (blas != nullptr ? blas->GetBlasLt() : nullptr);
 }
 
-blas::DataType GetScaleType(blas::DataType c_type,
-                            blas::ComputationType computation_type) {
-  return ((computation_type == blas::ComputationType::kF32) &&
-          (c_type != blas::DataType::kComplexFloat))
-             ? blas::DataType::kFloat
-             : c_type;
+DataType GetScaleType(DataType c_type, ComputationType computation_type) {
+  return (computation_type == ComputationType::kF32 &&
+                  c_type != DataType::kComplexFloat
+              ? DataType::kFloat
+              : c_type);
 }
 
-}  // namespace stream_executor::gpu
+}  // namespace gpu
+
+}  // namespace stream_executor
diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h
index ab2412d..40adeec 100644
--- a/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h
+++ b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h
@@ -86,6 +86,8 @@
   double beta;
   int64_t compute_precision;
   std::optional<int64_t> algorithm;
+  bool grad_x;
+  bool grad_y;
   std::optional<blas::ComputationType> compute_type;
 };
 
diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc
index 45398f4..6a5e006 100644
--- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc
+++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc
@@ -24,6 +24,7 @@
 #include "absl/log/log.h"
 #include "absl/status/status.h"
 #include "absl/strings/str_cat.h"
+#include "absl/types/span.h"
 #include "xla/stream_executor/command_buffer.h"
 #include "xla/stream_executor/gpu/gpu_driver.h"
 #include "xla/stream_executor/gpu/gpu_executor.h"
@@ -140,6 +141,11 @@
   return tsl::OkStatus();
 }
 
+absl::Span<GpuGraphNodeHandle> GpuCommandBuffer::GetDependencies() {
+  return nodes_.empty() ? absl::Span<GpuGraphNodeHandle>()
+                        : absl::Span<GpuGraphNodeHandle>(&nodes_.back(), 1);
+}
+
 tsl::Status GpuCommandBuffer::CheckNotFinalized() {
   if (state_ == State::kFinalized)
     return absl::InternalError(
@@ -167,11 +173,12 @@
 
   // Adds a new kernel node to the graph under construction.
   if (state_ == State::kCreate) {
+    absl::Span<GpuGraphNodeHandle> deps = GetDependencies();
     GpuGraphNodeHandle* node = &nodes_.emplace_back();
     return GpuDriver::GraphAddKernelNode(
-        node, graph_, {}, kernel.name(), gpu_func, blocks.x, blocks.y, blocks.z,
-        threads.x, threads.y, threads.z, args.number_of_shared_bytes(),
-        kernel_params, /*extra=*/nullptr);
+        node, graph_, deps, kernel.name(), gpu_func, blocks.x, blocks.y,
+        blocks.z, threads.x, threads.y, threads.z,
+        args.number_of_shared_bytes(), kernel_params, /*extra=*/nullptr);
   }
 
   // Updates kernel node in the executable graph.
@@ -193,9 +200,10 @@
 
   // Adds a child graph node to the graph under construction.
   if (state_ == State::kCreate) {
+    absl::Span<GpuGraphNodeHandle> deps = GetDependencies();
     GpuGraphNodeHandle* node = &nodes_.emplace_back();
     return GpuDriver::GraphAddChildNode(
-        node, graph_, {}, GpuCommandBuffer::Cast(&nested)->graph());
+        node, graph_, deps, GpuCommandBuffer::Cast(&nested)->graph());
   }
 
   return UnsupportedStateError(state_);
@@ -208,9 +216,10 @@
 
   // Adds a new memcpy node to the graph under construction.
   if (state_ == State::kCreate) {
+    absl::Span<GpuGraphNodeHandle> deps = GetDependencies();
     GpuGraphNodeHandle* node = &nodes_.emplace_back();
     return GpuDriver::GraphAddMemcpyD2DNode(parent_->gpu_context(), node,
-                                            graph_, {}, AsDevicePtr(*dst),
+                                            graph_, deps, AsDevicePtr(*dst),
                                             AsDevicePtr(src), size);
   }
 
diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h
index aade02c..0c761f4 100644
--- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h
+++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h
@@ -21,6 +21,7 @@
 #include <vector>
 
 #include "absl/functional/any_invocable.h"
+#include "absl/types/span.h"
 #include "xla/stream_executor/command_buffer.h"
 #include "xla/stream_executor/gpu/gpu_executor.h"
 #include "xla/stream_executor/gpu/gpu_types.h"
@@ -80,6 +81,11 @@
   }
 
  private:
+  // TODO(ezhulenev): Currently we serialize all Gpu nodes by adding a
+  // dependency between all nodes added to a command buffer. We need a concept
+  // of a barrier at a command buffer level.
+  absl::Span<GpuGraphNodeHandle> GetDependencies();
+
   // Returns OK status if command buffer is not finalized and it is still
   // possible to add new commands to it, otherwise returns internal error.
   tsl::Status CheckNotFinalized();
diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc b/third_party/xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc
index 01e1d1d..7735ae3 100644
--- a/third_party/xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc
+++ b/third_party/xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc
@@ -121,8 +121,8 @@
   (void)reserve_memory_;
 
 #if TF_CUDA_MALLOC_ASYNC_SUPPORTED
-  stream_exec_ = DeviceIdUtil::ExecutorForPlatformDeviceId(
-                     GPUMachineManager(), platform_device_id)
+  stream_exec_ = DeviceIdUtil::ExecutorForPlatformDeviceId(GPUMachineManager(),
+                                                           platform_device_id)
                      .value();
   // Initialized here as it only exist if compiled with a recent
   // enough CUDA.
@@ -298,9 +298,17 @@
   }
   cuda::ScopedActivateExecutorContext scoped_activation{stream_exec_};
   void* ptr = nullptr;
-  if (auto result =
-          cuMemAllocFromPoolAsync(reinterpret_cast<CUdeviceptr*>(&ptr),
-                                  num_bytes, pool_, cuda_stream_)) {
+  auto result = cuMemAllocFromPoolAsync(reinterpret_cast<CUdeviceptr*>(&ptr),
+                                        num_bytes, pool_, cuda_stream_);
+  if (result == CUDA_ERROR_OUT_OF_MEMORY) {
+    // Doing a stream synchronization give the driver more flexibility
+    // for blocks coalescing and doing memory remapping. So it can
+    // solve some OOM cases when memory is tight.
+    cuStreamSynchronize(cuda_stream_);
+    result = cuMemAllocFromPoolAsync(reinterpret_cast<CUdeviceptr*>(&ptr),
+                                     num_bytes, pool_, cuda_stream_);
+  }
+  if (result) {
     size_t free, total;
     cuMemGetInfo(&free, &total);
     LOG(ERROR) << Name() << " cuMemAllocAsync failed to allocate " << num_bytes
diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h
index c5ae523..b5fdb64 100644
--- a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h
+++ b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h
@@ -288,6 +288,9 @@
     return it->second;
   }
 
+  int cc_major() const { return cc_major_; }
+  int cc_minor() const { return cc_minor_; }
+
  private:
   // Host callback landing routine invoked by CUDA.
   // data: User-provided callback provided to HostCallback() above, captured
diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_graph.cc b/third_party/xla/xla/stream_executor/gpu/gpu_graph.cc
index c96ca81..ef72657 100644
--- a/third_party/xla/xla/stream_executor/gpu/gpu_graph.cc
+++ b/third_party/xla/xla/stream_executor/gpu/gpu_graph.cc
@@ -127,8 +127,7 @@
   return tsl::errors::Internal("Unexpected value for GraphNodeType");
 }
 
-tsl::StatusOr<OwnedGpuGraphExec::UpdateResult> OwnedGpuGraphExec::Update(
-    OwnedGpuGraph graph) {
+tsl::Status OwnedGpuGraphExec::Update(OwnedGpuGraph graph) {
   VLOG(3) << "Update gpu graph exec with a new graph after " << num_launches_
           << " launches since last update"
           << " #" << num_updates_++;
@@ -141,33 +140,7 @@
   auto st = GpuDriver::GraphExecUpdate(get(), graph.get(), &result);
   uint64_t end_nanos = tsl::Env::Default()->NowNanos();
 
-  VLOG(5) << "Updated gpu graph exec #" << id_ << " (took "
-          << (end_nanos - start_nanos) / 1000 << " us)";
-
-  // TODO(b/297051365): We currently fallback to op-by-op mode because CUBLAS
-  // generate a memset node which causes graph update to fail. We should remove
-  // the fallback mechanism once cuBLAS completely works in gpu graphs.
-  auto compute_should_fallback = [&]() -> tsl::StatusOr<bool> {
-    if (result.result != GpuDriver::GraphExecUpdateResult::kError &&
-        result.result != GpuDriver::GraphExecUpdateResult::kParametersChanged)
-      return false;
-
-    if (result.error_node == nullptr) return false;
-    TF_ASSIGN_OR_RETURN(GpuDriver::GraphNodeType node_type,
-                        GpuDriver::GraphNodeGetType(result.error_node));
-    if (node_type != GpuDriver::GraphNodeType::kMemset) return false;
-
-    if (result.error_from_node != nullptr) return false;
-
-    return true;
-  };
-
-  if (!st.ok() || result.result != GpuDriver::GraphExecUpdateResult::kSuccess) {
-    TF_ASSIGN_OR_RETURN(bool should_fallback, compute_should_fallback());
-    if (should_fallback) {
-      return UpdateResult::kFallback;
-    }
-
+  if (!st.ok()) {
     TF_ASSIGN_OR_RETURN(std::string result_str,
                         GraphExecUpdateResultToString(result.result));
     std::string error_message = absl::StrCat(
@@ -193,7 +166,10 @@
     return tsl::errors::Internal(error_message);
   }
 
-  return UpdateResult::kSuccess;
+  VLOG(5) << "Updated gpu graph exec #" << id_ << " (took "
+          << (end_nanos - start_nanos) / 1000 << " us)";
+
+  return tsl::OkStatus();
 }
 
 tsl::Status OwnedGpuGraphExec::Launch(stream_executor::Stream* stream) {
diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_graph.h b/third_party/xla/xla/stream_executor/gpu/gpu_graph.h
index af58a41..a897164 100644
--- a/third_party/xla/xla/stream_executor/gpu/gpu_graph.h
+++ b/third_party/xla/xla/stream_executor/gpu/gpu_graph.h
@@ -87,11 +87,9 @@
   OwnedGpuGraphExec(OwnedGpuGraphExec&&) = default;
   OwnedGpuGraphExec& operator=(OwnedGpuGraphExec&&) = default;
 
-  enum class UpdateResult { kSuccess, kFallback };
-
   // Updates executable graph instance with a newly captured graph. Returns an
   // error if the new graph is not compatible (see `cudaGraphExecUpdate`).
-  tsl::StatusOr<UpdateResult> Update(OwnedGpuGraph graph);
+  tsl::Status Update(OwnedGpuGraph graph);
 
   // Launches captured graph on a given stream.
   tsl::Status Launch(stream_executor::Stream* stream);
diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_kernel.h b/third_party/xla/xla/stream_executor/gpu/gpu_kernel.h
index b7fa33c..146f7ff 100644
--- a/third_party/xla/xla/stream_executor/gpu/gpu_kernel.h
+++ b/third_party/xla/xla/stream_executor/gpu/gpu_kernel.h
@@ -23,7 +23,7 @@
 #define XLA_STREAM_EXECUTOR_GPU_GPU_KERNEL_H_
 
 #include "xla/stream_executor/gpu/gpu_driver.h"
-#include "xla/stream_executor/kernel_cache_config.h"
+#include "xla/stream_executor/kernel.h"
 #include "xla/stream_executor/platform/port.h"
 #include "xla/stream_executor/stream_executor_internal.h"
 #include "tsl/platform/logging.h"
diff --git a/third_party/xla/xla/stream_executor/kernel.h b/third_party/xla/xla/stream_executor/kernel.h
index 4a9f7fb..9626990 100644
--- a/third_party/xla/xla/stream_executor/kernel.h
+++ b/third_party/xla/xla/stream_executor/kernel.h
@@ -81,7 +81,6 @@
 #include "absl/strings/string_view.h"
 #include "absl/types/span.h"
 #include "xla/stream_executor/device_memory.h"
-#include "xla/stream_executor/kernel_cache_config.h"
 #include "xla/stream_executor/platform/port.h"
 
 namespace stream_executor {
@@ -95,6 +94,23 @@
 class KernelInterface;
 }  // namespace internal
 
+// This enum represents potential configurations of L1/shared memory when
+// running a particular kernel. These values represent user preference, and
+// the runtime is not required to respect these choices.
+enum class KernelCacheConfig {
+  // Indicates no preference for device L1/shared memory configuration.
+  kNoPreference,
+
+  // Indicates a preference for more shared memory than L1 cache.
+  kPreferShared,
+
+  // Indicates a preference for more L1 cache than shared memory.
+  kPreferL1,
+
+  // Indicates a preference for equal amounts of L1 cache and shared memory.
+  kPreferEqual,
+};
+
 // KernelMetadata holds runtime-queryable attributes of a loaded kernel, such as
 // registers allocated, shared memory used, etc.
 // Not all platforms support reporting of all information, so each accessor
diff --git a/third_party/xla/xla/stream_executor/kernel_cache_config.h b/third_party/xla/xla/stream_executor/kernel_cache_config.h
deleted file mode 100644
index f72b3fc..0000000
--- a/third_party/xla/xla/stream_executor/kernel_cache_config.h
+++ /dev/null
@@ -1,42 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// This file contains declarations relating to kernel cache configuration
-// parameters recognized by the StreamExecutor.
-#ifndef XLA_STREAM_EXECUTOR_KERNEL_CACHE_CONFIG_H_
-#define XLA_STREAM_EXECUTOR_KERNEL_CACHE_CONFIG_H_
-
-namespace stream_executor {
-
-// This enum represents potential configurations of L1/shared memory when
-// running a particular kernel. These values represent user preference, and
-// the runtime is not required to respect these choices.
-enum class KernelCacheConfig {
-  // Indicates no preference for device L1/shared memory configuration.
-  kNoPreference,
-
-  // Indicates a preference for more shared memory than L1 cache.
-  kPreferShared,
-
-  // Indicates a preference for more L1 cache than shared memory.
-  kPreferL1,
-
-  // Indicates a preference for equal amounts of L1 cache and shared memory.
-  kPreferEqual,
-};
-
-}  // namespace stream_executor
-
-#endif  // XLA_STREAM_EXECUTOR_KERNEL_CACHE_CONFIG_H_
diff --git a/third_party/xla/xla/stream_executor/lazy_op_runner.h b/third_party/xla/xla/stream_executor/lazy_op_runner.h
index 2c390d7..609eb67 100644
--- a/third_party/xla/xla/stream_executor/lazy_op_runner.h
+++ b/third_party/xla/xla/stream_executor/lazy_op_runner.h
@@ -244,6 +244,8 @@
     std::optional<TensorDescriptor> activation_descriptor;
     std::optional<double> dropout_rate;
     std::optional<int64_t> seed;
+    bool is_flash_attention;
+    bool is_causal_mask;
   };
 
   static tsl::StatusOr<std::unique_ptr<const OpRunner<FusedMHASignature>>>
@@ -254,7 +256,8 @@
         config.bmm1_rhs_descriptor, config.bmm2_rhs_descriptor,
         config.intermediate_bmm2_lhs_descriptor, config.output_descriptor,
         config.activation_descriptor, config.mask_descriptor,
-        config.bias_descriptor, config.scale, config.dropout_rate, config.seed);
+        config.bias_descriptor, config.scale, config.dropout_rate, config.seed,
+        config.is_flash_attention, config.is_causal_mask);
   }
 };
 
@@ -272,11 +275,15 @@
     const TensorDescriptor& d_bmm1_lhs_descriptor;
     const TensorDescriptor& d_bmm1_rhs_descriptor;
     const TensorDescriptor& d_bmm2_rhs_descriptor;
-    const TensorDescriptor& d_s_descriptor;
+    std::optional<TensorDescriptor> d_s_descriptor;
     std::optional<TensorDescriptor> mask_descriptor;
     std::optional<TensorDescriptor> d_bias_descriptor;
+    std::optional<TensorDescriptor> fwd_output_descriptor;
+    std::optional<TensorDescriptor> bias_descriptor;
     std::optional<double> dropout_rate;
     std::optional<int64_t> seed;
+    bool is_flash_attention;
+    bool is_causal_mask;
   };
 
   static tsl::StatusOr<
@@ -290,8 +297,10 @@
         config.bmm2_grad_gemm2_rhs_descriptor, config.d_output_descriptor,
         config.d_bmm1_lhs_descriptor, config.d_bmm1_rhs_descriptor,
         config.d_bmm2_rhs_descriptor, config.d_s_descriptor,
-        config.mask_descriptor, config.d_bias_descriptor, config.scale,
-        config.dropout_rate, config.seed);
+        config.mask_descriptor, config.d_bias_descriptor,
+        config.fwd_output_descriptor, config.bias_descriptor, config.scale,
+        config.dropout_rate, config.seed, config.is_flash_attention,
+        config.is_causal_mask);
   }
 };
 
diff --git a/third_party/xla/xla/stream_executor/multi_platform_manager.h b/third_party/xla/xla/stream_executor/multi_platform_manager.h
index f428bbf..b326983 100644
--- a/third_party/xla/xla/stream_executor/multi_platform_manager.h
+++ b/third_party/xla/xla/stream_executor/multi_platform_manager.h
@@ -13,9 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 
-// This is a registration-oriented interface for multiple platforms. It will
-// replace the MachineManager singleton interface, as MachineManager does not
-// currently support simultaneous use of multiple platforms.
+// This is a registration-oriented interface for multiple platforms.
 //
 // Usage:
 //
@@ -71,7 +69,6 @@
 #include "absl/strings/string_view.h"
 #include "xla/stream_executor/platform.h"
 #include "xla/stream_executor/platform/initialize.h"
-#include "xla/stream_executor/platform/port.h"
 #include "tsl/platform/status.h"
 #include "tsl/platform/statusor.h"
 
diff --git a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc
index 241bbb8..6a20995 100644
--- a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc
+++ b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc
@@ -21,8 +21,6 @@
 #include "rocm/rocm_config.h"
 #include "xla/primitive_util.h"
 #include "xla/status_macros.h"
-#include "xla/stream_executor/blas.h"
-#include "xla/stream_executor/rocm/hip_blas_utils.h"
 #include "xla/util.h"
 
 #if TF_HIPBLASLT
@@ -137,13 +135,13 @@
                              ? m.num_cols
                              : m.num_rows;
   }
-
+  auto hipblas_data_type_ = AsHipblasDataType(type);
   hipblasLtMatrixLayout_t hip_layout;
   SE_HIPBLAS_RETURN_IF_ERROR(wrap::hipblasLtMatrixLayoutCreate(
-      &hip_layout, AsHipblasDataType(type), m.num_rows, m.num_cols,
+      &hip_layout, hipblas_data_type_, m.num_rows, m.num_cols,
       *leading_dim_stride));
   // Wrap hipblas handle immediately, so it is cleaned up if an error occurs.
-  BlasLt::MatrixLayout layout(hip_layout);
+  BlasLt::MatrixLayout layout(hip_layout, hipblas_data_type_);
   if (m.order != gpu::MatrixLayout::Order::kColumnMajor)
     return tsl::errors::Internal(
         "HipblasLT does not support row-major matrices");
@@ -167,14 +165,16 @@
   VLOG(2) << "BlasLt::MatmulDesc::Create compute_type" << int(compute_type)
           << " scale_type " << int(scale_type) << " epilogue " << int(epilogue)
           << " pointer_mode " << int(pointer_mode);
+  auto hip_scale_type = AsHipblasDataType(scale_type);
+  auto hip_compute_type = AsHipblasComputeType(compute_type);
   SE_HIPBLAS_RETURN_IF_ERROR(wrap::hipblasLtMatmulDescCreate(
-      &hip_desc, AsHipblasComputeType(compute_type),
-      AsHipblasDataType(scale_type)));
+      &hip_desc, hip_compute_type, hip_scale_type));
   // Wrap hipblas handle immediately, so it is cleaned up if an error occurs.
-  BlasLt::MatmulDesc desc(hip_desc);
+  BlasLt::MatmulDesc desc(hip_desc, hip_compute_type, hip_scale_type);
   if (pointer_mode != PointerMode::kHost) {
     return tsl::errors::Internal("hipblaslt does not support device pointers");
   }
+
   TF_RETURN_IF_ERROR(SetAttr(hip_desc, HIPBLASLT_MATMUL_DESC_TRANSA,
                              AsHipblasOperation(trans_a)));
   TF_RETURN_IF_ERROR(SetAttr(hip_desc, HIPBLASLT_MATMUL_DESC_TRANSB,
@@ -421,39 +421,31 @@
 
 namespace {
 
-using cudaDataType_t = hipblasDatatype_t;
-#define CUDA_R_16BF HIPBLAS_R_16B
-#define CUDA_R_16F HIPBLAS_R_16F
-#define CUDA_R_32F HIPBLAS_R_32F
-#define CUDA_R_64F HIPBLAS_R_64F
-#define CUDA_C_32F HIPBLAS_C_32F
-#define CUDA_C_64F HIPBLAS_C_64F
-
-template <cudaDataType_t CudaT>
-struct CudaToNativeT;
+template <hipblasltDatatype_t>
+struct HipToNativeT;
 
 template <>
-struct CudaToNativeT<CUDA_R_16BF> {
+struct HipToNativeT<HIPBLASLT_R_16B> {
   using type = Eigen::bfloat16;
 };
 template <>
-struct CudaToNativeT<CUDA_R_16F> {
+struct HipToNativeT<HIPBLASLT_R_16F> {
   using type = Eigen::half;
 };
 template <>
-struct CudaToNativeT<CUDA_R_32F> {
+struct HipToNativeT<HIPBLASLT_R_32F> {
   using type = float;
 };
 template <>
-struct CudaToNativeT<CUDA_R_64F> {
+struct HipToNativeT<HIPBLASLT_R_64F> {
   using type = double;
 };
 template <>
-struct CudaToNativeT<CUDA_C_32F> {
+struct HipToNativeT<HIPBLASLT_C_32F> {
   using type = complex64;
 };
 template <>
-struct CudaToNativeT<CUDA_C_64F> {
+struct HipToNativeT<HIPBLASLT_C_64F> {
   using type = complex128;
 };
 
@@ -473,25 +465,33 @@
   std::tuple operand_types{a_desc_.type(), b_desc_.type(), c_desc_.type(),
                            d_desc_.type()};
 
-#define TYPED_MATMUL(SCALENTYPE, ATYPE, BTYPE, CTYPE, DTYPE)                \
-  if (operand_types == std::make_tuple(ATYPE, BTYPE, CTYPE, DTYPE)) {       \
-    return gpu::BlasLt::MatmulPlan::DoMatmul<                               \
-        SCALENTYPE, CudaToNativeT<ATYPE>::type, CudaToNativeT<BTYPE>::type, \
-        CudaToNativeT<CTYPE>::type, CudaToNativeT<DTYPE>::type>(            \
-        stream, alpha_, a, b, beta_, c, d, bias, aux, a_scale, b_scale,     \
-        c_scale, d_scale, d_amax, algorithm, scratch_allocator,             \
-        profile_result);                                                    \
+#define TYPED_MATMUL(SCALENTYPE, ATYPE, BTYPE, CTYPE, DTYPE)              \
+  if (operand_types == std::make_tuple(ATYPE, BTYPE, CTYPE, DTYPE)) {     \
+    return gpu::BlasLt::MatmulPlan::DoMatmul<                             \
+        SCALENTYPE, HipToNativeT<ATYPE>::type, HipToNativeT<BTYPE>::type, \
+        HipToNativeT<CTYPE>::type, HipToNativeT<DTYPE>::type>(            \
+        stream, alpha_, a, b, beta_, c, d, bias, aux, a_scale, b_scale,   \
+        c_scale, d_scale, d_amax, algorithm, scratch_allocator,           \
+        profile_result);                                                  \
   }
 
   // Other data types:
-  TYPED_MATMUL(float, CUDA_R_16BF, CUDA_R_16BF, CUDA_R_16BF, CUDA_R_16BF)
-  TYPED_MATMUL(float, CUDA_R_16F, CUDA_R_16F, CUDA_R_16F, CUDA_R_16F)
-  TYPED_MATMUL(float, CUDA_R_16BF, CUDA_R_16BF, CUDA_R_32F, CUDA_R_32F)
-  TYPED_MATMUL(float, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F, CUDA_R_32F)
-  TYPED_MATMUL(float, CUDA_R_32F, CUDA_R_32F, CUDA_R_32F, CUDA_R_32F)
-  TYPED_MATMUL(double, CUDA_R_64F, CUDA_R_64F, CUDA_R_64F, CUDA_R_64F)
-  TYPED_MATMUL(complex64, CUDA_C_32F, CUDA_C_32F, CUDA_C_32F, CUDA_C_32F)
-  TYPED_MATMUL(complex128, CUDA_C_64F, CUDA_C_64F, CUDA_C_64F, CUDA_C_64F)
+  TYPED_MATMUL(float, HIPBLASLT_R_16B, HIPBLASLT_R_16B, HIPBLASLT_R_16B,
+               HIPBLASLT_R_16B)
+  TYPED_MATMUL(float, HIPBLASLT_R_16F, HIPBLASLT_R_16F, HIPBLASLT_R_16F,
+               HIPBLASLT_R_16F)
+  TYPED_MATMUL(float, HIPBLASLT_R_16B, HIPBLASLT_R_16B, HIPBLASLT_R_32F,
+               HIPBLASLT_R_32F)
+  TYPED_MATMUL(float, HIPBLASLT_R_16F, HIPBLASLT_R_16F, HIPBLASLT_R_32F,
+               HIPBLASLT_R_32F)
+  TYPED_MATMUL(float, HIPBLASLT_R_32F, HIPBLASLT_R_32F, HIPBLASLT_R_32F,
+               HIPBLASLT_R_32F)
+  TYPED_MATMUL(double, HIPBLASLT_R_64F, HIPBLASLT_R_64F, HIPBLASLT_R_64F,
+               HIPBLASLT_R_64F)
+  TYPED_MATMUL(complex64, HIPBLASLT_C_32F, HIPBLASLT_C_32F, HIPBLASLT_C_32F,
+               HIPBLASLT_C_32F)
+  TYPED_MATMUL(complex128, HIPBLASLT_C_64F, HIPBLASLT_C_64F, HIPBLASLT_C_64F,
+               HIPBLASLT_C_64F)
 
 #undef TYPED_MATMUL
 
diff --git a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h
index 0b11051..678608e 100644
--- a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h
+++ b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h
@@ -13,6 +13,7 @@
 #ifndef XLA_STREAM_EXECUTOR_ROCM_HIP_BLAS_LT_H_
 #define XLA_STREAM_EXECUTOR_ROCM_HIP_BLAS_LT_H_
 
+#include "rocm/rocm_config.h"
 #include "xla/stream_executor/blas.h"
 #include "xla/stream_executor/device_memory.h"
 #include "xla/stream_executor/gpu/gpu_blas_lt.h"
@@ -22,9 +23,7 @@
 
 #if TF_HIPBLASLT
 
-#include "rocm/rocm_config.h"
 #include "xla/stream_executor/rocm/hip_blas_utils.h"
-#include "xla/stream_executor/rocm/hipblaslt_wrapper.h"
 
 namespace stream_executor {
 
@@ -43,14 +42,16 @@
   struct MatrixLayout {
     static tsl::StatusOr<MatrixLayout> Create(const gpu::MatrixLayout& m);
 
-    hipblasDatatype_t type() const { return HIPBLAS_R_32F; }
+    hipblasltDatatype_t type() const { return datatype_; }
     hipblasLtMatrixLayout_t get() const { return handle_.get(); }
 
    private:
-    explicit MatrixLayout(hipblasLtMatrixLayout_t handle)
-        : handle_(handle, wrap::hipblasLtMatrixLayoutDestroy) {}
+    MatrixLayout(hipblasLtMatrixLayout_t handle, hipblasltDatatype_t datatype)
+        : handle_(handle, wrap::hipblasLtMatrixLayoutDestroy),
+          datatype_(datatype) {}
 
     Owned<hipblasLtMatrixLayout_t> handle_;
+    hipblasltDatatype_t datatype_;
   };
 
   class MatmulDesc {
@@ -62,20 +63,24 @@
         Epilogue epilogue = Epilogue::kDefault,
         PointerMode pointer_mode = PointerMode::kHost);
 
-    hipblasLtComputeType_t compute_type() const {
-      return HIPBLASLT_COMPUTE_F32;
-    }
-    hipblasDatatype_t scale_type() const { return HIPBLAS_R_32F; }
+    hipblasLtComputeType_t compute_type() const { return compute_type_; }
+    hipblasltDatatype_t scale_type() const { return datatype_; }
     hipblasPointerMode_t pointer_mode() const {
       return HIPBLAS_POINTER_MODE_HOST;
     }
     hipblasLtMatmulDesc_t get() const { return handle_.get(); }
 
    private:
-    explicit MatmulDesc(hipblasLtMatmulDesc_t handle)
-        : handle_(handle, wrap::hipblasLtMatmulDescDestroy) {}
+    MatmulDesc(hipblasLtMatmulDesc_t handle,
+               hipblasLtComputeType_t compute_type,
+               hipblasltDatatype_t datatype)
+        : handle_(handle, wrap::hipblasLtMatmulDescDestroy),
+          compute_type_(compute_type),
+          datatype_(datatype) {}
 
     Owned<hipblasLtMatmulDesc_t> handle_;
+    hipblasLtComputeType_t compute_type_;
+    hipblasltDatatype_t datatype_;
   };
 
   struct MatmulPlan : public gpu::BlasLt::MatmulPlan {
diff --git a/third_party/xla/xla/stream_executor/rocm/hip_blas_utils.cc b/third_party/xla/xla/stream_executor/rocm/hip_blas_utils.cc
index 6073342..69d2a48 100644
--- a/third_party/xla/xla/stream_executor/rocm/hip_blas_utils.cc
+++ b/third_party/xla/xla/stream_executor/rocm/hip_blas_utils.cc
@@ -32,27 +32,27 @@
   return tsl::OkStatus();
 }
 
-hipblasDatatype_t AsHipblasDataType(blas::DataType type) {
+hipblasltDatatype_t AsHipblasDataType(blas::DataType type) {
   switch (type) {
     case blas::DataType::kF8E5M2:
     case blas::DataType::kF8E4M3FN:
       LOG(FATAL) << "hipblaslt does not support F8 yet";
     case blas::DataType::kHalf:
-      return HIPBLAS_R_16F;
+      return HIPBLASLT_R_16F;
     case blas::DataType::kBF16:
-      return HIPBLAS_R_16B;
+      return HIPBLASLT_R_16B;
     case blas::DataType::kFloat:
-      return HIPBLAS_R_32F;
+      return HIPBLASLT_R_32F;
     case blas::DataType::kDouble:
-      return HIPBLAS_R_64F;
+      return HIPBLASLT_R_64F;
     case blas::DataType::kInt8:
-      return HIPBLAS_R_8I;
+      return HIPBLASLT_R_8I;
     case blas::DataType::kInt32:
-      return HIPBLAS_R_32I;
+      return HIPBLASLT_R_32I;
     case blas::DataType::kComplexFloat:
-      return HIPBLAS_C_32F;
+      return HIPBLASLT_C_32F;
     case blas::DataType::kComplexDouble:
-      return HIPBLAS_C_64F;
+      return HIPBLASLT_C_64F;
     default:
       LOG(FATAL) << "unknown data type";
   }
diff --git a/third_party/xla/xla/stream_executor/rocm/hip_blas_utils.h b/third_party/xla/xla/stream_executor/rocm/hip_blas_utils.h
index 3c575dc..c4f7676 100644
--- a/third_party/xla/xla/stream_executor/rocm/hip_blas_utils.h
+++ b/third_party/xla/xla/stream_executor/rocm/hip_blas_utils.h
@@ -25,6 +25,18 @@
 
 #if TF_HIPBLASLT
 
+#if TF_ROCM_VERSION < 60000
+#define hipblasltDatatype_t hipblasDatatype_t
+#define HIPBLASLT_R_16F HIPBLAS_R_16F
+#define HIPBLASLT_R_16B HIPBLAS_R_16B
+#define HIPBLASLT_R_32F HIPBLAS_R_32F
+#define HIPBLASLT_R_64F HIPBLAS_R_64F
+#define HIPBLASLT_R_8I HIPBLAS_R_8I
+#define HIPBLASLT_R_32I HIPBLAS_R_32I
+#define HIPBLASLT_C_32F HIPBLAS_C_32F
+#define HIPBLASLT_C_64F HIPBLAS_C_64F
+#endif
+
 namespace stream_executor {
 namespace rocm {
 
@@ -32,7 +44,7 @@
   TF_RETURN_IF_ERROR(::stream_executor::rocm::ToStatus(expr, #expr))
 
 tsl::Status ToStatus(hipblasStatus_t status, const char* prefix);
-hipblasDatatype_t AsHipblasDataType(blas::DataType type);
+hipblasltDatatype_t AsHipblasDataType(blas::DataType type);
 hipblasLtComputeType_t AsHipblasComputeType(blas::ComputationType type);
 hipblasOperation_t AsHipblasOperation(blas::Transpose trans);
 
diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_blas.cc b/third_party/xla/xla/stream_executor/rocm/rocm_blas.cc
index b426cd8..c9c6beb 100644
--- a/third_party/xla/xla/stream_executor/rocm/rocm_blas.cc
+++ b/third_party/xla/xla/stream_executor/rocm/rocm_blas.cc
@@ -26,6 +26,7 @@
 #include "absl/strings/str_format.h"
 #include "absl/types/span.h"
 #include "unsupported/Eigen/CXX11/Tensor"  // from @eigen_archive
+#include "rocm/rocm_config.h"
 #include "xla/stream_executor/device_memory.h"
 #include "xla/stream_executor/gpu/gpu_activation.h"
 #include "xla/stream_executor/gpu/gpu_executor.h"
@@ -425,7 +426,8 @@
                                  const void *alpha, const DeviceMemoryBase &a,
                                  int lda, const DeviceMemoryBase &b, int ldb,
                                  const void *beta, DeviceMemoryBase *c, int ldc,
-                                 const NumericOptions &numeric_options) {
+                                 const NumericOptions &numeric_options,
+                                 blas::CallContext context) {
   blas_log("DoBlasGemm");
   VLOG(1) << absl::StreamFormat(
       "doing rocBLAS GEMM: at=%d bt=%d m=%u n=%u "
@@ -463,6 +465,15 @@
       tsl::StatusOr<bool> maybe_hasXDLOPS = GpuDriver::GetMFMASupport();
       if (maybe_hasXDLOPS.ok() && maybe_hasXDLOPS.value()) {
         VLOG(1) << "Using rocblas_gemm_ex";
+        bool is_backprop = (context == blas::CallContext::kBackpropInput1) ||
+                           (context == blas::CallContext::kBackpropInput2);
+
+        uint32_t flags = rocblas_gemm_flags_none;
+#if TF_ROCM_VERSION >= 50000
+        if (is_backprop) {
+          flags = rocblas_gemm_flags_fp16_alt_impl;
+        }
+#endif
         return DoBlasInternalStatus(
             wrap::rocblas_gemm_ex, stream, /* pointer_mode_host = */ true,
             ROCMBlasTranspose(transa), ROCMBlasTranspose(transb),
@@ -470,7 +481,7 @@
             rocblas_datatype_f16_r, lda, b.opaque(), rocblas_datatype_f16_r,
             ldb, beta, c->opaque(), rocblas_datatype_f16_r, ldc, c->opaque(),
             rocblas_datatype_f16_r, ldc, rocblas_datatype_f32_r,
-            rocblas_gemm_algo_standard, 0, 0);
+            rocblas_gemm_algo_standard, 0, flags);
       } else {
         VLOG(1) << "Using rocblas_hgemm";
         const Eigen::half alpha_half(*static_cast<const float *>(alpha));
@@ -549,7 +560,7 @@
     blas::DataType type_b, int ldb, const void *beta, DeviceMemoryBase *c,
     blas::DataType type_c, int ldc, blas::ComputationType computation_type,
     blas::AlgorithmType algorithm, const NumericOptions &numeric_options,
-    blas::ProfileResult *output_profile_result) {
+    blas::ProfileResult *output_profile_result, blas::CallContext context) {
   // ROCM TODO: properly implement the interface
   return tsl::errors::Internal("DoBlasGemmWithAlgorithm ",
                                "is not implemented on ROCm yet");
@@ -563,7 +574,7 @@
     DeviceMemoryBase *c, blas::DataType type_c, int ldc, int64_t stride_c,
     int batch_count, blas::ComputationType computation_type,
     blas::AlgorithmType algorithm, const NumericOptions &numeric_options,
-    blas::ProfileResult *output_profile_result) {
+    blas::ProfileResult *output_profile_result, blas::CallContext context) {
   // ROCM TODO: properly implement the interface
   return tsl::errors::Internal("DoBlasGemmStridedBatchedWithAlgorithm ",
                                "is not implemented on ROCm yet");
@@ -854,15 +865,13 @@
   return tsl::OkStatus();
 }
 
-bool ROCMBlas::DoBlasGemmBatched(Stream *stream, blas::Transpose transa,
-                                 blas::Transpose transb, uint64_t m, uint64_t n,
-                                 uint64 k, float alpha,
-                                 DeviceMemorySlice<Eigen::half> a, int lda,
-                                 DeviceMemorySlice<Eigen::half> b, int ldb,
-                                 float beta, DeviceMemorySlice<Eigen::half> c,
-                                 int ldc, int batch_count,
-                                 const NumericOptions &numeric_options,
-                                 ScratchAllocator *scratch_allocator) {
+bool ROCMBlas::DoBlasGemmBatched(
+    Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
+    uint64_t n, uint64 k, float alpha, DeviceMemorySlice<Eigen::half> a,
+    int lda, DeviceMemorySlice<Eigen::half> b, int ldb, float beta,
+    DeviceMemorySlice<Eigen::half> c, int ldc, int batch_count,
+    const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator,
+    blas::CallContext context) {
   blas_log("DoBlasGemmBatched");
   const Eigen::half alpha_half(alpha);
   const Eigen::half beta_half(beta);
@@ -884,8 +893,8 @@
     DeviceMemorySlice<Eigen::bfloat16> a_array, int lda,
     DeviceMemorySlice<Eigen::bfloat16> b_array, int ldb, float beta,
     DeviceMemorySlice<Eigen::bfloat16> c_array, int ldc, int batch_count,
-    const NumericOptions &numeric_options,
-    ScratchAllocator *scratch_allocator) {
+    const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator,
+    blas::CallContext context) {
   blas_log("DoBlasGemmBatched");
   const Eigen::bfloat16 alpha_bf16(alpha);
   const Eigen::bfloat16 beta_bf16(beta);
@@ -900,15 +909,13 @@
   return status.ok();
 }
 
-bool ROCMBlas::DoBlasGemmBatched(Stream *stream, blas::Transpose transa,
-                                 blas::Transpose transb, uint64_t m, uint64_t n,
-                                 uint64 k, float alpha,
-                                 DeviceMemorySlice<float> a_array, int lda,
-                                 DeviceMemorySlice<float> b_array, int ldb,
-                                 float beta, DeviceMemorySlice<float> c_array,
-                                 int ldc, int batch_count,
-                                 const NumericOptions &numeric_options,
-                                 ScratchAllocator *scratch_allocator) {
+bool ROCMBlas::DoBlasGemmBatched(
+    Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
+    uint64_t n, uint64 k, float alpha, DeviceMemorySlice<float> a_array,
+    int lda, DeviceMemorySlice<float> b_array, int ldb, float beta,
+    DeviceMemorySlice<float> c_array, int ldc, int batch_count,
+    const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator,
+    blas::CallContext context) {
   blas_log("DoBlasGemmBatched");
   tsl::Status status = DoBlasGemmBatchedInternal(
       wrap::rocblas_sgemm_strided_batched, stream, transa, transb, m, n, k,
@@ -920,15 +927,13 @@
   return status.ok();
 }
 
-bool ROCMBlas::DoBlasGemmBatched(Stream *stream, blas::Transpose transa,
-                                 blas::Transpose transb, uint64_t m, uint64_t n,
-                                 uint64 k, double alpha,
-                                 DeviceMemorySlice<double> a_array, int lda,
-                                 DeviceMemorySlice<double> b_array, int ldb,
-                                 double beta, DeviceMemorySlice<double> c_array,
-                                 int ldc, int batch_count,
-                                 const NumericOptions &numeric_options,
-                                 ScratchAllocator *scratch_allocator) {
+bool ROCMBlas::DoBlasGemmBatched(
+    Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
+    uint64_t n, uint64 k, double alpha, DeviceMemorySlice<double> a_array,
+    int lda, DeviceMemorySlice<double> b_array, int ldb, double beta,
+    DeviceMemorySlice<double> c_array, int ldc, int batch_count,
+    const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator,
+    blas::CallContext context) {
   blas_log("DoBlasGemmBatched");
   tsl::Status status = DoBlasGemmBatchedInternal(
       wrap::rocblas_dgemm_strided_batched, stream, transa, transb, m, n, k,
@@ -947,7 +952,7 @@
     DeviceMemorySlice<std::complex<float>> b_array, int ldb,
     std::complex<float> beta, DeviceMemorySlice<std::complex<float>> c_array,
     int ldc, int batch_count, const NumericOptions &numeric_options,
-    ScratchAllocator *scratch_allocator) {
+    ScratchAllocator *scratch_allocator, blas::CallContext context) {
   blas_log("DoBlasGemmBatched");
   tsl::Status status = DoBlasGemmBatchedInternal(
       wrap::rocblas_cgemm_strided_batched, stream, transa, transb, m, n, k,
@@ -966,7 +971,7 @@
     DeviceMemorySlice<std::complex<double>> b_array, int ldb,
     std::complex<double> beta, DeviceMemorySlice<std::complex<double>> c_array,
     int ldc, int batch_count, const NumericOptions &numeric_options,
-    ScratchAllocator *scratch_allocator) {
+    ScratchAllocator *scratch_allocator, blas::CallContext context) {
   blas_log("DoBlasGemmBatched");
   tsl::Status status = DoBlasGemmBatchedInternal(
       wrap::rocblas_zgemm_strided_batched, stream, transa, transb, m, n, k,
@@ -1096,7 +1101,7 @@
     const DeviceMemoryBase &a, int lda, int64_t stride_a,
     const DeviceMemoryBase &b, int ldb, int64_t stride_b, const void *beta,
     DeviceMemoryBase *c, int ldc, int64_t stride_c, int batch_count,
-    const NumericOptions &numeric_options) {
+    const NumericOptions &numeric_options, blas::CallContext context) {
   VLOG(1) << absl::StreamFormat(
       "doing rocBLAS SGEMM Strided Batched<float>: at=%d bt=%d m=%u n=%u "
       "k=%llu alpha=%p a=%p lda=%d b=%p ldb=%d beta=%p "
diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_dnn.cc b/third_party/xla/xla/stream_executor/rocm/rocm_dnn.cc
index 3bda5c1..9a200cc 100644
--- a/third_party/xla/xla/stream_executor/rocm/rocm_dnn.cc
+++ b/third_party/xla/xla/stream_executor/rocm/rocm_dnn.cc
@@ -28,6 +28,7 @@
 #include "absl/types/span.h"
 #include "Eigen/Core"  // from @eigen_archive
 #include "rocm/include/miopen/miopen.h"
+#include "rocm/rocm_config.h"
 #include "xla/stream_executor/dnn.h"
 #include "xla/stream_executor/gpu/gpu_activation.h"
 #include "xla/stream_executor/gpu/gpu_driver.h"
@@ -252,7 +253,7 @@
 
 #endif
 
-#if (TF_ROCM_VERSION >= 50300)
+#if (TF_ROCM_VERSION >= 50000)
 // clang-format off
 #define MIOPEN_DNN_ROUTINE_EACH(__macro)                             \
   __macro(miopenBatchNormalizationBackward)                          \
@@ -625,8 +626,8 @@
 dnn::ProfileResult GetProfileResultFromConvSolution(
     miopenConvSolution_t solution) {
   dnn::ProfileResult profile_result;
-  profile_result.set_algorithm(
-      {solution.solution_id, false, solution.workspace_size});
+  profile_result.set_algorithm({(dnn::AlgorithmDesc::Index)solution.solution_id,
+                                false, solution.workspace_size});
   profile_result.set_elapsed_time_in_ms(solution.time);
   profile_result.set_scratch_size(solution.workspace_size);
   return profile_result;
@@ -1312,7 +1313,7 @@
       const int op_idx, const float* alpha, const float* beta,
       const void* scale, const void* offset, void* running_mean,
       void* running_variance, void* saved_mean, void* saved_inv_variance,
-      double epsilon, double exponential_average_factor) {
+      double exponential_average_factor, double epsilon) {
     miopenFusionOpDescriptor_t batchnorm_op;
     auto status =
         wrap::miopenFusionPlanGetOp(fusion_plan_, op_idx, &batchnorm_op);
@@ -1323,8 +1324,8 @@
 
     status = wrap::miopenSetOpArgsBatchNormForward(
         fusion_args_, batchnorm_op, alpha, beta, scale, offset, saved_mean,
-        saved_inv_variance, running_mean, running_variance, epsilon,
-        exponential_average_factor);
+        saved_inv_variance, running_mean, running_variance,
+        exponential_average_factor, epsilon);
     if (status != miopenStatusSuccess) {
       LOG(FATAL) << "call to miopenSetOpArgsBatchNormForward failed: "
                  << ToString(status);
@@ -1684,7 +1685,7 @@
     float beta = 0.0;
     return ScopedFusionPlanBase::SetBatchNormForwardArgs(
         k_batchnorm_op_idx, &alpha, &beta, scale, offset, batch_mean, batch_var,
-        saved_mean, saved_var, epsilon, /*exponential_average_factor=*/1.0);
+        saved_mean, saved_var, /*exponential_average_factor=*/1.0, epsilon);
   }
 
   miopenStatus_t SetActivationForwardArgs(
@@ -1842,6 +1843,9 @@
     case dnn::DataType::kInt8:
       if (data_layout == dnn::DataLayout::kBatchDepthYX) return miopenInt8;
     case dnn::DataType::kDouble:
+      LOG(FATAL)
+          << "Unsupported DNN data type: tf.float64 (dnn::DataType::kDouble)";
+      break;
     default:
       LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
   }
@@ -1956,6 +1960,7 @@
                       miopenRNNDirectionMode_t direction_mode,
                       miopenRNNMode_t rnn_mode, miopenDataType_t data_type,
                       float dropout, uint64_t seed,
+                      const dnn::AlgorithmConfig& algorithm_config,
                       ScratchAllocator* state_allocator)
       : rnn_desc_(nullptr),
         num_layers_(num_layers),
@@ -1964,7 +1969,8 @@
         input_mode_(input_mode),
         direction_mode_(direction_mode),
         rnn_mode_(rnn_mode),
-        data_type_(data_type) {
+        data_type_(data_type),
+        algorithm_config_(algorithm_config) {
     // Create the RNN handle
     auto status = wrap::miopenCreateRNNDescriptor(&rnn_desc_);
     RETURN_IF_MIOPEN_ERROR(status, "Unable to create RNN descriptor");
@@ -2000,6 +2006,9 @@
   miopenRNNDirectionMode_t direction_mode() const { return direction_mode_; }
   miopenRNNMode_t rnn_mode() const { return rnn_mode_; }
   miopenDataType_t data_type() const { return data_type_; }
+  const dnn::AlgorithmConfig& algorithm_config() const {
+    return algorithm_config_;
+  }
   int64_t ParamsSizeInBytes() const override {
     return miopen_params_desc_->params_size_in_bytes();
   }
@@ -2025,6 +2034,7 @@
   miopenRNNDirectionMode_t direction_mode_;
   miopenRNNMode_t rnn_mode_;
   miopenDataType_t data_type_;
+  dnn::AlgorithmConfig algorithm_config_;
   tsl::Status status_;
   // no dropout in MIOpen.
   // std::unique_ptr<miopenDropoutDescriptor> miopen_dropout_desc_;
@@ -2290,7 +2300,8 @@
     const MIOpenRnnStateTensorDescriptor& output_c_desc,
     DeviceMemory<T>* output_c_data, bool is_training,
     ScratchAllocator* reserve_space_allocator,
-    ScratchAllocator* workspace_allocator) {
+    ScratchAllocator* workspace_allocator,
+    dnn::ProfileResult* output_profile_result) {
   // extract model parameters
   RnnModelDims model_dims;
   bool res = ExtractAndCheckRnnForward(
@@ -2346,6 +2357,18 @@
     }
   }
 
+  std::optional<GpuTimer> timer;
+  const bool is_profiling = output_profile_result != nullptr;
+
+  if (is_profiling) {
+    auto timer_or_status = GpuTimer::Create(AsGpuStream(stream));
+    if (!timer_or_status.ok()) {
+      LOG(ERROR) << "Failed to create timer";
+      return false;
+    }
+    timer.emplace(std::move(*timer_or_status));
+  }
+
   // make the forward call
   if (!is_training) {
     auto status = wrap::miopenRNNForwardInference(
@@ -2385,6 +2408,19 @@
       return false;
     }
   }
+
+  if (is_profiling) {
+    tsl::StatusOr<absl::Duration> elapsed = timer->GetElapsedDuration();
+    if (!elapsed.ok()) {
+      LOG(ERROR) << "Failed to get elapsed duration";
+      return false;
+    }
+    auto algo_desc = *rnn_desc.algorithm_config().algorithm();
+    output_profile_result->set_algorithm(algo_desc);
+    output_profile_result->set_elapsed_time_in_ms(
+        absl::ToDoubleMilliseconds(*elapsed));
+  }
+
   return true;
 }
 
@@ -2411,7 +2447,8 @@
     DeviceMemory<T>* input_c_backprop_data,
     DeviceMemory<T>* params_backprop_data,
     DeviceMemory<uint8>* reserve_space_data,
-    ScratchAllocator* workspace_allocator) {
+    ScratchAllocator* workspace_allocator,
+    dnn::ProfileResult* output_profile_result) {
   // extract model parameters
   RnnModelDims model_dims;
   bool res = ExtractAndCheckRnnForward(
@@ -2458,6 +2495,18 @@
   if ((size_data > 0) && (input_c_backprop_data->opaque() != nullptr))
     stream->ThenMemZero(input_c_backprop_data, size_data * type_size);
 
+  std::optional<GpuTimer> timer;
+  const bool is_profiling = output_profile_result != nullptr;
+
+  if (is_profiling) {
+    auto timer_or_status = GpuTimer::Create(AsGpuStream(stream));
+    if (!timer_or_status.ok()) {
+      LOG(ERROR) << "Failed to create timer";
+      return false;
+    }
+    timer.emplace(std::move(*timer_or_status));
+  }
+
   // make the backward data call
   auto status = wrap::miopenRNNBackwardData(
       miopen.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/,
@@ -2504,6 +2553,18 @@
     }
   }
 
+  if (is_profiling) {
+    tsl::StatusOr<absl::Duration> elapsed = timer->GetElapsedDuration();
+    if (!elapsed.ok()) {
+      LOG(ERROR) << "Failed to get elapsed duration";
+      return false;
+    }
+    auto algo_desc = *rnn_desc.algorithm_config().algorithm();
+    output_profile_result->set_algorithm(algo_desc);
+    output_profile_result->set_elapsed_time_in_ms(
+        absl::ToDoubleMilliseconds(*elapsed));
+  }
+
   return true;
 }
 
@@ -2733,7 +2794,8 @@
       miopen.handle(), num_layers, hidden_size, input_size,
       ToMIOpenRnnInputMode(input_mode),
       ToMIOpenRnnDirectionMode(direction_mode), ToMIOpenRnnMode(rnn_mode),
-      ToMIOpenDataType(data_type), dropout, seed, state_allocator));
+      ToMIOpenDataType(data_type), dropout, seed, algorithm_config,
+      state_allocator));
   if (!rnn_desc->ok()) {
     return rnn_desc->Status();
   }
@@ -2788,8 +2850,6 @@
     ScratchAllocator* reserve_space_allocator,
     ScratchAllocator* workspace_allocator,
     dnn::ProfileResult* output_profile_result) {
-  // ROCM TODO: output_profile_result is ignore for now
-
   const MIOpenRnnDescriptor& miopen_rnn_desc =
       static_cast<const MIOpenRnnDescriptor&>(rnn_desc);
   const MIOpenRnnSequenceTensorDescriptor& miopen_input_desc =
@@ -2810,7 +2870,7 @@
       miopen_input_h_desc, input_h_data, miopen_input_c_desc, input_c_data,
       params, miopen_output_desc, output_data, miopen_output_h_desc,
       output_h_data, miopen_output_c_desc, output_c_data, is_training,
-      reserve_space_allocator, workspace_allocator);
+      reserve_space_allocator, workspace_allocator, output_profile_result);
 }
 
 bool MIOpenSupport::DoRnnForward(
@@ -2831,8 +2891,6 @@
     ScratchAllocator* reserve_space_allocator,
     ScratchAllocator* workspace_allocator,
     dnn::ProfileResult* output_profile_result) {
-  // ROCM TODO: output_profile_result is ignore for now
-
   const MIOpenRnnDescriptor& miopen_rnn_desc =
       static_cast<const MIOpenRnnDescriptor&>(rnn_desc);
   const MIOpenRnnSequenceTensorDescriptor& miopen_input_desc =
@@ -2853,7 +2911,7 @@
       miopen_input_h_desc, input_h_data, miopen_input_c_desc, input_c_data,
       params, miopen_output_desc, output_data, miopen_output_h_desc,
       output_h_data, miopen_output_c_desc, output_c_data, is_training,
-      reserve_space_allocator, workspace_allocator);
+      reserve_space_allocator, workspace_allocator, output_profile_result);
 }
 
 bool MIOpenSupport::DoRnnForward(
@@ -2905,8 +2963,6 @@
     DeviceMemory<uint8>* reserve_space_data,
     ScratchAllocator* workspace_allocator,
     dnn::ProfileResult* output_profile_result) {
-  // ROCM TODO: output_profile_result is ignore for now
-
   const MIOpenRnnDescriptor& miopen_rnn_desc =
       static_cast<const MIOpenRnnDescriptor&>(rnn_desc);
   const MIOpenRnnSequenceTensorDescriptor& miopen_input_desc =
@@ -2929,7 +2985,7 @@
       output_h_data, miopen_output_c_desc, output_c_data, output_backprop_data,
       output_h_backprop_data, output_c_backprop_data, input_backprop_data,
       input_h_backprop_data, input_c_backprop_data, params_backprop_data,
-      reserve_space_data, workspace_allocator);
+      reserve_space_data, workspace_allocator, output_profile_result);
 }
 
 bool MIOpenSupport::DoRnnBackward(
@@ -2957,8 +3013,6 @@
     DeviceMemory<uint8>* reserve_space_data,
     ScratchAllocator* workspace_allocator,
     dnn::ProfileResult* output_profile_result) {
-  // ROCM TODO: output_profile_result is ignore for now
-
   const MIOpenRnnDescriptor& miopen_rnn_desc =
       static_cast<const MIOpenRnnDescriptor&>(rnn_desc);
   const MIOpenRnnSequenceTensorDescriptor& miopen_input_desc =
@@ -2981,7 +3035,7 @@
       output_h_data, miopen_output_c_desc, output_c_data, output_backprop_data,
       output_h_backprop_data, output_c_backprop_data, input_backprop_data,
       input_h_backprop_data, input_c_backprop_data, params_backprop_data,
-      reserve_space_data, workspace_allocator);
+      reserve_space_data, workspace_allocator, output_profile_result);
 }
 
 bool MIOpenSupport::DoRnnBackward(
@@ -3106,7 +3160,16 @@
         input_desc_{input_descriptor, ToMIOpenDataType(input_type)},
         output_desc_{output_descriptor, ToMIOpenDataType(input_type)},
         filter_desc_{filter_descriptor, ToMIOpenDataType(input_type)},
-        conv_desc_{conv_descriptor, ToMIOpenDataType(input_type)} {}
+        conv_desc_{conv_descriptor, ToMIOpenDataType(input_type)} {
+    bool is_backprop = ((kind == dnn::ConvolutionKind::BACKWARD_DATA) ||
+                        (kind == dnn::ConvolutionKind::BACKWARD_FILTER));
+    // #if TF_ROCM_VERSION >= 50000
+    if (is_backprop && (ToMIOpenDataType(input_type) == miopenHalf)) {
+      wrap::miopenSetConvolutionAttribute(
+          conv_desc_.handle(), MIOPEN_CONVOLUTION_ATTRIB_FP16_ALT_IMPL, 1);
+    }
+    // #endif
+  }
 
   std::string ToString() const override {
     return dnn::AlgorithmDesc{algo_id_, false, workspace_size_}.ToString();
@@ -3361,6 +3424,17 @@
   ScopedConvolutionDescriptor conv{convolution_descriptor,
                                    ToMIOpenDataType(element_type)};
 
+  bool is_backprop = ((kind == dnn::ConvolutionKind::BACKWARD_DATA) ||
+                      (kind == dnn::ConvolutionKind::BACKWARD_FILTER));
+  // bool is_backprop = (call_context == dnn::CallContext::kBackpropData) ||
+  //                   (call_context == dnn::CallContext::kBackpropFilter);
+
+#if TF_ROCM_VERSION >= 50000
+  if (is_backprop && (ToMIOpenDataType(element_type) == miopenHalf)) {
+    wrap::miopenSetConvolutionAttribute(
+        conv.handle(), MIOPEN_CONVOLUTION_ATTRIB_FP16_ALT_IMPL, 1);
+  }
+#endif
   // First determine the number of algorityhms available
   size_t maxSolutionCount = 0;
 
@@ -3570,6 +3644,18 @@
   ScopedConvolutionDescriptor conv{convolution_descriptor,
                                    ToMIOpenDataType(element_type)};
 
+  bool is_backprop = ((kind == dnn::ConvolutionKind::BACKWARD_DATA) ||
+                      (kind == dnn::ConvolutionKind::BACKWARD_FILTER));
+  // bool is_backprop = (call_context == dnn::CallContext::kBackpropData) ||
+  //                    (call_context == dnn::CallContext::kBackpropFilter);
+
+#if TF_ROCM_VERSION >= 50000
+  if (is_backprop && (ToMIOpenDataType(element_type) == miopenHalf)) {
+    wrap::miopenSetConvolutionAttribute(
+        conv.handle(), MIOPEN_CONVOLUTION_ATTRIB_FP16_ALT_IMPL, 1);
+  }
+#endif
+
   // Determine the workspace memory size that will need by the call to Find
   size_t scratch_memory_size = 0;
   switch (kind) {
@@ -3711,11 +3797,41 @@
 
 bool MIOpenSupport::GetRnnAlgorithms(
     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
-  // ROCM TODO: implement this with proper MIOpen API
+  std::vector<dnn::AlgorithmDesc::Index> algo_types = {
+      // clang-format off
+    miopenRNNdefault,
+      // clang-format on
+  };
+
+  out_algorithms->clear();
+  for (auto i : algo_types) {
+    out_algorithms->push_back({i, /*use_tensor_ops=*/false});
+  }
   return true;
 }
 
 bool MIOpenSupport::DoBatchNormalizationForward(
+    Stream* stream, const DeviceMemory<Eigen::bfloat16>& x,
+    const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
+    const DeviceMemory<float>& estimated_mean,
+    const DeviceMemory<float>& estimated_variance,
+    const DeviceMemory<Eigen::bfloat16>& side_input,
+    const dnn::BatchDescriptor& x_desc,
+    const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
+    const double exponential_average_factor,
+    dnn::ActivationMode activation_mode, DeviceMemory<Eigen::bfloat16>* y,
+    DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
+    DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
+    bool is_training, ScratchAllocator* reserve_space_allocator,
+    ScratchAllocator* workspace_allocator) {
+  return DoBatchNormalizationForwardImpl<Eigen::bfloat16, float>(
+      stream, dnn::DataType::kBF16, dnn::DataType::kFloat, x, scale, offset,
+      estimated_mean, estimated_variance, side_input, x_desc, scale_offset_desc,
+      epsilon, exponential_average_factor, activation_mode, y, batch_mean,
+      batch_var, saved_mean, saved_inv_var, is_training);
+}
+
+bool MIOpenSupport::DoBatchNormalizationForward(
     Stream* stream, const DeviceMemory<Eigen::half>& x,
     const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
     const DeviceMemory<float>& estimated_mean,
@@ -3806,6 +3922,25 @@
 }
 
 bool MIOpenSupport::DoBatchNormalizationBackward(
+    Stream* stream, const DeviceMemory<Eigen::bfloat16>& y_backprop,
+    const DeviceMemory<Eigen::bfloat16>& x, const DeviceMemory<float>& scale,
+    const DeviceMemory<float>& offset, const DeviceMemory<float>& mean,
+    const DeviceMemory<float>& inv_var, const DeviceMemory<Eigen::bfloat16>& y,
+    const dnn::BatchDescriptor& x_desc,
+    const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
+    dnn::ActivationMode activation_mode,
+    DeviceMemory<Eigen::bfloat16>* x_backprop,
+    DeviceMemory<float>* scale_backprop, DeviceMemory<float>* offset_backprop,
+    DeviceMemory<Eigen::bfloat16>* side_input_backprop,
+    DeviceMemory<uint8_t>* reserve_space_data,
+    ScratchAllocator* workspace_allocator) {
+  return DoBatchNormalizationBackwardImpl<Eigen::bfloat16, float>(
+      stream, miopenBFloat16, miopenFloat, y_backprop, x, scale, mean, inv_var,
+      x_desc, scale_offset_desc, epsilon, x_backprop, scale_backprop,
+      offset_backprop);
+}
+
+bool MIOpenSupport::DoBatchNormalizationBackward(
     Stream* stream, const DeviceMemory<Eigen::half>& y_backprop,
     const DeviceMemory<Eigen::half>& x, const DeviceMemory<float>& scale,
     const DeviceMemory<float>& offset, const DeviceMemory<float>& mean,
@@ -3964,7 +4099,8 @@
   return stream->ThenBlasGemm<T, T>(
       tb, ta, _n, _m, _k, static_cast<DeviceMemory<T>>(b_data), _ldb,
       static_cast<DeviceMemory<T>>(a_data), _lda,
-      static_cast<DeviceMemory<T>*>(&c_data), _ldc, NumericOptions{});
+      static_cast<DeviceMemory<T>*>(&c_data), _ldc, NumericOptions{},
+      blas::CallContext::kNone);
 }
 
 template <typename T>
@@ -4131,7 +4267,9 @@
     if (!stream
              ->ThenBlasGemm(blas::Transpose::kNoTranspose,
                             blas::Transpose::kNoTranspose, m, n, k, weights, m,
-                            input_data, k, output_data, m, NumericOptions{})
+                            input_data, k, output_data, m, NumericOptions{},
+                            blas::CallContext::kNone)
+
              .ok()) {
       return false;
     }
@@ -4211,10 +4349,10 @@
       return ptrs;
     };
 
-    stream->ThenBlasGemmBatched(blas::Transpose::kNoTranspose,
-                                blas::Transpose::kNoTranspose, m, n, k, alpha,
-                                toPtrs(a), lda, toPtrs(b), ldb, beta, toPtrs(c),
-                                ldc, batch_count, NumericOptions{});
+    stream->ThenBlasGemmBatched(
+        blas::Transpose::kNoTranspose, blas::Transpose::kNoTranspose, m, n, k,
+        alpha, toPtrs(a), lda, toPtrs(b), ldb, beta, toPtrs(c), ldc,
+        batch_count, NumericOptions{}, blas::CallContext::kNone);
   }
 
   return stream->ok();
@@ -4814,6 +4952,20 @@
   return true;
 }
 
+bool UseNhwcLayoutForRocm() {
+#if TF_ROCM_VERSION >= 50100
+  static bool is_enabled = [] {
+    bool is_enabled = false;
+    TF_CHECK_OK(tsl::ReadBoolFromEnvVar("TF_USE_ROCM_NHWC",
+                                        /*default_val=*/false, &is_enabled));
+    return is_enabled;
+  }();
+  return is_enabled;
+#else  // TF_ROCM_VERSION < 50000
+  return false;
+#endif
+}
+
 }  // namespace gpu
 
 void initialize_miopen() {
diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_dnn.h b/third_party/xla/xla/stream_executor/rocm/rocm_dnn.h
index 30cf2df..8cd6f43 100644
--- a/third_party/xla/xla/stream_executor/rocm/rocm_dnn.h
+++ b/third_party/xla/xla/stream_executor/rocm/rocm_dnn.h
@@ -294,6 +294,21 @@
       bool is_training, ScratchAllocator* reserve_space_allocator,
       ScratchAllocator* workspace_allocator) override;
 
+  bool DoBatchNormalizationForward(
+      Stream* stream, const DeviceMemory<Eigen::bfloat16>& x,
+      const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
+      const DeviceMemory<float>& estimated_mean,
+      const DeviceMemory<float>& estimated_variance,
+      const DeviceMemory<Eigen::bfloat16>& side_input,
+      const dnn::BatchDescriptor& x_desc,
+      const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
+      const double exponential_average_factor,
+      dnn::ActivationMode activation_mode, DeviceMemory<Eigen::bfloat16>* y,
+      DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
+      DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
+      bool is_training, ScratchAllocator* reserve_space_allocator,
+      ScratchAllocator* workspace_allocator) override;
+
   bool DoBatchNormalizationBackward(
       Stream* stream, const DeviceMemory<float>& y_backprop,
       const DeviceMemory<float>& x, const DeviceMemory<float>& scale,
@@ -321,6 +336,21 @@
       DeviceMemory<uint8>* reserve_space_data,
       ScratchAllocator* workspace_allocator) override;
 
+  bool DoBatchNormalizationBackward(
+      Stream* stream, const DeviceMemory<Eigen::bfloat16>& y_backprop,
+      const DeviceMemory<Eigen::bfloat16>& x, const DeviceMemory<float>& scale,
+      const DeviceMemory<float>& offset, const DeviceMemory<float>& mean,
+      const DeviceMemory<float>& inv_var,
+      const DeviceMemory<Eigen::bfloat16>& y,
+      const dnn::BatchDescriptor& x_desc,
+      const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
+      dnn::ActivationMode activation_mode,
+      DeviceMemory<Eigen::bfloat16>* x_backprop,
+      DeviceMemory<float>* scale_backprop, DeviceMemory<float>* offset_backprop,
+      DeviceMemory<Eigen::bfloat16>* side_input_backprop,
+      DeviceMemory<uint8_t>* reserve_space_data,
+      ScratchAllocator* workspace_allocator) override;
+
   tsl::Status DoConvolve(
       dnn::ConvolutionKind kind, dnn::DataType element_type,
       dnn::DataType output_type, Stream* stream,
@@ -349,32 +379,6 @@
       const dnn::AlgorithmConfig& algorithm_config,
       dnn::ProfileResult* output_profile_result) override;
 
-  bool DoConvolveQuantized(
-      Stream* stream, const dnn::BatchDescriptor& input_descriptor,
-      const DeviceMemory<float>& input_data,
-      const dnn::FilterDescriptor& filter_descriptor,
-      const DeviceMemory<int8>& filter_coefficients,
-      const DeviceMemory<float>& coefficient_scales,
-      const dnn::ConvolutionDescriptor& convolution_descriptor,
-      const dnn::BatchDescriptor& output_descriptor,
-      DeviceMemory<float>* output_data) override {
-    LOG(ERROR) << "DoConvolveQuantized not supported by MIOpen";
-    return false;
-  }
-
-  bool DoConvolveQuantized(
-      Stream* stream, const dnn::BatchDescriptor& input_descriptor,
-      const DeviceMemory<float>& input_data,
-      const dnn::FilterDescriptor& filter_descriptor,
-      const DeviceMemory<int16>& filter_coefficients,
-      const DeviceMemory<float>& coefficient_scales,
-      const dnn::ConvolutionDescriptor& convolution_descriptor,
-      const dnn::BatchDescriptor& output_descriptor,
-      DeviceMemory<float>* output_data) override {
-    LOG(ERROR) << "DoConvolveQuantized not supported by MIOpen";
-    return false;
-  }
-
   bool DoSeparableConvolve(
       Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
       const DeviceMemory<float>& input_data,
@@ -586,7 +590,8 @@
                         const MIOpenRnnStateTensorDescriptor& output_c_desc,
                         DeviceMemory<T>* output_c_data, bool is_training,
                         ScratchAllocator* reserve_space_allocator,
-                        ScratchAllocator* workspace_allocator);
+                        ScratchAllocator* workspace_allocator,
+                        dnn::ProfileResult* output_profile_result);
   template <class T>
   bool DoRnnBackwardImpl(Stream* stream, const MIOpenRnnDescriptor& rnn_desc,
                          const MIOpenRnnSequenceTensorDescriptor& input_desc,
@@ -610,7 +615,8 @@
                          DeviceMemory<T>* input_c_backprop_data,
                          DeviceMemory<T>* params_backprop_data,
                          DeviceMemory<uint8>* reserve_space_data,
-                         ScratchAllocator* workspace_allocator);
+                         ScratchAllocator* workspace_allocator,
+                         dnn::ProfileResult* output_profile_result);
 
   tsl::Status DoPrepareForConvolution(
       dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
@@ -670,6 +676,16 @@
   void operator=(const MIOpenSupport&) = delete;
 };
 
+// A helper function for the front frameworks.
+// e.g., TF(tensorflow/core/kernels/conv_ops.cc, fused_batch_norm_op.cc
+// and tensorflow/core/grappler/optimizers/generic_layout_optimizer.cc)
+// This will decide whether to use NHWC in Convolution/Batchnorm.
+// This mode can be faster in in FP16 workloads on gfx908 and beyond.
+// Requires ROCm 5.0+.
+// TODO (ROCm): Use autotune to choose between this mode and NCHW
+// when MIOpen has more optimized kernels.
+bool UseNhwcLayoutForRocm();
+
 }  // namespace gpu
 }  // namespace stream_executor
 
diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc b/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc
index a15d868..0941efa 100644
--- a/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc
+++ b/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc
@@ -748,6 +748,18 @@
                      "hipDrvGraphAddMemcopyNode is not available on ROCm yet"};
 }
 
+/* static */ tsl::Status GpuDriver::GraphAddChildNode(
+    hipGraphNode_t* node, hipGraph_t graph, absl::Span<hipGraphNode_t> deps,
+    hipGraph_t child) {
+  VLOG(2) << "Create a new node by cloning the child graph " << child
+          << " and add it to " << graph << "; deps: " << deps.size();
+
+  RETURN_IF_ROCM_ERROR(
+      wrap::hipGraphAddChildGraphNode(node, graph, deps.data(), deps.size(),
+                                      child),
+      "Failed to create a child graph node and add it to a HIP graph");
+}
+
 /* static */ tsl::Status GpuDriver::LaunchKernel(
     GpuContext* context, absl::string_view kernel_name, hipFunction_t function,
     unsigned int grid_dim_x, unsigned int grid_dim_y, unsigned int grid_dim_z,
diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_driver_wrapper.h b/third_party/xla/xla/stream_executor/rocm/rocm_driver_wrapper.h
index 5808f4c2..8c3bfb7 100644
--- a/third_party/xla/xla/stream_executor/rocm/rocm_driver_wrapper.h
+++ b/third_party/xla/xla/stream_executor/rocm/rocm_driver_wrapper.h
@@ -99,6 +99,7 @@
   __macro(hipGetDeviceProperties)                   \
   __macro(hipGetErrorString)                        \
   __macro(hipGraphAddKernelNode)                    \
+  __macro(hipGraphAddChildGraphNode)                \
   __macro(hipGraphAddMemcpyNode)                    \
   __macro(hipGraphCreate)                           \
   __macro(hipGraphDebugDotPrint)                    \
diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_gpu_executor.cc b/third_party/xla/xla/stream_executor/rocm/rocm_gpu_executor.cc
index 76864db..235a4f7 100644
--- a/third_party/xla/xla/stream_executor/rocm/rocm_gpu_executor.cc
+++ b/third_party/xla/xla/stream_executor/rocm/rocm_gpu_executor.cc
@@ -31,7 +31,6 @@
 #include "xla/stream_executor/gpu/gpu_kernel.h"
 #include "xla/stream_executor/gpu/gpu_stream.h"
 #include "xla/stream_executor/gpu/gpu_timer.h"
-#include "xla/stream_executor/kernel_cache_config.h"
 #include "xla/stream_executor/platform.h"
 #include "xla/stream_executor/platform/dso_loader.h"
 #include "xla/stream_executor/platform/initialize.h"
diff --git a/third_party/xla/xla/stream_executor/stream.cc b/third_party/xla/xla/stream_executor/stream.cc
index dfe91b1..d607aaa 100644
--- a/third_party/xla/xla/stream_executor/stream.cc
+++ b/third_party/xla/xla/stream_executor/stream.cc
@@ -180,14 +180,6 @@
   return ToVlogString(absl::Span<const T>(elements));
 }
 
-std::string ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout) {
-  switch (depth_to_space_layout) {
-    case dnn::DepthToSpaceLayout::DepthHeightWidth:
-      return "DepthToSpaceLayout::DepthHeightWidth";
-  }
-  return "unknown DepthToSpaceLayout";
-}
-
 std::string ToVlogString(dnn::DataType data_type) {
   switch (data_type) {
     case dnn::DataType::kFloat:
@@ -531,33 +523,6 @@
   return *this;
 }
 
-Stream &Stream::ThenConvolveQuantized(
-    const dnn::BatchDescriptor &input_descriptor,
-    const DeviceMemory<float> &input_data,
-    const dnn::FilterDescriptor &filter_descriptor,
-    const DeviceMemory<int8_t> &filter_coefficients,
-    const DeviceMemory<float> &coefficient_scales,
-    const dnn::ConvolutionDescriptor &convolution_descriptor,
-    const dnn::BatchDescriptor &output_descriptor,
-    DeviceMemory<float> *output) {
-  VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
-            PARAM(filter_descriptor), PARAM(filter_coefficients),
-            PARAM(coefficient_scales), PARAM(convolution_descriptor),
-            PARAM(output_descriptor), PARAM(output));
-
-  if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
-    CheckError(dnn->DoConvolveQuantized(
-        this, input_descriptor, input_data, filter_descriptor,
-        filter_coefficients, coefficient_scales, convolution_descriptor,
-        output_descriptor, output));
-  } else {
-    SetError();
-    LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor "
-                    "without DNN support";
-  }
-  return *this;
-}
-
 Stream &Stream::ThenSeparableConvolve(
     const dnn::BatchDescriptor &batch_descriptor,
     const DeviceMemory<float> &input_data,
@@ -747,25 +712,6 @@
   return *this;
 }
 
-Stream &Stream::ThenDepthToSpace(
-    const dnn::BatchDescriptor &input_dimensions,
-    const DeviceMemory<float> &input_data,
-    const dnn::DepthToSpaceLayout &depth_to_space_layout,
-    const int sqrt_depth_reduction, DeviceMemory<float> *output_data) {
-  VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
-            PARAM(depth_to_space_layout), PARAM(sqrt_depth_reduction),
-            PARAM(output_data));
-
-  if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
-    CheckError(dnn->DoDepthToSpace(this, input_dimensions, input_data,
-                                   depth_to_space_layout, sqrt_depth_reduction,
-                                   output_data));
-  } else {
-    SetErrorAndLogNoDnnSupport();
-  }
-  return *this;
-}
-
 Stream &Stream::ThenElementwiseOperate(
     dnn::ElementwiseOperation operation,
     absl::Span<const dnn::BatchDescriptor> input_dimensions,
@@ -1356,8 +1302,8 @@
     uint64_t k, float alpha, DeviceMemorySlice<Eigen::half> a, int lda,
     DeviceMemorySlice<Eigen::half> b, int ldb, float beta,
     DeviceMemorySlice<Eigen::half> c, int ldc, int batch_count,
-    const NumericOptions &numeric_options,
-    ScratchAllocator *scratch_allocator) {
+    const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator,
+    blas::CallContext context) {
   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
@@ -1366,11 +1312,11 @@
                float, DeviceMemorySlice<Eigen::half>, int,
                DeviceMemorySlice<Eigen::half>, int, float,
                DeviceMemorySlice<Eigen::half>, int, int, const NumericOptions &,
-               ScratchAllocator *>
+               ScratchAllocator *, blas::CallContext>
       impl;
   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
-              numeric_options, scratch_allocator);
+              numeric_options, scratch_allocator, context);
 }
 
 Stream &Stream::ThenBlasGemmBatchedWithScratch(
@@ -1378,8 +1324,8 @@
     uint64_t k, float alpha, DeviceMemorySlice<Eigen::bfloat16> a, int lda,
     DeviceMemorySlice<Eigen::bfloat16> b, int ldb, float beta,
     DeviceMemorySlice<Eigen::bfloat16> c, int ldc, int batch_count,
-    const NumericOptions &numeric_options,
-    ScratchAllocator *scratch_allocator) {
+    const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator,
+    blas::CallContext context) {
   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
@@ -1388,22 +1334,23 @@
                float, DeviceMemorySlice<Eigen::bfloat16>, int,
                DeviceMemorySlice<Eigen::bfloat16>, int, float,
                DeviceMemorySlice<Eigen::bfloat16>, int, int,
-               const NumericOptions &, ScratchAllocator *>
+               const NumericOptions &, ScratchAllocator *, blas::CallContext>
       impl;
   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
-              numeric_options, scratch_allocator);
+              numeric_options, scratch_allocator, context);
 }
 
 Stream &Stream::ThenBlasGemmBatched(
     blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
     uint64_t k, float alpha, DeviceMemorySlice<float> a, int lda,
     DeviceMemorySlice<float> b, int ldb, float beta, DeviceMemorySlice<float> c,
-    int ldc, int batch_count, const NumericOptions &numeric_options) {
+    int ldc, int batch_count, const NumericOptions &numeric_options,
+    blas::CallContext context) {
   return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
                                         b, ldb, beta, c, ldc, batch_count,
                                         numeric_options,
-                                        /*scratch_allocator=*/nullptr);
+                                        /*scratch_allocator=*/nullptr, context);
 }
 
 Stream &Stream::ThenBlasGemmBatchedWithScratch(
@@ -1411,7 +1358,7 @@
     uint64_t k, float alpha, DeviceMemorySlice<float> a, int lda,
     DeviceMemorySlice<float> b, int ldb, float beta, DeviceMemorySlice<float> c,
     int ldc, int batch_count, const NumericOptions &numeric_options,
-    ScratchAllocator *scratch_allocator) {
+    ScratchAllocator *scratch_allocator, blas::CallContext context) {
   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
@@ -1419,11 +1366,11 @@
   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64,
                float, DeviceMemorySlice<float>, int, DeviceMemorySlice<float>,
                int, float, DeviceMemorySlice<float>, int, int,
-               const NumericOptions &, ScratchAllocator *>
+               const NumericOptions &, ScratchAllocator *, blas::CallContext>
       impl;
   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
-              numeric_options, scratch_allocator);
+              numeric_options, scratch_allocator, context);
 }
 
 Stream &Stream::ThenBlasGemmBatchedWithScratch(
@@ -1431,8 +1378,8 @@
     uint64_t k, double alpha, DeviceMemorySlice<double> a, int lda,
     DeviceMemorySlice<double> b, int ldb, double beta,
     DeviceMemorySlice<double> c, int ldc, int batch_count,
-    const NumericOptions &numeric_options,
-    ScratchAllocator *scratch_allocator) {
+    const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator,
+    blas::CallContext context) {
   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
@@ -1441,11 +1388,11 @@
                double, DeviceMemorySlice<double>, int,
                DeviceMemorySlice<double>, int, double,
                DeviceMemorySlice<double>, int, int, const NumericOptions &,
-               ScratchAllocator *>
+               ScratchAllocator *, blas::CallContext>
       impl;
   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
-              numeric_options, scratch_allocator);
+              numeric_options, scratch_allocator, context);
 }
 
 Stream &Stream::ThenBlasGemmBatched(
@@ -1454,11 +1401,11 @@
     DeviceMemorySlice<std::complex<float>> a, int lda,
     DeviceMemorySlice<std::complex<float>> b, int ldb, std::complex<float> beta,
     DeviceMemorySlice<std::complex<float>> c, int ldc, int batch_count,
-    const NumericOptions &numeric_options) {
+    const NumericOptions &numeric_options, blas::CallContext context) {
   return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
                                         b, ldb, beta, c, ldc, batch_count,
                                         numeric_options,
-                                        /*scratch_allocator=*/nullptr);
+                                        /*scratch_allocator=*/nullptr, context);
 }
 
 Stream &Stream::ThenBlasGemmBatchedWithScratch(
@@ -1467,8 +1414,8 @@
     DeviceMemorySlice<std::complex<float>> a, int lda,
     DeviceMemorySlice<std::complex<float>> b, int ldb, std::complex<float> beta,
     DeviceMemorySlice<std::complex<float>> c, int ldc, int batch_count,
-    const NumericOptions &numeric_options,
-    ScratchAllocator *scratch_allocator) {
+    const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator,
+    blas::CallContext context) {
   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
@@ -1477,11 +1424,11 @@
                std::complex<float>, DeviceMemorySlice<std::complex<float>>, int,
                DeviceMemorySlice<std::complex<float>>, int, std::complex<float>,
                DeviceMemorySlice<std::complex<float>>, int, int,
-               const NumericOptions &, ScratchAllocator *>
+               const NumericOptions &, ScratchAllocator *, blas::CallContext>
       impl;
   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
-              numeric_options, scratch_allocator);
+              numeric_options, scratch_allocator, context);
 }
 
 Stream &Stream::ThenBlasGemmBatchedWithScratch(
@@ -1491,7 +1438,7 @@
     DeviceMemorySlice<std::complex<double>> b, int ldb,
     std::complex<double> beta, DeviceMemorySlice<std::complex<double>> c,
     int ldc, int batch_count, const NumericOptions &numeric_options,
-    ScratchAllocator *scratch_allocator) {
+    ScratchAllocator *scratch_allocator, blas::CallContext context) {
   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
@@ -1500,11 +1447,12 @@
                std::complex<double>, DeviceMemorySlice<std::complex<double>>,
                int, DeviceMemorySlice<std::complex<double>>, int,
                std::complex<double>, DeviceMemorySlice<std::complex<double>>,
-               int, int, const NumericOptions &, ScratchAllocator *>
+               int, int, const NumericOptions &, ScratchAllocator *,
+               blas::CallContext>
       impl;
   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
-              numeric_options, scratch_allocator);
+              numeric_options, scratch_allocator, context);
 }
 
 Stream &Stream::ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src,
diff --git a/third_party/xla/xla/stream_executor/stream.h b/third_party/xla/xla/stream_executor/stream.h
index 3234c67..2d3633e 100644
--- a/third_party/xla/xla/stream_executor/stream.h
+++ b/third_party/xla/xla/stream_executor/stream.h
@@ -316,16 +316,6 @@
                        const dnn::BatchDescriptor &output_descriptor,
                        DeviceMemory<float> *output);
 
-  Stream &ThenConvolveQuantized(
-      const dnn::BatchDescriptor &input_descriptor,
-      const DeviceMemory<float> &input_data,
-      const dnn::FilterDescriptor &filter_descriptor,
-      const DeviceMemory<int8_t> &filter_coefficients,
-      const DeviceMemory<float> &coefficient_scales,
-      const dnn::ConvolutionDescriptor &convolution_descriptor,
-      const dnn::BatchDescriptor &output_descriptor,
-      DeviceMemory<float> *output_data);
-
   template <typename InputType, typename OutputType>
   tsl::Status ConvolveWithAlgorithm(
       dnn::ConvolutionKind kind, const dnn::BatchDescriptor &input_descriptor,
@@ -467,7 +457,8 @@
       std::optional<dnn::TensorDescriptor> activation_descriptor,
       std::optional<dnn::TensorDescriptor> mask_descriptor,
       std::optional<dnn::TensorDescriptor> bias_descriptor, double scale,
-      std::optional<double> dropout_rate, std::optional<int64_t> seed) {
+      std::optional<double> dropout_rate, std::optional<int64_t> seed,
+      bool is_flash_attention, bool is_causal_mask) {
     dnn::DnnSupport *dnn_support = parent_->AsDnn();
     if (!dnn_support) {
       return absl::UnimplementedError("DNN library is not found.");
@@ -476,7 +467,8 @@
         this, algorithm_desc, kind, bmm1_lhs_descriptor, bmm1_rhs_descriptor,
         bmm2_rhs_descriptor, intermediate_bmm2_lhs_descriptor,
         output_descriptor, activation_descriptor, mask_descriptor,
-        bias_descriptor, scale, dropout_rate, seed);
+        bias_descriptor, scale, dropout_rate, seed, is_flash_attention,
+        is_causal_mask);
   }
 
   tsl::StatusOr<std::unique_ptr<const dnn::FusedMHABackwardRunner>>
@@ -490,10 +482,13 @@
       const dnn::TensorDescriptor &d_bmm1_lhs_descriptor,
       const dnn::TensorDescriptor &d_bmm1_rhs_descriptor,
       const dnn::TensorDescriptor &d_bmm2_rhs_descriptor,
-      const dnn::TensorDescriptor &d_s_descriptor,
+      std::optional<dnn::TensorDescriptor> d_s_descriptor,
       std::optional<dnn::TensorDescriptor> mask_descriptor,
-      std::optional<dnn::TensorDescriptor> d_bias_descriptor, double scale,
-      std::optional<double> dropout_rate, std::optional<int64_t> seed) {
+      std::optional<dnn::TensorDescriptor> d_bias_descriptor,
+      std::optional<dnn::TensorDescriptor> fwd_output_descriptor,
+      std::optional<dnn::TensorDescriptor> bias_descriptor, double scale,
+      std::optional<double> dropout_rate, std::optional<int64_t> seed,
+      bool is_flash_attention, bool is_causal_mask) {
     dnn::DnnSupport *dnn_support = parent_->AsDnn();
     if (!dnn_support) {
       return absl::UnimplementedError("DNN library is not found.");
@@ -503,8 +498,9 @@
         bmm1_grad_gemm2_rhs_descriptor, bmm2_grad_gemm1_lhs_descriptor,
         bmm2_grad_gemm2_rhs_descriptor, d_output_descriptor,
         d_bmm1_lhs_descriptor, d_bmm1_rhs_descriptor, d_bmm2_rhs_descriptor,
-        d_s_descriptor, mask_descriptor, d_bias_descriptor, scale, dropout_rate,
-        seed);
+        d_s_descriptor, mask_descriptor, d_bias_descriptor,
+        fwd_output_descriptor, bias_descriptor, scale, dropout_rate, seed,
+        is_flash_attention, is_causal_mask);
   }
 
   Stream &ThenSeparableConvolve(
@@ -627,18 +623,6 @@
       absl::Span<const DeviceMemory<float> *const> input_data,
       DeviceMemory<float> *output_data);
 
-  // Depth to space takes an X by Y image with depth D*M² and changes it to an
-  // MX x MY image with depth D. Each input location (x,y) with depth D*M² in
-  // the input image is changed to an MxM contiguous area in the output image,
-  // with the values being laid out in raster order specified by
-  // DepthToSpaceLayout, and will have a new depth of D.
-  // See the DoDepthToSpace comment for more information.
-  Stream &ThenDepthToSpace(const dnn::BatchDescriptor &input_dimensions,
-                           const DeviceMemory<float> &input_data,
-                           const dnn::DepthToSpaceLayout &depth_to_space_layout,
-                           const int sqrt_depth_reduction,
-                           DeviceMemory<float> *output_data);
-
   Stream &ThenElementwiseOperate(
       dnn::ElementwiseOperation operation,
       absl::Span<const dnn::BatchDescriptor> input_dimensions,
@@ -760,11 +744,12 @@
                            const DeviceMemory<InputType> &a, int lda,
                            const DeviceMemory<InputType> &b, int ldb,
                            DeviceMemory<OutputType> *c, int ldc,
-                           const NumericOptions &numeric_options) {
+                           const NumericOptions &numeric_options,
+                           blas::CallContext context) {
     InputType alpha{1.0};
     InputType beta{0.0};
     return ThenBlasGemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c,
-                        ldc, numeric_options);
+                        ldc, numeric_options, context);
   }
 
   template <typename InputType, typename OutputType, typename ConstantType>
@@ -773,7 +758,8 @@
                            const DeviceMemory<InputType> &a, int lda,
                            const DeviceMemory<InputType> &b, int ldb,
                            ConstantType beta, DeviceMemory<OutputType> *c,
-                           int ldc, const NumericOptions &numeric_options) {
+                           int ldc, const NumericOptions &numeric_options,
+                           blas::CallContext context) {
     static_assert(
         detail::is_any_of<InputType, int8_t, Eigen::half, Eigen::bfloat16,
                           float, double, std::complex<float>,
@@ -802,9 +788,9 @@
     UpcastHalfToFloat<ConstantType>(&alpha_ptr, &beta_ptr, &alpha_storage,
                                     &beta_storage);
 
-    return blas->DoBlasGemm(this, transa, transb, m, n, k,
-                            blas::ToDataType<InputType>::value, alpha_ptr, a,
-                            lda, b, ldb, beta_ptr, c, ldc, numeric_options);
+    return blas->DoBlasGemm(
+        this, transa, transb, m, n, k, blas::ToDataType<InputType>::value,
+        alpha_ptr, a, lda, b, ldb, beta_ptr, c, ldc, numeric_options, context);
   }
 
   // TODO(reedwm): Update all callers to pass correct NumericOptions.
@@ -814,9 +800,9 @@
                            const DeviceMemory<InputType> &a, int lda,
                            const DeviceMemory<InputType> &b, int ldb,
                            ConstantType beta, DeviceMemory<OutputType> *c,
-                           int ldc) {
+                           int ldc, blas::CallContext context) {
     return ThenBlasGemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c,
-                        ldc, NumericOptions{});
+                        ldc, NumericOptions{}, context);
   }
 
   template <typename InputType, typename OutputType>
@@ -825,13 +811,14 @@
       uint64_t k, const DeviceMemory<InputType> &a, int lda,
       const DeviceMemory<InputType> &b, int ldb, DeviceMemory<OutputType> *c,
       int ldc, blas::ComputationType computation_type,
-      blas::AlgorithmType algorithm,
-      blas::ProfileResult *output_profile_result) {
+      blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result,
+      blas::CallContext context) {
     OutputType alpha{1};
     OutputType beta{0};
-    return ThenBlasGemmWithAlgorithm(
-        transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
-        computation_type, algorithm, NumericOptions{}, output_profile_result);
+    return ThenBlasGemmWithAlgorithm(transa, transb, m, n, k, alpha, a, lda, b,
+                                     ldb, beta, c, ldc, computation_type,
+                                     algorithm, NumericOptions{},
+                                     output_profile_result, context);
   }
 
   template <typename InputType, typename OutputType, typename ConstantType>
@@ -842,7 +829,7 @@
       DeviceMemory<OutputType> *c, int ldc,
       blas::ComputationType computation_type, blas::AlgorithmType algorithm,
       const NumericOptions &numeric_options,
-      blas::ProfileResult *output_profile_result) {
+      blas::ProfileResult *output_profile_result, blas::CallContext context) {
     TF_RETURN_IF_ERROR(
         CheckTypesForExtendedBlas<InputType, OutputType, ConstantType>(
             computation_type));
@@ -865,7 +852,8 @@
         blas::ToDataType<InputType>::value, lda, b,
         blas::ToDataType<InputType>::value, ldb, beta_ptr, c,
         blas::ToDataType<OutputType>::value, ldc, computation_type, algorithm,
-        numeric_options, output_profile_result);
+        numeric_options, output_profile_result, context);
+
     if (output_profile_result) {
       // The error is recorded in the profile.
       return ::tsl::OkStatus();
@@ -881,7 +869,7 @@
       int64_t stride_b, ConstantType beta, DeviceMemory<OutputType> *c, int ldc,
       int64_t stride_c, int batch_count, blas::ComputationType computation_type,
       blas::AlgorithmType algorithm, const NumericOptions &numeric_options,
-      blas::ProfileResult *output_profile_result) {
+      blas::ProfileResult *output_profile_result, blas::CallContext context) {
     TF_RETURN_IF_ERROR(
         CheckTypesForExtendedBlas<InputType, OutputType, ConstantType>(
             computation_type));
@@ -902,7 +890,8 @@
         blas::ToDataType<InputType>::value, lda, stride_a, b,
         blas::ToDataType<InputType>::value, ldb, stride_b, beta_ptr, c,
         blas::ToDataType<OutputType>::value, ldc, stride_c, batch_count,
-        computation_type, algorithm, numeric_options, output_profile_result);
+        computation_type, algorithm, numeric_options, output_profile_result,
+        context);
     if (output_profile_result) {
       // The error is recorded in the profile.
       return ::tsl::OkStatus();
@@ -920,7 +909,9 @@
                               DeviceMemorySlice<float> b, int ldb, float beta,
                               DeviceMemorySlice<float> c, int ldc,
                               int batch_count,
-                              const NumericOptions &numeric_options);
+                              const NumericOptions &numeric_options,
+                              blas::CallContext context);
+
   Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
                               uint64_t m, uint64 n, uint64_t k,
                               std::complex<float> alpha,
@@ -929,37 +920,49 @@
                               std::complex<float> beta,
                               DeviceMemorySlice<std::complex<float>> c, int ldc,
                               int batch_count,
-                              const NumericOptions &numeric_options);
+                              const NumericOptions &numeric_options,
+                              blas::CallContext context);
+  Stream &ThenBlasGemmBatched(
+      blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
+      uint64_t k, std::complex<double> alpha,
+      DeviceMemorySlice<std::complex<double>> a, int lda,
+      DeviceMemorySlice<std::complex<double>> b, int ldb,
+      std::complex<double> beta, DeviceMemorySlice<std::complex<double>> c,
+      int ldc, int batch_count, const NumericOptions &numeric_options,
+      blas::CallContext context);
+
   Stream &ThenBlasGemmBatchedWithScratch(
       blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
       uint64_t k, float alpha, DeviceMemorySlice<Eigen::half> a, int lda,
       DeviceMemorySlice<Eigen::half> b, int ldb, float beta,
       DeviceMemorySlice<Eigen::half> c, int ldc, int batch_count,
       const NumericOptions &numeric_options,
-      ScratchAllocator *scratch_allocator);
+      ScratchAllocator *scratch_allocator, blas::CallContext context);
+
   Stream &ThenBlasGemmBatchedWithScratch(
       blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
       uint64_t k, float alpha, DeviceMemorySlice<Eigen::bfloat16> a, int lda,
       DeviceMemorySlice<Eigen::bfloat16> b, int ldb, float beta,
       DeviceMemorySlice<Eigen::bfloat16> c, int ldc, int batch_count,
       const NumericOptions &numeric_options,
-      ScratchAllocator *scratch_allocator);
-  Stream &ThenBlasGemmBatchedWithScratch(blas::Transpose transa,
-                                         blas::Transpose transb, uint64_t m,
-                                         uint64 n, uint64_t k, float alpha,
-                                         DeviceMemorySlice<float> a, int lda,
-                                         DeviceMemorySlice<float> b, int ldb,
-                                         float beta, DeviceMemorySlice<float> c,
-                                         int ldc, int batch_count,
-                                         const NumericOptions &numeric_options,
-                                         ScratchAllocator *scratch_allocator);
+      ScratchAllocator *scratch_allocator, blas::CallContext context);
+
+  Stream &ThenBlasGemmBatchedWithScratch(
+      blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
+      uint64_t k, float alpha, DeviceMemorySlice<float> a, int lda,
+      DeviceMemorySlice<float> b, int ldb, float beta,
+      DeviceMemorySlice<float> c, int ldc, int batch_count,
+      const NumericOptions &numeric_options,
+      ScratchAllocator *scratch_allocator, blas::CallContext context);
+
   Stream &ThenBlasGemmBatchedWithScratch(
       blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
       uint64_t k, double alpha, DeviceMemorySlice<double> a, int lda,
       DeviceMemorySlice<double> b, int ldb, double beta,
       DeviceMemorySlice<double> c, int ldc, int batch_count,
       const NumericOptions &numeric_options,
-      ScratchAllocator *scratch_allocator);
+      ScratchAllocator *scratch_allocator, blas::CallContext context);
+
   Stream &ThenBlasGemmBatchedWithScratch(
       blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
       uint64_t k, std::complex<float> alpha,
@@ -967,7 +970,8 @@
       DeviceMemorySlice<std::complex<float>> b, int ldb,
       std::complex<float> beta, DeviceMemorySlice<std::complex<float>> c,
       int ldc, int batch_count, const NumericOptions &numeric_options,
-      ScratchAllocator *scratch_allocator);
+      ScratchAllocator *scratch_allocator, blas::CallContext context);
+
   Stream &ThenBlasGemmBatchedWithScratch(
       blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
       uint64_t k, std::complex<double> alpha,
@@ -975,7 +979,7 @@
       DeviceMemorySlice<std::complex<double>> b, int ldb,
       std::complex<double> beta, DeviceMemorySlice<std::complex<double>> c,
       int ldc, int batch_count, const NumericOptions &numeric_options,
-      ScratchAllocator *scratch_allocator);
+      ScratchAllocator *scratch_allocator, blas::CallContext context);
 
   template <typename InputType, typename OutputType, typename ConstantType>
   tsl::Status ThenBlasGemmStridedBatched(
@@ -983,8 +987,8 @@
       uint64_t k, ConstantType alpha, const DeviceMemory<InputType> &a, int lda,
       int64_t stride_a, const DeviceMemory<InputType> &b, int ldb,
       int64_t stride_b, ConstantType beta, DeviceMemory<OutputType> *c, int ldc,
-      int64_t stride_c, int batch_count,
-      const NumericOptions &numeric_options) {
+      int64_t stride_c, int batch_count, const NumericOptions &numeric_options,
+      blas::CallContext context) {
     static_assert(
         detail::is_any_of<InputType, int8_t, float, Eigen::half,
                           Eigen::bfloat16, double, std::complex<float>,
@@ -1011,7 +1015,7 @@
     return blas->DoBlasGemmStridedBatched(
         this, transa, transb, m, n, k, blas::ToDataType<InputType>::value,
         alpha_ptr, a, lda, stride_a, b, ldb, stride_b, beta_ptr, c, ldc,
-        stride_c, batch_count, numeric_options);
+        stride_c, batch_count, numeric_options, context);
   }
 
   // See BlasSupport::DoBlasTrsm.
diff --git a/third_party/xla/xla/stream_executor/stream_executor_internal.h b/third_party/xla/xla/stream_executor/stream_executor_internal.h
index 68a9b35..424b31c 100644
--- a/third_party/xla/xla/stream_executor/stream_executor_internal.h
+++ b/third_party/xla/xla/stream_executor/stream_executor_internal.h
@@ -41,7 +41,6 @@
 #include "xla/stream_executor/event.h"
 #include "xla/stream_executor/fft.h"
 #include "xla/stream_executor/kernel.h"
-#include "xla/stream_executor/kernel_cache_config.h"
 #include "xla/stream_executor/kernel_spec.h"
 #include "xla/stream_executor/launch_dim.h"
 #include "xla/stream_executor/module_spec.h"
diff --git a/third_party/xla/xla/stream_executor/stream_executor_pimpl.h b/third_party/xla/xla/stream_executor/stream_executor_pimpl.h
index bfe545b..dc15118 100644
--- a/third_party/xla/xla/stream_executor/stream_executor_pimpl.h
+++ b/third_party/xla/xla/stream_executor/stream_executor_pimpl.h
@@ -46,10 +46,6 @@
 #include "tsl/platform/threadpool.h"
 #include "tsl/protobuf/dnn.pb.h"
 
-// TODO(ezhulenev): Remove include of internal header. Currently we have too
-// many targets depending on transitive dependencies.
-#include "xla/stream_executor/stream_executor_internal.h"
-
 namespace stream_executor {
 
 class Stream;
diff --git a/third_party/xla/xla/stream_executor/tpu/BUILD b/third_party/xla/xla/stream_executor/tpu/BUILD
index a12a625..06f5988 100644
--- a/third_party/xla/xla/stream_executor/tpu/BUILD
+++ b/third_party/xla/xla/stream_executor/tpu/BUILD
@@ -191,7 +191,6 @@
         ":tpu_stream_interface",
         ":tpu_topology_external",
         "//xla/stream_executor",
-        "//xla/stream_executor:allocator_stats",
         "//xla/stream_executor/platform",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/functional:any_invocable",
@@ -248,7 +247,6 @@
         ":tpu_executor_api",
         ":tpu_executor_c_api_hdrs",
         ":tpu_topology_external",
-        "//xla/stream_executor:allocator_stats",
         "//xla/stream_executor:stream_executor_headers",
         "//xla/stream_executor:stream_executor_internal",
         "@com_google_absl//absl/container:flat_hash_map",
@@ -298,7 +296,6 @@
         ":tpu_topology_external",
         "//xla:status",
         "//xla/stream_executor",
-        "//xla/stream_executor:allocator_stats",
         "//xla/stream_executor:stream_executor_internal",
         "@com_google_absl//absl/cleanup",
         "@com_google_absl//absl/container:flat_hash_map",
diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD
index 4eb977f..7dc92ec 100644
--- a/third_party/xla/xla/tests/BUILD
+++ b/third_party/xla/xla/tests/BUILD
@@ -1389,6 +1389,24 @@
 )
 
 xla_test(
+    name = "int4_test",
+    srcs = ["int4_test.cc"],
+    backends = [
+        "cpu",
+        "gpu",
+        "interpreter",
+    ],
+    deps = [
+        ":client_library_test_base",
+        ":hlo_test_base",
+        ":xla_internal_test_main",
+        "//xla:test",
+        "//xla/client:xla_builder",
+        "@local_tsl//tsl/platform:errors",
+    ],
+)
+
+xla_test(
     name = "slice_test",
     timeout = "long",
     srcs = ["slice_test.cc"],
@@ -1948,6 +1966,7 @@
         ":test_macros_header",
         ":xla_internal_test_main",
         "//xla:shape_util",
+        "//xla:types",
         "//xla:xla_data_proto_cc",
         "//xla/client:local_client",
         "//xla/client:xla_builder",
diff --git a/third_party/xla/xla/tests/collective_pipeliner_execution_test.cc b/third_party/xla/xla/tests/collective_pipeliner_execution_test.cc
index 838ac34..536e837 100644
--- a/third_party/xla/xla/tests/collective_pipeliner_execution_test.cc
+++ b/third_party/xla/xla/tests/collective_pipeliner_execution_test.cc
@@ -39,7 +39,11 @@
     HloPredicate should_process = HloPredicateIsOp<HloOpcode::kNegate>,
     CollectivePipeliner::PipeliningDirection pipelining_direction =
         CollectivePipeliner::PipeliningDirection::kForward,
-    bool pipeline_use_tree = false) {
+    bool pipeline_use_tree = false,
+    HloPredicate acceptable_formatting =
+        [](const HloInstruction*) { return true; },
+    HloPredicate reuse_pipelined_op_buffer =
+        [](const HloInstruction*) { return true; }) {
   CollectivePipeliner::Config config = {
       /*level_to_operate_on=*/level_to_operate_on,
       /*max_pipelining_per_loop=*/INT64_MAX,
@@ -49,6 +53,8 @@
       /*direction=*/
       pipelining_direction,
       /*should_process=*/should_process,
+      /*acceptable_formatting=*/acceptable_formatting,
+      /*reuse_pipelined_op_buffer=*/reuse_pipelined_op_buffer,
   };
 
   HloPassPipeline pass("optimizer");
diff --git a/third_party/xla/xla/tests/convert_test.cc b/third_party/xla/xla/tests/convert_test.cc
index 8f895da..4330d29 100644
--- a/third_party/xla/xla/tests/convert_test.cc
+++ b/third_party/xla/xla/tests/convert_test.cc
@@ -29,6 +29,7 @@
 #include "xla/tests/client_library_test_base.h"
 #include "xla/tests/literal_test_util.h"
 #include "xla/tests/test_macros.h"
+#include "xla/types.h"
 #include "xla/xla_data.pb.h"
 #include "tsl/platform/float8.h"
 #include "tsl/platform/test.h"
@@ -536,6 +537,70 @@
   ComputeAndCompareR1<uint64_t>(&builder, unsigned_x, {});
 }
 
+TEST_F(ConvertTest, ConvertR1S4ToR1S8) {
+  XlaBuilder builder(TestName());
+  auto a = ConstantR1<s4>(&builder, {s4(0), s4(1), s4(2), s4(-8)});
+  ConvertElementType(a, S8);
+
+  std::vector<int8_t> expected = {0, 1, 2, -8};
+  ComputeAndCompareR1<int8_t>(&builder, expected, {});
+}
+
+TEST_F(ConvertTest, ConvertR1S4ParameterToR1S8) {
+  XlaBuilder builder(TestName());
+  Literal arg_literal =
+      LiteralUtil::CreateR1<s4>({s4(0), s4(1), s4(2), s4(-8)});
+  auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
+  std::unique_ptr<GlobalData> arg_data =
+      client_->TransferToServer(arg_literal).value();
+
+  ConvertElementType(arg_param, S8);
+
+  std::vector<int8_t> expected = {0, 1, 2, -8};
+  ComputeAndCompareR1<int8_t>(&builder, expected, {arg_data.get()});
+}
+
+TEST_F(ConvertTest, ConvertR1U4ToR1U8) {
+  XlaBuilder builder(TestName());
+  auto a = ConstantR1<u4>(&builder, {u4(0), u4(1), u4(2), u4(15)});
+  ConvertElementType(a, U8);
+
+  std::vector<uint8_t> expected = {0, 1, 2, 15};
+  ComputeAndCompareR1<uint8_t>(&builder, expected, {});
+}
+
+TEST_F(ConvertTest, ConvertR1U4ParameterToR1U8) {
+  XlaBuilder builder(TestName());
+  Literal arg_literal =
+      LiteralUtil::CreateR1<u4>({u4(0), u4(1), u4(2), u4(15)});
+  auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
+  std::unique_ptr<GlobalData> arg_data =
+      client_->TransferToServer(arg_literal).value();
+
+  ConvertElementType(arg_param, U8);
+
+  std::vector<uint8_t> expected = {0, 1, 2, 15};
+  ComputeAndCompareR1<uint8_t>(&builder, expected, {arg_data.get()});
+}
+
+TEST_F(ConvertTest, ConvertR1S8ToR1S4) {
+  XlaBuilder builder(TestName());
+  auto a = ConstantR1<int8_t>(&builder, {0, 1, 2, -8});
+  ConvertElementType(a, S4);
+
+  std::vector<s4> expected = {s4(0), s4(1), s4(2), s4(-8)};
+  ComputeAndCompareR1<s4>(&builder, expected, {});
+}
+
+TEST_F(ConvertTest, ConvertR1U8ToR1U4) {
+  XlaBuilder builder(TestName());
+  auto a = ConstantR1<uint8_t>(&builder, {0, 1, 2, 15});
+  ConvertElementType(a, U4);
+
+  std::vector<u4> expected = {u4(0), u4(1), u4(2), u4(15)};
+  ComputeAndCompareR1<u4>(&builder, expected, {});
+}
+
 XLA_TEST_F(ConvertTest, ConvertBF16F32) {
   XlaBuilder builder(TestName());
 
diff --git a/third_party/xla/xla/tests/int4_test.cc b/third_party/xla/xla/tests/int4_test.cc
new file mode 100644
index 0000000..d0f0267
--- /dev/null
+++ b/third_party/xla/xla/tests/int4_test.cc
@@ -0,0 +1,100 @@
+/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <optional>
+#include <string>
+
+#include "xla/test.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/tests/test_macros.h"
+
+namespace xla {
+namespace {
+
+XLA_TEST_F(HloTestBase, InputIsOutput) {
+  const std::string hlo_text = R"(
+  HloModule InputIsOutput
+  ENTRY main {
+    ROOT p = s4[8] parameter(0)
+  }
+)";
+  EXPECT_TRUE(RunAndCompare(hlo_text, std::nullopt));
+}
+
+XLA_TEST_F(HloTestBase, Reshape) {
+  // Tests that the convert is not moved after the reshape. Currently reshape
+  // and most other ops are unsupported in int4
+  const std::string hlo_text = R"(
+  HloModule Reshape
+  ENTRY main {
+    x = s4[2,3] parameter(0)
+    y = s8[2,3] convert(x)
+    ROOT reshape = s8[3,2] reshape(y)
+  }
+)";
+  EXPECT_TRUE(RunAndCompare(hlo_text, std::nullopt));
+}
+
+XLA_TEST_F(HloTestBase, Slice) {
+  // Tests indexing s4 arrays in the presence of a slice instruction. On
+  // CPUs/GPUs, the slice is fused with the s4 array
+  const std::string hlo_text = R"(
+  HloModule Slice
+  ENTRY main {
+    x = s4[4,5] parameter(0)
+    y = s8[4,5] convert(x)
+    ROOT s = s8[3,2] slice(y), slice={[0:3],[2:4]}
+  }
+)";
+  EXPECT_TRUE(RunAndCompare(hlo_text, std::nullopt));
+}
+
+XLA_TEST_F(HloTestBase, NonMajorToMinorLayout) {
+  // Tests transposing a matrix with a non-major-to-minor layout.
+  const std::string hlo_text = R"(
+  HloModule NonMajorToMinorLayout
+  ENTRY main {
+    x = s4[2,2]{0,1} parameter(0)
+    y = s8[2,2]{0,1} convert(x)
+    ROOT transpose = s8[2,2]{0,1} transpose(y), dimensions={1,0}
+  })";
+  EXPECT_TRUE(RunAndCompare(hlo_text, std::nullopt));
+}
+
+XLA_TEST_F(HloTestBase, Int4Output2d) {
+  // Tests outputting a 2D int4 array.
+  const std::string hlo_text = R"(
+  HloModule Int4Output2d
+  ENTRY main {
+    x = s8[2,2] parameter(0)
+    ROOT y = s4[2,2] convert(x)
+  })";
+  EXPECT_TRUE(RunAndCompare(hlo_text, std::nullopt));
+}
+
+XLA_TEST_F(HloTestBase, TupleOutput) {
+  // Tests tuple output with an int4 array
+  const std::string hlo_text = R"(
+  HloModule TupleOutput
+  ENTRY main {
+    x = s4[2,2] parameter(0)
+    y = s8[2,2] convert(x)
+    ROOT t = (s4[2,2], s8[2,2]) tuple(x, y)
+  })";
+  EXPECT_TRUE(RunAndCompare(hlo_text, std::nullopt));
+}
+
+}  // namespace
+}  // namespace xla
diff --git a/third_party/xla/xla/tests/tuple_test.cc b/third_party/xla/xla/tests/tuple_test.cc
index d6789f6..741105b 100644
--- a/third_party/xla/xla/tests/tuple_test.cc
+++ b/third_party/xla/xla/tests/tuple_test.cc
@@ -360,7 +360,7 @@
 
     ENTRY test {
       parameter = f32[3]{0} parameter(0)
-      ROOT tuple = (f32[3]{0}, f32[3]{0}) tuple(parameter)
+      ROOT tuple = (f32[3]{0}, f32[2]{0}) tuple(parameter, parameter)
     }
   )";
 
diff --git a/third_party/xla/xla/tools/hlo_control_flow_flattening.cc b/third_party/xla/xla/tools/hlo_control_flow_flattening.cc
index 5c112f1..933c33b4 100644
--- a/third_party/xla/xla/tools/hlo_control_flow_flattening.cc
+++ b/third_party/xla/xla/tools/hlo_control_flow_flattening.cc
@@ -393,7 +393,9 @@
 Status HloControlFlowFlattening::RemoveId(HloInstruction* hlo) const {
   HloComputation* computation = hlo->parent();
   HloInstruction* zero = CreateConstant(hlo->shape(), computation);
+  std::string original_op_name(hlo->name());
   TF_RETURN_IF_ERROR(computation->ReplaceInstruction(hlo, zero));
+  zero->SetAndSanitizeName(original_op_name);
   return OkStatus();
 }
 
@@ -456,6 +458,7 @@
                    instruction->custom_call_target() == "SliceId"))) {
         VLOG(1) << "Remove " << instruction->name();
         TF_RETURN_IF_ERROR(RemoveId(instruction));
+        changed = true;
       }
     }
   }
diff --git a/third_party/xla/xla/tools/hlo_control_flow_flattening_test.cc b/third_party/xla/xla/tools/hlo_control_flow_flattening_test.cc
index de9f4d7..a4d018d 100644
--- a/third_party/xla/xla/tools/hlo_control_flow_flattening_test.cc
+++ b/third_party/xla/xla/tools/hlo_control_flow_flattening_test.cc
@@ -494,6 +494,27 @@
             "collective-permute");
 }
 
+TEST_F(HloControlFlowFlatteningTest, ReplicaIdSucceedsWithChange) {
+  absl::string_view hlo_string = R"(
+  HloModule ReplicaId
+
+  ENTRY ReplicaId {
+    ROOT replica-id.18600 = u32[]{:T(128)} replica-id()
+  }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  HloControlFlowFlattening flattening(HloControlFlowFlattening::Options{});
+  EXPECT_TRUE(flattening.Run(module.get()).value());
+  TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
+                           /*allow_mixed_precision=*/true)
+                   .Run(module.get())
+                   .status());
+  EXPECT_THAT(module->entry_computation()->root_instruction(), op::Constant());
+  EXPECT_EQ(module->entry_computation()->root_instruction()->name(),
+            "replica-id.18600");
+}
+
 TEST_F(HloControlFlowFlatteningTest, CollectivePermuteInPlaceUpdate) {
   absl::string_view hlo_string = R"(
   HloModule CollectivePermuteInPlaceUpdate
diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.h b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.h
index 415e9b6..93d22bf 100644
--- a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.h
+++ b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.h
@@ -40,9 +40,9 @@
 
 // Supported input formats for the input HLO module.
 enum class InputFormat {
-  kText,                 // Text format.
-  kProtoText,            // Protobuf text format.
-  kProtoBinary,          // Protobuf binary format.
+  kText,                 // Text format returned by HloModule::ToString().
+  kProtoText,            // Protobuf text format of an xla::HloProto message.
+  kProtoBinary,          // Protobuf binary format of an xla::HloProto message.
   kSnapshotProtoBinary,  // HloSnapshot protobuf binary format. Can be dumped by
                          // TensorFlow by setting the environment variable
                          // xla_dump_hlo_snapshots.
diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/BUILD b/third_party/xla/xla/translate/mhlo_to_hlo/BUILD
index dd8214c..a9a72de 100644
--- a/third_party/xla/xla/translate/mhlo_to_hlo/BUILD
+++ b/third_party/xla/xla/translate/mhlo_to_hlo/BUILD
@@ -83,12 +83,14 @@
         ":operator_writer_inc",
         ":stack_frame_index_builder",
         ":type_to_shape",
+        "//xla:array",
         "//xla:comparison_util",
         "//xla:literal",
         "//xla:literal_util",
         "//xla:shape_util",
         "//xla:status",
         "//xla:status_macros",
+        "//xla:types",
         "//xla:xla_data_proto_cc",
         "//xla/client:xla_builder",
         "//xla/client/lib:approx_topk",
diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc b/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc
index ed3525f..5e74c80 100644
--- a/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc
+++ b/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc
@@ -25,6 +25,7 @@
 #include <vector>
 
 #include "absl/types/span.h"
+#include "llvm/ADT/APInt.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/STLExtras.h"
@@ -57,6 +58,7 @@
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
 #include "stablehlo/dialect/StablehloOps.h"  // from @stablehlo
+#include "xla/array.h"
 #include "xla/client/lib/approx_topk.h"
 #include "xla/client/lib/approx_topk_shape.h"
 #include "xla/client/lib/matrix.h"
@@ -76,6 +78,7 @@
 #include "xla/service/gpu/backend_configs.pb.h"
 #include "xla/service/hlo.pb.h"
 #include "xla/service/hlo_parser.h"
+#include "xla/shape.h"
 #include "xla/shape_util.h"
 #include "xla/status.h"
 #include "xla/status_macros.h"
@@ -83,6 +86,7 @@
 #include "xla/translate/mhlo_to_hlo/location_exporter.h"
 #include "xla/translate/mhlo_to_hlo/stack_frame_index_builder.h"
 #include "xla/translate/mhlo_to_hlo/type_to_shape.h"
+#include "xla/types.h"
 #include "xla/xla_data.pb.h"
 #include "tsl/platform/float8.h"
 #include "tsl/platform/statusor.h"
@@ -155,6 +159,30 @@
   return true;
 }
 
+template <typename T>
+xla::Array<T> ArrayFromDenseElementsAttr(mlir::DenseElementsAttr dense_attr) {
+  constexpr xla::PrimitiveType type =
+      xla::primitive_util::NativeToPrimitiveType<T>();
+  xla::Shape shape = xla::TypeToShape(dense_attr.getType());
+  xla::Array<T> array(shape.dimensions());
+  if constexpr (!xla::primitive_util::Is4BitType(type)) {
+    array.SetValues(dense_attr.getValues<T>());
+  } else {
+    // The only way to get subbyte integers from getValues() is to get them as
+    // APInts.
+    auto values = dense_attr.getValues<llvm::APInt>();
+    for (int i = 0; i < values.size(); i++) {
+      if constexpr (type == xla::U4) {
+        array.data()[i] = xla::u4{values[i].getZExtValue()};
+      } else {
+        static_assert(type == xla::S4);
+        array.data()[i] = xla::s4(values[i].getSExtValue());
+      }
+    }
+  }
+  return array;
+}
+
 StatusOr<xla::Literal> CreateArrayLiteralFromAttr(mlir::ElementsAttr attr,
                                                   xla::Layout layout) {
   auto dense_attr = attr.dyn_cast<mlir::DenseElementsAttr>();
@@ -169,8 +197,8 @@
                           primitive_type_constant)) {
           using cpp_type =
               xla::primitive_util::NativeTypeOf<primitive_type_constant>;
-          xla::Array<cpp_type> source_data(shape.dimensions());
-          source_data.SetValues(dense_attr.getValues<cpp_type>());
+          xla::Array<cpp_type> source_data =
+              ArrayFromDenseElementsAttr<cpp_type>(dense_attr);
           return xla::LiteralUtil::CreateFromArrayWithLayout(source_data,
                                                              layout);
         }
diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/int4.mlir b/third_party/xla/xla/translate/mhlo_to_hlo/tests/int4.mlir
new file mode 100644
index 0000000..615fe12
--- /dev/null
+++ b/third_party/xla/xla/translate/mhlo_to_hlo/tests/int4.mlir
@@ -0,0 +1,27 @@
+// RUN: xla-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s
+
+// Test int4 constants and conversions.
+
+// CHECK-LABEL: ENTRY %main.{{.*}} () -> s4[6]
+func.func @main() -> tensor<6xi4> {
+  // CHECK-NEXT: %[[CONSTANT:.*]] = s4[6] constant({1, -2, -3, 4, -8, 7})
+  %0 = mhlo.constant dense<[1, -2, -3, 4, -8, 7]> : tensor<6xi4>
+  // CHECK-NEXT: %[[CONVERT1:.*]] = s8[6] convert(s4[6] %[[CONSTANT]])
+  %1 = "mhlo.convert"(%0) : (tensor<6xi4>) -> tensor<6xi8>
+  // CHECK-NEXT: ROOT %[[CONVERT2:.*]] = s4[6] convert(s8[6] %[[CONVERT1]])
+  %2 = "mhlo.convert"(%1) : (tensor<6xi8>) -> tensor<6xi4>
+  func.return %2 : tensor<6xi4>
+}
+
+// -----
+
+// CHECK-LABEL: ENTRY %main.{{.*}} () -> u4[4]
+func.func @main() -> tensor<4xui4> {
+  // CHECK-NEXT: %[[CONSTANT:.*]] = u4[4] constant({1, 2, 3, 15})
+  %0 = mhlo.constant dense<[1, 2, 3, 15]> : tensor<4xui4>
+  // CHECK-NEXT: %[[CONVERT1:.*]] = u8[4] convert(u4[4] %[[CONSTANT]])
+  %1 = "mhlo.convert"(%0) : (tensor<4xui4>) -> tensor<4xui8>
+  // CHECK-NEXT: ROOT %[[CONVERT2:.*]] = u4[4] convert(u8[4] %[[CONVERT1]])
+  %2 = "mhlo.convert"(%1) : (tensor<4xui8>) -> tensor<4xui4>
+  func.return %2 : tensor<4xui4>
+}
diff --git a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc b/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc
index 5a21319..94dd2d3 100644
--- a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc
+++ b/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc
@@ -159,8 +159,9 @@
   module.getBody()->clear();
   OpBuilder builder(module);
 
+  std::vector<const BufferAllocation*> ordered_allocations;
   TF_RETURN_WITH_CONTEXT_IF_ERROR(
-      HloToLhloModule(**assignment, *hlo_module, module),
+      HloToLhloModule(**assignment, *hlo_module, module, &ordered_allocations),
       "converting HLO to LHLO");
 
   return ::tsl::OkStatus();
@@ -700,6 +701,8 @@
   }
   op.setPrecisionConfigAttr(
       xla::ConvertPrecisionConfig(&config.precision_config(), &builder));
+  op.setGradXAttr(builder.getBoolAttr(config.grad_x()));
+  op.setGradYAttr(builder.getBoolAttr(config.grad_y()));
 }
 
 tsl::StatusOr<lmhlo_gpu::CublasLtMatmulEpilogue> AsLhloEpilogue(
@@ -2152,7 +2155,8 @@
                              token_mode);
 }
 
-tsl::Status LhloDialectEmitter::Initialize() {
+tsl::Status LhloDialectEmitter::Initialize(
+    std::vector<const BufferAllocation*>* ordered_allocations) {
   TF_RET_CHECK(computation_.IsEntryComputation());
 
   mlir::IntegerAttr unique_id =
@@ -2178,9 +2182,11 @@
   }
   Block* block = func_op.addEntryBlock();
 
-  llvm::SmallVector<const BufferAllocation*, 8> ordered_allocations;
-  for (const BufferAllocation& alloc : assignment_.Allocations())
-    ordered_allocations.push_back(&alloc);
+  for (const BufferAllocation& alloc : assignment_.Allocations()) {
+    if (!alloc.is_thread_local()) {
+      ordered_allocations->push_back(&alloc);
+    }
+  }
 
   if (computation_.IsEntryComputation()) {
     // Sort the rather arbitrarily ordered allocations to match the input/output
@@ -2208,7 +2214,7 @@
       return false;
     };
 
-    std::stable_sort(ordered_allocations.begin(), ordered_allocations.end(),
+    std::stable_sort(ordered_allocations->begin(), ordered_allocations->end(),
                      allocation_comparator);
   }
 
@@ -2232,11 +2238,9 @@
   // - one memref for each of the parameters.
   // - one memref for each other buffer allocation.
   llvm::SmallVector<DictionaryAttr, 8> args_attrs;
-  for (const BufferAllocation* alloc : ordered_allocations) {
-    if (alloc->is_thread_local()) {
-      continue;
-    }
-
+  auto it = ordered_allocations->begin();
+  while (it != ordered_allocations->end()) {
+    const BufferAllocation* alloc = *it;
     // There are optional attributes to help the program run through XLA. XLA
     // defines ExecutionInput and ExecutionOutput structures to carry
     // input-output type and buffer information, therefore any information they
@@ -2280,6 +2284,7 @@
       const Shape* sub_shape = iter->second.first;
       const xla::ShapeIndex& shape_index = iter->second.second;
       if (!sub_shape->IsArray()) {
+        it = ordered_allocations->erase(it);
         continue;
       }
       arg_attr_list.set("lmhlo.output_index",
@@ -2296,6 +2301,7 @@
     block->addArgument(arg_type, loc);
     allocations_[alloc] = block->getArguments().back();
     args_attrs.push_back(arg_attr_list.getDictionary(builder_.getContext()));
+    it++;
   }
 
   FunctionType function_type =
@@ -2315,7 +2321,7 @@
 
 tsl::Status HloToLhloModule(
     const BufferAssignment& assignment, const HloModule& hlo_module,
-    ModuleOp module,
+    ModuleOp module, std::vector<const BufferAllocation*>* ordered_allocations,
     absl::flat_hash_map<const mlir::Operation*, const xla::HloInstruction*>*
         lhlo_to_hlo_map) {
   module.getContext()
@@ -2334,7 +2340,7 @@
   const HloComputation* computation = hlo_module.entry_computation();
 
   LhloDialectEmitter emitter(assignment, *computation, module);
-  TF_RETURN_IF_ERROR(emitter.Initialize());
+  TF_RETURN_IF_ERROR(emitter.Initialize(ordered_allocations));
 
   const xla::HloInstructionSequence* schedule =
       assignment.hlo_ordering().SequentialOrder(*computation);
diff --git a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.h b/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.h
index 052ae9d..a512bdf 100644
--- a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.h
+++ b/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.h
@@ -43,7 +43,8 @@
  public:
   // Initializes internal data structures. It must be called before calling any
   // of the visitors.
-  tsl::Status Initialize();
+  tsl::Status Initialize(
+      std::vector<const xla::BufferAllocation*>* ordered_allocations);
 
   LhloDialectEmitter(const xla::BufferAssignment& assignment,
                      const xla::HloComputation& computation, ModuleOp module)
@@ -328,9 +329,12 @@
 // `lhlo_to_hlo_map`, if non-null, is populated with a mapping from generated
 // top-level MLIR operations to the original HLO instructions. "top-level" means
 // that ops inside the bodies of fusions are not included (but all fusions are).
+// Store buffer allocations from buffer assignment in the order of inputs to the
+// LMHLO entry function.
 tsl::Status HloToLhloModule(
     const xla::BufferAssignment& assignment, const xla::HloModule& hlo_module,
     ModuleOp module,
+    std::vector<const xla::BufferAllocation*>* ordered_allocation,
     absl::flat_hash_map<const mlir::Operation*, const xla::HloInstruction*>*
         lhlo_to_hlo_map = nullptr);
 
diff --git a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/tests/hlo_text_to_lhlo_no_opt.hlotxt b/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/tests/hlo_text_to_lhlo_no_opt.hlotxt
index 79bbb3c..a1722f6 100644
--- a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/tests/hlo_text_to_lhlo_no_opt.hlotxt
+++ b/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/tests/hlo_text_to_lhlo_no_opt.hlotxt
@@ -26,14 +26,14 @@
 
 // CHECK-LABEL: func @main
 // CHECK: "lmhlo.scatter"
-// CHECK: ^bb0(%[[ARG5:.*]]: tensor<i32>, %[[ARG6:.*]]: tensor<i32>):
-// CHECK:  mhlo.return %[[ARG6]]
 // CHECK: indices_are_sorted = false
 // CHECK: update_window_dims = [1]
 // CHECK: inserted_window_dims = [0]
 // CHECK: scatter_dims_to_operand_dims = [0]
 // CHECK: index_vector_dim = 1
 // CHECK: unique_indices = false
+// CHECK: ^bb0(%[[ARG5:.*]]: tensor<i32>, %[[ARG6:.*]]: tensor<i32>):
+// CHECK:  mhlo.return %[[ARG6]]
 // CHECK: (memref<3x3xi32>, memref<2xi32>, memref<2x3xi32>, memref<3x3xi32>) -> ()
 ENTRY main {
   operand = s32[3,3] parameter(0)
@@ -68,15 +68,15 @@
 // CHECK-LABEL: func @main
 // CHECK: %[[GLOBAL_MEMREF:.*]] = memref.get_global @[[$GLOBAL]] : memref<f32>
 // CHECK: "lmhlo.select_and_scatter"(%{{.*}}, %{{.*}}, %[[GLOBAL_MEMREF]], %{{.*}})
+// CHECK: padding = dense<0> : tensor<1xi64>
+// CHECK: window_dimensions = dense<3> : tensor<1xi64>
+// CHECK: window_strides = dense<3> : tensor<1xi64>
 // CHECK: ^bb0(%[[ARG0:.*]]: tensor<f32>, %[[ARG1:.*]]: tensor<f32>):
 // CHECK: %[[COMPARE:.*]] = mhlo.compare GE, %[[ARG0]], %[[ARG1]]
 // CHECK: mhlo.return %[[COMPARE]] : tensor<i1>
 // CHECK: ^bb0(%[[ARG2:.*]]: tensor<f32>, %[[ARG3:.*]]: tensor<f32>):
 // CHECK: %[[ADD:.*]] = mhlo.add %[[ARG2]], %[[ARG3]]
 // CHECK: mhlo.return %[[ADD]] : tensor<f32>
-// CHECK: padding = dense<0> : tensor<1xi64>
-// CHECK: window_dimensions = dense<3> : tensor<1xi64>
-// CHECK: window_strides = dense<3> : tensor<1xi64>
 // CHECK: (memref<6xf32>, memref<2xf32>, memref<f32>, memref<6xf32>) -> ()
 ENTRY main () -> f32[6] {
   %operand = f32[6]{0} parameter(0)
@@ -215,12 +215,12 @@
   param0 = f32[8] parameter(0)
   // CHECK:  [[VIEW:%.*]] = memref.view [[BUFFER]]{{.*}} : memref<32xi8> to memref<8xf32>
   // CHECK:  [[TOKEN:%.*]] = "lmhlo_gpu.all_reduce_start"([[VIEW]], [[VIEW]])
+  // CHECK-SAME:  channel_id = #mhlo.channel_handle<handle = 1, type = 0>
+  // CHECK-SAME{LITERAL}:  replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64>
   // CHECK:  ^bb0([[ARG0:%.*]]: tensor<f32>, [[ARG1:%.*]]: tensor<f32>):
   // CHECK:    [[ADD:%.*]] = mhlo.add [[ARG0]], [[ARG1]]
   // CHECK:    mhlo.return [[ADD]] : tensor<f32>
-  // CHECK:  }) {
-  // CHECK-SAME:  channel_id = #mhlo.channel_handle<handle = 1, type = 0>
-  // CHECK-SAME{LITERAL}:  replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64>
+  // CHECK:  })
   // CHECK:  "lmhlo_gpu.all_reduce_done"([[TOKEN]])
   start = f32[8] all-reduce-start(param0),
       channel_id=1, replica_groups={{0,1,2,3}, {4,5,6,7}}, to_apply=add
@@ -247,12 +247,12 @@
   // CHECK:  [[VIEW0:%.*]] = memref.view [[BUFFER0]]{{.*}} : memref<32xi8> to memref<8xf32>
   // CHECK:  [[VIEW1:%.*]] = memref.view [[BUFFER1]]{{.*}} : memref<36xi8> to memref<9xf32>
   // CHECK:  [[TOKEN:%.*]] = "lmhlo_gpu.all_reduce_start"([[VIEW0]], [[VIEW1]], [[VIEW0]], [[VIEW1]])
+  // CHECK-SAME:  channel_id = #mhlo.channel_handle<handle = 1, type = 0>
+  // CHECK-SAME{LITERAL}:  replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64>
   // CHECK:  ^bb0([[ARG0:%.*]]: tensor<f32>, [[ARG1:%.*]]: tensor<f32>):
   // CHECK:    [[ADD:%.*]] = mhlo.add [[ARG0]], [[ARG1]]
   // CHECK:    mhlo.return [[ADD]] : tensor<f32>
-  // CHECK:  }) {
-  // CHECK-SAME:  channel_id = #mhlo.channel_handle<handle = 1, type = 0>
-  // CHECK-SAME{LITERAL}:  replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64>
+  // CHECK:  })
   // CHECK:  "lmhlo_gpu.all_reduce_done"([[TOKEN]])
   start = (f32[8], f32[9]) all-reduce-start(param0, param1),
       channel_id=1, replica_groups={{0,1,2,3}, {4,5,6,7}}, to_apply=add
@@ -388,7 +388,7 @@
 HloModule TestModule
 
 // CHECK: func @main
-// CHECK:   "lmhlo.rng_get_and_update_state"(%{{.*}}) {delta = 131072 : i64} : (memref<2xui64>) -> ()
+// CHECK:   "lmhlo.rng_get_and_update_state"(%{{.*}}) <{delta = 131072 : i64}> : (memref<2xui64>) -> ()
 ENTRY main {
   ROOT %rng-get-and-update-state = u64[2]{0} rng-get-and-update-state(), delta=131072
 }
@@ -564,9 +564,9 @@
 HloModule CustomCallWithComputation
 
 // CHECK: "lmhlo.custom_call"
+// CHECK: call_target_name = "__custom"
 // CHECK: %0 = mhlo.add
 // CHECK: mhlo.return %0
-// CHECK: call_target_name = "__custom"
 
 computation1 {
   param_0 = f32[] parameter(0)
@@ -586,17 +586,17 @@
 // CHECK: func @main
 // CHECK: %[[ARG1:arg[0-9]+]]: memref<16xi8> {lmhlo.params = 1 : index}
 // CHECK: %[[VIEW:.*]] = memref.view %[[ARG1]][%c0][]
-// CHECK: %[[TOKEN:.*]] = "lmhlo.send"(%[[VIEW]]) {
+// CHECK: %[[TOKEN:.*]] = "lmhlo.send"(%[[VIEW]])
 // CHECK:   channel_handle = #mhlo.channel_handle<handle = 1, type = 2>,
 // CHECK:   frontend_attributes = {_xla_dcn_recv_channel = "2",
 // CHECK:                          _xla_host_transfer_handler_name = "undef",
 // CHECK:                          _xla_host_transfer_rendezvous = "undef"}
 // CHECK:   is_host_transfer = true
-// CHECK: } : (memref<4xf32>) -> !mhlo.token
-// CHECK: "lmhlo.send_done"(%0) {
+// CHECK: : (memref<4xf32>) -> !mhlo.token
+// CHECK: "lmhlo.send_done"(%0)
 // CHECK:   channel_handle = #mhlo.channel_handle<handle = 1, type = 2>,
 // CHECK    is_host_transfer = true
-// CHECK: } : (!mhlo.token) -> ()
+// CHECK: : (!mhlo.token) -> ()
 ENTRY main {
   %tok = token[] parameter(0)
   %buf = f32[4]{0} parameter(1)
@@ -611,16 +611,16 @@
 // CHECK: func @main
 // CHECK: %[[ARG1:arg[0-9]+]]: memref<16xi8> {lmhlo.output_index = dense<0> : tensor<1xi64>}
 // CHECK: %[[VIEW:.*]] = memref.view %[[ARG1]][%c0][]
-// CHECK: %[[TOKEN:.*]] = "lmhlo.recv"(%[[VIEW]]) {
+// CHECK: %[[TOKEN:.*]] = "lmhlo.recv"(%[[VIEW]])
 // CHECK:   channel_handle = #mhlo.channel_handle<handle = 1, type = 3>,
 // CHECK:    frontend_attributes = {_xla_host_transfer_handler_name = "undef",
 // CHECK:                           _xla_host_transfer_rendezvous = "undef"}
 // CHECK:   is_host_transfer = true
-// CHECK: } : (memref<4xf32>) -> !mhlo.token
-// CHECK: "lmhlo.recv_done"(%0) {
+// CHECK: : (memref<4xf32>) -> !mhlo.token
+// CHECK: "lmhlo.recv_done"(%0)
 // CHECK:   channel_handle = #mhlo.channel_handle<handle = 1, type = 3>,
 // CHECK    is_host_transfer = true
-// CHECK: } : (!mhlo.token) -> ()
+// CHECK: : (!mhlo.token) -> ()
 ENTRY main {
   %tok = token[] parameter(0)
   %recv = (f32[4]{0}, u32[], token[]) recv(token[] %tok), channel_id=1, is_host_transfer=true, frontend_attributes={_xla_host_transfer_handler_name="undef",_xla_host_transfer_rendezvous="undef"}
@@ -632,7 +632,7 @@
 HloModule TestAllGatherAsync
 
 // CHECK: func @main
-// CHECK: %[[TOKEN:.*]] = "lmhlo_gpu.all_gather_start"(%{{.*}}, %{{.*}}) {
+// CHECK: %[[TOKEN:.*]] = "lmhlo_gpu.all_gather_start"(%{{.*}}, %{{.*}}) <
 // CHECK-SAME: all_gather_dimension = 1 : i64
 // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>
 // CHECK-SAME: use_global_device_ids = false
@@ -649,14 +649,14 @@
 HloModule AsyncReduceScatter
 
 // CHECK: func @main
-// CHECK: %[[TOKEN:.*]] = "lmhlo_gpu.reduce_scatter_start"(%{{.*}}, %{{.*}}) ({
-// CHECK:  ^bb0([[ARG0:%.*]]: tensor<f32>, [[ARG1:%.*]]: tensor<f32>):
-// CHECK:    [[ADD:%.*]] = mhlo.add [[ARG0]], [[ARG1]]
-// CHECK:    mhlo.return [[ADD]] : tensor<f32>
-// CHECK:  }) {
+// CHECK: %[[TOKEN:.*]] = "lmhlo_gpu.reduce_scatter_start"(%{{.*}}, %{{.*}}) <
 // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
 // CHECK-SAME: scatter_dimension = 0
 // CHECK-SAME: use_global_device_ids = false
+// CHECK:  ^bb0([[ARG0:%.*]]: tensor<f32>, [[ARG1:%.*]]: tensor<f32>):
+// CHECK:    [[ADD:%.*]] = mhlo.add [[ARG0]], [[ARG1]]
+// CHECK:    mhlo.return [[ADD]] : tensor<f32>
+// CHECK:  }) :
 // CHECK ""lmhlo_gpu.reduce_scatter_done"(%[[TOKEN]])
 
 add {
@@ -682,7 +682,7 @@
 HloModule AsyncAllToAll
 
 // CHECK: func @main
-// CHECK: %[[TOKEN:.*]] = "lmhlo_gpu.all_to_all_start"(%{{.*}}, %{{.*}}) {
+// CHECK: %[[TOKEN:.*]] = "lmhlo_gpu.all_to_all_start"(%{{.*}}, %{{.*}}) <
 // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
 // CHECK ""lmhlo_gpu.all_to_all_done"(%[[TOKEN]])
 
@@ -702,7 +702,7 @@
 HloModule TestAllGatherAsyncWithSyncFlagFalse
 
 // CHECK: func @main
-// CHECK: %[[TOKEN:.*]] = "lmhlo_gpu.all_gather_start"(%{{.*}}, %{{.*}}) {
+// CHECK: %[[TOKEN:.*]] = "lmhlo_gpu.all_gather_start"(%{{.*}}, %{{.*}}) <
 // CHECK-SAME: all_gather_dimension = 1 : i64
 // CHECK-SAME: is_sync = false
 // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>
@@ -720,7 +720,7 @@
 HloModule TestAllGatherAsyncWithSyncFlagTrue
 
 // CHECK: func @main
-// CHECK: %[[TOKEN:.*]] = "lmhlo_gpu.all_gather_start"(%{{.*}}, %{{.*}}) {
+// CHECK: %[[TOKEN:.*]] = "lmhlo_gpu.all_gather_start"(%{{.*}}, %{{.*}}) <
 // CHECK-SAME: all_gather_dimension = 1 : i64
 // CHECK-SAME: is_sync = true
 // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>
diff --git a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/tests/non_identity_layouts.hlotxt b/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/tests/non_identity_layouts.hlotxt
index 56a60ec..7967070 100644
--- a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/tests/non_identity_layouts.hlotxt
+++ b/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/tests/non_identity_layouts.hlotxt
@@ -14,7 +14,7 @@
   x = f32[3, 2]{1,0} parameter(0)
 
   // CHECK:   %[[VIEW:.*]] = memref.view {{.*}} : memref<24xi8> to memref<3x2xf32>
-  // CHECK: "lmhlo.fusion"() ({
+  // CHECK: "lmhlo.fusion"() <{backend_config = "{{.*}}"}> ({
   // CHECK:   %[[VAL2:.*]] = bufferization.to_tensor %[[VIEW]] : memref<3x2xf32>
   // CHECK:   %[[VAL3:.*]] = mhlo.copy %[[VAL2]] {
   // CHECK-SAME:               result_layout = dense<[0, 1]>
@@ -22,6 +22,6 @@
   // CHECK-SAME:             } : tensor<3x2xf32>
   // CHECK:   memref.tensor_store %[[VAL3:.*]], %{{.*}} : memref<3x2xf32, #[[MAP]]>
   // CHECK:   "lmhlo.terminator"() : () -> ()
-  // CHECK: }) {backend_config = "{{.*}}"} : () -> ()
+  // CHECK: }) : () -> ()
   ROOT fusion = f32[3, 2]{0,1} fusion(f32[3, 2]{1,0} x), kind=kLoop, calls=Fusion
 }
diff --git a/third_party/xla/xla/util.h b/third_party/xla/xla/util.h
index c55d694..ac708f0 100644
--- a/third_party/xla/xla/util.h
+++ b/third_party/xla/xla/util.h
@@ -491,20 +491,32 @@
 // Returns `base` multiplied by itself `exponent` number of times.
 //
 // Note: returns 1 when `exponent` is zero.
-// Precondition: `exponent` is non-negative.
-template <typename T>
-constexpr T IPow(T base, int exponent) {
-  // A negative `exponent` is indicative of a logic bug for integral `base`.
-  // We disallow it for floating-point types for symmetry.
-  ABSL_ASSERT(exponent >= 0);
+// Precondition: `exponent` is non-negative for integral `T`.
+template <typename T, typename ExpType>
+constexpr T IPow(T base, ExpType exponent) {
+  static_assert(std::numeric_limits<ExpType>::is_integer);
+  if constexpr (std::numeric_limits<T>::is_integer) {
+    // A negative `exponent` is indicative of a logic bug for integral `base`.
+    // We disallow it for floating-point types for symmetry.
+    ABSL_ASSERT(exponent >= 0);
+  }
+  const bool take_reciprocal = exponent < 0;
   // We use the right-to-left binary exponentiation algorithm.
-  T result{1};
-  while (exponent > 0) {
+  T result(1);
+  for (;;) {
     if ((exponent & 1) != 0) {
       result *= base;
     }
+    exponent /= 2;
+    if (exponent == 0) {
+      break;
+    }
     base *= base;
-    exponent >>= 1;
+  }
+  if constexpr (std::numeric_limits<ExpType>::is_signed) {
+    if (take_reciprocal) {
+      return T(1) / result;
+    }
   }
   return result;
 }
diff --git a/third_party/xla/xla/xla.bzl b/third_party/xla/xla/xla.bzl
index 108b44f..9cf3704 100644
--- a/third_party/xla/xla/xla.bzl
+++ b/third_party/xla/xla/xla.bzl
@@ -6,7 +6,7 @@
 )
 load(
     "@local_tsl//tsl:tsl.bzl",
-    "if_tsl_link_protobuf",
+    "if_oss",
     "tsl_copts",
     _tsl_clean_dep = "clean_dep",
 )
@@ -39,87 +39,54 @@
 def xla_py_test_deps():
     return []
 
-def xla_cc_binary(deps = None, copts = tsl_copts(), **kwargs):
-    if not deps:
-        deps = []
+# TODO(ddunleavy): some of these should be removed from here and added to
+# specific targets.
+# We actually shouldn't need this anymore post vendoring. If we build without
+# `framework_shared_object` in the bazelrc all of this should be able to go
+# away. The problem is making sure that all these impl deps are `if_static`'d
+# appropriately throughout XLA.
+_XLA_SHARED_OBJECT_SENSITIVE_DEPS = if_oss([_tsl_clean_dep("@com_google_protobuf//:protobuf")]) + [
+    clean_dep("//xla:xla_proto_cc_impl"),
+    clean_dep("//xla:xla_data_proto_cc_impl"),
+    clean_dep("//xla/service:hlo_proto_cc_impl"),
+    clean_dep("//xla/service:buffer_assignment_proto_cc_impl"),
+    clean_dep("//xla/service/memory_space_assignment:memory_space_assignment_proto_cc_impl"),
+    clean_dep("//xla/service/gpu:backend_configs_cc_impl"),
+    clean_dep("//xla/service/gpu/model:hlo_op_profile_proto_cc_impl"),
+    clean_dep("//xla/stream_executor:device_description_proto_cc_impl"),
+    clean_dep("//xla/stream_executor:device_id_utils"),
+    clean_dep("//xla/stream_executor:stream_executor_impl"),
+    clean_dep("//xla/stream_executor/gpu:gpu_cudamallocasync_allocator"),
+    clean_dep("//xla/stream_executor/gpu:gpu_init_impl"),
+    clean_dep("@local_tsl//tsl/profiler/utils:time_utils_impl"),
+    clean_dep("@local_tsl//tsl/profiler/backends/cpu:annotation_stack_impl"),
+    clean_dep("@local_tsl//tsl/profiler/backends/cpu:traceme_recorder_impl"),
+    clean_dep("@local_tsl//tsl/profiler/protobuf:profiler_options_proto_cc_impl"),
+    clean_dep("@local_tsl//tsl/profiler/protobuf:xplane_proto_cc_impl"),
+    clean_dep("//xla:autotune_results_proto_cc_impl"),
+    clean_dep("//xla:autotuning_proto_cc_impl"),
+    clean_dep("@local_tsl//tsl/protobuf:protos_all_cc_impl"),
+    clean_dep("@local_tsl//tsl/platform:env_impl"),
+    clean_dep("@local_tsl//tsl/framework:allocator"),
+    clean_dep("@local_tsl//tsl/framework:allocator_registry_impl"),
+    clean_dep("@local_tsl//tsl/util:determinism"),
+] + if_cuda_is_configured([
+    clean_dep("//xla/stream_executor/cuda:cuda_stream"),
+    clean_dep("//xla/stream_executor/cuda:all_runtime"),
+    clean_dep("//xla/stream_executor/cuda:stream_executor_cuda"),
+]) + if_rocm_is_configured([
+    clean_dep("//xla/stream_executor/gpu:gpu_stream"),
+    clean_dep("//xla/stream_executor/rocm:all_runtime"),
+    clean_dep("//xla/stream_executor/rocm:stream_executor_rocm"),
+])
 
-    # TODO(ddunleavy): some of these should be removed from here and added to
-    # specific targets.
-    deps += [
-        _tsl_clean_dep("@com_google_protobuf//:protobuf"),
-        "//xla:xla_proto_cc_impl",
-        "//xla:xla_data_proto_cc_impl",
-        "//xla/service:hlo_proto_cc_impl",
-        "//xla/service:buffer_assignment_proto_cc_impl",
-        "//xla/service/memory_space_assignment:memory_space_assignment_proto_cc_impl",
-        "//xla/service/gpu:backend_configs_cc_impl",
-        "//xla/service/gpu/model:hlo_op_profile_proto_cc_impl",
-        "//xla/stream_executor:device_description_proto_cc_impl",
-        "//xla/stream_executor:stream_executor_impl",
-        "//xla/stream_executor/gpu:gpu_init_impl",
-        "@local_tsl//tsl/platform:env_impl",
-        "@local_tsl//tsl/platform:tensor_float_32_utils",
-        "@local_tsl//tsl/profiler/utils:time_utils_impl",
-        "@local_tsl//tsl/profiler/backends/cpu:annotation_stack_impl",
-        "@local_tsl//tsl/profiler/backends/cpu:traceme_recorder_impl",
-        "@local_tsl//tsl/profiler/protobuf:profiler_options_proto_cc_impl",
-        "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc_impl",
-        "//xla:autotune_results_proto_cc_impl",
-        "//xla:autotuning_proto_cc_impl",
-        "@local_tsl//tsl/protobuf:protos_all_cc_impl",
-        "@local_tsl//tsl/framework:allocator",
-        "@local_tsl//tsl/framework:allocator_registry_impl",
-        "@local_tsl//tsl/util:determinism",
-    ]
-    native.cc_binary(deps = deps, copts = copts, **kwargs)
+def xla_cc_binary(deps = [], copts = tsl_copts(), **kwargs):
+    native.cc_binary(deps = deps + _XLA_SHARED_OBJECT_SENSITIVE_DEPS, copts = copts, **kwargs)
 
-def xla_cc_test(
-        name,
-        deps = [],
-        **kwargs):
+def xla_cc_test(name, deps = [], **kwargs):
     native.cc_test(
         name = name,
-        deps = deps + if_tsl_link_protobuf(
-                   [],
-                   [
-                       _tsl_clean_dep("@com_google_protobuf//:protobuf"),
-                       # TODO(zacmustin): remove these in favor of more granular dependencies in each test.
-                       clean_dep("//xla:xla_proto_cc_impl"),
-                       clean_dep("//xla:xla_data_proto_cc_impl"),
-                       clean_dep("//xla/service:hlo_proto_cc_impl"),
-                       clean_dep("//xla/service:buffer_assignment_proto_cc_impl"),
-                       clean_dep("//xla/service/memory_space_assignment:memory_space_assignment_proto_cc_impl"),
-                       clean_dep("//xla/service/gpu:backend_configs_cc_impl"),
-                       clean_dep("//xla/service/gpu/model:hlo_op_profile_proto_cc_impl"),
-                       clean_dep("//xla/stream_executor:device_description_proto_cc_impl"),
-                       clean_dep("//xla/stream_executor:device_id_utils"),
-                       clean_dep("//xla/stream_executor:stream_executor_impl"),
-                       clean_dep("//xla/stream_executor/gpu:gpu_cudamallocasync_allocator"),
-                       clean_dep("//xla/stream_executor/gpu:gpu_init_impl"),
-                       clean_dep("@local_tsl//tsl/profiler/utils:time_utils_impl"),
-                       clean_dep("@local_tsl//tsl/profiler/backends/cpu:annotation_stack_impl"),
-                       clean_dep("@local_tsl//tsl/profiler/backends/cpu:traceme_recorder_impl"),
-                       clean_dep("@local_tsl//tsl/profiler/protobuf:profiler_options_proto_cc_impl"),
-                       clean_dep("@local_tsl//tsl/profiler/protobuf:xplane_proto_cc_impl"),
-                       clean_dep("//xla:autotune_results_proto_cc_impl"),
-                       clean_dep("//xla:autotuning_proto_cc_impl"),
-                       clean_dep("@local_tsl//tsl/protobuf:protos_all_cc_impl"),
-                       clean_dep("@local_tsl//tsl/platform:env_impl"),
-                       clean_dep("@local_tsl//tsl/framework:allocator"),
-                       clean_dep("@local_tsl//tsl/framework:allocator_registry_impl"),
-                       clean_dep("@local_tsl//tsl/util:determinism"),
-                   ],
-               ) +
-               if_cuda_is_configured([
-                   clean_dep("//xla/stream_executor/cuda:cuda_stream"),
-                   clean_dep("//xla/stream_executor/cuda:all_runtime"),
-                   clean_dep("//xla/stream_executor/cuda:stream_executor_cuda"),
-               ]) +
-               if_rocm_is_configured([
-                   clean_dep("//xla/stream_executor/gpu:gpu_stream"),
-                   clean_dep("//xla/stream_executor/rocm:all_runtime"),
-                   clean_dep("//xla/stream_executor/rocm:stream_executor_rocm"),
-               ]),
+        deps = deps + _XLA_SHARED_OBJECT_SENSITIVE_DEPS,
         exec_properties = tf_exec_properties(kwargs),
         **kwargs
     )
@@ -135,3 +102,9 @@
 
 def xla_nvml_deps():
     return ["@local_config_cuda//cuda:nvml_headers"]
+
+def xla_cub_deps():
+    return ["@local_config_cuda//cuda:cub_headers"]
+
+def xla_symbol_repository_deps():
+    return []
diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto
index 7399600..892eeb4 100644
--- a/third_party/xla/xla/xla.proto
+++ b/third_party/xla/xla/xla.proto
@@ -458,13 +458,18 @@
   // Whether to use cuBLASLt for GEMMs on GPUs.
   bool xla_gpu_enable_cublaslt = 166;
 
-  // 0:   Disable GPU graph capture.
-  // 1:   Enable GPU graphs for fusions and memcpy (safest ones).
-  // 2:   Enable GPU graphs for gemms.
-  // 3:   Enable GPU graphs for convolutions.
-  //
-  // Default: 0.
-  int32 xla_gpu_graph_level = 194;
+  // Commands are categorized into four types: FUSION represents regular fusion
+  // kernels. CUBLAS, CUDNN, and NCCL represent library calls.
+  enum CommandBufferCmdType {
+    INVALID = 0;
+    FUSION = 1;
+    CUBLAS = 2;
+    CUDNN = 3;
+    NCCL = 4;
+  }
+
+  // Determine the types of commands that are recorded into command buffers.
+  repeated CommandBufferCmdType xla_gpu_enable_command_buffer = 258;
 
   // Only instantiates a GPU graph after the captured function execution count
   // reaches the threshold. This constant is a heuristic to avoid creating a
@@ -550,6 +555,7 @@
 
   bool xla_gpu_lhs_enable_gpu_async_tracker = 204;
   string xla_gpu_pgle_profile_file_or_directory_path = 210;
+  int32 xla_gpu_memory_limit_slop_factor = 260;
 
   bool xla_gpu_enable_pipelined_collectives = 239;
   bool xla_gpu_enable_pipelined_all_reduce = 217;
@@ -646,7 +652,10 @@
 
   int32 xla_gpu_llvm_verification_level = 256;
 
-  // Next id: 258
+  // Enable radix sort using CUB.
+  bool xla_gpu_enable_cub_radix_sort = 259;
+
+  // Next id: 261
 
   // Extra options to pass to the compilation backend (e.g. LLVM); specific
   // interpretation of these values is left to the backend.
@@ -659,7 +668,8 @@
   // xla_gpu_enable_cuda_graphs
   // xla_gpu_allow_all_reduce_kernel
   // xla_gpu_enable_experimental_block_size
-  reserved 5, 117, 133, 139, 176, 178, 180, 193, 214;
+  // xla_gpu_graph_level
+  reserved 5, 117, 133, 139, 176, 178, 180, 193, 214, 194;
 }
 
 // Contains flags which affects the GPU compilation result.